'How can I save an object containing keras models?

Here is my code skeleton:

def build_model(x, y):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Dense(1, activation='relu'))
    model.compile(loss='mean_squared_error', optimizer='adam')
    model.fit(x, y)
    return model

class MultiModel(object):
    def __init__(self, x1, y1, x2, y2):
        self.model1 = build_model(x1, y1)
        self.model2 = build_model(x2, y2)

if __name__ == '__main__':
    # code that builds x1, x2, y1, y2
    mm = MultiModel(x1, y1, x2, y2)  # How to save mm ?

The issue is I don't know how to save the mm object that contains several Keras models.

The Keras built-in save method enables only to save Keras model, so it is unusable in that case. The pickle module can not save _thread.RLock objects, so it is also unusable.

It is maybe possible to save each model independently with the Keras save method, then to regroup them and to save them as a whole. But I do not know how to proceed.

Any ideas ?



Solution 1:[1]

As it took me a long time to find a working solution to a similar problem (pickling an object containing objects containing Keras models), I post here the solution. It is adriangb's, found in LifeSaver:

import pickle

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils


def unpack(model, training_config, weights):
    restored_model = deserialize(model)
    if training_config is not None:
        restored_model.compile(
            **saving_utils.compile_args_from_training_config(
                training_config
            )
        )
    restored_model.set_weights(weights)
    return restored_model

# Hotfix function
def make_keras_picklable():

    def __reduce__(self):
        model_metadata = saving_utils.model_metadata(self)
        training_config = model_metadata.get("training_config", None)
        model = serialize(self)
        weights = self.get_weights()
        return (unpack, (model, training_config, weights))

    cls = Model
    cls.__reduce__ = __reduce__

# Run the function
make_keras_picklable()

# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])

# Save
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

Thanks Adrian !

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1