Apex: fix batchnorm when training may lead to RuntimeError: expected scalar type Half but found Float

Created on 11 Jan 2019  路  1Comment  路  Source: NVIDIA/apex

fix batchnorm may lead to RuntimeError: expected scalar type Half but found Float

It is Ok when simply forward the input. But the error occurs when backward the loss.

Here is a short demo.

import torch
import torch.nn as nn
from torch.nn import init
from torchvision import models
from torch.autograd import Variable
from apex.fp16_utils import *

def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

model = models.resnet50(pretrained=True)
model.cuda()
model = network_to_half(model)
model.train()
model.apply(fix_bn) # fix batchnorm
input = Variable(torch.FloatTensor(8, 3, 224, 224).cuda().half())
output = model(input)
output_mean = torch.mean(output)
output_mean.backward()

Most helpful comment

Please do

def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval().half()

Reason for this is, for regular training it is better (performance-wise) to use cudnn batch norm, which requires its weights to be in fp32, thus batch norm modules are not converted to half in network_to_half. However, cudnn does not support batchnorm backward in the eval mode, which is what you are doing, and to use pytorch implementation for this, weights have to be of the same type as inputs.

>All comments

Please do

def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval().half()

Reason for this is, for regular training it is better (performance-wise) to use cudnn batch norm, which requires its weights to be in fp32, thus batch norm modules are not converted to half in network_to_half. However, cudnn does not support batchnorm backward in the eval mode, which is what you are doing, and to use pytorch implementation for this, weights have to be of the same type as inputs.

Was this page helpful?
0 / 5 - 0 ratings