Memory leak encountered on CPU when sampling from a Dirichlet distribution in a guide over many training iterations. The leak only occurs when the parameters of the Dirichlet distribution are defined using param.
OS: Ubuntu 18.04.2
Python: 3.6.7
PyTorch: 1.1.0
Pyro: 0.3.3
import torch
from torch.distributions import constraints
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro.distributions as dist
a = 20
b = 100000
def model(x):
pass
def guide(x):
alpha = pyro.param('alpha', torch.ones(a, b), constraint=constraints.positive)
with pyro.plate('sample_x', a):
pyro.sample('x', dist.Dirichlet(alpha))
elbo = Trace_ELBO(max_plate_nesting=1, vectorize_particles=True)
opt = Adam({'lr': 0.0001})
svi = SVI(model, guide, opt, loss=elbo)
for epoch in range(500):
svi.step(1)
Changing alpha in the guide to
alpha = torch.ones(a, b) / b
stops the leak from occurring (ignore errors).
@ajrcampbell Thanks for the clear bug report! I'll try to reproduce.
This appears to be a PyTorch issue. I'll first try fixing torch._dirichlet_grad() upstream and then try to push a workaround patch to Pyro if possible. To reproduce the PyTorch issue it suffices to:
import resource, torch
concentration = torch.ones(20, 100000)
total = concentration.sum(-1, True).expand_as(concentration)
x = concentration / total
for epoch in range(500):
torch._dirichlet_grad(x, concentration, total)
print('maxrss = {}'.format(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss))
I believe this can be closed now.
How can I use the fix?
You can try installing PyTorch from source using the current master branch, or use the nightly builds published everyday (in which case you will have to wait until tomorrow I suppose).