Primal Dual Interior Points

jaxnnls.pdip.centering_params(s, z, ds_a, dz_a)

duality gap + cc term in predictor-corrector PDIP

jaxnnls.pdip.converged_check(inputs)

Check if the PDIP algorithm has converged

Parameters:

inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter)

Returns:

True if converged or MAX_ITER reached, False otherwise

Return type:

bool

jaxnnls.pdip.factorize_kkt(Q, s, z)

Cache factorize matrix values for solving KKT conditions quickly

Parameters:
  • Q (jax.numpy.array) – (n, n) positive definite matrix.

  • s (jax.numpy.array) – (n,) slack vector

  • z (jax.numpy.array) – (n,) dual vector

Returns:

  • P_inv_vec (jax.numpy.array) – (n,) ratio of z and s (inverse of the diag P matrix)

  • L_H (jax.numpy.array) – The Cholesky decomposition of the H matrix (Q with P_inv_vec added down the diag)

jaxnnls.pdip.initialize(Q, q)

Initialize primal and dual variables

Parameters:
  • Q (jax.numpy.array) – (n, n) positive definite matrix.

  • q (jax.numpy.array) – (n,) vector

Returns:

  • x (jax.numpy.array) – (n,) Initial primal variable

  • s (jax.numpy.array) – (n,) Initial slack variable

  • z (jax.numpy.array) – (n,) Initial dual variable

Maximum alpha <=1 such that x + alpha * dx >= 0

Parameters:
  • x (jax.numpy.array) – (n,) vector

  • dx (jax.numpy.array) – (n,) gradient of vector x

Returns:

Maximum alpha <=1 such that x + alpha * dx >= 0

Return type:

float

jaxnnls.pdip.pdip_pc_step(inputs)

One step of the predictor-corrector PDIP algorithm.

Parameters:

inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter)

Returns:

Updated state (Q, q, x, s, z, solver_tol, converged, pdip_iter)

Return type:

tuple

jaxnnls.pdip.solve_kkt_rhs(s, z, P_inv_vec, L_H, v1, v2, v3)

Solve the right hand side or the KKT conditions

Parameters:
  • s (jnp.array) – (n,) slack vector

  • z (jnp.array) – (n,) dual vector

  • P_inv_vec (jax.numpy.array) – (n,) inverse of the diag P matrix

  • L_H (jax.numpy.array) – The Cholesky decomposition of the H matrix

  • v1 (jax.numpy.array) – (n,) negative residual 1

  • v2 (jax.numpy.array) – (n,) negative residual 2

  • v3 (jax.numpy.array) – (n,) negative residual 3

Returns:

  • dx (jax.numpy.array) – (n,) step size for primal vector

  • ds (jax.numpy.array) – (n,) step size for slack vector

  • dz (jax.numpy.array) – (n,) step size for dual vector

jaxnnls.pdip.solve_nnls(Q, q)

Solve the non-negative least square problem.

Parameters:
  • Q (jax.numpy.array) – (n, n) positive definite matrix.

  • q (jax.numpy.array) – (n,) vector

Returns:

  • x (jax.numpy.array) – (n,) solution x to Qx=q such that x >= 0

  • s (jax.numpy.array) – (n,) slack variable at the solution

  • z (jax.numpy.array) – (n,) dual variable at the solution

  • converged (int) – 1 if the algorithm converged, 0 otherwise

  • pdip_iter (int) – The number of PDIP iterations taken

Implicit Differentiation

jaxnnls.diff_qp.diff_nnls(Q, z, s, lam, dl_dz)

Implicit derivatives

jaxnnls.diff_qp.solve_nnls_primal_backward(res, input_grad)

Custom backwards pass derivative

jaxnnls.diff_qp.solve_nnls_primal_forward(Q, q, target_kappa=0.001)

Custom forward pass derivative

jaxnnls.diff_qp.solve_nnls_primal(Q, q, target_kappa=0.001)

Solve the non-negative least square problem with the ability to take a (relaxed) gradient.

Parameters:
  • Q (jax.numpy.array) – (n, n) positive definite matrix.

  • q (jax.numpy.array) – (n,) vector

  • target_kappa (float) – target relaxation parameter used for the gradient

Returns:

(n,) solution x to Qx=q such that x >= 0

Return type:

jax.numpy.array

Relaxed Primal Dual Interior Points

jaxnnls.pdip_relaxed.converged_check_relaxed(inputs)

Check if the relaxed PDIP algorithm has converged

Parameters:

inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter, target_kappa)

Returns:

True if converged or MAX_ITER reached, False otherwise

Return type:

bool

jaxnnls.pdip_relaxed.pdip_pc_step_relaxed(inputs)

One step of the relaxed predictor-corrector PDIP algorithm.

Parameters:

inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter, target_kappa)

Returns:

Updated state (Q, q, x, s, z, solver_tol, converged, pdip_iter, target_kappa)

Return type:

tuple

jaxnnls.pdip_relaxed.solve_relaxed_nnls(Q, q, x, s, z, target_kappa=0.001)

Solve the relaxed non-negative least square problem.

Parameters:
  • Q (jax.numpy.array) – (n, n) positive definite matrix.

  • q (jax.numpy.array) – (n,) primal vector

  • s (jax.numpy.array) – (n,) slack vector

  • z (jax.numpy.array) – (n,) dual vector

  • target_kappa (float) – target relaxation parameter

Returns:

  • x (jax.numpy.array) – (n,) relaxed solution x to Qx=q such that x >= 0

  • s (jax.numpy.array) – (n,) slack variable at the relaxed solution

  • z (jax.numpy.array) – (n,) dual variable at the relaxed solution

  • converged (int) – 1 if the algorithm converged, 0 otherwise

  • pdip_iter (int) – The number of relaxed PDIP iterations taken