'Gradio - Pytorch MNIST Digit Recognizer
I watched the following video on YouTube https://www.youtube.com/watch?v=jx9iyQZhSwI where it was shown that it is possible to use Gradio and the learned model of MNIST dataset in Tensorflow. I have read and written that it is possible to use Pytorch in Gradio, but I have problems with its implementation. Does anyone have an idea how to do this? My Pytorch code of cnn
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2),
)
# fully connected layer, output 10 classes
self.out = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
# flatten the output of conv2 to (batch_size, 32 * 7 * 7)
x = x.view(x.size(0), -1)
output = self.out(x)
return output, x # return x for visualization
By watching I find that I need to change function that Gradio use
def predict_image(img):
img_3d=img.reshape(-1,28,28)
im_resize=img_3d/255.0
prediction=CNN(im_resize)
pred=np.argmax(prediction)
return pred
Solution 1:[1]
Im sorry if I got your question wrong, but from what I understand you are getting an error when trying to predict the digit using your function predict image.
So here are two possible hints. Maybe you have implemented them already, but I don't know because of the very small code snippet.
First of all. Have you set your model into evaluation mode using
CNN.eval()
Do after you finished training your model and want to evaluate inputs without training the model.
Second of all, maybe you need to add a fourth dimension to your input tensor "im_resize". Normally your model expects a dimension for the number of channels, the batch size, the height and the width of your input. In addition I can not tell if your input is a of the datatype torch.tensor . If not transform your array into a tensor first.
You can add a batch dimension to your input tensor by using
im_resize = im_resize.unsqueeze(0)
I hope that I understand your question correctly and was able to help you.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Odin |