'how do i solve Key 8 error while using pytorch?

from torch.utils.data import (TensorDataset, DataLoader, RandomSampler, 
                              SequentialSampler)

def data_loader(train_inputs, val_inputs, train_labels, val_labels, batch_size=50):
    """
       Convert train and validation sets to torch.Tensors and load them to DataLoader.
    """

    # Convert data type to torch.Tensor
    train_inputs, val_inputs, train_labels, val_labels =\
    tuple(torch.tensor(data) for data in
        [train_inputs, val_inputs, train_labels, val_labels])

    # Specify batch_size
    batch_size = 50

    # Create DataLoader for training data
    train_data = TensorDataset(train_inputs, train_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, 
        batch_size=batch_size)

    # Create DataLoader for validation data
    val_data = TensorDataset(val_inputs, val_labels)
    val_sampler = SequentialSampler(val_data)
    val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)

    return train_dataloader, val_dataloader

The code works fine when the train_inputs and val_inputs tensors are of of type int64, but doesn't when the type is int32.

Can someone tell me what's wrong here?

ERROR:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~\Anaconda3\lib\site-packages\pandas\core\indexes\base.py:3621, in Index.get_loc(self, key, method, tolerance)
   3620 try:
-> 3621     return self._engine.get_loc(casted_key)
   3622 except KeyError as err:

File ~\Anaconda3\lib\site-packages\pandas\_libs\index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()

File ~\Anaconda3\lib\site-packages\pandas\_libs\index.pyx:163, in pandas._libs.index.IndexEngine.get_loc()

File pandas\_libs\hashtable_class_helper.pxi:2131, in pandas._libs.hashtable.Int64HashTable.get_item()

File pandas\_libs\hashtable_class_helper.pxi:2140, in pandas._libs.hashtable.Int64HashTable.get_item()

KeyError: 8

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Input In [31], in <cell line: 6>()
      2 train_inputs, val_inputs, train_labels, val_labels = train_test_split(
      3     input_ids, labels, test_size=0.1, random_state=42)
      5 # Load data to PyTorch DataLoader
----> 6 train_dataloader, val_dataloader = data_loader(train_inputs, val_inputs, train_labels, val_labels, batch_size=50)

Input In [28], in data_loader(train_inputs, val_inputs, train_labels, val_labels, batch_size)
      6 """Convert train and validation sets to torch.Tensors and load them to
      7 DataLoader.
      8 """
     10 # Convert data type to torch.Tensor
     11 train_inputs, val_inputs, train_labels, val_labels =\
---> 12 tuple(torch.tensor(data) for data in
     13       [train_inputs, val_inputs, train_labels, val_labels])
     15 # Specify batch_size
     16 batch_size = 50

Input In [28], in <genexpr>(.0)
      6 """Convert train and validation sets to torch.Tensors and load them to
      7 DataLoader.
      8 """
     10 # Convert data type to torch.Tensor
     11 train_inputs, val_inputs, train_labels, val_labels =\
---> 12 tuple(torch.tensor(data) for data in
     13       [train_inputs, val_inputs, train_labels, val_labels])
     15 # Specify batch_size
     16 batch_size = 50

File ~\Anaconda3\lib\site-packages\pandas\core\series.py:958, in Series.__getitem__(self, key)
    955     return self._values[key]
    957 elif key_is_scalar:
--> 958     return self._get_value(key)
    960 if is_hashable(key):
    961     # Otherwise index.get_value will raise InvalidIndexError
    962     try:
    963         # For labels that don't resolve as scalars like tuples and frozensets

File ~\Anaconda3\lib\site-packages\pandas\core\series.py:1069, in Series._get_value(self, label, takeable)
   1066     return self._values[label]
   1068 # Similar to Index.get_value, but we do not fall back to positional
-> 1069 loc = self.index.get_loc(label)
   1070 return self.index._get_values_for_loc(self, loc, label)

File ~\Anaconda3\lib\site-packages\pandas\core\indexes\base.py:3623, in Index.get_loc(self, key, method, tolerance)
   3621     return self._engine.get_loc(casted_key)
   3622 except KeyError as err:
-> 3623     raise KeyError(key) from err
   3624 except TypeError:
   3625     # If we have a listlike key, _check_indexing_error will raise
   3626     #  InvalidIndexError. Otherwise we fall through and re-raise
   3627     #  the TypeError.
   3628     self._check_indexing_error(key)

KeyError: 8


Solution 1:[1]

I was using the same code on my data set and had the same issue. I did 2 things. changed the random_state to not be 42 (which probably wasn't what fixed it) and I also changed my labels to np.array and now it works

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 granolagirl