Ray: [RLLib] Loading keras weights in Model and using the Tune API

Created on 9 May 2020  路  12Comments  路  Source: ray-project/ray

What is your question?

Similar to #7046, I'd like to load weights of my keras model that was trained outside of RLLib. But, I'd also like to use Tune for hyperparamter search. How do I load the pretrained model and also use tune.run for my experiments?
Ray version and other system information (Python version, TensorFlow version, OS):
Ray: 0.8.2, TF: 2.0, Linux

question tune

All 12 comments

cc @richardliaw the new tune functional API could work here. I guess we need way of specifying the resource requirements that the RLlib Trainer defines right now via a static method.

Hypothetically, you can pass in a function to resources_per_trial.

You mean an API like this?

def train(config, checkpoint=None):
    trainer = PPOTrainer(config):
    if checkpoint:
        trainer.workers.local_worker().get_policy().set_state(...)
        trainer.workers.sync_weights()
    while True:
        yield trainer.train()


tune.run(train, config, resources_per_trial=PPOTrainer.resource_request)

yeah

Hi, is there any updated for this issue?

I've been trying to load keras weights when building the model. But it seems like the experiments using tune still starts from scratch.

Can you post what you're doing, the error you're getting, and the behavior you're expecting?

@richardliaw

Sure.

I'm doing RL with a custom environment and a custom model using tune.run(agent, **run_kwargs) with ray 0.8.6 and tf = try_import_tf() (so graph mode under tf2).

We have a working code and the agent learns to solve the problem.

Then we used an external dataset to train the agent in tf2 using model.fit, the trained agent achieved a satisfying performance, let's say reward > 0.5 with reward range [0, 1]. So the objective is to load this model for RL runs to boost the training.

The custom model is roughly defined as follows:

class CustomModel(ABC, TFModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config: Dict, name: str):
        super(CustomModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)

        # create the model and register its variables
        self.model = self.build_model(model_config["custom_model_config"], name)
        self.register_variables(self.model.variables)
        self.model.summary()

    def build_model(self, model_config, name):
        inputs = ...
        logits, value = BackboneModel(
            model_config=model_config,
            name=name,
        )(inputs)

        model = tf.keras.Model(inputs, [logits, value])
        return model

So the problem is how to load the weights, a h5 file saved from TensorFlow Checkpoint callback. Here we have both the weights of the entire model or the backbone.

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(save_weights_only=True)

We've tried:

load the entire model before/after register_variables

self.model = self._build_model(model_config["custom_model_config"], name)
self.model.load_weights("weights.h5")
self.register_variables(self.model.variables)
self.model = self._build_model(model_config["custom_model_config"], name)
self.register_variables(self.model.variables)
self.model.load_weights("weights.h5")

for these methods, I've printed the weights after load_weights and it seems like the weights have been changed.

then we've tried to load only the backbone model inside build_model

    def build_model(self, model_config, name):
        inputs = ...
        backbone = BackboneModel(
            model_config=model_config,
            name=name,
        )
        logits, value = backbone(inputs)
        backbone.load_weights("backbone_weights.h5")

        model = tf.keras.Model(inputs, [logits, value])
        return model

However all these methods didn't provide a significant difference to the RL runs.

We suspect if the weights have been properly restored. So the question is which should be the correct way to load a keras model for experiments using tune.run?

Here we can not use restore as I believe the required inputs are different from the keras saved checkpoints or h5 files?

Later I will try to reproduce this problem using cartepole environment and a toy network to see if the loading is still not working.

Ideally, we would like to just pass a file path via configs and define a function for the custom model as below, so that all workers will load the given weights. Is this possible?

    def import_from_h5(self, import_file):
        """
        https://github.com/ray-project/ray/issues/7046
        :param import_file:
        :return:
        """
        # Override this to define custom weight loading behavior from h5 files.
        self.model.load_weights(import_file)

Thanks ;)

Hmm, yeah I think this is probably not exactly related to Tune. It sounds like somehow initialization of training is overriding your model weights. Two tips:

  1. Try setting lr = 0 - see if it matches the performance you expect.
  2. Try printing or watching the model weights during training - checking that the model weights don't change dramatically will help.

https://github.com/keras-team/keras/issues/2378#issuecomment-394126640

@richardliaw Thanks! Will give it a try!

So finally I have a working script using custom env and network.
I suspect the problem come from the saving/loading of the weights with h5 file as mentioned in #7046.

So with a numpy approach, this is working, means that the second tune run with loading will starts with a >100 reward and with lr=0, optimizer will not influence the weights.


import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils import try_import_tf
import gym
from gym.spaces import Discrete, Box
from ray import tune
import os
import shutil
import logging
import numpy as np

tf = try_import_tf()

# Change this directory
out_dir = "/Users/mathpluscode/ray_results/load_tune"


class MyKerasModel(TFModelV2):
    """Custom model for policy gradient algorithms."""

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(MyKerasModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        self.inputs = tf.keras.layers.Input(shape=obs_space.shape, name="observations")
        layer_1 = tf.keras.layers.Dense(
            16, name="layer1", activation=tf.nn.relu, kernel_initializer=normc_initializer(1.0)
        )(self.inputs)
        layer_out = tf.keras.layers.Dense(
            num_outputs, name="out", activation=None, kernel_initializer=normc_initializer(0.01)
        )(layer_1)
        value_out = tf.keras.layers.Dense(
            1, name="value", activation=None, kernel_initializer=normc_initializer(0.01)
        )(layer_1)
        self.base_model = tf.keras.Model(self.inputs, [layer_out, value_out])

        self.register_variables(self.base_model.variables)

    def forward(self, input_dict, state, seq_lens):
        model_out, self._value_out = self.base_model(input_dict["obs"])
        return model_out, state

    def value_function(self):
        return tf.reshape(self._value_out, [-1])

    def import_from_h5(self, import_file):
        # Override this to define custom weight loading behavior from h5 files.
        self.base_model.load_weights(import_file)


class SimpleCorridor(gym.Env):
    def __init__(self, config):
        self.end_pos = config["corridor_length"]
        self.cur_pos = 0
        self.action_space = Discrete(2)
        self.observation_space = Box(0.0, self.end_pos, shape=(1,))

    def reset(self):
        self.cur_pos = 0
        return [self.cur_pos]

    def step(self, action):
        if action == 0 and self.cur_pos > 0:
            self.cur_pos -= 1
        elif action == 1:
            self.cur_pos += 1
        done = self.cur_pos >= self.end_pos
        return [self.cur_pos], 1 if done else 0, done, {}


def get_weights_from_agent(agent):
    return agent.get_weights()["default_policy"]["default_policy/value/kernel"][0]


def train_fn_save(config, reporter):
    agent = ppo.PPOTrainer(config=config)
    print("SAVE: after init, before train", get_weights_from_agent(agent))
    for _ in range(10):
        result = agent.train()
        reporter(**result)
    print("SAVE: after train, before save", get_weights_from_agent(agent))
    np.save(os.path.join(out_dir, "weights.npy"), agent.get_policy().get_weights())
    print("SAVE: after save", get_weights_from_agent(agent))
    agent.stop()


def train_fn_load(config, reporter):
    agent = ppo.PPOTrainer(config=config)

    weight_path = os.path.join(out_dir, "weights.npy")
    print("LOAD: after init, before load", get_weights_from_agent(agent))

    weight = np.load(weight_path, allow_pickle=True).item()
    agent.workers.local_worker().get_policy().set_weights(weight)
    agent.workers.sync_weights()

    print("LOAD: before train, after load", get_weights_from_agent(agent))
    for _ in range(5):
        result = agent.train()
        reporter(**result)
    agent.stop()


if __name__ == "__main__":
    ray.init(logging_level=logging.ERROR)

    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir)
    assert os.path.exists(out_dir)

    ModelCatalog.register_custom_model("keras_model", MyKerasModel)
    config = {
        "log_level": "ERROR",
        "lr": 0.01,
        "num_workers": 0,
        "model": {"custom_model": "keras_model"},
        "env": "CartPole-v0",
        "env_config": {"corridor_length": 5},
    }

    agent = ppo.PPOTrainer

    # generate weights
    analysis = tune.run(train_fn_save, config=config)
    print(analysis.get_best_trial("episode_reward_mean"))

    # load weights
    config["lr"] = 0.0  # remove learning
    tune.run(train_fn_load, config=config)

    # load weights to agent
    agent_instance = ppo.PPOTrainer(config)
    weight = np.load(os.path.join(out_dir, "weights.npy"), allow_pickle=True).item()
    agent_instance.get_policy().set_weights(weight)
    print("weight after agent save_weights", get_weights_from_agent(agent_instance))
    result = agent_instance.train()
    print(result["episode_reward_mean"])

And here is a working example with h5 file, the problem was due to the setting of graph and session.

import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils import try_import_tf
import gym
from gym.spaces import Discrete, Box
from ray import tune
import os
import shutil
import logging

tf = try_import_tf()

# Change this directory
out_dir = "/Users/mathpluscode/ray_results/load_tune"


class MyKerasModel(TFModelV2):
    """Custom model for policy gradient algorithms."""

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super(MyKerasModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
        self.inputs = tf.keras.layers.Input(shape=obs_space.shape, name="observations")
        layer_1 = tf.keras.layers.Dense(
            16, name="layer1", activation=tf.nn.relu, kernel_initializer=normc_initializer(1.0)
        )(self.inputs)
        layer_out = tf.keras.layers.Dense(
            num_outputs, name="out", activation=None, kernel_initializer=normc_initializer(0.01)
        )(layer_1)
        value_out = tf.keras.layers.Dense(
            1, name="value", activation=None, kernel_initializer=normc_initializer(0.01)
        )(layer_1)
        self.base_model = tf.keras.Model(self.inputs, [layer_out, value_out])

        self.register_variables(self.base_model.variables)

    def forward(self, input_dict, state, seq_lens):
        model_out, self._value_out = self.base_model(input_dict["obs"])
        return model_out, state

    def value_function(self):
        return tf.reshape(self._value_out, [-1])

    def import_from_h5(self, import_file):
        # Override this to define custom weight loading behavior from h5 files.
        self.base_model.load_weights(import_file)


class SimpleCorridor(gym.Env):
    def __init__(self, config):
        self.end_pos = config["corridor_length"]
        self.cur_pos = 0
        self.action_space = Discrete(2)
        self.observation_space = Box(0.0, self.end_pos, shape=(1,))

    def reset(self):
        self.cur_pos = 0
        return [self.cur_pos]

    def step(self, action):
        if action == 0 and self.cur_pos > 0:
            self.cur_pos -= 1
        elif action == 1:
            self.cur_pos += 1
        done = self.cur_pos >= self.end_pos
        return [self.cur_pos], 1 if done else 0, done, {}


def get_weights_from_agent(agent):
    return agent.get_weights()["default_policy"]["default_policy/value/kernel"][0]


def train_fn_save(config, reporter):
    agent = ppo.PPOTrainer(config=config)
    print("SAVE: after init, before train", get_weights_from_agent(agent))
    for _ in range(10):
        result = agent.train()
        reporter(**result)
    print("SAVE: after train, before save", get_weights_from_agent(agent))
    with agent.get_policy()._sess.graph.as_default():
        with agent.get_policy()._sess.as_default():
            agent.get_policy().model.base_model.save_weights(os.path.join(out_dir, "weights.h5"))
    print("SAVE: after save", get_weights_from_agent(agent))
    agent.stop()


def train_fn_load(config, reporter):
    agent = ppo.PPOTrainer(config=config)
    print("LOAD: after init, before load", get_weights_from_agent(agent))

    agent.workers.local_worker().get_policy().import_model_from_h5(os.path.join(out_dir, "weights.h5"))
    agent.workers.sync_weights()

    print("LOAD: before train, after load", get_weights_from_agent(agent))
    for _ in range(5):
        result = agent.train()
        reporter(**result)
    agent.stop()


if __name__ == "__main__":
    ray.init(logging_level=logging.ERROR)

    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir)
    assert os.path.exists(out_dir)

    ModelCatalog.register_custom_model("keras_model", MyKerasModel)
    config = {
        "log_level": "ERROR",
        "lr": 0.01,
        "num_workers": 0,
        "model": {"custom_model": "keras_model"},
        "env": "CartPole-v0",
        "env_config": {"corridor_length": 5},
    }

    agent = ppo.PPOTrainer

    # generate weights
    analysis = tune.run(train_fn_save, config=config)
    print(analysis.get_best_trial("episode_reward_mean"))

    # load weights
    config["lr"] = 0.0  # remove learning
    tune.run(train_fn_load, config=config)

Update: instead of defining our own trainable function, inheriting the PPOTrainer class might be better as it inherits all the functions. We only need to load the weights inside the init function.

class LoadablePPOTrainer(ppo.PPOTrainer):
    def __init__(self, config, **kwargs):
        super(LoadablePPOTrainer, self).__init__(config, **kwargs)
        self.workers.local_worker().get_policy().import_model_from_h5(os.path.join(out_dir, "weights.h5"))
        self.workers.sync_weights()

tune.run(LoadablePPOTrainer, stop={"episodes_total": 50}, config=config)

nice!

Was this page helpful?
0 / 5 - 0 ratings