Ignite: set global_step from checkpoint

Created on 3 Mar 2020  ·  7Comments  ·  Source: pytorch/ignite

❓ Questions/Help/Support

Hello, I wonder how am I suppose to set the global_step from checkpoint for handlers?

Specifically, i would like to have most of the handlers in tensorboard logger using engine.state.iteration as global step.
Since engine.state.iteration is reset when i do engine.run. I had to create my own counter (like the following), which basically add an offset to current iteration, and pass it to all the global_step_transform arguments in output handler.

class global_step_counter:
    def __init__(self, resume_from_iter):
        self.resume_from_iter = resume_from_iter
        return 
    def set_starting_val(self, new_val):
        self.resume_from_iter = new_val
        return 
    def __call__(self, engine):
        return engine.state.iteration + self.resume_from_iter

However, i realize the following handlers does not provide global_step_transform option

  • OptimizerParamsHandler
  • WeightsScalarHandler
  • GradsScalarHandler
  • WeightsHistHandler
  • GradsHistHandler
    Instead, they use get_event_attrib_value to get the global_step, which would make this variable inconsistent with others

Thank you

question

All 7 comments

Hi @Luhan-Cheng, could you detail your use-case please?

We provide global_step_transform to transform the global_step for evaluators as their internal state is not related to the trainer's one. So, it means if we attach a handler to an evaluator without global_step_transform like that

tb_logger.attach(
    evaluator, 
    log_handler=OutputHandler(
        tag="training",
        metric_names=["nll", "accuracy"],
        # For illustration purposes, without global_step_transform
        # global_step_transform=global_step_from_engine(trainer)
    ),
    event_name=Events.EPOCH_COMPLETED
)

we will log values always with global step equal 1 as evaluator runs once over validation dataset (1 epoch).
In order, to provide additional info from the trainer who counts training epochs/iterations we set global_step_transform=global_step_from_engine(trainer) as in all our examples : https://pytorch.org/ignite/master/contrib/handlers.html#ignite.contrib.handlers.tensorboard_logger.TensorboardLogger

Other handlers you cited

OptimizerParamsHandler
WeightsScalarHandler
GradsScalarHandler
WeightsHistHandler
GradsHistHandler

are normally attached to the trainer, so automatically they have global step defined by the training.

Specifically, i would like to have most of the handlers in tensorboard logger using engine.state.iteration as global step.

So, I would say, this should be out-of-the-box.

Since engine.state.iteration is reset when i do engine.run. I had to create my own counter (like the following), which basically add an offset to current iteration, and pass it to all the global_step_transform arguments in output handler.

Seems like you would like to resume the training ? You can use engine.load_state_dict to setup starting iteration/epoch.

Thank you @vfdev-5 for fast response,
Yes, I would like to resume training. I did run trainer.load_state_dict and it set the trainer.state.iteration to the final iteration last run. However, as long as i run trainer.run(...) It reset the iteration to 0 in https://github.com/pytorch/ignite/blob/3ab78a21fff39933a8623f3070fe9fe4770723bc/ignite/engine/engine.py#L607

Please, provide a code snippet of your use-case. It would be helpful.

State is reset when 1) either it is not provided previously or 2) training is done.

Training function

def train():
    args = parse_argument()
    config = load_config(str(args.config))
    logging.info("Load configuration {}".format(
        json.dumps(config, indent=4, sort_keys=False)))
    device = torch.device('cuda' if torch.cuda.is_available()
                          and config['train']['cuda'] else 'cpu')
    image_dataset = DataLoader(Dataset(
        config['dataset_path'], image_num=config['train']['image_num']), batch_size=config['train']['batch_size'])
    model = VAE(z_channel=config['model']['z_channel'],
                encoder_channels=config['model']['encoder_channels']).to(device)
    logging.debug(f'Model: {model}')
    logging.info('Encoder summary')
    sample_image_batch = next(iter(image_dataset)).to(device)
    summary(model.encode, sample_image_batch.shape[1:])
    with torch.no_grad():
        sample_output_shape = model.encode(sample_image_batch).shape
    logging.info('Decoder summary')
    z_shape = (config['model']['z_channel'], *sample_output_shape[-2:])
    summary(model.decode, z_shape)
    lr_scheduler, optimizer = create_learning_scheduler_and_optimizer(
        config, model)
    logdir = Path(config['train']['log_dir'])

    trainer, tb_logger = create_train_engine(
        model, optimizer, loss_function, device, lr_scheduler, logdir)

    # empty graph written to log. profiler execution seems to fail but it is not suppose to affect result
    #tb_logger.writer.add_graph(model.encode, input_to_model=sample_image_batch)
    rand_z = torch.Tensor(np.random.rand(*z_shape)).to(device)
    tb_logger.writer.add_graph(model.decode, input_to_model=rand_z.unsqueeze(0))

    evaluator = create_eval_engine(model, device)

    resume = config['train']['resume']
    to_save = {'model': model, 'optimizer': optimizer,
               'lr_scheduler': lr_scheduler, 'trainer': trainer}
    ckpt_dir = logdir / 'checkpoints'
    if resume and ckpt_dir.exists() and len(list(ckpt_dir.iterdir())) != 0:
        checkpoint_path = max(ckpt_dir.iterdir(),
                              key=lambda x: x.stat().st_mtime).resolve()
        print(f"loading checkpoint from {checkpoint_path} ......")
        checkpoint = torch.load(checkpoint_path)
        for k, v in to_save.items():
            v.load_state_dict(checkpoint[k])
            print(f"loaded {k}")

        ckpt_files = [f for f in ckpt_dir.iterdir()]
        logging.warning(f"finished loading, now removing files {ckpt_files}")
        for i in ckpt_files:
            i.unlink()

    attach_checkpoints_saver(
        model, optimizer, lr_scheduler, trainer, ckpt_dir, to_save)

    trainer.run(image_dataset, max_epochs=config['train']['epoch_num'])
    tb_logger.close()

trainer is created through


def create_train_engine(model, optimizer, loss_func, device, lr_scheduler, log_dir):
    def update_model(engine, image_batch):
        model.train()
        image_batch = image_batch.to(device)
        optimizer.zero_grad()
        out, mu, logvar, z = model(image_batch)
        loss, MSE, KLD = loss_func(out, image_batch, mu, logvar)
        loss.backward()
        optimizer.step()
        out = {
            'loss': {
                'total': loss.item(),
                'mse': MSE.item(),
                'kld': KLD.item()
            },
            'repr': {
                'mu': mu,
                'logvar': logvar,
                'z': z
            }
        }
        return out
    trainer = Engine(update_model)

    trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
    if device.type == 'cuda':
        GpuInfo().attach(trainer, name='gpu')

    def concat_var(varname, output_transform):
        concat_var = VariableAccumulation(lambda accumulator, output:  torch.cat(
            [accumulator, output], dim=0) if accumulator.ndim != 0 else output, output_transform=output_transform)
        concat_var.attach(trainer, f'concat_{varname}')
        return concat_var

    concat_mu = concat_var('mu', lambda x: x['repr']['mu'])
    concat_logvar = concat_var('logvar', lambda x: x['repr']['logvar'])
    concat_image = concat_var('image', lambda x: trainer.state.batch)
    @trainer.on(Events.EPOCH_COMPLETED)
    def logging_each_epoch(trainer):
        #        print(f"concat_mu is {trainer.state.metrics['concat_mu'][0].shape}")
        logging.debug(
            f"concat_mu.accumulator.shape = {concat_mu.accumulator.shape}")
        logging.debug(
            f"concat_logvar.accumulator.shape = {concat_logvar.accumulator.shape}")
        logging.debug(
            f"concat_image.accumulator.shape = {concat_image.accumulator.shape}")

        return
    @trainer.on(Events.ITERATION_COMPLETED)
    def logging_each_iteration(trainer):
        logging.info(f"loss : {trainer.state.output['loss']}")
        return

    tb_logger = TensorboardLogger(log_dir=str(log_dir / 'tb_events'))
    tb_logger.attach(
        trainer,
        log_handler=OutputHandler(
            tag='training_loss', output_transform=lambda x: x['loss']
        ),
        event_name=Events.ITERATION_COMPLETED
    )
    tb_logger.attach(
        trainer,
        log_handler=OutputHandler(tag='training_metrics', metric_names='all'),
        event_name=Events.ITERATION_COMPLETED
    )
    tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(
        optimizer), event_name=Events.ITERATION_COMPLETED(every=1))

    tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=10))
    tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=10))

    tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=10))
    tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=10))

    @trainer.on(Events.COMPLETED)
    def log_embedding(trainer):
        #        tb_logger.writer.add_embedding(trainer.state.output['repr']['z'].view(1, -1), global_step=trainer.state.iteration, label_img=trainer.state.batch)
        mu_embedding = torch.flatten(concat_mu.accumulator, start_dim=1)
        tb_logger.writer.add_embedding(mu_embedding, global_step=trainer.state.iteration, label_img=concat_image.accumulator)
        return


    ProgressBar(persist=True).attach(trainer, metric_names=[
        'gpu:0 mem(%)', 'gpu:0 util(%)'] if device.type == 'cuda' else None)
    tb_logger.attach(
        trainer,
        log_handler=OutputHandler(
            tag="training/loss", output_transform=lambda output: output['loss']
        ),
        event_name=Events.ITERATION_COMPLETED(every=1),
    )
    return trainer, tb_logger

when i run tensorboard --logdir logdir/ --host localhost
I got scalar plot reset its global_step
image

Here is a minimal example @vfdev-5

from collections import OrderedDict
import numpy as np
from ignite.engine import Engine
init_state = OrderedDict([('seed', 12), ('epoch_length', 8), ('max_epochs', 1), ('iteration', 8)])

def dataset():
    yield np.random.random()

e = Engine(lambda x , b: b)
e.load_state_dict(init_state)
print(f'e.state before run: {e.state}')
e.run(dataset(), epoch_length=1)
print(f'e.state after run: {e.state}')

output

(.venv)  ◰³ .venv  ~/micro  python test.py                                                                                                              Wed Mar  4 00:01:33 2020
e.state before run: State:
    iteration: 8
    epoch: 1
    epoch_length: 8
    max_epochs: 1
    output: <class 'NoneType'>
    batch: <class 'NoneType'>
    metrics: <class 'dict'>
    dataloader: <class 'NoneType'>
    seed: 12

e.state after run: State:
    iteration: 1
    epoch: 1
    epoch_length: 1
    max_epochs: 1
    output: 0.15416284237967237
    batch: 0.15416284237967237
    metrics: <class 'dict'>
    dataloader: <class 'generator'>
    seed: 12

You can see the iteration is indeed reset

@Luhan-Cheng thanks for the last snippet.
Yes, the state is reset because of

init_state = OrderedDict([('seed', 12), ('epoch_length', 8), ('max_epochs', 1), ('iteration', 8)])

this init_state defines a run of 1 epoch of 8 iterations but starting from 8th iteration.

If you would like to continue, you need to set more epochs. For example, like that

from collections import OrderedDict
import numpy as np
from ignite.engine import Engine
init_state = OrderedDict([('seed', 12), ('epoch_length', 8), ('max_epochs', 2), ('iteration', 8)])

def dataset():
    while True:
        yield np.random.random()

e = Engine(lambda x , b: b)
e.load_state_dict(init_state)
print(f'e.state before run: {e.state}')
e.run(dataset())
print(f'e.state after run: {e.state}')

this gives

e.state before run: State:
    iteration: 8
    epoch: 1
    epoch_length: 8
    max_epochs: 2
    output: <class 'NoneType'>
    batch: <class 'NoneType'>
    metrics: <class 'dict'>
    dataloader: <class 'NoneType'>
    seed: 12

e.state after run: State:
    iteration: 16
    epoch: 2
    epoch_length: 8
    max_epochs: 2
    output: 0.03666430642108576
    batch: 0.03666430642108576
    metrics: <class 'dict'>
    dataloader: <class 'generator'>
    seed: 12

Thank you @vfdev-5 ! this solved my issue

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Aiden-Jeon picture Aiden-Jeon  ·  3Comments

TheCodez picture TheCodez  ·  3Comments

vfdev-5 picture vfdev-5  ·  3Comments

karfly picture karfly  ·  4Comments

vfdev-5 picture vfdev-5  ·  3Comments