Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from typing import TYPE_CHECKING, List | |
import logging | |
import json | |
import commentjson as cjson | |
import os | |
import sys | |
import requests | |
import urllib3 | |
from tqdm import tqdm | |
import colorama | |
from duckduckgo_search import ddg | |
import asyncio | |
import aiohttp | |
from enum import Enum | |
from .presets import * | |
from .llama_func import * | |
from .utils import * | |
from . import shared | |
from .config import retrieve_proxy | |
from .base_model import BaseLLMModel, ModelType | |
class OpenAIClient(BaseLLMModel): | |
def __init__( | |
self, model_name, api_key, system_prompt=INITIAL_SYSTEM_PROMPT, temperature=1.0, top_p=1.0 | |
) -> None: | |
super().__init__(model_name=model_name, temperature=temperature, top_p=top_p, system_prompt=system_prompt) | |
self.api_key = api_key | |
self.completion_url = shared.state.completion_url | |
self.usage_api_url = shared.state.usage_api_url | |
self.headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {self.api_key}", | |
} | |
def get_answer_stream_iter(self): | |
response = self._get_response(stream=True) | |
if response is not None: | |
iter = self._decode_chat_response(response) | |
partial_text = "" | |
for i in iter: | |
partial_text += i | |
yield partial_text | |
else: | |
yield standard_error_msg + general_error_msg | |
def get_answer_at_once(self): | |
response = self._get_response() | |
response = json.loads(response.text) | |
content = response["choices"][0]["message"]["content"] | |
total_token_count = response["usage"]["total_tokens"] | |
return content, total_token_count | |
def count_token(self, user_input): | |
input_token_count = count_token(construct_user(user_input)) | |
if self.system_prompt is not None and len(self.all_token_counts) == 0: | |
system_prompt_token_count = count_token(construct_system(self.system_prompt)) | |
return input_token_count + system_prompt_token_count | |
return input_token_count | |
def set_system_prompt(self, new_system_prompt): | |
self.system_prompt = new_system_prompt | |
def billing_info(self): | |
try: | |
curr_time = datetime.datetime.now() | |
last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d") | |
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d") | |
usage_url = f"{self.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}" | |
try: | |
usage_data = self._get_billing_data(usage_url) | |
except Exception as e: | |
logging.error(f"获取API使用情况失败:"+str(e)) | |
return f"**获取API使用情况失败**" | |
rounded_usage = "{:.5f}".format(usage_data['total_usage']/100) | |
return f"**本月使用金额** \u3000 ${rounded_usage}" | |
except requests.exceptions.ConnectTimeout: | |
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt | |
return status_text | |
except requests.exceptions.ReadTimeout: | |
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt | |
return status_text | |
except Exception as e: | |
logging.error(f"获取API使用情况失败:"+str(e)) | |
return standard_error_msg + error_retrieve_prompt | |
# 在不开启多账号模式的时候,这个装饰器不会起作用 | |
def _get_response(self, stream=False): | |
openai_api_key = self.api_key | |
system_prompt = self.system_prompt | |
history = self.history | |
logging.debug(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET) | |
temperature = self.temperature | |
top_p = self.top_p | |
selected_model = self.model_name | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {openai_api_key}", | |
} | |
if system_prompt is not None: | |
history = [construct_system(system_prompt), *history] | |
payload = { | |
"model": selected_model, | |
"messages": history, # [{"role": "user", "content": f"{inputs}"}], | |
"temperature": temperature, # 1.0, | |
"top_p": top_p, # 1.0, | |
"n": 1, | |
"stream": stream, | |
"presence_penalty": 0, | |
"frequency_penalty": 0, | |
} | |
if stream: | |
timeout = timeout_streaming | |
else: | |
timeout = TIMEOUT_ALL | |
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求 | |
if shared.state.completion_url != COMPLETION_URL: | |
logging.info(f"使用自定义API URL: {shared.state.completion_url}") | |
with retrieve_proxy(): | |
try: | |
response = requests.post( | |
shared.state.completion_url, | |
headers=headers, | |
json=payload, | |
stream=stream, | |
timeout=timeout, | |
) | |
except: | |
return None | |
return response | |
def _get_billing_data(self, usage_url): | |
with retrieve_proxy(): | |
response = requests.get( | |
usage_url, | |
headers=self.headers, | |
timeout=TIMEOUT_ALL, | |
) | |
if response.status_code == 200: | |
data = response.json() | |
return data | |
else: | |
raise Exception(f"API request failed with status code {response.status_code}: {response.text}") | |
def _decode_chat_response(self, response): | |
for chunk in response.iter_lines(): | |
if chunk: | |
chunk = chunk.decode() | |
chunk_length = len(chunk) | |
try: | |
chunk = json.loads(chunk[6:]) | |
except json.JSONDecodeError: | |
print(f"JSON解析错误,收到的内容: {chunk}") | |
continue | |
if chunk_length > 6 and "delta" in chunk["choices"][0]: | |
if chunk["choices"][0]["finish_reason"] == "stop": | |
break | |
try: | |
yield chunk["choices"][0]["delta"]["content"] | |
except Exception as e: | |
# logging.error(f"Error: {e}") | |
continue | |
def get_model(model_name, access_key=None, temprature=None, top_p=None, system_prompt = None) -> BaseLLMModel: | |
model_type = ModelType.get_type(model_name) | |
if model_type == ModelType.OpenAI: | |
model = OpenAIClient(model_name, access_key, system_prompt, temprature, top_p) | |
return model | |
if __name__=="__main__": | |
with open("config.json", "r") as f: | |
openai_api_key = cjson.load(f)["openai_api_key"] | |
client = OpenAIClient("gpt-3.5-turbo", openai_api_key) | |
chatbot = [] | |
stream = False | |
# 测试账单功能 | |
print(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET) | |
print(client.billing_info()) | |
# 测试问答 | |
print(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET) | |
question = "巴黎是中国的首都吗?" | |
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream): | |
print(i) | |
print(f"测试问答后history : {client.history}") | |
# 测试记忆力 | |
print(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET) | |
question = "我刚刚问了你什么问题?" | |
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream): | |
print(i) | |
print(f"测试记忆力后history : {client.history}") | |
# 测试重试功能 | |
print(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET) | |
for i in client.retry(chatbot=chatbot, stream=stream): | |
print(i) | |
print(f"重试后history : {client.history}") | |
# # 测试总结功能 | |
# print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET) | |
# chatbot, msg = client.reduce_token_size(chatbot=chatbot) | |
# print(chatbot, msg) | |
# print(f"总结后history: {client.history}") | |