I'm training a simple feedforward neural network which outputs parameters of a beta distribution from which I take samples, and use it with BCELoss. The problem is that the memory (RAM) usage increases with number of training iterations and it starts thrashing soon after, giving me no choice but to force-shut-down the computer (code below). The GPU usage is stable.
Here is the network architecture:
from pyro.distributions.torch import Beta
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
class WNetBeta(nn.Module):
def __init__(self, in_ch, out_ch, C0=32):
super(WNetBeta, self).__init__()
self.conv0 = conv(in_ch, C0, (1, 3, 3), padding=(0, 1, 1))
self.resblock1 = ResBlock(C0, C0)
self.resblock2 = ResBlock(C0, C0)
self.tconv1 = tconv(C0, C0)
self.avgsample = nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2))
self.resblock3 = ResBlock(C0, C0)
self.resblock4 = ResBlock(C0, C0)
self.tconv2 = tconv(C0, C0)
self.pconv1 = pconv(C0, out_ch, (1, 3, 3), padding=(0, 1, 1))
def forward(self, x):
data = self.conv0(x)
data = self.resblock1(data)
data = self.resblock2(data)
data = self.tconv1(data)
data = self.avgsample(data)
data = self.resblock3(data)
data = self.resblock4(data)
data = self.tconv2(data)
out = self.pconv1(data)
## get alpha, beta
logalpha = out[:, 0]
logbeta = out[:, 1]
alpha = torch.exp(logalpha)
beta = torch.exp(logbeta)
# Get samples now
T = 10
alphar = alpha.unsqueeze(1).repeat(1, T, 1, 1, 1, 1)
betar = beta.unsqueeze(1).repeat(1, T, 1, 1, 1, 1)
m = Beta(alphar, betar)
p = m.rsample().mean(1)
return p, logalpha, logbeta
And here is my training loop:
model = WNetBeta(4, 2).cuda()
# define optimizer
optim = Adam(model.parameters(), lr=2e-5, weight_decay=1e-7)
#optim = SGD(model.parameters(), lr=1e-5, momentum=0.3)
if os.path.exists(optim_name):
optim.load_state_dict(torch.load(optim_name))
# dataloader
train_ds = Loader('train')
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
val_ds = Loader('val')
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)
# loss function
loss_fn = nn.BCELoss()
# training loop
for i in range(EPOCHS):
model.train()
for j, (patch, whole_label) in enumerate(train_loader):
# get patch
optim.zero_grad()
patch = Variable(patch).cuda()
whole_label = Variable(whole_label).cuda()
p, log_alpha, log_beta = model(patch)
loss_val = loss_fn(p, whole_label)
loss_val.backward()
# Step
optim.step()
Thank you in advance.
Edit: I'm not putting up the definitions of the different modules (ResBlock, pconv, etc.) because they just normal convolution layers with different activations and stuff, so I think it isn't relevant and would just clutter up things.
This problem doesn't occur in case of a "baseline" which simply outputs a logit that is trained with BCELoss, so the case that the memory leak is due to a faulty data loader is unlikely.
this appears to be pure pytorch code (no pyro). so the appropriate place to post your question would be https://discuss.pytorch.org/
This problem, however doesn't occur when I use the Beta distribution of PyTorch. But since rsample doesn't have the backward implementation for GPU in PyTorch (and hence making it slow), I used Pyro's Beta distribution. So this seems like a Pyro issue.
The Beta used is from Pyro, not PyTorch.
if that's the case you need to make your code snippet much clearer. show the relevant import statements, etc
I apologize for not making it clear enough. I have updated the post. Please let me know if I should add/modify anything else.
thanks for clarifying!
@neerajprad any idea what might be going on here?
@rohitrango can you try reproducing your issue with PyTorch 0.4.0?
@rohitrango - A lot of distributions are broken in PyTorch 0.4.1. Could you try with PyTorch 0.4.0 instead?
however doesn't occur when I use the Beta distribution of PyTorch. But since rsample doesn't have the backward implementation for GPU in PyTorch (and hence making it slow), I used Pyro's Beta distribution.
Pyro's Beta distribution (like most other Pyro distributions that have a torch counterpart) is merely a wrapper over torch.distributions.Beta. The reason why it is not throwing an error on the GPU is because we have patched over torch._standard_gamma to transfer the parameters and do the required operation on the CPU and transfer back to the GPU, so any memory leak issues that you see with the earlier version will likely also be in the torch version.
I'm afraid that PyTorch 0.4.0 doesn't solve the problem either. I'm using htop to check the memory usage, and it just grows as the iterations continue.
Edit: What's more interesting is that this memory leak happens only while training. While testing, I didn't see any memory issues.
@rohitrango Can you create a smaller example to replicate the problem? Here are my suggestions:
train in just a few iterations (to avoid shutdown) repeat before sampling, I don't understand why you need repeat hear (instead of feeding sample_shape into rsample), but I have seen many problems with repeat/expand + log_prob in pytorch 0.4.1..contiguous() at various places.This problem, however doesn't occur when I use the Beta distribution of PyTorch.
Do you mean that when you run this on the CPU using torch.distributions.Beta you don't see a memory leak, but when you use the Pyro version on the CPU, you find this to be an issue?
If so, can you try running this using the torch-1.0 branch of Pyro (installation instructions) and the current PyTorch release and confirm if you still see this issue?
@rohitrango Just to clarify what might be going on: Pyro's Beta actually doesn't implement GPU reparametrization, instead import pyro.distributions triggers a patch to PyTorch's underlying gradients to do a GPU->CPU->GPU conversion for that computation. I'm curious whether you could get the same effect by using the PyTorch Beta but including import pyro.distributions.torch_patch at the top of your script.
Hi,
I tried using torch-1.0 and its corresponding pyro distribution. The same problem still occurs.
Also, I tried using torch.distributions.Beta and converting the tensors to cpu and back to gpu after sampling for the backward pass. The same problem occurs without pyro, so I think its a PyTorch problem. I didn't notice the differences earlier but with a bigger example, its clear.
I observe similar behavior on pytorch 0.41 and 1.0; why does this op need to transfer from gpu to cpu and vice versa? Is there some native op that is missing a gpu implementation?
Yes. Dirichlet grad is not supported on GPU.
Closing this - let us track this issue upstream in https://github.com/pytorch/pytorch/issues/11030 where it belongs.
When we do the dirichlet resampling, we meet the same issue that "cuda out of memory".Did someone solve this problem yet?

When we do the dirichlet resampling, we meet the same issue that "cuda out of memory"
I am not sure if it is related to the above issue. Regardless, the right place for further discussing this (please cc. Pyro developers) should be either https://github.com/pytorch/pytorch/issues/11030, or a new issue in pytorch with your specific code example.