Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import clueai
|
|
4 |
import torch
|
5 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
6 |
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
|
7 |
-
model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
|
8 |
# 使用
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
model.to(device)
|
@@ -25,11 +25,11 @@ def answer(text, sample=True, top_p=0.9, temperature=0.7):
|
|
25 |
top_p:0-1之间,生成的内容越多样'''
|
26 |
text = preprocess(text)
|
27 |
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
out=model.generate(**encoding, **generate_config)
|
33 |
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
|
34 |
return postprocess(out_text[0])
|
35 |
|
|
|
4 |
import torch
|
5 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
6 |
tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
|
7 |
+
model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
|
8 |
# 使用
|
9 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
model.to(device)
|
|
|
25 |
top_p:0-1之间,生成的内容越多样'''
|
26 |
text = preprocess(text)
|
27 |
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
|
28 |
+
if not sample:
|
29 |
+
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
|
30 |
+
else:
|
31 |
+
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
|
32 |
+
#out=model.generate(**encoding, **generate_config)
|
33 |
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
|
34 |
return postprocess(out_text[0])
|
35 |
|