'Sparse Categorical CrossEntropy shape problem with Keras
I have a multiclass problem where an image can be one of three classes (Masked, UnMasked, Hybrid).
I am using image_dataset_from_directory from keras preprocessing module which makes things easier.
Load dataset
def load_from_directory(shuffle=False):
train_ds = tfk.preprocessing.image_dataset_from_directory(
directory=TRAINING_PATH,
image_size=IMAGE_SIZE,
validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE,
seed=SEED,
subset='training',
label_mode='int',
shuffle=shuffle
)
val_ds = tfk.preprocessing.image_dataset_from_directory(
directory=TRAINING_PATH,
image_size=IMAGE_SIZE,
validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE,
seed=SEED,
subset='validation',
label_mode='int',
shuffle=False
)
test_ds = tfk.preprocessing.image_dataset_from_directory(
directory=TESTING_PATH,
labels=None,
image_size=IMAGE_SIZE,
batch_size=BATCH_SIZE,
seed=SEED,
label_mode='int',
shuffle=False
)
return train_ds, val_ds, test_ds
On the keras documentation it states for the label_model the following: 'int': means that the labels are encoded as integers (e.g. for sparse_categorical_crossentropy loss) ... (here is the link). Since my folder has the following structure:
- MaskDataset
- training
- 0/
- img1, img2, …, imgN
- 1/
- img1, img2, …, imgN
- 2/
- img1, img2, …, imgN
- 0/
- testing
- img1, img2, …, imgN
- training
And for the model by final layers are the following:
x = tfkl.SeparableConv2D(1024, 3, padding='same')(x)
x = tfkl.BatchNormalization()(x)
x = tfkl.Activation('relu')(x)
# GlobalAveragePooling + Dropout
x = tfkl.GlobalAveragePooling2D()(x)
x = tfkl.Dropout(0.5)(x)
# Softmax
outputs = tfkl.Dense(units=len(CLASS_NAMES), activation='softmax')(x)
model = tfk.Model(inputs, outputs)
which I get the following complain: "ValueError: Shapes (None, 3) and (None, 1) are incompatible"
If I change the label_mode to 'categorical', when loading the dataset and change the loss function from SparseCategoricalCrossEntropy to CategoricalCrossEntropy and the accuracy from SparseCategoricalAccuracy to CategoricalAccuracy it works, but I really wanted to understand why I am not able to use the SparseCrossEntropy loss function and how to fix it.
Edit: Add model compile and fit
def compile_model(model, plot=False):
model.compile(
optimizer=tf.optimizers.Adam(1e-3),
loss=tf.losses.SparseCategoricalCrossentropy(name='loss'),
metrics=[
tfk.metrics.SparseCategoricalAccuracy(name='accuracy'),
tfk.metrics.Precision(name='precision'),
tfk.metrics.Recall(name='recall'),
]
)
model.summary()
if plot: tfk.utils.plot_model(model, show_shapes=True)
def train_model(model, debug_mode=False):
callbacks = [tfk.callbacks.EarlyStopping(patience=5, monitor='val_loss', restore_best_weights=True)]
if debug_mode:
callbacks.append(tfk.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.h5', save_best_only=True, monitor='val_loss'))
callbacks.append(tfk.callbacks.TensorBoard(log_dir='./tensorboard'))
history = model.fit(
x=train_ds,
validation_data=val_ds,
epochs=100,
callbacks=callbacks,
# steps_per_epoch=len(train_ds),
# validation_steps=len(val_ds),
)
return history
Solution 1:[1]
tf.keras.metrics.SparseCategoricalAccuracy() --> is for SparseCategorical (int) class. tf.keras.metrics.Recall() --> is for categorical (one-hot) class.
You have to use a one-hot class if you want to use any metric naming without the 'Sparse'.
update:
num_class=4
def get_img_and_onehot_class(img_path, class):
img = tf.io.read_file(img_path)
img = tf.io.decode_jpeg(img, channels=3)
""" Other preprocessing of image."""
return img, tf.one_hot(class, num_class)
when you got the one-hot class:
loss=tf.losses.CategoricalCrossentropy
METRICS=[tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),]
model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=loss,
metrics= METRICS)
Ref:#72217176
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |