'how do i solve Key 8 error while using pytorch?
from torch.utils.data import (TensorDataset, DataLoader, RandomSampler,
SequentialSampler)
def data_loader(train_inputs, val_inputs, train_labels, val_labels, batch_size=50):
"""
Convert train and validation sets to torch.Tensors and load them to DataLoader.
"""
# Convert data type to torch.Tensor
train_inputs, val_inputs, train_labels, val_labels =\
tuple(torch.tensor(data) for data in
[train_inputs, val_inputs, train_labels, val_labels])
# Specify batch_size
batch_size = 50
# Create DataLoader for training data
train_data = TensorDataset(train_inputs, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=batch_size)
# Create DataLoader for validation data
val_data = TensorDataset(val_inputs, val_labels)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)
return train_dataloader, val_dataloader
The code works fine when the train_inputs
and val_inputs
tensors are of of type int64
, but doesn't when the type is int32
.
Can someone tell me what's wrong here?
ERROR:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File ~\Anaconda3\lib\site-packages\pandas\core\indexes\base.py:3621, in Index.get_loc(self, key, method, tolerance)
3620 try:
-> 3621 return self._engine.get_loc(casted_key)
3622 except KeyError as err:
File ~\Anaconda3\lib\site-packages\pandas\_libs\index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()
File ~\Anaconda3\lib\site-packages\pandas\_libs\index.pyx:163, in pandas._libs.index.IndexEngine.get_loc()
File pandas\_libs\hashtable_class_helper.pxi:2131, in pandas._libs.hashtable.Int64HashTable.get_item()
File pandas\_libs\hashtable_class_helper.pxi:2140, in pandas._libs.hashtable.Int64HashTable.get_item()
KeyError: 8
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
Input In [31], in <cell line: 6>()
2 train_inputs, val_inputs, train_labels, val_labels = train_test_split(
3 input_ids, labels, test_size=0.1, random_state=42)
5 # Load data to PyTorch DataLoader
----> 6 train_dataloader, val_dataloader = data_loader(train_inputs, val_inputs, train_labels, val_labels, batch_size=50)
Input In [28], in data_loader(train_inputs, val_inputs, train_labels, val_labels, batch_size)
6 """Convert train and validation sets to torch.Tensors and load them to
7 DataLoader.
8 """
10 # Convert data type to torch.Tensor
11 train_inputs, val_inputs, train_labels, val_labels =\
---> 12 tuple(torch.tensor(data) for data in
13 [train_inputs, val_inputs, train_labels, val_labels])
15 # Specify batch_size
16 batch_size = 50
Input In [28], in <genexpr>(.0)
6 """Convert train and validation sets to torch.Tensors and load them to
7 DataLoader.
8 """
10 # Convert data type to torch.Tensor
11 train_inputs, val_inputs, train_labels, val_labels =\
---> 12 tuple(torch.tensor(data) for data in
13 [train_inputs, val_inputs, train_labels, val_labels])
15 # Specify batch_size
16 batch_size = 50
File ~\Anaconda3\lib\site-packages\pandas\core\series.py:958, in Series.__getitem__(self, key)
955 return self._values[key]
957 elif key_is_scalar:
--> 958 return self._get_value(key)
960 if is_hashable(key):
961 # Otherwise index.get_value will raise InvalidIndexError
962 try:
963 # For labels that don't resolve as scalars like tuples and frozensets
File ~\Anaconda3\lib\site-packages\pandas\core\series.py:1069, in Series._get_value(self, label, takeable)
1066 return self._values[label]
1068 # Similar to Index.get_value, but we do not fall back to positional
-> 1069 loc = self.index.get_loc(label)
1070 return self.index._get_values_for_loc(self, loc, label)
File ~\Anaconda3\lib\site-packages\pandas\core\indexes\base.py:3623, in Index.get_loc(self, key, method, tolerance)
3621 return self._engine.get_loc(casted_key)
3622 except KeyError as err:
-> 3623 raise KeyError(key) from err
3624 except TypeError:
3625 # If we have a listlike key, _check_indexing_error will raise
3626 # InvalidIndexError. Otherwise we fall through and re-raise
3627 # the TypeError.
3628 self._check_indexing_error(key)
KeyError: 8
Solution 1:[1]
I was using the same code on my data set and had the same issue. I did 2 things. changed the random_state to not be 42 (which probably wasn't what fixed it) and I also changed my labels to np.array and now it works
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | granolagirl |