Example 2, discontinuous function

For our second example we will look at a discontinuous function with zero contours that terminate in end points. We will use jax.numpy.where to add in the discontinuity, as a result we need to make sure forward_mode_differentiation is specified when we use the value_and_grad_wrapper.

%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import matplotlib.colors as colors

from jax_zero_contour import ZeroSolver
@jax.tree_util.Partial
def f(pos):
     # avoid r=0 so the grad is finite
    r = jnp.sqrt(jnp.sum(pos**2, axis=0) + 1e-15)
    theta = jnp.arctan2(pos[1], pos[0])
    z = jnp.sinc(r)
    return jnp.where(theta >= 0, z + 0.5, z)
n = 1024
x = jnp.linspace(-3, 3, n)
y = jnp.linspace(-3, 3, n)
X, Y = jnp.meshgrid(x, y)
z = f(jnp.stack([X, Y]))
plt.imshow(
    z,
    extent=(x.min(), x.max(), y.min(), y.max()),
    norm=colors.SymLogNorm(linthresh=0.1, vmin=-1, vmax=2),
    cmap='PuOr_r',
    origin='lower',
    interpolation='nearest'
)
plt.colorbar();
_images/9ba7401758445b78cc98beab33cb62edddcb367602f917419d7211a4d30c04dd.png

As expected, the zero contours are not closed loops in this case, but instead half circles.

Note

As we used jnp.where inside our function definition we will need to set forward_mode_differentiation to True to avoid NaN values of the gradient.

zs = ZeroSolver(forward_mode_differentiation=True)
init_guess = jnp.array([[0.0, -0.6], [0.0, -1.6], [0.0, -2.6]])
paths, stopping_reason = zs.zero_contour_finder(
    f,
    init_guess,
    delta=0.01,
    N=500
)
print(stopping_reason)
[[1 1]
 [1 1]
 [1 1]]

In all cases the zero finder terminated with [1, 1] indicating that it found an end point in both directions.

plt.figure(figsize=(12, 4))
plt.subplot(121)
plt.imshow(
    z,
    extent=(x.min(), x.max(), y.min(), y.max()),
    norm=colors.SymLogNorm(linthresh=0.1, vmin=-1, vmax=2),
    cmap='PuOr_r',
    origin='lower',
    interpolation='nearest'
)
plt.colorbar()
plt.plot(*paths['path'].T)
plt.plot(*init_guess[0], 'x', ms=10, color='C0')
plt.plot(*init_guess[1], 'x', ms=10, color='C1')
plt.plot(*init_guess[2], 'x', ms=10, color='C2')

plt.subplot(122)
plt.xlabel('path index')
plt.ylabel('function value on contour')
plt.plot(paths['value'].T);
_images/126dd15d0aedc39ff8853a53ff912bd2f09a000640a6eaf2b913ed636335a2ff.png

We can see that the contour end points are correctly identified in the face of a discontinuity in the function.