Stable-baselines: [question] How to evaluate PPO2 with MlpLnLstmPolicy trained on SubprocVecEnv having nminibatch > 1?

Created on 3 Apr 2020  路  7Comments  路  Source: hill-a/stable-baselines

The code below gives an error message

ValueError: Cannot feed value of shape (1, 4) for Tensor 'input/Ob:0', which has shape '(4, 4)'

It looks like if you are training with 4 environments, the model.predict(obs) method accepts only input of batch size 4.

On the other hand, I have tried that the evaluate function does not accept multiple environments.

AssertionError: You must pass only one environment for evaluation
import gym

from stable_baselines.common.policies import MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import PPO2
from stable_baselines.common.callbacks import EvalCallback

if __name__ == '__main__':
    N_ENV = 4
    NMINIBATCHES = 4

    def _make_env():
        return gym.make('CartPole-v1')

    train_env = SubprocVecEnv([_make_env for _ in range(N_ENV)])
    train_env = VecNormalize(train_env, norm_obs=True, norm_reward=False)

    eval_env = DummyVecEnv([_make_env])
    eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False)

    eval_callback = EvalCallback(eval_env, best_model_save_path='./model',
        log_path='./logs'
        , eval_freq=100
        , deterministic=True
        , render=False
        , n_eval_episodes=1
    )

    model = PPO2(MlpLnLstmPolicy, train_env
        , verbose=1
        , policy_kwargs={'n_lstm': 16}
        , nminibatches=NMINIBATCHES
    )

    model.learn(total_timesteps=int(4000), callback = eval_callback)
question

Most helpful comment

To be more specific, I made a new evaluate_lstm_policy() that adds the following to evaluate_policy()

zero_completed_obs = np.zeros((NMINIBATCHES,) + env.observation_space.shape)
zero_completed_obs[0, :] = obs
obs = zero_completed_obs

and make a new EvalCallback that call this evaluate_lstm_policy()

All 7 comments

This is a limitation of LSTM policies (mentioned here, although it is bit hidden), where they indeed require same amount of environments to be evaluated on.

One another thing: I think you need to use same VecNormalize for both training and evaluation to have the right normalization statistics (@araffin ?)

Got it. So I guess currently when working with LSTM policies, the only choice is to run it with nminibatches=1 and a single environment if we want to deploy our model to interact with an environment?

Since otherwise, model.predict() has to be called with obs of size (nminibatches, obs_size).

Good catch on VecNormalize, I would like to know too.

I would advise against training with single environment (known to be unstable).

if we want to deploy our model to interact with an environment?

An alternative is to pad your observations with zeros until it matches the desired shape. arrafin has a nice example here.

I think you need to use same VecNormalize for both training and evaluation to have the right normalization statistics

Yes, the EvalCallback takes care of that, but you should use training=False for the evaluation (cf rl zoo)

It seems we may need to update evaluate_policy function for the LSTM case, but I'm afraid to complexify the code too much :/

Thanks for the help, it works after I hacked the evaluate_policy and EvalCallback according to your code samples.

To be more specific, I made a new evaluate_lstm_policy() that adds the following to evaluate_policy()

zero_completed_obs = np.zeros((NMINIBATCHES,) + env.observation_space.shape)
zero_completed_obs[0, :] = obs
obs = zero_completed_obs

and make a new EvalCallback that call this evaluate_lstm_policy()

To be more specific, I made a new evaluate_lstm_policy() that adds the following to evaluate_policy()

zero_completed_obs = np.zeros((NMINIBATCHES,) + env.observation_space.shape)
zero_completed_obs[0, :] = obs
obs = zero_completed_obs

and make a new EvalCallback that call this evaluate_lstm_policy()

good and simple enough to merge it into the project to solve the eval problem

Was this page helpful?
0 / 5 - 0 ratings