File size: 2,597 Bytes
dd486e6
6aaddfa
92cf0ad
 
602e36b
dd486e6
 
 
 
 
 
 
 
602e36b
 
 
 
 
dd486e6
d0461f7
 
 
 
 
 
 
 
05ddd47
 
58a9a0e
05ddd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0461f7
 
 
 
 
 
 
 
 
 
 
 
 
 
dd486e6
 
 
 
 
 
 
 
 
602e36b
dd486e6
59fcd34
602e36b
dd486e6
 
602e36b
 
 
 
 
 
6f5f04c
7851c38
6cb551d
602e36b
 
 
 
 
59fcd34
7dce6c2
 
602e36b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import json
import requests
import sseclient
import google.generativeai as palm_api

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt


palm_api_token = os.getenv("PALM_API_TOKEN")
if palm_api_token is None:
    raise ValueError("PaLM API Token is not set")
else:
    palm_api.configure(api_key=palm_api_token)

class PaLMChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        pass

    @classmethod
    def prompt(cls, pingpong, truncate_size):
        ping = pingpong.ping[:truncate_size]
        pong = pingpong.pong
        
        if pong is None or pong.strip() == "":
            return [
                {
                    "author": "USER",
                    "content": ping
                },
            ]
        else:
            pong = pong[:truncate_size]

            return [
                {
                    "author": "USER",
                    "content": ping
                },
                {
                    "author": "AI",
                    "content": pong
                },
            ]

class PaLMChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
        results = []
        
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
            results += fmt.prompt(pingpong, truncate_size=truncate_size)

        return results    

class GradioPaLMChatPPManager(PaLMChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = []

        for pingpong in self.pingpongs[from_idx:to_idx]:
            results.append(fmt.ui(pingpong))

        return results    

async def gen_text(
    prompt,
    parameters=None
):
    if parameters is None:
        temperature = 0.7
        top_k = 40
        top_p = 0.95
    
        parameters = {
            'model': 'models/chat-bison-001',
            'candidate_count': 1,
            'context': "",
            'temperature': temperature,
            'top_k': top_k,
            'top_p': top_p,
        }

    response = await palm_api.chat_async(**parameters, messages=prompt)

    response_txt = response.last
    
    return response, response_txt