'Use of pytorch dataset for model inference- GPU
I am running T5-base-grammar-correction for grammer correction on my dataframe with text column
from happytransformer import HappyTextToText
from happytransformer import TTSettings
from tqdm.notebook import tqdm
tqdm.pandas()
happy_tt = HappyTextToText("T5", "./t5-base-grammar-correction")
beam_settings = TTSettings(num_beams=5, min_length=1, max_length=30)
def grammer_pipeline(text):
text = "gec: " + text
result = happy_tt.generate_text(text, args=beam_settings)
return result.text
df['new_text'] = df['original_text'].progress_apply(grammer_pipeline)
Pandas apply function, though runs and provides required results, but runs quite slow.
Also I get the below warning while executing the code
/home/.local/lib/python3.6/site-packages/transformers/pipelines/base.py:908: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
UserWarning,
I have access to GPU. Can somebody provide some pointers to speed up the execution and utilising full capabilities of GPU
--------------------------------EDIT---------------------------------
I tried using pytorch Dataset in the below way, but still the processing is slow:
class CustomD(Dataset):
def __init__(self, text):
self.text = text
self.len = text.shape[0]
def __len__(self):
return self.len
def __getitem__(self, idx):
text = self.text[idx]
text = "gec: " + text
result = happy_tt.generate_text(text, args=beam_settings)
return result.text
TD = GramData(df.original_text)
final_data = DataLoader(dataset=TD,
batch_size=10,
shuffle=False
)
import itertools
list_modified=[]
for (idx, batch) in enumerate(final_data):
list_modified.append(batch)
flat_list = [item for sublist in list_modified for item in sublist]
df["new_text"]=flat_list
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|