ZhangCheng commited on
Commit
c868079
1 Parent(s): f307b88

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -9
README.md CHANGED
@@ -26,31 +26,31 @@ trained_tokenizer_path = 'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation'
26
 
27
  class QuestionGeneration:
28
 
29
- def __init__(self):
30
  self.model = T5ForConditionalGeneration.from_pretrained(trained_model_path)
31
  self.tokenizer = T5Tokenizer.from_pretrained(trained_tokenizer_path)
32
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
  self.model = self.model.to(self.device)
34
  self.model.eval()
35
 
36
- def generate(self, answer:str, context:str):
37
  input_text = '<answer> %s <context> %s ' % (answer, context)
38
  encoding = self.tokenizer.encode_plus(
39
  input_text,
40
  return_tensors='pt'
41
  )
42
- input_ids = encoding['input_ids'].to(self.device)
43
- attention_mask = encoding['attention_mask'].to(self.device)
44
  outputs = self.model.generate(
45
- input_ids = input_ids,
46
- attention_mask = attention_mask
47
  )
48
  question = self.tokenizer.decode(
49
  outputs[0],
50
- skip_special_tokens = True,
51
- clean_up_tokenization_spaces = True
52
  )
53
- return {'question': question, 'answer': answer}
54
 
55
  if __name__ == "__main__":
56
  context = 'ZhangCheng fine-tuned T5 on SQuAD dataset for question generation.'
 
26
 
27
  class QuestionGeneration:
28
 
29
+ def __init__(self, model_dir=None):
30
  self.model = T5ForConditionalGeneration.from_pretrained(trained_model_path)
31
  self.tokenizer = T5Tokenizer.from_pretrained(trained_tokenizer_path)
32
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
  self.model = self.model.to(self.device)
34
  self.model.eval()
35
 
36
+ def generate(self, answer: str, context: str):
37
  input_text = '<answer> %s <context> %s ' % (answer, context)
38
  encoding = self.tokenizer.encode_plus(
39
  input_text,
40
  return_tensors='pt'
41
  )
42
+ input_ids = encoding['input_ids']
43
+ attention_mask = encoding['attention_mask']
44
  outputs = self.model.generate(
45
+ input_ids=input_ids,
46
+ attention_mask=attention_mask
47
  )
48
  question = self.tokenizer.decode(
49
  outputs[0],
50
+ skip_special_tokens=True,
51
+ clean_up_tokenization_spaces=True
52
  )
53
+ return {'question': question, 'answer': answer, 'context': context}
54
 
55
  if __name__ == "__main__":
56
  context = 'ZhangCheng fine-tuned T5 on SQuAD dataset for question generation.'