It's been noted a few times that TFPolicyGraph is hard to understand for new developers. I think this is due to a few reasons:
Similar problems apply to the TorchPolicyGraph (though there one solution may be to just extend PolicyGraph directly, which is done for QMIX and seems cleaner).
This issue is to solicit suggestions on how to improve this. A few ideas:
cc @hartikainen @richardliaw @gehring @pcmoritz @joneswong @eugenevinitsky
@ericl I think all problem points are correct. In general, I find that several of the classes are trying to do too many things which 1) bloats the class, making hard to parse 2) prevents code from being placed in more intuitive places.
Also, the nomenclature of classes inside rllib is throwing me off at times. For example, the structure and design of the Agent class (and its subclasses) feels very unnatural to me. This is because it doesn't look like any other agent implementation I've every seen and is far different from what I would implement if I were to ask to implement an agent from scratch. After spending some time going through the code, I understand that it is structure that way to fit the Trainable API, but some parts of it still feels weird to me (e.g., Agent has three initialization methods _setup, _init, __init__).
I know want to focus on TFPolicyGraph, but I think properly simplifying that class will involve thinking about the whole flow and possibly require some refactoring in several classes.
As for your specific suggestions:
Break up the TFPolicyGraph class into smaller pieces (e.g., separate functions), that are combined together with some "builder" function.
100% agree to the breaking up. I think very specific, light classes make for the most readable code. I'm a bit apprehensive about the builder approach though.
The reason for this is because a major source of frustration for me is comes from encountering a function call which executes code that I can't quickly locate. As a result, understanding some aspect of a class requires a much larger effort to understand than it should, and often requires an understanding of many more classes than you really want to know.
This can be caused by lots of things (sometimes my own stupidity gets in the way) but I commonly see it happen when carrying around class constructors, or when the code relies on the side-effects of methods called elsewhere (e.g., much earlier in time and from some class not relevant to what I'm trying to debug/understand), or when functions get assigned dynamically.
I would be afraid that a builder like pattern would make the code harder to parse. Though I am still opened to the idea.
Instead of overrides, pass in explicit functions into the TFPolicyGraph constructor?
I think if the classes are more specialized and leaner, overrides should be fine. Related to my last point, I would expect that passing functions will make the code hard to read for first time users.
Better documentation and examples?
That is always good!
For the agent class in particular, I think one tweak we could make is renaming _init(self) to _init(self, config, env_creator), which at least makes the input state explicit rather than implicit through self. Any other suggestions?
I was also thinking about a build-like pattern for Agent itself, though I still don't know if it's a net win. The code would certainly be clearer at a glance, but you would be adding another abstraction. It could look something like this:
def initial_state(config):
return {
"last_target_update": 0,
}
def nstep_post(traj):
return traj
def update_target_if_needed(state, optimizer, policy_graph, config):
if optimizer.num_sampled - state["last_target_update"] > config["target_upd_freq"]:
state["last_target_update"] = optimizer.num_sampled
policy_graph.update_target()
DQNAgent = Agent.build(
name="DQN",
default_config=DEFAULT_CONFIG,
policy_graph=DQNPolicyGraph,
postprocessing=nstep_post,
initial_state=initial_state,
optimizer=SyncSamplesOptimizer,
after_optimizer_step=update_target_if_needed)
It would be even better to split DQNPolicyGraph into DQNLoss and DQNPolicy. I don't know how hard this would be though. Perhaps if DQNLoss took in DQNPolicy as a constructor parameter.
Concretely, I think one approach would be to split TFPolicyGraph into TFLoss and TFPolicy (these two classes provide the common utilities for building up learn_on_batch(), compute_actions(), and so on). This should be a pretty straightforward refactoring if TFLoss can have a reference to the policy object.
For example, for the simple policy grad algorithm there would be three classes:
pg_policy = PGPolicy(obs_space, action_space, ...)
pg_loss = PGLoss(pg_policy)
pg_policy_graph = TFPolicyGraph(PGPolicy, PGLoss, compute_advantages_postprocessor)
Note that we still need to preserve the concept of a policy graph class, which is just to avoid "boiling the ocean" in this refactor. A more aggressive later refactoring could also get rid of that concept entirely.
I'm using rllib on a problem with continuous time. I ran into some of the complexity described in this discussion while implementing the equivalent of pcontinues in deepmind trfl/trfl/value_ops.py for DQNPolicyGraph. So I figured I would add my experience to the discussion. As a disclaimer, I haven't worked with rllib for super long so maybe there are workarounds for these issues I'm just not aware of.
The first, sort of trivial issue is that DQNAgent uses the DQNPolicyGraph by default. I tried to keep my changes local to DQNPolicyGraph but ended up duplicating DQNAgent to make using a custom graph simple. The build-like pattern described above would seem to solve this problem, but any other way of overriding the defaults would work as well.
Second, I thought the most convenient way to pass the information about environment timesteps needed for computing the proper discounting would be to add fields like env_time_t, env_time_tp1 to the sample_batch, which is passed to postprocess_trajectory. However, it seems that the fields in a batch are manipulated by multiple classes on the path from the environment to compute_td_error, so to avoid having to change too much I just included the time as a part of the observations. Changing the batch handling to something more flexible might not be in the scope of simplifying the policy graph code, so maybe this idea is more for the "aggressive later refactoring".
I was also thinking about a build-like pattern for Agent itself, though I still don't know if it's a net win.
In general, I really like the move towards more explicit passing of arguments (e.g., that Agent.build example and the changes to _init()). In a perfect world, I would split the config dicts such that all relevant parameters are passed as keyword arguments. Being able to see what parameters can be configured from looking at the constructor is very transparent while packed configs are very opaque.
Just to illustrate what I am referring to:
# seeing what can be configured is tedious
class SomeAgentV1:
def __init__(self, config):
self.config = config
# this function could be somewhere else, i.e., specified by a super/sub-class
def setup(self):
# do stuff with param1 and param2, maybe build the model here
...
do_something_with_param1(self.config["my_agent_param1"])
...
do_something_with_param2(self.config["my_agent_param2"])
...
# I probably have to go through the function's code just to know what to specify.
self.model = self.my_model_constructor(self.config["model_config"])
# seeing what can be configured is easy
class SomeAgentV2:
def __init__(self, my_agent_param1, my_agent_param2, model_config):
self.param1 = my_agent_param1
self.param2 = my_agent_param2
# if I lookup this function signature, I will know right away what the config can be
self.model = self.my_model_constructor(**model_config)
def setup(self):
# do stuff with param1 and param2, maybe build the model here instead
...
do_something_with_param1(self.param1)
...
do_something_with_param2(self.param2)
However, I'm not suggesting this is something worth refactoring everything for. As you say, let's keep the ocean from boiling ;)
Otherwise, at a high level, it still isn't completely clear to me why Agent subclasses must have that level of complexity. I think part of it comes from the fact that there is confusion with regards to what Trainable is (at an intuitive level). I think tune's Trainable API is great but its purpose differs from that of an agent API. Splitting the training logic from the agent, e.g., having an Agent and a AgentTrainable/RLLibTrainable class, could go a long way in making the code clear and accessible for users who which to contribute agents.
As for the builder approach specifically, it looks like this is an alternative to splitting the Agent class. I think splitting things will make for code which is easier to follow, but the Agent.build would likely be easier to integrate. Though those ideas aren't necessarily mutually exclusive, I feel like if the classes were lighter, it wouldn't feel as useful. What do you think?
Concretely, I think one approach would be to split TFPolicyGraph into TFLoss and TFPolicy
I think this is a good idea!
Ultimately, I would like the concept of an "Agent" to go away. For one, it's confusing when dealing with multi-agent environments, where the actual agents are entities in the environment.
The right name is probably something like "Algorithm" or "Trainer", which makes the following read nicely:
trainer = IMPALATrainer(config, env)
trainer.train()
trainer.get_policy("pol1")
Re: passing configs explicitly: I'm not sure this is practical for the top-level config given the huge number of common configs agent has (and in particular, the need for some agents to override the default values of some of these configs). We should definitely try to pass them explicitly to sub-components though. For example, this is already done for PolicyEvaluator and subclasses of PolicyOptimizer. But not for models.
Re: splitting Agent: I'm not sure what this means, given how small most agent subclasses are (if you ignore all the boilerplate configuration validation). We could probably actually eliminate the need for subclasses if the agent superclass got a bit more functionality.
Ultimately, I would like the concept of an "Agent" to go away. For one, it's confusing when dealing with multi-agent environments, where the actual agents are entities in the environment.
I think that is a good solution! I agree with you. This would likely make the code clearer and more intuitive.
Re: splitting Agent
I think with what you just proposed, this isn't relevant anymore. I was trying to solve the same problem but you are probably right, there wouldn't be much left if you try to separate the idea of an agent from the "Trainer".
Re: passing configs explicitly
I completely agree! I'm sorry, I was a little unclear. I think it is completely fair to have a packed high level config to be passed around by the high-level API. My comment was mostly about how and where the config is unpacked -- instead of unpacking "dynamically" inside various functions, configs could be unpacked in the constructor call or some other function call (which is quite convenient with the my_function(**config) syntax).
Here's a shot at a build-like pattern for trainer: https://gist.github.com/ericl/0d3502f204c7612a429bfd3c3aba0307
A plus I think is that it makes the overall training workflow more clear at a glance:
PPOTrainer = build_trainer(
"PPO",
default_config=DEFAULT_CONFIG,
policy_graph=PPOPolicyGraph,
make_optimizer=make_optimizer,
validate_config=validate_config,
after_optimizer_step=update_kl,
before_train_step=warn_about_obs_filter,
after_train_result=warn_about_bad_reward_scales)
This is something that could also be achieved by using more helper methods in the trainer subclasses though. I guess the build-like pattern is kind of a forcing function for that.
So here's what I'm thinking are possible next steps:
build_trainer pattern, and document this as the recommended way to add a new algorithm / modify a training workflow. This could be done incrementally.Here's a start at modularization (only implemented for the simple pg example): https://github.com/ray-project/ray/pull/4795/files
You can a take look into the patch for more details, but at a high level it lets you compose policy graphs and trainers much more concisely. Also, the loss function is defined "eagerly" so there is no need for placeholders (the placeholders are auto-created based on the policy outputs).