Jax: apply constraints in optimizer

Created on 25 Apr 2020  路  7Comments  路  Source: google/jax

How to apply constraints in the optimizer (i.e. somehow change params each time after opt_update function)?

question

Most helpful comment

If you're not going to use momentum and stuff, then having init and get_params is a bit heavy-handed and you could get rid of them. Also, if you're not going to use tuples/lists/dicts for decision variables then you don't need @optimizer either. At that point, you might as well ignore optimizers.py entirely and just write things directly with grad:

x = x_init
for i in range(1000):
  g = grad(loss, 1)(A, x, y)
  x = nonnegative_projector(x - 1e-2 * g)

optimality = nonnegative_projector(-grad(loss, 1)(A, x, y))
print(np.allclose(optimality, 0., atol=1e-4, rtol=1e-4))  # True
print(x)  # [0.        0.6929333]

Much cleaner!

I just want to make sure you are getting something out of using any of optimizers.py. If you don't need it, don't use it!

All 7 comments

There are some functions in optimizers.py for learning rate schedules, but for other parameters I don't think there's an easy way.

Thanks for the question! Are you thinking of something like, as one example, projected gradient algorithms?

In general, I recommend just treating JAX like you would treat NumPy here: just write it! The JAX API is meant to be just the NumPy API plus function transformations (grad, jit, vmap, etc.). So: how would you write such an algorithm in NumPy?

It's true that we have example libraries like jax.experimental.optimizers, but I wouldn't think of those as a framework. They're meant to be inspirational snippets for you to fork for your own purposes. Don't feel constrained by them! Perhaps you can adapt them to constrained optimization, or perhaps for the constrained optimization case you need something different.

cc @mblondel @froystig who might have some related work and expertise.

@mattjj yes, smth like a projected gradient. Right now I am performing weird stuff like optimizers.unpack_optimizer_state -> update parameters inside this state in the first item of JoinPoint.subtree tuple -> optimizers.pack_optimizer_state, but it looks like I am using private information of optimizer state (i.e. the fact, that parameters would always be in the first item of the tuple).

If we want to think of the projector as problem instance data (rather than being built-in to the optimization routine), maybe you could write things like:

import jax.numpy as np
from jax.experimental.optimizers import optimizer

@optimizer
def pgd(step_size, projector):
  def init(x0):
    return x0
  def update(i, g, x):
    return projector(x - step_size * g)
  def get_params(x):
    return x
  return init, update, get_params

The idea is that projector does the projection you want, like Euclidean projection onto the nonnegative orthant or whatever, operating only on arrays. (The optimizer decorator just serves to map the triple of functions that work on arrays to a triple of functions that work on nested lists/tuples/dicts of arrays via mapping over them.)

You can then use it like this:

# problem instance from Matlab docs:
# https://www.mathworks.com/help/matlab/ref/lsqnonneg.html
A = np.array([[0.0372, 0.2869],
              [0.6861, 0.7071],
              [0.6233, 0.6245],
              [0.6344, 0.6170]]);
y = np.array([0.8587, 0.1781, 0.0747, 0.8405])

import numpy as onp  # no lstsq in jax.numpy.linalg yet
unc_x, *_ = onp.linalg.lstsq(A, y)
print(unc_x)  # [-2.5627463  3.1107764]

# nnls via pgd
from jax import grad

def loss(A, x, y):
  prediction = np.dot(A, x)
  return np.sum((prediction - y) ** 2)

def nonnegative_projector(x):
  return np.maximum(x, 0)

x_init = np.ones(2)
init, update, get_params = pgd(1e-2, nonnegative_projector)

opt_state = init(x_init)
for i in range(1000):
  x = get_params(opt_state)
  g = grad(loss, 1)(A, x, y)
  opt_state = update(i, g, opt_state)
x = get_params(opt_state)

optimality = nonnegative_projector(-grad(loss, 1)(A, x, y))
print(np.allclose(optimality, 0., atol=1e-4, rtol=1e-4))  # True
print(x)  # [0.        0.6929333]

WDYT? Is this what you have in mind?

By the way, jax.numpy.linalg.lstsq is probably coming in #2744.

If you're not going to use momentum and stuff, then having init and get_params is a bit heavy-handed and you could get rid of them. Also, if you're not going to use tuples/lists/dicts for decision variables then you don't need @optimizer either. At that point, you might as well ignore optimizers.py entirely and just write things directly with grad:

x = x_init
for i in range(1000):
  g = grad(loss, 1)(A, x, y)
  x = nonnegative_projector(x - 1e-2 * g)

optimality = nonnegative_projector(-grad(loss, 1)(A, x, y))
print(np.allclose(optimality, 0., atol=1e-4, rtol=1e-4))  # True
print(x)  # [0.        0.6929333]

Much cleaner!

I just want to make sure you are getting something out of using any of optimizers.py. If you don't need it, don't use it!

I think we covered this issue, so I'm going to close it. Please open more as questions arise!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

harshit-2115 picture harshit-2115  路  3Comments

zhongwen picture zhongwen  路  3Comments

clemisch picture clemisch  路  3Comments

fehiepsi picture fehiepsi  路  3Comments

sussillo picture sussillo  路  3Comments