File size: 2,274 Bytes
aeae383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from typing import List, Dict, Tuple
from transformers import T5TokenizerFast, T5ForConditionalGeneration
import string
from typing import List
# Constants
MODEL_NAME = 't5-small'
SOURCE_MAX_TOKEN_LEN = 300
TARGET_MAX_TOKEN_LEN = 80
SEP_TOKEN = '<sep>'
TOKENIZER_LEN = 32101
class QuestionAnswerGenerator():
def __init__(self):
self.tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
self.tokenizer.add_tokens(SEP_TOKEN)
self.tokenizer_len = len(self.tokenizer)
self.model = T5ForConditionalGeneration.from_pretrained("fahmiaziz/QAModel")
def generate(self, answer: str, context: str) -> str:
model_output = self._model_predict(answer, context)
generated_answer, generated_question = model_output.split(SEP_TOKEN)
return generated_question
def generate_qna(self, context: str) -> Tuple[str, str]:
answer_mask = '[MASK]'
model_output = self._model_predict(answer_mask, context)
qna_pair = model_output.split(SEP_TOKEN)
if len(qna_pair) < 2:
generated_answer = ''
generated_question = qna_pair[0]
else:
generated_answer = qna_pair[0]
generated_question = qna_pair[1]
return generated_answer, generated_question
def _model_predict(self, answer: str, context: str) -> str:
source_encoding = self.tokenizer(
'{} {} {}'.format(answer, SEP_TOKEN, context),
max_length=SOURCE_MAX_TOKEN_LEN,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
generated_ids = self.model.generate(
input_ids=source_encoding['input_ids'],
attention_mask=source_encoding['attention_mask'],
num_beams=16,
max_length=TARGET_MAX_TOKEN_LEN,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
use_cache=True
)
preds = {
self.tokenizer.decode(generated_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for generated_id in generated_ids
}
return ''.join(preds)
|