Zero Finder

Find and follow a zero value contour for any 2D function written in Jax.

jax_zero_contour.zero_contour_finder.split_curves(a, threshold)

Given a set of sorted points, split it into multiple arrays if the distance between adjacent points is larger than the given threshold. Used to split an array into unique contours for plotting.

Parameters:
  • a (jnp.array) – Sorted list of positions (see the sort_by_distance function)

  • threshold (float) – If adjacent points are greater than this distance apart, split the list at that position.

Returns:

List of split arrays. If the first and last points of a sub-array are within the threshold of each other the first point is repeated at the end of the array (e.g. the contour is closed).

Return type:

list of jnp.arrays

class jax_zero_contour.zero_contour_finder.ZeroSolver(tol=1e-06, max_newton=5, forward_mode_differentiation=False)
__init__(tol=1e-06, max_newton=5, forward_mode_differentiation=False)

A class for solving zero contour values for a function.

Parameters:
  • tol (float, optional) – Newton’s steps are used to bring each proposed point on the contour to be within this tolerance of zero, by default 1e-6.

  • max_newton (int, optional) – The maximum number of Newton’s steps to run inside the path integrator, by default 5. To get from the initial guess to a point on the contour 5 * max_newton steps are used.

  • forward_mode_differentiation (bool, optional) – If True use forward mode auto-differentiation, otherwise use reverse mode, by default False

static path_reduce(paths)

A helper function to remove the NaN values from a contour path dictionary. Because the size of the output is dependent on the inputs this function can not be jit’ed.

Parameters:

paths (dict) – output path dictionary from the zero_contour_finder function

Returns:

paths – the paths object with the jax.numpy.nan values removed

Return type:

dict

zero_contour_finder(f, init_guess, delta=0.1, N=1000, silent_fail=True)

Find the zero contour of a 2D function.

After a path hits an endpoint or closes any further points on the contour are written to jax.numpy.nan. The final output will be shifted so that the finite parts of the contour are brought to the front of the array. The points in the resulting paths are ordered.

Any points along the contour that have a function evaluation greater than 20 times the tolerance are also written to jax.numpy.nan.

Parts of this code use jax.lax.cond to stop the calculation early when certain termination conditions are satisfied. As a result it should not be combined with jax.vmap.

Parameters:
  • f (function) – The function you want to find the zero contours for, it should have as input one positional argument that is an array shape (1, 2) (e.g. jnp.array([x_value, y_value])) and returns a single value

  • init_guess (jax.numpy.array) – Initial guesses for points near the zero contour, one guess per row.

  • delta (float, optional) – The step size to take along the contour when searching for a new point, by default 0.1.

  • N (int, optional) – The total number of steps to take in each direction from the starting point(s). The final path will be 2N+1 in size (N points in the forward direction, N points in the reverse direction, with the initial point in the middle). By default 1000.

  • silent_fail (bool, optional) – If False the code will raise an exception if any of the initial points do lead to a zero value with the tolerance value within the number of allowed max_newton steps. If True the code will continue anyways. By default True

Returns:

  • paths (dict) – The return dictionary will have two keys. “path”: jax.numpy.array with shape (number of guesses, 2N+1, 2) with the contours paths for each guess. “value”: jax.numpy.array with shape (number of guesses, 2N+1) with the function value at each point on the path

  • stop_output (jax.numpy.array) – List containing the stopping conditions for each guess