Apex: Apply apex to an already trained normal model and continue the training

Created on 8 Apr 2019  路  23Comments  路  Source: NVIDIA/apex

Can I train a PyTorch model for let's say 4 epochs in a normal fashion while saving the model, and then continue the training by applying amp?

checkpointing

Most helpful comment

@Rhuax, @rwightman, @apsears, @glample, @hadaev8
Checkpointing just got merged into out master branch.
Checkout the README to see an example usage.

All 23 comments

I don't know the best way to do this off the top of my head, but I'm taking a detailed look at amp checkpointing (saving/restoring models) this week and next week, and I'll include this use case on my list of things to support.

I'm pretty sure if you load existing FP32 weights into your model before AMPing it'll just work. I've done that successfully.

I think that restoring the optimizer state can be problematic, even if it's before initializing AMP (memory fuzzy).

In NVIDIA/tacotron2, you seem to be able to go from FP16 -> full precision and back with no problem. OTOH I just ran into an issue where the first few updates wreck the previous loss, wiping out many epochs worth of optimization. Seems like a Checkpointing issue...

@mcarilli do you have any ETA on this? I really think this is a critical feature. As of now, I cannot use AMP as it gives me the following divergence issues after checkpoint reloading:

image

In that case experiments ended up recovering, but sometimes they never reach the performance they were at during the checkpoint. For now I switched back to FP16_Optimizer because it does not have the same problem, but I miss the ability to accumulate gradients over several batches, which limits the performance I can get in some tasks.

Thanks!

I agree it's a critical feature. @ptrblck and I decided on a strategy then got preempted with a bunch of work debugging/optimizing some internal models. I'm hoping to implement it next week, barring further interrupts...I'll keep you updated if that changes.

@mcarilli any luck with the checkpointing?
Do you know if there is a temporary workaround I could use in the meantime?

Any progress?

@Rhuax, @rwightman, @apsears, @glample, @hadaev8
Checkpointing just got merged into out master branch.
Checkout the README to see an example usage.

@ptrblck
May you give example how to load inside a function?

def load_checkpoint(checkpoint_path, model, optimizer, warm_start_rus):
    assert os.path.isfile(checkpoint_path)
    print("Loading checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint_dict['state_dict'])
    optimizer.load_state_dict(checkpoint_dict['optimizer'])
    amp.load_state_dict(checkpoint['amp'])
    print("Loaded checkpoint '{}' from iteration {}" .format(
        checkpoint_path, iteration))
    return model, optimizer, learning_rate, iteration

I need to pass amp inside, right?

@hadaev8 This should work, if you've called amp.initialize before on your model and optimizer.

@ptrblck

AttributeError: 'AmpState' object has no attribute 'loss_scalers'

Any advice?

@hadaev8 Could you post a reproducible code snippet so that we can debug it?

My code based on this repo
https://github.com/NVIDIA/tacotron2
Not easy to make snippet

Thanks for the update.
amp.state_dict() tries to return the loss_scaler(s), which aren't initialized in your code snippet.
Make sure to create a model and call amp.initialize() on it first.

model = nn.Linear(1, 1).cuda()
model = amp.initialize(model, opt_level='O1')
amp.state_dict()

Anyway, we should improve the error message and raise a proper warning.

@ptrblck
I get my problem case, tried to set state dict before amp.initialize, thanks.

Hey, @ptrblck! Thanks for following up with everyone in this thread. FYI, I get the same problem of the checkpoint with opt_level='O2' but it works with opt_level='O1'

@giacaglia What kind of error are you seeing? Do you see a bump in the loss(es)?
If so, did you also restored the optimizer's state_dict as well as amp's?

Yes, I did see a big bump in the losses. I saved the optimizer's state_dict as well as amp's. The problem doesn't occur when I set opt_level to O1. I can try to make the problem reproducible if that is of any help

Sure! That would be helpful for debugging, although we recommend to use O1. ;)

@ptrblck Can you confirm you recommend O1? Aren't the vast bulk of the speed improvements due to the O2 optimizations?

@daniel347x
Yes, we recommend O1 for the typical use case. You might of course compare different opt levels for your use case. :)

@ptrendx I get the AttributeError: 'AmpState' object has no attribute 'loss_scalers' error as @hadaev8 but I am saving the state after a few batches. I think it might be important to mention that I have the enabled parameter of amp.initialize set to False. Could you please initialize the loss_scalers with some dummy data to allow for versatility? A workaround is to enable AMP but use O0.

Was this page helpful?
0 / 5 - 0 ratings