When running a model using apex+ddt my loss spikes dramatically after the model restarts.
If i disable apex, it works fine.
Currently, I've set up apex this way:
optimizer = Adam()
schedulers= LR
torch.cuda.set_device(gpu_nb)
model.cuda(gpu_nb)
# apex
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
# ddp
model = DDP(model)
# restore state
ckpt = torch.load(path)
model.load_state_dict(ckpt['state_dict'])
optimizer.load_state_dict(ckpt['opt_dict'])
LR.load_state_dict(ckpt['lr_dict'])
amp.load_state_dict(ckpt['amp'])
# continue ....
Actual code is here:

@ptrblck
Thanks for the code @williamFalcon!
We'll try to reproduce and debug it.
Do you see this loss spike only using LightningDistributedDataParallel or also PyTorch's DDP?
lightning鈥檚 ddp is pytorch ddp. except it routed the forward call to train_step or val_step. but otherwise the same
You can replicate this by doing the following:
model = MNISTModel()
trainer = Trainer(gpus=[0,1], use_amp=True)
trainer.fit(model)
Run the above script using a slurm script on a node with 2 gpus (i used 2 v100s with 32gb each). set the walltime to 10 mins (so the loss can go down). At 7 mins it'll resubmit itself and you'll see the problem.
Although MNIST might be too trivial. I'd try cifar-10 perhaps?
Here's a full working example to replicate. Make sure to install lightning from master:
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from test_tube import Experiment, SlurmCluster, HyperOptArgumentParser
import numpy as np
import pytorch_lightning as pl
PORT = np.random.randint(12000, 20000, 1)[0]
SEED = 2334
torch.manual_seed(SEED)
np.random.seed(SEED)
"""
To run in interactive node:
python issue.py
To run as a cluster job (submits job to cluster):
python issue.py --cluster
"""
class CIFAR100LM(pl.LightningModule):
def __init__(self, save_path):
super(CIFAR100LM, self).__init__()
self.save_path = save_path
self.l1 = torch.nn.Linear(32 * 32*3, 1028)
self.l2 = torch.nn.Linear(1028, 2048)
self.l3 = torch.nn.Linear(2048, 100)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = self.l3(x)
return x
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}
def configure_optimizers(self):
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=0.0002)
@pl.data_loader
def tng_dataloader(self):
return DataLoader(CIFAR100(self.save_path, train=True, download=True,
transform=transforms.ToTensor()), batch_size=32)
@pl.data_loader
def val_dataloader(self):
return DataLoader(CIFAR100(self.save_path, train=True, download=True,
transform=transforms.ToTensor()), batch_size=32)
def main(*args, **kwargs):
log_path = '/log/path'
exp = Experiment(
name='apex_bug_test',
save_dir=log_path,
autosave=False,
description='amp test'
)
exp.save()
data_path = '/data/CIFAR100'
model = CIFAR100LM(data_path)
trainer = pl.Trainer(
gpus=[0, 1],
use_amp=True,
experiment=exp,
min_nb_epochs=200,
distributed_backend='ddp'
)
trainer.fit(model)
def launch_cluster_job(args):
# enable cluster training
slurm_log_path = '/log/path'
cluster = SlurmCluster(
log_path=slurm_log_path,
hyperparam_optimizer=args
)
# email for cluster coms
cluster.notify_job_status(email='[email protected]', on_done=True, on_fail=True)
# configure cluster
cluster.per_experiment_nb_nodes = 1
cluster.per_experiment_nb_gpus = 2
cluster.job_time = '00:10:00'
cluster.memory_mb_per_node = 0
cluster.per_experiment_nb_cpus = 2
# any modules for code to run in env
cluster.add_command('source activate your_conda_env')
cluster.add_command('export NCCL_SOCKET_IFNAME=^docker0,lo')
cluster.add_command('export NCCL_DEBUG=INFO')
cluster.add_command('export PYTHONFAULTHANDLER=1')
cluster.add_command(f'export MASTER_PORT={PORT}')
cluster.load_modules(['NCCL/2.4.7-1-cuda.10.0'])
cluster.python_cmd = 'python'
cluster.add_slurm_cmd(cmd='constraint', value='volta32gb', comment='use 32gb gpus')
cluster.add_slurm_cmd(cmd='ntasks-per-node', value=2, comment='1 task per gpu')
# name of exp
job_display_name = 'apex_bug_test'
# run hopt
print('submitting jobs...')
cluster.optimize_parallel_cluster_gpu(
main,
nb_trials=1,
job_name=job_display_name
)
if __name__ == '__main__':
parser = HyperOptArgumentParser()
parser.add_argument('--cluster', dest='cluster', action='store_true')
args = parser.parse_args()
if args.cluster:
launch_cluster_job(args)
else:
main()

This is from the above example.
orange before checkpointing blue after resuming training
ok, digging into this more...
Here is the loss after 3 reloads (each color a different reload).

However, when tracking the training accuracy i see that it remains high even after reload even when the loss spikes.

This suggests that the model loads correctly, but the scaling is off (which has to do with amp.load_state_dict()).
In this simple model it's not a problem, but on more complex ones with losses sensitive to scaling, it nans out the losses after the model restarts.
@williamFalcon, @ptrblck did you guys figure this out? we are having the same problem.
Turns out the problem is related to reloading the optimizer. When you amp.initialize, the optimizer.state needs to be an empty dictionary, or this problem occurs. As a workaround, we empty optimizer.state before amp.initialize and use a short warmup for the Adam optimizer to recover its moving averages (thanks to @yaroslavvb for the warmup suggestion)
@ibeltagy Could you please explain how you're warming up the optimizer? All that is coming to mind for me is calling optim.step(), but that would actually apply the gradients to the model, which seems like exactly what I don't want to happen. Are you saving the original state of your model, stepping the optimizer a few times, and then restoring the model back to the original?
I'm dealing with a model that explodes after the first step upon reload, so aside from reloading and training with optim level O0, I'm not sure what to do.
I mean slowly increasing learning rate from zero to the value you want.
We also encountered the same issue for PyTorch's DDP. @ptrblck The loss becomes quite large when we reload the checkpoint.
Same issues here. Loss spikes after loading checkpoint . Not loading the optimizer helps sometimes but not always. Nvidia needs to fix this asap.
I don't see quick responses from them in GitHub.
Is it time to search for an alternative to apex!
I found an ugly hack but it seems to work. It goes like this,
# load optimizer from file
optimizer = torch.optim.AdamW(...)
optimizer_state_dict = torch.load(f_opt)
optimizer.load_state_dict(optimizer_state_dict)
# then remove optimizer state to make amp happy
optimizer.state = {}
# init amp and load it from file
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
amp.load_state_dict(torch.load(f_amp))
# forward, backward, optimizer step, zero_grad
loss = model(random_input)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
model.zero_grad()
# then load optimizer state_dict again (this time without removing optimizer.state)
optimizer.load_state_dict(optimizer_state_dict)
So amp is expecting one optimizer step before loading the checkpoint smoothly?
Has anyone looked into the source code to find the root cause ?
no, amp.initialize is expecting an optimizer with an empty optimizer.state. With the first optimizer.step, the optimizer initializes its optimizer.state and it gets registered with apex somehow. At this step, replacing optimizer.state with a state from the checkpoint seems to work.
Unfortunately, I cannot get the workaround to work. I may have to disable fp16 entirely - unless I find that the loss spike is harmless in terms of actual model improvement.
I am grateful for fp16, but this does seem like a nearly show-stopping issue! It should be possible to restart from a checkpoint and continue where you left off - right? I'm a bit surprised this isn't a bigger issue - are folks not using fp16?
i don鈥檛 remember how we solved this but we did in pytorch lightning. you could try running it there with fp16 enabled
@williamFalcon, couldn't find the relevant code (not here https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_io.py#L374), and also didn't find amp.load_state_dict (which is needed for loss scale)
@daniel347x FWIW, I've found the severity of the bug to be somewhat task dependent. For some tasks, the loss spike disappears quickly, and it's indiscernible whether or not it had a lasting effect on convergence. For others, it will legitimately ruin the model.
My current best rule of thumb is that classification tasks are most affected, and particularly, as the number of possible classes increases, the worse the spike affects the model.
Can this be because model params are actually FP16 when they are saved. After loading optimizer gets the FP16 params as FP32 which causes loss in precission?
I have also encountered the same problem, and the loss diverged at O2 level. But I found @ibeltagy 's hack is useful.
I suspect the problem is with scale_loss and amp.initialize not correctly handling optimizer when it is already loaded. Maybe a simple fix can be done here.
@ptrblck Any updates on this?
we鈥檙e now using native amp with pt 1.6+ on pytorch lightning. I would just switch to that.
Add optimizer.load_state_dict right before the first optimizer.step works for me too. (I basically add a manual checkpoint loading in optimizer_step of LightningModule)
For general pytorch user, this is what I have:
state_dict = torch.load(optimizer_ckpt)
optimizer.load_state_dict(state_dict)
....
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.load_state_dict(state_dict)
optimizer.step()
optimizer.zero_grad()
It seems the reason is when you load the first time, the saved states will be cast to fp16, while at this time the states are not properly initialized because of the lazy_init. After the first time to call amp.scale_loss, the states of optimizer are properly initialized and the states will be recast to fp32, and the precision difference here will cause the problem.
However if amp_lazy_init earlier, the loss still spikes (but not that bad.)
Most helpful comment
@williamFalcon, @ptrblck did you guys figure this out? we are having the same problem.