fix batchnorm may lead to RuntimeError: expected scalar type Half but found Float
It is Ok when simply forward the input. But the error occurs when backward the loss.
Here is a short demo.
import torch
import torch.nn as nn
from torch.nn import init
from torchvision import models
from torch.autograd import Variable
from apex.fp16_utils import *
def fix_bn(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
model = models.resnet50(pretrained=True)
model.cuda()
model = network_to_half(model)
model.train()
model.apply(fix_bn) # fix batchnorm
input = Variable(torch.FloatTensor(8, 3, 224, 224).cuda().half())
output = model(input)
output_mean = torch.mean(output)
output_mean.backward()
Please do
def fix_bn(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval().half()
Reason for this is, for regular training it is better (performance-wise) to use cudnn batch norm, which requires its weights to be in fp32, thus batch norm modules are not converted to half in network_to_half. However, cudnn does not support batchnorm backward in the eval mode, which is what you are doing, and to use pytorch implementation for this, weights have to be of the same type as inputs.
Most helpful comment
Please do
Reason for this is, for regular training it is better (performance-wise) to use cudnn batch norm, which requires its weights to be in fp32, thus batch norm modules are not converted to half in
network_to_half. However, cudnn does not support batchnorm backward in the eval mode, which is what you are doing, and to use pytorch implementation for this, weights have to be of the same type as inputs.