'Duplicate layers when reusing pytorch model
I am trying to reuse some of the resnet layers for a custom architecture and ran into a issue I can't figure out. Here is a simplified example; when I run:
import torch
from torchvision import models
from torchsummary import summary
def convrelu(in_channels, out_channels, kernel, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
nn.ReLU(inplace=True),
)
class ResNetUNet(nn.Module):
def __init__(self):
super().__init__()
self.base_model = models.resnet18(pretrained=False)
self.base_layers = list(self.base_model.children())
self.layer0 = nn.Sequential(*self.base_layers[:3])
def forward(self, x):
print(x.shape)
output = self.layer0(x)
return output
base_model = ResNetUNet().cuda()
summary(base_model,(3,224,224))
Is giving me:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
Conv2d-2 [-1, 64, 112, 112] 9,408
BatchNorm2d-3 [-1, 64, 112, 112] 128
BatchNorm2d-4 [-1, 64, 112, 112] 128
ReLU-5 [-1, 64, 112, 112] 0
ReLU-6 [-1, 64, 112, 112] 0
================================================================
Total params: 19,072
Trainable params: 19,072
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 36.75
Params size (MB): 0.07
Estimated Total Size (MB): 37.40
----------------------------------------------------------------
This is duplicating each layer (there are 2 convs, 2 batchnorms, 2 relu's) as opposed to giving one layer each. If I print out self.base_layers[:3]
I get:
[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True)]
which shows just three layers without duplicates. Why is it duplicating my layers?
I am using pytorch version 1.4.0
Solution 1:[1]
Your layers aren't actually being invoked twice. This is an artifact of how summary
is implemented.
The simple reason is because summary
recursively iterates over all the children of your module and registers forward hooks for each of them. Since you have repeated children (in base_model
and layer0
) then those repeated modules get multiple hooks registered. When summary calls forward this causes both of the hooks for each module to be invoked which causes repeats of the layers to be reported.
For your toy example a solution would be to simply not assign base_model
as an attribute since it's not being used during forward anyway. This avoids having base_model
ever being added as a child.
class ResNetUNet(nn.Module):
def __init__(self):
super().__init__()
base_model = models.resnet18(pretrained=False)
base_layers = list(base_model.children())
self.layer0 = nn.Sequential(*base_layers[:3])
Another solution is to create a modified version of summary
which doesn't register hooks for the same module multiple times. Below is an augmented summary
where I use a set named already_registered
to keep track of modules which already have hooks registered to avoid registering multiple hooks.
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
def summary(model, input_size, batch_size=-1, device="cuda"):
# keep track of registered modules so that we don't add multiple hooks
already_registered = set()
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].size())
summary[m_key]["input_shape"][0] = batch_size
if isinstance(output, (list, tuple)):
summary[m_key]["output_shape"] = [
[-1] + list(o.size())[1:] for o in output
]
else:
summary[m_key]["output_shape"] = list(output.size())
summary[m_key]["output_shape"][0] = batch_size
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
and not (module == model)
and module not in already_registered:
):
already_registered.add(module)
hooks.append(module.register_forward_hook(hook))
device = device.lower()
assert device in [
"cuda",
"cpu",
], "Input device is not valid, please specify 'cuda' or 'cpu'"
if device == "cuda" and torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]
# batch_size of 2 for batchnorm
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
# print(type(x[0]))
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
# make a forward pass
# print(x.shape)
model(*x)
# remove these hooks
for h in hooks:
h.remove()
print("----------------------------------------------------------------")
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
line_new = "{:>20} {:>25} {:>15}".format(
layer,
str(summary[layer]["output_shape"]),
"{0:,}".format(summary[layer]["nb_params"]),
)
total_params += summary[layer]["nb_params"]
total_output += np.prod(summary[layer]["output_shape"])
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
print(line_new)
# assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
print("================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("----------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("----------------------------------------------------------------")
# return summary
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |