System information
Experiment
I have been experimenting with PBT for the training of a convnet in pytorch. Everything is working fine but I am a bit frustrated with the checkpointing. My _save and _restore methods are similar to https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_convnet_example.py
I am running an experiment with 4 samples, each one on a single gpu. My stopping condition is an accuracy threshold.
Questions
1 - Is it possible to stop all the trials as soon as one reaches the stopping condition? For example, in the training below, the blue trial reached the stopping condition and the orange should be killed.

2 - The top performers models are saved in temp directories during training. How is it possible for me to recover them if the training crashes / I want to kill it early? I understand that I could checkpoint every epoch with checkpoint_freq but it seems a bit sub-optimal. Is there a way to save the current best model in the trial local directory?
1 - Is it possible to stop all the trials as soon as one reaches the stopping condition? For example, in the training below, the blue trial reached the stopping condition and the orange should be killed.
see https://ray.readthedocs.io/en/latest/tune-usage.html#custom-stopping-criteria, pass in a stateful custom stopping criteria function as in the example
Is there a way to save the current best model in the trial local directory?
Set checkpoint_score_attr to whatever metric you want to use to determine how to score the checkpoints. Set keep_checkpoints_num so that the worst checkpoints are deleted.
@ujvl Thanks for the answers. I wasn't aware of the checkpoint_score_attr.
IMO both functions mentioned are required by many users. Maybe we can change the current pbt examples to demo the usage? I can draft the PR(or just modify the existing one (https://github.com/ray-project/ray/pull/6533). cc @richardliaw
Yeah, that sounds good (to modify existing one).
On Mon, Dec 23, 2019 at 6:58 PM Yuhao Yang notifications@github.com wrote:
@ujvl https://github.com/ujvl Thanks for the answers. I wasn't aware of
the checkpoint_score_attr.IMO both functions mentioned are required by many users. Maybe we can add
change the current pbt examples to demo the usage? I can draft the PR(or
just modify the existing one (#6533
https://github.com/ray-project/ray/pull/6533). cc @richardliaw
https://github.com/richardliaw—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/ray-project/ray/issues/6558?email_source=notifications&email_token=ABCRZZKMASEXB42WIFMNHYDQ2D34ZA5CNFSM4J54EXP2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEHRT3ZA#issuecomment-568540644,
or unsubscribe
https://github.com/notifications/unsubscribe-auth/ABCRZZIJDLDKMMH4E6ZXM5LQ2D34ZANCNFSM4J54EXPQ
.
Feel free to reopen this if you see any other issues.
I tried checkpoint with checkpoint_score_attr but I think it's only generating checkpoints in memory, I found nothing written to the local directory. I was trying on my desktop.
cc @ujvl @richardliaw to confirm if it's expected. Thanks.
That sounds incorrect; did you try this without a scheduler?
With pbt, Here's the complete code:
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# __tutorial_imports_begin__
import argparse
import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
get_data_loaders
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.util import validate_save_restore
from ray.tune.trial import ExportFormat
# __tutorial_imports_end__
# __trainable_begin__
class PytorchTrainble(tune.Trainable):
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
scheduler. The example reuse some of the functions in mnist_pytorch,
and is a good demo for how to add the tuning function without
changing the original training code.
"""
def _setup(self, config):
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet()
self.optimizer = optim.SGD(
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def _train(self):
train(self.model, self.optimizer, self.train_loader)
acc = test(self.model, self.test_loader)
return {"mean_accuracy": acc}
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path
def _restore(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))
def _export_model(self, export_formats, export_dir):
if export_formats == [ExportFormat.MODEL]:
path = os.path.join(export_dir, "exported_convnet.pt")
torch.save(self.model.state_dict(), path)
return {export_formats[0]: path}
else:
raise ValueError("unexpected formats: " + str(export_formats))
def reset_config(self, new_config):
for param_group in self.optimizer.param_groups:
if "lr" in new_config:
param_group["lr"] = new_config["lr"]
if "momentum" in new_config:
param_group["momentum"] = new_config["momentum"]
self.config = new_config
return True
# __trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
datasets.MNIST("~/data", train=True, download=True)
# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainble)
validate_save_restore(PytorchTrainble, use_object_store=True)
# __pbt_begin__
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.0001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
class Stopper:
def __init__(self):
self.should_stop = False
def stop(self, trial_id, result):
max_iter = 5 if args.smoke_test else 100
if not self.should_stop and result["mean_accuracy"] > 0.96:
self.should_stop = True
return self.should_stop or result["training_iteration"] >= max_iter
stopper = Stopper()
analysis = tune.run(
PytorchTrainble,
name="pbt_test",
scheduler=scheduler,
reuse_actors=True,
verbose=1,
stop=stopper.stop,
export_formats=[ExportFormat.MODEL],
checkpoint_score_attr="mean_accuracy",
keep_checkpoints_num=7,
num_samples=4,
config={
"lr": tune.uniform(0.001, 1),
"momentum": tune.uniform(0.001, 1),
})
# __tune_end__
Ah, try setting checkpoint_freq=3. That should trigger the checkpointing
mechanism.
On Fri, Dec 27, 2019 at 8:59 PM Yuhao Yang notifications@github.com wrote:
Here's the complete code:
!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function__tutorial_imports_begin__
import argparse
import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
get_data_loadersimport ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.util import validate_save_restore
from ray.tune.trial import ExportFormat__tutorial_imports_end__
__trainable_begin__
class PytorchTrainble(tune.Trainable):
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
scheduler. The example reuse some of the functions in mnist_pytorch,
and is a good demo for how to add the tuning function without
changing the original training code.
"""def _setup(self, config): self.train_loader, self.test_loader = get_data_loaders() self.model = ConvNet() self.optimizer = optim.SGD( self.model.parameters(), lr=config.get("lr", 0.01), momentum=config.get("momentum", 0.9)) def _train(self): train(self.model, self.optimizer, self.train_loader) acc = test(self.model, self.test_loader) return {"mean_accuracy": acc} def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "model.pth") torch.save(self.model.state_dict(), checkpoint_path) return checkpoint_path def _restore(self, checkpoint_path): self.model.load_state_dict(torch.load(checkpoint_path)) def _export_model(self, export_formats, export_dir): if export_formats == [ExportFormat.MODEL]: path = os.path.join(export_dir, "exported_convnet.pt") torch.save(self.model.state_dict(), path) return {export_formats[0]: path} else: raise ValueError("unexpected formats: " + str(export_formats)) def reset_config(self, new_config): for param_group in self.optimizer.param_groups: if "lr" in new_config: param_group["lr"] = new_config["lr"] if "momentum" in new_config: param_group["momentum"] = new_config["momentum"] self.config = new_config return True__trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()ray.init() datasets.MNIST("~/data", train=True, download=True) # check if PytorchTrainble will save/restore correctly before execution validate_save_restore(PytorchTrainble) validate_save_restore(PytorchTrainble, use_object_store=True) # __pbt_begin__ scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="mean_accuracy", mode="max", perturbation_interval=5, hyperparam_mutations={ # distribution for resampling "lr": lambda: np.random.uniform(0.0001, 1), # allow perturbations within this set of categorical values "momentum": [0.8, 0.9, 0.99], }) # __pbt_end__ # __tune_begin__ class Stopper: def __init__(self): self.should_stop = False def stop(self, trial_id, result): max_iter = 5 if args.smoke_test else 100 if not self.should_stop and result["mean_accuracy"] > 0.96: self.should_stop = True return self.should_stop or result["training_iteration"] >= max_iter stopper = Stopper() analysis = tune.run( PytorchTrainble, name="pbt_test", scheduler=scheduler, reuse_actors=True, verbose=1, stop=stopper.stop, export_formats=[ExportFormat.MODEL], checkpoint_score_attr="mean_accuracy", keep_checkpoints_num=7, num_samples=4, config={ "lr": tune.uniform(0.001, 1), "momentum": tune.uniform(0.001, 1), }) # __tune_end__—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/ray-project/ray/issues/6558?email_source=notifications&email_token=ABCRZZNOXQ5R4JPAYMUXTBDQ23MMZA5CNFSM4J54EXP2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEHYCGQY#issuecomment-569385795,
or unsubscribe
https://github.com/notifications/unsubscribe-auth/ABCRZZJZ6ID3YNTP3J5O7BLQ23MMZANCNFSM4J54EXPQ
.
Thanks. That makes sense. Sorry for the false alarm~
On a second thought, checkpoint_freq may not always help. E.g. when we set checkpoint_freq = 5, it will ignore the best models between the checkpoints intervals.
With checkpoint_score_attr and keep_checkpoints_num, the tuning process should be able to compare the current result with history results and keep all the best checkpoints from each training iteration. The comparison cost should be insignificant as long as we don't checkpoint for each iteration.
Maybe when checkpoint_score_attr and keep_checkpoints_num are set, checkpoint_freq can be ignored.
Appreciate your feedback. @richardliaw @ujvl Thanks.
you can implement your own checkpoint policy doing this comparison within _train and returning "SHOULD_CHECKPOINT": True ad-hoc in the result. this will force the checkpoint, and of course you can then disable checkpoint_freq as you mentioned
Thanks for the reply @ujvl
Your suggestion should work. I'm thinking about two things:
checkpoint_score_attr and keep_checkpoints_num?checkpoint_freq is set. I'm not sure if it's right.checkpoint_score_attr, keep_checkpoints_num, and checkpoint_freq.checkpoint_freq. This is fine for me.
Most helpful comment
see https://ray.readthedocs.io/en/latest/tune-usage.html#custom-stopping-criteria, pass in a stateful custom stopping criteria function as in the example
Set
checkpoint_score_attrto whatever metric you want to use to determine how to score the checkpoints. Setkeep_checkpoints_numso that the worst checkpoints are deleted.