Primal Dual Interior Points
- jaxnnls.pdip.centering_params(s, z, ds_a, dz_a)
duality gap + cc term in predictor-corrector PDIP
- jaxnnls.pdip.converged_check(inputs)
Check if the PDIP algorithm has converged
- Parameters:
inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter)
- Returns:
True if converged or MAX_ITER reached, False otherwise
- Return type:
bool
- jaxnnls.pdip.factorize_kkt(Q, s, z)
Cache factorize matrix values for solving KKT conditions quickly
- Parameters:
Q (jax.numpy.array) – (n, n) positive definite matrix.
s (jax.numpy.array) – (n,) slack vector
z (jax.numpy.array) – (n,) dual vector
- Returns:
P_inv_vec (jax.numpy.array) – (n,) ratio of z and s (inverse of the diag P matrix)
L_H (jax.numpy.array) – The Cholesky decomposition of the H matrix (Q with P_inv_vec added down the diag)
- jaxnnls.pdip.initialize(Q, q)
Initialize primal and dual variables
- Parameters:
Q (jax.numpy.array) – (n, n) positive definite matrix.
q (jax.numpy.array) – (n,) vector
- Returns:
x (jax.numpy.array) – (n,) Initial primal variable
s (jax.numpy.array) – (n,) Initial slack variable
z (jax.numpy.array) – (n,) Initial dual variable
- jaxnnls.pdip.ort_line_search(x, dx)
Maximum alpha <=1 such that x + alpha * dx >= 0
- Parameters:
x (jax.numpy.array) – (n,) vector
dx (jax.numpy.array) – (n,) gradient of vector x
- Returns:
Maximum alpha <=1 such that x + alpha * dx >= 0
- Return type:
float
- jaxnnls.pdip.pdip_pc_step(inputs)
One step of the predictor-corrector PDIP algorithm.
- Parameters:
inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter)
- Returns:
Updated state (Q, q, x, s, z, solver_tol, converged, pdip_iter)
- Return type:
tuple
- jaxnnls.pdip.solve_kkt_rhs(s, z, P_inv_vec, L_H, v1, v2, v3)
Solve the right hand side or the KKT conditions
- Parameters:
s (jnp.array) – (n,) slack vector
z (jnp.array) – (n,) dual vector
P_inv_vec (jax.numpy.array) – (n,) inverse of the diag P matrix
L_H (jax.numpy.array) – The Cholesky decomposition of the H matrix
v1 (jax.numpy.array) – (n,) negative residual 1
v2 (jax.numpy.array) – (n,) negative residual 2
v3 (jax.numpy.array) – (n,) negative residual 3
- Returns:
dx (jax.numpy.array) – (n,) step size for primal vector
ds (jax.numpy.array) – (n,) step size for slack vector
dz (jax.numpy.array) – (n,) step size for dual vector
- jaxnnls.pdip.solve_nnls(Q, q)
Solve the non-negative least square problem.
- Parameters:
Q (jax.numpy.array) – (n, n) positive definite matrix.
q (jax.numpy.array) – (n,) vector
- Returns:
x (jax.numpy.array) – (n,) solution x to Qx=q such that x >= 0
s (jax.numpy.array) – (n,) slack variable at the solution
z (jax.numpy.array) – (n,) dual variable at the solution
converged (int) – 1 if the algorithm converged, 0 otherwise
pdip_iter (int) – The number of PDIP iterations taken
Implicit Differentiation
- jaxnnls.diff_qp.diff_nnls(Q, z, s, lam, dl_dz)
Implicit derivatives
- jaxnnls.diff_qp.solve_nnls_primal_backward(res, input_grad)
Custom backwards pass derivative
- jaxnnls.diff_qp.solve_nnls_primal_forward(Q, q, target_kappa=0.001)
Custom forward pass derivative
- jaxnnls.diff_qp.solve_nnls_primal(Q, q, target_kappa=0.001)
Solve the non-negative least square problem with the ability to take a (relaxed) gradient.
- Parameters:
Q (jax.numpy.array) – (n, n) positive definite matrix.
q (jax.numpy.array) – (n,) vector
target_kappa (float) – target relaxation parameter used for the gradient
- Returns:
(n,) solution x to Qx=q such that x >= 0
- Return type:
jax.numpy.array
Relaxed Primal Dual Interior Points
- jaxnnls.pdip_relaxed.converged_check_relaxed(inputs)
Check if the relaxed PDIP algorithm has converged
- Parameters:
inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter, target_kappa)
- Returns:
True if converged or MAX_ITER reached, False otherwise
- Return type:
bool
- jaxnnls.pdip_relaxed.pdip_pc_step_relaxed(inputs)
One step of the relaxed predictor-corrector PDIP algorithm.
- Parameters:
inputs (tuple) – Tuple of the current state (Q, q, x, s, z, solver_tol, converged, pdip_iter, target_kappa)
- Returns:
Updated state (Q, q, x, s, z, solver_tol, converged, pdip_iter, target_kappa)
- Return type:
tuple
- jaxnnls.pdip_relaxed.solve_relaxed_nnls(Q, q, x, s, z, target_kappa=0.001)
Solve the relaxed non-negative least square problem.
- Parameters:
Q (jax.numpy.array) – (n, n) positive definite matrix.
q (jax.numpy.array) – (n,) primal vector
s (jax.numpy.array) – (n,) slack vector
z (jax.numpy.array) – (n,) dual vector
target_kappa (float) – target relaxation parameter
- Returns:
x (jax.numpy.array) – (n,) relaxed solution x to Qx=q such that x >= 0
s (jax.numpy.array) – (n,) slack variable at the relaxed solution
z (jax.numpy.array) – (n,) dual variable at the relaxed solution
converged (int) – 1 if the algorithm converged, 0 otherwise
pdip_iter (int) – The number of relaxed PDIP iterations taken