简体   繁体   中英

How to pickle Keras custom layer?

I write a custom layer class extends by Layer class, then I want to pickle the history for further analysis, but when I reload the pickle object from file, python raise an error:

Unknown Layer: Attention.

So, how can I fix it?

I have both tried get_config , __getstate__ and __setstate__ , but it failed. I just want to pickle the keras history, but not the model, so please don't tell me the save model methods with custom_object parameters.

This problem occurs because when dumping the history, it fails to dump the full model. So when loading it, it cannot find the custom class.

I've noticed that the keras.callbacks.History object has an attribute model , and the incomplete dump of it is the cause of this problem.

And you said:

I just want to pickle the keras history, but not the model

So following is a workaround:

hist = model.fit(X, Y, ...)
hist.model = None

By just setting the model attribute to None, and you can dump and load your history object successfully!

Following is the MVCE:

from keras.models import Sequential
from keras.layers import Conv2D, Dense, Flatten, Layer
import keras.backend as K
import numpy as np
import pickle

# MyLayer from https://keras.io/layers/writing-your-own-keras-layers/
class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

model = Sequential()
model.add(Conv2D(filters=32, kernel_size=(3,3), input_shape=(28,28,3), activation='sigmoid'))
model.add(Flatten())
model.add(MyLayer(10))
model.add(Dense(3, activation='softmax'))

model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'], optimizer='adam')

model.summary()

X = np.random.randn(64, 28, 28, 3)
Y = np.random.randint(0, high=2, size=(64,1))

hist = model.fit(X, Y, batch_size=8)

hist.model = None

with open('hist.pkl', 'wb') as f:
    pickle.dump(hist, f)

with open('hist.pkl', 'rb') as f:
    hist_reloaded = pickle.load(f)

print(hist.history)
print(hist_reloaded.history)

The output:

{'acc': [0.484375], 'loss': [6.140302091836929]}

{'acc': [0.484375], 'loss': [6.140302091836929]}

PS If one wants to save keras model with custom layer, this should be helpful.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM