BAAI
/

shunxing1234 commited on
Commit
7d07c9a
1 Parent(s): 4b56aa6

Update README_zh.md

Browse files
Files changed (1) hide show
  1. README_zh.md +3 -24
README_zh.md CHANGED
@@ -47,40 +47,19 @@ AquilaChat-7B v0.8 在 FlagEval 大模型评测中( “主观+客观”)相
47
  ```python
48
  from transformers import AutoTokenizer, AutoModelForCausalLM
49
  import torch
50
-
51
- device = torch.device("cuda:1")
52
-
53
  model_info = "BAAI/AquilaChat-7B"
54
  tokenizer = AutoTokenizer.from_pretrained(model_info, trust_remote_code=True)
55
  model = AutoModelForCausalLM.from_pretrained(model_info, trust_remote_code=True)
56
  model.eval()
57
  model.to(device)
58
-
59
  text = "请给出10个要到北京旅游的理由。"
60
-
61
  tokens = tokenizer.encode_plus(text)['input_ids'][:-1]
62
-
63
  tokens = torch.tensor(tokens)[None,].to(device)
64
-
65
-
66
  with torch.no_grad():
67
- out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
68
-
69
  out = tokenizer.decode(out.cpu().numpy().tolist())
70
- if "###" in out:
71
- special_index = out.index("###")
72
- out = out[: special_index]
73
-
74
- if "[UNK]" in out:
75
- special_index = out.index("[UNK]")
76
- out = out[:special_index]
77
-
78
- if "</s>" in out:
79
- special_index = out.index("</s>")
80
- out = out[: special_index]
81
-
82
- if len(out) > 0 and out[0] == " ":
83
- out = out[1:]
84
  print(out)
85
  ```
86
 
 
47
  ```python
48
  from transformers import AutoTokenizer, AutoModelForCausalLM
49
  import torch
50
+ device = torch.device("cuda")
 
 
51
  model_info = "BAAI/AquilaChat-7B"
52
  tokenizer = AutoTokenizer.from_pretrained(model_info, trust_remote_code=True)
53
  model = AutoModelForCausalLM.from_pretrained(model_info, trust_remote_code=True)
54
  model.eval()
55
  model.to(device)
 
56
  text = "请给出10个要到北京旅游的理由。"
 
57
  tokens = tokenizer.encode_plus(text)['input_ids'][:-1]
 
58
  tokens = torch.tensor(tokens)[None,].to(device)
59
+ stop_tokens = ["###", "[UNK]", "</s>"]
 
60
  with torch.no_grad():
61
+ out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007, bad_words_ids=[[tokenizer.encode(token)[0] for token in stop_tokens]])[0]
 
62
  out = tokenizer.decode(out.cpu().numpy().tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  print(out)
64
  ```
65