chansung commited on
Commit
d0461f7
1 Parent(s): c22fa26

Update palmapi.py

Browse files
Files changed (1) hide show
  1. palmapi.py +42 -8
palmapi.py CHANGED
@@ -17,7 +17,39 @@ if palm_api_token is None:
17
  else:
18
  palm_api.configure(api_key=palm_api_token)
19
 
20
- class GradioPaLMChatPPManager(PPManager):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
22
  if to_idx == -1 or to_idx >= len(self.pingpongs):
23
  to_idx = len(self.pingpongs)
@@ -47,15 +79,17 @@ def gen_text(
47
  'top_p': top_p,
48
  }
49
 
50
- if palm is None:
51
- response = palm_api.chat(**parameters, messages=[prompt])
52
- else:
53
- palm.temperature = parameters['temperature']
54
- palm.top_k = parameters['top_k']
55
- palm.top_p = parameters['top_p']
56
 
57
- response = palm.reply(prompt)
58
 
 
 
59
  if len(response.filters) > 0 and \
60
  response.filters[0]['reason'] == 2:
61
  response_txt = "your request is blocked for some reasons"
 
17
  else:
18
  palm_api.configure(api_key=palm_api_token)
19
 
20
+ class PaLMChatPromptFmt(PromptFmt):
21
+ @classmethod
22
+ def ctx(cls, context):
23
+ pass
24
+
25
+ @classmethod
26
+ def prompt(cls, pingpong, truncate_size):
27
+ ping = pingpong.ping[:truncate_size]
28
+ pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
29
+ return [
30
+ {
31
+ "author": "USER",
32
+ "content": ping
33
+ },
34
+ {
35
+ "author": "AI",
36
+ "content": pong
37
+ },
38
+ ]
39
+
40
+ class PaLMChatPPManager(PPManager):
41
+ def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
42
+ results = []
43
+
44
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
45
+ to_idx = len(self.pingpongs)
46
+
47
+ for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
48
+ results += fmt.prompt(pingpong, truncate_size=truncate_size)
49
+
50
+ return results
51
+
52
+ class GradioPaLMChatPPManager(PaLMChatPPManager):
53
  def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
54
  if to_idx == -1 or to_idx >= len(self.pingpongs):
55
  to_idx = len(self.pingpongs)
 
79
  'top_p': top_p,
80
  }
81
 
82
+ # if palm is None:
83
+ # response = palm_api.chat(**parameters, messages=[prompt])
84
+ # else:
85
+ # palm.temperature = parameters['temperature']
86
+ # palm.top_k = parameters['top_k']
87
+ # palm.top_p = parameters['top_p']
88
 
89
+ # response = palm.reply(prompt)
90
 
91
+ response = palm_api.chat(**parameters, messages=prompt)
92
+
93
  if len(response.filters) > 0 and \
94
  response.filters[0]['reason'] == 2:
95
  response_txt = "your request is blocked for some reasons"