Keras: After Flatten() both batch and output shape of a tensor become none

Created on 13 Dec 2017  路  16Comments  路  Source: keras-team/keras

If I use Flatten() using functional API, both batch and output shape become none, where as in Sequential model its print correct output shape. I need both batch and output size later on and I have to use functional api because of my model complexity, is this is a issue of Keras?
@fchollet @farizrahman4u @Dref360

`input = Input(batch_shape=[64,224,224,3])

test = Conv2D(96,kernel_size=(11,11),strides=(4,4), activation='relu',name='conv1')(input)
print 'b4 flatten shape ', test.get_shape()

test = Flatten()(test)

print 'after flateen shape ',test.get_shape()

alexnet = Sequential()

alexnet.add(InputLayer(batch_input_shape=img_input))

alexnet.add(Conv2D(96, kernel_size=(11, 11), strides=(4, 4),
                   activation='relu', name='conv1'))

alexnet.add(Flatten())

print 'Sequential model shape ',alexnet.output_shape`

The output I got :

b4 flatten shape (64, 30, 30, 96)
after flatten shape (?, ?)
Sequential model shape (64, 86400)

Most helpful comment

My current workaround is to query the shape before Flatten()is done this way:

# imports
import numpy as np
import keras
from keras.layers import *

# toy net
x = Input(shape = (12,100,10))
x = Dense(32)(x)
f = Flatten()(x)

# shape inference
shape_before_flatten = x.shape.as_list()[1:] # [1:] to skip None
shape_flatten = np.prod(shape_before_flatten) # value of shape in the non-batch dimension

# print
print("x = "+ str(x))
print("shape before Flatten() = " + str(shape_before_flatten))
print("shape after Flatten() = " + str(f.shape.as_list()))
print("shape_flatten in the non-batch dimension: " + str(shape_flatten))


x = Tensor("dense_17/BiasAdd:0", shape=(?, 12, 100, 32), dtype=float32)
shape before Flatten() = [12, 100, 32]
shape after Flatten() = [None, None]
shape_flatten in the non-batch dimension: 38400

All 16 comments

My current workaround is to use Reshape((-1,)) instead, which doesn't have this problem. Here's a minimal example for reproducing this bug:

from keras.layers import Input, Flatten, Reshape

x = Input(batch_shape=(16, 10, 10))
print(x)

x = Input(batch_shape=(16, 10, 10))
x = Flatten()(x)
print(x)

x = Input(batch_shape=(16, 10, 10))
x = Reshape((-1,))(x)
print(x)

Prints:

Tensor("input_1:0", shape=(16, 10, 10), dtype=float32)
Tensor("flatten_1/Reshape:0", shape=(?, ?), dtype=float32) # from Flatten
Tensor("reshape_1/Reshape:0", shape=(16, 100), dtype=float32) # from Reshape

Hi, I am using Keras 2.1.5 and am getting similar issues with Flatten() and Reshape not behaving as expected. The output of the same commands as above are

Tensor("input_5:0", shape=(16, 10, 10), dtype=float32)
Tensor("flatten_15/Reshape:0", shape=(?, ?), dtype=float32)
Tensor("reshape_27/Reshape:0", shape=(?, ?), dtype=float32)

The work around suggested in the previous comment doesn't work. My current workaround is

x = Input(batch_shape=(16, 10, 10))
x = keras.layers.Reshape((100,))(x)
print(x)

which outputs

Tensor("reshape_29/Reshape:0", shape=(16, 100), dtype=float32)

Hi, I am using Keras 2.1.5 and am getting similar issues with Flatten() and Reshape not behaving as expected. The output of the same commands as above are

Tensor("input_5:0", shape=(16, 10, 10), dtype=float32)
Tensor("flatten_15/Reshape:0", shape=(?, ?), dtype=float32)
Tensor("reshape_27/Reshape:0", shape=(?, ?), dtype=float32)

The work around suggested in the previous comment doesn't work. My current workaround is

x = Input(batch_shape=(16, 10, 10))
x = keras.layers.Reshape((100,))(x)
print(x)

which outputs

Tensor("reshape_29/Reshape:0", shape=(16, 100), dtype=float32)

same issue here. No I need to explicitly take down the shape and use it in reshape

+1
default_encoding: ANSI_X3.4-1968
ipython_version: 6.4.0,
os_name: posix,
platform: Linux-3.10.0-693.21.1.el7.x86_64-x86_64-with-Ubuntu-16.04-xenial,n
sys_platform: linux,
sys_version: 3.5.2
GCC 5.4.0 20160609
keras version: 2.1.6
python version: 3.5
tensorflow backend version: 1.8.0

My current workaround is to query the shape before Flatten()is done this way:

# imports
import numpy as np
import keras
from keras.layers import *

# toy net
x = Input(shape = (12,100,10))
x = Dense(32)(x)
f = Flatten()(x)

# shape inference
shape_before_flatten = x.shape.as_list()[1:] # [1:] to skip None
shape_flatten = np.prod(shape_before_flatten) # value of shape in the non-batch dimension

# print
print("x = "+ str(x))
print("shape before Flatten() = " + str(shape_before_flatten))
print("shape after Flatten() = " + str(f.shape.as_list()))
print("shape_flatten in the non-batch dimension: " + str(shape_flatten))


x = Tensor("dense_17/BiasAdd:0", shape=(?, 12, 100, 32), dtype=float32)
shape before Flatten() = [12, 100, 32]
shape after Flatten() = [None, None]
shape_flatten in the non-batch dimension: 38400

I'm also experiencing this issue. Any updates?

I am also experiencing the exact same issue.

does it affect the model structure or it is just a print issue?

Hi, I am using Keras 2.1.5 and am getting similar issues with Flatten() and Reshape not behaving as expected. The output of the same commands as above are

Tensor("input_5:0", shape=(16, 10, 10), dtype=float32)
Tensor("flatten_15/Reshape:0", shape=(?, ?), dtype=float32)
Tensor("reshape_27/Reshape:0", shape=(?, ?), dtype=float32)

The work around suggested in the previous comment doesn't work. My current workaround is

x = Input(batch_shape=(16, 10, 10))
x = keras.layers.Reshape((100,))(x)
print(x)

which outputs

Tensor("reshape_29/Reshape:0", shape=(16, 100), dtype=float32)

This helped me. Thanks a ton !!!

Same problem. Reshape((-1,)) not work either.
Keras 2.2.4, Tensorflow 1.12

Any updates? I have the same problem, after Flatten(), tensor's shape become (?,?)

Any updates? I have the same problem, after Flatten(), tensor's shape become (?,?)

I have a workaround here:

x = Input(shape=(, 10, 10))
shape = np.prod(x.shape[1:])
x = Reshape((shape,))(x)

this code will return shape=(?,100)

Any updates? I have the same problem, after Flatten(), tensor's shape become (?,?)

I have a workaround here:

x = Input(shape=(, 10, 10))
shape = np.prod(x.shape[1:])
x = Reshape((shape,))(x)

this code will return shape=(?,100)

Not working with:

x = Input(shape=(, 10, 10))
shape = np.prod(x.shape[1:])
x = Reshape((shape,))(x)
x = Dense(10, activation='relu'))(x)

This works

x = Input(shape=(32, 10, 10))
shape = np.prod(x.shape[1:])
x = Reshape((shape,))(x)
print(x.shape)

output:
(?, 3200)

When adding some following layers, we can add an additional Reshape layer:

x = Input(shape=(, 10, 10))
shape = np.prod(x.shape[1:])
x = Reshape((shape,))(x)
x = Reshape((-1,))(x)
x = Dense(10, activation='relu'))(x)

This works

x = Input(shape=(32, 10, 10))
shape = np.prod(x.shape[1:])
x = Reshape((shape,))(x)
print(x.shape)

output:
(?, 3200)

When adding some following layers, we can add an additional Reshape layer:

x = Input(shape=(, 10, 10))
shape = np.prod(x.shape[1:])
x = Reshape((shape,))(x)
x = Reshape((-1,))(x)
x = Dense(10, activation='relu'))(x)

Above solution will not work! A quick fix for adding following layer is:
shape = np.prod(x.shape[1:])
reshaper=keras.layers.Lambda(lambda x: tf.keras.backend.reshape(x, shape=(self.batch_size, shape)))
x = reshaper(x)
x = Dense(10, activation='relu'))(x)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

zygmuntz picture zygmuntz  路  3Comments

farizrahman4u picture farizrahman4u  路  3Comments

nryant picture nryant  路  3Comments

harishkrishnav picture harishkrishnav  路  3Comments

kylemcdonald picture kylemcdonald  路  3Comments