Stable-baselines: Converting a model into PyTorch

Created on 14 Jun 2019  路  7Comments  路  Source: hill-a/stable-baselines

Hi,

I'm trying to load a pre-trained model and convert it into a PyTorch model but can't get it to work and was wondering if someone could help me.

I'm able to load the pre-trained model using stable baselines and copy the weights over to a PyTorch model. But then when I play the game the pytorch agent is not able to get the same score as the baselines agent and i am not sure why. It could potentially be because the baselines agent does some extra pre-processing behind the scenes besides just normalising the state to the 0-1 range? Is anyone able to help me?

I've made a colab to demonstrate the problem here: https://colab.research.google.com/drive/1-IIjA1oKUjg5eoctajpl06OoHzU-5-_9

question

Most helpful comment

@p-christ I made a working colab notebook: https://colab.research.google.com/drive/1XwCWeZPnogjz7SLW2kLFXEJGmynQPI-4

The problem came from tensorflow/pytorch differences, not SB.

Closing the issue.

All 7 comments

Hello,

It could potentially be because the baselines agent does some extra pre-processing behind the scenes besides just normalising the state to the 0-1 range?

There is no such up to my knowledge. However, I would recommend you first trying with a simpler network (e.g. on CartPole-v1), because it gets tricky with convolution (different conventions for pytorch and tensorflow).
I would also suggest you to check all shapes and names before assigning the weights, using .named_parameters() for pytorch.

EDIT: I made it work with CartPole, will share the notebook soon

Update: I made it work with CartPole, you can find the notebook here: https://colab.research.google.com/drive/1R-wHO2gLQScx46EIjqj7Sj6MjK-i5Hey

Will try to make it work with the cnn if I have some time this weekend.

Update: I'm working on the CNN now, it seems that the problem comes with the first fully connected layer (the conv layer outputs the right thing).

The problem comes from the reshape (from conv to fc)

@p-christ , I solved the issue doing that before flattening:

# shape before flattening
# tf: (?, 7, 7, 64)
# pytorch: [1, 64, 7, 7]
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(x.size(0), -1)

@p-christ I made a working colab notebook: https://colab.research.google.com/drive/1XwCWeZPnogjz7SLW2kLFXEJGmynQPI-4

The problem came from tensorflow/pytorch differences, not SB.

Closing the issue.

thanks a lot

Was this page helpful?
0 / 5 - 0 ratings