Pyro: [bug] JIT plus .mask() errors when batch_size changes

Created on 18 Mar 2019  路  5Comments  路  Source: pyro-ppl/pyro

Issue Description

When using JIT, there is typically no problem passing in a minibatch with smaller batch size (as would happen for the last minibatch in a dataloader). However, when there is a masked distribution in the guide, an error appears which complains that tensor sizes do not match.

Environment

  • OS: reproduced on Mac 10.13.6 and Ubuntu 16.04
  • Python version: 3.6.5
  • PyTorch version: 1.0.0
  • Pyro version: 0.3.0

Code Snippet

The error can be replicated using this code snippet. (The model and data have no real meaning... I was just going for a minimal snippet to reproduce the error.)

With USE_JIT=False, the code runs without error.
With USE_JIT=True, the code produces an error.
With USE_JIT=True, if you manually comment out the .mask() statement in the guide, the code runs without error.

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO
from pyro.optim import Adam
import torch

USE_JIT = True

def model(x):
    with pyro.plate("data", size=x.shape[0]):
        z = pyro.sample("z", dist.Normal(0., 5.).expand(x.shape).to_event(1))
        x = pyro.sample("obs", dist.Normal(z, 1.).to_event(1), obs=x)

def guide(x):
    offset = pyro.param("offset", torch.Tensor([0.]))
    with pyro.plate("data", size=x.shape[0]):
        masking = dist.Bernoulli(logits=x.sum(dim=-1, keepdim=False)).sample()
        z = pyro.sample("z", dist.Normal(offset, 0.1).expand(x.shape).to_event(1).mask(masking))

pyro.clear_param_store()
pyro.enable_validation(True)

loss = JitTrace_ELBO() if USE_JIT else Trace_ELBO()
svi = SVI(model, guide, optim=Adam({'lr': 1e-2}), loss=loss)

offset_param = 5.

for i in range(5):
    normalizer = 0.
    epoch_loss = 0.
    for _ in range(200):
        # generate a fake minibatch of data
        x = torch.randn((10, 200)) 
        x = x + 2 * offset_param * (x.sum(dim=-1, keepdim=True) > 0).float()
        # train
        epoch_loss += svi.step(x)
        normalizer += x.shape[0]

    print(f'epoch {i} loss = {epoch_loss/normalizer}')

print(f'\noffset = {pyro.param("offset").detach()}')

# now pass in a minibatch of data with smaller size
x = torch.randn((5, 200))
svi.step(x)  # induces error

The error message is the following:

RuntimeError: 
The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 0 (infer_size at /Users/administrator/nightlies/pytorch-1.0.0/wheel_build_dirs/wheel_3.6/pytorch/aten/src/ATen/ExpandUtils.cpp:22)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x116583f47 in libc10.dylib)
frame #1: at::infer_size(c10::ArrayRef<long long>, c10::ArrayRef<long long>) + 523 (0x11b66a43b in libcaffe2.dylib)
frame #2: at::TensorIterator::compute_shape() + 466 (0x11b8a52a2 in libcaffe2.dylib)
frame #3: at::TensorIterator::Builder::build() + 39 (0x11b8a4697 in libcaffe2.dylib)
frame #4: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&) + 454 (0x11b8a4216 in libcaffe2.dylib)
frame #5: at::native::sub_out(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::Scalar) + 745 (0x11b732d19 in libcaffe2.dylib)
frame #6: at::native::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) + 48 (0x11b7335b0 in libcaffe2.dylib)
frame #7: at::TypeDefault::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) const + 247 (0x11bb356f7 in libcaffe2.dylib)
frame #8: torch::autograd::VariableType::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) const + 1101 (0x1246a5add in libtorch.1.dylib)
frame #9: std::__1::__function::__func<torch::jit::(anonymous namespace)::$_494, std::__1::allocator<torch::jit::(anonymous namespace)::$_494>, int (std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&)>::operator()(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 163 (0x124bb5293 in libtorch.1.dylib)
frame #10: torch::jit::InterpreterStateImpl::runImpl(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 245 (0x124c2a2e5 in libtorch.1.dylib)
frame #11: torch::jit::InterpreterStateImpl::run(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 28 (0x124c229dc in libtorch.1.dylib)
frame #12: torch::jit::GraphExecutorImpl::run(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 4389 (0x124bf1ad5 in libtorch.1.dylib)
frame #13: torch::jit::script::Method::run(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 216 (0x116156d48 in libtorch_python.dylib)
frame #14: torch::jit::invokeScriptMethodFromPython(torch::jit::script::Method&, torch::jit::tuple_slice, pybind11::kwargs) + 163 (0x116156bb3 in libtorch_python.dylib)
frame #15: void pybind11::cpp_function::initialize<torch::jit::script::initJitScriptBindings(_object*)::$_21, pybind11::object, pybind11::args, pybind11::kwargs, pybind11::name, pybind11::is_method, pybind11::sibling>(torch::jit::script::initJitScriptBindings(_object*)::$_21&&, pybind11::object (*)(pybind11::args, pybind11::kwargs), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) + 269 (0x1161580cd in libtorch_python.dylib)
frame #16: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3348 (0x115dadb34 in libtorch_python.dylib)
<omitting python frames>
:
operation failed in interpreter:
/anaconda3/lib/python3.6/site-packages/pyro/distributions/util.py(189): scale_and_mask
/anaconda3/lib/python3.6/site-packages/pyro/distributions/score_parts.py(23): scale_and_mask
/anaconda3/lib/python3.6/site-packages/pyro/distributions/torch_distribution.py(278): score_parts
/anaconda3/lib/python3.6/site-packages/pyro/poutine/trace_struct.py(192): compute_score_parts
/anaconda3/lib/python3.6/site-packages/pyro/infer/enum.py(52): get_importance_trace
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(52): _get_trace
/anaconda3/lib/python3.6/site-packages/pyro/infer/elbo.py(163): _get_traces
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(170): loss_and_surrogate_loss
/anaconda3/lib/python3.6/site-packages/pyro/poutine/messenger.py(27): _wraps
/anaconda3/lib/python3.6/site-packages/pyro/ops/jit.py(84): compiled
/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py(635): trace
/anaconda3/lib/python3.6/site-packages/pyro/ops/jit.py(87): __call__
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(203): loss_and_surrogate_loss
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(212): loss_and_grads
/anaconda3/lib/python3.6/site-packages/pyro/infer/svi.py(99): step
<ipython-input-7-a390225aa054>(36): <module>
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3267): run_code
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3185): run_ast_nodes
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3020): run_cell_async
/anaconda3/lib/python3.6/site-packages/IPython/core/async_helpers.py(67): _pseudo_sync_runner
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2845): _run_cell
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2819): run_cell
/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py(537): run_cell
/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py(208): do_execute
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py(399): execute_request
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py(233): dispatch_shell
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py(283): dispatcher
/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(432): _run_callback
/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(480): _handle_recv
/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(450): _handle_events
/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/anaconda3/lib/python3.6/site-packages/tornado/platform/asyncio.py(117): _handle_events
/anaconda3/lib/python3.6/asyncio/events.py(145): _run
/anaconda3/lib/python3.6/asyncio/base_events.py(1432): _run_once
/anaconda3/lib/python3.6/asyncio/base_events.py(422): run_forever
/anaconda3/lib/python3.6/site-packages/tornado/platform/asyncio.py(127): start
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py(486): start
/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py(658): launch_instance
/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py(16): <module>
/anaconda3/lib/python3.6/runpy.py(85): _run_code
/anaconda3/lib/python3.6/runpy.py(193): _run_module_as_main
bug jit

All 5 comments

Thanks for the clear bug report @sjfleming !

This appears to be a bug in scale_and_mask(). I am confused, since I thought we fixed this, and we even have a test_masked_fill() that exercises resizing. @neerajprad any ideas?

Thanks for the bug report. We'll add this as a test to test_jit.py, and debug further.

I have filed an issue for this in PyTorch - the problem is with the ~ operator which doesn't seem to generalize to different tensor shapes under JIT, and we use ~mask in the scale_and_mask function. I suppose we could work around by passing the inverted mask directly, but I don't think there is a backward compatible workaround that we could put in Pyro's dev branch.

@neerajprad would mask == 0 or 255 - mask serve as a workaround for ~mask? Nice sleuthing!

(mask == 0) works great! I can add that with a test.

Was this page helpful?
0 / 5 - 0 ratings