Pytorch-lightning: How to save checkpoints within lightning_logs?

Created on 22 Mar 2020  ·  3Comments  ·  Source: PyTorchLightning/pytorch-lightning

I'm currently doing checkpointing as follows:

checkpoint_callback = pl.callbacks.ModelCheckpoint(
          filepath=os.path.join(os.getcwd(), 'checkpoints/{epoch}-{val_loss:.2f}'),
          verbose=True,
          monitor='val_loss', 
          mode='min', 
          save_top_k=-1,
          period=1
      )


  trainer = pl.Trainer(
      default_save_path=os.path.join(os.getcwd(), 'log_files_are_stored_here'),
      gpus=1,
      max_epochs=2
      checkpoint_callback=checkpoint_callback
  )  

This creates the following folder structure:

├── checkpoints # all the .pth files are saved here
└── log_files_are_stored_here
    └── lightning_logs 
       ├── version_0
       ├── version_1
       ├── version_2

How can I get the .pth files for each version to be saved in the respective version folders like so?:

└── log_files_are_stored_here
    └── lightning_logs 
       ├── version_0
            └── checkpoints #  save the .pth files here
       ├── version_1
            └── checkpoints #  save the .pth files here
       ├── version_2
            └── checkpoints #  save the .pth files here
question won't fix

Most helpful comment

Hi,

this is how I do it:

tt_logger = TestTubeLogger(save_dir=str(log_dir / "tt_logs"), name=run_name)

checkpoint_dir = (
    Path(tt_logger.save_dir)
    / tt_logger.experiment.name
    / f"version_{tt_logger.experiment.version}"
    / "checkpoints"
)
filepath = checkpoint_dir / "{epoch}-{val_loss:.4f}"
checkpoint_cb = ModelCheckpoint(filepath=str(filepath))

trainer = pl.Trainer(
        logger=tt_logger,
        checkpoint_callback=checkpoint_cb,
        ...
    )

All 3 comments

Hi,

this is how I do it:

tt_logger = TestTubeLogger(save_dir=str(log_dir / "tt_logs"), name=run_name)

checkpoint_dir = (
    Path(tt_logger.save_dir)
    / tt_logger.experiment.name
    / f"version_{tt_logger.experiment.version}"
    / "checkpoints"
)
filepath = checkpoint_dir / "{epoch}-{val_loss:.4f}"
checkpoint_cb = ModelCheckpoint(filepath=str(filepath))

trainer = pl.Trainer(
        logger=tt_logger,
        checkpoint_callback=checkpoint_cb,
        ...
    )

Hi, the TensorBoard version inspired by @chris-clem snippet.

Any idea how to get rid of the "HACK"?

tb_logger = TensorBoardLogger(save_dir='logs/tb/')

# HACK: to avoid tb_logger crashing in self._get_next_version() if I access tb_logger.log_dir
os.makedirs(f'logs/tb/default', exist_ok=True) 

mcp =  ModelCheckpoint(filepath=f'{tb_logger.log_dir}/' + '{epoch}_vl_{val_loss:.2f}')
trainer = Trainer(logger=tb_logger, checkpoint_callback=mcp) 
...

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

versatran01 picture versatran01  ·  3Comments

srush picture srush  ·  3Comments

maxime-louis picture maxime-louis  ·  3Comments

Vichoko picture Vichoko  ·  3Comments

DavidRuhe picture DavidRuhe  ·  3Comments