Apex: amp + checkpoint loading = problems

Created on 6 Mar 2019  路  41Comments  路  Source: NVIDIA/apex

Hi,
as you know I have been experimenting with amp for a while now. Today I stumbled upon very unexpected behavior. My FP16 models (trained with amp) do just as well than the FP32 models by themselves. But I usually also ensemble my models by doing something like this:

results = []
for c in checkpoints:
    network.load(checkpoint)
    results.append(network(data))

Interestingly, the performance drops quite a bit if I am doing that with amp enabled. To illustrate this, I created a minimalistic example with mnist:

from copy import deepcopy
import torch
import matplotlib
matplotlib.use("agg")
from torch.backends import cudnn
from apex import amp
import argparse
from torch import cuda
from torch import nn
from urllib import request
import gzip
import pickle
import os
import numpy as np


def load(mnist_file):
    init()
    with open(mnist_file, 'rb') as f:
        mnist = pickle.load(f)
    data_tr = mnist["training_images"].reshape(60000, 1, 28, 28)
    data_te = mnist["test_images"].reshape(10000, 1, 28, 28)
    return data_tr, mnist["training_labels"], data_te, mnist["test_labels"]


filename = [
["training_images","train-images-idx3-ubyte.gz"],
["test_images","t10k-images-idx3-ubyte.gz"],
["training_labels","train-labels-idx1-ubyte.gz"],
["test_labels","t10k-labels-idx1-ubyte.gz"]
]


def download_mnist():
    base_url = "http://yann.lecun.com/exdb/mnist/"
    for name in filename:
        print("Downloading "+name[1]+"...")
        request.urlretrieve(base_url+name[1], name[1])
    print("Download complete.")


def save_mnist():
    mnist = {}
    for name in filename[:2]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28)
    for name in filename[-2:]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
    with open("mnist.pkl", 'wb') as f:
        pickle.dump(mnist,f)
    print("Save complete.")


def init():
    if not os.path.isfile("mnist.pkl"):
        download_mnist()
        save_mnist()


def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    return initial_lr * (1 - epoch / max_epochs)**exponent


class GlobalAveragePool(nn.Module):
    def forward(self, x):
        axes = range(2, len(x.shape))
        for a in axes[::-1]:
            x = x.mean(a, keepdim=False)
        return x


def get_default_network_config():
    """
    returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use
    :return:
    """
    props = {}
    props['conv_op'] = nn.Conv2d
    props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} # kernel size will be set by network!
    props['nonlin'] = nn.LeakyReLU
    props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
    props['norm_op'] = nn.BatchNorm2d
    props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
    props['dropout_op'] = nn.Dropout2d
    props['dropout_op_kwargs'] = {'p': 0.0, 'inplace': True}
    return props


class ConvDropoutNormReLU(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, network_props):
        """
        if network_props['dropout_op'] is None then no dropout
        if network_props['norm_op'] is None then no norm
        :param input_channels:
        :param output_channels:
        :param kernel_size:
        :param network_props:
        """
        super(ConvDropoutNormReLU, self).__init__()

        network_props = deepcopy(network_props)  # network_props is a dict and mutable, so we deepcopy to be safe.

        self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size,
                                             padding=[(i - 1) // 2 for i in kernel_size],
                                             **network_props['conv_op_kwargs'])

        # maybe dropout
        if network_props['dropout_op'] is not None:
            self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs'])
        else:
            self.do = lambda x: x

        if network_props['norm_op'] is not None:
            self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs'])
        else:
            self.norm = lambda x: x

        self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs'])

        self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin)

    def forward(self, x):
        return self.all(x)


class StackedConvLayers(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None):
        """
        if network_props['dropout_op'] is None then no dropout
        if network_props['norm_op'] is None then no norm
        :param input_channels:
        :param output_channels:
        :param kernel_size:
        :param network_props:
        """
        super(StackedConvLayers, self).__init__()

        network_props = deepcopy(network_props)  # network_props is a dict and mutable, so we deepcopy to be safe.
        network_props_first = deepcopy(network_props)

        if first_stride is not None:
            network_props_first['conv_op_kwargs']['stride'] = first_stride

        self.convs = nn.Sequential(
            ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first),
            *[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in range(num_convs - 1)]
        )

    def forward(self, x):
        return self.convs(x)


class SimpleNetwork(nn.Module):
    def __init__(self, props=None):
        super(SimpleNetwork, self).__init__()
        if props is None:
            props = get_default_network_config()
        self.stage1 = StackedConvLayers(1, 16, (3, 3), props, 2, 1)
        self.stage2 = StackedConvLayers(16, 32, (3, 3), props, 2, 2)
        self.stage3 = StackedConvLayers(32, 64, (3, 3), props, 3, 2)
        self.stage4 = StackedConvLayers(64, 128, (3, 3), props, 3, 2)
        self.pool = GlobalAveragePool()
        self.fc = nn.Linear(128, 10, False)

    def forward(self, x):
        return self.fc(self.pool(self.stage4(self.stage3(self.stage2(self.stage1(x))))))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, required=False, default=None)
    parser.add_argument("--test_only", action="store_true", default=False)
    parser.add_argument("-s", help="output filename for trained model")
    parser.add_argument("-test_fnames", required=False, nargs='+')

    args = parser.parse_args()
    seed = args.seed
    test_only = args.test_only

    # seeding
    np.random.seed(seed)
    cuda.manual_seed(np.random.randint(10000))
    cuda.manual_seed_all(np.random.randint(10000))
    cudnn.deterministic = True
    cudnn.benchmark = False

    amp_handle = amp.init()

    data_tr, target_tr, data_te, target_te = load("mnist.pkl")

    data_tr = torch.from_numpy(data_tr).float().cuda()
    target_tr = torch.from_numpy(target_tr).long().cuda()
    data_te = torch.from_numpy(data_te).float().cuda()
    target_te = torch.from_numpy(target_te).long().cuda()

    network = SimpleNetwork().cuda()

    batch_size = 512

    if not test_only:
        optimizer = torch.optim.Adam(network.parameters(), 1e-3, amsgrad=True, weight_decay=1e-5)

        epochs = 30

        loss = torch.nn.CrossEntropyLoss()

        network.train()
        for epoch in range(epochs):
            print(epoch)
            optimizer.param_groups[0]['lr'] = poly_lr(epoch, epochs, 1e-3, 0.9)

            for _ in range(60000 // batch_size):
                optimizer.zero_grad()
                idxs = np.random.choice(60000, batch_size)
                data = data_tr[idxs]
                target = target_tr[idxs]

                out = network(data)

                l = loss(out, target)

                with amp_handle.scale_loss(l, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

        torch.save(network.state_dict(), args.s)

        with torch.no_grad():
            network.eval()
            out = network(data_te)

            _, amax = out.max(dim=1)
            acc = (amax == target_te).float().mean()
            print("accuracy on test: ", acc)
    else:
        if not isinstance(args.test_fnames, list):
            args.test_fnames = [args.test_fnames]

        for f in args.test_fnames:
            network.load_state_dict(torch.load(f, map_location=torch.device('cuda', torch.cuda.current_device())))

            with torch.no_grad():
                network.eval()
                out = network(data_te)

                _, amax = out.max(dim=1)
                acc = (amax == target_te).float().mean()
                print("file", f, "accuracy on test: ", acc)

I just hacked this together, so please ignore any potential ugliness in the code.

Here is how you can reproduce the problem:
First, train the network several times and save to different output files:

python train_mnist.py --seed 1 -s mnist_seed1.model

accuracy on test: tensor(0.9959, device='cuda:0')

python train_mnist.py --seed 2 -s mnist_seed2.model

accuracy on test: tensor(0.9955, device='cuda:0')

python train_mnist.py --seed 3 -s mnist_seed3.model

accuracy on test: tensor(0.9949, device='cuda:0')

Now that you have the trained models, you can run the testing by passing the filenames to the script like this:
python train_mnist.py --test_only -test_fnames mnist_seed1.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device='cuda:0')

python train_mnist.py --test_only -test_fnames mnist_seed2.model

file mnist_seed2.model accuracy on test: tensor(0.9955, device='cuda:0')

python train_mnist.py --test_only -test_fnames mnist_seed3.model

file mnist_seed3.model accuracy on test: tensor(0.9949, device='cuda:0')

The script also supports giving it several model checkpoints at once and it will test all of them one after the other. Although I am not ensembling here, this is the same procedure that I do in my ensembling code and the same issue appears here as well:
python train_mnist.py --test_only -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device='cuda:0')
file mnist_seed2.model accuracy on test: tensor(0.1135, device='cuda:0')
file mnist_seed3.model accuracy on test: tensor(0.1029, device='cuda:0')

If you look into the script (line 240+), it is doing nothing different than before, except loading new checkpoints with network.load_state_dict between test set predictions. We are seeing a big drop in performance from the second checkpoint onwards.

To demonstrate that this is not a problem with the files themselves, I ran it in a different order with the same result:
python train_mnist.py --test_only -test_fnames mnist_seed3.model mnist_seed1.model mnist_seed2.model

file mnist_seed3.model accuracy on test: tensor(0.9949, device='cuda:0')
file mnist_seed1.model accuracy on test: tensor(0.1036, device='cuda:0')
file mnist_seed2.model accuracy on test: tensor(0.1010, device='cuda:0')

I can fix this issue in this particular script by not initializing amp when I am running just the testing (replace amp_handle = amp.init() with

    if not test_only:
        amp_handle = amp.init()

). After replacing that, testing multiple checkpoints runs nicely:

python1 train_mnist.py --test_only -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device='cuda:0')
file mnist_seed2.model accuracy on test: tensor(0.9954, device='cuda:0')
file mnist_seed3.model accuracy on test: tensor(0.9949, device='cuda:0')

I am not sure what is going on here, but I think this it would be rather important to understand what is going on. It took me a good 3 hours to finally figure out what was causing my severe performance regression today. Do you have any idea how this issue could be solved? I need to be able to load checkpoints during and after my trainings and rely on this to work :-)

Best,
Fabian

checkpointing

Most helpful comment

Is there any update yet? I'm running into the same problem and cannot figure out what is causing this.
I tried to restore the model and continued to train it but it seemed that I was training from scratch not from checkpoint.

All 41 comments

I uploaded the checkpoints to my dropbox so that you don't have to train yourself:

https://www.dropbox.com/s/bm2tn8v0725ska7/mnist_checkpoints.zip?dl=0

Thank you for the thorough repro, this is definitely something we want to support and it鈥檚 not immediately obvious what the problem is. I鈥檒l look into it over the next few days, i have a lot of work in flight at the moment.

In the meantime, i strongly recommend moving to the new API. It鈥檚 more versatile, probably a bit faster, and will be the official api moving forward.

Hi,
thanks for the hint. I swapped to the new api (api_refactor branch) and reran my script:

python1 train_mnist.py --seed 4 -s mnist_seed4.model

Selected optimization level O1: Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
{'enabled': True, 'opt_level': 'O1', 'cast_model_type': None, 'patch_torch_functions': True, 'keep_batchnorm_fp32': None, 'master_weights': None, 'loss_scale': 'dynamic'}
enabled : True
opt_level : O1
cast_model_type : None
patch_torch_functions : True
keep_batchnorm_fp32 : None
master_weights : None
loss_scale : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled : True
opt_level : O1
cast_model_type : None
patch_torch_functions : True
keep_batchnorm_fp32 : None
master_weights : None
loss_scale : dynamic
accuracy on test: tensor(0.9957, device='cuda:0')

python train_mnist.py -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model mnist_seed4.model --test_only

Selected optimization level O1: Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
{'enabled': True, 'opt_level': 'O1', 'cast_model_type': None, 'patch_torch_functions': True, 'keep_batchnorm_fp32': None, 'master_weights': None, 'loss_scale': 'dynamic'}
enabled : True
opt_level : O1
cast_model_type : None
patch_torch_functions : True
keep_batchnorm_fp32 : None
master_weights : None
loss_scale : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled : True
opt_level : O1
cast_model_type : None
patch_torch_functions : True
keep_batchnorm_fp32 : None
master_weights : None
loss_scale : dynamic
file mnist_seed1.model accuracy on test: tensor(0.9959, device='cuda:0')
file mnist_seed2.model accuracy on test: tensor(0.1135, device='cuda:0')
file mnist_seed3.model accuracy on test: tensor(0.1029, device='cuda:0')
file mnist_seed4.model accuracy on test: tensor(0.1034, device='cuda:0')

Problem still persists though, but I think you expected that already :-)
Thanks for your efforts! Apex is great and I love it!

Thanks for switching, i hope you like the new API! Like I said, this is important and i'll look at it soon, but I need to work on other things (ie a GAN example) today and probably tomorrow.

No worries =) Take your time
I like the new API! The documentation seems to not be ready yet, but thanks to the webinar I know what to do (so far).
Best,
Fabian

The documentation isn't complete to my satisfaction, but I've got a decent amount posted.
https://nvidia.github.io/apex/amp.html#
https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users
Is there anything major you think is missing aside from checkpointing instructions?

Is there anything major you think is missing aside from checkpointing instructions?

So you are saying that the checkpoint loading problem is just something that is missing from the documentation and that I should be doing differently?

Anyways, I will convert all my models to use the new API tomorrow. If there are any questions that are not answered by the documentation I will get back to you.

No, I think you may have uncovered a legitimate problem. I'll keep you updated.

Still need to look into this, I've had some fires come up and I also have a GTC talk next week, but this is important

No worries! Now that I know what caused my drops in performance I can work around it in the meantime

Hi, I wanted to step by and ask what the status is. Have you had the opportunity to look into this yet?
Best,
Fabian

Unfortunately I haven't had a chance to look at this yet. The first 2.5 weeks last month were busy with GTC and in the last couple weeks I was working on updating the Amp backend to handle arbitrary models/optimizers/losses. As of yesterday the update is merged (https://github.com/NVIDIA/apex/pull/232). Check the new guidance if it's relevant to your use case.

I have to spend the next few days enabling some fused optimizers to work with https://github.com/NVIDIA/apex/pull/232, but I'm planning to dedicate next week to a thorough investigation of checkpointing.

Great, thank you so much! Let me know if I can help out!

Okay, I finally fixed this. The issue was that different sets of gradients were being passed into different processes. I did something like:

if partial_run:
    for p in model.module.part_that_only_runs_sometimes.parameters(): 
        p.grad = torch.zeros_like(p)

I'm also having some trouble with checkpoint loading, though not sure if it's related to this one. After loading a checkpoint the code simply hangs in the training loop indefinitely. When I kill the process, it's always stuck here: https://i.imgur.com/X6QiH3e.png

The code works perfectly fine without the checkpoint resuming. Not sure how to go about debugging, but glad to help in whatever way possible. The code is a bit complicated but I can share relevant parts privately if it helps. It may be related to @203 since I think the code is getting stuck in the backward pass, though I am calling set_device.

To make things even weirder, it looks like the model is able to execute backward passes a few times but then hangs. I put a print statement after the call to backward, and I get this output (the number is the rank of the GPU, I have 4):

after backward 3
after backward 1
after backward 0
after backward 2
after backward 3
after backward 1

I think it's related to delay_allreduce and requires_grad=False but not sure how.

Is there any update yet? I'm running into the same problem and cannot figure out what is causing this.
I tried to restore the model and continued to train it but it seemed that I was training from scratch not from checkpoint.

Hm. I found a way around by changing the order of init and loading new model.

  • amp.init -> load checkpoint: failed
  • load checkpoint -> amp.init: successful
    Hope it's helpful.

@FabianIsensee Thanks for the great script to reproduce this issue!
I changed a bit of your original code to be compatible with the latest apex version.
You can find the code here.

I trained the models using different seeds from scratch, tested them using

python repro.py --test_only -test_fnames mnist_seedX.model

where I changed X to the seeds and got the following standalone accuracies:

file mnist_seed1.model accuracy on test:  tensor(0.9951, device='cuda:0')
file mnist_seed2.model accuracy on test:  tensor(0.9950, device='cuda:0')
file mnist_seed3.model accuracy on test:  tensor(0.9962, device='cuda:0')

After that I tried to execute the all models sequentially to reproduce your issue using

python repro.py --test_only -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model

and got the same accuracies:

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
file mnist_seed1.model accuracy on test:  tensor(0.9951, device='cuda:0')
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
file mnist_seed2.model accuracy on test:  tensor(0.9950, device='cuda:0')
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
file mnist_seed3.model accuracy on test:  tensor(0.9962, device='cuda:0')

Changing the model order did not change anything and the models still performed well.

Could you check my code and compare it to yours?
Maybe I'm missing something or this issue was due to an amp.initialize call before the model was completely loaded (e.g. model.load_state_dict not yet called) as @npmhung also described?

Hi @ptrblck ,
thank you for your reply! Indeed, the error is gone in your code while it is still present in mine. You are also right by saying it has to do with the initialization (the snippets are modified from your version of the code):

This will result in the problem that I initially described

        network = amp.initialize(network, opt_level="O1")
        for f in args.test_fnames:
            network.load_state_dict(torch.load(f, map_location=torch.device('cuda', torch.cuda.current_device())))

            with torch.no_grad():
                network.eval()
                out = network(data_te)

                _, amax = out.max(dim=1)
                acc = (amax == target_te).float().mean()
                print("file", f, "accuracy on test: ", acc)

This is how you do it in you code. This snipped will give the correct test set results

        for f in args.test_fnames:
            network.load_state_dict(torch.load(f, map_location=torch.device('cuda', torch.cuda.current_device())))
            network = amp.initialize(network, opt_level="O1")

            with torch.no_grad():
                network.eval()
                out = network(data_te)

                _, amax = out.max(dim=1)
                acc = (amax == target_te).float().mean()
                print("file", f, "accuracy on test: ", acc)

So the question is: why do I need to re-initialize amp after loading a checkpoint?
Best,
Fabian

I would generally recommend to first initialize the complete module, call the mixed-precision routines (in this case amp.initialize), and then use DDP if desired, as stated in the docs:

amp.initialize should be called after you have finished constructing your model(s) and optimizer(s), but before you send your model through any DistributedDataParallel wrapper.

I'm a bit careful at giving the complete reason for this issue as I'm still trying to figure out what went wrong and how to prevent the re-occurrence of this issue.

I think your current workflow is a good example of a use case for repeated calls to amp.initialize(), since you are loading the models sequentially and not all at once (which might often not be possible due to memory limitation). I'm also looking into this issue, as it doesn't seem to be provided out of the box (although it's apparently working currently ;) )

Currently, amp.initialize should only be called once, although it can process an arbitrary number of models and optimizers (see the corresponding Advanced Amp Usage topic). If you think your use case requires amp.initialize to be called more than once, let us know.

Best,
ptrblck

Hi,
I am not sure if checkpoint loading should be considered part of initialize the complete module. We are not changing the network in any way, all we need to do is load all the parameters (and batch norm running stats). This is not uncommon: It often happens that experimenters want to load some checkpoint during training. Initializing amp over and over again is certainly not ideal and may introduce unwanted and unforseeable side-effects. For example, you would have to put the already initialized optimizer back into amp.initialize all the time:

(not real code, just something I came up with)

network = Network()
optimizer = Adam(network.parameters())
network, optimizer = amp.initialize(network, optimizer, opt_level="O1")

(run training up to some point)

(ah dang something bad happened, let's go back to some checkpoint)
network.load_state_dict(XX)
# now, according to the experiment in the posts above we would have to re-initialize amp to ensure it really works. 
network, optimizer = amp.initialize(network, optimizer, opt_level="O1")

I really don't know what will happen if we do this (putting the optimizer back into initialize).

My uneducated guess is that this entire issue is just some (probably easy to fix) bug. Due to amp using master parameters for fp16 training (at least in O1) something might get messed up on checkpoint loading (some parameters being loaded while others aren't? Or the optimizer/network pointing to a wrong set of parameters?) and that there must be an easy fix for this (provided that the person looking at it knows how amp works - unfortunately I don't :-D )
Best,
Fabian

Hi,
you are making valid points about restoring the model during training, and that's an interesting use case!

I assume if you are restoring the model during training, you would also store the optimizer's state_dict and reload it?

Hi,
yes you are right, that I would do as well.
I just think that there is plenty of cases where either net network, the optimizer or both need to be loaded from checkpoints. Initializing amp each time is something that ideally should not be needed. Even if re-initializing several times works, this would be yet another thing one could simply forget to do and wonder why the results are not as expected.

This seems to be fixed now, right? I do AMP initialize, followed by DDP, followed by model weight loading (since I save the model after the DDP call, it can't be loaded into the pre-DDP model), and I don't think I see this issue.

I can confirm that models saved with 'O2' which are loaded before initialization of apex fail miserably. However, when you first initialize apex, and then load the checkpoint everything (seems) to be fine.

@jonasteuwen
I have loss spike if I load model dict before initializing apex

@hadaev8 Do you also see this loss spike if you swap the order, i.e. initialize your model using amp before loading the state dict?

Which opt_level are you currently using?

@ptrblck tried both, 02.

@hadaev8 Thanks for the information. Are you also seeing this loss spike without using apex at all, i.e. just PyTorch?

@ptrblck
No, also 01 have no loss spike.

@hadaev8 We are currently working to merge a checkpointing PR, which might fix your issue. I'll ping you once it's merged, so that you could try to run your code again.

Thanks

@ptrblck I'm observing the same issue (restarting training on O2 spikes loss, but I can load O2-trained checkpoint in O1 and continue training without problem), could you broadcast on this thread when you guys merge the PR?

@ZhongxiaYan Sure, I'll post here once we merged it, so you can try it out.

Seems it related to the optimizer.
if I make a brand new one instead of loading, it will not spike.
With old adam val loss after one iteration is 0,2, with new 0.15 (around same before saving model).

@FabianIsensee, @dave-epstein, @jonasteuwen, @hadaev8, @ZhongxiaYan
Checkpointing just got merged into out master branch.
Checkout the README to see an example usage.

Is amp.state_dict() universal across all places the module is imported? For example, can I call amp.initialize in file A, train the model with amp in file B, and save the state dict in file C, with accessing amp by importing it separately in each file? Or do I have to import once and pass the imported module around between files?

Seems it related to the optimizer.
if I make a brand new one instead of loading, it will not spike.
With old adam val loss after one iteration is 0,2, with new 0.15 (around same before saving model).

Similar experience for me. I'm using the new APEX api with O2 level (using defaults for O2 configuration) for a 3d segmentation model. If I use a fresh optimizer, training seems to resume back to where it left off. I am also using the new amp.state_dict() option to load the state of amp. Below is my load_checkpoint function commenting out optimizer state loading, which allows for more "normal" resume-training-conditions. I quote "normal" because my loss and metric are not as good as the end of training before resuming. Can't figure out why yet.

This model is a modified DeepLabV3+ that uses torch.utils.checkpoint to help manage GPU memory, "ranger" as the optimizer https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer, Nvidia's partial convolutions https://github.com/NVIDIA/partialconv, and Tversky Loss as the criterion. I mention this because after 80 epochs with an input size of (1,3,320,320,320) and 90 iterations per epoch, I run into NaNs and I am trying to figure out why. Might add it as a new issue if I determine that it is related to APEX. The loss function is also included below in case someone might be able to spot something odd.

def load_checkpoint(model, optimizer, amp, fname):
    '''
    Loads the checkpoint to CPU.
    '''
    ckpt = torch.load(fname, map_location=lambda storage, loc: storage)
    model.load_state_dict(ckpt['mod_state_dict'])
    # optimizer.load_state_dict(ckpt['opt_state_dict'])
    amp.load_state_dict(ckpt['amp_state_dict'])
    epoch = ckpt['epoch']
    running_loss = ckpt['running_loss']
    running_metric = ckpt['running_metric']
    return model, optimizer, amp, epoch, running_loss, running_metric
class TverskyLoss3d(nn.Module):
    def __init__(self, num_classes, metric_type='tversky', tversky_beta=0.4,
                 device=1):
        super().__init__()
        self.num_classes = num_classes
        self.device = device
        self.eps = 1e-6
        if metric_type == 'dice':
            self.alpha, self.beta = 0.5, 0.5
        elif metric_type == 'tanimoto':
            self.alpha, self.beta = 1, 1
        elif metric_type == 'tversky':
            self.beta = tversky_beta
            self.alpha = 1 - self.beta

    def forward(self, pred, target):
        pred, target = pred.cuda(self.device), target.cuda(self.device)
        pred_soft = pred.exp()
        ones = pred_soft*0 + 1.0

        target_one_hot = one_hot(target, num_classes=self.num_classes,
                                 device=pred.device, dtype=pred.dtype)

        dims = (0,2,3,4)
        intersection = (pred_soft * target_one_hot).sum(dims)
        fps = torch.sum(pred_soft * (ones - target_one_hot), dims)
        fns = torch.sum((ones - pred_soft) * target_one_hot, dims)

        numerator = intersection
        denominator = intersection + self.alpha * fps + self.beta * fns
        # print(f'numerator: {numerator}, denominator: {denominator}')
        tversky_loss = (numerator / (denominator + self.eps)).sum()
        # print(f'tversky_loss: {tversky_loss}')
        return self.num_classes - tversky_loss

@dave-epstein amp.state_dict() will use the global _amp_state and thus should be accessible in each script, where you've imported amp.

@jpcenteno80 Do you have a reproducible code snippet for the loss spike after restoring (or the checkpoint)?
If you are running into NaN values in the forward pass, you could use the anomaly detection mode in PyTorch to get narrow down the function, which created these invalid values.

Thanks @ptrblck. I have been able to track down my NaNs in the loss to the output of a nn.Conv3d layer in which one value is -inf. My amp setup is O2 with the amp default settings. The output of the nn.conv3d layer is in FP16 when training using ampO2. The problem output value of the nn.Conv3d layer is -66610.1094 in FP32. However, when it gets cast to FP16, it turns to -inf. This output then goes to a nn.GroupNorm layer which ends up outputting NaNs at certain indices.

Is there a way of clipping the output of the nn.Conv3d layer so that the output stays within the FP16 range of -6.55e4 to 6.55e4? I wonder if that would slow things down and be inneficient. Do you have any other suggestions?

This only happens far into training, at the same nn.conv3d layer each time, when validation loss and metrics are "converging". The input to the problem nn.Conv3d layer is within a "normal" range, with values between -2000 to 2000, depending on the original input to the network.

# recreated operation in FP32 after saving conv3d weights and input tensors while training
conv3d = nn.Conv3d(364, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
conv3d_inp.size() #torch.Size([1, 364, 19, 19, 19])
conv3d_out = conv3d(conv3d_inp)
conv3d_out.min() #tensor(-66610.1094, device='cuda:0')

# cast to FP16
conv3d_out.half().min() #tensor(-inf, device='cuda:0', dtype=torch.float16)

So, i load model, then amp state dick and still have problems with loss spikes, any ideas?

Was this page helpful?
0 / 5 - 0 ratings