Keras: Get the last sequence of return_sequences=True

Created on 14 Dec 2017  路  2Comments  路  Source: keras-team/keras

I have a sequence-to-sequence model where the encoder LSTM has return_sequences=True. I want to get only the last sequence of the returned sequences. How can I do ?

PS I know with return_sequences=False I'll have the last sequence but I can't do this.

Most helpful comment

You can use functional model and Lambda layer like this:
last_timestep = Lambda(lambda x: x[:, -1, :])(lstm_layer)

All 2 comments

from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM, RepeatVector
import numpy as np

train_x = np.array([1.0, 1.0, 1.0]).reshape((1, 3, 1))
train_y = np.asarray([[1, 1, 1]])

inputs1 = Input(shape=(3, 1))
seqs = LSTM(1, return_sequences=False)(inputs1)
x = RepeatVector(3)(seqs)
x = LSTM(3)(x)

model1 = Model(inputs=inputs1, outputs=x)
model1.compile(optimizer='adam', loss='mean_squared_error')

model2 = Model(inputs=inputs1, outputs=seqs) ## the output is the encoder part
model2.compile(optimizer='adam', loss='mean_squared_error')

model1.fit(train_x, train_y)

model2.layers[0].set_weights(model1.layers[0].get_weights()) #set the weights from the trained model
print(model2.predict(train_x))

If I understand what you want, the the above sample code is a sequence to sequence model that takes input of 3 Dims and the output is 3 Dims too.

The first model 'model1' fits the data and update the weights in the training process.
After training the model, we initialize another model by the encoder part of the 'model1' and get the output sequence from using the prediction function.

You can use functional model and Lambda layer like this:
last_timestep = Lambda(lambda x: x[:, -1, :])(lstm_layer)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

EderSantana picture EderSantana  路  219Comments

patyork picture patyork  路  73Comments

dipanjannag picture dipanjannag  路  265Comments

phipleg picture phipleg  路  60Comments

lmoesch picture lmoesch  路  89Comments