Hello, I wonder how am I suppose to set the global_step from checkpoint for handlers?
Specifically, i would like to have most of the handlers in tensorboard logger using engine.state.iteration as global step.
Since engine.state.iteration is reset when i do engine.run. I had to create my own counter (like the following), which basically add an offset to current iteration, and pass it to all the global_step_transform arguments in output handler.
class global_step_counter:
def __init__(self, resume_from_iter):
self.resume_from_iter = resume_from_iter
return
def set_starting_val(self, new_val):
self.resume_from_iter = new_val
return
def __call__(self, engine):
return engine.state.iteration + self.resume_from_iter
However, i realize the following handlers does not provide global_step_transform option
Thank you
Hi @Luhan-Cheng, could you detail your use-case please?
We provide global_step_transform to transform the global_step for evaluators as their internal state is not related to the trainer's one. So, it means if we attach a handler to an evaluator without global_step_transform like that
tb_logger.attach(
evaluator,
log_handler=OutputHandler(
tag="training",
metric_names=["nll", "accuracy"],
# For illustration purposes, without global_step_transform
# global_step_transform=global_step_from_engine(trainer)
),
event_name=Events.EPOCH_COMPLETED
)
we will log values always with global step equal 1 as evaluator runs once over validation dataset (1 epoch).
In order, to provide additional info from the trainer who counts training epochs/iterations we set global_step_transform=global_step_from_engine(trainer) as in all our examples : https://pytorch.org/ignite/master/contrib/handlers.html#ignite.contrib.handlers.tensorboard_logger.TensorboardLogger
Other handlers you cited
OptimizerParamsHandler
WeightsScalarHandler
GradsScalarHandler
WeightsHistHandler
GradsHistHandler
are normally attached to the trainer, so automatically they have global step defined by the training.
Specifically, i would like to have most of the handlers in tensorboard logger using engine.state.iteration as global step.
So, I would say, this should be out-of-the-box.
Since engine.state.iteration is reset when i do engine.run. I had to create my own counter (like the following), which basically add an offset to current iteration, and pass it to all the global_step_transform arguments in output handler.
Seems like you would like to resume the training ? You can use engine.load_state_dict to setup starting iteration/epoch.
Thank you @vfdev-5 for fast response,
Yes, I would like to resume training. I did run trainer.load_state_dict and it set the trainer.state.iteration to the final iteration last run. However, as long as i run trainer.run(...) It reset the iteration to 0 in https://github.com/pytorch/ignite/blob/3ab78a21fff39933a8623f3070fe9fe4770723bc/ignite/engine/engine.py#L607
Please, provide a code snippet of your use-case. It would be helpful.
State is reset when 1) either it is not provided previously or 2) training is done.
Training function
def train():
args = parse_argument()
config = load_config(str(args.config))
logging.info("Load configuration {}".format(
json.dumps(config, indent=4, sort_keys=False)))
device = torch.device('cuda' if torch.cuda.is_available()
and config['train']['cuda'] else 'cpu')
image_dataset = DataLoader(Dataset(
config['dataset_path'], image_num=config['train']['image_num']), batch_size=config['train']['batch_size'])
model = VAE(z_channel=config['model']['z_channel'],
encoder_channels=config['model']['encoder_channels']).to(device)
logging.debug(f'Model: {model}')
logging.info('Encoder summary')
sample_image_batch = next(iter(image_dataset)).to(device)
summary(model.encode, sample_image_batch.shape[1:])
with torch.no_grad():
sample_output_shape = model.encode(sample_image_batch).shape
logging.info('Decoder summary')
z_shape = (config['model']['z_channel'], *sample_output_shape[-2:])
summary(model.decode, z_shape)
lr_scheduler, optimizer = create_learning_scheduler_and_optimizer(
config, model)
logdir = Path(config['train']['log_dir'])
trainer, tb_logger = create_train_engine(
model, optimizer, loss_function, device, lr_scheduler, logdir)
# empty graph written to log. profiler execution seems to fail but it is not suppose to affect result
#tb_logger.writer.add_graph(model.encode, input_to_model=sample_image_batch)
rand_z = torch.Tensor(np.random.rand(*z_shape)).to(device)
tb_logger.writer.add_graph(model.decode, input_to_model=rand_z.unsqueeze(0))
evaluator = create_eval_engine(model, device)
resume = config['train']['resume']
to_save = {'model': model, 'optimizer': optimizer,
'lr_scheduler': lr_scheduler, 'trainer': trainer}
ckpt_dir = logdir / 'checkpoints'
if resume and ckpt_dir.exists() and len(list(ckpt_dir.iterdir())) != 0:
checkpoint_path = max(ckpt_dir.iterdir(),
key=lambda x: x.stat().st_mtime).resolve()
print(f"loading checkpoint from {checkpoint_path} ......")
checkpoint = torch.load(checkpoint_path)
for k, v in to_save.items():
v.load_state_dict(checkpoint[k])
print(f"loaded {k}")
ckpt_files = [f for f in ckpt_dir.iterdir()]
logging.warning(f"finished loading, now removing files {ckpt_files}")
for i in ckpt_files:
i.unlink()
attach_checkpoints_saver(
model, optimizer, lr_scheduler, trainer, ckpt_dir, to_save)
trainer.run(image_dataset, max_epochs=config['train']['epoch_num'])
tb_logger.close()
trainer is created through
def create_train_engine(model, optimizer, loss_func, device, lr_scheduler, log_dir):
def update_model(engine, image_batch):
model.train()
image_batch = image_batch.to(device)
optimizer.zero_grad()
out, mu, logvar, z = model(image_batch)
loss, MSE, KLD = loss_func(out, image_batch, mu, logvar)
loss.backward()
optimizer.step()
out = {
'loss': {
'total': loss.item(),
'mse': MSE.item(),
'kld': KLD.item()
},
'repr': {
'mu': mu,
'logvar': logvar,
'z': z
}
}
return out
trainer = Engine(update_model)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
if device.type == 'cuda':
GpuInfo().attach(trainer, name='gpu')
def concat_var(varname, output_transform):
concat_var = VariableAccumulation(lambda accumulator, output: torch.cat(
[accumulator, output], dim=0) if accumulator.ndim != 0 else output, output_transform=output_transform)
concat_var.attach(trainer, f'concat_{varname}')
return concat_var
concat_mu = concat_var('mu', lambda x: x['repr']['mu'])
concat_logvar = concat_var('logvar', lambda x: x['repr']['logvar'])
concat_image = concat_var('image', lambda x: trainer.state.batch)
@trainer.on(Events.EPOCH_COMPLETED)
def logging_each_epoch(trainer):
# print(f"concat_mu is {trainer.state.metrics['concat_mu'][0].shape}")
logging.debug(
f"concat_mu.accumulator.shape = {concat_mu.accumulator.shape}")
logging.debug(
f"concat_logvar.accumulator.shape = {concat_logvar.accumulator.shape}")
logging.debug(
f"concat_image.accumulator.shape = {concat_image.accumulator.shape}")
return
@trainer.on(Events.ITERATION_COMPLETED)
def logging_each_iteration(trainer):
logging.info(f"loss : {trainer.state.output['loss']}")
return
tb_logger = TensorboardLogger(log_dir=str(log_dir / 'tb_events'))
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag='training_loss', output_transform=lambda x: x['loss']
),
event_name=Events.ITERATION_COMPLETED
)
tb_logger.attach(
trainer,
log_handler=OutputHandler(tag='training_metrics', metric_names='all'),
event_name=Events.ITERATION_COMPLETED
)
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(
optimizer), event_name=Events.ITERATION_COMPLETED(every=1))
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=10))
@trainer.on(Events.COMPLETED)
def log_embedding(trainer):
# tb_logger.writer.add_embedding(trainer.state.output['repr']['z'].view(1, -1), global_step=trainer.state.iteration, label_img=trainer.state.batch)
mu_embedding = torch.flatten(concat_mu.accumulator, start_dim=1)
tb_logger.writer.add_embedding(mu_embedding, global_step=trainer.state.iteration, label_img=concat_image.accumulator)
return
ProgressBar(persist=True).attach(trainer, metric_names=[
'gpu:0 mem(%)', 'gpu:0 util(%)'] if device.type == 'cuda' else None)
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="training/loss", output_transform=lambda output: output['loss']
),
event_name=Events.ITERATION_COMPLETED(every=1),
)
return trainer, tb_logger
when i run tensorboard --logdir logdir/ --host localhost
I got scalar plot reset its global_step

Here is a minimal example @vfdev-5
from collections import OrderedDict
import numpy as np
from ignite.engine import Engine
init_state = OrderedDict([('seed', 12), ('epoch_length', 8), ('max_epochs', 1), ('iteration', 8)])
def dataset():
yield np.random.random()
e = Engine(lambda x , b: b)
e.load_state_dict(init_state)
print(f'e.state before run: {e.state}')
e.run(dataset(), epoch_length=1)
print(f'e.state after run: {e.state}')
output
(.venv) ◰³ .venv ~/micro python test.py Wed Mar 4 00:01:33 2020
e.state before run: State:
iteration: 8
epoch: 1
epoch_length: 8
max_epochs: 1
output: <class 'NoneType'>
batch: <class 'NoneType'>
metrics: <class 'dict'>
dataloader: <class 'NoneType'>
seed: 12
e.state after run: State:
iteration: 1
epoch: 1
epoch_length: 1
max_epochs: 1
output: 0.15416284237967237
batch: 0.15416284237967237
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: 12
You can see the iteration is indeed reset
@Luhan-Cheng thanks for the last snippet.
Yes, the state is reset because of
init_state = OrderedDict([('seed', 12), ('epoch_length', 8), ('max_epochs', 1), ('iteration', 8)])
this init_state defines a run of 1 epoch of 8 iterations but starting from 8th iteration.
If you would like to continue, you need to set more epochs. For example, like that
from collections import OrderedDict
import numpy as np
from ignite.engine import Engine
init_state = OrderedDict([('seed', 12), ('epoch_length', 8), ('max_epochs', 2), ('iteration', 8)])
def dataset():
while True:
yield np.random.random()
e = Engine(lambda x , b: b)
e.load_state_dict(init_state)
print(f'e.state before run: {e.state}')
e.run(dataset())
print(f'e.state after run: {e.state}')
this gives
e.state before run: State:
iteration: 8
epoch: 1
epoch_length: 8
max_epochs: 2
output: <class 'NoneType'>
batch: <class 'NoneType'>
metrics: <class 'dict'>
dataloader: <class 'NoneType'>
seed: 12
e.state after run: State:
iteration: 16
epoch: 2
epoch_length: 8
max_epochs: 2
output: 0.03666430642108576
batch: 0.03666430642108576
metrics: <class 'dict'>
dataloader: <class 'generator'>
seed: 12
Thank you @vfdev-5 ! this solved my issue