'can anyone explain what "out = self(images)" do in below code
I am not able to understand, if prediction is calculated in forward method, then why there is need "out = self(images)" and what it will do. I am bit confuse about this code.
class MnistModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, xb):
xb = xb.reshape(-1, 784)
out = self.linear(xb)
return out
def training_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels) # Calculate loss
return loss
def validation_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels) # Calculate loss
acc = accuracy(out, labels) # Calculate accuracy
return {'val_loss': loss, 'val_acc': acc}
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
model = MnistModel()
Solution 1:[1]
In Python, self
refers to the instance that you have created from a class (similar to this
in Java and C++). An instance is callable, which means it may be called like a function itself, if method __call__
have been overridden.
Example:
class A:
def __init__(self):
pass
def __call__(self, x, y):
return x + y
a = A()
print(a(3,4)) # Prints 7
In your case, __call__
method is implemented in super class nn.Module
.
As it is a neural network module it needs an input placeholder. "out" is the placeholder for the data that is going to be forward the output of the module to the next layer or module of your model.
In the case of nn.Module
class instances (and those that inherit from the class) the forward method is what is used as the __call__
method. At least where it is defined with respect to the nn.Module
class.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | S.B |