Jax: Optimizing multiple parameters via jax optimizer

Created on 20 Feb 2019  路  2Comments  路  Source: google/jax

I have a loss function that takes in multiple parameters as follows and I would like to optimize the loss w.r.t. each of p1, p2, ..pn.

def loss(p1, p2, ..pn, **kwargs):
    ...

Currently I am initializing an optimizer for each of these parameters, and updating them in a loop as follows:

    # len(opt_updates), len(opt_states) = # of params
    params = [optimizers.get_params(s) for s in opt_states]
    for i, opt_state in enumerate(opt_states):
        g = grad(loss, i)(*params, **kwargs)
        updates.append(opt_updates[j](i, g, opt_state))

I am new to JAX, so I am wondering if this is the right way to call the optimizer if we need to optimize an objective with multiple parameters? In particular, I was confused by this statement in the README:

The parameters being optimized can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can store your parameters however you鈥檇 like.

While I can pass in a list of parameters to the init and update functions, it doesn't do the right thing and throws an exception:

>   update = lambda g, state: pack(update_fun(i, g, *state))
E   TypeError: update_fun() missing 1 required positional argument: 'v'

Most helpful comment

Thanks for the issue report!

You should be able to optimize over nested list/tuple/dict structures; you shouldn't need one optimizer per parameter.

Could you please give a complete and self-contained example that produces the error you gave? It's hard to debug what's going wrong from the error alone.

As an example of using multiple optimizer states, you could take a look at this example, which uses a tuple of two values as the optimizer state:
https://github.com/google/jax/blob/master/examples/advi.py#L120

All 2 comments

Thanks for the issue report!

You should be able to optimize over nested list/tuple/dict structures; you shouldn't need one optimizer per parameter.

Could you please give a complete and self-contained example that produces the error you gave? It's hard to debug what's going wrong from the error alone.

As an example of using multiple optimizer states, you could take a look at this example, which uses a tuple of two values as the optimizer state:
https://github.com/google/jax/blob/master/examples/advi.py#L120

Thanks for sharing this example, it helped me fix the issue and clean up my code. I'll mention the resolution here in case it is helpful to users. Following is a simple example:

def loss(x, y):
    return np.sum(x ** 2 + y ** 2)


def step(i, opt_state, opt_update):
    params = optimizers.get_params(opt_state)
    g = grad(loss)(*params)
    return opt_update(i, g, opt_state)


init_params = (np.array([1., 1., 1.]), np.array([-1, -1., -1.]))
opt_init, opt_update = optimizers.momentum(step_size=1e-2, mass=0.9)
opt_state = opt_init(init_params)
for i in range(1000):
    opt_state = step(i, opt_state, opt_update)

assert np.allclose(optimizers.get_params(opt_state), np.zeros(3))

The issue is with the call grad(loss)(*params) which messes up the argument list to update_fn. If we write our loss to instead take in the parameters list as the first arg (not an indeterminable number of varargs), it works fine.

def loss(params):
    x, y = params[0], params[1]
    return np.sum(x ** 2 + y ** 2)


def step(i, opt_state, opt_update):
    params = optimizers.get_params(opt_state)
    g = grad(loss)(params)
    return opt_update(i, g, opt_state)
Was this page helpful?
0 / 5 - 0 ratings