'Use different metrics in tf.keras.metrics for mutli-classification model
I am using the TensorFlow federated framework for
a multiclassification problem. I am following the tutorials and most of them use the metric (tf.keras.metrics.SparseCategoricalAccuracy
) to measure the models' accuracy.
I wanted to explore the other measures like (AUC, recall, F1, and precision) but I am getting the errors.
The code and the error message are provided below.
def create_keras_model():
initializer = tf.keras.initializers.Zeros()
return tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(8,)),
tf.keras.layers.Dense(64),
tf.keras.layers.Dense(4, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),
tf.keras.metrics.Recall()]
)
The error
ValueError: Shapes (None, 4) and (None,) are incompatible
Is it because of the muti classification problem, that we cannot use these measures with it? and if so, is there any other metric I may use to measure my multi-classification model.
Solution 1:[1]
tf.keras.metrics.SparseCategoricalAccuracy() --> is for SparseCategorical (int) class. tf.keras.metrics.Recall() --> is for categorical (one-hot) class.
You have to use a one-hot class if you want to use any metric naming without the 'Sparse'.
update:
num_class=4
def get_img_and_onehot_class(img_path, class):
img = tf.io.read_file(img_path)
img = tf.io.decode_jpeg(img, channels=3)
""" Other preprocessing of image."""
return img, tf.one_hot(class, num_class)
when you got the one-hot class:
loss=tf.losses.CategoricalCrossentropy
METRICS=[tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),]
model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=loss,
metrics= METRICS)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |