I would like to do some model based RL with stable baselines. From what I've read, it seems that to do this, I just have to train a model f to predict the next state, given the current state and action, i.e. s{t+1} <- f(s{t}, a{t}). I want to do this so that the RL can plan several steps into the future.
Anyways, I need access to the obs and actions from the learn method in PPO2. Basically, after each rollout, I want to use the data collected to train f in a supervised manner. I am trying to use a callback to do this. Is this possible?
Thanks!
I am not sure if this is doable with callbacks alone, but modifying existing code should not be too difficult. Environment steps for learning are done here, and the surrounding code is about collecting states/actions/etc for next updates, which sounds like the part you want to modify.
PS: I am not a model-based-RL-guy, so I can not comment on the idea you suggested.
Thank you for your quick reply! Sounds good, thanks!
Hi, you can do that through callbacks. You just need to modify callback._on_rollout_end along these lines:
def _on_rollout_end(self):
mb_obs = self.locals["mb_obs"]
mb_rewards = self.locals["mb_rewards"]
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
The code that collects the transitions is here:
Notice that in the collection of transitions we call callback.update_locals, this informs the callback of the variables in the runner. You can then access them through self.locals.
Note: This is on the latest version of sb2. Not sure if it's been pushed on pypi.
Note: This is on the latest version of sb2. Not sure if it's been pushed to the pypi.
@PartiallyTyped thanks for the response. In fact, I had tried printing the locals but didn't see the obs in there. Makes sense if its not in the pypi yet.
It’s not on pypi, the patch #787 was completed recently (^^’).
Thanks again for commenting. In case anyone wants to use the new patch, you can use
!pip install git+https://github.com/hill-a/stable-baselines.git
Thanks again for commenting. In case anyone wants to use the new patch, you can use
!pip install git+https://github.com/hill-a/stable-baselines.git
this is actually already in the documentation ;)
https://stable-baselines.readthedocs.io/en/master/guide/install.html#bleeding-edge-version