Stable-baselines: How can i get the parameters of the trained policy

Created on 6 Mar 2019  路  3Comments  路  Source: hill-a/stable-baselines

Sorry for the vague question. I effectively want to train two models in two environments, and then average their network parameters to produce a new model and see how it performs on these environments. Any advice on how i can go about doing this using stable baselines?

question

Most helpful comment

I throw in some sample code that I used once. You will need to adapt it to the policy and architecture you used:

def get_var(name):
    if not name.endswith(":0"):
        name = name + ":0"
    return model.sess.graph.get_tensor_by_name(name).eval(session=model.sess)
w0 = get_var("model/pi_fc0/w")
b0 = get_var("model/pi_fc0/b")
w1 = get_var("model/pi/w")
b1 = get_var("model/pi/b")

All 3 comments

Hello,
I recommend you to take a look at the save and load methods.
What you are looking for (the weights as numpy matrices) can be found here for SAC for instance.

I throw in some sample code that I used once. You will need to adapt it to the policy and architecture you used:

def get_var(name):
    if not name.endswith(":0"):
        name = name + ":0"
    return model.sess.graph.get_tensor_by_name(name).eval(session=model.sess)
w0 = get_var("model/pi_fc0/w")
b0 = get_var("model/pi_fc0/b")
w1 = get_var("model/pi/w")
b1 = get_var("model/pi/b")

My previous link does not work with the new PR...
So, in the following code snippet, params is a list of numpy matrices that represents the weights of the policy + the critic:
https://github.com/hill-a/stable-baselines/blob/aa62f0241bbc3c6ec651628e228241bbd18e63eb/stable_baselines/sac/sac.py#L522

Was this page helpful?
0 / 5 - 0 ratings