Pytorch-lightning: How to save checkpoints within lightning_logs?

Created on 22 Mar 2020  ·  3Comments  ·  Source: PyTorchLightning/pytorch-lightning

I'm currently doing checkpointing as follows:

checkpoint_callback = pl.callbacks.ModelCheckpoint(
          filepath=os.path.join(os.getcwd(), 'checkpoints/{epoch}-{val_loss:.2f}'),
          verbose=True,
          monitor='val_loss', 
          mode='min', 
          save_top_k=-1,
          period=1
      )


  trainer = pl.Trainer(
      default_save_path=os.path.join(os.getcwd(), 'log_files_are_stored_here'),
      gpus=1,
      max_epochs=2
      checkpoint_callback=checkpoint_callback
  )  

This creates the following folder structure:

├── checkpoints # all the .pth files are saved here
└── log_files_are_stored_here
    └── lightning_logs 
       ├── version_0
       ├── version_1
       ├── version_2

How can I get the .pth files for each version to be saved in the respective version folders like so?:

└── log_files_are_stored_here
    └── lightning_logs 
       ├── version_0
            └── checkpoints #  save the .pth files here
       ├── version_1
            └── checkpoints #  save the .pth files here
       ├── version_2
            └── checkpoints #  save the .pth files here
question won't fix

Most helpful comment

Hi,

this is how I do it:

tt_logger = TestTubeLogger(save_dir=str(log_dir / "tt_logs"), name=run_name)

checkpoint_dir = (
    Path(tt_logger.save_dir)
    / tt_logger.experiment.name
    / f"version_{tt_logger.experiment.version}"
    / "checkpoints"
)
filepath = checkpoint_dir / "{epoch}-{val_loss:.4f}"
checkpoint_cb = ModelCheckpoint(filepath=str(filepath))

trainer = pl.Trainer(
        logger=tt_logger,
        checkpoint_callback=checkpoint_cb,
        ...
    )

All 3 comments

Hi,

this is how I do it:

tt_logger = TestTubeLogger(save_dir=str(log_dir / "tt_logs"), name=run_name)

checkpoint_dir = (
    Path(tt_logger.save_dir)
    / tt_logger.experiment.name
    / f"version_{tt_logger.experiment.version}"
    / "checkpoints"
)
filepath = checkpoint_dir / "{epoch}-{val_loss:.4f}"
checkpoint_cb = ModelCheckpoint(filepath=str(filepath))

trainer = pl.Trainer(
        logger=tt_logger,
        checkpoint_callback=checkpoint_cb,
        ...
    )

Hi, the TensorBoard version inspired by @chris-clem snippet.

Any idea how to get rid of the "HACK"?

tb_logger = TensorBoardLogger(save_dir='logs/tb/')

# HACK: to avoid tb_logger crashing in self._get_next_version() if I access tb_logger.log_dir
os.makedirs(f'logs/tb/default', exist_ok=True) 

mcp =  ModelCheckpoint(filepath=f'{tb_logger.log_dir}/' + '{epoch}_vl_{val_loss:.2f}')
trainer = Trainer(logger=tb_logger, checkpoint_callback=mcp) 
...

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

BraveDistribution picture BraveDistribution  ·  31Comments

sai-prasanna picture sai-prasanna  ·  24Comments

lorenzoFabbri picture lorenzoFabbri  ·  34Comments

dschaehi picture dschaehi  ·  31Comments

Borda picture Borda  ·  60Comments