Keras: Simultaneous training of shared layers with two data sets

Created on 9 Aug 2016  路  5Comments  路  Source: keras-team/keras

I'm working on semi-supervised multi-task learning problem. I have a labelled data set (~2k instances) and unlabelled data set (~17k instances). The shape of the data points from both sets is identical.

My multi-task learning architecture simultaneously reconstructs and classifies the labelled input which works very well. I would now like to add the unlabelled data by omitting the classification branch and simply evaluating the reconstruction cost only. Training should now be done on both data sets at the same time so that supervised training can guide unsupervised feature learning.

Unfortunately, my current naive implementation does not allow me to handle the different number of instances in the data sets. Executing the code below leads to the following error
'Exception: All input arrays (x) should have the same number of samples.'

Simply repeating the labelled data leads to severe over-fitting on the classification task.

Here's the code:

labelled_input   = Input(shape=(240*66,), name='l_input') #~2k instances
unlabelled_input = Input(shape=(240*66,), name='u_input') #17k instances

encoder = [Dense(1000, activation='relu'), Dense(500, activation='relu')]
decoder = [Dense(1000), Dense(240*66, name='x_hat')]
classifier = [Dense(65, activation='softmax', name='y_pred')]

"""
[...]
Defines encode(), decode() and classify() which simply 
apply the respective layers to the argument
"""

# Processing labelled data points
l_h     = encode(labelled_input)
l_x_hat = decode(l_h)
# Dummy activation to help distinguish the outputs
l_x_hat = Activation('linear', name='l_x_hat')
y_pred  = classify(l_h)

# Processing unlabelled data points
u_h     = encode(unlabelled_input)
u_x_hat = decode(u_h)
u_x_hat = Activation('linear', name='u_x_hat')

model = Model(input=[labelled_input, unlabelled_input], output=[l_x_hat, u_x_hat, y_pred])
model.compile(optimizer='nadam', 
              loss={'l_x_hat' : 'mse',  'u_x_hat' : 'mse', 'y_pred' : 'categorical_crossentropy'})

model.fit({'labelled_input' : x_train, 'unlabelled_input': u_x_train}, #inputs
           'l_x_hat': x_train, 'u_x_hat': u_train, 'y_pred': y_train}, #outputs
            nb_epoch=nb_epoch, batch_size=batch_size)
stale

Most helpful comment

One solution is to have two separate models that share layers. Then, you can alternate updating the supervised and unsupervised models

Here's a rough sketch of how that might work with your code

model_unsupervised = Model(input=unlabelled_input, output=[l_x_hat, u_x_hat])
model_supervised = Model(input=unlabelled_input, output=y_pred)

Then, you can either use model.fit on each to train for a full epoch, or model.train_on_batch to do batch updates to each in an alternating fashion.

Also -- you might get a better response/discussion on the keras google group rather than filing a github issue for this type of thing.

All 5 comments

One solution is to have two separate models that share layers. Then, you can alternate updating the supervised and unsupervised models

Here's a rough sketch of how that might work with your code

model_unsupervised = Model(input=unlabelled_input, output=[l_x_hat, u_x_hat])
model_supervised = Model(input=unlabelled_input, output=y_pred)

Then, you can either use model.fit on each to train for a full epoch, or model.train_on_batch to do batch updates to each in an alternating fashion.

Also -- you might get a better response/discussion on the keras google group rather than filing a github issue for this type of thing.

What do input_motion and h mean? Why are encoder and decoder not symmetrical?

@Imorton-zd I've updated the file, sorry for the confusion

@jonathan-schwarz @fchollet : I am facing the exact same problem. Did you figure out the solution?...

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

MarkVdBergh picture MarkVdBergh  路  3Comments

kylemcdonald picture kylemcdonald  路  3Comments

NancyZxll picture NancyZxll  路  3Comments

snakeztc picture snakeztc  路  3Comments

harishkrishnav picture harishkrishnav  路  3Comments