masanorihirano
commited on
Commit
•
7a42c18
1
Parent(s):
bed8c52
update
Browse files- app.py +7 -29
- pyproject.toml +1 -1
app.py
CHANGED
@@ -9,16 +9,12 @@ from typing import Union
|
|
9 |
import gradio as gr
|
10 |
import requests
|
11 |
import torch
|
|
|
12 |
from fastchat.conversation import Conversation
|
13 |
-
from fastchat.conversation import
|
14 |
-
from fastchat.conversation import get_conv_template
|
15 |
-
from fastchat.conversation import register_conv_template
|
16 |
-
from fastchat.model.model_adapter import BaseAdapter
|
17 |
-
from fastchat.model.model_adapter import load_model
|
18 |
-
from fastchat.model.model_adapter import model_adapters
|
19 |
from fastchat.serve.cli import SimpleChatIO
|
20 |
-
from fastchat.serve.inference import compress_module
|
21 |
from fastchat.serve.inference import generate_stream
|
|
|
22 |
from huggingface_hub import Repository
|
23 |
from huggingface_hub import snapshot_download
|
24 |
from peft import LoraConfig
|
@@ -30,24 +26,8 @@ from transformers import LlamaTokenizer
|
|
30 |
from transformers import PreTrainedModel
|
31 |
from transformers import PreTrainedTokenizerBase
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
"Model adapater for vicuna-v1.1"
|
36 |
-
|
37 |
-
def match(self, model_path: str):
|
38 |
-
return "llama" in model_path
|
39 |
-
|
40 |
-
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
41 |
-
tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
|
42 |
-
model = LlamaForCausalLM.from_pretrained(
|
43 |
-
model_path,
|
44 |
-
low_cpu_mem_usage=True,
|
45 |
-
**from_pretrained_kwargs,
|
46 |
-
)
|
47 |
-
return model, tokenizer
|
48 |
-
|
49 |
-
|
50 |
-
model_adapters.insert(-1, LLaMAdapter())
|
51 |
|
52 |
|
53 |
def load_lora_model(
|
@@ -67,12 +47,10 @@ def load_lora_model(
|
|
67 |
device=device,
|
68 |
num_gpus=num_gpus,
|
69 |
max_gpu_memory=max_gpu_memory,
|
70 |
-
load_8bit=
|
71 |
cpu_offloading=cpu_offloading,
|
72 |
debug=debug,
|
73 |
)
|
74 |
-
if load_8bit:
|
75 |
-
compress_module(model)
|
76 |
if lora_weight is not None:
|
77 |
# model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
|
78 |
config = LoraConfig.from_pretrained(lora_weight)
|
@@ -217,7 +195,7 @@ def evaluate(
|
|
217 |
gr.update(interactive=True),
|
218 |
)
|
219 |
|
220 |
-
conv =
|
221 |
|
222 |
conv.append_message(conv.roles[0], instruction)
|
223 |
conv.append_message(conv.roles[1], None)
|
|
|
9 |
import gradio as gr
|
10 |
import requests
|
11 |
import torch
|
12 |
+
import transformers
|
13 |
from fastchat.conversation import Conversation
|
14 |
+
from fastchat.conversation import get_default_conv_template
|
|
|
|
|
|
|
|
|
|
|
15 |
from fastchat.serve.cli import SimpleChatIO
|
|
|
16 |
from fastchat.serve.inference import generate_stream
|
17 |
+
from fastchat.serve.inference import load_model
|
18 |
from huggingface_hub import Repository
|
19 |
from huggingface_hub import snapshot_download
|
20 |
from peft import LoraConfig
|
|
|
26 |
from transformers import PreTrainedModel
|
27 |
from transformers import PreTrainedTokenizerBase
|
28 |
|
29 |
+
transformers.AutoTokenizer.from_pretrained = LlamaTokenizer.from_pretrained
|
30 |
+
transformers.AutoModelForCausalLM.from_pretrained = LlamaForCausalLM.from_pretrained
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
def load_lora_model(
|
|
|
47 |
device=device,
|
48 |
num_gpus=num_gpus,
|
49 |
max_gpu_memory=max_gpu_memory,
|
50 |
+
load_8bit=True,
|
51 |
cpu_offloading=cpu_offloading,
|
52 |
debug=debug,
|
53 |
)
|
|
|
|
|
54 |
if lora_weight is not None:
|
55 |
# model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
|
56 |
config = LoraConfig.from_pretrained(lora_weight)
|
|
|
195 |
gr.update(interactive=True),
|
196 |
)
|
197 |
|
198 |
+
conv = get_default_conv_template(BASE_MODEL).copy()
|
199 |
|
200 |
conv.append_message(conv.roles[0], instruction)
|
201 |
conv.append_message(conv.roles[1], None)
|
pyproject.toml
CHANGED
@@ -15,7 +15,7 @@ huggingface-hub = "^0.14.1"
|
|
15 |
sentencepiece = "^0.1.99"
|
16 |
bitsandbytes = "^0.38.1"
|
17 |
accelerate = "^0.19.0"
|
18 |
-
fschat = "0.2.
|
19 |
transformers = "4.28.1"
|
20 |
|
21 |
|
|
|
15 |
sentencepiece = "^0.1.99"
|
16 |
bitsandbytes = "^0.38.1"
|
17 |
accelerate = "^0.19.0"
|
18 |
+
fschat = "0.2.3"
|
19 |
transformers = "4.28.1"
|
20 |
|
21 |
|