Spaces:
Sleeping
Sleeping
Xudong Liu
commited on
Commit
•
d987918
1
Parent(s):
c51b92e
增加对Anthropic的Claude大模型的支持 (#919)
Browse files* 增加了对Anthropic的Claude模型的支持
* 增加对Anthropic的Claude大模型的支持
- config_example.json +1 -0
- modules/config.py +3 -0
- modules/models/Claude.py +55 -0
- modules/models/base_model.py +3 -0
- modules/models/models.py +3 -0
- modules/presets.py +4 -2
- requirements.txt +2 -0
config_example.json
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
"spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
|
15 |
"spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
|
16 |
"spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
|
|
|
17 |
|
18 |
|
19 |
//== Azure ==
|
|
|
14 |
"spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
|
15 |
"spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
|
16 |
"spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
|
17 |
+
"claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
|
18 |
|
19 |
|
20 |
//== Azure ==
|
modules/config.py
CHANGED
@@ -128,6 +128,9 @@ os.environ["SPARK_APPID"] = spark_appid
|
|
128 |
spark_api_secret = config.get("spark_api_secret", "")
|
129 |
os.environ["SPARK_API_SECRET"] = spark_api_secret
|
130 |
|
|
|
|
|
|
|
131 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
132 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
133 |
|
|
|
128 |
spark_api_secret = config.get("spark_api_secret", "")
|
129 |
os.environ["SPARK_API_SECRET"] = spark_api_secret
|
130 |
|
131 |
+
claude_api_secret = config.get("claude_api_secret", "")
|
132 |
+
os.environ["CLAUDE_API_SECRET"] = claude_api_secret
|
133 |
+
|
134 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
135 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
136 |
|
modules/models/Claude.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
3 |
+
from ..presets import *
|
4 |
+
from ..utils import *
|
5 |
+
|
6 |
+
from .base_model import BaseLLMModel
|
7 |
+
|
8 |
+
|
9 |
+
class Claude_Client(BaseLLMModel):
|
10 |
+
def __init__(self, model_name, api_secret) -> None:
|
11 |
+
super().__init__(model_name=model_name)
|
12 |
+
self.api_secret = api_secret
|
13 |
+
if None in [self.api_secret]:
|
14 |
+
raise Exception("请在配置文件或者环境变量中设置Claude的API Secret")
|
15 |
+
self.claude_client = Anthropic(api_key=self.api_secret)
|
16 |
+
|
17 |
+
|
18 |
+
def get_answer_stream_iter(self):
|
19 |
+
system_prompt = self.system_prompt
|
20 |
+
history = self.history
|
21 |
+
if system_prompt is not None:
|
22 |
+
history = [construct_system(system_prompt), *history]
|
23 |
+
|
24 |
+
completion = self.claude_client.completions.create(
|
25 |
+
model=self.model_name,
|
26 |
+
max_tokens_to_sample=300,
|
27 |
+
prompt=f"{HUMAN_PROMPT}{history}{AI_PROMPT}",
|
28 |
+
stream=True,
|
29 |
+
)
|
30 |
+
if completion is not None:
|
31 |
+
partial_text = ""
|
32 |
+
for chunk in completion:
|
33 |
+
partial_text += chunk.completion
|
34 |
+
yield partial_text
|
35 |
+
else:
|
36 |
+
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
|
37 |
+
|
38 |
+
|
39 |
+
def get_answer_at_once(self):
|
40 |
+
system_prompt = self.system_prompt
|
41 |
+
history = self.history
|
42 |
+
if system_prompt is not None:
|
43 |
+
history = [construct_system(system_prompt), *history]
|
44 |
+
|
45 |
+
completion = self.claude_client.completions.create(
|
46 |
+
model=self.model_name,
|
47 |
+
max_tokens_to_sample=300,
|
48 |
+
prompt=f"{HUMAN_PROMPT}{history}{AI_PROMPT}",
|
49 |
+
)
|
50 |
+
if completion is not None:
|
51 |
+
return completion.completion, len(completion.completion)
|
52 |
+
else:
|
53 |
+
return "获取资源错误", 0
|
54 |
+
|
55 |
+
|
modules/models/base_model.py
CHANGED
@@ -145,6 +145,7 @@ class ModelType(Enum):
|
|
145 |
Midjourney = 11
|
146 |
Spark = 12
|
147 |
OpenAIInstruct = 13
|
|
|
148 |
|
149 |
@classmethod
|
150 |
def get_type(cls, model_name: str):
|
@@ -179,6 +180,8 @@ class ModelType(Enum):
|
|
179 |
model_type = ModelType.LangchainChat
|
180 |
elif "星火大模型" in model_name_lower:
|
181 |
model_type = ModelType.Spark
|
|
|
|
|
182 |
else:
|
183 |
model_type = ModelType.LLaMA
|
184 |
return model_type
|
|
|
145 |
Midjourney = 11
|
146 |
Spark = 12
|
147 |
OpenAIInstruct = 13
|
148 |
+
Claude = 14
|
149 |
|
150 |
@classmethod
|
151 |
def get_type(cls, model_name: str):
|
|
|
180 |
model_type = ModelType.LangchainChat
|
181 |
elif "星火大模型" in model_name_lower:
|
182 |
model_type = ModelType.Spark
|
183 |
+
elif "claude" in model_name_lower:
|
184 |
+
model_type = ModelType.Claude
|
185 |
else:
|
186 |
model_type = ModelType.LLaMA
|
187 |
return model_type
|
modules/models/models.py
CHANGED
@@ -116,6 +116,9 @@ def get_model(
|
|
116 |
from .spark import Spark_Client
|
117 |
model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
|
118 |
"SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
|
|
|
|
|
|
|
119 |
elif model_type == ModelType.Unknown:
|
120 |
raise ValueError(f"未知模型: {model_name}")
|
121 |
logging.info(msg)
|
|
|
116 |
from .spark import Spark_Client
|
117 |
model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
|
118 |
"SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
|
119 |
+
elif model_type == ModelType.Claude:
|
120 |
+
from .Claude import Claude_Client
|
121 |
+
model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET"))
|
122 |
elif model_type == ModelType.Unknown:
|
123 |
raise ValueError(f"未知模型: {model_name}")
|
124 |
logging.info(msg)
|
modules/presets.py
CHANGED
@@ -74,7 +74,8 @@ ONLINE_MODELS = [
|
|
74 |
"minimax-abab5-chat",
|
75 |
"midjourney",
|
76 |
"讯飞星火大模型V2.0",
|
77 |
-
"讯飞星火大模型V1.5"
|
|
|
78 |
]
|
79 |
|
80 |
LOCAL_MODELS = [
|
@@ -125,7 +126,8 @@ MODEL_TOKEN_LIMIT = {
|
|
125 |
"gpt-4-0613": 8192,
|
126 |
"gpt-4-32k": 32768,
|
127 |
"gpt-4-32k-0314": 32768,
|
128 |
-
"gpt-4-32k-0613": 32768
|
|
|
129 |
}
|
130 |
|
131 |
TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
|
|
|
74 |
"minimax-abab5-chat",
|
75 |
"midjourney",
|
76 |
"讯飞星火大模型V2.0",
|
77 |
+
"讯飞星火大模型V1.5",
|
78 |
+
"Claude"
|
79 |
]
|
80 |
|
81 |
LOCAL_MODELS = [
|
|
|
126 |
"gpt-4-0613": 8192,
|
127 |
"gpt-4-32k": 32768,
|
128 |
"gpt-4-32k-0314": 32768,
|
129 |
+
"gpt-4-32k-0613": 32768,
|
130 |
+
"Claude": 4096
|
131 |
}
|
132 |
|
133 |
TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
|
requirements.txt
CHANGED
@@ -30,3 +30,5 @@ python-docx
|
|
30 |
websocket_client
|
31 |
pydantic==1.10.8
|
32 |
google-search-results
|
|
|
|
|
|
30 |
websocket_client
|
31 |
pydantic==1.10.8
|
32 |
google-search-results
|
33 |
+
anthropic==0.3.11
|
34 |
+
|