Examples

In this example we will be looking at the basic usage of jaxnnls to solve a non-negative least squares (NNLS) problem.

Warning

While the algorithm can sometimes work with Jax’s default 32-bit precision, it is recommended that you enable 64-bit precision. The Cholesky decompositions used can become unstable at lower precision and lead to nan results.

Basic usage

To begin we will write a function that randomly generates a non-trivial NNLS system.

import jax

# enable 64 bit mode
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jaxnnls

# adjust the print options to make easier to read
jnp.set_printoptions(suppress=True)


def generate_random_qp(key, nx):
    # split the random key
    key_q, key_mask, key_x, key_z = jax.random.split(key, 4)
    # make a positive definite Q matrix
    Q = jax.random.normal(key_q, (nx, nx))
    Q = Q.T @ Q
    # make the primal and dual variables (all positive)
    x = jnp.abs(jax.random.normal(key_x, (nx,)))
    z = jnp.abs(jax.random.normal(key_z, (nx,)))
    # mask out 50% of the values to zero
    mask = jax.random.choice(key_mask, jnp.array([True, False]), (nx,))
    x = jnp.where(mask, x, 0)
    z = jnp.where(mask, 0, z)
    # make the "observed" vector that has x as it's NNLS solution
    q = Q @ x - z
    return Q, q, x, z

Now let’s make a Jax random key and generates an example system with a 5x5 Q matrix.

_key = jax.random.key(0)
key, _key = jax.random.split(_key)

Q, q, x, z = generate_random_qp(key, 5)

Next let’s find the unconstrained solution using jnp.linalg.solve

print(jnp.linalg.solve(Q, q))
[-0.49537474 -4.19172978  2.4482249   2.57614299 -3.93564659]

We can clearly see that this leads to negative values in the solution. Now let’s take the same system but use the NNLS solver. If you are only interested in the primal solution (e.g. x) we can use jaxnnls.solve_nnls_primal. If you want both the primal and dual solutions (along with some extra diagnostic information) you should use jaxnnls.solve_nnls.

jit_solve_nnls_primal = jax.jit(jaxnnls.solve_nnls_primal)
x_solve = jit_solve_nnls_primal(Q, q)
print(x_solve)
[0.24240651 0.         0.20491032 0.05982262 0.        ]

Now we can see the solution being found is all positive as desired. We can also check this against the known solution.

print(jnp.allclose(x, x_solve))
True

Solving a batch of problems

The solver is full compatible with vmap for solving a system of problems at the same time. First we will generate set of random problems with known solutions.

key, _key = jax.random.split(_key)

Qs, qs, xs, zs = jax.vmap(generate_random_qp, in_axes=(0, None))(jax.random.split(key, 20), 5)

Now we will jit and vmap the solver and apply it to our set of problems.

batch_nnls = jax.jit(jax.vmap(jaxnnls.solve_nnls_primal, in_axes=(0, 0)))
batch_xs = batch_nnls(Qs, qs)
print(batch_xs)
[[0.         0.         0.         0.         0.75465726]
 [0.         0.2847717  0.26074245 0.45316464 0.        ]
 [0.9906791  0.         0.63885713 0.44994772 0.        ]
 [0.         0.252907   0.         0.         0.79185027]
 [3.32383184 1.09270692 0.         0.13081915 0.        ]
 [0.95259472 3.39704618 0.         0.         0.91332539]
 [0.16156309 0.24036181 0.75923941 0.         0.52760814]
 [0.01559403 0.13851779 0.         1.52683184 0.        ]
 [0.65993688 1.0579556  0.60431438 0.93401873 0.        ]
 [0.         0.21490957 0.09996912 0.61789174 0.        ]
 [1.35246462 0.         1.22011375 0.41823024 0.        ]
 [0.         0.80765402 0.         1.47898561 0.09715588]
 [0.         0.36266893 2.44237877 0.09065782 0.        ]
 [0.         0.         1.60061884 0.         0.28679851]
 [0.         0.         0.60722848 0.86527315 0.        ]
 [0.         0.53199946 2.24017961 0.74248151 0.16909305]
 [0.         0.         0.01502807 0.9199673  0.20331264]
 [0.         0.46073631 0.         0.14101541 0.14167207]
 [1.18345423 0.         0.         1.17427011 1.03613143]
 [0.54701621 0.24974786 0.         0.         1.66293739]]

We see that all the solutions are indeed position as expected. Now let’s check if they match the known solutions.

print(jnp.allclose(xs, batch_xs))
True

Differentiating a NNLS

If we are only looking at the primal solution with jaxnnls.solve_nnls_primal we can use automatic differentiation. For this example we will set up a simple loss function and calculated the gradients of that loss with respect to both Q and q.

def loss(Q, q, target_kappa=0):
    x = jaxnnls.solve_nnls_primal(Q, q, target_kappa=target_kappa)
    x_bar = jnp.ones_like(x)
    residual = x - x_bar
    return jnp.dot(residual, residual)


loss_and_grad = jax.jit(jax.value_and_grad(loss, argnums=(0, 1)))

l, (dl_dQ, dl_dq) = loss_and_grad(Q, q)
print(l)
4.090049007977832
print(dl_dQ)
[[0.19053153 0.         0.09559879 0.14735349 0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.09559879 0.         0.02547621 0.10840558 0.        ]
 [0.14735349 0.         0.10840558 0.06112563 0.        ]
 [0.         0.         0.         0.         0.        ]]
print(dl_dq)
[-0.78600008 -0.         -0.12432857 -1.02178115 -0.        ]

In the above example we set target_kappa=0. This means no smoothing will be applied to the gradients. In general, when dealing with constrained solvers like this, the gradients can be discontinuous. In this example we see that we only have non-zero gradient values for the elements of x that are non-zero when solved.

If we were aiming to minimize our loss using a gradient decent method, we would only be able to move a subset of our parameters at a time because of this. By increasing the target_kappa value these discontinuities will be smoothed out, providing more useful information for gradient based optimizers.

l_kappa, (dl_dQ_kappa, dl_dq_kappa) = loss_and_grad(Q, q, target_kappa=1e-3)
print(l_kappa)
4.090049007977832
print(dl_dQ_kappa)
[[0.2432692  0.12769956 0.13826567 0.10431171 0.00039145]
 [0.12769956 0.02300606 0.09821804 0.03492767 0.00032613]
 [0.13826567 0.09821804 0.06365574 0.07083384 0.00015223]
 [0.10431171 0.03492767 0.07083384 0.03579759 0.00022219]
 [0.00039145 0.00032613 0.00015223 0.00022219 0.0000003 ]]
print(dl_dq_kappa)
[-0.96254868 -0.91475926 -0.30862555 -0.59715493 -0.00042679]

We can see that the loss value has not changed as the smoothing is only applied to the gradients. As for the two gradients, all the values have become non-zero. Now if gradient decent was applied all the value would move rather than just a subset of them.

For more information about the smoothing process please refer to the qpax paper.

Diagnostic information

In all the examples above we used jaxnnls.solve_nnls_primal as we were only interested in the primal solution. If you want the dual solution or more diagnostic information the jaxnnls.solve_nnls function is available.

x, s, z, converged, number_iterations = jaxnnls.solve_nnls(Q, q)

The outputs are:

  • x: the primal solution

  • s: the slack variable (will be the same as x if the algorithm converged)

  • z: the dual solution

  • converged: flag that is 1 if the algorithm converged and 0 otherwise

  • number_iterations: the number of steps the algorithm took to converged

Note

The code will run a maximum of 50 steps before stopping and reporting it did not converge.

Note

Automatic differentiation is only available for jaxnnls.solve_nnls_primal not this version of the function.

print(x)
[0.24240651 0.         0.20491032 0.05982262 0.        ]
print(s)
[0.24240651 0.         0.20491032 0.05982262 0.        ]
print(z)
[0.         0.02085643 0.         0.         1.43297883]
print(converged)
1
print(number_iterations)
10