'Some parameters are not getting saved when saving a model in pytorch
I have built an encoder-decoder model with attention for morph inflection generation. I am able to train the model and predict on test data but I am getting wrong predicting after loading a saved model I am not getting any error during saving or loading but When I load a saved model its predictions are completely wrong. It looks like some parameters are not getting saved.
I have tried to load and save the model using both techniques
using state_dict() eg. torch.save(encoder.state_dict(),'path')
saving complete model eg.torch.save(encoder,'path')
I have tried to save different classes one by one and also making a superclass that initiates all those class and then saving just superclass
but nothing seems to be working
Encoder class
class Encoder(nn.Module):
def __init__(self,vocab_size,embedding_size, encoder_hid_dem,decoder_hid_dem,bidirectional,dropout):
self.encoder_hid_dem = encoder_hid_dem
self.bias = False
self.embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size, padding_idx=0)
self.GRU_layer = nn.GRU(input_size=embedding_size, hidden_size=encoder_hid_dem, batch_first=True, bidirectional=bidirectional)
self.fc = nn.Linear(encoder_hid_dem*self.encoder_n_direction,decoder_hid_dem)
self.dropout = nn.Dropout(dropout)
def forward(self, input_word):
# print(input_word.size())
#[batch_size src_sent_lent]
embed_out = self.embedding_layer(input_word)
#[BATCH_SIZE src_sent_lent embedding_dim]
embed_out = F.relu(embed_out)
embed_out = self.dropout(embed_out)
self.batch = embed_out.size()[0]
# hidden = self.init_hidden()
GRU_out,hidden = self.GRU_layer(embed_out)
# print(GRU_out.size())
# print(hidd.size())
#[BATCH_SIZE sec_sent_len n_direction*hid_dem]
#[n_layer*n_direction batch_size hid_dem]
#where the first hid_dim elements in the third axis are the hidden states from the top layer forward RNN, and the last hid_dim elements are hidden states from the top layer backward RNN
#hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
#hidden [-2, :, : ] is the last of the forwards RNN
#hidden [-1, :, : ] is the last of the backwards RNN
GRU_out = F.relu(GRU_out)
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:],hidden[-1,:,:]),dim=1)))
# print(GRU_out.size())
# print(hidden.size())
#outputs = [batch_size src sent len, encoder_hid_dim * n_direction]
#hidden = [batch size, dec hid dim]
return GRU_out,hidden
def init_hidden(self):
return (Variable(torch.eye(1, self.encoder_hid_dem)).unsqueeze(1).repeat(2, self.batch, 1).to(self.device))
Attention class
class Attention(nn.Module):
def __init__(self,encoder_hid_dem,decoder_hid_dem,bidirectional):
self.enc_hid_dim = encoder_hid_dem
self.dec_hid_dim = decoder_hid_dem
self.attn = nn.Linear((encoder_hid_dem * self.encoder_n_direction) + decoder_hid_dem, decoder_hid_dem)
self.v = nn.Parameter(torch.rand(decoder_hid_dem))
def forward(self, hidden, encoder_outputs):
#hidden = [batch size, dec hid dim]
#encoder_outputs = [batch_size ,src sent len, enc hid dim * encoder_n_direction]
batch_size = encoder_outputs.shape[0]
src_len = encoder_outputs.shape[1]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
#hidden = [batch size, src sent len, dec hid dim]
#encoder_outputs = [batch size, src sent len, enc hid dim * encoder_n_direction]
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
#energy = [batch size, src sent len, dec hid dim]
energy = energy.permute(0, 2, 1)
#energy = [batch size, dec hid dim, src sent len]
#v = [dec hid dim]
v = self.v.repeat(batch_size, 1).unsqueeze(1)
#v = [batch size, 1, dec hid dim]
attention = torch.bmm(v, energy).squeeze(1)
#attention= [batch size, src len]
return F.softmax(attention, dim=1)
Decoder class
class Decoder(nn.Module):
def __init__(self, decoder_hid_dem, encoder_hid_dem, vocab_size,embedding_dim,attention,decoder_input_size,linear_input_size,bidirectional,dropout):
self.dropout = dropout
self.output_dim = vocab_size
self.GRU_layer_out = nn.GRU(decoder_input_size,decoder_hid_dem)
self.out_layer = nn.Linear(in_features=linear_input_size, out_features=vocab_size)
self.dropout = nn.Dropout(dropout)
#self.GRU_layer_out.bias = torch.nn.Parameter(torch.zeros(decoder_input_size))
def forward(self, feature, hidden,actual_word,encoder_outputs):
feature = feature.unsqueeze(1)
# print('decoder')
# print(feature.size())
#[batch_size src_sent_lent=1 feat_size=6]
# print(hidden.size())
# [batch_size dec_hid_dim]
# print(actual_word.size())
# [batch_size src_sent_lent=1 embedding_dim]
# print(encoder_outputs.size())
# outputs = [batch_size src sent len, encoder_hid_dim * encoder_n_directional]
a = self.attention(hidden,encoder_outputs)
# print(a.size())
# [batch_size src_sent_len]
a = a.unsqueeze(1)
#a = [batch size, 1, src len]
weighted = torch.bmm(a,encoder_outputs)
# print(weighted.size())
# weighted = [batch size, 1, enc_hid_dim * encoder_n_direction]
# if len(actual_word.size()) != 0:
input_char = torch.cat((actual_word,feature,weighted),2)
# else:
# input_char = torch.cat((feature,weighted),2)
# print(input_char.size())
# [1 BATCH_SIZE decoder_input_size]
hidden = hidden.unsqueeze(0)
# print(hidden.size())
#[1 batch_size decoder_hid_dem]
output, hidden = self.GRU_layer_out(input_char, hidden)
# print(output.size())
# [sent_len=1 batch_size decoder_n_direction*decoder_hid_dem]
# print(hidden.size())
# [n_layer*n_direction BATCH_SIZE hid_dem]
output = F.leaky_relu(output)
output = self.dropout(output)
output = torch.cat((output.squeeze(0),weighted.squeeze(1),actual_word.squeeze(1)),dim=1)
pre_out = self.out_layer(output)
predicted_output = F.log_softmax(pre_out, dim=1)
# print(predicted_output.size())
# [ batch_size vacab_size ]
return predicted_output, hidden.squeeze(0)
def init_hidden(self, batch):
return (Variable(torch.eye(1, self.decoder_hid_dem)).unsqueeze(1).repeat(1, batch, 1).to(self.device),Variable(torch.eye(1, self.decoder_hid_dem)).unsqueeze(1).repeat(1, batch, 1).to(self.device))
seq2seq class
class Seq2Seq(nn.Module):
def __init__(self,encoder,decoder,device):
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self,input_word,output_word,features_word,teaching_forcing_ratio,limit):
input_word = input_word.to(self.device)
output_word = output_word.to(self.device)
features_word = features_word.to(self.device)
batch_size= input_word.size()[0]
max_len = input_word.size()[1]
max_len = limit
vocabsize = self.decoder.output_dim
actual_word = self.encoder.embedding_layer(torch.tensor(char_to_index['<sos>']).view(1, -1).to(self.device)).repeat(batch_size, 1, 1)
encoder_outputs,hidden = self.encoder(input_word)
predicted_word = torch.zeros(max_len,batch_size,vocabsize).to(self.device)
for t in range(1,max_len):
output,hidden=self.decoder(features, hidden,actual_word,encoder_outputs)
predicted_word[t] = output
topv, topi = output.topk(1)
bs = topi.size()[0]
temp2 = torch.zeros(0,1,300).to(self.device)
for row in range(bs):
index = topi[row][0].item()
temp = self.encoder.embedding_layer(torch.tensor(index).view(1, -1).to(self.device))
temp2 = torch.cat((temp2,temp))
teacher_force = random.random() < teaching_forcing_ratio
if teacher_force == 1:
actual_word = self.encoder.embedding_layer(output_word[:,t]).unsqueeze(1)
actual_word = temp2
return predicted_word
and this code is used to save and load model
torch.save(model.state_dict(), 'model.pt')
I want that when I run my model on pre-trained weights, it should predict correctly acc to those weights
Solution 1:[1]
Your provided code for saving/loading parameters is wrong. The loading and saving model parameters are pretty straight-forward. In your case, it should be:
# loading
saved_params = torch.load(
filename, map_location=lambda storage, loc: storage
# saving
params = s2s.state_dict()
torch.save(params, filename)
You need to make the Seq2Seq
class a derived class of PyTorch's nn.Module
just like your encoder/decoder classes. Otherwise, you can't use the state_dict()
method. You can assume Seq2Seq
class is like a container that contains your whole network, although it does not have any learnable weights itself.
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
Solution 1 |