Example 1, smooth function

The zero_contour_finder function will find points along the zero contour of a 2D function defined with Jax. This example will walk though its basic usage and what each of the input parameters are.

Imports

We begin with some basic imports for Jax, Matplotlib, and the zero contour functions.

Note

While it is necessary to use 64 bit precision to use the zero contour functions, it can be useful to help avoid floating point errors for very rapidly changing functions.

%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import matplotlib.colors as colors

from jax_zero_contour import ZeroSolver

For our first test we will create a function that will have two distinct closed contours. This function must take in one position value jnp.array([x, y]) and return a single number.

@jax.tree_util.Partial
def f(p):
    return 0.01 * ((p + 0.5)**4 - 30 * p**2 - 20 * p).sum(axis=0) + 2.0

For reference let’s plot the function evaluated on an evenly spaced grid. This is not used to find the contours, only for visualization purposes. For the plot we will use a diverging symmetric log colorbar that is linear near the value of zero, this will help to visualize where the contour of interest are in the function.

n = 1024
x = jnp.linspace(-8, 8, n)
y = jnp.linspace(-8, 8, n)
X, Y = jnp.meshgrid(x, y)
z = f(jnp.stack([X, Y]))
plt.imshow(
    z,
    extent=(x.min(), x.max(), y.min(), y.max()),
    norm=colors.SymLogNorm(linthresh=0.1, vmin=-175, vmax=175),
    cmap='PuOr_r',
    origin='lower',
    interpolation='nearest'
)
plt.colorbar();
_images/ccf4a3cff30fb52aa49759f4f5b5c87dd6a3a0950497898afd85d2bd2f5fa2f4.png

Finding the zero contours

There are two steps to finding a zero contour:

  1. Create a ZeroSolver class instance

  2. Use the zero_contour_finder method on this class to find our contours

The zero_contour_finder method has two required inputs:

  • f: The function we can to find the contours for. Ths function must be wrapped in jax.tree_util.Partial as it is being included as the input to a jited function.

  • init_guess: Initial guesses for a points near the zero contour, one guess per row.

Note

The zero_solver class takes in the keyword forward_mode_differentiation on creation, if set to true it will use forward mode auto-differentiation when evaluating all gradients. This can be useful if your target function make use of jax.numpy.where and can’t be reverse auto-differentiated.

zs = ZeroSolver()

Looking at the plot an input point of (-2.0, 0.0) should find the larger contour and the point (4.0, 4.0) should find the smaller one.

init_guess = jnp.array([[-2.0, 0.0], [4.0, 4.0]])
paths, stopping_reason = zs.zero_contour_finder(
    f,
    init_guess
)
print(stopping_reason)
[[2 2]
 [2 2]]

Both the forward pass and backward pass for both guesses returned a stopping condition of 2.

Note

There is one row per initial guess, with the first column being the stopping condition for traversing the contour in the forward direction, and the second column being for the reverse direction.

Let’s check what that means by looking at the stopping_conditions look up table:

zs.stopping_conditions
{0: 'none', 1: 'end_point', 2: 'closed_loop'}

2 indicates that the contours have both formed a “closed_loop” before stopping. When this happens the resulting path is filled with NaN values after the close is detected. These NaN values are always moved to the end of the path array for easy trimming if that is needed. As a result, If a contour closes on the forward pass the values from the reverse pass are not used.

Note

The ZeroSolver.path_reduce helper method is provided to remove these NaN values, but because each contour have a length that is dependent on the inputs, this method is not compatible with jit.

Now we can plot the resulting contours along side the initial guesses and the value of each point along the contours. The path object that is returned is a dictionary with two keys:

  • "path": array with shape (number of guesses, 2N+1, 2) with the contours paths for each guess

  • "value": array with shape (number of guesses, 2N+1) with the function values at each point on the contour

plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.imshow(
    z,
    extent=(x.min(), x.max(), y.min(), y.max()),
    norm=colors.SymLogNorm(linthresh=0.1, vmin=-175, vmax=175),
    cmap='PuOr_r',
    origin='lower',
    interpolation='nearest'
)
plt.plot(*paths['path'].T)
plt.plot(*init_guess[0], 'x', ms=10, color='C0')
plt.plot(*init_guess[1], 'x', ms=10, color='C1')
plt.colorbar()

plt.subplot(122)
plt.xlabel('path index')
plt.ylabel('function value on contour')
plt.plot(paths['value'].T);
_images/6b47bd0295c509561b077083519476d306ab6fb15b3a8ddd940f79b554731765.png

In both cases we see that the value along the contours are within the default tolerance of \(10^{-6}\). Before we move on to our next example, let’s take a closer look at the generated path. For clarity we will look at the smaller path, but this time plot the scatter points:

plt.plot(*paths['path'][1].T, '.', ms=1, color='C1')
plt.gca().set_aspect(1);
_images/3613847af75d3319dc00ad12e08d0d80e2652cc8e69901c208929b8b660140a5.png

From this plot we can see that the generated point on the contour are very close to being evenly spaced. The spacing of these points is controlled by the delta parameter that defaults to 0.1. Let’s try adjusting this to get some points that are closer together.

paths_close, stopping_reason_close = zs.zero_contour_finder(
    f,
    init_guess,
    delta=0.01
)
plt.plot(*paths_close['path'][1].T, '.', ms=1, color='C1')
plt.gca().set_aspect(1);
_images/b5ca99d63bef1bc2055f5446d5861bc4da6c1a9ae2c22b6b19a3326879044496.png

The points are so close together that on this plot it looks like a solid line. If we take a look at the stopping conditions for the smaller spacing we will see that the larger contour did not close this time.

print(stopping_reason_close)
[[0 0]
 [0 2]]
plt.plot(*paths_close['path'][0].T, '.', ms=1, color='C0')
plt.gca().set_aspect(1);
_images/ff64e6a4c14dbf4f9cc20670895f01d0ebc712637c9b006b0e4243c3d5016343.png

This is because by default the zero_contour_finder function will trace the contour both clockwise and counterclockwise from the initial point up to 1000 steps in each direction before stopping. With the smaller step size more points needed to be taken to get around the contour. If we increase the number of steps we can make it close again.

Note

The two stopping conditions for each initial guess are for the path in each direction. Typical examples are:

  • [2, 2]: The path close during the first (clockwise) search and the second (counterclockwise) search

  • [0, 2]: The path did not close in the first search, but did close during the second (counterclockwise) search

  • [0, 0]: The path did not close

  • [1, 1]: The path hit endpoints in both directions

paths_close, stopping_reason_close = zs.zero_contour_finder(
    f,
    init_guess,
    delta=0.01,
    N=2300
)
print(stopping_reason_close)
[[0 2]
 [2 2]]
plt.plot(*paths_close['path'][0].T, '.', ms=1, color='C0')
plt.gca().set_aspect(1);
_images/d66efe2e3cbd05e94844ecf37f03fd7ac490c694bd0dd18f79f1f81b6b607fea.png

Tolerance

The optional keywords for the ZeroSolver function are:

  • tol: Newton’s steps are used to bring each proposed point on the contour to be within this tolerance of zero, by default 1e-6.

  • max_newton: The maximum number of Newton’s steps to run, by default 25.

  • forward_mode_differentiation: Flag used to set either reverse-mode or forward-mode auto-differentiation

If you need to find the contour to a higher tolerance you can adjust the tol value to be smaller (down to ~a few times your floating point precision). By construction, each new proposed point on the contour should be close to zero, so very few Newton’s steps should be needed to bring it back in line, but in the face of a function discontinuity or an place where the function’s gradient is large, it can take more steps. max_newton controls the allowed maximum number of steps that can be taken to avoid infinite loops.

As a test let’s re-run the above contours with a smaller tolerance value.

Note

For 64 bit precision the smallest tolerance value that can be used is typically around \(10^{-14}\)

zs = ZeroSolver(tol=1e-14)
paths_tol, stopping_reason_tol = zs.zero_contour_finder(
    f,
    init_guess
)
print(stopping_reason_tol)
[[2 2]
 [2 2]]
plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.imshow(
    z,
    extent=(x.min(), x.max(), y.min(), y.max()),
    norm=colors.SymLogNorm(linthresh=0.1, vmin=-175, vmax=175),
    cmap='PuOr_r',
    origin='lower',
    interpolation='nearest'
)

plt.plot(*paths_tol['path'].T)
plt.plot(*init_guess[0], 'x', ms=10, color='C0')
plt.plot(*init_guess[1], 'x', ms=10, color='C1')
plt.colorbar()

plt.subplot(122)
plt.xlabel('path index')
plt.ylabel('function value on contour')
plt.plot(paths_tol['value'].T);
_images/2a130b7f17b7b85b1b685f0cced65a30d6b72f6c95ad5b12b2ff0b766ed9f63e.png

We can now see the value of the function along the contour is far smaller than before.

Stopping conditions

As pointed out in the previous section there are three different stopping conditions:

  • 0: none

  • 1: end point

  • 2: closed loop

None

The maximum number of steps were taken before reaching either stopping condition below.

End point

The zero value contour is not always closed, this can be caused by several thing such as a discontinuity in the function. This stopping condition is triggered when the proposed next point in the path is more than 2 * delta away from the previous point (i.e. the Newton’s step brings the point away after the Euler step).

Closed loop

When a contour loops back to it’s starting point the contour is marked as being closed. This stopping condition is triggered when the proposed next point is less than 1.1 * delta away from the starting point of the path.

After either an endpoint is reached or a contour is closed all remaining steps are recorded as NaN to avoid tracing over the same contour many times over. Once all paths have terminated the path integration is halted early to not waste compute time.