'Displaying images from each class of a batched tensorflow dataset

I'm doing an assignment creating a cv model with 6 different classes.

I've loaded my dataset as per this example:

https://keras.io/examples/vision/image_classification_from_scratch/

but now want to visualise the data by showing 6 examples of each of the 6 classes. I keep running into indexing errors and am not sure how to resolve this! Any help would be much appreciated. My code is as follows:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "model-data",
    validation_split=0.2,
    subset="training",
    seed=888,
    image_size=image_size,
    batch_size=batch_size,
)

for a in range(6):
    
    axes=[]
    fig=plt.figure()
    for images, labels in train_ds.take(-1):
        items = np.where(labels==a)[0]
        plt.title("6 Examples of " + class_names[a])
    
        for i in range(6):   
            axes.append( fig.add_subplot(2, 3, i+1))         
            plt.imshow(images[items[i]].numpy().astype("uint8"))
        fig.tight_layout()    
        plt.show()


Solution 1:[1]

In the example below, I am visualizing the dataset with five classes. I’m plotting five images of five classes from the dataset.

import random
# Selecting a random batch from train_ds
# Note that if a particular batch doesn’t have all the classes (six in this case, then we only print the existing classes)
num_batch_trainds = train_ds.cardinality().numpy()
random_batch = random.randint(1,num_batch_trainds)
print(random_batch)
 
plt.figure(figsize=(10, 10))
# Getting the image related to a specific class and plotting them
for i in range(len(class_names)):
   filtered_ds = train_ds.filter(lambda x, l: tf.math.equal(l[0], i))
   for image, label in filtered_ds.take(random_batch):
       ax = plt.subplot(3, 3, i+1)
       plt.imshow(image[0].numpy().astype('uint8'))
       plt.title(class_names[label.numpy()[0]])
       plt.axis('off')

Please take a look at the colab.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Tfer3