Ray: RLLib: Pytorch + PPO + RNN: KeyError: 'seq_lens' in the batch dictionary

Created on 18 Feb 2020  路  16Comments  路  Source: ray-project/ray

What is the problem?

Installed ray with the nightly wheel.
I wrote a custom env, model, and action distribution.
I attempt to train it with PPO but there is a key error in one of the internal object used by RLLib (the batch dict with "seq_lens" that is used for masking recurrent model when backpropagating)
2020-02-18 16:11:55,261 ERROR trial_runner.py:513 -- Trial PPO_test_49f2c33a: Error processing event. Traceback (most recent call last): File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 459, in _process_trial result = self.trial_executor.fetch_result(trial) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 377, in fetch_result result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/worker.py", line 1522, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=13094, ip=10.0.2.217) File "python/ray/_raylet.pyx", line 447, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 425, in ray._raylet.execute_task.function_executor File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 477, in train raise e File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 463, in train result = Trainable.train(self) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trainable.py", line 254, in train result = self._train() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 122, in _train fetches = self.optimizer.step() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step self.standardize_fields) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/sgd.py", line 111, in do_minibatch_sgd }, minibatch.count)))[policy_id] File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 619, in learn_on_batch info_out[pid] = policy.learn_on_batch(batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 100, in learn_on_batch loss_out = self._loss(self, self.model, self.dist_class, train_batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 112, in ppo_surrogate_loss print(train_batch["seq_lens"]) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__ value = dict.__getitem__(self, key)
KeyError: 'seq_lens'

Reproduction (REQUIRED)

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.policy import TupleActions
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.tune.registry import register_env
import gym
from gym.spaces import Discrete, Box, Dict, MultiDiscrete
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor


def _make_f32_array(number):
    return np.array(number, dtype="float32")


class TorchMultiCategorical(ActionDistribution):
    """MultiCategorical distribution for MultiDiscrete action spaces."""

    @override(ActionDistribution)
    def __init__(self, inputs, model):
        input_lens = model.dist_input_lens
        inputs_splitted = inputs.split(input_lens, dim=1)
        self.cats = [
            torch.distributions.categorical.Categorical(logits=input_)
            for input_ in inputs_splitted
        ]

    @override(ActionDistribution)
    def sample(self):
        arr = [cat.sample() for cat in self.cats]
        ret = torch.stack(arr, dim=1)
        return ret

    @override(ActionDistribution)
    def logp(self, actions):
        # # If tensor is provided, unstack it into list
        if isinstance(actions, torch.Tensor):
            actions = torch.unbind(actions, dim=1)
        logps = torch.stack([cat.log_prob(act) for cat, act in zip(self.cats, actions)])
        return torch.sum(logps, dim=0)

    @override(ActionDistribution)
    def multi_entropy(self):
        return torch.stack([cat.entropy() for cat in self.cats], dim=1)

    @override(ActionDistribution)
    def entropy(self):
        return torch.sum(self.multi_entropy(), dim=1)

    @override(ActionDistribution)
    def multi_kl(self, other):
        return torch.stack(
            [
                torch.distributions.kl.kl_divergence(cat, oth_cat)
                for cat, oth_cat in zip(self.cats, other.cats)
            ],
            dim=1,
        )

    @override(ActionDistribution)
    def kl(self, other):
        return torch.sum(self.multi_kl(other), dim=1)

    @staticmethod
    @override(ActionDistribution)
    def required_model_output_shape(action_space, model_config):
        return np.sum(action_space.nvec)


class ReproEnv(gym.Env):
    def __init__(self, config):
        self.cur_pos = 0
        self.window_size = config["window_size"]
        self.need_reset = False
        self.action_space = MultiDiscrete([3, 2, 51, 10, 2])
        self.observation_space = Dict(
            {
                "lob": Box(low=-np.inf, high=np.inf, shape=(self.window_size, 40)),
                "unallocated_wealth": Box(low=0, high=1, shape=()),
                "taker_fees": Box(low=-1, high=1, shape=()),
                "maker_fees": Box(low=-1, high=1, shape=()),
                "order": Dict(
                    {
                        "side": Discrete(3),
                        "type": Discrete(2),
                        "size": Box(low=0, high=1, shape=()),
                        "price": Box(low=-np.inf, high=np.inf, shape=()),
                        "filled": Box(low=0, high=1, shape=()),
                    }
                ),
                "position": Dict(
                    {
                        "side": Discrete(3),
                        "size": Box(low=0, high=1, shape=()),
                        "entry_price": Box(low=-np.inf, high=np.inf, shape=()),
                        "unrealized_pnl": Box(low=-100, high=np.inf, shape=()),
                    }
                ),
            }
        )

    def reset(self):
        self.cur_pos = 0
        self.need_reset = False
        return self.step([0, 0, 0, 0, 0])[0]  # Noop

    def step(self, action):
        if self.need_reset:
            raise Exception("You need to reset this environment!")
        self.cur_pos += 1
        assert action in self.action_space, action
        if self.cur_pos >= 1000:
            done = True
        else:
            done = False
        info = {}
        if done:
            self.need_reset = True
        observation = {
            "lob": np.zeros((self.window_size, 40)),
            "taker_fees": _make_f32_array(0),
            "maker_fees": _make_f32_array(0),
            "unallocated_wealth": _make_f32_array(0),
            "order": {
                "side": 0,
                "type": 0,
                "size": _make_f32_array(0),
                "price": _make_f32_array(0),
                "filled": _make_f32_array(0),
            },
            "position": {
                "side": 0,
                "size": _make_f32_array(0),
                "entry_price": _make_f32_array(0),
                "unrealized_pnl": _make_f32_array(0),
            },
        }
        assert observation in self.observation_space, observation
        return observation, 0, done, info
        # return observation


class CNN(nn.Module):
    def __init__(self, dropout=0.2):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 2), stride=(1, 2))
        self.conv2 = nn.Conv2d(16, 16, kernel_size=(4, 1))
        self.conv3 = nn.Conv2d(16, 16, kernel_size=(4, 1))

        self.conv4 = nn.Conv2d(16, 32, kernel_size=(1, 2), stride=(1, 2))
        self.conv5 = nn.Conv2d(32, 32, kernel_size=(4, 1))
        self.conv6 = nn.Conv2d(32, 32, kernel_size=(4, 1))

        self.conv7 = nn.Conv2d(32, 64, kernel_size=(1, 10))
        self.conv8 = nn.Conv2d(64, 64, kernel_size=(4, 1))
        self.conv9 = nn.Conv2d(64, 64, kernel_size=(4, 1))
        # Pad to preserve the length in the time domain
        self.pad = nn.ZeroPad2d((0, 0, 0, 3))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(self.pad(x)))
        x = F.leaky_relu(self.conv3(self.pad(x)))
        x = F.leaky_relu(self.conv4(x))
        x = F.leaky_relu(self.conv5(self.pad(x)))
        x = F.leaky_relu(self.conv6(self.pad(x)))
        x = F.leaky_relu(self.conv7(x))
        x = F.leaky_relu(self.conv8(self.pad(x)))
        x = F.leaky_relu(self.conv9(self.pad(x)))
        x = self.dropout(x)
        return x


class TestNet(TorchModelV2, nn.Module):
    def init_hidden(self, hidden_size):
        h0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
        c0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
        return (h0, c0)

    def __init__(self, obs_space, action_space, num_outputs, config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, config, name)
        nn.Module.__init__(self)
        model_config = config["custom_options"]
        print("Model config:")
        print(model_config)
        dropout = model_config["dropout"]
        window_size = model_config["window_size"]
        self.cnn = CNN(dropout=dropout)
        print(f"Dropout: {dropout}")
        print(f"Window size: {window_size}")

        # Value function
        self._value_branch = nn.Sequential(
            nn.Linear(6, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1),
        )
        # Policy: Signal
        self.long_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)
        self.short_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)

        self.long_025_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_025_fc = nn.Linear(32, 1)

        self.long_050_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_050_fc = nn.Linear(32, 1)

        self.long_075_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_075_fc = nn.Linear(32, 1)

        self.short_025_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_025_fc = nn.Linear(32, 1)

        self.short_050_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_050_fc = nn.Linear(32, 1)

        self.short_075_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_075_fc = nn.Linear(32, 1)
        # Policy: Brain
        self.dumb_fc = nn.Linear(6, 68)

        self.dist_input_lens = [3, 2, 51, 10, 2]
        self._cur_value = None

    @override(TorchModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        long_lstm_h, long_lstm_c = self.init_hidden(64)
        short_lstm_h, short_lstm_c = self.init_hidden(64)

        long_025_lstm_h, long_025_lstm_c = self.init_hidden(32)
        long_050_lstm_h, long_050_lstm_c = self.init_hidden(32)
        long_075_lstm_h, long_075_lstm_c = self.init_hidden(32)

        short_025_lstm_h, short_025_lstm_c = self.init_hidden(32)
        short_050_lstm_h, short_050_lstm_c = self.init_hidden(32)
        short_075_lstm_h, short_075_lstm_c = self.init_hidden(32)

        initial_state = [
            long_lstm_h,
            long_lstm_c,
            short_lstm_h,
            short_lstm_c,
            long_025_lstm_h,
            long_025_lstm_c,
            long_050_lstm_c,
            long_050_lstm_c,
            long_075_lstm_h,
            long_075_lstm_h,
            short_025_lstm_h,
            short_025_lstm_c,
            short_050_lstm_c,
            short_050_lstm_c,
            short_075_lstm_h,
            short_075_lstm_h,
        ]
        return initial_state

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

    def forward(self, input_dict, hidden_state, seq_lens):
        # if seq_lens is None:
        #     raise Exception("seq_lens is None")
        lob = input_dict["obs"]["lob"]
        batch_size, window_length, features = lob.size()
        # assert list(hidden_state[0].size()) == [1, 1, 64]
        # Unpack the hidden_state
        long_lstm_h = hidden_state[0]
        long_lstm_c = hidden_state[1]
        short_lstm_h = hidden_state[2]
        short_lstm_c = hidden_state[3]

        long_025_lstm_h = hidden_state[4]
        long_025_lstm_c = hidden_state[5]
        long_050_lstm_h = hidden_state[6]
        long_050_lstm_c = hidden_state[7]
        long_075_lstm_h = hidden_state[8]
        long_075_lstm_c = hidden_state[9]

        short_025_lstm_h = hidden_state[10]
        short_025_lstm_c = hidden_state[11]
        short_050_lstm_h = hidden_state[12]
        short_050_lstm_c = hidden_state[13]
        short_075_lstm_h = hidden_state[14]
        short_075_lstm_c = hidden_state[15]
        # Build the tuples
        long_lstm_hidden = (long_lstm_h.view(1, -1, 64), long_lstm_c.view(1, -1, 64))
        short_lstm_hidden = (short_lstm_h.view(1, -1, 64), short_lstm_c.view(1, -1, 64))
        long_025_lstm_hidden = (
            long_025_lstm_h.view(1, -1, 32),
            long_025_lstm_c.view(1, -1, 32),
        )
        long_050_lstm_hidden = (
            long_050_lstm_h.view(1, -1, 32),
            long_050_lstm_c.view(1, -1, 32),
        )
        long_075_lstm_hidden = (
            long_075_lstm_h.view(1, -1, 32),
            long_075_lstm_c.view(1, -1, 32),
        )
        short_025_lstm_hidden = (
            short_025_lstm_h.view(1, -1, 32),
            short_025_lstm_c.view(1, -1, 32),
        )
        short_050_lstm_hidden = (
            short_050_lstm_h.view(1, -1, 32),
            short_050_lstm_c.view(1, -1, 32),
        )
        short_075_lstm_hidden = (
            short_075_lstm_h.view(1, -1, 32),
            short_075_lstm_c.view(1, -1, 32),
        )

        c_in = lob.view(batch_size, 1, window_length, features)
        c_out = self.cnn(c_in)
        # Embeddings from the CNN, reshaped to be consummed by the LSTM
        embeddings = c_out.view(batch_size, 1, -1)

        long_out, long_lstm_hidden = self.long_lstm(embeddings, long_lstm_hidden)
        short_out, short_lstm_hidden = self.short_lstm(embeddings, short_lstm_hidden)
        # Now on to the tail LSTMs
        long_025_out, long_025_lstm_hidden = self.long_025_lstm(
            F.leaky_relu(long_out), long_025_lstm_hidden
        )
        long_050_out, long_050_lstm_hidden = self.long_050_lstm(
            F.leaky_relu(long_out), long_050_lstm_hidden
        )
        long_075_out, long_075_lstm_hidden = self.long_075_lstm(
            F.leaky_relu(long_out), long_075_lstm_hidden
        )

        short_025_out, short_025_lstm_hidden = self.short_025_lstm(
            F.leaky_relu(short_out), short_025_lstm_hidden
        )
        short_050_out, short_050_lstm_hidden = self.short_050_lstm(
            F.leaky_relu(short_out), short_050_lstm_hidden
        )
        short_075_out, short_075_lstm_hidden = self.short_075_lstm(
            F.leaky_relu(short_out), short_075_lstm_hidden
        )
        # Reshape the outputs of the tail LSTMs into (batch, hidden_size)
        long_025_out = long_025_out.view(batch_size, -1)
        long_050_out = long_050_out.view(batch_size, -1)
        long_075_out = long_075_out.view(batch_size, -1)

        short_025_out = short_025_out.view(batch_size, -1)
        short_050_out = short_050_out.view(batch_size, -1)
        short_075_out = short_075_out.view(batch_size, -1)
        # Fully connected at the end
        long_025_q = self.long_025_fc(F.leaky_relu(long_025_out))
        long_050_q = self.long_050_fc(F.leaky_relu(long_050_out))
        long_075_q = self.long_075_fc(F.leaky_relu(long_075_out))

        short_025_q = self.short_025_fc(F.leaky_relu(short_025_out))
        short_050_q = self.short_050_fc(F.leaky_relu(short_050_out))
        short_075_q = self.short_075_fc(F.leaky_relu(short_075_out))
        quantiles = [
            long_025_q,
            long_050_q,
            long_075_q,
            short_025_q,
            short_050_q,
            short_075_q,
        ]
        quantiles = torch.cat(quantiles, dim=1).view(batch_size, 6)
        new_hidden_state = [
            long_lstm_hidden[0],
            long_lstm_hidden[1],
            short_lstm_hidden[0],
            short_lstm_hidden[1],
            long_025_lstm_hidden[0],
            long_025_lstm_hidden[1],
            long_050_lstm_hidden[0],
            long_050_lstm_hidden[1],
            long_075_lstm_hidden[0],
            long_075_lstm_hidden[1],
            short_025_lstm_hidden[0],
            short_025_lstm_hidden[1],
            short_050_lstm_hidden[0],
            short_050_lstm_hidden[1],
            short_075_lstm_hidden[0],
            short_075_lstm_hidden[1],
        ]
        assert list(new_hidden_state[0].size()) == [
            1,
            list(new_hidden_state[0].size())[1],
            64,
        ], new_hidden_state[0].size()
        # Value function
        self._cur_value = self._value_branch(quantiles).squeeze(1)
        logits = self.dumb_fc(quantiles)
        return logits, new_hidden_state


ModelCatalog.register_custom_action_dist("torchmulticategorical", TorchMultiCategorical)
ModelCatalog.register_custom_model("test", TestNet)
register_env("test", lambda config: ReproEnv(config))

ray.init()
tune.run(
    ppo.PPOTrainer,
    config={
        "num_workers": 1,
        "env": "test",
        "log_level": "INFO",
        "use_pytorch": True,
        "num_gpus": 1,
        "vf_share_layers": True,
        "env_config": {"window_size": 100,},
        "model": {
            "custom_action_dist": "torchmulticategorical",
            "custom_model": "test",
            "custom_options": {
                "window_size": 100,
                "dropout": 0.2,
                "use_learned_hidden": True,
            },
        },
    },
)

  • [x] I have verified my script runs in a clean environment and reproduces the issue.
  • [x] I have verified the issue also occurs with the latest wheels.
P1 bug rllib

All 16 comments

I think this is since we don't support RNNs with pytorch yet, cc @sven1977

From the doc: "Similarly, you can create and register custom PyTorch models for use with PyTorch-based algorithms (e.g., A2C, PG, QMIX). See these examples of fully connected, convolutional, and recurrent torch models."
It looks like you do!
Here is a pytorch rnn in the codebase: https://github.com/ray-project/ray/blob/master/rllib/agents/qmix/model.py

I'll assign myself and take a look. Thanks for filing this!

Yes, we'll have to fix LSTM support for torch.

I have the same problem right now and would be happy to look into it. What has to be done to make it work?

Thanks for offering your help everyone!
I'm taking a look right now. ... I'll keep you posted throughout the next few hours/days what I find.

Great, thanks! This is my current minimum example in case it helps to reproduce the error (latest wheel, Python 3.6):

import argparse
import math
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np

from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog

torch, nn = try_import_torch()

class CartPoleStatelessEnv(gym.Env):
    # ... refer https://github.com/ray-project/ray/blob/5cebee68d681bebfd59255b811338d39e2cc2e7d/rllib/examples/cartpole_lstm.py

def _get_size(obs_space):
    return get_preprocessor(obs_space)(obs_space).size

class RNNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        self.obs_size = _get_size(obs_space)
        self.rnn_hidden_dim = model_config["lstm_cell_size"]
        self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)

        self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
        self._cur_value = None

    @override(TorchModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
        return h

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

    @override(TorchModelV2)
    def forward(self, input_dict, hidden_state, seq_lens):
        x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
        h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        q = self.fc2(h)
        self._cur_value = self.value_branch(h).squeeze(1)
        return q, [h]


if __name__ == "__main__":
    import ray
    from ray import tune

    ModelCatalog.register_custom_model("rnnmodel", RNNModel)
    tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())

    ray.init()

    tune.run(
        "PPO",
        stop={"episode_reward_mean": 200},
        config={
            "use_pytorch": True,
            "model": {
                "custom_model": "rnnmodel",
                "lstm_use_prev_action_reward": "store_true",
                "lstm_cell_size": 20,
                "custom_options": {}
            },
            "num_sgd_iter": 5,
            "vf_share_layers": True,
            "vf_loss_coeff": 0.0001,
            "env": "cartpole_stateless",
        }
    )

With the error

Traceback (most recent call last):                                                                                                                                                                                 
  File "[...]/ray/tune/trial_runner.py", line 467, in _process_trial                                                                     
    result = self.trial_executor.fetch_result(trial)                                                                                                                                                               
  File "[...]/ray/tune/ray_trial_executor.py", line 381, in fetch_result                                                                 
    result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)                                                                                                                                                         
  File "[...]/ray/worker.py", line 1505, in get                                                                                          
    raise value.as_instanceof_cause()                                                                                                                                                                              
ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=21658, ip=128.232.69.20)                                                                                                                              
  File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task                                                                                                                                             
  File "python/ray/_raylet.pyx", line 423, in ray._raylet.execute_task.function_executor                                                                                                                           
  File "[...]/ray/rllib/agents/trainer.py", line 504, in train                                                                           
    raise e                                                                                                                                                                                                        
  File "[...]/ray/rllib/agents/trainer.py", line 490, in train                                                                           
    result = Trainable.train(self)                                                                                                                                                                                 
  File "[...]/ray/tune/trainable.py", line 261, in train                                                                                 
    result = self._train()                                                                                                                                                                                         
  File "[...]/ray/rllib/agents/trainer_template.py", line 150, in _train                                                                 
    fetches = self.optimizer.step()                                                                                                                                                                                
  File "[...]/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step                                                          
    self.standardize_fields)                                                                                                                                                                                       
  File "[...]/ray/rllib/utils/sgd.py", line 115, in do_minibatch_sgd                                                                     
    }, minibatch.count)))[policy_id]                                                                                                                                                                               
  File "[...]/ray/rllib/evaluation/rollout_worker.py", line 632, in learn_on_batch
    info_out[pid] = policy.learn_on_batch(batch)
  File "[...]/ray/rllib/policy/torch_policy.py", line 132, in learn_on_batch
    loss_out = self._loss(self, self.model, self.dist_class, train_batch)
  File "[...]/ray/rllib/agents/ppo/ppo_torch_policy.py", line 120, in ppo_surrogate_loss
    max_seq_len = torch.max(train_batch["seq_lens"])
  File "[...]/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__
    value = dict.__getitem__(self, key)
KeyError: 'seq_lens'

Yeah, I think I know what it is: In rllib/policy/tf_policy.py, we call: _get_loss_inputs_dict (Worker calls: Policy.lean_on_batch, which then calls Policy._build_learn_on_batch in the tf-case) and that calculates and adds the seq_lens. In pytorch, we don't seem to do anything comparable. I'll fix this now.

Almost there. Just some flaws now in the PPO loss concerning valid_mask.
Apologies for the docs mentioning that we do generically support pytorch + LSTMs: We don't (yet)! There will be a PR (probably tomorrow), which will fix that for at least the standard PG-algos: PPO, PG, A2C/A3C, iff one uses a custom torch Model. Making the "use_lstm" auto-wrapping functionality work will be a follow-up PR.

There was quite some stuff missing for this to work. I'll do a WIP PR later today and post it here. Got the example running, but CartPole doesn't seem to learn with PPO + torch + LSTM. Will have to take a closer look.

Here is the WIP PR that makes the above minimal example run: https://github.com/ray-project/ray/pull/7797
Will add more tests and make sure the CartPole example learns as well.

Ok, #7797 is learning the RepeatInitialEnv example using PPO + torch + a custom torch model.
Check out this example script (included in the PR; you need this PR for the code below to run and learn). Will be merged within the next few days.

import argparse

import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
import ray.tune as tune

torch, nn = try_import_torch()

parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="repeat_initial")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--lstm-cell-size", type=int, default=32)


class RNNModel(RecurrentTorchModel):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.rnn_hidden_dim = model_config["lstm_cell_size"]
        self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)

        self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
        self._cur_value = None

    @override(ModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

    @override(RecurrentTorchModel)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.

        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).

        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batch as a List of one item (only one hidden state b/c
                Gru).
        """
        x = nn.functional.relu(self.fc1(inputs))
        h = state[0]
        outs = []
        for i in range(torch.max(seq_lens)):
            h = self.rnn(x[:, i], h)
            outs.append(h)
        outs = torch.stack(outs, 0)
        q = self.fc2(outs)
        self._cur_value = torch.reshape(self.value_branch(outs), [-1])
        return q, [outs[-1]]


if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(num_cpus=args.num_cpus or None)
    ModelCatalog.register_custom_model("rnn", RNNModel)
    tune.register_env("repeat_initial", lambda c: RepeatInitialEnv())
    tune.register_env("cartpole_stateless", lambda c: CartPoleStatelessEnv())

    config = {
        "num_workers": 0,
        "num_envs_per_worker": 20,
        "gamma": 0.9,
        "entropy_coeff": 0.001,
        "use_pytorch": True,
        "model": {
            "custom_model": "rnn",
            "lstm_use_prev_action_reward": "store_true",
            "lstm_cell_size": args.lstm_cell_size,
            "custom_options": {}
        },
        "lr": 0.0003,
        "num_sgd_iter": 5,
        "vf_loss_coeff": 1e-5,
        "env": args.env,
    }

    tune.run(
        args.run,
        stop={"episode_reward_mean": args.stop},
        config=config,
    )

@justinglibert (see post above).
@janblumenkamp and @justinglibert Thanks for you help, guys!

Awesome, thank you very much Sven! I will try it!

Updated the PR once more. Switched out the Gru for an LSTM and this works much better now on our test envs.

The PR that fixes this problem (https://github.com/ray-project/ray/pull/7797) has been merged.

Please see this example code here for a learning example with PPO:
https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_rnn_model.py

Closing this issue.

Was this page helpful?
0 / 5 - 0 ratings