Keras: trainable doesn't work correctly

Created on 20 Jun 2018  路  12Comments  路  Source: keras-team/keras

x = Input(shape=(32,))
layer1 = Dense(32)
layer2 = Dense(32)
y = layer2(layer1(x))
model_1 = Model(x, y)
layer1.trainable = True
layer2.trainable = False
model_1.compile(optimizer='rmsprop', loss='mse')
model_2 = Model(x, y)
layer1.trainable = False
layer2.trainable = True
model_2.compile(optimizer='rmsprop', loss='mse')
model_3 = Model(x,y)
layer1.trainable = True
layer2.trainable = True
model_3.compile(optimizer='rmsprop', loss='mse')

model_1.summary()
model_2.summary()
model_3.summary()

After compiling model_3, model_1 and model_2 get wrong total number of parameters:

model 1 summary:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_15 (InputLayer)        (None, 32)                0         
_________________________________________________________________
dense_26 (Dense)             (None, 32)                1056      
_________________________________________________________________
dense_27 (Dense)             (None, 32)                1056      
=================================================================
Total params: 1,056
Trainable params: 1,056
Non-trainable params: 0
_________________________________________________________________
model 2 summary:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_15 (InputLayer)        (None, 32)                0         
_________________________________________________________________
dense_26 (Dense)             (None, 32)                1056      
_________________________________________________________________
dense_27 (Dense)             (None, 32)                1056      
=================================================================
Total params: 1,056
Trainable params: 1,056
Non-trainable params: 0
_________________________________________________________________
model 3 summary:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py:478: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
  'Discrepancy between trainable weights and collected trainable'
=================================================================
input_15 (InputLayer)        (None, 32)                0         
_________________________________________________________________
dense_26 (Dense)             (None, 32)                1056      
_________________________________________________________________
dense_27 (Dense)             (None, 32)                1056      
=================================================================
Total params: 2,112
Trainable params: 2,112
Non-trainable params: 0
_________________________________________________________________

Most helpful comment

@rnorm When you modify layer.trainable, you should call model.compile again to make it take effect, and then it should work well.

All 12 comments

@rnorm When you modify layer.trainable, you should call model.compile again to make it take effect, and then it should work well.

@yanboliang thanks.
I want to build up 3 models with shared layers.
In model 1, layer 1 is trainable and layer 2 is not.
In model 2, layer 2 is trainable and layer 1 is not.
In model 3, both are trainable.
That is why after modifying the trainable property I only compile one model.

So if I understand your suggestion correctly, the same shared layer in different models must have the same trainable property, i.e., all true or all false?

Yep.

@yanboliang if so, could you recommend some solution for what I'm trying to do?

I don't understand you scenarios very clearly. In model 1, layer 2 is non-trainable(never be updated), and in model 2, layer 2 is trainable(will be updated), then how they can share variables/parameters?
Usually we freeze(set layer as non-trainable) a layers in the context of fine-tuning a model, or using fixed embeddings for a text input. Then you should finished training a model, and then change some layers to non-trainable in the model, and train the updated model on new dataset.

@yanboliang Here is a hypothetical example:
I have one network with two layers, but for some reasons, I don't want to optimize the entire model directly. Instead, I want to train the two layers separately. Specifically, in a for loop - train layer 1 while fixing layer 2, and then train layer 2 while fixing layer 1. To do that, I build up two models - model 1 and model 2 sharing the same architecture and the same weights. Then in a for loop, I want to train model 1 while freezing its second layer and then train model 2 while freezing its first layer.

For your case, I think you should trigger fit after compile model:

layer1.trainable = True
layer2.trainable = False
model_1.compile(optimizer='rmsprop', loss='mse')
model_1.fit(...)

model_2 = Model(x, y)
layer1.trainable = False
layer2.trainable = True
model_2.compile(optimizer='rmsprop', loss='mse')
model_2.fit(...)

model_3 = Model(x,y)
layer1.trainable = True
layer2.trainable = True
model_3.compile(optimizer='rmsprop', loss='mse')
model_3.fit(...)

@yanboliang but I need to put that into a for loop and train with batch. As a result, in each iteration each model is recompiled.

If you don't recompile model, the modification for layer.trainable will not take effect.

I think the summary produced is wrong. For the first two models, it should be

Total params: 2,112
Trainable params: 1,056
Non-trainable params: 1,056

Could you confirm that this is bug?

I also have the same problem. Is there any way to do it? I've seen this is solved when I am freezing a whole model, and then combining it (like GANS). But is it possible to do the same with only a couple of layers?

I have a very similar problem than rnorm https://github.com/keras-team/keras/issues/8259

Was this page helpful?
0 / 5 - 0 ratings