Upload 243 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- img/docker_logs.png +0 -0
- img/langchain+chatglm.png +3 -0
- img/langchain+chatglm2.png +0 -0
- img/qr_code_36.jpg +0 -0
- img/qr_code_37.jpg +0 -0
- img/qr_code_38.jpg +0 -0
- img/qr_code_39.jpg +0 -0
- img/vue_0521_0.png +0 -0
- img/vue_0521_1.png +3 -0
- img/vue_0521_2.png +3 -0
- img/webui_0419.png +0 -0
- img/webui_0510_0.png +0 -0
- img/webui_0510_1.png +0 -0
- img/webui_0510_2.png +0 -0
- img/webui_0521_0.png +0 -0
- loader/RSS_loader.py +54 -0
- loader/__init__.py +14 -0
- loader/__pycache__/__init__.cpython-310.pyc +0 -0
- loader/__pycache__/__init__.cpython-311.pyc +0 -0
- loader/__pycache__/dialogue.cpython-310.pyc +0 -0
- loader/__pycache__/image_loader.cpython-310.pyc +0 -0
- loader/__pycache__/image_loader.cpython-311.pyc +0 -0
- loader/__pycache__/pdf_loader.cpython-310.pyc +0 -0
- loader/dialogue.py +131 -0
- loader/image_loader.py +42 -0
- loader/pdf_loader.py +58 -0
- models/__init__.py +4 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/chatglm_llm.cpython-310.pyc +0 -0
- models/__pycache__/fastchat_openai_llm.cpython-310.pyc +0 -0
- models/__pycache__/llama_llm.cpython-310.pyc +0 -0
- models/__pycache__/moss_llm.cpython-310.pyc +0 -0
- models/__pycache__/shared.cpython-310.pyc +0 -0
- models/base/__init__.py +13 -0
- models/base/__pycache__/__init__.cpython-310.pyc +0 -0
- models/base/__pycache__/base.cpython-310.pyc +0 -0
- models/base/__pycache__/remote_rpc_model.cpython-310.pyc +0 -0
- models/base/base.py +41 -0
- models/base/lavis_blip2_multimodel.py +26 -0
- models/base/remote_rpc_model.py +33 -0
- models/chatglm_llm.py +83 -0
- models/fastchat_openai_llm.py +137 -0
- models/llama_llm.py +185 -0
- models/loader/__init__.py +2 -0
- models/loader/__pycache__/__init__.cpython-310.pyc +0 -0
- models/loader/__pycache__/args.cpython-310.pyc +0 -0
- models/loader/__pycache__/loader.cpython-310.pyc +0 -0
- models/loader/args.py +55 -0
- models/loader/loader.py +447 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
img/langchain+chatglm.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
img/vue_0521_1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
img/vue_0521_2.png filter=lfs diff=lfs merge=lfs -text
|
img/docker_logs.png
ADDED
img/langchain+chatglm.png
ADDED
Git LFS Details
|
img/langchain+chatglm2.png
ADDED
img/qr_code_36.jpg
ADDED
img/qr_code_37.jpg
ADDED
img/qr_code_38.jpg
ADDED
img/qr_code_39.jpg
ADDED
img/vue_0521_0.png
ADDED
img/vue_0521_1.png
ADDED
Git LFS Details
|
img/vue_0521_2.png
ADDED
Git LFS Details
|
img/webui_0419.png
ADDED
img/webui_0510_0.png
ADDED
img/webui_0510_1.png
ADDED
img/webui_0510_2.png
ADDED
img/webui_0521_0.png
ADDED
loader/RSS_loader.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.docstore.document import Document
|
2 |
+
import feedparser
|
3 |
+
import html2text
|
4 |
+
import ssl
|
5 |
+
import time
|
6 |
+
|
7 |
+
|
8 |
+
class RSS_Url_loader:
|
9 |
+
def __init__(self, urls=None,interval=60):
|
10 |
+
'''可用参数urls数组或者是字符串形式的url列表'''
|
11 |
+
self.urls = []
|
12 |
+
self.interval = interval
|
13 |
+
if urls is not None:
|
14 |
+
try:
|
15 |
+
if isinstance(urls, str):
|
16 |
+
urls = [urls]
|
17 |
+
elif isinstance(urls, list):
|
18 |
+
pass
|
19 |
+
else:
|
20 |
+
raise TypeError('urls must be a list or a string.')
|
21 |
+
self.urls = urls
|
22 |
+
except:
|
23 |
+
Warning('urls must be a list or a string.')
|
24 |
+
|
25 |
+
#定时代码还要考虑是不是引入其他类,暂时先不对外开放
|
26 |
+
def scheduled_execution(self):
|
27 |
+
while True:
|
28 |
+
docs = self.load()
|
29 |
+
return docs
|
30 |
+
time.sleep(self.interval)
|
31 |
+
|
32 |
+
def load(self):
|
33 |
+
if hasattr(ssl, '_create_unverified_context'):
|
34 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
35 |
+
documents = []
|
36 |
+
for url in self.urls:
|
37 |
+
parsed = feedparser.parse(url)
|
38 |
+
for entry in parsed.entries:
|
39 |
+
if "content" in entry:
|
40 |
+
data = entry.content[0].value
|
41 |
+
else:
|
42 |
+
data = entry.description or entry.summary
|
43 |
+
data = html2text.html2text(data)
|
44 |
+
metadata = {"title": entry.title, "link": entry.link}
|
45 |
+
documents.append(Document(page_content=data, metadata=metadata))
|
46 |
+
return documents
|
47 |
+
|
48 |
+
if __name__=="__main__":
|
49 |
+
#需要在配置文件中加入urls的配置,或者是在用户界面上加入urls的配置
|
50 |
+
urls = ["https://www.zhihu.com/rss", "https://www.36kr.com/feed"]
|
51 |
+
loader = RSS_Url_loader(urls)
|
52 |
+
docs = loader.load()
|
53 |
+
for doc in docs:
|
54 |
+
print(doc)
|
loader/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_loader import UnstructuredPaddleImageLoader
|
2 |
+
from .pdf_loader import UnstructuredPaddlePDFLoader
|
3 |
+
from .dialogue import (
|
4 |
+
Person,
|
5 |
+
Dialogue,
|
6 |
+
Turn,
|
7 |
+
DialogueLoader
|
8 |
+
)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"UnstructuredPaddleImageLoader",
|
12 |
+
"UnstructuredPaddlePDFLoader",
|
13 |
+
"DialogueLoader",
|
14 |
+
]
|
loader/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (414 Bytes). View file
|
|
loader/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (531 Bytes). View file
|
|
loader/__pycache__/dialogue.cpython-310.pyc
ADDED
Binary file (4.95 kB). View file
|
|
loader/__pycache__/image_loader.cpython-310.pyc
ADDED
Binary file (2.23 kB). View file
|
|
loader/__pycache__/image_loader.cpython-311.pyc
ADDED
Binary file (3.94 kB). View file
|
|
loader/__pycache__/pdf_loader.cpython-310.pyc
ADDED
Binary file (2.57 kB). View file
|
|
loader/dialogue.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from abc import ABC
|
3 |
+
from typing import List
|
4 |
+
from langchain.docstore.document import Document
|
5 |
+
from langchain.document_loaders.base import BaseLoader
|
6 |
+
|
7 |
+
|
8 |
+
class Person:
|
9 |
+
def __init__(self, name, age):
|
10 |
+
self.name = name
|
11 |
+
self.age = age
|
12 |
+
|
13 |
+
|
14 |
+
class Dialogue:
|
15 |
+
"""
|
16 |
+
Build an abstract dialogue model using classes and methods to represent different dialogue elements.
|
17 |
+
This class serves as a fundamental framework for constructing dialogue models.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, file_path: str):
|
21 |
+
self.file_path = file_path
|
22 |
+
self.turns = []
|
23 |
+
|
24 |
+
def add_turn(self, turn):
|
25 |
+
"""
|
26 |
+
Create an instance of a conversation participant
|
27 |
+
:param turn:
|
28 |
+
:return:
|
29 |
+
"""
|
30 |
+
self.turns.append(turn)
|
31 |
+
|
32 |
+
def parse_dialogue(self):
|
33 |
+
"""
|
34 |
+
The parse_dialogue function reads the specified dialogue file and parses each dialogue turn line by line.
|
35 |
+
For each turn, the function extracts the name of the speaker and the message content from the text,
|
36 |
+
creating a Turn instance. If the speaker is not already present in the participants dictionary,
|
37 |
+
a new Person instance is created. Finally, the parsed Turn instance is added to the Dialogue object.
|
38 |
+
|
39 |
+
Please note that this sample code assumes that each line in the file follows a specific format:
|
40 |
+
<speaker>:\r\n<message>\r\n\r\n. If your file has a different format or includes other metadata,
|
41 |
+
you may need to adjust the parsing logic accordingly.
|
42 |
+
"""
|
43 |
+
participants = {}
|
44 |
+
speaker_name = None
|
45 |
+
message = None
|
46 |
+
|
47 |
+
with open(self.file_path, encoding='utf-8') as file:
|
48 |
+
lines = file.readlines()
|
49 |
+
for i, line in enumerate(lines):
|
50 |
+
line = line.strip()
|
51 |
+
if not line:
|
52 |
+
continue
|
53 |
+
|
54 |
+
if speaker_name is None:
|
55 |
+
speaker_name, _ = line.split(':', 1)
|
56 |
+
elif message is None:
|
57 |
+
message = line
|
58 |
+
if speaker_name not in participants:
|
59 |
+
participants[speaker_name] = Person(speaker_name, None)
|
60 |
+
|
61 |
+
speaker = participants[speaker_name]
|
62 |
+
turn = Turn(speaker, message)
|
63 |
+
self.add_turn(turn)
|
64 |
+
|
65 |
+
# Reset speaker_name and message for the next turn
|
66 |
+
speaker_name = None
|
67 |
+
message = None
|
68 |
+
|
69 |
+
def display(self):
|
70 |
+
for turn in self.turns:
|
71 |
+
print(f"{turn.speaker.name}: {turn.message}")
|
72 |
+
|
73 |
+
def export_to_file(self, file_path):
|
74 |
+
with open(file_path, 'w', encoding='utf-8') as file:
|
75 |
+
for turn in self.turns:
|
76 |
+
file.write(f"{turn.speaker.name}: {turn.message}\n")
|
77 |
+
|
78 |
+
def to_dict(self):
|
79 |
+
dialogue_dict = {"turns": []}
|
80 |
+
for turn in self.turns:
|
81 |
+
turn_dict = {
|
82 |
+
"speaker": turn.speaker.name,
|
83 |
+
"message": turn.message
|
84 |
+
}
|
85 |
+
dialogue_dict["turns"].append(turn_dict)
|
86 |
+
return dialogue_dict
|
87 |
+
|
88 |
+
def to_json(self):
|
89 |
+
dialogue_dict = self.to_dict()
|
90 |
+
return json.dumps(dialogue_dict, ensure_ascii=False, indent=2)
|
91 |
+
|
92 |
+
def participants_to_export(self):
|
93 |
+
"""
|
94 |
+
participants_to_export
|
95 |
+
:return:
|
96 |
+
"""
|
97 |
+
participants = set()
|
98 |
+
for turn in self.turns:
|
99 |
+
participants.add(turn.speaker.name)
|
100 |
+
return ', '.join(participants)
|
101 |
+
|
102 |
+
|
103 |
+
class Turn:
|
104 |
+
def __init__(self, speaker, message):
|
105 |
+
self.speaker = speaker
|
106 |
+
self.message = message
|
107 |
+
|
108 |
+
|
109 |
+
class DialogueLoader(BaseLoader, ABC):
|
110 |
+
"""Load dialogue."""
|
111 |
+
|
112 |
+
def __init__(self, file_path: str):
|
113 |
+
"""Initialize with dialogue."""
|
114 |
+
self.file_path = file_path
|
115 |
+
dialogue = Dialogue(file_path=file_path)
|
116 |
+
dialogue.parse_dialogue()
|
117 |
+
self.dialogue = dialogue
|
118 |
+
|
119 |
+
def load(self) -> List[Document]:
|
120 |
+
"""Load from dialogue."""
|
121 |
+
documents = []
|
122 |
+
participants = self.dialogue.participants_to_export()
|
123 |
+
|
124 |
+
for turn in self.dialogue.turns:
|
125 |
+
metadata = {"source": f"Dialogue File:{self.dialogue.file_path},"
|
126 |
+
f"speaker:{turn.speaker.name},"
|
127 |
+
f"participant:{participants}"}
|
128 |
+
turn_document = Document(page_content=turn.message, metadata=metadata.copy())
|
129 |
+
documents.append(turn_document)
|
130 |
+
|
131 |
+
return documents
|
loader/image_loader.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Loader that loads image files."""
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
5 |
+
from paddleocr import PaddleOCR
|
6 |
+
import os
|
7 |
+
import nltk
|
8 |
+
from configs.model_config import NLTK_DATA_PATH
|
9 |
+
|
10 |
+
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
11 |
+
|
12 |
+
class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
|
13 |
+
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
14 |
+
|
15 |
+
def _get_elements(self) -> List:
|
16 |
+
def image_ocr_txt(filepath, dir_path="tmp_files"):
|
17 |
+
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
|
18 |
+
if not os.path.exists(full_dir_path):
|
19 |
+
os.makedirs(full_dir_path)
|
20 |
+
filename = os.path.split(filepath)[-1]
|
21 |
+
ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False)
|
22 |
+
result = ocr.ocr(img=filepath)
|
23 |
+
|
24 |
+
ocr_result = [i[1][0] for line in result for i in line]
|
25 |
+
txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
|
26 |
+
with open(txt_file_path, 'w', encoding='utf-8') as fout:
|
27 |
+
fout.write("\n".join(ocr_result))
|
28 |
+
return txt_file_path
|
29 |
+
|
30 |
+
txt_file_path = image_ocr_txt(self.file_path)
|
31 |
+
from unstructured.partition.text import partition_text
|
32 |
+
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
import sys
|
37 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
38 |
+
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg")
|
39 |
+
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
40 |
+
docs = loader.load()
|
41 |
+
for doc in docs:
|
42 |
+
print(doc)
|
loader/pdf_loader.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Loader that loads image files."""
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
5 |
+
from paddleocr import PaddleOCR
|
6 |
+
import os
|
7 |
+
import fitz
|
8 |
+
import nltk
|
9 |
+
from configs.model_config import NLTK_DATA_PATH
|
10 |
+
|
11 |
+
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
12 |
+
|
13 |
+
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
14 |
+
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
15 |
+
|
16 |
+
def _get_elements(self) -> List:
|
17 |
+
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
|
18 |
+
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
|
19 |
+
if not os.path.exists(full_dir_path):
|
20 |
+
os.makedirs(full_dir_path)
|
21 |
+
ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False)
|
22 |
+
doc = fitz.open(filepath)
|
23 |
+
txt_file_path = os.path.join(full_dir_path, f"{os.path.split(filepath)[-1]}.txt")
|
24 |
+
img_name = os.path.join(full_dir_path, 'tmp.png')
|
25 |
+
with open(txt_file_path, 'w', encoding='utf-8') as fout:
|
26 |
+
for i in range(doc.page_count):
|
27 |
+
page = doc[i]
|
28 |
+
text = page.get_text("")
|
29 |
+
fout.write(text)
|
30 |
+
fout.write("\n")
|
31 |
+
|
32 |
+
img_list = page.get_images()
|
33 |
+
for img in img_list:
|
34 |
+
pix = fitz.Pixmap(doc, img[0])
|
35 |
+
if pix.n - pix.alpha >= 4:
|
36 |
+
pix = fitz.Pixmap(fitz.csRGB, pix)
|
37 |
+
pix.save(img_name)
|
38 |
+
|
39 |
+
result = ocr.ocr(img_name)
|
40 |
+
ocr_result = [i[1][0] for line in result for i in line]
|
41 |
+
fout.write("\n".join(ocr_result))
|
42 |
+
if os.path.exists(img_name):
|
43 |
+
os.remove(img_name)
|
44 |
+
return txt_file_path
|
45 |
+
|
46 |
+
txt_file_path = pdf_ocr_txt(self.file_path)
|
47 |
+
from unstructured.partition.text import partition_text
|
48 |
+
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
import sys
|
53 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
54 |
+
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.pdf")
|
55 |
+
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
56 |
+
docs = loader.load()
|
57 |
+
for doc in docs:
|
58 |
+
print(doc)
|
models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .chatglm_llm import ChatGLM
|
2 |
+
from .llama_llm import LLamaLLM
|
3 |
+
from .moss_llm import MOSSLLM
|
4 |
+
from .fastchat_openai_llm import FastChatOpenAILLM
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (338 Bytes). View file
|
|
models/__pycache__/chatglm_llm.cpython-310.pyc
ADDED
Binary file (2.66 kB). View file
|
|
models/__pycache__/fastchat_openai_llm.cpython-310.pyc
ADDED
Binary file (4.45 kB). View file
|
|
models/__pycache__/llama_llm.cpython-310.pyc
ADDED
Binary file (6.45 kB). View file
|
|
models/__pycache__/moss_llm.cpython-310.pyc
ADDED
Binary file (3.88 kB). View file
|
|
models/__pycache__/shared.cpython-310.pyc
ADDED
Binary file (1.48 kB). View file
|
|
models/base/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.base.base import (
|
2 |
+
AnswerResult,
|
3 |
+
BaseAnswer
|
4 |
+
)
|
5 |
+
from models.base.remote_rpc_model import (
|
6 |
+
RemoteRpcModel
|
7 |
+
)
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"AnswerResult",
|
11 |
+
"BaseAnswer",
|
12 |
+
"RemoteRpcModel",
|
13 |
+
]
|
models/base/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (334 Bytes). View file
|
|
models/base/__pycache__/base.cpython-310.pyc
ADDED
Binary file (1.79 kB). View file
|
|
models/base/__pycache__/remote_rpc_model.cpython-310.pyc
ADDED
Binary file (1.59 kB). View file
|
|
models/base/base.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Optional, List
|
3 |
+
import traceback
|
4 |
+
from collections import deque
|
5 |
+
from queue import Queue
|
6 |
+
from threading import Thread
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import transformers
|
10 |
+
from models.loader import LoaderCheckPoint
|
11 |
+
|
12 |
+
|
13 |
+
class AnswerResult:
|
14 |
+
"""
|
15 |
+
消息实体
|
16 |
+
"""
|
17 |
+
history: List[List[str]] = []
|
18 |
+
llm_output: Optional[dict] = None
|
19 |
+
|
20 |
+
|
21 |
+
class BaseAnswer(ABC):
|
22 |
+
"""上层业务包装器.用于结果生成统一api调用"""
|
23 |
+
|
24 |
+
@property
|
25 |
+
@abstractmethod
|
26 |
+
def _check_point(self) -> LoaderCheckPoint:
|
27 |
+
"""Return _check_point of llm."""
|
28 |
+
|
29 |
+
@property
|
30 |
+
@abstractmethod
|
31 |
+
def _history_len(self) -> int:
|
32 |
+
"""Return _history_len of llm."""
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def set_history_len(self, history_len: int) -> None:
|
36 |
+
"""Return _history_len of llm."""
|
37 |
+
|
38 |
+
def generatorAnswer(self, prompt: str,
|
39 |
+
history: List[List[str]] = [],
|
40 |
+
streaming: bool = False):
|
41 |
+
pass
|
models/base/lavis_blip2_multimodel.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from models.base import (BaseAnswer,
|
5 |
+
AnswerResult)
|
6 |
+
|
7 |
+
|
8 |
+
class MultimodalAnswerResult(AnswerResult):
|
9 |
+
image: str = None
|
10 |
+
|
11 |
+
|
12 |
+
class LavisBlip2Multimodal(BaseAnswer, ABC):
|
13 |
+
|
14 |
+
@property
|
15 |
+
@abstractmethod
|
16 |
+
def _blip2_instruct(self) -> any:
|
17 |
+
"""Return _blip2_instruct of blip2."""
|
18 |
+
|
19 |
+
@property
|
20 |
+
@abstractmethod
|
21 |
+
def _image_blip2_vis_processors(self) -> dict:
|
22 |
+
"""Return _image_blip2_vis_processors of blip2 image processors."""
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def set_image_path(self, image_path: str):
|
26 |
+
"""set set_image_path"""
|
models/base/remote_rpc_model.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from models.base import (BaseAnswer,
|
5 |
+
AnswerResult)
|
6 |
+
|
7 |
+
|
8 |
+
class MultimodalAnswerResult(AnswerResult):
|
9 |
+
image: str = None
|
10 |
+
|
11 |
+
|
12 |
+
class RemoteRpcModel(BaseAnswer, ABC):
|
13 |
+
|
14 |
+
@property
|
15 |
+
@abstractmethod
|
16 |
+
def _api_key(self) -> str:
|
17 |
+
"""Return _api_key of client."""
|
18 |
+
|
19 |
+
@property
|
20 |
+
@abstractmethod
|
21 |
+
def _api_base_url(self) -> str:
|
22 |
+
"""Return _api_base of client host bash url."""
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def set_api_key(self, api_key: str):
|
26 |
+
"""set set_api_key"""
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def set_api_base_url(self, api_base_url: str):
|
30 |
+
"""set api_base_url"""
|
31 |
+
@abstractmethod
|
32 |
+
def call_model_name(self, model_name):
|
33 |
+
"""call model name of client"""
|
models/chatglm_llm.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from langchain.llms.base import LLM
|
3 |
+
from typing import Optional, List
|
4 |
+
from models.loader import LoaderCheckPoint
|
5 |
+
from models.base import (BaseAnswer,
|
6 |
+
AnswerResult)
|
7 |
+
|
8 |
+
|
9 |
+
class ChatGLM(BaseAnswer, LLM, ABC):
|
10 |
+
max_token: int = 10000
|
11 |
+
temperature: float = 0.01
|
12 |
+
top_p = 0.9
|
13 |
+
checkPoint: LoaderCheckPoint = None
|
14 |
+
# history = []
|
15 |
+
history_len: int = 10
|
16 |
+
|
17 |
+
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
18 |
+
super().__init__()
|
19 |
+
self.checkPoint = checkPoint
|
20 |
+
|
21 |
+
@property
|
22 |
+
def _llm_type(self) -> str:
|
23 |
+
return "ChatGLM"
|
24 |
+
|
25 |
+
@property
|
26 |
+
def _check_point(self) -> LoaderCheckPoint:
|
27 |
+
return self.checkPoint
|
28 |
+
|
29 |
+
@property
|
30 |
+
def _history_len(self) -> int:
|
31 |
+
return self.history_len
|
32 |
+
|
33 |
+
def set_history_len(self, history_len: int = 10) -> None:
|
34 |
+
self.history_len = history_len
|
35 |
+
|
36 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
37 |
+
print(f"__call:{prompt}")
|
38 |
+
response, _ = self.checkPoint.model.chat(
|
39 |
+
self.checkPoint.tokenizer,
|
40 |
+
prompt,
|
41 |
+
history=[],
|
42 |
+
max_length=self.max_token,
|
43 |
+
temperature=self.temperature
|
44 |
+
)
|
45 |
+
print(f"response:{response}")
|
46 |
+
print(f"+++++++++++++++++++++++++++++++++++")
|
47 |
+
return response
|
48 |
+
|
49 |
+
def generatorAnswer(self, prompt: str,
|
50 |
+
history: List[List[str]] = [],
|
51 |
+
streaming: bool = False):
|
52 |
+
|
53 |
+
if streaming:
|
54 |
+
history += [[]]
|
55 |
+
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
56 |
+
self.checkPoint.tokenizer,
|
57 |
+
prompt,
|
58 |
+
history=history[-self.history_len:-1] if self.history_len > 1 else [],
|
59 |
+
max_length=self.max_token,
|
60 |
+
temperature=self.temperature
|
61 |
+
)):
|
62 |
+
# self.checkPoint.clear_torch_cache()
|
63 |
+
history[-1] = [prompt, stream_resp]
|
64 |
+
answer_result = AnswerResult()
|
65 |
+
answer_result.history = history
|
66 |
+
answer_result.llm_output = {"answer": stream_resp}
|
67 |
+
yield answer_result
|
68 |
+
else:
|
69 |
+
response, _ = self.checkPoint.model.chat(
|
70 |
+
self.checkPoint.tokenizer,
|
71 |
+
prompt,
|
72 |
+
history=history[-self.history_len:] if self.history_len > 0 else [],
|
73 |
+
max_length=self.max_token,
|
74 |
+
temperature=self.temperature
|
75 |
+
)
|
76 |
+
self.checkPoint.clear_torch_cache()
|
77 |
+
history += [[prompt, response]]
|
78 |
+
answer_result = AnswerResult()
|
79 |
+
answer_result.history = history
|
80 |
+
answer_result.llm_output = {"answer": response}
|
81 |
+
yield answer_result
|
82 |
+
|
83 |
+
|
models/fastchat_openai_llm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
import requests
|
3 |
+
from typing import Optional, List
|
4 |
+
from langchain.llms.base import LLM
|
5 |
+
|
6 |
+
from models.loader import LoaderCheckPoint
|
7 |
+
from models.base import (RemoteRpcModel,
|
8 |
+
AnswerResult)
|
9 |
+
from typing import (
|
10 |
+
Collection,
|
11 |
+
Dict
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _build_message_template() -> Dict[str, str]:
|
16 |
+
"""
|
17 |
+
:return: 结构
|
18 |
+
"""
|
19 |
+
return {
|
20 |
+
"role": "",
|
21 |
+
"content": "",
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
26 |
+
api_base_url: str = "http://localhost:8000/v1"
|
27 |
+
model_name: str = "chatglm-6b"
|
28 |
+
max_token: int = 10000
|
29 |
+
temperature: float = 0.01
|
30 |
+
top_p = 0.9
|
31 |
+
checkPoint: LoaderCheckPoint = None
|
32 |
+
history = []
|
33 |
+
history_len: int = 10
|
34 |
+
|
35 |
+
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
36 |
+
super().__init__()
|
37 |
+
self.checkPoint = checkPoint
|
38 |
+
|
39 |
+
@property
|
40 |
+
def _llm_type(self) -> str:
|
41 |
+
return "FastChat"
|
42 |
+
|
43 |
+
@property
|
44 |
+
def _check_point(self) -> LoaderCheckPoint:
|
45 |
+
return self.checkPoint
|
46 |
+
|
47 |
+
@property
|
48 |
+
def _history_len(self) -> int:
|
49 |
+
return self.history_len
|
50 |
+
|
51 |
+
def set_history_len(self, history_len: int = 10) -> None:
|
52 |
+
self.history_len = history_len
|
53 |
+
|
54 |
+
@property
|
55 |
+
def _api_key(self) -> str:
|
56 |
+
pass
|
57 |
+
|
58 |
+
@property
|
59 |
+
def _api_base_url(self) -> str:
|
60 |
+
return self.api_base_url
|
61 |
+
|
62 |
+
def set_api_key(self, api_key: str):
|
63 |
+
pass
|
64 |
+
|
65 |
+
def set_api_base_url(self, api_base_url: str):
|
66 |
+
self.api_base_url = api_base_url
|
67 |
+
|
68 |
+
def call_model_name(self, model_name):
|
69 |
+
self.model_name = model_name
|
70 |
+
|
71 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
72 |
+
print(f"__call:{prompt}")
|
73 |
+
try:
|
74 |
+
import openai
|
75 |
+
# Not support yet
|
76 |
+
openai.api_key = "EMPTY"
|
77 |
+
openai.api_base = self.api_base_url
|
78 |
+
except ImportError:
|
79 |
+
raise ValueError(
|
80 |
+
"Could not import openai python package. "
|
81 |
+
"Please install it with `pip install openai`."
|
82 |
+
)
|
83 |
+
# create a chat completion
|
84 |
+
completion = openai.ChatCompletion.create(
|
85 |
+
model=self.model_name,
|
86 |
+
messages=self.build_message_list(prompt)
|
87 |
+
)
|
88 |
+
print(f"response:{completion.choices[0].message.content}")
|
89 |
+
print(f"+++++++++++++++++++++++++++++++++++")
|
90 |
+
return completion.choices[0].message.content
|
91 |
+
|
92 |
+
# 将历史对话数组转换为文本格式
|
93 |
+
def build_message_list(self, query) -> Collection[Dict[str, str]]:
|
94 |
+
build_message_list: Collection[Dict[str, str]] = []
|
95 |
+
history = self.history[-self.history_len:] if self.history_len > 0 else []
|
96 |
+
for i, (old_query, response) in enumerate(history):
|
97 |
+
user_build_message = _build_message_template()
|
98 |
+
user_build_message['role'] = 'user'
|
99 |
+
user_build_message['content'] = old_query
|
100 |
+
system_build_message = _build_message_template()
|
101 |
+
system_build_message['role'] = 'system'
|
102 |
+
system_build_message['content'] = response
|
103 |
+
build_message_list.append(user_build_message)
|
104 |
+
build_message_list.append(system_build_message)
|
105 |
+
|
106 |
+
user_build_message = _build_message_template()
|
107 |
+
user_build_message['role'] = 'user'
|
108 |
+
user_build_message['content'] = query
|
109 |
+
build_message_list.append(user_build_message)
|
110 |
+
return build_message_list
|
111 |
+
|
112 |
+
def generatorAnswer(self, prompt: str,
|
113 |
+
history: List[List[str]] = [],
|
114 |
+
streaming: bool = False):
|
115 |
+
|
116 |
+
try:
|
117 |
+
import openai
|
118 |
+
# Not support yet
|
119 |
+
openai.api_key = "EMPTY"
|
120 |
+
openai.api_base = self.api_base_url
|
121 |
+
except ImportError:
|
122 |
+
raise ValueError(
|
123 |
+
"Could not import openai python package. "
|
124 |
+
"Please install it with `pip install openai`."
|
125 |
+
)
|
126 |
+
# create a chat completion
|
127 |
+
completion = openai.ChatCompletion.create(
|
128 |
+
model=self.model_name,
|
129 |
+
messages=self.build_message_list(prompt)
|
130 |
+
)
|
131 |
+
|
132 |
+
history += [[prompt, completion.choices[0].message.content]]
|
133 |
+
answer_result = AnswerResult()
|
134 |
+
answer_result.history = history
|
135 |
+
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
136 |
+
|
137 |
+
yield answer_result
|
models/llama_llm.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
from langchain.llms.base import LLM
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import transformers
|
7 |
+
from transformers.generation.logits_process import LogitsProcessor
|
8 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
9 |
+
from typing import Optional, List, Dict, Any
|
10 |
+
from models.loader import LoaderCheckPoint
|
11 |
+
from models.base import (BaseAnswer,
|
12 |
+
AnswerResult)
|
13 |
+
|
14 |
+
|
15 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
16 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
17 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
18 |
+
scores.zero_()
|
19 |
+
scores[..., 5] = 5e4
|
20 |
+
return scores
|
21 |
+
|
22 |
+
|
23 |
+
class LLamaLLM(BaseAnswer, LLM, ABC):
|
24 |
+
checkPoint: LoaderCheckPoint = None
|
25 |
+
# history = []
|
26 |
+
history_len: int = 3
|
27 |
+
max_new_tokens: int = 500
|
28 |
+
num_beams: int = 1
|
29 |
+
temperature: float = 0.5
|
30 |
+
top_p: float = 0.4
|
31 |
+
top_k: int = 10
|
32 |
+
repetition_penalty: float = 1.2
|
33 |
+
encoder_repetition_penalty: int = 1
|
34 |
+
min_length: int = 0
|
35 |
+
logits_processor: LogitsProcessorList = None
|
36 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None
|
37 |
+
eos_token_id: Optional[int] = [2]
|
38 |
+
|
39 |
+
state: object = {'max_new_tokens': 50,
|
40 |
+
'seed': 1,
|
41 |
+
'temperature': 0, 'top_p': 0.1,
|
42 |
+
'top_k': 40, 'typical_p': 1,
|
43 |
+
'repetition_penalty': 1.2,
|
44 |
+
'encoder_repetition_penalty': 1,
|
45 |
+
'no_repeat_ngram_size': 0,
|
46 |
+
'min_length': 0,
|
47 |
+
'penalty_alpha': 0,
|
48 |
+
'num_beams': 1,
|
49 |
+
'length_penalty': 1,
|
50 |
+
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
|
51 |
+
'truncation_length': 2048, 'custom_stopping_strings': '',
|
52 |
+
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
|
53 |
+
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
|
54 |
+
'pre_layer': 0, 'gpu_memory_0': 0}
|
55 |
+
|
56 |
+
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
57 |
+
super().__init__()
|
58 |
+
self.checkPoint = checkPoint
|
59 |
+
|
60 |
+
@property
|
61 |
+
def _llm_type(self) -> str:
|
62 |
+
return "LLamaLLM"
|
63 |
+
|
64 |
+
@property
|
65 |
+
def _check_point(self) -> LoaderCheckPoint:
|
66 |
+
return self.checkPoint
|
67 |
+
|
68 |
+
def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
69 |
+
input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt',
|
70 |
+
add_special_tokens=add_special_tokens)
|
71 |
+
# This is a hack for making replies more creative.
|
72 |
+
if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id:
|
73 |
+
input_ids = input_ids[:, 1:]
|
74 |
+
|
75 |
+
# Llama adds this extra token when the first character is '\n', and this
|
76 |
+
# compromises the stopping criteria, so we just remove it
|
77 |
+
if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
78 |
+
input_ids = input_ids[:, 1:]
|
79 |
+
|
80 |
+
# Handling truncation
|
81 |
+
if truncation_length is not None:
|
82 |
+
input_ids = input_ids[:, -truncation_length:]
|
83 |
+
|
84 |
+
return input_ids.cuda()
|
85 |
+
|
86 |
+
def decode(self, output_ids):
|
87 |
+
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
|
88 |
+
return reply
|
89 |
+
|
90 |
+
# 将历史对话数组转换为文本格式
|
91 |
+
def history_to_text(self, query, history):
|
92 |
+
"""
|
93 |
+
历史对话软提示
|
94 |
+
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
|
95 |
+
数组转换为所需的文本格式。然后,我们将格式化后的历史文本
|
96 |
+
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
|
97 |
+
:return:
|
98 |
+
"""
|
99 |
+
formatted_history = ''
|
100 |
+
history = history[-self.history_len:] if self.history_len > 0 else []
|
101 |
+
if len(history) > 0:
|
102 |
+
for i, (old_query, response) in enumerate(history):
|
103 |
+
formatted_history += "### Human:{}\n### Assistant:{}\n".format(old_query, response)
|
104 |
+
formatted_history += "### Human:{}\n### Assistant:".format(query)
|
105 |
+
return formatted_history
|
106 |
+
|
107 |
+
def prepare_inputs_for_generation(self,
|
108 |
+
input_ids: torch.LongTensor):
|
109 |
+
"""
|
110 |
+
预生成注意力掩码和 输入序列中每个位置的索引的张量
|
111 |
+
# TODO 没有思路
|
112 |
+
:return:
|
113 |
+
"""
|
114 |
+
|
115 |
+
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
|
116 |
+
|
117 |
+
attention_mask = self.get_masks(input_ids, input_ids.device)
|
118 |
+
|
119 |
+
position_ids = self.get_position_ids(
|
120 |
+
input_ids,
|
121 |
+
device=input_ids.device,
|
122 |
+
mask_positions=mask_positions
|
123 |
+
)
|
124 |
+
|
125 |
+
return input_ids, position_ids, attention_mask
|
126 |
+
|
127 |
+
@property
|
128 |
+
def _history_len(self) -> int:
|
129 |
+
return self.history_len
|
130 |
+
|
131 |
+
def set_history_len(self, history_len: int = 10) -> None:
|
132 |
+
self.history_len = history_len
|
133 |
+
|
134 |
+
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
135 |
+
print(f"__call:{prompt}")
|
136 |
+
if self.logits_processor is None:
|
137 |
+
self.logits_processor = LogitsProcessorList()
|
138 |
+
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
139 |
+
|
140 |
+
gen_kwargs = {
|
141 |
+
"max_new_tokens": self.max_new_tokens,
|
142 |
+
"num_beams": self.num_beams,
|
143 |
+
"top_p": self.top_p,
|
144 |
+
"do_sample": True,
|
145 |
+
"top_k": self.top_k,
|
146 |
+
"repetition_penalty": self.repetition_penalty,
|
147 |
+
"encoder_repetition_penalty": self.encoder_repetition_penalty,
|
148 |
+
"min_length": self.min_length,
|
149 |
+
"temperature": self.temperature,
|
150 |
+
"eos_token_id": self.checkPoint.tokenizer.eos_token_id,
|
151 |
+
"logits_processor": self.logits_processor}
|
152 |
+
|
153 |
+
# 向量转换
|
154 |
+
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
155 |
+
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
156 |
+
|
157 |
+
|
158 |
+
gen_kwargs.update({'inputs': input_ids})
|
159 |
+
# 注意力掩码
|
160 |
+
# gen_kwargs.update({'attention_mask': attention_mask})
|
161 |
+
# gen_kwargs.update({'position_ids': position_ids})
|
162 |
+
if self.stopping_criteria is None:
|
163 |
+
self.stopping_criteria = transformers.StoppingCriteriaList()
|
164 |
+
# 观测输出
|
165 |
+
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
166 |
+
|
167 |
+
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
168 |
+
new_tokens = len(output_ids[0]) - len(input_ids[0])
|
169 |
+
reply = self.decode(output_ids[0][-new_tokens:])
|
170 |
+
print(f"response:{reply}")
|
171 |
+
print(f"+++++++++++++++++++++++++++++++++++")
|
172 |
+
return reply
|
173 |
+
|
174 |
+
def generatorAnswer(self, prompt: str,
|
175 |
+
history: List[List[str]] = [],
|
176 |
+
streaming: bool = False):
|
177 |
+
|
178 |
+
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
179 |
+
softprompt = self.history_to_text(prompt,history=history)
|
180 |
+
response = self._call(prompt=softprompt, stop=['\n###'])
|
181 |
+
|
182 |
+
answer_result = AnswerResult()
|
183 |
+
answer_result.history = history + [[prompt, response]]
|
184 |
+
answer_result.llm_output = {"answer": response}
|
185 |
+
yield answer_result
|
models/loader/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .loader import *
|
models/loader/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (182 Bytes). View file
|
|
models/loader/__pycache__/args.cpython-310.pyc
ADDED
Binary file (1.73 kB). View file
|
|
models/loader/__pycache__/loader.cpython-310.pyc
ADDED
Binary file (11.1 kB). View file
|
|
models/loader/args.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from configs.model_config import *
|
4 |
+
|
5 |
+
|
6 |
+
# Additional argparse types
|
7 |
+
def path(string):
|
8 |
+
if not string:
|
9 |
+
return ''
|
10 |
+
s = os.path.expanduser(string)
|
11 |
+
if not os.path.exists(s):
|
12 |
+
raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"')
|
13 |
+
return s
|
14 |
+
|
15 |
+
|
16 |
+
def file_path(string):
|
17 |
+
if not string:
|
18 |
+
return ''
|
19 |
+
s = os.path.expanduser(string)
|
20 |
+
if not os.path.isfile(s):
|
21 |
+
raise argparse.ArgumentTypeError(f'No such file: "{string}"')
|
22 |
+
return s
|
23 |
+
|
24 |
+
|
25 |
+
def dir_path(string):
|
26 |
+
if not string:
|
27 |
+
return ''
|
28 |
+
s = os.path.expanduser(string)
|
29 |
+
if not os.path.isdir(s):
|
30 |
+
raise argparse.ArgumentTypeError(f'No such directory: "{string}"')
|
31 |
+
return s
|
32 |
+
|
33 |
+
|
34 |
+
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
35 |
+
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain | '
|
36 |
+
'基于本地知识库的 ChatGLM 问答')
|
37 |
+
|
38 |
+
parser.add_argument('--no-remote-model', action='store_true', help='remote in the model on '
|
39 |
+
'loader checkpoint, '
|
40 |
+
'if your load local '
|
41 |
+
'model to add the ` '
|
42 |
+
'--no-remote-model`')
|
43 |
+
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
44 |
+
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
45 |
+
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
46 |
+
|
47 |
+
# Accelerate/transformers
|
48 |
+
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
49 |
+
help='Load the model with 8-bit precision.')
|
50 |
+
parser.add_argument('--bf16', action='store_true', default=BF16,
|
51 |
+
help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
52 |
+
|
53 |
+
args = parser.parse_args([])
|
54 |
+
# Generares dict with a default value for each argument
|
55 |
+
DEFAULT_ARGS = vars(args)
|
models/loader/loader.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional, List, Dict, Tuple, Union
|
8 |
+
import torch
|
9 |
+
import transformers
|
10 |
+
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
11 |
+
AutoTokenizer, LlamaTokenizer)
|
12 |
+
from configs.model_config import LLM_DEVICE
|
13 |
+
|
14 |
+
|
15 |
+
class LoaderCheckPoint:
|
16 |
+
"""
|
17 |
+
加载自定义 model CheckPoint
|
18 |
+
"""
|
19 |
+
# remote in the model on loader checkpoint
|
20 |
+
no_remote_model: bool = False
|
21 |
+
# 模型名称
|
22 |
+
model_name: str = None
|
23 |
+
tokenizer: object = None
|
24 |
+
# 模型全路径
|
25 |
+
model_path: str = None
|
26 |
+
model: object = None
|
27 |
+
model_config: object = None
|
28 |
+
lora_names: set = []
|
29 |
+
lora_dir: str = None
|
30 |
+
ptuning_dir: str = None
|
31 |
+
use_ptuning_v2: bool = False
|
32 |
+
# 如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156
|
33 |
+
# 另一个原因可能是由于bitsandbytes安装时选择了系统环境变量里不匹配的cuda版本,
|
34 |
+
# 例如PATH下存在cuda10.2和cuda11.2,bitsandbytes安装时选择了10.2,而torch等安装依赖的版本是11.2
|
35 |
+
# 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本,一劳永逸的方法是:
|
36 |
+
# 0. 在终端执行`pip uninstall bitsandbytes`
|
37 |
+
# 1. 删除.bashrc文件下关于PATH的条目
|
38 |
+
# 2. 在终端执行 `echo $PATH >> .bashrc`
|
39 |
+
# 3. 删除.bashrc文件下PATH中关于不匹配的cuda版本路径
|
40 |
+
# 4. 在终端执行`source .bashrc`
|
41 |
+
# 5. 再执行`pip install bitsandbytes`
|
42 |
+
|
43 |
+
load_in_8bit: bool = False
|
44 |
+
is_llamacpp: bool = False
|
45 |
+
bf16: bool = False
|
46 |
+
params: object = None
|
47 |
+
# 自定义设备网络
|
48 |
+
device_map: Optional[Dict[str, int]] = None
|
49 |
+
# 默认 cuda ,如果不支持cuda使用多卡, 如果不支持多卡 使用cpu
|
50 |
+
llm_device = LLM_DEVICE
|
51 |
+
|
52 |
+
def __init__(self, params: dict = None):
|
53 |
+
"""
|
54 |
+
模型初始化
|
55 |
+
:param params:
|
56 |
+
"""
|
57 |
+
self.model = None
|
58 |
+
self.tokenizer = None
|
59 |
+
self.params = params or {}
|
60 |
+
self.model_name = params.get('model_name', False)
|
61 |
+
self.model_path = params.get('model_path', None)
|
62 |
+
self.no_remote_model = params.get('no_remote_model', False)
|
63 |
+
self.lora = params.get('lora', '')
|
64 |
+
self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
|
65 |
+
self.lora_dir = params.get('lora_dir', '')
|
66 |
+
self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
|
67 |
+
self.load_in_8bit = params.get('load_in_8bit', False)
|
68 |
+
self.bf16 = params.get('bf16', False)
|
69 |
+
|
70 |
+
def _load_model_config(self, model_name):
|
71 |
+
|
72 |
+
if self.model_path:
|
73 |
+
checkpoint = Path(f'{self.model_path}')
|
74 |
+
else:
|
75 |
+
if not self.no_remote_model:
|
76 |
+
checkpoint = model_name
|
77 |
+
else:
|
78 |
+
raise ValueError(
|
79 |
+
"本地模型local_model_path未配置路径"
|
80 |
+
)
|
81 |
+
|
82 |
+
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
83 |
+
|
84 |
+
return model_config
|
85 |
+
|
86 |
+
def _load_model(self, model_name):
|
87 |
+
"""
|
88 |
+
加载自定义位置的model
|
89 |
+
:param model_name:
|
90 |
+
:return:
|
91 |
+
"""
|
92 |
+
print(f"Loading {model_name}...")
|
93 |
+
t0 = time.time()
|
94 |
+
|
95 |
+
if self.model_path:
|
96 |
+
checkpoint = Path(f'{self.model_path}')
|
97 |
+
else:
|
98 |
+
if not self.no_remote_model:
|
99 |
+
checkpoint = model_name
|
100 |
+
else:
|
101 |
+
raise ValueError(
|
102 |
+
"本地模型local_model_path未配置路径"
|
103 |
+
)
|
104 |
+
|
105 |
+
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
|
106 |
+
if 'chatglm' in model_name.lower():
|
107 |
+
LoaderClass = AutoModel
|
108 |
+
else:
|
109 |
+
LoaderClass = AutoModelForCausalLM
|
110 |
+
|
111 |
+
# Load the model in simple 16-bit mode by default
|
112 |
+
# 如果加载没问题,但在推理时报错RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
|
113 |
+
# 那还是因为显存不够,此时只能考虑--load-in-8bit,或者配置默认模型为`chatglm-6b-int8`
|
114 |
+
if not any([self.llm_device.lower() == "cpu",
|
115 |
+
self.load_in_8bit, self.is_llamacpp]):
|
116 |
+
|
117 |
+
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
|
118 |
+
# 根据当前设备GPU数量决定是否进行多卡部署
|
119 |
+
num_gpus = torch.cuda.device_count()
|
120 |
+
if num_gpus < 2 and self.device_map is None:
|
121 |
+
model = (
|
122 |
+
LoaderClass.from_pretrained(checkpoint,
|
123 |
+
config=self.model_config,
|
124 |
+
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
125 |
+
trust_remote_code=True)
|
126 |
+
.half()
|
127 |
+
.cuda()
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
from accelerate import dispatch_model
|
131 |
+
|
132 |
+
model = LoaderClass.from_pretrained(checkpoint,
|
133 |
+
config=self.model_config,
|
134 |
+
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
135 |
+
trust_remote_code=True).half()
|
136 |
+
# 可传入device_map自定义每张卡的部署情况
|
137 |
+
if self.device_map is None:
|
138 |
+
if 'chatglm' in model_name.lower():
|
139 |
+
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
140 |
+
elif 'moss' in model_name.lower():
|
141 |
+
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
|
142 |
+
else:
|
143 |
+
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
144 |
+
|
145 |
+
model = dispatch_model(model, device_map=self.device_map)
|
146 |
+
else:
|
147 |
+
model = (
|
148 |
+
LoaderClass.from_pretrained(
|
149 |
+
checkpoint,
|
150 |
+
config=self.model_config,
|
151 |
+
trust_remote_code=True)
|
152 |
+
.float()
|
153 |
+
.to(self.llm_device)
|
154 |
+
)
|
155 |
+
|
156 |
+
elif self.is_llamacpp:
|
157 |
+
|
158 |
+
try:
|
159 |
+
from models.extensions.llamacpp_model_alternative import LlamaCppModel
|
160 |
+
|
161 |
+
except ImportError as exc:
|
162 |
+
raise ValueError(
|
163 |
+
"Could not import depend python package "
|
164 |
+
"Please install it with `pip install llama-cpp-python`."
|
165 |
+
) from exc
|
166 |
+
|
167 |
+
model_file = list(checkpoint.glob('ggml*.bin'))[0]
|
168 |
+
print(f"llama.cpp weights detected: {model_file}\n")
|
169 |
+
|
170 |
+
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
171 |
+
return model, tokenizer
|
172 |
+
|
173 |
+
elif self.load_in_8bit:
|
174 |
+
try:
|
175 |
+
from accelerate import init_empty_weights
|
176 |
+
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
177 |
+
from transformers import BitsAndBytesConfig
|
178 |
+
|
179 |
+
except ImportError as exc:
|
180 |
+
raise ValueError(
|
181 |
+
"Could not import depend python package "
|
182 |
+
"Please install it with `pip install transformers` "
|
183 |
+
"`pip install bitsandbytes``pip install accelerate`."
|
184 |
+
) from exc
|
185 |
+
|
186 |
+
params = {"low_cpu_mem_usage": True}
|
187 |
+
|
188 |
+
if not self.llm_device.lower().startswith("cuda"):
|
189 |
+
raise SystemError("8bit 模型需要 CUDA 支持,或者改用量化后模型!")
|
190 |
+
else:
|
191 |
+
params["device_map"] = 'auto'
|
192 |
+
params["trust_remote_code"] = True
|
193 |
+
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True,
|
194 |
+
llm_int8_enable_fp32_cpu_offload=False)
|
195 |
+
|
196 |
+
with init_empty_weights():
|
197 |
+
model = LoaderClass.from_config(self.model_config,trust_remote_code = True)
|
198 |
+
model.tie_weights()
|
199 |
+
if self.device_map is not None:
|
200 |
+
params['device_map'] = self.device_map
|
201 |
+
else:
|
202 |
+
params['device_map'] = infer_auto_device_map(
|
203 |
+
model,
|
204 |
+
dtype=torch.int8,
|
205 |
+
no_split_module_classes=model._no_split_modules
|
206 |
+
)
|
207 |
+
try:
|
208 |
+
|
209 |
+
model = LoaderClass.from_pretrained(checkpoint, **params)
|
210 |
+
except ImportError as exc:
|
211 |
+
raise ValueError(
|
212 |
+
"如果开启了8bit量化加载,项目无法启动,参考此位置,选择合适的cuda版本,https://github.com/TimDettmers/bitsandbytes/issues/156"
|
213 |
+
) from exc
|
214 |
+
# Custom
|
215 |
+
else:
|
216 |
+
|
217 |
+
print(
|
218 |
+
"Warning: self.llm_device is False.\nThis means that no use GPU bring to be load CPU mode\n")
|
219 |
+
params = {"low_cpu_mem_usage": True, "torch_dtype": torch.float32, "trust_remote_code": True}
|
220 |
+
model = LoaderClass.from_pretrained(checkpoint, **params).to(self.llm_device, dtype=float)
|
221 |
+
|
222 |
+
# Loading the tokenizer
|
223 |
+
if type(model) is transformers.LlamaForCausalLM:
|
224 |
+
tokenizer = LlamaTokenizer.from_pretrained(checkpoint, clean_up_tokenization_spaces=True)
|
225 |
+
# Leaving this here until the LLaMA tokenizer gets figured out.
|
226 |
+
# For some people this fixes things, for others it causes an error.
|
227 |
+
try:
|
228 |
+
tokenizer.eos_token_id = 2
|
229 |
+
tokenizer.bos_token_id = 1
|
230 |
+
tokenizer.pad_token_id = 0
|
231 |
+
except Exception as e:
|
232 |
+
print(e)
|
233 |
+
pass
|
234 |
+
else:
|
235 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
|
236 |
+
|
237 |
+
print(f"Loaded the model in {(time.time() - t0):.2f} seconds.")
|
238 |
+
return model, tokenizer
|
239 |
+
|
240 |
+
def chatglm_auto_configure_device_map(self, num_gpus: int) -> Dict[str, int]:
|
241 |
+
# transformer.word_embeddings 占用1层
|
242 |
+
# transformer.final_layernorm 和 lm_head 占用1层
|
243 |
+
# transformer.layers 占用 28 层
|
244 |
+
# 总共30层分配到num_gpus张卡上
|
245 |
+
num_trans_layers = 28
|
246 |
+
per_gpu_layers = 30 / num_gpus
|
247 |
+
|
248 |
+
# bugfix: PEFT加载lora模型出现的层命名不同
|
249 |
+
if self.lora:
|
250 |
+
layer_prefix = 'base_model.model.transformer'
|
251 |
+
else:
|
252 |
+
layer_prefix = 'transformer'
|
253 |
+
|
254 |
+
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
255 |
+
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
256 |
+
# linux下 model.device 会被设置成 lm_head.device
|
257 |
+
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
258 |
+
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
259 |
+
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
260 |
+
|
261 |
+
encode = ""
|
262 |
+
if 'chatglm2' in self.model_name:
|
263 |
+
device_map = {
|
264 |
+
f"{layer_prefix}.embedding.word_embeddings": 0,
|
265 |
+
f"{layer_prefix}.rotary_pos_emb": 0,
|
266 |
+
f"{layer_prefix}.output_layer": 0,
|
267 |
+
f"{layer_prefix}.encoder.final_layernorm": 0,
|
268 |
+
f"base_model.model.output_layer": 0
|
269 |
+
}
|
270 |
+
encode = ".encoder"
|
271 |
+
else:
|
272 |
+
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
273 |
+
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
274 |
+
f'base_model.model.lm_head': 0, }
|
275 |
+
used = 2
|
276 |
+
gpu_target = 0
|
277 |
+
for i in range(num_trans_layers):
|
278 |
+
if used >= per_gpu_layers:
|
279 |
+
gpu_target += 1
|
280 |
+
used = 0
|
281 |
+
assert gpu_target < num_gpus
|
282 |
+
device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target
|
283 |
+
used += 1
|
284 |
+
|
285 |
+
return device_map
|
286 |
+
|
287 |
+
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
|
288 |
+
try:
|
289 |
+
|
290 |
+
from accelerate import init_empty_weights
|
291 |
+
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
292 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
293 |
+
from transformers.modeling_utils import no_init_weights
|
294 |
+
from transformers.utils import ContextManagers
|
295 |
+
except ImportError as exc:
|
296 |
+
raise ValueError(
|
297 |
+
"Could not import depend python package "
|
298 |
+
"Please install it with `pip install transformers` "
|
299 |
+
"`pip install bitsandbytes``pip install accelerate`."
|
300 |
+
) from exc
|
301 |
+
|
302 |
+
if self.model_path:
|
303 |
+
checkpoint = Path(f'{self.model_path}')
|
304 |
+
else:
|
305 |
+
if not self.no_remote_model:
|
306 |
+
checkpoint = model_name
|
307 |
+
else:
|
308 |
+
raise ValueError(
|
309 |
+
"本地模型local_model_path未配置路径"
|
310 |
+
)
|
311 |
+
|
312 |
+
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
313 |
+
pretrained_model_name_or_path=checkpoint)
|
314 |
+
|
315 |
+
with ContextManagers([no_init_weights(_enable=True), init_empty_weights()]):
|
316 |
+
model = cls(self.model_config)
|
317 |
+
max_memory = get_balanced_memory(model, dtype=torch.int8 if self.load_in_8bit else None,
|
318 |
+
low_zero=False, no_split_module_classes=model._no_split_modules)
|
319 |
+
device_map = infer_auto_device_map(
|
320 |
+
model, dtype=torch.float16 if not self.load_in_8bit else torch.int8, max_memory=max_memory,
|
321 |
+
no_split_module_classes=model._no_split_modules)
|
322 |
+
device_map["transformer.wte"] = 0
|
323 |
+
device_map["transformer.drop"] = 0
|
324 |
+
device_map["transformer.ln_f"] = 0
|
325 |
+
device_map["lm_head"] = 0
|
326 |
+
return device_map
|
327 |
+
|
328 |
+
def _add_lora_to_model(self, lora_names):
|
329 |
+
|
330 |
+
try:
|
331 |
+
|
332 |
+
from peft import PeftModel
|
333 |
+
|
334 |
+
except ImportError as exc:
|
335 |
+
raise ValueError(
|
336 |
+
"Could not import depend python package. "
|
337 |
+
"Please install it with `pip install peft``pip install accelerate`."
|
338 |
+
) from exc
|
339 |
+
# 目前加载的lora
|
340 |
+
prior_set = set(self.lora_names)
|
341 |
+
# 需要加载的
|
342 |
+
added_set = set(lora_names) - prior_set
|
343 |
+
# 删除的lora
|
344 |
+
removed_set = prior_set - set(lora_names)
|
345 |
+
self.lora_names = list(lora_names)
|
346 |
+
|
347 |
+
# Nothing to do = skip.
|
348 |
+
if len(added_set) == 0 and len(removed_set) == 0:
|
349 |
+
return
|
350 |
+
|
351 |
+
# Only adding, and already peft? Do it the easy way.
|
352 |
+
if len(removed_set) == 0 and len(prior_set) > 0:
|
353 |
+
print(f"Adding the LoRA(s) named {added_set} to the model...")
|
354 |
+
for lora in added_set:
|
355 |
+
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
|
356 |
+
return
|
357 |
+
|
358 |
+
# If removing anything, disable all and re-add.
|
359 |
+
if len(removed_set) > 0:
|
360 |
+
self.model.disable_adapter()
|
361 |
+
|
362 |
+
if len(lora_names) > 0:
|
363 |
+
print("Applying the following LoRAs to {}: {}".format(self.model_name, ', '.join(lora_names)))
|
364 |
+
params = {}
|
365 |
+
if self.llm_device.lower() != "cpu":
|
366 |
+
params['dtype'] = self.model.dtype
|
367 |
+
if hasattr(self.model, "hf_device_map"):
|
368 |
+
params['device_map'] = {"base_model.model." + k: v for k, v in self.model.hf_device_map.items()}
|
369 |
+
elif self.load_in_8bit:
|
370 |
+
params['device_map'] = {'': 0}
|
371 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
372 |
+
|
373 |
+
self.model = PeftModel.from_pretrained(self.model, Path(f"{self.lora_dir}/{lora_names[0]}"), **params)
|
374 |
+
|
375 |
+
for lora in lora_names[1:]:
|
376 |
+
self.model.load_adapter(Path(f"{self.lora_dir}/{lora}"), lora)
|
377 |
+
|
378 |
+
if not self.load_in_8bit and self.llm_device.lower() != "cpu":
|
379 |
+
|
380 |
+
if not hasattr(self.model, "hf_device_map"):
|
381 |
+
if torch.has_mps:
|
382 |
+
device = torch.device('mps')
|
383 |
+
self.model = self.model.to(device)
|
384 |
+
else:
|
385 |
+
self.model = self.model.cuda()
|
386 |
+
|
387 |
+
def clear_torch_cache(self):
|
388 |
+
gc.collect()
|
389 |
+
if self.llm_device.lower() != "cpu":
|
390 |
+
if torch.has_mps:
|
391 |
+
try:
|
392 |
+
from torch.mps import empty_cache
|
393 |
+
empty_cache()
|
394 |
+
except Exception as e:
|
395 |
+
print(e)
|
396 |
+
print(
|
397 |
+
"如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
|
398 |
+
elif torch.has_cuda:
|
399 |
+
device_id = "0" if torch.cuda.is_available() else None
|
400 |
+
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
|
401 |
+
with torch.cuda.device(CUDA_DEVICE):
|
402 |
+
torch.cuda.empty_cache()
|
403 |
+
torch.cuda.ipc_collect()
|
404 |
+
else:
|
405 |
+
print("未检测到 cuda 或 mps,暂不支持清理显存")
|
406 |
+
|
407 |
+
def unload_model(self):
|
408 |
+
del self.model
|
409 |
+
del self.tokenizer
|
410 |
+
self.model = self.tokenizer = None
|
411 |
+
self.clear_torch_cache()
|
412 |
+
|
413 |
+
def set_model_path(self, model_path):
|
414 |
+
self.model_path = model_path
|
415 |
+
|
416 |
+
def reload_model(self):
|
417 |
+
self.unload_model()
|
418 |
+
self.model_config = self._load_model_config(self.model_name)
|
419 |
+
|
420 |
+
if self.use_ptuning_v2:
|
421 |
+
try:
|
422 |
+
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
|
423 |
+
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
424 |
+
prefix_encoder_file.close()
|
425 |
+
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
426 |
+
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
427 |
+
except Exception as e:
|
428 |
+
print("加载PrefixEncoder config.json失败")
|
429 |
+
|
430 |
+
self.model, self.tokenizer = self._load_model(self.model_name)
|
431 |
+
|
432 |
+
if self.lora:
|
433 |
+
self._add_lora_to_model([self.lora])
|
434 |
+
|
435 |
+
if self.use_ptuning_v2:
|
436 |
+
try:
|
437 |
+
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
|
438 |
+
new_prefix_state_dict = {}
|
439 |
+
for k, v in prefix_state_dict.items():
|
440 |
+
if k.startswith("transformer.prefix_encoder."):
|
441 |
+
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
442 |
+
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
443 |
+
self.model.transformer.prefix_encoder.float()
|
444 |
+
except Exception as e:
|
445 |
+
print("加载PrefixEncoder模型参数失败")
|
446 |
+
|
447 |
+
self.model = self.model.eval()
|