'How can I save an object containing keras models?
Here is my code skeleton:
def build_model(x, y):
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(1, activation='relu'))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(x, y)
return model
class MultiModel(object):
def __init__(self, x1, y1, x2, y2):
self.model1 = build_model(x1, y1)
self.model2 = build_model(x2, y2)
if __name__ == '__main__':
# code that builds x1, x2, y1, y2
mm = MultiModel(x1, y1, x2, y2) # How to save mm ?
The issue is I don't know how to save the mm object that contains several Keras models.
The Keras built-in save method enables only to save Keras model, so it is unusable in that case. The pickle module can not save _thread.RLock objects, so it is also unusable.
It is maybe possible to save each model independently with the Keras save method, then to regroup them and to save them as a whole. But I do not know how to proceed.
Any ideas ?
Solution 1:[1]
As it took me a long time to find a working solution to a similar problem (pickling an object containing objects containing Keras models), I post here the solution. It is adriangb's, found in LifeSaver:
import pickle
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils
def unpack(model, training_config, weights):
restored_model = deserialize(model)
if training_config is not None:
restored_model.compile(
**saving_utils.compile_args_from_training_config(
training_config
)
)
restored_model.set_weights(weights)
return restored_model
# Hotfix function
def make_keras_picklable():
def __reduce__(self):
model_metadata = saving_utils.model_metadata(self)
training_config = model_metadata.get("training_config", None)
model = serialize(self)
weights = self.get_weights()
return (unpack, (model, training_config, weights))
cls = Model
cls.__reduce__ = __reduce__
# Run the function
make_keras_picklable()
# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
# Save
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)
Thanks Adrian !
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |