Hey there,
I would like to use the GridSearch provided through sklearn with keras, but also retain the usage of callbacks.
How can I do this?
Thank you,
Keiron
I'm joining in on this. It doesn't appear to work in the way I'd expect. Including simple early stopping and reduceLRonPlateau callbacks results in the following:
RuntimeError Traceback (most recent call last)
<ipython-input-159-1a0430e264a4> in <module>()
25 start = datetime.datetime.now()
26
---> 27 grid_result = grid.fit(xtrain, ytrain)
28
29 end = datetime.datetime.now()
/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups)
938 train/test set.
939 """
--> 940 return self._fit(X, y, groups, ParameterGrid(self.param_grid))
941
942
/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/model_selection/_search.py in _fit(self, X, y, groups, parameter_iterable)
547 n_candidates * n_splits))
548
--> 549 base_estimator = clone(self.estimator)
550 pre_dispatch = self.pre_dispatch
551
/home/nobody/anaconda3/lib/python3.5/site-packages/sklearn/base.py in clone(estimator, safe)
124 raise RuntimeError('Cannot clone object %s, as the constructor '
125 'does not seem to set parameter %s' %
--> 126 (estimator, name))
127
128 return new_object
RuntimeError: Cannot clone object <keras.wrappers.scikit_learn.KerasClassifier object at 0x2b132c647400>, as the constructor does not seem to set parameter callbacks
I was able to solve this for my similar use case (GridSearch) by passing the callback to the fit_params kwarg of GridSearchCV.
I.E. grid_search = GridSearchCV(est, params, fit_params={'callbacks': [MyCallback()]})
this solution does not work for learning_curve which does not implement the fit_params argument
This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.
this is a temporary solution that will stop working soon (see here): fit_params as a constructor argument was deprecated in version 0.19 and will be removed in version 0.21.
This part from the link mentioned by @baharian was especially helpful:Pass fit parameters to the fit method instead.

grid_search = grid_search.fit(X_train, y_train, callbacks=[MyCallback()]) worked for me.
Most helpful comment
grid_search = grid_search.fit(X_train, y_train, callbacks=[MyCallback()])worked for me.