Cannot load custom policy.
Loading a model without an environment, this model cannot be trained until it has a valid environment.
Traceback (most recent call last):
File "/snap/pycharm-professional/89/helpers/pydev/pydevd.py", line 1664, in <module>
main()
File "/snap/pycharm-professional/89/helpers/pydev/pydevd.py", line 1658, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "/snap/pycharm-professional/89/helpers/pydev/pydevd.py", line 1068, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "/snap/pycharm-professional/89/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/SRC/pathway/rl-EXTERNAL/stable-baselines/train_test_load_custompolicy.py", line 58, in <module>
model = A2C.load("a2c_lunar")
File "/usr/local/HAMMER/DYNAMIC/SRC/pathway/rl-EXTERNAL/stable-baselines/stable_baselines/common/base_class.py", line 361, in load
model.setup_model()
File "/usr/local/HAMMER/DYNAMIC/SRC/pathway/rl-EXTERNAL/stable-baselines/stable_baselines/a2c/a2c.py", line 102, in setup_model
n_batch_step, reuse=False)
File "/SRC/pathway/rl-EXTERNAL/stable-baselines/train_test_load_custompolicy.py", line 16, in __init__
super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=0,
TypeError: super(type, obj): obj must be an instance or subtype of type
Code example
import gym
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import A2C
from stable_baselines.common.policies import ActorCriticPolicy
import tensorflow as tf
# Create and wrap the environment
env = gym.make('LunarLander-v2')
env = DummyVecEnv([lambda: env])
class CustomPolicy(ActorCriticPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **kwargs):
super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=0,
reuse=reuse, scale=True)
with tf.variable_scope("model", reuse=reuse):
activ = tf.nn.relu
extracted_features = tf.layers.flatten(self.processed_x)
latent = activ(tf.layers.dense(extracted_features, 64, name='latent_fc'))
pi_latent= activ(tf.layers.dense(latent, 64, name='pi_fc'))
vf_latent = activ(tf.layers.dense(latent, 64, name='vf_fc'))
value_fn = tf.layers.dense(vf_latent, 1, name='vf')
self.proba_distribution, self.policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)
self.value_fn = value_fn
self.initial_state = None
self._setup_init()
def step(self, obs, state=None, mask=None, deterministic=True):
action, value, neglogp = self.sess.run([self.action, self._value, self.neglogp], {self.obs_ph: obs})
return action, value, self.initial_state, neglogp
def proba_step(self, obs, state=None, mask=None):
return self.sess.run(self.policy_proba, {self.obs_ph: obs})
def value(self, obs, state=None, mask=None):
return self.sess.run(self._value, {self.obs_ph: obs})
model = A2C(CustomPolicy, env, ent_coef=0.1, verbose=1)
# Train the agent
model.learn(total_timesteps=500)
# Save the agent
model.save("a2c_lunar")
del model # delete trained model to demonstrate loading
# Load the trained agent
model = A2C.load("a2c_lunar")
# Enjoy trained agent
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
System Info
Loading works fine for me without custom policy. Ie. below works fine:
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import A2C
# Create and wrap the environment
env = gym.make('LunarLander-v2')
env = DummyVecEnv([lambda: env])
model = A2C(MlpPolicy, env, ent_coef=0.1, verbose=1)
model.learn(total_timesteps=5000)
model.save("a2c_lunar")
del model # delete trained model to demonstrate loading
# Load the trained agent
model = A2C.load("a2c_lunar")
# Enjoy trained agent
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
I found if I put the custom policy class in common/policies.py , it loads just fine.
But using exact same class code in train.py : the error as shown above.
Hi, thank you for reporting the bug. I think documentation is missing on that.
I'll take a look today.
@pathway When using custom policy, you need to pass it explicitly when loading the model.
Otherwise, it can't find the definition of your policy (That's why it works when putting the policy definition in common/policies.py). I'll add a warning in the documentation.
import gym
from stable_baselines.common.policies import FeedForwardPolicy, register_policy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import A2C
# Custom MLP policy of three layers of size 128 each
class CustomPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
layers=[128, 128, 128],
feature_extraction="mlp")
# Create and wrap the environment
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
model = A2C(CustomPolicy, env, verbose=1)
# Train the agent
model.learn(total_timesteps=1000)
model.save("test_save")
del model
model = A2C.load("test_save", env=env, policy=CustomPolicy)
Update:
The custom policy is saved as 'policy': __main__.CustomPolicy where as the default policies are saved as stable_baselines.common.policies.MlpPolicy. I am currently working on finding a fix, I'm not sure if it exists...
I found another solution but that is not as clean as passing explicitly the policy, you need to specify __module__ = None in the definition of your custom policy.
The following code works but is not the recommended solution ;)
# Custom MLP policy of three layers of size 128 each
class CustomPolicy(FeedForwardPolicy):
__module__ = None
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
layers=[8, 8],
feature_extraction="mlp")
Most helpful comment
@pathway When using custom policy, you need to pass it explicitly when loading the model.
Otherwise, it can't find the definition of your policy (That's why it works when putting the policy definition in common/policies.py). I'll add a warning in the documentation.