Stable-baselines: Improve docs and tests for model.save() / model.load()

Created on 3 May 2019  路  7Comments  路  Source: hill-a/stable-baselines

  • [ ] Document which state is preserved and not preserved when model is reloaded
  • [ ] Add tests to ensure that the state that is documented to be preserved is actually preserved

Original question:

Is there a test to make sure that if I stop training, save the model, reload it from file, and continue training, the result will be identical compared to what would happen if I didn't stop training at all? I am mostly concerned about obscure state such as optimizer's internal variables.

documentation enhancement question

All 7 comments

Hello,

I am mostly concerned about obscure state such as optimizer's internal variables.

that is a good question. I think there are different things that are not currently stored which prevent from a perfect recover from previous training:

  • the content of replay buffer is not saved for off-policy methods (for disk usage reason)
  • the state of the optimizer (e.g. Adam) is not saved
  • the learning rate schedule may be reset when reloading a model / resuming training

However, despite those caveats, I did not experience huge drop in performance (yet) when resuming training.

Hm. Does the following sound good:

  1. Close this issue
  2. Open more issues:

    1. "Document which state is preserved and not preserved when model is reloaded"

    2. "Add tests to ensure that the state that is documented to be preserved is actually preserved"

?

I don't know, I would maybe keep that issue, what do you think @hill-a , @erniejunior or @AdamGleave ?

I think it makes sense to keep this issue open but maybe break it into a task list.

Created a task list and renamed the issue to reflect its current contents.

Any update for resuming training in an unexpected crash case?
And I also found that setting reset_num_timesteps=False in the learn function cannot create a continuous Tensorboard curve after resuming, any suggestion?

This is my code

env = gym.make('CartPole-v1')

trpo = TRPO(
    MlpPolicy, env,
    timesteps_per_batch=512,
    max_kl=0.001,
    cg_iters=10,
    cg_damping=1e-3,
    entcoeff=0.0,
    gamma=0.99,
    lam=1.0,
    vf_stepsize=1e-4,
    vf_iters=3,
    tensorboard_log='./tmp_tb_logs',
    verbose=1, n_cpu_tf_sess=None)
trpo.learn(total_timesteps=5e4)
trpo.save('./tmp_models/trpo_cartpole')

del trpo

trpo = TRPO.load('./tmp_models/trpo_cartpole', env=env)
trpo.learn(total_timesteps=10e4, reset_num_timesteps=False)

+1 to continuation of tensorboard logging when reset_num_timesteps=False

Was this page helpful?
0 / 5 - 0 ratings