niuyazhe commited on
Commit
d2f4f1c
1 Parent(s): 794ea91

feature(nyz): add naive Mistral-7B-Instruct-v0.1 demo

Browse files
README.md CHANGED
@@ -29,13 +29,9 @@ QUESTION_LANG=cn QUESTION_LLM='chatgpt' QUESTION_LLM_KEY=<your API key> python3
29
  ```shell
30
  QUESTION_LANG=en QUESTION_LLM='chatgpt' QUESTION_LLM_KEY=<your API key> python3 -u app.py
31
  ```
32
- ### LLaMA2-7b + 中文
33
  ```shell
34
- QUESTION_LANG=cn QUESTION_LLM='llama2-7b' python3 -u app.py
35
- ```
36
- ### LLaMA2-7b + 英文
37
- ```shell
38
- QUESTION_LANG=en QUESTION_LLM='llama2-7b' python3 -u app.py
39
  ```
40
  ## :technologist: 为什么制作这个游戏
41
 
@@ -50,9 +46,9 @@ QUESTION_LANG=en QUESTION_LLM='llama2-7b' python3 -u app.py
50
  - [x] 支持自定义关卡
51
  - [ ] 在线试玩链接
52
  - [ ] Hugging Face Space 链接
53
- - [ ] 支持LLaMA2-7B(英文)
54
- - [ ] 支持Mistral-7B(英文)
55
  - [ ] 支持Baichuan2-7B(中文)
 
56
  - [ ] LLM 推理速度优化
57
 
58
 
 
29
  ```shell
30
  QUESTION_LANG=en QUESTION_LLM='chatgpt' QUESTION_LLM_KEY=<your API key> python3 -u app.py
31
  ```
32
+ ### Mistral-7B-Instruct-v0.1 + 英文
33
  ```shell
34
+ QUESTION_LANG=en QUESTION_LLM='mistral-7b' python3 -u app.py
 
 
 
 
35
  ```
36
  ## :technologist: 为什么制作这个游戏
37
 
 
46
  - [x] 支持自定义关卡
47
  - [ ] 在线试玩链接
48
  - [ ] Hugging Face Space 链接
49
+ - [x] 支持Mistral-7B-Instruct-v0.1(英文)
 
50
  - [ ] 支持Baichuan2-7B(中文)
51
+ - [ ] 支持LLaMA2-7B(英文)
52
  - [ ] LLM 推理速度优化
53
 
54
 
app.py CHANGED
@@ -11,7 +11,7 @@ _QUESTIONS = list_ordered_questions()
11
  _LANG = os.environ.get('QUESTION_LANG', 'cn')
12
  assert _LANG in ['cn', 'en'], _LANG
13
  _LLM = os.environ.get('QUESTION_LLM', 'chatgpt')
14
- assert _LLM in ['chatgpt', 'llama2-7b'], _LLM
15
  _LLM_KEY = os.environ.get('QUESTION_LLM_KEY', None)
16
 
17
  if _LANG == "cn":
 
11
  _LANG = os.environ.get('QUESTION_LANG', 'cn')
12
  assert _LANG in ['cn', 'en'], _LANG
13
  _LLM = os.environ.get('QUESTION_LLM', 'chatgpt')
14
+ assert _LLM in ['chatgpt', 'mistral-7b'], _LLM
15
  _LLM_KEY = os.environ.get('QUESTION_LLM_KEY', None)
16
 
17
  if _LANG == "cn":
llmriddles/llms/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
- from .chatgpt import ask_chatgpt
2
  from .base import register_llm, get_llm_fn
 
 
 
 
1
  from .base import register_llm, get_llm_fn
2
+ from .chatgpt import ask_chatgpt
3
+ from .mistral import ask_mistral_7b_instruct
llmriddles/llms/llm_client.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import requests
3
+ import logging
4
+ import argparse
5
+
6
+
7
+ class LLMFlaskClient:
8
+ def __init__(self, ip: str, port: int, max_retry: int = 3):
9
+ self.ip = ip
10
+ self.port = port
11
+
12
+ self.url_prefix_format = 'http://{}:{}/'
13
+ self.url = self.url_prefix_format.format(self.ip, self.port)
14
+ self.max_retry = max_retry
15
+
16
+ self.logger = logging.getLogger()
17
+ self.logger.addHandler(logging.StreamHandler())
18
+ self.logger.handlers[0].setFormatter(logging.Formatter("%(message)s"))
19
+
20
+ def _request(self, name: str, data: dict):
21
+ for _ in range(self.max_retry):
22
+ try:
23
+ self.logger.info(f'{name}\ndata: {data}')
24
+ response = requests.post(self.url + name, json=data).json()
25
+ except Exception as e:
26
+ self.logger.warning('error: ', repr(e))
27
+ time.sleep(1)
28
+ continue
29
+ if response['code'] == 0:
30
+ return response['output']
31
+ else:
32
+ raise Exception(response['error_msg'])
33
+ raise Exception("Web service failed. Please retry or contact with manager")
34
+
35
+ def run(self, message: str) -> str:
36
+ try:
37
+ return self._request('ask_llm_for_answer', {'user_text': message})
38
+ except Exception as e:
39
+ return f"Error: {repr(e)}"
40
+
41
+
42
+ if __name__ == "__main__":
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--ip', required=True)
45
+ parser.add_argument('-p', '--port', required=True)
46
+ parser.add_argument('--debug', action='store_true')
47
+ args = parser.parse_args()
48
+ if args.debug:
49
+ logging.getLogger().setLevel(logging.INFO)
50
+ else:
51
+ logging.getLogger().setLevel(logging.WARNING)
52
+
53
+ client = LLMFlaskClient(args.ip, args.port)
54
+ print(client.run('Please concatenate string "1+" and "1=3". Only give me the result without "".'))
llmriddles/llms/llm_server.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from flask import Flask, request
3
+ import argparse
4
+ import logging
5
+
6
+
7
+ class LLMInstance:
8
+
9
+ def __init__(self, model_path: str, device: str = "cuda"):
10
+
11
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
13
+ self.model.to(device)
14
+ self.device = device
15
+
16
+ def query(self, message):
17
+ try:
18
+ messages = [
19
+ {"role": "user", "content": message},
20
+ ]
21
+ encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
22
+ model_inputs = encodeds.to(self.device)
23
+
24
+ generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
25
+ decoded = self.tokenizer.batch_decode(generated_ids)
26
+
27
+ # output is the string decoded[0] after "[/INST]". There may exist "</s>", delete it.
28
+ output = decoded[0].split("[/INST]")[1].split("</s>")[0]
29
+ return {
30
+ 'code': 0,
31
+ 'ret': True,
32
+ 'error_msg': None,
33
+ 'output': output
34
+ }
35
+ except Exception as e:
36
+ return {
37
+ 'code': 1,
38
+ 'ret': False,
39
+ 'error_msg': str(e),
40
+ 'output': None
41
+ }
42
+
43
+
44
+ def create_app(core):
45
+ app = Flask(__name__)
46
+
47
+ @app.route('/ask_llm_for_answer', methods=['POST'])
48
+ def ask_llm_for_answer():
49
+ user_text = request.json['user_text']
50
+ return core.query(user_text)
51
+
52
+ return app
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('-m', '--model_path', required=True, default='Mistral-7B-Instruct-v0.1', help='the model path of reward model')
58
+ parser.add_argument('--ip', default='0.0.0.0')
59
+ parser.add_argument('-p', '--port', default=8001)
60
+ parser.add_argument('--debug', action='store_true')
61
+ args = parser.parse_args()
62
+
63
+ if args.debug:
64
+ logging.getLogger().setLevel(logging.DEBUG)
65
+ else:
66
+ logging.getLogger().setLevel(logging.INFO)
67
+ logging.getLogger().addHandler(logging.StreamHandler())
68
+ logging.getLogger().handlers[0].setFormatter(logging.Formatter("%(message)s"))
69
+
70
+ core = LLMInstance(args.model_path)
71
+ app = create_app(core)
72
+ app.run(host=args.ip, port=args.port)
llmriddles/llms/mistral.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ from .base import register_llm
4
+ from .llm_client import LLMFlaskClient
5
+
6
+
7
+ @lru_cache()
8
+ def _get_mistral_7b_instruct_server(host: str, port: int):
9
+ from .llm_server import LLMInstance, create_app
10
+ core = LLMInstance('Mistral-7B-Instruct-v0.1')
11
+ app = create_app(core)
12
+ app.run(host=host, port=port)
13
+
14
+
15
+ def ask_mistral_7b_instruct(message: str, **kwargs):
16
+ host, port = '0.0.0.0', 8001
17
+ _get_mistral_7b_instruct_server(host, port)
18
+ client = LLMFlaskClient(host, port)
19
+ return client.run(message).strip()
20
+
21
+
22
+ register_llm('mistral-7b', ask_mistral_7b_instruct)
requirements.txt CHANGED
@@ -2,4 +2,6 @@ hbutils>=0.9.1
2
  tqdm
3
  requests>=2.20
4
  gradio==4.1.1
5
- openai>=1
 
 
 
2
  tqdm
3
  requests>=2.20
4
  gradio==4.1.1
5
+ openai>=1
6
+ flask
7
+ transformers