'Displaying images from each class of a batched tensorflow dataset
I'm doing an assignment creating a cv model with 6 different classes.
I've loaded my dataset as per this example:
https://keras.io/examples/vision/image_classification_from_scratch/
but now want to visualise the data by showing 6 examples of each of the 6 classes. I keep running into indexing errors and am not sure how to resolve this! Any help would be much appreciated. My code is as follows:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"model-data",
validation_split=0.2,
subset="training",
seed=888,
image_size=image_size,
batch_size=batch_size,
)
for a in range(6):
axes=[]
fig=plt.figure()
for images, labels in train_ds.take(-1):
items = np.where(labels==a)[0]
plt.title("6 Examples of " + class_names[a])
for i in range(6):
axes.append( fig.add_subplot(2, 3, i+1))
plt.imshow(images[items[i]].numpy().astype("uint8"))
fig.tight_layout()
plt.show()
Solution 1:[1]
In the example below, I am visualizing the dataset with five classes. I’m plotting five images of five classes from the dataset.
import random
# Selecting a random batch from train_ds
# Note that if a particular batch doesn’t have all the classes (six in this case, then we only print the existing classes)
num_batch_trainds = train_ds.cardinality().numpy()
random_batch = random.randint(1,num_batch_trainds)
print(random_batch)
plt.figure(figsize=(10, 10))
# Getting the image related to a specific class and plotting them
for i in range(len(class_names)):
filtered_ds = train_ds.filter(lambda x, l: tf.math.equal(l[0], i))
for image, label in filtered_ds.take(random_batch):
ax = plt.subplot(3, 3, i+1)
plt.imshow(image[0].numpy().astype('uint8'))
plt.title(class_names[label.numpy()[0]])
plt.axis('off')
Please take a look at the colab.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Tfer3 |