Example 3, no zero contour
If there is no contour the initial zero finding will fail to converge to a zero value. After each Newton’s step the code checks the value of the function to see if it is without the set tolerance of zero, if it is not it takes another step. After a fixed number of steps (5 * max_newton
) if one of the inputs does not terminate, the zero finding code will write all value to NaN
for that particular input.
Note
This is done with a threshold cut on the path’s value set to 20 times the tolerance, if the path eventually finds its way to a contour it will return non-NaN values.
If you would rather the code throw an error when any of the input points fail set the keyword silent_fail=False
. With this set the code will raise a ValueError
value indicating it did not find a contour for at least one of the inputs and provide the index for those inputs.
Note
Just because one initial point fails to find a contour does not mean another point will also fail. Newton’s method will get stuck if its path travels through any point where the gradient is equal to zero (e.g. a maximum or minimum of the function). Try to make sure your initial points are not near the any extremum of the function.
%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from jax_zero_contour import ZeroSolver
Let’s make a function that is always positive.
@jax.tree_util.Partial
def f(pos):
# avoid r=0 so the grad is finite
r = jnp.sqrt(jnp.sum(pos**2, axis=0) + 1e-15)
return jnp.sinc(r) + 0.5
n = 1024
x = jnp.linspace(-2, 2, n)
y = jnp.linspace(-2, 2, n)
X, Y = jnp.meshgrid(x, y)
z = f(jnp.stack([X, Y]))
plt.imshow(
z,
extent=(x.min(), x.max(), y.min(), y.max()),
norm=colors.SymLogNorm(linthresh=0.1, vmin=-1, vmax=2),
cmap='PuOr_r',
origin='lower',
interpolation='nearest'
)
plt.colorbar();

zs = ZeroSolver()
init_guess_1 = jnp.array([[0.0, -0.6], [0.0, 1.0]])
output = zs.zero_contour_finder(
f,
init_guess_1
)
print(output[0]['value'])
[[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]]
As expected, the zero finder code was unable to identify any zero values based on the initial point and all the output values are NaN. If we wanted to let the code to return an error instead we can set the keyword silent_fail=False
.
init_guess_1 = jnp.array([[0.0, -0.6], [0.0, 1.0]])
output = zs.zero_contour_finder(
f,
init_guess_1,
silent_fail=False
)
ERROR:2025-06-24 13:37:47,883:jax._src.debugging:98: jax.debug.callback failed
Traceback (most recent call last):
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 96, in debug_callback_impl
callback(*args)
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 336, in _flat_callback
callback(*args, **kwargs)
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax_zero_contour/zero_contour_finder.py", line 274, in excepting_message
raise ValueError(f'No zero contour found after 5 * max_newton ({5 * self.max_newton}) iterations')
ValueError: No zero contour found after 5 * max_newton (25) iterations
Index of failed input(s): [0 1]
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[8], line 2
1 init_guess_1 = jnp.array([[0.0, -0.6], [0.0, 1.0]])
----> 2 output = zs.zero_contour_finder(
3 f,
4 init_guess_1,
5 silent_fail=False
6 )
[... skipping hidden 5 frame]
File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:1297, in ExecuteReplicated.__call__(self, *args)
1294 if (self.ordered_effects or self.has_unordered_effects
1295 or self.has_host_callbacks):
1296 input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1297 results = self.xla_executable.execute_sharded(
1298 input_bufs, with_tokens=True
1299 )
1301 result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
1302 len(self.ordered_effects))
1303 sharded_runtime_token = results.consume_token()
XlaRuntimeError: INTERNAL: CpuCallback error calling callback: Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 211, in start
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/asyncio/events.py", line 84, in _run
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3100, in run_cell
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3155, in _run_cell
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3367, in run_cell_async
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3612, in run_ast_nodes
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3672, in run_code
File "/tmp/ipykernel_2307/2099799116.py", line 2, in <module>
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/pjit.py", line 292, in cache_miss
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/pjit.py", line 153, in _python_pjit_helper
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/pjit.py", line 1877, in _pjit_call_impl_python
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/profiler.py", line 354, in wrapper
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1297, in __call__
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/callback.py", line 782, in _wrapped_callback
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 202, in _callback
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 99, in debug_callback_impl
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/debugging.py", line 336, in _flat_callback
File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax_zero_contour/zero_contour_finder.py", line 274, in excepting_message
ValueError: No zero contour found after 5 * max_newton (25) iterations