'How to add "Crop" in order to concatenate the skip connections in Encoder and Decoder Levels as described in UNET paper

I have implemented the following UNET Paper Code and here is the architecture:

enter image description here

The problem with this is that that at level 4, Encoder has features of shape 512 x 64 x 64 and the Decoder part will be having a different features shape as 512 x 56 x 56. So Then I looked closely to find these gray arrows for copy and crop. In the paper, there is no mentioning of how it is done but just 2 reference of crops as:

Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in every convolution.

Could Someone please explain how could I make them compatible? The code Till the BottleNeck working perfectly and as I am stuck at Cropping part so could not test the logic and code flow for Decoder.

'''
Whole UNet is divided in 3 parts: Endoder -> BottleNeck -> Decoder. There are skip connections between 'Nth' level of Encoder with Nth level of Decoder.

There is 1 basic entity called "Convolution" Block which has 3*3 Convolution (or Transposed Convolution during Upsampling) -> ReLu -> BatchNorm
Then there is Maxpooling
'''

import torch
import torch.nn as nn

# ------------------- TESTING CODE -------------------
image = torch.randn(1,1, 572,572) # Batch of 1 Gray scale image as described in paper to test
enc = Encoder(1)
feat, skip = enc(image)
bot = BottleNeck()
feat = bot(feat)

# dec = Decoder()
# feats = dec(feat, skip) # Error Starts here in Decoder  Block

# -----------------------------

class ConvolutionBlock(nn.Module):
    '''
    The basic Convolution Block Which Will have Convolution -> RelU -> Convolution -> RelU
    '''
    def __init__(self, input_features, out_features):
        '''
        args:
            batch_norm was introduced after UNET so they did not know if it existed. Might be useful
        '''
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(input_features, out_features, kernel_size = 3, padding= 0), # padding is 0 by default, 1 means the input width, height == out width, height
            nn.ReLU(),
            nn.Conv2d(out_features, out_features, kernel_size = 3, padding = 0),
            nn.ReLU(),
            )

    def forward(self, feature_map_x):
        '''
        feature_map_x could be the image itself or the
        '''
        return self.network(feature_map_x)


class Encoder(nn.Module):
    '''
    '''
    def __init__(self, image_channels:int = 3, blockwise_features = [64, 128, 256, 512]):
        '''
        In UNET, the features start at 64 and keeps getting twice the size of the previous one till it reached BottleNeck
        args:
            image_channels: Channels in the Input Image. Typically it is any of the 1 or 3 (rarely 4)
            blockwise_features = Each block has it's own input and output features. it means first ConV block will output 64 features, second 128 and so on
        '''
        super().__init__()
        repeat = len(blockwise_features) # how many layers we need to add len of blockwise_features == len of out_features

        self.layers = nn.ModuleList()
        
        for i in range(repeat):
            if i == 0:
                in_filters = image_channels
                out_filters = blockwise_features[0]
            else:
                in_filters = blockwise_features[i-1]
                out_filters = blockwise_features[i]
            
            self.layers.append(ConvolutionBlock(in_filters, out_filters))

        self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2)  # Since There is No gradient for Maxpooling, You can instantiate a single layer for the whole operation
        # https://datascience.stackexchange.com/questions/11699/backprop-through-max-pooling-layers
        
    
    def forward(self, feature_map_x):
        skip_connections = [] # i_th level of features from Encoder will be conatenated with i_th level of decoder before applying CNN
        
        for layer in self.layers:
            feature_map_x = layer(feature_map_x)
            skip_connections.append(feature_map_x)
            feature_map_x = self.maxpool(feature_map_x) # Use Max Pooling AFTER storing the Skip connections

        return feature_map_x, skip_connections

    
class BottleNeck(nn.Module):
    '''
    ConvolutionBlock without Max Pooling
    '''
    def __init__(self, input_features = 512, output_features = 1024):
        super().__init__()
        self.layer = ConvolutionBlock(input_features, output_features)

        
    def forward(self, feature_map_x):
        return self.layer(feature_map_x)
        

class Decoder(nn.Module):
    '''
    '''
    def __init__(self, blockwise_features = [512, 256, 128, 64]):
        '''
        Do exactly opposite of Encoder
        '''
        super().__init__()

        self.upsample_layers = nn.ModuleList()
        self.conv_layers = nn.ModuleList()
        
        for i, feature in enumerate(blockwise_features):

            self.upsample_layers.append(nn.ConvTranspose2d(in_channels = feature*2, out_channels = feature, kernel_size = 2, stride = 2))  # Takes in 1024-> 512, takes 512->254 ......

            self.conv_layers.append(nn.ConvTranspose2d(in_channels = feature*2, out_channels = feature, kernel_size = 2, stride = 2)) # After Concatinating (512 + 512-> 1024), Use double Conv block
        
    
    def forward(self, feature_map_x, skip_connections):
        '''
        Steps go as:
        1. Upsample
        2. Concat Skip Connection
        3. Apply ConvolutionBlock
        '''

        for i, layer in enumerate(self.conv_layers): # 4 levels, 4 skip connections, 4 upsampling, 4 Double Conv Block

            feature_map_x = self.upsample_layers[i](feature_map_x) # step 1
            feature_map_x = torch.cat((skip_connections[-i-1], feature_map_x), dim = 1) # step 2
            feature_map_x = self.conv_layers[i](feature_map_x)

        return feature_map_x




Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source