ModelCheckpoint is unable to save filenames that reference a metric with a slash in their name. I use grouped metrics for tensorboard, and would like to save my files containing my loss: val/loss. However, ModelCheckpoint uses os.path.split, which splits the file name: https://github.com/PyTorchLightning/pytorch-lightning/blob/6ac0958166c66ed599c96737b587232b7a33d89e/pytorch_lightning/callbacks/model_checkpoint.py#L258
If I try to use
ModelCheckpoint("root/dir/{epoch}_{val/loss:.5f}")
The above evaluates to
self.dirpath = "root/dir/{epoch}_{val"
self.filename = "loss:.5f}"
This inevitably causes failure when attempting to format the output path.
As above, log a metric with a slash, then use it in model checkpoint output
class Module(pl.LightningModule):
...
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.loss_fn(logits, y)
self.log('val/loss', loss, on_epoch=True)
return loss
...
def main():
trainer = pl.Trainer(checkpoint_callback=ModelCheckpoint("{epoch}_{val/loss:.5f}"))
Split only along file path boundaries, ignoring variable names yet-to-be-formatted.
Per the previous example, we'd expect:
self.dirpath = "root/dir"
self.filename = "{epoch}_{val/loss:.5f}"
Hi! thanks for your contribution!, great first issue!
Cool, I was actually here to open this issue :)
I also encountered this problem. I think it's quite common to use the forward slash to group metrics in tensorboard, but this obviously clashes with this formatting.
I also tried to escape it, but with little success.
Edit
A negative look-ahead regex like this should do the job (assuming there are no nested brackets), but feels a bit an overkill.
Thanks for raising the issue! would you like submit a PR for the fix?
@ozen Am I correct that the issue now is that checkpoint names automatically include the metric name in addition to the value?
Could we make that configurable, so I can enable/disable automatic name insertion?
@its-dron exactly. Making it configurable could be ugly though. Maybe the slashes can be automatically converted to something like underscores.
Personally, I'd prefer to not have the auto-inserted metric names at all. I'd prefer to have fine-grain control over what my files are named instead of having the library decide for me.
But that seems like a drastic change to drop at this point.
@rohitgr7 could you reopen this issue?
Is it still an issue or not working as expected?
@rohitgr7 it's an issue. The way file name formatting is implemented in ModelCheckpoint._format_checkpoint_name() method prevents using metrics that contain a slash.
Interesting, I actually added a test for this
https://github.com/PyTorchLightning/pytorch-lightning/blob/b459fd26ac773484e4c97c12e4bab221bb1609b0/tests/checkpointing/test_model_checkpoint.py#L225-L229
not sure why is it not working. Can you put an example here, what's not working from your side. I'll have a look :)
assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' will pass.
But this won't create a checkpoint file named "epoch=4_val/loss=0.03000.ckpt"; it will create a directory named "epoch=4_val" and a file named "loss=0.03000.ckpt" inside that directory. If there are more than one metric with slashes, more directories...
This is the obvious result of using a slash in the name. But the way file name formatting works i.e. adding the metric name to the file name leaves the user no way to prevent this.
I'm still little confused. What do you want your filename to be like? Don't you want to inject the metric values inside it?
I'm still little confused. What do you want your filename to be like? Don't you want to inject the metric values inside it?
We're not talking about metric values. It injects metric names, which may contain a slash, which creates directories instead of a file. I'm sorry I'm repeating the same for the 3rd time but I don't know how else I can explain this.
Yeah I got that directory issue, AFAIK this is expected behavior when you have slashes in the filepath.
But the way file name formatting works i.e. adding the metric name to the file name leaves the user no way to prevent this.
your this statement got me confused.
What do you suggest, how should this work if filename contains slashes? raise a warning/exception or some replacement with another character or something else?
I think there are two workable solutions:
The latter could be as simple as (at this line):
if auto_insert_metric_name:
filename = filename.replace(group, name + "={" + name)
Feel free to send a PR with your suggestion!
I'm also interested in a solution to this problem.
I think f-strings handle this neatly: by default f'{metric}' just outputs the metric's value, while f'{metric=}' prints both the name and the value. But I imagine this would be an unacceptable breaking change at this point.
@its-dron's option 1 would be fine for my purposes, since the existing syntax just works without any changes and I don't really care about the specific names of my checkpoints anyway.
As a general solution, option 2 sounds better as it gives the freedom to completely customize the file names. Are you suggesting a single argument on the constructor of ModelCheckpoint to disable all metric names? Or do you have something in mind to specify it metric by metric?
@its-dron
I would prefer option 2.
@EliaCereda
A single argument (e.g. auto_insert_metric_name, prepend_metric_names or filename_with_names) on the constructor of ModelCheckpoint sounds good.
If that argument is false, the user should be able to choose between f'{metric}' and f'{metric=}', which is actually quite nice, and could easily be specified metric by metric.
BTW, slashes aren't only common, they are introduced by pl itself when using multiple dataloaders
I can confirm this bug. I also use tensorboard for loggin and have therefore a self.log('val/accuracy', val_acc) at the end of my validation_epoch_end.
I use these parameters for ModelCheckpoint
save_top_k: 3
monitor: val/accuracy
dirpath: saved_models/
filename: '{epoch}_{val/accuracy:.4f}'
and a directory called epoch=0_val is created and a checkpoint inside with the name accuracy=0.0000.ckpt
I would like the checkpoint to be named epoch=0_val_accuracy=0.0000.ckpt and to be placed inside the specified dirpath in this case. How can I solve this? I am using lightning 1.0.5