Ray: [rllib] Is it possible to add extra inputs to the model?

Created on 17 Sep 2019  ·  4Comments  ·  Source: ray-project/ray

The current implementation uses the observation as the input to the RL agent's policy's model, which seems to be hardly modified.

I am trying to add extra input to the model (for example some extra information as the second stream of inputs), but find the codes is so sophisticated that I don't know how to do.

It's easy to build the model with multiple input placeholders. But the question is how to let the policy.compute_actions or agent.compute_action to input such data.

Thanks a lot for your reply!

question

Most helpful comment

You need to specify the state argument of compute actions. The policy has a method to get the initial state, see the rollout.py script.

All 4 comments

At least partially related: it is not clear from where to retrieve the required (recurrent) state when using policy.compute_actions or policy.compute_single_action in combination with recurrent networks.

What are you trying to do? Usually, you can put this information in the observation. Observations can be arbitrarily complex with Dict spaces. For states, the new states are returned by the compute action call.

Same as @FedericoFontana, I have a trained agent using lstm and I want to load one of its checkpoint to run it against a test env. When calling the model.compute_action method, I get the error ValueError: Must pass in RNN state batches for placeholders [<tf.Tensor 'default_policy/Placeholder:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'default_policy/Placeholder_1:0' shape=(?, 256) dtype=float32>], got []

Here's my code:

    model.restore(
        '/home/ray_results/PPO/PPO_CustomEnv_ab41bcf8_0_2020-02-11_14-09-00h8cckrwu/checkpoint_50/checkpoint-50')

    env = CustomEnv(
        data=df_test, flatten_obs=True, lag_window=1)
    obs = env.reset()
    for _ in range(1000):
        action = model.compute_action(obs, prev_action=0, prev_reward=0) # This line throws the error
        obs, reward, done, info = env.step(action)
        env.render()

Which hidden state is supposed to be put here, since it's the first run ?

I'm using TF 1.15.

You need to specify the state argument of compute actions. The policy has a method to get the initial state, see the rollout.py script.

Was this page helpful?
0 / 5 - 0 ratings