'Displaying images loaded with pytorch dataloader

I am working with some lidar data images that I cannot post here due to a reputation restriction on posting images. However, when loading the same images using pytorch ImageFolder and Dataloader with the only transform being converting the images to tensors there seems to be some extreme thresholding and I can't seem to locate the cause of this.

Below is how I'm displaying the first image:

dataset = gdal.Open(dir)

print(dataset.RasterCount)
img = dataset.GetRasterBand(1).ReadAsArray() 

f = plt.figure() 
plt.imshow(img) 
print(img.shape)
plt.show() 

and here is how I am using the data loader and displaying the thresholded image:

data_transforms = {
        'train': transforms.Compose([
            transforms.ToTensor(),
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
        ]),
    }

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x]) for x in ['train', 'val']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=1,
                                                 shuffle=True,
                                                 num_workers=2) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

for image in dataloders["train"]:
  f = plt.figure() 
  print(image[0].shape)
  plt.imshow(image[0].squeeze()[0,:,:]) 
  plt.show() 
  break

Any help on an alternative way to display the images or any mistakes I am making would be greatly appreciated.



Solution 1:[1]

If you want to visualize images loaded by Dataloader, I suggest this script :

for batch in train_data_loader:
    inputs, targets = batch
    for img in inputs:
        image  = img.cpu().numpy()
        # transpose image to fit plt input
        image = image.T
        # normalise image
        data_min = np.min(image, axis=(1,2), keepdims=True)
        data_max = np.max(image, axis=(1,2), keepdims=True)
        scaled_data = (image - data_min) / (data_max - data_min)
        # show image
        plt.imshow(scaled_data)
        plt.show()

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 bart-khalid