import os import json import requests import sseclient import google.generativeai as palm_api from pingpong import PingPong from pingpong.pingpong import PPManager from pingpong.pingpong import PromptFmt from pingpong.pingpong import UIFmt from pingpong.gradio import GradioChatUIFmt palm_api_token = os.getenv("PALM_API_TOKEN") if palm_api_token is None: raise ValueError("PaLM API Token is not set") else: palm_api.configure(api_key=palm_api_token) class PaLMChatPromptFmt(PromptFmt): @classmethod def ctx(cls, context): pass @classmethod def prompt(cls, pingpong, truncate_size): ping = pingpong.ping[:truncate_size] pong = pingpong.pong if pong is None: return [ { "author": "USER", "content": ping }, ] else: pong = pong[:truncate_size] return [ { "author": "USER", "content": ping }, { "author": "AI", "content": pong }, ] class PaLMChatPPManager(PPManager): def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None): results = [] if to_idx == -1 or to_idx >= len(self.pingpongs): to_idx = len(self.pingpongs) for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]): results += fmt.prompt(pingpong, truncate_size=truncate_size) return results class GradioPaLMChatPPManager(PaLMChatPPManager): def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt): if to_idx == -1 or to_idx >= len(self.pingpongs): to_idx = len(self.pingpongs) results = [] for pingpong in self.pingpongs[from_idx:to_idx]: results.append(fmt.ui(pingpong)) return results def gen_text( prompt, parameters=None ): if parameters is None: temperature = 0.7 top_k = 40 top_p = 0.95 parameters = { # 'model': 'models/text-bison-001', 'candidate_count': 1, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, } # if palm is None: # response = palm_api.chat(**parameters, messages=[prompt]) # else: # palm.temperature = parameters['temperature'] # palm.top_k = parameters['top_k'] # palm.top_p = parameters['top_p'] # response = palm.reply(prompt) response = palm_api.chat(**parameters, messages=prompt) if len(response.filters) > 0 and \ response.filters[0]['reason'] == 2: response_txt = "your request is blocked for some reasons" else: response_txt = response.last return response, response_txt