Keras: TypeError: can't pickle _thread.RLock objects

Created on 25 Jun 2018  路  11Comments  路  Source: keras-team/keras

  • [X] Check that you are up-to-date with the master branch of Keras. You can update with:
    pip install git+git://github.com/keras-team/keras.git --upgrade --no-deps

  • [X] If running on TensorFlow, check that you are up-to-date with the latest version. The installation instructions can be found here.

  • [NA] If running on Theano, check that you are up-to-date with the master branch of Theano. You can update with:
    pip install git+git://github.com/Theano/Theano.git --upgrade --no-deps

  • [X] (or just copy the script here if it is short).

NOTE: I am not using Lambda functions, as per some other pickling thread object problems. Also, I do not see an Lambda calls in keras.wrappers.scikit_learn.

I would like to pickle a simple extension of a KerasClassifier as part of my sklearn pipeline, but it's throwing a tantrum about threading stuff. I'm happy to simply single thread the thing for the time being, so a solution involving that would be appreciated if nothing else.

from sklearn.datasets import load_iris
import numpy as np
import sklearn
import keras
import keras.wrappers.scikit_learn

class LogisticRegression(keras.wrappers.scikit_learn.KerasClassifier):
    def __init__(self, n_epochs=100, **kwargs):
        self.n_epochs = n_epochs
        super().__init__(**kwargs)

    def fit(self, X, y,**kwargs):
        # get the shape of X and one hot y
        self.input_shape = X.shape[-1]
        self.label_encoder = sklearn.preprocessing.LabelEncoder()
        self.label_encoder.fit(y)
        self.output_shape = len(self.label_encoder.classes_)
        label_encoded = self.label_encoder.transform(y).reshape((-1,1))
        y_onehot = sklearn.preprocessing.OneHotEncoder().fit_transform(label_encoded).toarray()
        super().fit(X,y_onehot,epochs=self.n_epochs,verbose=1,**kwargs)
        return self

    def check_params(self, params):
        #fuckit
        pass

    def __call__(self): # the build_fn thing
        # create model
        model = keras.models.Sequential()
        model.add(keras.layers.Dense(self.output_shape, input_dim=self.input_shape, kernel_initializer="normal", activation="softmax"))
        # Compile model
        model.compile(loss='categorical_crossentropy', optimizer='adam')
        return model

data = load_iris()
clf = LogisticRegression(1)
clf.fit(data.data, data.target)

import pickle
with open("blah.p","wb") as f:
    pickle.dump(clf, f)

Result:

TypeError: can't pickle _thread.RLock objects

Most helpful comment

I experienced this error when passing a keras.layers.Input -object as sk_params argument to keras.wrappers.scikit_learn.KerasRegressor. Creating the keras.layers.Input` -object inside the function solved the problem for me.

All 11 comments

I tried putting this in there, to force the whole thing onto a single thread, but it doesn't care, and I still get the error.

config = tf.ConfigProto(intra_op_parallelism_threads=1, 
                        inter_op_parallelism_threads=1, 
                        allow_soft_placement=False, 
                        device_count = {'CPU': 1})
session = tf.Session(config=config)
keras.backend.set_session(session)

I've just devised the following hacky non-thread-safe workaround that allows me to pickle keras stuff. I would still very much appreciate a more robust solution however.

Note the __getstate__ and __setstate__ methods which allow custom pickling. If keras.models.save_model and h5py.File were capable of accepting arbitrary file-like objects instead of only string filenames, this could be fixed to not write stuff to the local directory, and would fix some of the thread-safety problems.

class LogisticRegression(keras.wrappers.scikit_learn.KerasClassifier):
    TEMP_MODEL_FILE_ = "a_temporary_file_for_storing_a_keras_model.h5"
    MODEL_ATTR_ = "the_model_all_serialized_and_turned_into_an_hdf5_file_and_stuff"

    def __init__(self, n_epochs=100, **kwargs):
        self.n_epochs = n_epochs
        super().__init__(**kwargs)

    def fit(self, X, y,**kwargs):
        # get the shape of X and one hot y
        self.input_shape = X.shape[-1]
        self.label_encoder = sklearn.preprocessing.LabelEncoder()
        self.label_encoder.fit(y)
        self.output_shape = len(self.label_encoder.classes_)
        label_encoded = self.label_encoder.transform(y).reshape((-1,1))
        y_onehot = sklearn.preprocessing.OneHotEncoder().fit_transform(label_encoded).toarray()
        super().fit(X,y_onehot,epochs=self.n_epochs,verbose=1,**kwargs)
        return self

    def check_params(self, params):
        #fuckit
        pass

    def __call__(self): # the build_fn thing
        # create model
        model = keras.models.Sequential()
        model.add(keras.layers.Dense(self.output_shape, input_dim=self.input_shape, kernel_initializer="normal", activation="softmax"))
        # Compile model
        model.compile(loss='categorical_crossentropy', optimizer='adam')
        return model

    def __getstate__(self):
        d = dict(self.__dict__)
        self.model.save(LogisticRegression.TEMP_MODEL_FILE_)
        with open(LogisticRegression.TEMP_MODEL_FILE_,"rb") as f:
            serial_model_data = f.read()
        d[LogisticRegression.MODEL_ATTR_] = serial_model_data
        os.remove(LogisticRegression.TEMP_MODEL_FILE_)
        del d["model"]
        return d

    def __setstate__(self, d):
        with open(LogisticRegression.TEMP_MODEL_FILE_, "wb") as f:
            f.write(d[LogisticRegression.MODEL_ATTR_])
        del d[LogisticRegression.MODEL_ATTR_]
        self.model = keras.models.load_model(LogisticRegression.TEMP_MODEL_FILE_)
        os.remove(LogisticRegression.TEMP_MODEL_FILE_)
        self.__dict__.update(d)

As of right now Keras models are not picklable. Check in when this PR is merged:
https://github.com/keras-team/keras/pull/10483

I experienced this error when passing a keras.layers.Input -object as sk_params argument to keras.wrappers.scikit_learn.KerasRegressor. Creating the keras.layers.Input` -object inside the function solved the problem for me.

I am using Lambda functions. The problem arises. How to solved the problem ??

As of right now Keras models are not picklable. Check in when this PR is merged:

10483

Are they still not picklable?

10483 was closed in favor of #11030 and now they should be picklable so, why this error still arises?

Is there any progress on this?

Getting same error

Getting the error when trying to joblib.dump() a tf.keras model.

Was this page helpful?
0 / 5 - 0 ratings