Pytorch-lightning: Checkpoints based on validation_step or validation_epoch_end

Created on 28 Sep 2020  ·  4Comments  ·  Source: PyTorchLightning/pytorch-lightning

Somewhere I found an example for

def validation_step(self, batch, batch_idx):
       ....
        return {'val_loss': loss, ....}

def validation_epoch_end(self, batch):
        avg_val_loss               = torch.tensor([ x['val_loss'] for x in batch] ).mean()
        .....
        return {'val_loss': avg_val_loss,....}

What does the automatic checkpoint use for deciding if it got a better checkpoint?

My average val loss is getting better, but I do not have a checkpoint ( green line is run 292 ).

image

image

image

To avoid ambiguity, it would be nice to change the name. Where are all the places I would have to change the name 'val_loss' if I were to make it 'avg_val_loss'?

btw, lighting is amazing! I made excellent progress on a monster transformer model and I never worried about figuring out checkpoint, ddp, multi gpu, etc, etc.

question won't fix

Most helpful comment

If I use

    checkpoint_callback = ModelCheckpoint(
        filepath='/path/to/store/weights.ckpt',
        save_best_only=True,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )

how do I specify the same path that lightning uses by default (ie: f"/ligthing_logs/version_{num}/epoch_{epoch}")?

All 4 comments

My training loss is not averaged, and it's very erratic.

image

I suspect the checkpoint is based on the non-averaged val loss, since it's not logged, I can only speculate that its just as erratic as training loss. So I would like checkpoint based on average val loss.

Given my model takes days to train, any way to change this to avegare validation loss without restarting everything?

If I use

    checkpoint_callback = ModelCheckpoint(
        filepath='/path/to/store/weights.ckpt',
        save_best_only=True,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )

how do I specify the same path that lightning uses by default (ie: f"/ligthing_logs/version_{num}/epoch_{epoch}")?

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

remisphere picture remisphere  ·  3Comments

edenlightning picture edenlightning  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

as754770178 picture as754770178  ·  3Comments

baeseongsu picture baeseongsu  ·  3Comments