Ray: [rllib] Allow for JAX framework

Created on 2 Jun 2020  路  8Comments  路  Source: ray-project/ray

What is the problem?

I've been using JAX as my framework for a little while now. I just upgraded to the nightly build (due to some unrelated issues) and now RLlib is telling me I need to install TensorFlow or Torch.

I tried setting {'framework': 'jax', ...} in my trainer config, but this results in another error. Basically, not recognizing any framework other than one of: [tf|tfe|torch|auto]

Ray version: ray-0.9.0.dev0, Python 3.8, Ubuntu Linux 20.04 LTS

Script to reproduce:


import ray
from ray.rllib.policy.policy import Policy as BasePolicy
from ray.rllib.agents.trainer_template import build_trainer


class Policy(BasePolicy):
    def compute_actions(self, obs_batch, **kwargs):
        actions = [self.action_space.sample() for _ in obs_batch]
        return actions, [], {}

    def get_weights(self):
        pass

    def set_weights(self, weights):
        pass

    def learn_on_batch(self, sample_batch):
        pass


trainer = build_trainer(
    name='foo',
    default_policy=Policy,
)

ray.init()
ray.tune.run(
    trainer,
    config={
        'framework': 'jax',
        'env': 'FrozenLake-v0',
    },
    stop={'training_iteration': 1}
)

If we cannot run your script, we cannot fix your issue.

  • [x] I have verified my script runs in a clean environment and reproduces the issue.
  • [x] I have verified the issue also occurs with the latest wheels.
P1 bug rllib

Most helpful comment

Ok, this works now. Just explicitly use None as your framework in the config.
@KristianHolsheimer your point is valid, but we do want to apply type checking here as we internally really only support tf|torch|tfe|None. If you want, just create a new key in your Trainer's config: e.g. jax=True for now and check the value of that. We will look into adding JAX very soon and probably have some rudimentary support for this in the near future (generic default Policy/Models).

config={
        'framework': None,
        'env': 'FrozenLake-v0',
    },

Closing this issue.

All 8 comments

@sven1977 should we support framework=None or something like this?

Yes, that would probably be more sustainable.

Am I right in thinking that as long as my Policy-derived class implements all required methods it doesn't depend on any specific framework?

And if that's the case, why pick a default framework at all?

Would RLlib allow for mixing of frameworks?

What would go wrong if I use tensorflow to collect experience, pytorch to learn and jax/numpy to run trails?

Of course, we can think of many reasons not to use such a setup, but to what extent does RLlib's functionality depend on a choice of framework?

Yeah I think the framework parameter might have been a bit heavy handed. We do allow mixing of policies with different frameworks.

I made a PR to remove the framework import checking, which should allow any kind of policy to be used no matter what the setting is.

Yeah, let's allow framework=None as well, in which case, RLlib shouldn't check anything.

@sven1977 This might be silly question, but why would we want to check the framework at all?

I feel that with the latest changes that @ericl made in #8748 is a better setup, i.e. drop the framework checks altogether. This means that the framework config setting is just a hint that allows you to write some conditional logic if a specific value is set.

For instance, I might want to implement some logic if config['framework'] == 'jax', which shouldn't cause any Exceptions elsewhere in the codebase.

Ok, this works now. Just explicitly use None as your framework in the config.
@KristianHolsheimer your point is valid, but we do want to apply type checking here as we internally really only support tf|torch|tfe|None. If you want, just create a new key in your Trainer's config: e.g. jax=True for now and check the value of that. We will look into adding JAX very soon and probably have some rudimentary support for this in the near future (generic default Policy/Models).

config={
        'framework': None,
        'env': 'FrozenLake-v0',
    },

Closing this issue.

Good to hear that JAX is on your road map. Let me know if I can help.

I shared a couple testing scripts in #8776

Was this page helpful?
0 / 5 - 0 ratings