Pyro: Make AbstractInfer classes lazily consume traces

Created on 30 Jan 2019  路  7Comments  路  Source: pyro-ppl/pyro

Our TracePosterior sub-classes like MCMC return a generator over traces. The TracePosterior.run method greedily consumes from this generator to hold a list with num_samples traces.

If the user does not require access to raw traces, but simply the joint / marginal distribution over some of the sample sites, holding the raw traces in memory is wasteful. By reducing the traces to some terminal representation as they are generated rather than holding all of them in memory, we can reduce the rate at which memory consumption grows with num_samples for models that generate large sized tensors or where the initial data is large. This will also be useful with num_chains > 1 on CUDA for which we need to copy tensors from worker processes to the main process: https://github.com/uber/pyro/pull/1694#issuecomment-456530927.

This is to propose a change in the interface so that we do not consume the traces greedily by calling .run. Instead traces can be consumed and relevant site values populated into EmpiricalDistribution when we call .marginal on these classes. Likewise, TracePredictive will only start running inference when .marginal is called. We can still preserve the .run method in case users need access to raw traces.

This will require a few changes to the EmpiricalDistribution and AbstractInfer classes, but I wanted to open this issue for discussion. @eb8680 - I think you had some additional ideas on refactoring these classes, please feel free to add your thoughts.

In addition to efficiency, this will also have some additional benefits (we should make sure that the refactor addresses these points):

  • SVI shouldn't need a separate num_samples kwargs. This has led to some confusion as users sometimes use the default num_samples=10 and simply resample from these traces when computing the posterior predictive. With lazy evaluation, this will be driven by the num_samples arg to TracePredictive.
  • The case of mismatched cond_indep_stacks between train and test as mentioned by @ahmadsalim in #1770.
discussion refactor

All 7 comments

+1 My use case is the following: I use autoguide, my SVI time is fast but it takes a bit of time (comparing to SVI time) to generate 10000 samples (marginal distribution) from posteriors. I believe that it will be faster if we only generate nodes' values (instead of traces). Currently, I do this way

for i in range(10000):
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
    model_poutine = poutine.trace(poutine.replay(model, trace=guide_trace))
    yield model_poutine.get_trace(*args, **kwargs), 1.0, 1.0

It will be faster if I can do the following in the run method:

for i in range(10000):
    yield guide(*args, **kwargs), 1.0, 1.0

Here guide is an instance of AutoGuide class, which generates a dictionary which maps each latent site to its corresponding value. Of course, I can go ahead and use the later version to generate marginal distribution but it would be nice to have an official support for it.

@neerajprad this seems like a reasonable change (and is how things used to work in Pyro very early on).

@fehiepsi I agree the first version is unnecessarily slow; we do things this way currently because return values may be different in model and guide, but we could just switch to the second version and drop support for computing marginals over return values or, once we switch to lazy aggregation, only do the first version if the return value marginal is requested by the user. The other big source of slowdown is the Trace data structure, which inherits from networkx.DiGraph but should probably just be an OrderedDict.

@eb8680 Yeah, I used minipyro to do some inferences (while debugging MCMC for GPU) and it works pretty well. Hope that we are able to use OrderedDict instead if it is a big source of slowdown.

The other big source of slowdown is the Trace data structure, which inherits from networkx.DiGraph but should probably just be an OrderedDict.

This will also greatly help with debugging a bunch of multiprocessing errors that arise out of sharing the trace data structure amongst different processes.

@eb8680 - How much effort do you think it is to deprecate our dependency on networkx? I can take a look and start refactoring. Are there any unknowns that we should be aware of?

How much effort do you think it is to deprecate our dependency on networkx?

The most minimal refactoring would be to copy the current Trace API exactly with something like the following:

class Trace(namedtuple("_trace", ["nodes", "edges", "graph", "graph_type"])):

    def __init__(self, graph_type="flat"):
        super(Trace, self).__init__(nodes=OrderedDict(), edges=OrderedDict(),
            graph=OrderedDict(), graph_type=graph_type)

    def add_node(self, site_name, **kwargs):
        # duplicate checking logic here
        ...
        self.nodes[site_name] = kwargs.copy()

    def copy(self):
        # preserve current shallow copy semantics
        new_tr = Trace(graph_type=self.graph_type)
        new_tr.nodes.update(self.nodes)
        new_tr.edges.update(self.edges)
        new_tr.graph.update(self.graph)
        return new_tr

    ... # plus all the other methods of Trace that only depend on Trace.nodes/edges

I think the only place where functionality beyond the above is used is in TraceGraph_ELBO, but that could be ported fairly easily.

I think the only place where functionality beyond the above is used is in TraceGraph_ELBO, but that could be ported fairly easily.

Great, I'll start taking a look at this then.

Closing this in favor of #1930.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

neerajprad picture neerajprad  路  4Comments

fehiepsi picture fehiepsi  路  3Comments

neerajprad picture neerajprad  路  4Comments

neerajprad picture neerajprad  路  4Comments

eb8680 picture eb8680  路  4Comments