T5ForSQG | T5 Search Query Generation
Collection
6 items
•
Updated
class T5ForSQG:
def __init__(self, model_path):
self.model = T5ForConditionalGeneration.from_pretrained(model_path)
self.tokenizer = T5Tokenizer.from_pretrained(model_path)
def make_queries(self, topic, n=1, device='cpu', batch_size=16):
ds = YourDataSetClass(pd.DataFrame({'topic': ['make queries: '+topic]*n, 'queries': [[]*n]}, index=range(n)), self.tokenizer, 64, 64, 'topic', 'queries')
loader_params = {'batch_size': n if n < batch_size else batch_size, 'shuffle': False, 'num_workers': 0}
loader = DataLoader(ds, **loader_params)
self.model.eval()
predictions = []
with torch.no_grad():
for _, data in enumerate(loader, 0):
y = data['target_ids'].to(device, dtype = torch.long)
ids = data['source_ids'].to(device, dtype = torch.long)
mask = data['source_mask'].to(device, dtype = torch.long)
generated_ids = self.model.generate(
input_ids = ids,
attention_mask = mask,
max_length=64,
num_beams=1,
repetition_penalty=2.5,
length_penalty=1.0,
do_sample = True,
temperature = 1.5,
top_k = 10,
top_p = 0.95
)
preds = list(set([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]))
predictions.extend(preds)
return list(set(predictions))