Tuchuanhuhuhu commited on
Commit
9813f91
1 Parent(s): 6a49812

feat: 加入GPT 模型微调功能

Browse files
ChuanhuChatbot.py CHANGED
@@ -5,6 +5,7 @@ logging.basicConfig(
5
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
6
  )
7
 
 
8
  import gradio as gr
9
 
10
  from modules import config
@@ -15,6 +16,7 @@ from modules.overwrites import *
15
  from modules.webui import *
16
  from modules.repo import *
17
  from modules.models.models import get_model
 
18
 
19
  logging.getLogger("httpx").setLevel(logging.WARNING)
20
 
@@ -34,6 +36,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
34
  assert type(my_api_key)==str
35
  user_api_key = gr.State(my_api_key)
36
  current_model = gr.State(create_new_model)
 
37
 
38
  topic = gr.State(i18n("未命名对话历史记录"))
39
 
@@ -188,14 +191,17 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
188
  with gr.Tab(label=i18n("训练")):
189
  with gr.Column(variant="panel"):
190
  dataset_preview_json = gr.JSON(label=i18n("数据集预览"), readonly=True)
191
- upload_dataset_btn = gr.UploadButton(label = i18n("上传数据集"), file_types=[".xlsx", ".jsonl"])
 
 
192
  with gr.Column(variant="panel"):
 
193
  openai_train_epoch_slider = gr.Slider(label=i18n("训练轮数"), minimum=1, maximum=100, value=3, step=1, interactive=True)
194
  openai_start_train_btn = gr.Button(i18n("开始训练"))
195
  with gr.Column(variant="panel"):
196
  openai_train_status = gr.Markdown(label=i18n("训练状态"), value=i18n("未开始训练"))
197
  openai_status_refresh_btn = gr.Button(i18n("刷新状态"))
198
- add_to_models_btn = gr.Button(i18n("添加到模型列表"), interactive=False)
199
 
200
  with gr.Tab(label=i18n("高级")):
201
  gr.HTML(get_html("appearance_switcher.html").format(label=i18n("切换亮暗色主题")), elem_classes="insert-block")
@@ -485,6 +491,14 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
485
  historyFileSelectDropdown.change(**load_history_from_file_args)
486
  downloadFile.change(upload_chat_history, [current_model, downloadFile, user_name], [saveFileName, systemPromptTxt, chatbot])
487
 
 
 
 
 
 
 
 
 
488
  # Advanced
489
  max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
490
  temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
 
5
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
6
  )
7
 
8
+ import colorama
9
  import gradio as gr
10
 
11
  from modules import config
 
16
  from modules.webui import *
17
  from modules.repo import *
18
  from modules.models.models import get_model
19
+ from modules.train_func import handle_dataset_selection, handle_dataset_clear, upload_to_openai, start_training, get_training_status, add_to_models
20
 
21
  logging.getLogger("httpx").setLevel(logging.WARNING)
22
 
 
36
  assert type(my_api_key)==str
37
  user_api_key = gr.State(my_api_key)
38
  current_model = gr.State(create_new_model)
39
+ openai_ft_file_id = gr.State("")
40
 
41
  topic = gr.State(i18n("未命名对话历史记录"))
42
 
 
191
  with gr.Tab(label=i18n("训练")):
192
  with gr.Column(variant="panel"):
193
  dataset_preview_json = gr.JSON(label=i18n("数据集预览"), readonly=True)
194
+ dataset_selection = gr.Files(label = i18n("选择数据集"), file_types=[".xlsx", ".jsonl"], file_count="single")
195
+ upload_to_openai_btn = gr.Button(i18n("上传到OpenAI"), interactive=False)
196
+
197
  with gr.Column(variant="panel"):
198
+ openai_ft_suffix = gr.Textbox(label=i18n("模型名称后缀"), value="", lines=1, placeholder=i18n("可选,用于区分不同的模型"))
199
  openai_train_epoch_slider = gr.Slider(label=i18n("训练轮数"), minimum=1, maximum=100, value=3, step=1, interactive=True)
200
  openai_start_train_btn = gr.Button(i18n("开始训练"))
201
  with gr.Column(variant="panel"):
202
  openai_train_status = gr.Markdown(label=i18n("训练状态"), value=i18n("未开始训练"))
203
  openai_status_refresh_btn = gr.Button(i18n("刷新状态"))
204
+ add_to_models_btn = gr.Button(i18n("添加训练好的模型到模型列表"), interactive=False)
205
 
206
  with gr.Tab(label=i18n("高级")):
207
  gr.HTML(get_html("appearance_switcher.html").format(label=i18n("切换亮暗色主题")), elem_classes="insert-block")
 
491
  historyFileSelectDropdown.change(**load_history_from_file_args)
492
  downloadFile.change(upload_chat_history, [current_model, downloadFile, user_name], [saveFileName, systemPromptTxt, chatbot])
493
 
494
+ # Train
495
+ dataset_selection.upload(handle_dataset_selection, dataset_selection, [dataset_preview_json, upload_to_openai_btn, status_display])
496
+ dataset_selection.clear(handle_dataset_clear, [], [dataset_preview_json, upload_to_openai_btn])
497
+ upload_to_openai_btn.click(upload_to_openai, [dataset_selection], [openai_ft_file_id, status_display], show_progress=True)
498
+ openai_start_train_btn.click(start_training, [openai_ft_file_id, openai_ft_suffix, openai_train_epoch_slider], [openai_train_status])
499
+ openai_status_refresh_btn.click(get_training_status, [], [openai_train_status, add_to_models_btn])
500
+ add_to_models_btn.click(add_to_models, [], [model_select_dropdown, status_display], show_progress=True)
501
+
502
  # Advanced
503
  max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
504
  temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
modules/index_func.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import logging
3
 
4
- import colorama
5
  import PyPDF2
6
  from tqdm import tqdm
7
 
@@ -10,19 +10,6 @@ from modules.utils import *
10
  from modules.config import local_embedding
11
 
12
 
13
- def get_index_name(file_src):
14
- file_paths = [x.name for x in file_src]
15
- file_paths.sort(key=lambda x: os.path.basename(x))
16
-
17
- md5_hash = hashlib.md5()
18
- for file_path in file_paths:
19
- with open(file_path, "rb") as f:
20
- while chunk := f.read(8192):
21
- md5_hash.update(chunk)
22
-
23
- return md5_hash.hexdigest()
24
-
25
-
26
  def get_documents(file_src):
27
  from langchain.schema import Document
28
  from langchain.text_splitter import TokenTextSplitter
@@ -113,7 +100,7 @@ def construct_index(
113
  embedding_limit = None if embedding_limit == 0 else embedding_limit
114
  separator = " " if separator == "" else separator
115
 
116
- index_name = get_index_name(file_src)
117
  index_path = f"./index/{index_name}"
118
  if local_embedding:
119
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
 
1
  import os
2
  import logging
3
 
4
+ import hashlib
5
  import PyPDF2
6
  from tqdm import tqdm
7
 
 
10
  from modules.config import local_embedding
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def get_documents(file_src):
14
  from langchain.schema import Document
15
  from langchain.text_splitter import TokenTextSplitter
 
100
  embedding_limit = None if embedding_limit == 0 else embedding_limit
101
  separator = " " if separator == "" else separator
102
 
103
+ index_name = get_file_hash(file_src)
104
  index_path = f"./index/{index_name}"
105
  if local_embedding:
106
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
modules/train_func.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import traceback
4
+
5
+ import openai
6
+ import gradio as gr
7
+ import ujson as json
8
+
9
+ import modules.presets as presets
10
+ from modules.utils import get_file_hash
11
+ from modules.presets import i18n
12
+
13
+ def excel_to_jsonl(filepath, preview=False):
14
+ jsonl = []
15
+ with open(filepath, "rb") as f:
16
+ import pandas as pd
17
+ df = pd.read_excel(f)
18
+ for row in df.iterrows():
19
+ jsonl.append(row[1].to_dict())
20
+ if preview:
21
+ break
22
+ return jsonl
23
+
24
+ def jsonl_save_to_disk(jsonl, filepath):
25
+ file_hash = get_file_hash(file_paths = [filepath])
26
+ os.makedirs("files", exist_ok=True)
27
+ save_path = f"files/{file_hash}.jsonl"
28
+ with open(save_path, "w") as f:
29
+ f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in jsonl]))
30
+ return save_path
31
+
32
+ def handle_dataset_selection(file_src):
33
+ logging.info(f"Loading dataset {file_src.name}...")
34
+ preview = ""
35
+ if file_src.name.endswith(".jsonl"):
36
+ with open(file_src.name, "r") as f:
37
+ preview = f.readline()
38
+ else:
39
+ preview = excel_to_jsonl(file_src.name)[0]
40
+ return preview, gr.update(interactive=True), "预估数据集 token 数量: 这个功能还没实现"
41
+
42
+ def upload_to_openai(file_src):
43
+ openai.api_key = os.getenv("OPENAI_API_KEY")
44
+ dspath = file_src.name
45
+ msg = ""
46
+ logging.info(f"Uploading dataset {dspath}...")
47
+ if dspath.endswith(".xlsx"):
48
+ jsonl = excel_to_jsonl(dspath)
49
+ tmp_jsonl = []
50
+ for i in jsonl:
51
+ if "提问" in i and "答案" in i:
52
+ if "系统" in i :
53
+ tmp_jsonl.append({
54
+ "messages":[
55
+ {"role": "system", "content": i["系统"]},
56
+ {"role": "user", "content": i["提问"]},
57
+ {"role": "assistant", "content": i["答案"]}
58
+ ]
59
+ })
60
+ else:
61
+ tmp_jsonl.append({
62
+ "messages":[
63
+ {"role": "user", "content": i["提问"]},
64
+ {"role": "assistant", "content": i["答案"]}
65
+ ]
66
+ })
67
+ else:
68
+ logging.warning(f"跳过一行数据,因为没有找到提问和答案: {i}")
69
+ jsonl = tmp_jsonl
70
+ dspath = jsonl_save_to_disk(jsonl, dspath)
71
+ try:
72
+ uploaded = openai.File.create(
73
+ file=open(dspath, "rb"),
74
+ purpose='fine-tune'
75
+ )
76
+ return uploaded.id, f"上传成功,文件ID: {uploaded.id}"
77
+ except Exception as e:
78
+ traceback.print_exc()
79
+ return "", f"上传失败,原因:{ e }"
80
+
81
+ def build_event_description(id, status, trained_tokens, name=i18n("暂时未知")):
82
+ # convert to markdown
83
+ return f"""
84
+ #### 训练任务 {id}
85
+
86
+ 模型名称:{name}
87
+
88
+ 状态:{status}
89
+
90
+ 已经训练了 {trained_tokens} 个token
91
+ """
92
+
93
+ def start_training(file_id, suffix, epochs):
94
+ openai.api_key = os.getenv("OPENAI_API_KEY")
95
+ try:
96
+ job = openai.FineTuningJob.create(training_file=file_id, model="gpt-3.5-turbo", suffix=suffix, hyperparameters={"n_epochs": epochs})
97
+ return build_event_description(job.id, job.status, job.trained_tokens)
98
+ except Exception as e:
99
+ traceback.print_exc()
100
+ if "is not ready" in str(e):
101
+ return "训练出错,因为文件还没准备好。OpenAI 需要一点时间准备文件,过几分钟再来试试。"
102
+ return f"训练失败,原因:{ e }"
103
+
104
+ def get_training_status():
105
+ openai.api_key = os.getenv("OPENAI_API_KEY")
106
+ active_jobs = [build_event_description(job["id"], job["status"], job["trained_tokens"], job["fine_tuned_model"]) for job in openai.FineTuningJob.list(limit=10)["data"] if job["status"] != "cancelled"]
107
+ return "\n\n".join(active_jobs), gr.update(interactive=True) if len(active_jobs) > 0 else gr.update(interactive=False)
108
+
109
+ def handle_dataset_clear():
110
+ return gr.update(value=None), gr.update(interactive=False)
111
+
112
+ def add_to_models():
113
+ openai.api_key = os.getenv("OPENAI_API_KEY")
114
+ succeeded_jobs = [job for job in openai.FineTuningJob.list(limit=10)["data"] if job["status"] == "succeeded"]
115
+ presets.MODELS.extend([job["fine_tuned_model"] for job in succeeded_jobs])
116
+ return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
modules/utils.py CHANGED
@@ -5,14 +5,11 @@ import logging
5
  import commentjson as json
6
  import os
7
  import datetime
8
- from datetime import timezone
9
- import hashlib
10
  import csv
11
  import requests
12
  import re
13
  import html
14
- import sys
15
- import subprocess
16
 
17
  import gradio as gr
18
  from pypinyin import lazy_pinyin
@@ -241,7 +238,7 @@ def convert_bot_before_marked(chat_message):
241
  code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
242
  code_blocks = code_block_pattern.findall(chat_message)
243
  non_code_parts = code_block_pattern.split(chat_message)[::2]
244
- result = []
245
  for non_code, code in zip(non_code_parts, code_blocks + [""]):
246
  if non_code.strip():
247
  result.append(non_code)
@@ -670,3 +667,16 @@ def auth_from_conf(username, password):
670
  return False
671
  except:
672
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import commentjson as json
6
  import os
7
  import datetime
 
 
8
  import csv
9
  import requests
10
  import re
11
  import html
12
+ import hashlib
 
13
 
14
  import gradio as gr
15
  from pypinyin import lazy_pinyin
 
238
  code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
239
  code_blocks = code_block_pattern.findall(chat_message)
240
  non_code_parts = code_block_pattern.split(chat_message)[::2]
241
+ result = []
242
  for non_code, code in zip(non_code_parts, code_blocks + [""]):
243
  if non_code.strip():
244
  result.append(non_code)
 
667
  return False
668
  except:
669
  return False
670
+
671
+ def get_file_hash(file_src=None, file_paths=None):
672
+ if file_src:
673
+ file_paths = [x.name for x in file_src]
674
+ file_paths.sort(key=lambda x: os.path.basename(x))
675
+
676
+ md5_hash = hashlib.md5()
677
+ for file_path in file_paths:
678
+ with open(file_path, "rb") as f:
679
+ while chunk := f.read(8192):
680
+ md5_hash.update(chunk)
681
+
682
+ return md5_hash.hexdigest()
requirements.txt CHANGED
@@ -21,7 +21,8 @@ duckduckgo-search
21
  arxiv
22
  wikipedia
23
  google.generativeai
24
- openai
25
  unstructured
26
  google-api-python-client
27
  tabulate
 
 
21
  arxiv
22
  wikipedia
23
  google.generativeai
24
+ openai>=0.27.9
25
  unstructured
26
  google-api-python-client
27
  tabulate
28
+ ujson