'Some parameters are not getting saved when saving a model in pytorch

I have built an encoder-decoder model with attention for morph inflection generation. I am able to train the model and predict on test data but I am getting wrong predicting after loading a saved model I am not getting any error during saving or loading but When I load a saved model its predictions are completely wrong. It looks like some parameters are not getting saved.

I have tried to load and save the model using both techniques

  1. using state_dict() eg. torch.save(encoder.state_dict(),'path')

  2. saving complete model eg.torch.save(encoder,'path')

I have tried to save different classes one by one and also making a superclass that initiates all those class and then saving just superclass

but nothing seems to be working

Encoder class

class Encoder(nn.Module):
    def __init__(self,vocab_size,embedding_size, encoder_hid_dem,decoder_hid_dem,bidirectional,dropout):
        super().__init__()

        self.encoder_hid_dem  = encoder_hid_dem
        self.encoder_n_direction=1;
        self.bias = False
        self.dropout=dropout 
        if(bidirectional==True):
            self.encoder_n_direction=2;

        self.embedding_layer  = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size, padding_idx=0)
        self.GRU_layer        = nn.GRU(input_size=embedding_size, hidden_size=encoder_hid_dem, batch_first=True, bidirectional=bidirectional)
        self.fc               = nn.Linear(encoder_hid_dem*self.encoder_n_direction,decoder_hid_dem)
        self.dropout          = nn.Dropout(dropout)


    def forward(self, input_word):
        # print(input_word.size())
        #[batch_size    src_sent_lent]

        embed_out = self.embedding_layer(input_word)
        #[BATCH_SIZE    src_sent_lent   embedding_dim]

        embed_out = F.relu(embed_out)
        embed_out = self.dropout(embed_out)

        self.batch = embed_out.size()[0]

        # hidden =  self.init_hidden()
        GRU_out,hidden = self.GRU_layer(embed_out)


        # print(GRU_out.size())
        # print(hidd.size())

        #[BATCH_SIZE    sec_sent_len    n_direction*hid_dem]
        #[n_layer*n_direction   batch_size    hid_dem]

        #where the first hid_dim elements in the third axis are the hidden states from the top layer forward RNN, and the last hid_dim elements are hidden states from the top layer backward RNN

        #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        #hidden [-2, :, : ] is the last of the forwards RNN 
        #hidden [-1, :, : ] is the last of the backwards RNN

        GRU_out = F.relu(GRU_out)
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:],hidden[-1,:,:]),dim=1)))

        # print(GRU_out.size())
        # print(hidden.size())

        #outputs = [batch_size    src sent len, encoder_hid_dim * n_direction]
        #hidden = [batch size, dec hid dim]
        return GRU_out,hidden

    def init_hidden(self):
        return (Variable(torch.eye(1, self.encoder_hid_dem)).unsqueeze(1).repeat(2, self.batch, 1).to(self.device))

Attention class

class Attention(nn.Module):
    def __init__(self,encoder_hid_dem,decoder_hid_dem,bidirectional):
        super().__init__()
        self.enc_hid_dim = encoder_hid_dem
        self.dec_hid_dim = decoder_hid_dem
        self.encoder_n_direction=1;
        if(bidirectional==True):
            self.encoder_n_direction=2;

        self.attn = nn.Linear((encoder_hid_dem * self.encoder_n_direction) + decoder_hid_dem, decoder_hid_dem)
        self.v = nn.Parameter(torch.rand(decoder_hid_dem))

    def forward(self, hidden, encoder_outputs):

        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [batch_size    ,src sent len, enc hid dim * encoder_n_direction]

        batch_size = encoder_outputs.shape[0]
        src_len    = encoder_outputs.shape[1]

        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        #hidden          = [batch size, src sent len, dec hid dim]
        #encoder_outputs = [batch size, src sent len, enc hid dim * encoder_n_direction]

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2))) 
        #energy = [batch size, src sent len, dec hid dim]

        energy = energy.permute(0, 2, 1)
        #energy = [batch size, dec hid dim, src sent len]

        #v = [dec hid dim]
        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        #v = [batch size, 1, dec hid dim]

        attention = torch.bmm(v, energy).squeeze(1)
        #attention= [batch size, src len]

        return F.softmax(attention, dim=1)


Decoder class

class Decoder(nn.Module):
    def __init__(self, decoder_hid_dem, encoder_hid_dem, vocab_size,embedding_dim,attention,decoder_input_size,linear_input_size,bidirectional,dropout):
        super().__init__()      
        self.encoder_hid_dem=encoder_hid_dem
        self.decoder_hid_dem=decoder_hid_dem
        self.attention=attention
        self.dropout = dropout
        self.output_dim = vocab_size

        self.decoder_n_direction=1;
        if(bidirectional==True):
            self.decoder_n_direction=2;

        self.GRU_layer_out = nn.GRU(decoder_input_size,decoder_hid_dem)
        self.out_layer = nn.Linear(in_features=linear_input_size, out_features=vocab_size)
        self.dropout = nn.Dropout(dropout)
        #self.GRU_layer_out.bias = torch.nn.Parameter(torch.zeros(decoder_input_size))

    def forward(self, feature, hidden,actual_word,encoder_outputs):

        feature = feature.unsqueeze(1)
        # print('decoder')
        # print(feature.size())
        #[batch_size    src_sent_lent=1   feat_size=6]

        # print(hidden.size())
        # [batch_size     dec_hid_dim]


        # print(actual_word.size())
        # [batch_size    src_sent_lent=1   embedding_dim]

        # print(encoder_outputs.size())
        # outputs = [batch_size    src sent len, encoder_hid_dim * encoder_n_directional]


        a = self.attention(hidden,encoder_outputs)
        #  print(a.size())
        # [batch_size    src_sent_len]

        a = a.unsqueeze(1)
        #a = [batch size, 1, src len] 

        weighted = torch.bmm(a,encoder_outputs)
        # print(weighted.size())
        # weighted = [batch size, 1, enc_hid_dim * encoder_n_direction]
        # if len(actual_word.size()) != 0:
        input_char = torch.cat((actual_word,feature,weighted),2) 
        # else:
        #     input_char = torch.cat((feature,weighted),2)

        input_char=input_char.permute(1,0,2)
        #  print(input_char.size())
        # [1    BATCH_SIZE      decoder_input_size]

        hidden = hidden.unsqueeze(0)
        # print(hidden.size())
        #[1 batch_size decoder_hid_dem]

        output, hidden = self.GRU_layer_out(input_char, hidden)

        # print(output.size())
        # [sent_len=1   batch_size  decoder_n_direction*decoder_hid_dem]
        # print(hidden.size())
        # [n_layer*n_direction    BATCH_SIZE      hid_dem]



        output = F.leaky_relu(output)
        output = self.dropout(output)

        output = torch.cat((output.squeeze(0),weighted.squeeze(1),actual_word.squeeze(1)),dim=1)
        pre_out = self.out_layer(output)
        predicted_output = F.log_softmax(pre_out, dim=1)

        # print(predicted_output.size())
        # [ batch_size vacab_size ]
        return predicted_output, hidden.squeeze(0)  

    def init_hidden(self, batch):
        return (Variable(torch.eye(1, self.decoder_hid_dem)).unsqueeze(1).repeat(1, batch, 1).to(self.device),Variable(torch.eye(1, self.decoder_hid_dem)).unsqueeze(1).repeat(1, batch, 1).to(self.device))



seq2seq class

class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder,device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self,input_word,output_word,features_word,teaching_forcing_ratio,limit):
        #print(input_word)
        #print(input_word.size())
        input_word = input_word.to(self.device)
        output_word = output_word.to(self.device)
        features_word = features_word.to(self.device)

        batch_size= input_word.size()[0]
        if(limit==0):
            max_len   = input_word.size()[1]
        else:
            max_len   = limit
        vocabsize = self.decoder.output_dim

        actual_word = self.encoder.embedding_layer(torch.tensor(char_to_index['<sos>']).view(1, -1).to(self.device)).repeat(batch_size, 1, 1)
        encoder_outputs,hidden = self.encoder(input_word)
        features=features_word[:,:]

        predicted_word = torch.zeros(max_len,batch_size,vocabsize).to(self.device)

        for t in range(1,max_len):
            output,hidden=self.decoder(features, hidden,actual_word,encoder_outputs)
            #print(output.size())
            predicted_word[t] = output 
            topv, topi = output.topk(1)
            bs = topi.size()[0]
            temp2 = torch.zeros(0,1,300).to(self.device)
            for row in range(bs):
                index = topi[row][0].item()
                temp = self.encoder.embedding_layer(torch.tensor(index).view(1, -1).to(self.device))
                temp2 = torch.cat((temp2,temp))

            teacher_force = random.random() < teaching_forcing_ratio
            if teacher_force == 1:
                actual_word = self.encoder.embedding_layer(output_word[:,t]).unsqueeze(1)
            else:
                actual_word = temp2

        return predicted_word

and this code is used to save and load model

torch.save(model.state_dict(), 'model.pt')
model.load_state_dict(torch.load('model.pt'))

I want that when I run my model on pre-trained weights, it should predict correctly acc to those weights



Solution 1:[1]

Your provided code for saving/loading parameters is wrong. The loading and saving model parameters are pretty straight-forward. In your case, it should be:

# loading
saved_params = torch.load(
    filename, map_location=lambda storage, loc: storage
)
s2s.load_state_dict(saved_params)

# saving
params = s2s.state_dict()
torch.save(params, filename)

[Update]

You need to make the Seq2Seq class a derived class of PyTorch's nn.Module just like your encoder/decoder classes. Otherwise, you can't use the state_dict() method. You can assume Seq2Seq class is like a container that contains your whole network, although it does not have any learnable weights itself.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1