Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import _thread as thread | |
import base64 | |
import datetime | |
import hashlib | |
import hmac | |
import json | |
from collections import deque | |
from urllib.parse import urlparse | |
import ssl | |
from datetime import datetime | |
from time import mktime | |
from urllib.parse import urlencode | |
from wsgiref.handlers import format_date_time | |
from threading import Condition | |
import websocket | |
import logging | |
from .base_model import BaseLLMModel, CallbackToIterator | |
class Ws_Param(object): | |
# 来自官方 Demo | |
# 初始化 | |
def __init__(self, APPID, APIKey, APISecret, Spark_url): | |
self.APPID = APPID | |
self.APIKey = APIKey | |
self.APISecret = APISecret | |
self.host = urlparse(Spark_url).netloc | |
self.path = urlparse(Spark_url).path | |
self.Spark_url = Spark_url | |
# 生成url | |
def create_url(self): | |
# 生成RFC1123格式的时间戳 | |
now = datetime.now() | |
date = format_date_time(mktime(now.timetuple())) | |
# 拼接字符串 | |
signature_origin = "host: " + self.host + "\n" | |
signature_origin += "date: " + date + "\n" | |
signature_origin += "GET " + self.path + " HTTP/1.1" | |
# 进行hmac-sha256进行加密 | |
signature_sha = hmac.new( | |
self.APISecret.encode("utf-8"), | |
signature_origin.encode("utf-8"), | |
digestmod=hashlib.sha256, | |
).digest() | |
signature_sha_base64 = base64.b64encode( | |
signature_sha).decode(encoding="utf-8") | |
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' | |
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( | |
encoding="utf-8" | |
) | |
# 将请求的鉴权参数组合为字典 | |
v = {"authorization": authorization, "date": date, "host": self.host} | |
# 拼接鉴权参数,生成url | |
url = self.Spark_url + "?" + urlencode(v) | |
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 | |
return url | |
class Spark_Client(BaseLLMModel): | |
def __init__(self, model_name, appid, api_key, api_secret, user_name="") -> None: | |
super().__init__(model_name=model_name, user=user_name) | |
self.api_key = api_key | |
self.appid = appid | |
self.api_secret = api_secret | |
if None in [self.api_key, self.appid, self.api_secret]: | |
raise Exception("请在配置文件或者环境变量中设置讯飞的API Key、APP ID和API Secret") | |
if "2.0" in self.model_name: | |
self.spark_url = "wss://spark-api.xf-yun.com/v2.1/chat" | |
self.domain = "generalv2" | |
if "3.0" in self.model_name: | |
self.spark_url = "wss://spark-api.xf-yun.com/v3.1/chat" | |
self.domain = "generalv3" | |
else: | |
self.spark_url = "wss://spark-api.xf-yun.com/v1.1/chat" | |
self.domain = "general" | |
# 收到websocket错误的处理 | |
def on_error(self, ws, error): | |
ws.iterator.callback("出现了错误:" + error) | |
# 收到websocket关闭的处理 | |
def on_close(self, ws, one, two): | |
pass | |
# 收到websocket连接建立的处理 | |
def on_open(self, ws): | |
thread.start_new_thread(self.run, (ws,)) | |
def run(self, ws, *args): | |
data = json.dumps( | |
self.gen_params() | |
) | |
ws.send(data) | |
# 收到websocket消息的处理 | |
def on_message(self, ws, message): | |
ws.iterator.callback(message) | |
def gen_params(self): | |
""" | |
通过appid和用户的提问来生成请参数 | |
""" | |
data = { | |
"header": {"app_id": self.appid, "uid": "1234"}, | |
"parameter": { | |
"chat": { | |
"domain": self.domain, | |
"random_threshold": self.temperature, | |
"max_tokens": 4096, | |
"auditing": "default", | |
} | |
}, | |
"payload": {"message": {"text": self.history}}, | |
} | |
return data | |
def get_answer_stream_iter(self): | |
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.spark_url) | |
websocket.enableTrace(False) | |
wsUrl = wsParam.create_url() | |
ws = websocket.WebSocketApp( | |
wsUrl, | |
on_message=self.on_message, | |
on_error=self.on_error, | |
on_close=self.on_close, | |
on_open=self.on_open, | |
) | |
ws.appid = self.appid | |
ws.domain = self.domain | |
# Initialize the CallbackToIterator | |
ws.iterator = CallbackToIterator() | |
# Start the WebSocket connection in a separate thread | |
thread.start_new_thread( | |
ws.run_forever, (), {"sslopt": {"cert_reqs": ssl.CERT_NONE}} | |
) | |
# Iterate over the CallbackToIterator instance | |
answer = "" | |
total_tokens = 0 | |
for message in ws.iterator: | |
data = json.loads(message) | |
code = data["header"]["code"] | |
if code != 0: | |
ws.close() | |
raise Exception(f"请求错误: {code}, {data}") | |
else: | |
choices = data["payload"]["choices"] | |
status = choices["status"] | |
content = choices["text"][0]["content"] | |
if "usage" in data["payload"]: | |
total_tokens = data["payload"]["usage"]["text"]["total_tokens"] | |
answer += content | |
if status == 2: | |
ws.iterator.finish() # Finish the iterator when the status is 2 | |
ws.close() | |
yield answer, total_tokens | |