When using JIT, there is typically no problem passing in a minibatch with smaller batch size (as would happen for the last minibatch in a dataloader). However, when there is a masked distribution in the guide, an error appears which complains that tensor sizes do not match.
The error can be replicated using this code snippet. (The model and data have no real meaning... I was just going for a minimal snippet to reproduce the error.)
With USE_JIT=False, the code runs without error.
With USE_JIT=True, the code produces an error.
With USE_JIT=True, if you manually comment out the .mask() statement in the guide, the code runs without error.
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO
from pyro.optim import Adam
import torch
USE_JIT = True
def model(x):
with pyro.plate("data", size=x.shape[0]):
z = pyro.sample("z", dist.Normal(0., 5.).expand(x.shape).to_event(1))
x = pyro.sample("obs", dist.Normal(z, 1.).to_event(1), obs=x)
def guide(x):
offset = pyro.param("offset", torch.Tensor([0.]))
with pyro.plate("data", size=x.shape[0]):
masking = dist.Bernoulli(logits=x.sum(dim=-1, keepdim=False)).sample()
z = pyro.sample("z", dist.Normal(offset, 0.1).expand(x.shape).to_event(1).mask(masking))
pyro.clear_param_store()
pyro.enable_validation(True)
loss = JitTrace_ELBO() if USE_JIT else Trace_ELBO()
svi = SVI(model, guide, optim=Adam({'lr': 1e-2}), loss=loss)
offset_param = 5.
for i in range(5):
normalizer = 0.
epoch_loss = 0.
for _ in range(200):
# generate a fake minibatch of data
x = torch.randn((10, 200))
x = x + 2 * offset_param * (x.sum(dim=-1, keepdim=True) > 0).float()
# train
epoch_loss += svi.step(x)
normalizer += x.shape[0]
print(f'epoch {i} loss = {epoch_loss/normalizer}')
print(f'\noffset = {pyro.param("offset").detach()}')
# now pass in a minibatch of data with smaller size
x = torch.randn((5, 200))
svi.step(x) # induces error
The error message is the following:
RuntimeError:
The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 0 (infer_size at /Users/administrator/nightlies/pytorch-1.0.0/wheel_build_dirs/wheel_3.6/pytorch/aten/src/ATen/ExpandUtils.cpp:22)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x116583f47 in libc10.dylib)
frame #1: at::infer_size(c10::ArrayRef<long long>, c10::ArrayRef<long long>) + 523 (0x11b66a43b in libcaffe2.dylib)
frame #2: at::TensorIterator::compute_shape() + 466 (0x11b8a52a2 in libcaffe2.dylib)
frame #3: at::TensorIterator::Builder::build() + 39 (0x11b8a4697 in libcaffe2.dylib)
frame #4: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&) + 454 (0x11b8a4216 in libcaffe2.dylib)
frame #5: at::native::sub_out(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::Scalar) + 745 (0x11b732d19 in libcaffe2.dylib)
frame #6: at::native::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) + 48 (0x11b7335b0 in libcaffe2.dylib)
frame #7: at::TypeDefault::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) const + 247 (0x11bb356f7 in libcaffe2.dylib)
frame #8: torch::autograd::VariableType::sub_(at::Tensor&, at::Tensor const&, c10::Scalar) const + 1101 (0x1246a5add in libtorch.1.dylib)
frame #9: std::__1::__function::__func<torch::jit::(anonymous namespace)::$_494, std::__1::allocator<torch::jit::(anonymous namespace)::$_494>, int (std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&)>::operator()(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 163 (0x124bb5293 in libtorch.1.dylib)
frame #10: torch::jit::InterpreterStateImpl::runImpl(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 245 (0x124c2a2e5 in libtorch.1.dylib)
frame #11: torch::jit::InterpreterStateImpl::run(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 28 (0x124c229dc in libtorch.1.dylib)
frame #12: torch::jit::GraphExecutorImpl::run(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 4389 (0x124bf1ad5 in libtorch.1.dylib)
frame #13: torch::jit::script::Method::run(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >&) + 216 (0x116156d48 in libtorch_python.dylib)
frame #14: torch::jit::invokeScriptMethodFromPython(torch::jit::script::Method&, torch::jit::tuple_slice, pybind11::kwargs) + 163 (0x116156bb3 in libtorch_python.dylib)
frame #15: void pybind11::cpp_function::initialize<torch::jit::script::initJitScriptBindings(_object*)::$_21, pybind11::object, pybind11::args, pybind11::kwargs, pybind11::name, pybind11::is_method, pybind11::sibling>(torch::jit::script::initJitScriptBindings(_object*)::$_21&&, pybind11::object (*)(pybind11::args, pybind11::kwargs), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) + 269 (0x1161580cd in libtorch_python.dylib)
frame #16: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3348 (0x115dadb34 in libtorch_python.dylib)
<omitting python frames>
:
operation failed in interpreter:
/anaconda3/lib/python3.6/site-packages/pyro/distributions/util.py(189): scale_and_mask
/anaconda3/lib/python3.6/site-packages/pyro/distributions/score_parts.py(23): scale_and_mask
/anaconda3/lib/python3.6/site-packages/pyro/distributions/torch_distribution.py(278): score_parts
/anaconda3/lib/python3.6/site-packages/pyro/poutine/trace_struct.py(192): compute_score_parts
/anaconda3/lib/python3.6/site-packages/pyro/infer/enum.py(52): get_importance_trace
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(52): _get_trace
/anaconda3/lib/python3.6/site-packages/pyro/infer/elbo.py(163): _get_traces
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(170): loss_and_surrogate_loss
/anaconda3/lib/python3.6/site-packages/pyro/poutine/messenger.py(27): _wraps
/anaconda3/lib/python3.6/site-packages/pyro/ops/jit.py(84): compiled
/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py(635): trace
/anaconda3/lib/python3.6/site-packages/pyro/ops/jit.py(87): __call__
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(203): loss_and_surrogate_loss
/anaconda3/lib/python3.6/site-packages/pyro/infer/trace_elbo.py(212): loss_and_grads
/anaconda3/lib/python3.6/site-packages/pyro/infer/svi.py(99): step
<ipython-input-7-a390225aa054>(36): <module>
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3267): run_code
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3185): run_ast_nodes
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(3020): run_cell_async
/anaconda3/lib/python3.6/site-packages/IPython/core/async_helpers.py(67): _pseudo_sync_runner
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2845): _run_cell
/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2819): run_cell
/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py(537): run_cell
/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py(208): do_execute
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py(399): execute_request
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py(233): dispatch_shell
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py(283): dispatcher
/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(432): _run_callback
/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(480): _handle_recv
/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(450): _handle_events
/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/anaconda3/lib/python3.6/site-packages/tornado/platform/asyncio.py(117): _handle_events
/anaconda3/lib/python3.6/asyncio/events.py(145): _run
/anaconda3/lib/python3.6/asyncio/base_events.py(1432): _run_once
/anaconda3/lib/python3.6/asyncio/base_events.py(422): run_forever
/anaconda3/lib/python3.6/site-packages/tornado/platform/asyncio.py(127): start
/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py(486): start
/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py(658): launch_instance
/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py(16): <module>
/anaconda3/lib/python3.6/runpy.py(85): _run_code
/anaconda3/lib/python3.6/runpy.py(193): _run_module_as_main
Thanks for the clear bug report @sjfleming !
This appears to be a bug in scale_and_mask(). I am confused, since I thought we fixed this, and we even have a test_masked_fill() that exercises resizing. @neerajprad any ideas?
Thanks for the bug report. We'll add this as a test to test_jit.py, and debug further.
I have filed an issue for this in PyTorch - the problem is with the ~ operator which doesn't seem to generalize to different tensor shapes under JIT, and we use ~mask in the scale_and_mask function. I suppose we could work around by passing the inverted mask directly, but I don't think there is a backward compatible workaround that we could put in Pyro's dev branch.
@neerajprad would mask == 0 or 255 - mask serve as a workaround for ~mask? Nice sleuthing!
(mask == 0) works great! I can add that with a test.