chansung commited on
Commit
602e36b
1 Parent(s): a254af1

Update llama2.py

Browse files
Files changed (1) hide show
  1. llama2.py +40 -108
llama2.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import json
3
  import requests
4
  import sseclient
 
5
 
6
  from pingpong import PingPong
7
  from pingpong.pingpong import PPManager
@@ -9,36 +10,14 @@ from pingpong.pingpong import PromptFmt
9
  from pingpong.pingpong import UIFmt
10
  from pingpong.gradio import GradioChatUIFmt
11
 
12
- class LLaMA2ChatPromptFmt(PromptFmt):
13
- @classmethod
14
- def ctx(cls, context):
15
- if context is None or context == "":
16
- return ""
17
- else:
18
- return f"""<<SYS>>
19
- {context}
20
- <</SYS>>
21
- """
22
 
23
- @classmethod
24
- def prompt(cls, pingpong, truncate_size):
25
- ping = pingpong.ping[:truncate_size]
26
- pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
27
- return f"""[INST] {ping} [/INST] {pong}"""
28
 
29
- class LLaMA2ChatPPManager(PPManager):
30
- def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
31
- if to_idx == -1 or to_idx >= len(self.pingpongs):
32
- to_idx = len(self.pingpongs)
33
-
34
- results = fmt.ctx(self.ctx)
35
-
36
- for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
37
- results += fmt.prompt(pingpong, truncate_size=truncate_size)
38
-
39
- return results
40
-
41
- class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
42
  def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
43
  if to_idx == -1 or to_idx >= len(self.pingpongs):
44
  to_idx = len(self.pingpongs)
@@ -48,86 +27,39 @@ class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
48
  for pingpong in self.pingpongs[from_idx:to_idx]:
49
  results.append(fmt.ui(pingpong))
50
 
51
- return results
52
 
53
- async def gen_text(
54
- prompt,
55
- hf_model='meta-llama/Llama-2-70b-chat-hf',
56
- hf_token=None,
57
  parameters=None
58
  ):
59
- if hf_token is None:
60
- raise ValueError("Hugging Face Token is not set")
61
-
62
- if parameters is None:
63
- parameters = {
64
- 'max_new_tokens': 512,
65
- 'do_sample': True,
66
- 'return_full_text': False,
67
- 'temperature': 1.0,
68
- 'top_k': 50,
69
- # 'top_p': 1.0,
70
- 'repetition_penalty': 1.2
71
- }
72
-
73
- url = f'https://api-inference.huggingface.co/models/{hf_model}'
74
- headers={
75
- 'Authorization': f'Bearer {hf_token}',
76
- 'Content-type': 'application/json'
77
- }
78
- data = {
79
- 'inputs': prompt,
80
- 'stream': True,
81
- 'options': {
82
- 'use_cache': False,
83
- },
84
- 'parameters': parameters
85
- }
86
-
87
- r = requests.post(
88
- url,
89
- headers=headers,
90
- data=json.dumps(data),
91
- stream=True
92
- )
93
-
94
- client = sseclient.SSEClient(r)
95
- for event in client.events():
96
- yield json.loads(event.data)['token']['text']
97
-
98
- def gen_text_none_stream(
99
- prompt,
100
- hf_model='meta-llama/Llama-2-70b-chat-hf',
101
- hf_token=None,
102
- ):
103
- parameters = {
104
- 'max_new_tokens': 64,
105
- 'do_sample': True,
106
- 'return_full_text': False,
107
- 'temperature': 0.7,
108
- 'top_k': 10,
109
- # 'top_p': 1.0,
110
- 'repetition_penalty': 1.2
111
- }
112
-
113
- url = f'https://api-inference.huggingface.co/models/{hf_model}'
114
- headers={
115
- 'Authorization': f'Bearer {hf_token}',
116
- 'Content-type': 'application/json'
117
- }
118
- data = {
119
- 'inputs': prompt,
120
- 'stream': False,
121
- 'options': {
122
- 'use_cache': False,
123
- },
124
- 'parameters': parameters
125
- }
126
-
127
- r = requests.post(
128
- url,
129
- headers=headers,
130
- data=json.dumps(data),
131
- )
132
-
133
- return json.loads(r.text)[0]["generated_text"]
 
2
  import json
3
  import requests
4
  import sseclient
5
+ import google.generativeai as palm_api
6
 
7
  from pingpong import PingPong
8
  from pingpong.pingpong import PPManager
 
10
  from pingpong.pingpong import UIFmt
11
  from pingpong.gradio import GradioChatUIFmt
12
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ palm_api_token = os.getenv("PALM_API_TOKEN")
15
+ if palm_api_token is None:
16
+ raise ValueError("PaLM API Token is not set")
17
+ else:
18
+ palm_api.configure(api_key=palm_api_token)
19
 
20
+ class GradioPaLMChatPPManager(PPManager):
 
 
 
 
 
 
 
 
 
 
 
 
21
  def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
22
  if to_idx == -1 or to_idx >= len(self.pingpongs):
23
  to_idx = len(self.pingpongs)
 
27
  for pingpong in self.pingpongs[from_idx:to_idx]:
28
  results.append(fmt.ui(pingpong))
29
 
30
+ return results
31
 
32
+ def gen_text(
33
+ prompt,
34
+ palm,
 
35
  parameters=None
36
  ):
37
+ if parameters is None:
38
+ model = 'models/text-bison-001'
39
+ temperature = 0.7
40
+ candidate_count = 1
41
+ top_k = 40
42
+ top_p = 0.95
43
+
44
+ parameters = {
45
+ 'model': model,
46
+ 'temperature': temperature,
47
+ 'candidate_count': candidate_count,
48
+ 'top_k': top_k,
49
+ 'top_p': top_p,
50
+ }
51
+
52
+ if palm is None:
53
+ response = palm_api.chat(messages=[prompt])
54
+ else:
55
+ response = palm.reply(prompt)
56
+
57
+ if len(response.filters) > 0 and \
58
+ response.filters[0]['reason'] == 2:
59
+ response_txt = "your request is blocked for some reasons"
60
+
61
+ else:
62
+ response_txt = response.last
63
+
64
+ return response, response_txt
65
+