HansBug commited on
Commit
08cbdf8
1 Parent(s): 9abc992

dev(hansbug): fix this

Browse files
app.py CHANGED
@@ -10,6 +10,18 @@ _QUESTIONS = list_ordered_questions()
10
  _LANG = os.environ.get('QUESTION_LANG', 'cn')
11
  _LLM = os.environ.get('QUESTION_LLM', 'chatgpt')
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if __name__ == '__main__':
14
  with gr.Blocks() as demo:
15
  with gr.Row():
@@ -20,6 +32,8 @@ if __name__ == '__main__':
20
  gr_submit = gr.Button('Submit', interactive=False)
21
 
22
  with gr.Column():
 
 
23
  gr_predict = gr.Label(label='Correctness')
24
  gr_explanation = gr.TextArea(label='Explanation')
25
  gr_next = gr.Button('Next')
@@ -47,8 +61,15 @@ if __name__ == '__main__':
47
  )
48
 
49
 
50
- def _submit_answer(qs_text: str):
51
- executor = QuestionExecutor(_QUESTIONS[_QUESTION_ID], _LANG)
 
 
 
 
 
 
 
52
  answer_text, correctness, explanation = executor.check(qs_text)
53
  labels = {'Correct': 1.0} if correctness else {'Wrong': 1.0}
54
  if correctness:
@@ -59,7 +80,7 @@ if __name__ == '__main__':
59
 
60
  gr_submit.click(
61
  _submit_answer,
62
- inputs=[gr_question],
63
  outputs=[gr_answer, gr_predict, gr_explanation, gr_next],
64
  )
65
 
 
10
  _LANG = os.environ.get('QUESTION_LANG', 'cn')
11
  _LLM = os.environ.get('QUESTION_LLM', 'chatgpt')
12
 
13
+
14
+ def _need_api_key():
15
+ return _LLM == 'chatgpt'
16
+
17
+
18
+ def _get_api_key_cfgs(api_key):
19
+ if _LLM == 'chatgpt':
20
+ return {'api_key': api_key}
21
+ else:
22
+ return {}
23
+
24
+
25
  if __name__ == '__main__':
26
  with gr.Blocks() as demo:
27
  with gr.Row():
 
32
  gr_submit = gr.Button('Submit', interactive=False)
33
 
34
  with gr.Column():
35
+ gr_api_key = gr.Text(placeholder='Your API Key', label='API Key', type='password',
36
+ visible=_need_api_key())
37
  gr_predict = gr.Label(label='Correctness')
38
  gr_explanation = gr.TextArea(label='Explanation')
39
  gr_next = gr.Button('Next')
 
61
  )
62
 
63
 
64
+ def _submit_answer(qs_text: str, api_key: str):
65
+ if _need_api_key() and not api_key:
66
+ return '---', {}, 'Please Enter API Key Before Submitting Question.', \
67
+ gr.Button('Next', interactive=False)
68
+
69
+ executor = QuestionExecutor(
70
+ _QUESTIONS[_QUESTION_ID], _LANG,
71
+ llm=_LLM, llm_cfgs=_get_api_key_cfgs(api_key) if _need_api_key() else {}
72
+ )
73
  answer_text, correctness, explanation = executor.check(qs_text)
74
  labels = {'Correct': 1.0} if correctness else {'Wrong': 1.0}
75
  if correctness:
 
80
 
81
  gr_submit.click(
82
  _submit_answer,
83
+ inputs=[gr_question, gr_api_key],
84
  outputs=[gr_answer, gr_predict, gr_explanation, gr_next],
85
  )
86
 
llmriddles/llms/base.py CHANGED
@@ -1,11 +1,11 @@
1
  from typing import Callable, Dict
2
 
3
- _LLMS: Dict[str, Callable[[str], str]] = {}
4
 
5
 
6
- def register_llm(name: str, llm_ask_fn: Callable[[str], str]):
7
  _LLMS[name] = llm_ask_fn
8
 
9
 
10
- def get_llm_fn(name: str) -> Callable[[str], str]:
11
  return _LLMS[name]
 
1
  from typing import Callable, Dict
2
 
3
+ _LLMS: Dict[str, Callable] = {}
4
 
5
 
6
+ def register_llm(name: str, llm_ask_fn: Callable):
7
  _LLMS[name] = llm_ask_fn
8
 
9
 
10
+ def get_llm_fn(name: str) -> Callable:
11
  return _LLMS[name]
llmriddles/llms/chatgpt.py CHANGED
@@ -1,31 +1,25 @@
1
- import os
2
  from functools import lru_cache
3
 
4
- import openai
5
 
6
  from .base import register_llm
7
 
8
 
9
  @lru_cache()
10
- def _setup_openai():
11
- current_path = os.path.dirname(os.path.realpath(__file__))
12
- parent_dir = os.path.dirname(current_path)
13
- if 'OPENAI_KEY' in os.environ:
14
- openai.api_key = os.environ['OPENAI_KEY']
15
- else:
16
- openai.api_key_path = f'{parent_dir}/.key'
17
 
18
 
19
- def ask_chatgpt(message: str):
20
- _setup_openai()
21
 
22
- response = openai.ChatCompletion.create(
23
  model="gpt-3.5-turbo",
24
  messages=[
25
  {"role": "user", "content": message}
26
  ],
27
  )
28
- return response["choices"][0]["message"]["content"].strip()
29
 
30
 
31
  register_llm('chatgpt', ask_chatgpt)
 
 
1
  from functools import lru_cache
2
 
3
+ from openai import OpenAI
4
 
5
  from .base import register_llm
6
 
7
 
8
  @lru_cache()
9
+ def _get_openai_client(api_key):
10
+ return OpenAI(api_key=api_key)
 
 
 
 
 
11
 
12
 
13
+ def ask_chatgpt(message: str, api_key: str):
14
+ client = _get_openai_client(api_key)
15
 
16
+ response = client.chat.completions.create(
17
  model="gpt-3.5-turbo",
18
  messages=[
19
  {"role": "user", "content": message}
20
  ],
21
  )
22
+ return response.choices[0].message.content.strip()
23
 
24
 
25
  register_llm('chatgpt', ask_chatgpt)
llmriddles/questions/executor.py CHANGED
@@ -5,17 +5,18 @@ from ..llms import get_llm_fn
5
 
6
 
7
  class QuestionExecutor:
8
- def __init__(self, question: Question, lang: str = 'cn', llm: str = 'chatgpt'):
9
  self.question = question
10
  self.lang = lang
11
  self.llm = llm
 
12
 
13
  @property
14
  def question_text(self):
15
  return self.question.texts[self.lang]
16
 
17
  def check(self, qs_text: str) -> Tuple[str, bool, str]:
18
- answer_text = get_llm_fn(self.llm)(qs_text)
19
  correct, explanation = self.check_answer(answer_text)
20
  return answer_text, correct, explanation
21
 
 
5
 
6
 
7
  class QuestionExecutor:
8
+ def __init__(self, question: Question, lang: str = 'cn', llm: str = 'chatgpt', llm_cfgs=None):
9
  self.question = question
10
  self.lang = lang
11
  self.llm = llm
12
+ self.llm_cfgs = dict(llm_cfgs or {})
13
 
14
  @property
15
  def question_text(self):
16
  return self.question.texts[self.lang]
17
 
18
  def check(self, qs_text: str) -> Tuple[str, bool, str]:
19
+ answer_text = get_llm_fn(self.llm)(qs_text, **self.llm_cfgs)
20
  correct, explanation = self.check_answer(answer_text)
21
  return answer_text, correct, explanation
22
 
requirements.txt CHANGED
@@ -2,4 +2,4 @@ hbutils>=0.9.1
2
  tqdm
3
  requests>=2.20
4
  gradio==4.1.1
5
- openai<1
 
2
  tqdm
3
  requests>=2.20
4
  gradio==4.1.1
5
+ openai>=1