Pyro: Memory leak using Dirichlet distribution in guide

Created on 7 May 2019  路  5Comments  路  Source: pyro-ppl/pyro

Issue Description

Memory leak encountered on CPU when sampling from a Dirichlet distribution in a guide over many training iterations. The leak only occurs when the parameters of the Dirichlet distribution are defined using param.

Environment

OS: Ubuntu 18.04.2
Python: 3.6.7
PyTorch: 1.1.0
Pyro: 0.3.3

Code Snippet

import torch
from torch.distributions import constraints

import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro.distributions as dist


a = 20
b = 100000

def model(x):
    pass

def guide(x):
    alpha = pyro.param('alpha', torch.ones(a, b), constraint=constraints.positive)                  
    with pyro.plate('sample_x', a):
        pyro.sample('x', dist.Dirichlet(alpha))

elbo = Trace_ELBO(max_plate_nesting=1, vectorize_particles=True)
opt = Adam({'lr': 0.0001})
svi = SVI(model, guide, opt, loss=elbo)

for epoch in range(500):
    svi.step(1)

Changing alpha in the guide to

alpha = torch.ones(a, b) / b 

stops the leak from occurring (ignore errors).

bug

All 5 comments

@ajrcampbell Thanks for the clear bug report! I'll try to reproduce.

This appears to be a PyTorch issue. I'll first try fixing torch._dirichlet_grad() upstream and then try to push a workaround patch to Pyro if possible. To reproduce the PyTorch issue it suffices to:

import resource, torch
concentration = torch.ones(20, 100000)
total = concentration.sum(-1, True).expand_as(concentration)
x = concentration / total
for epoch in range(500):
    torch._dirichlet_grad(x, concentration, total)
    print('maxrss = {}'.format(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss))

I believe this can be closed now.

How can I use the fix?

You can try installing PyTorch from source using the current master branch, or use the nightly builds published everyday (in which case you will have to wait until tomorrow I suppose).

Was this page helpful?
0 / 5 - 0 ratings