Keras: Slow training due to 'on_batch_end()' and 'TimeDistributedDense()'

Created on 4 May 2016  路  5Comments  路  Source: keras-team/keras

Hi,
I just finished to train my first RNN with keras, but I have two small technical problems. I have several Warning Messages because apparently something is slowing down my training phase.
The first problem is 'TimeDistributedDense()'. Apparently is deprecated and I am supposed to use 'TimeDistributed(Dense(...))' instead, as suggested by the warning msg itself. When I surf into keras documentation this method is cited in section 'Core', but the code line:

from keras.layers.core import TimeDistributed

gives me an error. How am I suposed to import TimeDistributed?

The second problem is 'on_batch_end()'. The original warning msg is the following:

Warning (from warnings module):
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/keras/callbacks.py", line 66
% delta_t_median)
UserWarning: Method on_batch_end() is slow compared to the batch update (0.122654). Check your callbacks

I found another post here about this problem but people were saying that is due to heavy custom callbacks, but I didn't define any callback by my own. So what could be the problem?

Most helpful comment

There is no problem, actually, except that you are running with a very strange choice of parameters.

The callback that is being slow in this case, is logging to stdout. You're setting batch_size=1, which you should never do since it is insanely inefficient especially for such a small network, and at the same you are setting verbose=1 (default). This triggers one flush to stdout for every batch, i.e. in this case for sample processed. Because you have a tiny model, actually running the model on one sample is basically instant, and flushing to stdout is taking a significant time compared to that.

All 5 comments

First problem, from keras.layers.wrappers import TimeDistributed
Second problem, can you provide the script you run?

Thank you for your answer. Sorry for the first question but I had checked the wrappers section and there were only scikit-learn wrappers, so...
Anyway the part of script related to RNN is the following:

CLASSIFICATION

#params
in_neurons = 225
hidden_neurons_1 = 21
hidden_neurons_2 = 21
out_neurons = 4

# split dataset in training set and test set
X_train, X_test, y_train, y_test = train_test_split(NPZ, y, test_size=0.3, random_state=0)

X_train = sequence.pad_sequences(X_train,dtype='float32')
y_train = sequence.pad_sequences(y_train,dtype='int32') 

model = Sequential()

model.add(GRU(hidden_neurons_1, input_dim=in_neurons, return_sequences=True))
model.add(Dropout(0.2))
model.add(GRU(hidden_neurons_2, return_sequences=True))
model.add(Dropout(0.2))
model.add(TimeDistributed(Dense(out_neurons, activation = 'softmax')))

model.compile( loss = 'mse', optimizer = 'rmsprop',  metrics=["accuracy"])

model.summary()

model.fit(X_train, y_train,batch_size =1,nb_epoch=20)

# save model and weights
json_string = model.to_json()
open('my_model_architecture.json', 'w').write(json_string)
model.save_weights('my_model_weights.h5')

There is no problem, actually, except that you are running with a very strange choice of parameters.

The callback that is being slow in this case, is logging to stdout. You're setting batch_size=1, which you should never do since it is insanely inefficient especially for such a small network, and at the same you are setting verbose=1 (default). This triggers one flush to stdout for every batch, i.e. in this case for sample processed. Because you have a tiny model, actually running the model on one sample is basically instant, and flushing to stdout is taking a significant time compared to that.

Thanks! I tried to change both parameters and apparently only verbose=1 is the problem. Even by increasing the batch_size I still have the same warning message. It doesn't really makes sense because with verbose = 0 I cannot check how my training is going.

You could try verbose = 2. This gives you one output per epoch.

Was this page helpful?
0 / 5 - 0 ratings