'Plot the Decision Boundary of a Neural Network in PyTorch

I've been trying to plot the decision boundary of my neural network which I used for binary classification with the sigmoid function in the output layer but with no success, I found many posts discussing the plotting of the decision boundary of a scikit-learn classifier but not a neural network built in PyTorch. Below is my neural network:

class NeuralNetwork(torch.nn.Module):
  def __init__(self):
    super(NeuralNetwork, self).__init__()
    self.fc1 = torch.nn.Linear(23, 16)
    self.fc2 = torch.nn.Linear(16, 14)
    self.fc3 = torch.nn.Linear(14, 10)
    self.fc4 = torch.nn.Linear(10, 5)
    self.fc5 = torch.nn.Linear(5, 1)

  def forward(self, x):
    x = torch.relu(self.fc1(x))
    x = torch.relu(self.fc2(x))
    x = torch.relu(self.fc3(x))
    x = torch.relu(self.fc4(x))
    x = torch.sigmoid(self.fc5(x))
    return x

model = NeuralNetwork().double()

CUDA = torch.cuda.is_available()
if CUDA:
  model.cuda()

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

model_1.train()

Precision = []
Cost = []

for epoch in range(10001):

  if CUDA:
    inputs = X_train.cuda()
    label = Y_train.cuda()
  else:
    inputs = X_train
    label = Y_train

  prediction = model_1(inputs)
  loss = criterion(prediction, label)
  accuracy = ((prediction > 0.5) == label).float().mean().item()

  Cost.append(loss.item())
  Precision.append(accuracy)

  if epoch % 1000 == 0 or epoch == 30000:
    print("Epoch:", epoch, ",", "Loss:", loss.item(), ",", "Accuracy:", accuracy)

  # Backpropagation process
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

model_1.eval()

X_test = torch.from_numpy(X[27000:,:])
Y_test = torch.from_numpy(y[27000:,:]).double()

with torch.no_grad():

  y_pred = model_1(X_test)
  print("Accuracy: ", ((y_pred > 0.5) == Y_test).float().mean().item())

Here is my attempt trying to generate a similar plot here:

# I've chosen two features because the data is multi-dimensional
X0 = X_test[:,[0,]].reshape(3000)
X5 = X_test[:,[5,]].reshape(3000)
X0, X5 = np.meshgrid(X0, X5)

xx, yy = np.meshgrid(X0, X5)
grid = np.c_[xx.ravel(), yy.ravel()]
probs = y_pred

f, ax = plt.subplots(figsize=(8, 6))
contour = ax.contourf(xx, yy, probs, 25, cmap="RdBu",
                  vmin=0, vmax=1)
ax_c = f.colorbar(contour)
ax_c.set_label("$P(y = 1)$")
ax_c.set_ticks([0, .25, .5, .75, 1])

ax.scatter(X0, X5, c=Y_test, s=50,cmap="RdBu", vmin=-.2, vmax=1.2,edgecolor="white", linewidth=1)

ax.set(aspect="equal",
   xlim=(-5, 5), ylim=(-5, 5),
   xlabel="$X0$", ylabel="$X5$")

but unfortunately I get the following error:

TypeError                                 Traceback (most recent call last)

<ipython-input-52-fb941749621a> in <module>()
  1 f, ax = plt.subplots(figsize=(8, 6))
  2 contour = ax.contourf(xx, yy, probs, 25, cmap="RdBu",
----> 3                       vmin=0, vmax=1)
  4 ax_c = f.colorbar(contour)
  5 ax_c.set_label("$P(y = 1)$")

5 frames

/usr/local/lib/python3.6/dist-packages/matplotlib/contour.py in _check_xyz(self, args, kwargs)
1549             raise TypeError("Input z must be a 2D array.")
1550         elif z.shape[0] < 2 or z.shape[1] < 2:
-> 1551             raise TypeError("Input z must be at least a 2x2 array.")
1552         else:
1553             Ny, Nx = z.shape

TypeError: Input z must be at least a 2x2 array.

I would greatly appreciate your help, thanks in advance.



Solution 1:[1]

You could define a mesh of dots and then predict each dot. According to the result, we can find out the dots with different predictions on each side. Thus, by connecting the dots, we have an approximate decision boundary. However, this could be computationally expensive if the area to the plot is large or a detailed mesh is desired.

Here are some references.

Some code on stackexchange

A similar question on stackoverflow

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Wey Shi