Spaces:
Sleeping
Sleeping
File size: 7,882 Bytes
c26dfd8 a363f1b c26dfd8 a363f1b c26dfd8 a363f1b 0ce1a9f a363f1b c26dfd8 007cc3d c26dfd8 b346648 c26dfd8 a2154af f8a0305 c26dfd8 93defe7 bbf38ab c26dfd8 ea9cb69 c26dfd8 e99bd71 c26dfd8 93defe7 ea9cb69 c26dfd8 b346648 c26dfd8 93defe7 c26dfd8 ea9cb69 c26dfd8 8728d12 c26dfd8 8728d12 c6d16d4 6e4855e ea9cb69 666f878 ea9cb69 3ac03d8 93defe7 cef64b2 93defe7 f8a0305 dbe4a3e 02f41f3 93defe7 bbf38ab 93defe7 4b9ef74 007cc3d 4b9ef74 0b2092a 93defe7 12eb16f 93defe7 d987918 c26dfd8 9c8c84b c26dfd8 02f41f3 f4d58e4 007cc3d aad9d90 c26dfd8 02f41f3 c26dfd8 02f41f3 c26dfd8 2582c02 c26dfd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from __future__ import annotations
import logging
import os
import colorama
import commentjson as cjson
from modules import config
from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel, ModelType
def get_model(
model_name,
lora_model_path=None,
access_key=None,
temperature=None,
top_p=None,
system_prompt=None,
user_name="",
original_model = None
) -> BaseLLMModel:
msg = i18n("模型设置为了:") + f" {model_name}"
model_type = ModelType.get_type(model_name)
lora_selector_visibility = False
lora_choices = ["No LoRA"]
dont_change_lora_selector = False
if model_type != ModelType.OpenAI:
config.local_embedding = True
# del current_model.model
model = original_model
chatbot = gr.Chatbot.update(label=model_name)
try:
if model_type == ModelType.OpenAI:
logging.info(f"正在加载OpenAI模型: {model_name}")
from .OpenAI import OpenAIClient
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAIClient(
model_name=model_name,
api_key=access_key,
system_prompt=system_prompt,
temperature=temperature,
top_p=top_p,
user_name=user_name,
)
elif model_type == ModelType.OpenAIInstruct:
logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
from .OpenAIInstruct import OpenAI_Instruct_Client
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAI_Instruct_Client(
model_name, api_key=access_key, user_name=user_name)
elif model_type == ModelType.ChatGLM:
logging.info(f"正在加载ChatGLM模型: {model_name}")
from .ChatGLM import ChatGLM_Client
model = ChatGLM_Client(model_name, user_name=user_name)
elif model_type == ModelType.LLaMA and lora_model_path == "":
msg = f"现在请为 {model_name} 选择LoRA模型"
logging.info(msg)
lora_selector_visibility = True
if os.path.isdir("lora"):
lora_choices = ["No LoRA"] + get_file_names_by_pinyin("lora", filetypes=[""])
elif model_type == ModelType.LLaMA and lora_model_path != "":
logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
from .LLaMA import LLaMA_Client
dont_change_lora_selector = True
if lora_model_path == "No LoRA":
lora_model_path = None
msg += " + No LoRA"
else:
msg += f" + {lora_model_path}"
model = LLaMA_Client(
model_name, lora_model_path, user_name=user_name)
elif model_type == ModelType.XMChat:
from .XMChat import XMChat
if os.environ.get("XMCHAT_API_KEY") != "":
access_key = os.environ.get("XMCHAT_API_KEY")
model = XMChat(api_key=access_key, user_name=user_name)
elif model_type == ModelType.StableLM:
from .StableLM import StableLM_Client
model = StableLM_Client(model_name, user_name=user_name)
elif model_type == ModelType.MOSS:
from .MOSS import MOSS_Client
model = MOSS_Client(model_name, user_name=user_name)
elif model_type == ModelType.YuanAI:
from .inspurai import Yuan_Client
model = Yuan_Client(model_name, api_key=access_key,
user_name=user_name, system_prompt=system_prompt)
elif model_type == ModelType.Minimax:
from .minimax import MiniMax_Client
if os.environ.get("MINIMAX_API_KEY") != "":
access_key = os.environ.get("MINIMAX_API_KEY")
model = MiniMax_Client(
model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
elif model_type == ModelType.ChuanhuAgent:
from .ChuanhuAgent import ChuanhuAgent_Client
model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
msg = i18n("启用的工具:") + ", ".join([i.name for i in model.tools])
elif model_type == ModelType.GooglePaLM:
from .GooglePaLM import Google_PaLM_Client
access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
model = Google_PaLM_Client(
model_name, access_key, user_name=user_name)
elif model_type == ModelType.LangchainChat:
from .Azure import Azure_OpenAI_Client
model = Azure_OpenAI_Client(model_name, user_name=user_name)
elif model_type == ModelType.Midjourney:
from .midjourney import Midjourney_Client
mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
model = Midjourney_Client(
model_name, mj_proxy_api_secret, user_name=user_name)
elif model_type == ModelType.Spark:
from .spark import Spark_Client
model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
"SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
elif model_type == ModelType.Claude:
from .Claude import Claude_Client
model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET"))
elif model_type == ModelType.Unknown:
raise ValueError(f"未知模型: {model_name}")
logging.info(msg)
except Exception as e:
import traceback
traceback.print_exc()
msg = f"{STANDARD_ERROR_MSG}: {e}"
presudo_key = hide_middle_chars(access_key)
if original_model is not None and model is not None:
model.history = original_model.history
model.history_file_path = original_model.history_file_path
if dont_change_lora_selector:
return model, msg, chatbot, gr.update(), access_key, presudo_key
else:
return model, msg, chatbot, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility), access_key, presudo_key
if __name__ == "__main__":
with open("config.json", "r", encoding="utf-8") as f:
openai_api_key = cjson.load(f)["openai_api_key"]
# set logging level to debug
logging.basicConfig(level=logging.DEBUG)
# client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
client = get_model(model_name="chatglm-6b-int4")
chatbot = []
stream = False
# 测试账单功能
logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
logging.info(client.billing_info())
# 测试问答
logging.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
question = "巴黎是中国的首都吗?"
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
logging.info(i)
logging.info(f"测试问答后history : {client.history}")
# 测试记忆力
logging.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
question = "我刚刚问了你什么问题?"
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
logging.info(i)
logging.info(f"测试记忆力后history : {client.history}")
# 测试重试功能
logging.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
for i in client.retry(chatbot=chatbot, stream=stream):
logging.info(i)
logging.info(f"重试后history : {client.history}")
# # 测试总结功能
# print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
# chatbot, msg = client.reduce_token_size(chatbot=chatbot)
# print(chatbot, msg)
# print(f"总结后history: {client.history}")
|