'Combine the outputs of a sub model intermediate layer and a parent model
I am trying to make a toy example work; there is a simple submodel:
Model: "sub_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sub_middle (Dense) (None, 100) 12900
sub_last (Dense) (None, 100) 10100
=================================================================
which is embedded in a parent model:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sub_model (Sequential) (None, 100) 23000
last (Dense) (None, 2) 202
=================================================================
Is it possible to create a third model which outputs the sub_middle
layer's output and the parent model output?
Motivation: The sub-model can be a CNN predefined neural network e.g. VGG, with a custom head that is implemented in a parent model. I would like to create the third model for calculating the gradients of an arbitrary layer in the CNN predefined neural network in relation to the parent output.
The example code:
import tensorflow as tf
import numpy as np
sub_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(128)),
tf.keras.layers.Dense(100, name = 'sub_middle'),
tf.keras.layers.Dense(100, name = 'sub_last'),
],
name="sub_model",
)
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(128), name = 'input'),
sub_model,
tf.keras.layers.Dense(2, name = 'last'),
],
name="model",
)
# Works fine
grad_model = tf.keras.models.Model([model.get_layer('sub_model').input], [model.get_layer('sub_model').get_layer('sub_middle').output])
# Works fine
grad_model = tf.keras.models.Model([model.input], [model.input, model.output])
# ValueError: Graph disconnected
grad_model = tf.keras.models.Model([model.get_layer('sub_model').input], [model.get_layer('sub_model').get_layer('sub_middle').output, model.output])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|