The following code snippet shows that Stax MLPs can be defined w.r.t. unbatched examples (input_size = (1,)) while Convnets seem to require a batch size (though it can be -1). Is this intended behavior?
# Works
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'), Relu,
Conv(64, (3, 3), padding='SAME'), Relu,
MaxPool((2, 2)), Flatten,
Dense(128), Relu,
Dense(10), LogSoftmax,
)
# Initialize parameters, not committing to a batch shape
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(in_shape)
# Works
net_init, net_apply = stax.serial(
Dense(40), Relu,
Dense(40), Relu,
Dense(1)
)
in_shape = (1,)
out_shape, net_params = net_init(in_shape)
# Doesn't Work
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'), Relu,
Conv(64, (3, 3), padding='SAME'), Relu,
MaxPool((2, 2)), Flatten,
Dense(128), Relu,
Dense(10), LogSoftmax,
)
in_shape = (28, 28, 1)
out_shape, net_params = net_init(in_shape)
The last one returns the following error:
IndexError Traceback (most recent call last)
9 # Initialize parameters, not committing to a batch shape
10 in_shape = (28, 28, 1)
---> 11 out_shape, net_params = net_init(in_shape)
google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape)
269 params = []
270 for init_fun in init_funs:
--> 271 input_shape, param = init_fun(input_shape)
272 params.append(param)
273 return input_shape, params
google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape)
109 kernel_shape = [out_chan if c == 'O' else
110 input_shape[lhs_spec.index('C')] if c == 'I' else
--> 111 next(filter_shape_iter) for c in rhs_spec]
112 output_shape = lax.conv_general_shape_tuple(
113 input_shape, kernel_shape, strides, padding, dimension_numbers)
IndexError: tuple index out of range
Actually, rereading the issue, I think this is originally-intended behavior, but could be revised.
stax.Conv (which perhaps should be called Conv2D) requires a batch dimension essentially because the underlying XLA HLO (and corresponding lax function) requires a batch dimension. Notice how lhs and rhs must have ranks n+2 for n spatial dimensions, +1 for a channel dimension and +1 for a batch dimension.
We could revise the stax layer and/or the underlying lax primitive to work without a batch dimension (probably the latter so that it's easier to inherit the behavior in the former). Would that be useful to you? I'm guessing that, in a vmap world, we should take seriously the fact that we can remove batch dimensions from all our library code, including stax.
I am fine with refactoring my code to use batched convs (that’s reasonable,
given the underlying primitive is vectorized for efficiency). But yeah,
this was surprising because I was operating under the assumption that the
“vmap world” you mention should support the ability to define nets that
don’t take into account the batch dimension.
On Fri, Feb 15, 2019 at 6:44 AM Matthew Johnson notifications@github.com
wrote:
Actually, rereading the issue, I think this is originally-intended
behavior, but could be revised.stax.Conv (which perhaps should be called Conv2D) requires a batch
dimension essentially because the underlying XLA HLO (and corresponding
lax function) requires a batch dimension
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
Notice how lhs and rhs must have ranks n+2 for n spatial dimensions, +1 for
a channel dimension and +1 for a batch dimension.We could revise the stax layer and/or the underlying lax primitive to work
without a batch dimension (probably the latter so that it's easier to
inherit the behavior in the former). Would that be useful to you? I'm
guessing that, in a vmap world, we should take seriously the fact that we
can remove batch dimensions from all our library code, including stax.—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/381#issuecomment-464074585, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAacMdgBezXnMPJi9kdFW2Km5qmv-MGxks5vNsflgaJpZM4a88fG
.
If we removed the batch dimension from stax, it's not obvious to me how to define batch norm.
Agreed, there definitely should be the flexibility to define batched nets
(for models whose forward pass requires minibatches). But given the
behavior of MLP, a user can easily suspect that Conv would automatically
pretend a singleton batch under the hood if ndims==3. It is confusing if
some primitives assume batching while others do not.
On Fri, Feb 15, 2019 at 7:02 AM Peter Hawkins notifications@github.com
wrote:
If we removed the batch dimension from stax, it's not obvious to me how to
define batch norm.—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/381#issuecomment-464080821, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAacMZHDAIZV7SL4uef---NOjzL3FVRFks5vNsv0gaJpZM4a88fG
.
The catch with allowing a batch dimension to be omitted in that way would be that then we would have an ambiguity when we support conv layers with different numbers of spatial dimensions. We could fix that by requiring, say, the conv layer to have an explicitly specified spatial dimension (e.g., Conv2D instead of simply Conv.)
Btw stax.Conv already is really a Conv2D, with stax.GeneralConv allowing arbitrary spatial dimensions:
Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
I don’t have an opinion on general conv vs. Conv2D etc, but in the short
term, some kind of error message indicating why Conv requires a batch
dimension would be helpful for debugging purposes. Some libraries (dynet)
do support autobatching, which I assumed with vmap that this was also the
case in Jax.
On Fri, Feb 15, 2019 at 7:08 AM Peter Hawkins notifications@github.com
wrote:
The catch with allowing a batch dimension to be omitted in that way would
be that then we would have an ambiguity when we support conv layers with
different numbers of spatial dimensions. We could fix that by requiring,
say, the conv layer to have an explicitly specified spatial dimension
(e.g., Conv2D instead of simply Conv.)—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/381#issuecomment-464083103, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAacMcpLvNv5D48RPG2pxPoK2q2Lccepks5vNs1sgaJpZM4a88fG
.
Just to clarify, you can vmap over functions that call lax.conv just fine; AIUI this is mainly a question of the (experimental) Stax API layer.
To put a finer point on it: we like to think in terms of _expressing_ and _transforming_ functions. Expressing means just writing Python+NumPy(+lax) code. Transformations are things like automatic differentiation, vmap, jit, etc.
This is actually an issue about expressing an _unbatched_ conv: our stax.Conv layer currently _requires_ a batch dimension (as does the underlying lax.conv function).
To contrast, it's not an issue about transforming (autobatching) convs: you can add batch dimensions to your heart's content. But you have to start with a minimum of one batch dimension. As you point out, that's something peculiar to conv and not shared by operations like dot.
To summarize, I'd revise your statement to say that JAX supports autobatching just fine (vmap transformations can add arbitrary batch dimensions), but our stax.Conv and lax.conv have this peculiarity inherited from XLA that you can't directly express a convolution with no batch dimensions.
In any case, I think we agree that we should figure out a way to tweak stax.Conv and/or lax.conv to enable expressing unbatched convolution operations.
We've got a plan! Will update this issue as we make progress.
+1 on making Stax operate on single examples.
+1 just got confused by net_apply(net_params, batch_X) working, but vmap(partial(net_apply, net_params))(batch_X) failing with cryptic shape error for simple conv net in stax...
`
+1 looking forward to having JAX fully batch agnostic! :)
Most helpful comment
Agreed, there definitely should be the flexibility to define batched nets
(for models whose forward pass requires minibatches). But given the
behavior of MLP, a user can easily suspect that Conv would automatically
pretend a singleton batch under the hood if ndims==3. It is confusing if
some primitives assume batching while others do not.
On Fri, Feb 15, 2019 at 7:02 AM Peter Hawkins notifications@github.com
wrote: