'I this the correct way of computing the average accuracy?
I am fairly new to coding and getting confused between average accuracy and overall accuracy. I have created a function to calculate accuracy, i then divide this result by the len(dataloader) at the end of each epoch. Is this the correct way to calculate average accuracy? If not could someone explain how I go about doing this correctly?
def accuracy(predictions, labels):
classes = torch.argmax(predictions, dim=1)
return torch.mean((classes == labels).float())
def train(model, optimizer, dataloader):
#Setting model to train mode
model.train()
acc = 0.0
loss = 0.0
loss_fc = nn.CrossEntropyLoss()
for i, (img, label) in enumerate(dataloader):
#source images and labels to cpu device
img, label = img.to(device), label.to(device)
y_pred = model(img)
optimizer.zero_grad()
loss = loss_fc(y_pred, label)
loss.backward()
optimizer.step()
#Update loss and accuracy
loss += loss.item()
acc += accuracy(y_pred, s_label)
loss /= len(dataloader)
acc /= len(dataloader)
Solution 1:[1]
Not sure what you mean by the overall and average accuracy. Typically accuracy is calculated at the end of each epoch. You pass the accuracy function your predictions and your actual labels and it returns what proportion you got right as a decimal (0-1).
I haven't seen any use for calculating the average accuracy across every epoch during training as this metric would be heavily impacted by how fast your model learns rather than how well it is able to eventually perform e.g. a model that needs a lot of epochs to do well will probably appear worse on this average accuracy than one that can converge on fewer epochs.
If you take a look at the accuracy score metric from scikit-learn it should help clear things up for you.
Link: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html
Hope this helps!
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | mrw |