Keras: Using callbacks with the SciKit-Learn API?

Created on 16 Oct 2016  路  7Comments  路  Source: keras-team/keras

Hey there,

I would like to use the GridSearch provided through sklearn with keras, but also retain the usage of callbacks.

How can I do this?

Thank you,

Keiron

stale

Most helpful comment

grid_search = grid_search.fit(X_train, y_train, callbacks=[MyCallback()]) worked for me.

All 7 comments

I'm joining in on this. It doesn't appear to work in the way I'd expect. Including simple early stopping and reduceLRonPlateau callbacks results in the following:

RuntimeError                              Traceback (most recent call last)
<ipython-input-159-1a0430e264a4> in <module>()
     25 start = datetime.datetime.now()
     26 
---> 27 grid_result = grid.fit(xtrain, ytrain)
     28 
     29 end = datetime.datetime.now()

/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups)
    938             train/test set.
    939         """
--> 940         return self._fit(X, y, groups, ParameterGrid(self.param_grid))
    941 
    942 

/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/model_selection/_search.py in _fit(self, X, y, groups, parameter_iterable)
    547                                      n_candidates * n_splits))
    548 
--> 549         base_estimator = clone(self.estimator)
    550         pre_dispatch = self.pre_dispatch
    551 

/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/base.py in clone(estimator, safe)
    124             raise RuntimeError('Cannot clone object %s, as the constructor '
    125                                'does not seem to set parameter %s' %
--> 126                                (estimator, name))
    127 
    128     return new_object

RuntimeError: Cannot clone object <keras.wrappers.scikit_learn.KerasClassifier object at 0x2b132c647400>, as the constructor does not seem to set parameter callbacks

I was able to solve this for my similar use case (GridSearch) by passing the callback to the fit_params kwarg of GridSearchCV.
I.E. grid_search = GridSearchCV(est, params, fit_params={'callbacks': [MyCallback()]})

this solution does not work for learning_curve which does not implement the fit_params argument

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

this is a temporary solution that will stop working soon (see here): fit_params as a constructor argument was deprecated in version 0.19 and will be removed in version 0.21.

This part from the link mentioned by @baharian was especially helpful:Pass fit parameters to the fit method instead.

image

grid_search = grid_search.fit(X_train, y_train, callbacks=[MyCallback()]) worked for me.

Was this page helpful?
0 / 5 - 0 ratings