Example 3, no zero contour

If there is no contour the initial zero finding will fail to converge to a zero value. After each Newton’s step the code checks the value of the function to see if it is without the set tolerance of zero, if it is not it takes another step. After a fixed number of steps (5 * max_newton) if one of the inputs does not terminate, the zero finding code will write all value to NaN for that particular input.

Note

This is done with a threshold cut on the path’s value set to 20 times the tolerance, if the path eventually finds its way to a contour it will return non-NaN values.

If you would rather the code throw an error when any of the input points fail set the keyword silent_fail=False. With this set the code will raise a ValueError value indicating it did not find a contour for at least one of the inputs and provide the index for those inputs.

Note

Just because one initial point fails to find a contour does not mean another point will also fail. Newton’s method will get stuck if its path travels through any point where the gradient is equal to zero (e.g. a maximum or minimum of the function). Try to make sure your initial points are not near the any extremum of the function.

%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import matplotlib.colors as colors

from jax_zero_contour import ZeroSolver

Let’s make a function that is always positive.

@jax.tree_util.Partial
def f(pos):
     # avoid r=0 so the grad is finite
    r = jnp.sqrt(jnp.sum(pos**2, axis=0) + 1e-15)
    return jnp.sinc(r) + 0.5
n = 1024
x = jnp.linspace(-2, 2, n)
y = jnp.linspace(-2, 2, n)
X, Y = jnp.meshgrid(x, y)
z = f(jnp.stack([X, Y]))
plt.imshow(
    z,
    extent=(x.min(), x.max(), y.min(), y.max()),
    norm=colors.SymLogNorm(linthresh=0.1, vmin=-1, vmax=2),
    cmap='PuOr_r',
    origin='lower',
    interpolation='nearest'
)
plt.colorbar();
_images/38153e1c1f13d4dba6064644948ee5341dbbe5df9fd5bebeafbfc49f79258e48.png
zs = ZeroSolver()
init_guess_1 = jnp.array([[0.0, -0.6], [0.0, 1.0]])
output = zs.zero_contour_finder(
    f,
    init_guess_1
)
print(output[0]['value'])
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]

As expected, the zero finder code was unable to identify any zero values based on the initial point and all the output values are NaN. If we wanted to let the code to return an error instead we can set the keyword silent_fail=False.

init_guess_1 = jnp.array([[0.0, -0.6], [0.0, 1.0]])
output = zs.zero_contour_finder(
    f,
    init_guess_1,
    silent_fail=False
)
ERROR:2025-06-24 13:37:47,883:jax._src.debugging:98: jax.debug.callback failed
Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 96, in debug_callback_impl
    callback(*args)
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 336, in _flat_callback
    callback(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax_zero_contour/zero_contour_finder.py", line 274, in excepting_message
    raise ValueError(f'No zero contour found after 5 * max_newton ({5 * self.max_newton}) iterations')
ValueError: No zero contour found after 5 * max_newton (25) iterations
Index of failed input(s): [0 1]
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[8], line 2
      1 init_guess_1 = jnp.array([[0.0, -0.6], [0.0, 1.0]])
----> 2 output = zs.zero_contour_finder(
      3     f,
      4     init_guess_1,
      5     silent_fail=False
      6 )

    [... skipping hidden 5 frame]

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:1297, in ExecuteReplicated.__call__(self, *args)
   1294 if (self.ordered_effects or self.has_unordered_effects
   1295     or self.has_host_callbacks):
   1296   input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1297   results = self.xla_executable.execute_sharded(
   1298       input_bufs, with_tokens=True
   1299   )
   1301   result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
   1302       len(self.ordered_effects))
   1303   sharded_runtime_token = results.consume_token()

XlaRuntimeError: INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 211, in start
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/asyncio/events.py", line 84, in _run
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3100, in run_cell
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3155, in _run_cell
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3367, in run_cell_async
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3612, in run_ast_nodes
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3672, in run_code
  File "/tmp/ipykernel_2307/2099799116.py", line 2, in <module>
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/pjit.py", line 292, in cache_miss
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/pjit.py", line 153, in _python_pjit_helper
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/pjit.py", line 1877, in _pjit_call_impl_python
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/profiler.py", line 354, in wrapper
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1297, in __call__
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/callback.py", line 782, in _wrapped_callback
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 202, in _callback
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 99, in debug_callback_impl
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 336, in _flat_callback
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax_zero_contour/zero_contour_finder.py", line 274, in excepting_message
ValueError: No zero contour found after 5 * max_newton (25) iterations