Installed ray with the nightly wheel.
I wrote a custom env, model, and action distribution.
I attempt to train it with PPO but there is a key error in one of the internal object used by RLLib (the batch dict with "seq_lens" that is used for masking recurrent model when backpropagating)
2020-02-18 16:11:55,261 ERROR trial_runner.py:513 -- Trial PPO_test_49f2c33a: Error processing event.
Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 459, in _process_trial
result = self.trial_executor.fetch_result(trial)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 377, in fetch_result
result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/worker.py", line 1522, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=13094, ip=10.0.2.217)
File "python/ray/_raylet.pyx", line 447, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 425, in ray._raylet.execute_task.function_executor
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 477, in train
raise e
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 463, in train
result = Trainable.train(self)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trainable.py", line 254, in train
result = self._train()
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 122, in _train
fetches = self.optimizer.step()
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step
self.standardize_fields)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/sgd.py", line 111, in do_minibatch_sgd
}, minibatch.count)))[policy_id]
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 619, in learn_on_batch
info_out[pid] = policy.learn_on_batch(batch)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 100, in learn_on_batch
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 112, in ppo_surrogate_loss
print(train_batch["seq_lens"])
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__
value = dict.__getitem__(self, key)
KeyError: 'seq_lens'
import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.policy import TupleActions
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.tune.registry import register_env
import gym
from gym.spaces import Discrete, Box, Dict, MultiDiscrete
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
def _make_f32_array(number):
return np.array(number, dtype="float32")
class TorchMultiCategorical(ActionDistribution):
"""MultiCategorical distribution for MultiDiscrete action spaces."""
@override(ActionDistribution)
def __init__(self, inputs, model):
input_lens = model.dist_input_lens
inputs_splitted = inputs.split(input_lens, dim=1)
self.cats = [
torch.distributions.categorical.Categorical(logits=input_)
for input_ in inputs_splitted
]
@override(ActionDistribution)
def sample(self):
arr = [cat.sample() for cat in self.cats]
ret = torch.stack(arr, dim=1)
return ret
@override(ActionDistribution)
def logp(self, actions):
# # If tensor is provided, unstack it into list
if isinstance(actions, torch.Tensor):
actions = torch.unbind(actions, dim=1)
logps = torch.stack([cat.log_prob(act) for cat, act in zip(self.cats, actions)])
return torch.sum(logps, dim=0)
@override(ActionDistribution)
def multi_entropy(self):
return torch.stack([cat.entropy() for cat in self.cats], dim=1)
@override(ActionDistribution)
def entropy(self):
return torch.sum(self.multi_entropy(), dim=1)
@override(ActionDistribution)
def multi_kl(self, other):
return torch.stack(
[
torch.distributions.kl.kl_divergence(cat, oth_cat)
for cat, oth_cat in zip(self.cats, other.cats)
],
dim=1,
)
@override(ActionDistribution)
def kl(self, other):
return torch.sum(self.multi_kl(other), dim=1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.sum(action_space.nvec)
class ReproEnv(gym.Env):
def __init__(self, config):
self.cur_pos = 0
self.window_size = config["window_size"]
self.need_reset = False
self.action_space = MultiDiscrete([3, 2, 51, 10, 2])
self.observation_space = Dict(
{
"lob": Box(low=-np.inf, high=np.inf, shape=(self.window_size, 40)),
"unallocated_wealth": Box(low=0, high=1, shape=()),
"taker_fees": Box(low=-1, high=1, shape=()),
"maker_fees": Box(low=-1, high=1, shape=()),
"order": Dict(
{
"side": Discrete(3),
"type": Discrete(2),
"size": Box(low=0, high=1, shape=()),
"price": Box(low=-np.inf, high=np.inf, shape=()),
"filled": Box(low=0, high=1, shape=()),
}
),
"position": Dict(
{
"side": Discrete(3),
"size": Box(low=0, high=1, shape=()),
"entry_price": Box(low=-np.inf, high=np.inf, shape=()),
"unrealized_pnl": Box(low=-100, high=np.inf, shape=()),
}
),
}
)
def reset(self):
self.cur_pos = 0
self.need_reset = False
return self.step([0, 0, 0, 0, 0])[0] # Noop
def step(self, action):
if self.need_reset:
raise Exception("You need to reset this environment!")
self.cur_pos += 1
assert action in self.action_space, action
if self.cur_pos >= 1000:
done = True
else:
done = False
info = {}
if done:
self.need_reset = True
observation = {
"lob": np.zeros((self.window_size, 40)),
"taker_fees": _make_f32_array(0),
"maker_fees": _make_f32_array(0),
"unallocated_wealth": _make_f32_array(0),
"order": {
"side": 0,
"type": 0,
"size": _make_f32_array(0),
"price": _make_f32_array(0),
"filled": _make_f32_array(0),
},
"position": {
"side": 0,
"size": _make_f32_array(0),
"entry_price": _make_f32_array(0),
"unrealized_pnl": _make_f32_array(0),
},
}
assert observation in self.observation_space, observation
return observation, 0, done, info
# return observation
class CNN(nn.Module):
def __init__(self, dropout=0.2):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 2), stride=(1, 2))
self.conv2 = nn.Conv2d(16, 16, kernel_size=(4, 1))
self.conv3 = nn.Conv2d(16, 16, kernel_size=(4, 1))
self.conv4 = nn.Conv2d(16, 32, kernel_size=(1, 2), stride=(1, 2))
self.conv5 = nn.Conv2d(32, 32, kernel_size=(4, 1))
self.conv6 = nn.Conv2d(32, 32, kernel_size=(4, 1))
self.conv7 = nn.Conv2d(32, 64, kernel_size=(1, 10))
self.conv8 = nn.Conv2d(64, 64, kernel_size=(4, 1))
self.conv9 = nn.Conv2d(64, 64, kernel_size=(4, 1))
# Pad to preserve the length in the time domain
self.pad = nn.ZeroPad2d((0, 0, 0, 3))
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = F.leaky_relu(self.conv1(x))
x = F.leaky_relu(self.conv2(self.pad(x)))
x = F.leaky_relu(self.conv3(self.pad(x)))
x = F.leaky_relu(self.conv4(x))
x = F.leaky_relu(self.conv5(self.pad(x)))
x = F.leaky_relu(self.conv6(self.pad(x)))
x = F.leaky_relu(self.conv7(x))
x = F.leaky_relu(self.conv8(self.pad(x)))
x = F.leaky_relu(self.conv9(self.pad(x)))
x = self.dropout(x)
return x
class TestNet(TorchModelV2, nn.Module):
def init_hidden(self, hidden_size):
h0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
c0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
return (h0, c0)
def __init__(self, obs_space, action_space, num_outputs, config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, config, name)
nn.Module.__init__(self)
model_config = config["custom_options"]
print("Model config:")
print(model_config)
dropout = model_config["dropout"]
window_size = model_config["window_size"]
self.cnn = CNN(dropout=dropout)
print(f"Dropout: {dropout}")
print(f"Window size: {window_size}")
# Value function
self._value_branch = nn.Sequential(
nn.Linear(6, 128),
nn.LeakyReLU(),
nn.Linear(128, 128),
nn.LeakyReLU(),
nn.Linear(128, 128),
nn.LeakyReLU(),
nn.Linear(128, 1),
)
# Policy: Signal
self.long_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)
self.short_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)
self.long_025_lstm = nn.LSTM(64, 32, batch_first=True)
self.long_025_fc = nn.Linear(32, 1)
self.long_050_lstm = nn.LSTM(64, 32, batch_first=True)
self.long_050_fc = nn.Linear(32, 1)
self.long_075_lstm = nn.LSTM(64, 32, batch_first=True)
self.long_075_fc = nn.Linear(32, 1)
self.short_025_lstm = nn.LSTM(64, 32, batch_first=True)
self.short_025_fc = nn.Linear(32, 1)
self.short_050_lstm = nn.LSTM(64, 32, batch_first=True)
self.short_050_fc = nn.Linear(32, 1)
self.short_075_lstm = nn.LSTM(64, 32, batch_first=True)
self.short_075_fc = nn.Linear(32, 1)
# Policy: Brain
self.dumb_fc = nn.Linear(6, 68)
self.dist_input_lens = [3, 2, 51, 10, 2]
self._cur_value = None
@override(TorchModelV2)
def get_initial_state(self):
# make hidden states on same device as model
long_lstm_h, long_lstm_c = self.init_hidden(64)
short_lstm_h, short_lstm_c = self.init_hidden(64)
long_025_lstm_h, long_025_lstm_c = self.init_hidden(32)
long_050_lstm_h, long_050_lstm_c = self.init_hidden(32)
long_075_lstm_h, long_075_lstm_c = self.init_hidden(32)
short_025_lstm_h, short_025_lstm_c = self.init_hidden(32)
short_050_lstm_h, short_050_lstm_c = self.init_hidden(32)
short_075_lstm_h, short_075_lstm_c = self.init_hidden(32)
initial_state = [
long_lstm_h,
long_lstm_c,
short_lstm_h,
short_lstm_c,
long_025_lstm_h,
long_025_lstm_c,
long_050_lstm_c,
long_050_lstm_c,
long_075_lstm_h,
long_075_lstm_h,
short_025_lstm_h,
short_025_lstm_c,
short_050_lstm_c,
short_050_lstm_c,
short_075_lstm_h,
short_075_lstm_h,
]
return initial_state
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
def forward(self, input_dict, hidden_state, seq_lens):
# if seq_lens is None:
# raise Exception("seq_lens is None")
lob = input_dict["obs"]["lob"]
batch_size, window_length, features = lob.size()
# assert list(hidden_state[0].size()) == [1, 1, 64]
# Unpack the hidden_state
long_lstm_h = hidden_state[0]
long_lstm_c = hidden_state[1]
short_lstm_h = hidden_state[2]
short_lstm_c = hidden_state[3]
long_025_lstm_h = hidden_state[4]
long_025_lstm_c = hidden_state[5]
long_050_lstm_h = hidden_state[6]
long_050_lstm_c = hidden_state[7]
long_075_lstm_h = hidden_state[8]
long_075_lstm_c = hidden_state[9]
short_025_lstm_h = hidden_state[10]
short_025_lstm_c = hidden_state[11]
short_050_lstm_h = hidden_state[12]
short_050_lstm_c = hidden_state[13]
short_075_lstm_h = hidden_state[14]
short_075_lstm_c = hidden_state[15]
# Build the tuples
long_lstm_hidden = (long_lstm_h.view(1, -1, 64), long_lstm_c.view(1, -1, 64))
short_lstm_hidden = (short_lstm_h.view(1, -1, 64), short_lstm_c.view(1, -1, 64))
long_025_lstm_hidden = (
long_025_lstm_h.view(1, -1, 32),
long_025_lstm_c.view(1, -1, 32),
)
long_050_lstm_hidden = (
long_050_lstm_h.view(1, -1, 32),
long_050_lstm_c.view(1, -1, 32),
)
long_075_lstm_hidden = (
long_075_lstm_h.view(1, -1, 32),
long_075_lstm_c.view(1, -1, 32),
)
short_025_lstm_hidden = (
short_025_lstm_h.view(1, -1, 32),
short_025_lstm_c.view(1, -1, 32),
)
short_050_lstm_hidden = (
short_050_lstm_h.view(1, -1, 32),
short_050_lstm_c.view(1, -1, 32),
)
short_075_lstm_hidden = (
short_075_lstm_h.view(1, -1, 32),
short_075_lstm_c.view(1, -1, 32),
)
c_in = lob.view(batch_size, 1, window_length, features)
c_out = self.cnn(c_in)
# Embeddings from the CNN, reshaped to be consummed by the LSTM
embeddings = c_out.view(batch_size, 1, -1)
long_out, long_lstm_hidden = self.long_lstm(embeddings, long_lstm_hidden)
short_out, short_lstm_hidden = self.short_lstm(embeddings, short_lstm_hidden)
# Now on to the tail LSTMs
long_025_out, long_025_lstm_hidden = self.long_025_lstm(
F.leaky_relu(long_out), long_025_lstm_hidden
)
long_050_out, long_050_lstm_hidden = self.long_050_lstm(
F.leaky_relu(long_out), long_050_lstm_hidden
)
long_075_out, long_075_lstm_hidden = self.long_075_lstm(
F.leaky_relu(long_out), long_075_lstm_hidden
)
short_025_out, short_025_lstm_hidden = self.short_025_lstm(
F.leaky_relu(short_out), short_025_lstm_hidden
)
short_050_out, short_050_lstm_hidden = self.short_050_lstm(
F.leaky_relu(short_out), short_050_lstm_hidden
)
short_075_out, short_075_lstm_hidden = self.short_075_lstm(
F.leaky_relu(short_out), short_075_lstm_hidden
)
# Reshape the outputs of the tail LSTMs into (batch, hidden_size)
long_025_out = long_025_out.view(batch_size, -1)
long_050_out = long_050_out.view(batch_size, -1)
long_075_out = long_075_out.view(batch_size, -1)
short_025_out = short_025_out.view(batch_size, -1)
short_050_out = short_050_out.view(batch_size, -1)
short_075_out = short_075_out.view(batch_size, -1)
# Fully connected at the end
long_025_q = self.long_025_fc(F.leaky_relu(long_025_out))
long_050_q = self.long_050_fc(F.leaky_relu(long_050_out))
long_075_q = self.long_075_fc(F.leaky_relu(long_075_out))
short_025_q = self.short_025_fc(F.leaky_relu(short_025_out))
short_050_q = self.short_050_fc(F.leaky_relu(short_050_out))
short_075_q = self.short_075_fc(F.leaky_relu(short_075_out))
quantiles = [
long_025_q,
long_050_q,
long_075_q,
short_025_q,
short_050_q,
short_075_q,
]
quantiles = torch.cat(quantiles, dim=1).view(batch_size, 6)
new_hidden_state = [
long_lstm_hidden[0],
long_lstm_hidden[1],
short_lstm_hidden[0],
short_lstm_hidden[1],
long_025_lstm_hidden[0],
long_025_lstm_hidden[1],
long_050_lstm_hidden[0],
long_050_lstm_hidden[1],
long_075_lstm_hidden[0],
long_075_lstm_hidden[1],
short_025_lstm_hidden[0],
short_025_lstm_hidden[1],
short_050_lstm_hidden[0],
short_050_lstm_hidden[1],
short_075_lstm_hidden[0],
short_075_lstm_hidden[1],
]
assert list(new_hidden_state[0].size()) == [
1,
list(new_hidden_state[0].size())[1],
64,
], new_hidden_state[0].size()
# Value function
self._cur_value = self._value_branch(quantiles).squeeze(1)
logits = self.dumb_fc(quantiles)
return logits, new_hidden_state
ModelCatalog.register_custom_action_dist("torchmulticategorical", TorchMultiCategorical)
ModelCatalog.register_custom_model("test", TestNet)
register_env("test", lambda config: ReproEnv(config))
ray.init()
tune.run(
ppo.PPOTrainer,
config={
"num_workers": 1,
"env": "test",
"log_level": "INFO",
"use_pytorch": True,
"num_gpus": 1,
"vf_share_layers": True,
"env_config": {"window_size": 100,},
"model": {
"custom_action_dist": "torchmulticategorical",
"custom_model": "test",
"custom_options": {
"window_size": 100,
"dropout": 0.2,
"use_learned_hidden": True,
},
},
},
)
I think this is since we don't support RNNs with pytorch yet, cc @sven1977
From the doc: "Similarly, you can create and register custom PyTorch models for use with PyTorch-based algorithms (e.g., A2C, PG, QMIX). See these examples of fully connected, convolutional, and recurrent torch models."
It looks like you do!
Here is a pytorch rnn in the codebase: https://github.com/ray-project/ray/blob/master/rllib/agents/qmix/model.py
I'll assign myself and take a look. Thanks for filing this!
Yes, we'll have to fix LSTM support for torch.
I have the same problem right now and would be happy to look into it. What has to be done to make it work?
Thanks for offering your help everyone!
I'm taking a look right now. ... I'll keep you posted throughout the next few hours/days what I find.
Great, thanks! This is my current minimum example in case it helps to reproduce the error (latest wheel, Python 3.6):
import argparse
import math
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
torch, nn = try_import_torch()
class CartPoleStatelessEnv(gym.Env):
# ... refer https://github.com/ray-project/ray/blob/5cebee68d681bebfd59255b811338d39e2cc2e7d/rllib/examples/cartpole_lstm.py
def _get_size(obs_space):
return get_preprocessor(obs_space)(obs_space).size
class RNNModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.obs_size = _get_size(obs_space)
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
self._cur_value = None
@override(TorchModelV2)
def get_initial_state(self):
# make hidden states on same device as model
h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
return h
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
@override(TorchModelV2)
def forward(self, input_dict, hidden_state, seq_lens):
x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
self._cur_value = self.value_branch(h).squeeze(1)
return q, [h]
if __name__ == "__main__":
import ray
from ray import tune
ModelCatalog.register_custom_model("rnnmodel", RNNModel)
tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())
ray.init()
tune.run(
"PPO",
stop={"episode_reward_mean": 200},
config={
"use_pytorch": True,
"model": {
"custom_model": "rnnmodel",
"lstm_use_prev_action_reward": "store_true",
"lstm_cell_size": 20,
"custom_options": {}
},
"num_sgd_iter": 5,
"vf_share_layers": True,
"vf_loss_coeff": 0.0001,
"env": "cartpole_stateless",
}
)
With the error
Traceback (most recent call last):
File "[...]/ray/tune/trial_runner.py", line 467, in _process_trial
result = self.trial_executor.fetch_result(trial)
File "[...]/ray/tune/ray_trial_executor.py", line 381, in fetch_result
result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT)
File "[...]/ray/worker.py", line 1505, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=21658, ip=128.232.69.20)
File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 423, in ray._raylet.execute_task.function_executor
File "[...]/ray/rllib/agents/trainer.py", line 504, in train
raise e
File "[...]/ray/rllib/agents/trainer.py", line 490, in train
result = Trainable.train(self)
File "[...]/ray/tune/trainable.py", line 261, in train
result = self._train()
File "[...]/ray/rllib/agents/trainer_template.py", line 150, in _train
fetches = self.optimizer.step()
File "[...]/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step
self.standardize_fields)
File "[...]/ray/rllib/utils/sgd.py", line 115, in do_minibatch_sgd
}, minibatch.count)))[policy_id]
File "[...]/ray/rllib/evaluation/rollout_worker.py", line 632, in learn_on_batch
info_out[pid] = policy.learn_on_batch(batch)
File "[...]/ray/rllib/policy/torch_policy.py", line 132, in learn_on_batch
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
File "[...]/ray/rllib/agents/ppo/ppo_torch_policy.py", line 120, in ppo_surrogate_loss
max_seq_len = torch.max(train_batch["seq_lens"])
File "[...]/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__
value = dict.__getitem__(self, key)
KeyError: 'seq_lens'
Yeah, I think I know what it is: In rllib/policy/tf_policy.py, we call: _get_loss_inputs_dict (Worker calls: Policy.lean_on_batch, which then calls Policy._build_learn_on_batch in the tf-case) and that calculates and adds the seq_lens. In pytorch, we don't seem to do anything comparable. I'll fix this now.
Almost there. Just some flaws now in the PPO loss concerning valid_mask.
Apologies for the docs mentioning that we do generically support pytorch + LSTMs: We don't (yet)! There will be a PR (probably tomorrow), which will fix that for at least the standard PG-algos: PPO, PG, A2C/A3C, iff one uses a custom torch Model. Making the "use_lstm" auto-wrapping functionality work will be a follow-up PR.
There was quite some stuff missing for this to work. I'll do a WIP PR later today and post it here. Got the example running, but CartPole doesn't seem to learn with PPO + torch + LSTM. Will have to take a closer look.
Here is the WIP PR that makes the above minimal example run: https://github.com/ray-project/ray/pull/7797
Will add more tests and make sure the CartPole example learns as well.
Ok, #7797 is learning the RepeatInitialEnv example using PPO + torch + a custom torch model.
Check out this example script (included in the PR; you need this PR for the code below to run and learn). Will be merged within the next few days.
import argparse
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_torch_model import RecurrentTorchModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch
from ray.rllib.models import ModelCatalog
import ray.tune as tune
torch, nn = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="repeat_initial")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--lstm-cell-size", type=int, default=32)
class RNNModel(RecurrentTorchModel):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
self.obs_size = get_preprocessor(obs_space)(obs_space).size
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
self._cur_value = None
@override(ModelV2)
def get_initial_state(self):
# make hidden states on same device as model
h = [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
return h
@override(ModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
@override(RecurrentTorchModel)
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
Returns the resulting outputs as a sequence (B x T x ...).
Values are stored in self._cur_value in simple (B) shape (where B
contains both the B and T dims!).
Returns:
NN Outputs (B x T x ...) as sequence.
The state batch as a List of one item (only one hidden state b/c
Gru).
"""
x = nn.functional.relu(self.fc1(inputs))
h = state[0]
outs = []
for i in range(torch.max(seq_lens)):
h = self.rnn(x[:, i], h)
outs.append(h)
outs = torch.stack(outs, 0)
q = self.fc2(outs)
self._cur_value = torch.reshape(self.value_branch(outs), [-1])
return q, [outs[-1]]
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
ModelCatalog.register_custom_model("rnn", RNNModel)
tune.register_env("repeat_initial", lambda c: RepeatInitialEnv())
tune.register_env("cartpole_stateless", lambda c: CartPoleStatelessEnv())
config = {
"num_workers": 0,
"num_envs_per_worker": 20,
"gamma": 0.9,
"entropy_coeff": 0.001,
"use_pytorch": True,
"model": {
"custom_model": "rnn",
"lstm_use_prev_action_reward": "store_true",
"lstm_cell_size": args.lstm_cell_size,
"custom_options": {}
},
"lr": 0.0003,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"env": args.env,
}
tune.run(
args.run,
stop={"episode_reward_mean": args.stop},
config=config,
)
@justinglibert (see post above).
@janblumenkamp and @justinglibert Thanks for you help, guys!
Awesome, thank you very much Sven! I will try it!
Updated the PR once more. Switched out the Gru for an LSTM and this works much better now on our test envs.
The PR that fixes this problem (https://github.com/ray-project/ray/pull/7797) has been merged.
Please see this example code here for a learning example with PPO:
https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_rnn_model.py
Closing this issue.