'How to split a Keras model, with a non-sequential architecture like ResNet, into sub-models?

My model is a resnet-152 i wanna cutting it into two submodels and the problem is with the second one i can't figure out how to build a model from an intermediate layer to the output

I tried this code from this response and it doesn't work for me here is my code:

def getLayerIndexByName(model, layername):
    for idx, layer in enumerate(model.layers):
        if layer.name == layername:
            return idx

idx = getLayerIndexByName(resnet, 'res3a_branch2a')

input_shape = resnet.layers[idx].get_input_shape_at(0) # which is here in my case (None, 55, 55, 256)

layer_input = Input(shape=input_shape[1:]) # as keras will add the batch shape

# create the new nodes for each layer in the path
x = layer_input
for layer in resnet.layers[idx:]:
    x = layer(x)

# create the model
new_model = Model(layer_input, x)

And i am getting this error:

ValueError: Input 0 is incompatible with layer res3a_branch1: expected axis -1 of input shape to have value 256 but got shape (None, 28, 28, 512).

I also tried this function:

def split(model, start, end):
    confs = model.get_config()
    kept_layers = set()
    for i, l in enumerate(confs['layers']):
        if i == 0:
            confs['layers'][0]['config']['batch_input_shape'] = model.layers[start].input_shape
            if i != start:
                confs['layers'][0]['name'] += str(random.randint(0, 100000000)) # rename the input layer to avoid conflicts on merge
                confs['layers'][0]['config']['name'] = confs['layers'][0]['name']
        elif i < start or i > end:
            continue
        kept_layers.add(l['name'])
    # filter layers
    layers = [l for l in confs['layers'] if l['name'] in kept_layers]
    layers[1]['inbound_nodes'][0][0][0] = layers[0]['name']
    # set conf
    confs['layers'] = layers
    confs['input_layers'][0][0] = layers[0]['name']
    confs['output_layers'][0][0] = layers[-1]['name']
    # create new model
    submodel = Model.from_config(confs)
    for l in submodel.layers:
        orig_l = model.get_layer(l.name)
        if orig_l is not None:
            l.set_weights(orig_l.get_weights())
    return submodel

and i am getting this error:

ValueError: Unknown layer: Scale

as my resnet152 contains a Scale layer.

Here is a working version:

import resnet   # pip install resnet
from keras.models import Model
from keras.layers import Input

def getLayerIndexByName(model, layername):
    for idx, layer in enumerate(model.layers):
        if layer.name == layername:
            return idx


resnet = resnet.ResNet152(weights='imagenet')

idx = getLayerIndexByName(resnet, 'res3a_branch2a')

model1 = Model(inputs=resnet.input, outputs=resnet.get_layer('res3a_branch2a').output)

input_shape = resnet.layers[idx].get_input_shape_at(0) # get the input shape of desired layer
print(input_shape[1:])
layer_input = Input(shape=input_shape[1:]) # a new input tensor to be able to feed the desired layer

# create the new nodes for each layer in the path
x = layer_input
for layer in resnet.layers[idx:]:
    x = layer(x)

# create the model
model2 = Model(layer_input, x)

model2.summary()

Here is the error:

ValueError: Input 0 is incompatible with layer res3a_branch1: expected axis -1 of input shape to have value 256 but got shape (None, 28, 28, 512)


Solution 1:[1]

As I mentioned in the comments section since the ResNet model does not have a linear architecture (i.e. it has skip connections and a layer may be connected to multiple layers), you can't simply go through the layers of the model one after another in a loop and apply a layer on the output of previous layer in the loop (i.e. unlike the models with a linear architecture for which this method works).

So you need to find the connectivity of the layers and traverse that connectivity map to be able to construct a sub-model of the original model. Currently, this solution comes to my mind:

  1. Specify the last layer of your sub-model.
  2. Start from that layer and find all the connected layers to it.
  3. Get the output of those connected layers.
  4. Apply the last layer on the collected output.

Obviously step #3 implies a recursion: to get the output of connected layers (i.e. X), we first need to find their connected layers (i.e. Y), get their outputs (i.e. outputs of Y) and then apply them on those outputs (i.e. apply X on outputs of Y). Further, to find the connected layer you need to know a bit about the internals of Keras which has been covered in this answer. So we come up with this solution:

from keras.applications.resnet50 import ResNet50
from keras import models
from keras import layers

resnet = ResNet50()

# this is the split point, i.e. the starting layer in our sub-model
starting_layer_name = 'activation_46'

# create a new input layer for our sub-model we want to construct
new_input = layers.Input(batch_shape=resnet.get_layer(starting_layer_name).get_input_shape_at(0))

layer_outputs = {}
def get_output_of_layer(layer):
    # if we have already applied this layer on its input(s) tensors,
    # just return its already computed output
    if layer.name in layer_outputs:
        return layer_outputs[layer.name]

    # if this is the starting layer, then apply it on the input tensor
    if layer.name == starting_layer_name:
        out = layer(new_input)
        layer_outputs[layer.name] = out
        return out

    # find all the connected layers which this layer
    # consumes their output
    prev_layers = []
    for node in layer._inbound_nodes:
        prev_layers.extend(node.inbound_layers)

    # get the output of connected layers
    pl_outs = []
    for pl in prev_layers:
        pl_outs.extend([get_output_of_layer(pl)])

    # apply this layer on the collected outputs
    out = layer(pl_outs[0] if len(pl_outs) == 1 else pl_outs)
    layer_outputs[layer.name] = out
    return out

# note that we start from the last layer of our desired sub-model.
# this layer could be any layer of the original model as long as it is
# reachable from the starting layer
new_output = get_output_of_layer(resnet.layers[-1])

# create the sub-model
model = models.Model(new_input, new_output)

Important notes:

  1. This solution assumes that each layer in the original model has been used only once, i.e. it does not work for Siamese networks where a layer may be shared and therefore might be applied more than once on different input tensors.

  2. If you want to have a proper split of a model into multiple sub-models, then it makes sense to use only those layers for split point (e.g. indicated by starting_layer_name in the above code) which are NOT in a branch (e.g. in ResNet the activation layers after merge layers are a good option, but the res3a_branch2a you have selected is not a good option since it's in a branch). To get a better view of the original architecture of the model, you can always plot its diagram using plot_model() utility function:

    from keras.applications.resnet50 import ResNet50
    from keras.utils import plot_model
    
    resnet = ResNet50()
    plot_model(model, to_file='resnet_model.png')
    
  3. Since new nodes are created after constructing a sub-model, don't try to construct another sub-model which has overlap (i.e. if it does not have overlap, it's OK!) with the previous sub-model in the same run of the code above; otherwise, you may encounter errors.

Solution 2:[2]

I had a similar problem with slicing an Inception CNN for transfer learning, to set only the layers after a certain point to trainable.

def get_layers_above(cutoff_layer,model):

  def get_next_level(layer,model):
    def wrap_list(val):
      if type(val) is list:
        return val
      return [val] 
    r=[]
    for output_t in wrap_list(layer.output):
      r+=[x for x in model.layers if output_t.name in [y.name for y in wrap_list(x.input)]]
    return r

  visited=set()
  to_visit=set([cutoff_layer])

  while to_visit:
    layer=to_visit.pop()
    to_visit.update(get_next_level(layer,model))
    visited.add(layer)
  return list(visited)

I went with an iterative instead of a recursive solution because breadth-first traverse with sets seems like a safer solution for a network with many converging branches.

it should be used like this (InceptionV3 for example)

model = tf.keras.applications.InceptionV3(include_top=False,weights='imagenet',input_shape=(299,299,3))
layers=get_layers_above(model.get_layer('mixed9'),model)
print([l.name for l in layers])

output

 ['batch_normalization_89',
 'conv2d_93',
 'activation_86',
 'activation_91',
 'mixed10',
 'activation_88',
 'batch_normalization_85',
 'activation_93',
 'batch_normalization_90',
 'conv2d_87',
 'conv2d_86',
 'batch_normalization_86',
 'activation_85',
 'conv2d_91',
 'batch_normalization_91',
 'batch_normalization_87',
 'activation_90',
 'mixed9',
 'batch_normalization_92',
 'batch_normalization_88',
 'activation_87',
 'concatenate_1',
 'activation_89',
 'conv2d_88',
 'conv2d_92',
 'average_pooling2d_8',
 'activation_92',
 'mixed9_1',
 'conv2d_89',
 'conv2d_85',
 'conv2d_90',
 'batch_normalization_93']

Solution 3:[3]

In the case, when there is a layer with index middle, that it is connected only previous layer (# middle-1) and all layers after are not connected directly to layers before it, we can use the fact that every model is saved as a list of layers and create two partial models this way:

model1 = keras.models.Model(inputs=model.input, outputs=model.layers[middle - 1].output)
    
input = keras.Input(shape=model.layers[middle-1].output_shape[1:])
# layers is a dict in the form {name : output}
layers = {}
layers[model.layers[middle-1].name] = input
for layer in model.layers[middle:]:
    if type(layer.input) == list:
        x = []
        for layer_input in layer.input:
            x.append(layers[layer_input.name.split('/')[0]])
    else:
        x = layers[layer.input.name.split('/')[0]]
    y = layer(x)
    layers[layer.name] = y
model2 = keras.Model(inputs = [input], outputs = [y])

Then it is easy to check that model2.predict(model1.predict(x)) gives same results as model.predict(x)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 FlashDD
Solution 3