Incubator-mxnet: Get HybridBlock layer shape on runtime

Created on 3 Jan 2018  路  18Comments  路  Source: apache/incubator-mxnet

Dear all,

I am trying to build a custom pooling layer (both for ndarray and Symbol) and I need to know the input shape at runtime. According to the documentation, HybridBlock has the function "infer_shape", but I can't make it work. Any pointers into what I am doing wrong?

mxnet version

1.0.0 , build from conda, python3.

Minimum reproducible example

For example:

import mxnet as mx
import mxnet.ndarray as nd
from mxnet.gluon import HybridBlock

class runtime_shape(HybridBlock):


    def __init__(self,  **kwards):
        HybridBlock.__init__(self,**kwards)


    def hybrid_forward(self,F,_input):

        print (self.infer_shape(_input))

        return _input

xx = nd.random_uniform(shape=[5,5,16,16])

mynet = runtime_shape()
mynet.hybrid_forward(nd,xx)

Error Message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-41-3f539a940958> in <module>()
----> 1 mynet.hybrid_forward(nd,xx)

<ipython-input-38-afc9785b716d> in hybrid_forward(self, F, _input)
     17     def hybrid_forward(self,F,_input):
     18 
---> 19         print (self.infer_shape(_input))
     20 
     21         return _input

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in infer_shape(self, *args)
    460     def infer_shape(self, *args):
    461         """Infers shape of Parameters from inputs."""
--> 462         self._infer_attrs('infer_shape', 'shape', *args)
    463 
    464     def infer_type(self, *args):

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _infer_attrs(self, infer_fn, attr, *args)
    448     def _infer_attrs(self, infer_fn, attr, *args):
    449         """Generic infer attributes."""
--> 450         inputs, out = self._get_graph(*args)
    451         args, _ = _flatten(args)
    452         arg_attrs, _, aux_attrs = getattr(out, infer_fn)(

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _get_graph(self, *args)
    369             params = {i: j.var() for i, j in self._reg_params.items()}
    370             with self.name_scope():
--> 371                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
    372             out, self._out_format = _flatten(out)
    373 

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in __exit__(self, ptype, value, trace)
     78         if self._block._empty_prefix:
     79             return
---> 80         self._name_scope.__exit__(ptype, value, trace)
     81         self._name_scope = None
     82         _BlockScope._current = self._old_scope

AttributeError: 'NoneType' object has no attribute '__exit__'

Pending Requester Info

Most helpful comment

Are you trying to get the input shape at runtime?
If so, infer_shape is not the right API.

You may want to try something like:

import mxnet as mx
import mxnet.ndarray as nd
from mxnet.gluon import HybridBlock

class runtime_shape(HybridBlock):

    def __init__(self,  **kwards):
        HybridBlock.__init__(self,**kwards)

    def hybrid_forward(self,F,_input):
        print('input shape: {}'.format(_input.shape))
        return _input

xx = nd.random_uniform(shape=[5,5,16,16])
mynet = runtime_shape()
mynet.hybrid_forward(nd,xx)

Which returns:
input shape: (5L, 5L, 16L, 16L)

All 18 comments

Are you trying to get the input shape at runtime?
If so, infer_shape is not the right API.

You may want to try something like:

import mxnet as mx
import mxnet.ndarray as nd
from mxnet.gluon import HybridBlock

class runtime_shape(HybridBlock):

    def __init__(self,  **kwards):
        HybridBlock.__init__(self,**kwards)

    def hybrid_forward(self,F,_input):
        print('input shape: {}'.format(_input.shape))
        return _input

xx = nd.random_uniform(shape=[5,5,16,16])
mynet = runtime_shape()
mynet.hybrid_forward(nd,xx)

Which returns:
input shape: (5L, 5L, 16L, 16L)

@nswamy please apply labels:
Triaged
Pending Requester Info
Request For Information

Hi @lupesko , thank you very much for your reply. Perhaps this is something I am not understanding very well (the way to properly use hybridize call in a unified way both with Symbol and NDArray). I've just started learning Gluon + mxnet.

Your solution works (I actually have this solution to my code) when the _input variable is an NDArray type. But then I want a unified way of finding the shape at runtime both if I feed in an NDArray or a Symbol. As I understand it, the whole point of HybridBlock is to be able to use it both with Symbol and NDArray (and it's the former I have trouble with).

This modification of your example:

xx = nd.random_uniform(shape=[5,5,16,16])
mynet = runtime_shape()

mynet.hybridize() # This is the modification line

mynet.hybrid_forward(nd,xx)

Results the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-13-dd8cecdcffd1> in <module>()
----> 1 temp = mynet(xx)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in __call__(self, *args)
    302     def __call__(self, *args):
    303         """Calls forward. Only accepts positional arguments."""
--> 304         return self.forward(*args)
    305 
    306     def forward(self, *args):

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in forward(self, x, *args)
    504                 try:
    505                     if self._active:
--> 506                         return self._call_cached_op(x, *args)
    507                     params = {i: j.data(ctx) for i, j in self._reg_params.items()}
    508                 except DeferredInitializationError:

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _call_cached_op(self, *args)
    412     def _call_cached_op(self, *args):
    413         if self._cached_op is None:
--> 414             self._build_cache(*args)
    415 
    416         args, fmt = _flatten(args)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _build_cache(self, *args)
    377 
    378     def _build_cache(self, *args):
--> 379         inputs, out = self._get_graph(*args)
    380         input_idx = {var.name: i for i, var in enumerate(inputs)}
    381         self._cached_op = ndarray.CachedOp(out)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _get_graph(self, *args)
    369             params = {i: j.var() for i, j in self._reg_params.items()}
    370             with self.name_scope():
--> 371                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
    372             out, self._out_format = _flatten(out)
    373 

<ipython-input-10-451e3d4c189d> in hybrid_forward(self, F, _input)
      9 
     10     def hybrid_forward(self,F,_input):
---> 11         print('input shape: {}'.format(_input.shape))
     12         return _input
     13 

AttributeError: 'Symbol' object has no attribute 'shape'

Another example of the same error can be seen in this Gluon example (from here)

import mxnet as mx
import mxnet.ndarray as nd

from mxnet import gluon

class Net(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            self.fc1 = gluon.nn.Dense(256)
            self.fc2 = gluon.nn.Dense(128)
            self.fc3 = gluon.nn.Dense(2)

    def hybrid_forward(self, F, x):
        # F is a function space that depends on the type of x
        # If x's type is NDArray, then F will be mxnet.nd
        # If x's type is Symbol, then F will be mxnet.sym
        print('type(x): {}, F: {}'.format(
                type(x).__name__, F.__name__))

        # ----------------------------------------------------------------------------------------------
        # This is the line that I added, to infer the shape, based on your proposal
        print('input shape: {}'.format(x.shape)) 
        # ----------------------------------------------------------------------------------------------
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


# First run with NDArray works:
net = Net()
net.collect_params().initialize()
x = nd.random_normal(shape=(1, 512))
print('=== 1st forward ===')
y = net(x)
print('=== 2nd forward ===')
y = net(x)

Output:

=== 1st forward ===
type(x): NDArray, F: mxnet.ndarray
input shape: (1L, 512L)
=== 2nd forward ===
type(x): NDArray, F: mxnet.ndarray
input shape: (1L, 512L)

But if you call hybridize, you get an error:

net.hybridize()
print('=== 1st forward ===')
y = net(x)
print('=== 2nd forward ===')
y = net(x)

Error:

=== 1st forward ===
type(x): Symbol, F: mxnet.symbol
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-6-984f1343eb63> in <module>()
      1 net.hybridize()
      2 print('=== 1st forward ===')
----> 3 y = net(x)
      4 print('=== 2nd forward ===')
      5 y = net(x)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in __call__(self, *args)
    302     def __call__(self, *args):
    303         """Calls forward. Only accepts positional arguments."""
--> 304         return self.forward(*args)
    305 
    306     def forward(self, *args):

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in forward(self, x, *args)
    504                 try:
    505                     if self._active:
--> 506                         return self._call_cached_op(x, *args)
    507                     params = {i: j.data(ctx) for i, j in self._reg_params.items()}
    508                 except DeferredInitializationError:

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _call_cached_op(self, *args)
    412     def _call_cached_op(self, *args):
    413         if self._cached_op is None:
--> 414             self._build_cache(*args)
    415 
    416         args, fmt = _flatten(args)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _build_cache(self, *args)
    377 
    378     def _build_cache(self, *args):
--> 379         inputs, out = self._get_graph(*args)
    380         input_idx = {var.name: i for i, var in enumerate(inputs)}
    381         self._cached_op = ndarray.CachedOp(out)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _get_graph(self, *args)
    369             params = {i: j.var() for i, j in self._reg_params.items()}
    370             with self.name_scope():
--> 371                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
    372             out, self._out_format = _flatten(out)
    373 

<ipython-input-3-53434f883f09> in hybrid_forward(self, F, x)
     20 
     21 
---> 22         print('input shape: {}'.format(x.shape))
     23 
     24         x = F.relu(self.fc1(x))

AttributeError: 'Symbol' object has no attribute 'shape'

And to make the idea of what I need more concrete, I want to implement the pyramid scene parsing pooling (from here) as a custom layer (inheritance from HybridBlock) in which I need to apply pooling in 4 scales (along with convolutions 2D). And to my understanding, I need to know the input layer size, so I can infer correct proposals for kernel/stride sizes in the F.Pooling operator. The point is, I want to test/debug with NDArray and then call hybridize for efficient computations.

Again thank you very much for your answer.

Any updates on this issue?

infer_shape might have been a misnomer because what it does is actually inferring the shapes of parameters, and what you expect as well as the infer_shape that symbol provides is for inferring output shape.

In general, in HybridBlock you shouldn't depend on the exact shape of input, because otherwise it indicates that the computation graph may change depending on the shape, thus rendering it unsuitable for HybridBlock. That said, there are plenty of ways to get around the dependency on the exact value of shape (e.g. special placeholders in reshape, XX_like operators such as zeros_like).

If you could share exactly how you depend on the shape, others and I might be able to help more.

@szha thank you very much for your reply. I am trying to implement the pyramid scene parsing pooling layer, from here . For this implementation one needs to perform max pooling (or other) in a set of scales: a) global pooling, b) 1/2, c) 1/4, d) 1/8, then rescale and then concatenate results. Full details in the arxiv paper. My implementation for this module (am interested only in the pooling module, not the whole network architecture), is the following (works with NDArray, but I cannot call hybridize):

from mxnet import gluon
from mxnet.gluon import  HybridBlock
from mxnet.ndarray import NDArray
from conv2Dnormed import *

class PSP_Pooling(HybridBlock):

    """
    Pyramid Scene Parsing pooling layer, as defined in Zhao et al. 2017 (https://arxiv.org/abs/1612.01105)        
    This is only the pyramid pooling module. 
    INPUT:
        layer of size Nbatch, Nchannel, H, W
    OUTPUT:
        layer of size Nbatch,  Nchannel, H, W. 

    """

    def __init__(self, _nfilters, _norm_type = 'BatchNorm', **kwards):
        HybridBlock.__init__(self,**kwards)

        self.nfilters = _nfilters

        with self.name_scope():

            self.conv1 = gluon.nn.Conv2D(self.nfilters/4,kernel_size=(1,1),padding=0, prefix="_conv1_")
            self.conv2 = gluon.nn.Conv2D(self.nfilters/4,kernel_size=(1,1),padding=0, prefix="_conv2_")
            self.conv3 = gluon.nn.Conv2D(self.nfilters/4,kernel_size=(1,1),padding=0, prefix="_conv3_")
            self.conv4 = gluon.nn.Conv2D(self.nfilters/4,kernel_size=(1,1),padding=0, prefix="_conv4_")

        # Hopefully this creates a list of functions by reference 
        self.convs = [ self.conv1, self.conv2, self.conv3, self.conv4 ]

        # This is a custom Conv2D followed by either BatchNormalization or InstanceNormalization layer
        self.conv_norm_final = Conv2DNormed(channels = self.nfilters,
                                            kernel_size=(1,1),
                                            padding=(0,0),
                                            _norm_type=_norm_type)



    def hybrid_forward(self,F,_input):

        # This works ONLY for NDArray input :( 

        # Also this if statement could be slowing down the performance. 
        if isinstance(_input,NDArray):
            layer_size = _input.shape[2]
        else :
             raise NotImplementedError
#             layer_size = _input.infer_shape()
#             print layer_size


        p = [_input]

        for i in range(4):
            # @@@@@@@@@@@@@ Important part @@@@@@@@@@@@@@@
            # This is where I need to know the layer size, so I can apply appropriate pooling
            pool_size = layer_size / (2**i)
            # @@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@@@@
            x = F.Pooling(_input,kernel=[pool_size,pool_size],stride=[pool_size,pool_size],pool_type='max')
            x = F.UpSampling(x,sample_type='nearest',scale=pool_size)
            x = self.convs[i](x)
            p += [x]

        out = F.concat(p[0],p[1],p[2],p[3],p[4],dim=1)

        out = self.conv_norm_final(out)
        out = F.relu(out)

        return out

Thank you very much for your time.

Dear all, do we have any progress on this?
Thank you very much for your time.

Update: I used the
special params (0,-1,-2,-3,-4) of Reshape to get rid of using the Tensor's specific shape. Thanks!
<==
Any progress? I also have an operation to reshape a Symbol/NDArray depending on its shape. How to get the shape of the output Tensor?

Me too! I need a way to resize the output of a HybridBlock to the same size as the input for it. This is for an image segmentation task. I'm using the new-fangled BilinearResize2D layer, but this only accepts a width and height parameter. If only it could alternatively accept a Symbol as an argument and resize based on this...

Hi to all,

@safrooze gave the solution in this topic in the discussion forum. The trick is to overwrite the forward function, and getting the layer shape in there. Example

from mxnet import gluon

class GetShape(gluon.HybridBlock):
    def __init__(self,nchannels=0, kernel_size=(3,3), **kwards):
        gluon.HybridBlock.__init__(self,**kwards)

        self.layer_shape = None

        with self.name_scope():
            self.conv = gluon.nn.Conv2D(nchannels,kernel_size=kernel_size)



    def forward(self,x):
        self.layer_shape = x.shape

        return gluon.HybridBlock.forward(self,x)

    def hybrid_forward(self,F,x):
        print (self.layer_shape)
        out = self.conv(x)
        return out

mynet = GetShape(nchannels=12)
mynet.hybridize()

mynet.initialize(mx.init.Xavier(),ctx=ctx)
xx = nd.random.uniform(shape=[32,8,128,128])
out = mynet(xx)
# prints (32, 8, 128, 128)

Thank you to the community, am closing this.

Nope, it doesn't work if we feed another HybridBlock as input. See forum discussion.

@szha

If you could share exactly how you depend on the shape, others and I might be able to help more.

I depend on the shape as :

I have a y : (batch , c, h*w), and x: (batch, c, h, w)

h and w is not determined. Before I use net.hybridize(), I can use y.reshape(0, 0, *x.shape[2:]), however, after the hybridize, the symbol don't have the shape attribute, so how to achieve the same purpose?

how could I do? thank you very much.

Hi @ShoufaChen, your problem is easy as for this there exist "special" dimensions numbers, see nd.array.reshape operation. The same exists for Symbol

class CustomNet(HybridBlock):
    def __init__(self,**kwards):
        HybridBlock.__init__(self,**kwards)

        with self.name_scope(): 
              # do something here 

    def hybrid_forward(self,F,input):
        # reshape input with product of two last dimensions
        b = F.reshape(input,shape=[0,0,-3]) 
        # now b has dimensions input.shape[0],input.shape[0],input.shape[2] * input.shape[3]
        return b

you can use this trick inside your custom layer definition. Hope this helps.

edit: an ipython example

In [1]: import mxnet as mx 

In [2]: from mxnet import nd

In [3]: a = nd.random.uniform(shape=[5,3,12,12])

In [4]: b= nd.reshape(a,shape=(0,0,-3))

In [5]: b.shape
Out[5]: (5, 3, 144)

@feevos
thank you very much for your reply.
Maybe I don't represent my problem clearly. In fact, I want to 'split' the dimension, rather than multiply the last two dimension.

Is there any solution?

@ShoufaChen I am sorry, I didn't understand your question well in the beginning. Unfortunately I do not know a solution to the reverse problem you are after.

@feevos Thank you all the same. 馃槃

@ShoufaChen At the moment you only have one good option: define a self.shape in your block and add a set_shape() function that you call from the training top level using data's shape before you pass data into the function.

@feevos Did the suggestions on this thread help resolve your issue ?

If yes, please feel free to close the issue :)

Thanks!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Shiro-LK picture Shiro-LK  路  3Comments

xzqjack picture xzqjack  路  3Comments

luoruisichuan picture luoruisichuan  路  3Comments

phunterlau picture phunterlau  路  3Comments

qiliux picture qiliux  路  3Comments