'Model.fit() Validation Accuracy different than Model.predict()
I have created a CNN to do binary classification in keras with the following code:
def neural_network():
classifier = Sequential()
# Adding a first convolutional layer
classifier.add(Convolution2D(48, 3, input_shape = (320, 320, 3), activation = 'relu'))
classifier.add(MaxPooling2D())
# Adding a second convolutional layer
classifier.add(Convolution2D(48, 3, activation = 'relu'))
classifier.add(MaxPooling2D())
#Flattening
classifier.add(Flatten())
#Full connected
classifier.add(Dense(256, activation = 'relu'))
#Full connected
classifier.add(Dense(256, activation = 'sigmoid'))
#Full connected
classifier.add(Dense(1, activation = 'sigmoid'))
# Compiling the CNN
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
classifier.summary()
train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
horizontal_flip = True,
vertical_flip=True,
brightness_range=[0.5, 1.5])
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory('/content/drive/My Drive/data_sep/train',
target_size = (320, 320),
batch_size = 32,
class_mode = 'binary')
test_set = test_datagen.flow_from_directory('/content/drive/My Drive/data_sep/validate',
target_size = (320, 320),
batch_size = 32,
class_mode = 'binary')
es = EarlyStopping(
monitor="val_accuracy",
mode="max",
patience
baseline=None,
restore_best_weights=True,
)
filepath = "/content/drive/My Drive/data_sep/weightsbestval.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
history = classifier.fit(training_set,
epochs = 10,
validation_data = test_set,
callbacks= es
)
best_score = max(history.history['val_accuracy'])
from sklearn.metrics import classification_report
predictions =(classifier.predict(test_set) > 0.5).astype("int32")
newlist = predictions.tolist()
finallist = []
for number in newlist:
finallist.append(number[0])
predicted_classes = np.asarray(finallist)
true_classes = test_set.classes
class_labels = list(test_set.class_indices.keys())
report = classification_report(true_classes, predicted_classes, target_names=class_labels)
accuracy = metrics.accuracy_score(true_classes, predicted_classes)
print(true_classes)
print(predicted_classes)
print(class_labels)
correct = 0
for i in range(len(true_classes)):
if (true_classes[i] == predicted_classes[i]):
correct = correct + 1
print(correct)
print((correct*1.0)/(len(true_classes)*1.0))
print(report)
return best_score
When I run the model I get a validation accuracy of 81.90% by model.fit() But after finishing the model.predict validation accuracy is 40%. I have added a callback where the best weights are restored. So what could be the problem here?
Solution 1:[1]
What fixed it for me was that I created another Image Data Generator variable
test2_datagen = ImageDataGenerator(rescale = 1./255)
test2_set = test2_datagen.flow_from_directory('/content/drive/My Drive/data_sep/validate',
target_size = (320, 320),
batch_size = 32,
class_mode = 'binary',
Shuffle = False)
But as you can see I set Shuffle = False
. I am posting this answer in case anyone has the same problem. So I used test2_set
for for the prediction.
Solution 2:[2]
test2_set = test2_datagen.flow_from_directory('/content/drive/My Drive/data_sep/validate',
target_size= (320, 320),
batch_size= 32,
class_mode= 'binary',
shuffle= False)
Emphasis on the lowercase shuffle parameter, otherwise this code will fail
Solution 3:[3]
Since you are saving best model in this line
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
please load this model in your code , and then predict
from keras.models import load_model
loaded_model = load_model('data_sep/weightsbestval.hdf5')
Then
loaded_model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics['accuracy'])
score = loaded_model.evaluate(X_test, Y_test, verbose=0)
print ("%s: %.2f%%" % (loaded_model.metrics_names[1], score[1]*100))
Plz vote / mark correct if you find this useful
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | asimplecoder |
Solution 2 | Zac Dair |
Solution 3 | Prajot Kuvalekar |