So I have a custom layer SpatialTransformer:
class SpatialTransformer(Layer):
def __init__(self,localization_net,output_size,**kwargs):
self.locnet = localization_net
self.output_size = output_size
super(SpatialTransformer, self).__init__(**kwargs)
def get_config(self):
config = {'localization_net': self.locnet , 'output_size':self.output_size}
base_config = super(SpatialTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self.locnet.build(input_shape)
self.trainable_weights = self.locnet.trainable_weights
#self.regularizers = self.locnet.regularizers //NOT SUER ABOUT THIS, THERE IS NO MORE SUCH PARAMETR AT self.locnet
self.constraints = self.locnet.constraints
def compute_output_shape(self, input_shape):
output_size = self.output_size
return (None,
int(output_size[0]),
int(output_size[1]),
int(input_shape[-1]))
def call(self, X, mask=None):
affine_transformation = self.locnet.call(X)
output = self._transform(affine_transformation, X, self.output_size)
return output
def _repeat(self, x, num_repeats):
ones = tf.ones((1, num_repeats), dtype='int32')
x = tf.reshape(x, shape=(-1,1))
x = tf.matmul(x, ones)
return tf.reshape(x, [-1])
def _interpolate(self, image, x, y, output_size):
batch_size = tf.shape(image)[0]
height = tf.shape(image)[1]
width = tf.shape(image)[2]
num_channels = tf.shape(image)[3]
x = tf.cast(x , dtype='float32')
y = tf.cast(y , dtype='float32')
height_float = tf.cast(height, dtype='float32')
width_float = tf.cast(width, dtype='float32')
output_height = output_size[0]
output_width = output_size[1]
x = .5*(x + 1.0)*(width_float)
y = .5*(y + 1.0)*(height_float)
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
max_y = tf.cast(height - 1, dtype='int32')
max_x = tf.cast(width - 1, dtype='int32')
zero = tf.zeros([], dtype='int32')
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
flat_image_dimensions = width*height
pixels_batch = tf.range(batch_size)*flat_image_dimensions
flat_output_dimensions = output_height*output_width
base = self._repeat(pixels_batch, flat_output_dimensions)
base_y0 = base + y0*width
base_y1 = base + y1*width
indices_a = base_y0 + x0
indices_b = base_y1 + x0
indices_c = base_y0 + x1
indices_d = base_y1 + x1
flat_image = tf.reshape(image, shape=(-1, num_channels))
flat_image = tf.cast(flat_image, dtype='float32')
pixel_values_a = tf.gather(flat_image, indices_a)
pixel_values_b = tf.gather(flat_image, indices_b)
pixel_values_c = tf.gather(flat_image, indices_c)
pixel_values_d = tf.gather(flat_image, indices_d)
x0 = tf.cast(x0, 'float32')
x1 = tf.cast(x1, 'float32')
y0 = tf.cast(y0, 'float32')
y1 = tf.cast(y1, 'float32')
area_a = tf.expand_dims(((x1 - x) * (y1 - y)), 1)
area_b = tf.expand_dims(((x1 - x) * (y - y0)), 1)
area_c = tf.expand_dims(((x - x0) * (y1 - y)), 1)
area_d = tf.expand_dims(((x - x0) * (y - y0)), 1)
output = tf.add_n([area_a*pixel_values_a,
area_b*pixel_values_b,
area_c*pixel_values_c,
area_d*pixel_values_d])
return output
def _meshgrid(self, height, width):
x_linspace = tf.linspace(-1., 1., width)
y_linspace = tf.linspace(-1., 1., height)
x_coordinates, y_coordinates = tf.meshgrid(x_linspace, y_linspace)
x_coordinates = tf.reshape(x_coordinates, [-1])
y_coordinates = tf.reshape(y_coordinates, [-1])
ones = tf.ones_like(x_coordinates)
indices_grid = tf.concat([x_coordinates, y_coordinates, ones], 0)
return indices_grid
def _transform(self, affine_transformation, input_shape, output_size):
batch_size = tf.shape(input_shape)[0]
height = tf.shape(input_shape)[1]
width = tf.shape(input_shape)[2]
num_channels = tf.shape(input_shape)[3]
affine_transformation = tf.reshape(affine_transformation, shape=(batch_size,2,3))
affine_transformation = tf.reshape(affine_transformation, (-1, 2, 3))
affine_transformation = tf.cast(affine_transformation, 'float32')
width = tf.cast(width, dtype='float32')
height = tf.cast(height, dtype='float32')
output_height = output_size[0]
output_width = output_size[1]
indices_grid = self._meshgrid(output_height, output_width)
indices_grid = tf.expand_dims(indices_grid, 0)
indices_grid = tf.reshape(indices_grid, [-1]) # flatten?
indices_grid = tf.tile(indices_grid, tf.stack([batch_size]))
indices_grid = tf.reshape(indices_grid, (batch_size, 3, -1))
transformed_grid = tf.matmul(affine_transformation, indices_grid)
x_s = tf.slice(transformed_grid, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(transformed_grid, [0, 1, 0], [-1, 1, -1])
x_s_flatten = tf.reshape(x_s, [-1])
y_s_flatten = tf.reshape(y_s, [-1])
transformed_image = self._interpolate(input_shape,
x_s_flatten,
y_s_flatten,
output_size)
transformed_image = tf.reshape(transformed_image, shape=(batch_size,
output_height,
output_width,
num_channels))
return transformed_image
This is the starting layer for the model : 'model' and it has the input from another model : 'locnet'. I have saved both the models after training them.
So when I try to load the model which has this custom layer, I always get an error regarding the input arguments of SpatialTransformer from __init__()
For loading the model, I am using :
locnet=load_model('locnet1.h5')
model=load_model('model.h5',custom_objects={"SpatialTransformer": SpatialTransformer})
This is giving the error as follows:
Traceback (most recent call last):
File "stn_test.py", line 119, in
model=load_model('model.h5',custom_objects={"SpatialTransformer": SpatialTransformer})
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/models.py", line 240, in load_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/models.py", line 304, in model_from_config
return layer_module.deserialize(config, custom_objects=custom_objects)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/layers/__init__.py", line 54, in deserialize
printable_module_name='layer')
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 140, in deserialize_keras_object
list(custom_objects.items())))
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/models.py", line 1202, in from_config
layer = layer_module.deserialize(conf, custom_objects=custom_objects)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/layers/__init__.py", line 54, in deserialize
printable_module_name='layer')
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 141, in deserialize_keras_object
return cls.from_config(config['config'])
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/topology.py", line 1231, in from_config
return cls(**config)
TypeError: __init__() missing 2 required positional arguments: 'localization_net' and 'output_size'
Please help!
I've had a similar issue: https://github.com/fchollet/keras/issues/6900
If you do have the network architecture, try to only load the weights into the predefined model: model.load_weights('locnet1.h5'). For me, this does work.
I also had this problem. In my case, the argument I was sending to the __init__ function of the custom layer was a pandas data frame. There is probably an issue with the way it is being serialized/deserialized because I also got the same error message about missing positional arguments when trying to load the model. When I changed the argument from a pandas data frame to a standard python list the error went away.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.
in my case this was useful https://github.com/keras-team/keras/issues/5401
Useful for me too.
I solved with load_weights, thanks!
Most helpful comment
I've had a similar issue: https://github.com/fchollet/keras/issues/6900
If you do have the network architecture, try to only load the weights into the predefined model:
model.load_weights('locnet1.h5'). For me, this does work.