Examples
In this example we will be looking at the basic usage of jaxnnls
to solve a non-negative least squares (NNLS) problem.
Warning
While the algorithm can sometimes work with Jax’s default 32-bit precision, it is recommended that you enable 64-bit precision. The Cholesky decompositions used can become unstable at lower precision and lead to nan
results.
Basic usage
To begin we will write a function that randomly generates a non-trivial NNLS system.
import jax
# enable 64 bit mode
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jaxnnls
# adjust the print options to make easier to read
jnp.set_printoptions(suppress=True)
def generate_random_qp(key, nx):
# split the random key
key_q, key_mask, key_x, key_z = jax.random.split(key, 4)
# make a positive definite Q matrix
Q = jax.random.normal(key_q, (nx, nx))
Q = Q.T @ Q
# make the primal and dual variables (all positive)
x = jnp.abs(jax.random.normal(key_x, (nx,)))
z = jnp.abs(jax.random.normal(key_z, (nx,)))
# mask out 50% of the values to zero
mask = jax.random.choice(key_mask, jnp.array([True, False]), (nx,))
x = jnp.where(mask, x, 0)
z = jnp.where(mask, 0, z)
# make the "observed" vector that has x as it's NNLS solution
q = Q @ x - z
return Q, q, x, z
Now let’s make a Jax random key and generates an example system with a 5x5 Q
matrix.
_key = jax.random.key(0)
key, _key = jax.random.split(_key)
Q, q, x, z = generate_random_qp(key, 5)
Next let’s find the unconstrained solution using jnp.linalg.solve
print(jnp.linalg.solve(Q, q))
[-0.49537474 -4.19172978 2.4482249 2.57614299 -3.93564659]
We can clearly see that this leads to negative values in the solution. Now let’s take the same system but use the NNLS solver. If you are only interested in the primal solution (e.g. x
) we can use jaxnnls.solve_nnls_primal
. If you want both the primal and dual solutions (along with some extra diagnostic information) you should use jaxnnls.solve_nnls
.
jit_solve_nnls_primal = jax.jit(jaxnnls.solve_nnls_primal)
x_solve = jit_solve_nnls_primal(Q, q)
print(x_solve)
[0.24240651 0. 0.20491032 0.05982262 0. ]
Now we can see the solution being found is all positive as desired. We can also check this against the known solution.
print(jnp.allclose(x, x_solve))
True
Solving a batch of problems
The solver is full compatible with vmap
for solving a system of problems at the same time. First we will generate set of random problems with known solutions.
key, _key = jax.random.split(_key)
Qs, qs, xs, zs = jax.vmap(generate_random_qp, in_axes=(0, None))(jax.random.split(key, 20), 5)
Now we will jit
and vmap
the solver and apply it to our set of problems.
batch_nnls = jax.jit(jax.vmap(jaxnnls.solve_nnls_primal, in_axes=(0, 0)))
batch_xs = batch_nnls(Qs, qs)
print(batch_xs)
[[0. 0. 0. 0. 0.75465726]
[0. 0.2847717 0.26074245 0.45316464 0. ]
[0.9906791 0. 0.63885713 0.44994772 0. ]
[0. 0.252907 0. 0. 0.79185027]
[3.32383184 1.09270692 0. 0.13081915 0. ]
[0.95259472 3.39704618 0. 0. 0.91332539]
[0.16156309 0.24036181 0.75923941 0. 0.52760814]
[0.01559403 0.13851779 0. 1.52683184 0. ]
[0.65993688 1.0579556 0.60431438 0.93401873 0. ]
[0. 0.21490957 0.09996912 0.61789174 0. ]
[1.35246462 0. 1.22011375 0.41823024 0. ]
[0. 0.80765402 0. 1.47898561 0.09715588]
[0. 0.36266893 2.44237877 0.09065782 0. ]
[0. 0. 1.60061884 0. 0.28679851]
[0. 0. 0.60722848 0.86527315 0. ]
[0. 0.53199946 2.24017961 0.74248151 0.16909305]
[0. 0. 0.01502807 0.9199673 0.20331264]
[0. 0.46073631 0. 0.14101541 0.14167207]
[1.18345423 0. 0. 1.17427011 1.03613143]
[0.54701621 0.24974786 0. 0. 1.66293739]]
We see that all the solutions are indeed position as expected. Now let’s check if they match the known solutions.
print(jnp.allclose(xs, batch_xs))
True
Differentiating a NNLS
If we are only looking at the primal solution with jaxnnls.solve_nnls_primal
we can use automatic differentiation. For this example we will set up a simple loss function and calculated the gradients of that loss with respect to both Q
and q
.
def loss(Q, q, target_kappa=0):
x = jaxnnls.solve_nnls_primal(Q, q, target_kappa=target_kappa)
x_bar = jnp.ones_like(x)
residual = x - x_bar
return jnp.dot(residual, residual)
loss_and_grad = jax.jit(jax.value_and_grad(loss, argnums=(0, 1)))
l, (dl_dQ, dl_dq) = loss_and_grad(Q, q)
print(l)
4.090049007977832
print(dl_dQ)
[[0.19053153 0. 0.09559879 0.14735349 0. ]
[0. 0. 0. 0. 0. ]
[0.09559879 0. 0.02547621 0.10840558 0. ]
[0.14735349 0. 0.10840558 0.06112563 0. ]
[0. 0. 0. 0. 0. ]]
print(dl_dq)
[-0.78600008 -0. -0.12432857 -1.02178115 -0. ]
In the above example we set target_kappa=0
. This means no smoothing will be applied to the gradients. In general, when dealing with constrained solvers like this, the gradients can be discontinuous. In this example we see that we only have non-zero gradient values for the elements of x
that are non-zero when solved.
If we were aiming to minimize our loss using a gradient decent method, we would only be able to move a subset of our parameters at a time because of this. By increasing the target_kappa
value these discontinuities will be smoothed out, providing more useful information for gradient based optimizers.
l_kappa, (dl_dQ_kappa, dl_dq_kappa) = loss_and_grad(Q, q, target_kappa=1e-3)
print(l_kappa)
4.090049007977832
print(dl_dQ_kappa)
[[0.2432692 0.12769956 0.13826567 0.10431171 0.00039145]
[0.12769956 0.02300606 0.09821804 0.03492767 0.00032613]
[0.13826567 0.09821804 0.06365574 0.07083384 0.00015223]
[0.10431171 0.03492767 0.07083384 0.03579759 0.00022219]
[0.00039145 0.00032613 0.00015223 0.00022219 0.0000003 ]]
print(dl_dq_kappa)
[-0.96254868 -0.91475926 -0.30862555 -0.59715493 -0.00042679]
We can see that the loss value has not changed as the smoothing is only applied to the gradients. As for the two gradients, all the values have become non-zero. Now if gradient decent was applied all the value would move rather than just a subset of them.
For more information about the smoothing process please refer to the qpax paper.
Diagnostic information
In all the examples above we used jaxnnls.solve_nnls_primal
as we were only interested in the primal solution. If you want the dual solution or more diagnostic information the jaxnnls.solve_nnls
function is available.
x, s, z, converged, number_iterations = jaxnnls.solve_nnls(Q, q)
The outputs are:
x
: the primal solutions
: the slack variable (will be the same asx
if the algorithm converged)z
: the dual solutionconverged
: flag that is1
if the algorithm converged and0
otherwisenumber_iterations
: the number of steps the algorithm took to converged
Note
The code will run a maximum of 50 steps before stopping and reporting it did not converge.
Note
Automatic differentiation is only available for jaxnnls.solve_nnls_primal
not this version of the function.
print(x)
[0.24240651 0. 0.20491032 0.05982262 0. ]
print(s)
[0.24240651 0. 0.20491032 0.05982262 0. ]
print(z)
[0. 0.02085643 0. 0. 1.43297883]
print(converged)
1
print(number_iterations)
10