JaxNNLS

This package can be used for solving non-negative least square (NNLS) problems of the following form:

\[\begin{split}\begin{align*} \underset{x}{\text{minimize}} & \quad \frac{1}{2}x^TQx - q^Tx \\ \text{subject to} & \quad x \geq 0 \end{align*}\end{split}\]

where Q is positive definite. Or equivalently

\[\begin{split}\begin{align*} \underset{x}{{\text{solve}}} & \quad Ax = b \\ \text{subject to} & \quad x \geq 0 \end{align*}\end{split}\]

when you set

\[\begin{split}\begin{align*} Q&=A^TA \\ q&=A^Tb \end{align*}\end{split}\]

This solver can be combined with JAX’s jit and vmap functionality, as well as differentiated with reverse-mode grad.

The NNLS problem is solved with a primal-dual interior point algorithm. This code is a modification on the qpax package, but in the special case of NNLS. Because of the simplifications in this special case the resulting code is significantly faster when Q large in size.

As with the qpax code, derivative smoothing can be applied to the gradients.

Link to the documentation.