masanorihirano
commited on
Commit
•
0d4eedd
1
Parent(s):
d91928f
update
Browse files
app.py
CHANGED
@@ -11,8 +11,9 @@ from fastchat.serve.inference import compress_module
|
|
11 |
from fastchat.serve.inference import raise_warning_for_old_weights
|
12 |
from huggingface_hub import Repository
|
13 |
from huggingface_hub import hf_hub_download
|
|
|
14 |
from peft import LoraConfig
|
15 |
-
from peft import
|
16 |
from peft import set_peft_model_state_dict
|
17 |
from transformers import AutoModelForCausalLM
|
18 |
from transformers import GenerationConfig
|
@@ -63,7 +64,12 @@ try:
|
|
63 |
except Exception:
|
64 |
pass
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
if device == "cuda":
|
68 |
model = AutoModelForCausalLM.from_pretrained(
|
69 |
BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
|
@@ -83,12 +89,15 @@ else:
|
|
83 |
low_cpu_mem_usage=True,
|
84 |
torch_dtype=torch.float16,
|
85 |
)
|
|
|
|
|
|
|
86 |
adapters_weights = torch.load(checkpoint_name)
|
87 |
set_peft_model_state_dict(model, adapters_weights)
|
88 |
raise_warning_for_old_weights(BASE_MODEL, model)
|
89 |
compress_module(model, device)
|
90 |
-
if device == "cuda" or device == "mps":
|
91 |
-
|
92 |
|
93 |
|
94 |
def generate_prompt(instruction: str, input: Optional[str] = None):
|
@@ -308,5 +317,5 @@ with gr.Blocks(
|
|
308 |
clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
|
309 |
|
310 |
demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
|
311 |
-
server_name="0.0.0.0", server_port=7860
|
312 |
)
|
|
|
11 |
from fastchat.serve.inference import raise_warning_for_old_weights
|
12 |
from huggingface_hub import Repository
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
from peft import LoraConfig
|
16 |
+
from peft import get_peft_model
|
17 |
from peft import set_peft_model_state_dict
|
18 |
from transformers import AutoModelForCausalLM
|
19 |
from transformers import GenerationConfig
|
|
|
64 |
except Exception:
|
65 |
pass
|
66 |
|
67 |
+
resume_from_checkpoint = snapshot_download(
|
68 |
+
repo_id=LORA_WEIGHTS, use_auth_token=HF_TOKEN
|
69 |
+
)
|
70 |
+
checkpoint_name = hf_hub_download(
|
71 |
+
repo_id=LORA_WEIGHTS, filename="adapter_model.bin", use_auth_token=HF_TOKEN
|
72 |
+
)
|
73 |
if device == "cuda":
|
74 |
model = AutoModelForCausalLM.from_pretrained(
|
75 |
BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
|
|
|
89 |
low_cpu_mem_usage=True,
|
90 |
torch_dtype=torch.float16,
|
91 |
)
|
92 |
+
|
93 |
+
config = LoraConfig.from_pretrained(resume_from_checkpoint)
|
94 |
+
model = get_peft_model(model, config)
|
95 |
adapters_weights = torch.load(checkpoint_name)
|
96 |
set_peft_model_state_dict(model, adapters_weights)
|
97 |
raise_warning_for_old_weights(BASE_MODEL, model)
|
98 |
compress_module(model, device)
|
99 |
+
# if device == "cuda" or device == "mps":
|
100 |
+
# model = model.to(device)
|
101 |
|
102 |
|
103 |
def generate_prompt(instruction: str, input: Optional[str] = None):
|
|
|
317 |
clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
|
318 |
|
319 |
demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
|
320 |
+
share=True, server_name="0.0.0.0", server_port=7860
|
321 |
)
|