ClueAI commited on
Commit
6dbcbb7
1 Parent(s): 4654d38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -4,7 +4,7 @@ import clueai
4
  import torch
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
  tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
7
- model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2").half()
8
  # 使用
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
  model.to(device)
@@ -25,11 +25,11 @@ def answer(text, sample=True, top_p=0.9, temperature=0.7):
25
  top_p:0-1之间,生成的内容越多样'''
26
  text = preprocess(text)
27
  encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
28
- # if not sample:
29
- # out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
30
- # else:
31
- # out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
32
- out=model.generate(**encoding, **generate_config)
33
  out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
34
  return postprocess(out_text[0])
35
 
 
4
  import torch
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
  tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
7
+ model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
8
  # 使用
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
  model.to(device)
 
25
  top_p:0-1之间,生成的内容越多样'''
26
  text = preprocess(text)
27
  encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
28
+ if not sample:
29
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
30
+ else:
31
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
32
+ #out=model.generate(**encoding, **generate_config)
33
  out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
34
  return postprocess(out_text[0])
35