'Converting a tf.dataset to a PyTorch Dataset?

I'm working on this project where all the data comes preprocessed and ready as a tensorflow datasets which looks like this:

<MapDataset shapes: {input_ids: (128,), input_mask: (128,), label_ids: (), segment_ids: (128,)}, types: {input_ids: tf.int64, input_mask: tf.int64, label_ids: tf.int64, segment_ids: tf.int64}>

The script that I have is in PyTorch and takes in a Dataset object which looks like this:

Dataset({
    features: ['attention_mask', 'input_ids', 'label', 'sentence', 'token_type_ids'],
    num_rows: 12
})

is there anyway to convert one to the other? I'm quite new to both these API's so I'm not too sure how they work? Can I potentially convert one to the other using a dict?

Thank you



Solution 1:[1]

I use tfds.as_numpy(dataset) as the dataloader for my model training. To convert the data passed to my model, I use torch.as_tensor(data, device=<device>) inside my model's forward function.

import tensorflow_datasets as tfds
import torch.nn as nn

def train_dataloader(batch_size):
    return tfds.as_numpy(tfds.load('mnist').batch(batch_size))

class Model(nn.Module):
    def forward(self, x):
        x = torch.as_tensor(x, device='cuda')
        ...

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Jaideep Heer