Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
β’
35fba55
1
Parent(s):
87a0e23
update
Browse files- llama_lora/globals.py +1 -1
- llama_lora/ui/inference_ui.py +90 -79
- llama_lora/utils/data.py +14 -0
llama_lora/globals.py
CHANGED
@@ -28,7 +28,7 @@ class Global:
|
|
28 |
# UI related
|
29 |
ui_title: str = "LLaMA-LoRA"
|
30 |
ui_emoji: str = "π¦ποΈ"
|
31 |
-
ui_subtitle: str = "Toolkit for
|
32 |
ui_show_sys_info: bool = True
|
33 |
ui_dev_mode: bool = False
|
34 |
|
|
|
28 |
# UI related
|
29 |
ui_title: str = "LLaMA-LoRA"
|
30 |
ui_emoji: str = "π¦ποΈ"
|
31 |
+
ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
|
32 |
ui_show_sys_info: bool = True
|
33 |
ui_dev_mode: bool = False
|
34 |
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -7,14 +7,17 @@ from transformers import GenerationConfig
|
|
7 |
|
8 |
from ..globals import Global
|
9 |
from ..models import get_model_with_lora, get_tokenizer, get_device
|
10 |
-
from ..utils.data import
|
|
|
|
|
|
|
11 |
from ..utils.prompter import Prompter
|
12 |
from ..utils.callbacks import Iteratorize, Stream
|
13 |
|
14 |
device = get_device()
|
15 |
|
16 |
|
17 |
-
def
|
18 |
lora_model_name,
|
19 |
prompt_template,
|
20 |
variable_0, variable_1, variable_2, variable_3,
|
@@ -27,85 +30,93 @@ def inference(
|
|
27 |
max_new_tokens=128,
|
28 |
stream_output=False,
|
29 |
progress=gr.Progress(track_tqdm=True),
|
30 |
-
**kwargs,
|
31 |
):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
"max_new_tokens": max_new_tokens,
|
64 |
-
}
|
65 |
-
|
66 |
-
if stream_output:
|
67 |
-
# Stream the reply 1 token at a time.
|
68 |
-
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
69 |
-
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
70 |
-
|
71 |
-
def generate_with_callback(callback=None, **kwargs):
|
72 |
-
kwargs.setdefault(
|
73 |
-
"stopping_criteria", transformers.StoppingCriteriaList()
|
74 |
-
)
|
75 |
-
kwargs["stopping_criteria"].append(
|
76 |
-
Stream(callback_func=callback)
|
77 |
-
)
|
78 |
-
with torch.no_grad():
|
79 |
-
model.generate(**kwargs)
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
)
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
# new_tokens = len(output) - len(input_ids[0])
|
89 |
-
decoded_output = tokenizer.decode(output)
|
90 |
-
|
91 |
-
if output[-1] in [tokenizer.eos_token_id]:
|
92 |
-
break
|
93 |
-
|
94 |
-
yield prompter.get_response(decoded_output)
|
95 |
-
return # early return for stream_output
|
96 |
-
|
97 |
-
# Without streaming
|
98 |
-
with torch.no_grad():
|
99 |
-
generation_output = model.generate(
|
100 |
-
input_ids=input_ids,
|
101 |
-
generation_config=generation_config,
|
102 |
-
return_dict_in_generate=True,
|
103 |
-
output_scores=True,
|
104 |
-
max_new_tokens=max_new_tokens,
|
105 |
-
)
|
106 |
-
s = generation_output.sequences[0]
|
107 |
-
output = tokenizer.decode(s)
|
108 |
-
yield prompter.get_response(output)
|
109 |
|
110 |
|
111 |
def reload_selections(current_lora_model, current_prompt_template):
|
@@ -119,7 +130,7 @@ def reload_selections(current_lora_model, current_prompt_template):
|
|
119 |
iter(available_template_names_with_none), None)
|
120 |
|
121 |
default_lora_models = ["tloen/alpaca-lora-7b"]
|
122 |
-
available_lora_models = default_lora_models
|
123 |
|
124 |
current_lora_model = current_lora_model or next(
|
125 |
iter(available_lora_models), None)
|
@@ -263,7 +274,7 @@ def inference_ui():
|
|
263 |
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
264 |
|
265 |
generate_event = generate_btn.click(
|
266 |
-
fn=
|
267 |
inputs=[
|
268 |
lora_model,
|
269 |
prompt_template,
|
|
|
7 |
|
8 |
from ..globals import Global
|
9 |
from ..models import get_model_with_lora, get_tokenizer, get_device
|
10 |
+
from ..utils.data import (
|
11 |
+
get_available_template_names,
|
12 |
+
get_available_lora_model_names,
|
13 |
+
get_path_of_available_lora_model)
|
14 |
from ..utils.prompter import Prompter
|
15 |
from ..utils.callbacks import Iteratorize, Stream
|
16 |
|
17 |
device = get_device()
|
18 |
|
19 |
|
20 |
+
def do_inference(
|
21 |
lora_model_name,
|
22 |
prompt_template,
|
23 |
variable_0, variable_1, variable_2, variable_3,
|
|
|
30 |
max_new_tokens=128,
|
31 |
stream_output=False,
|
32 |
progress=gr.Progress(track_tqdm=True),
|
|
|
33 |
):
|
34 |
+
try:
|
35 |
+
variables = [variable_0, variable_1, variable_2, variable_3,
|
36 |
+
variable_4, variable_5, variable_6, variable_7]
|
37 |
+
prompter = Prompter(prompt_template)
|
38 |
+
prompt = prompter.generate_prompt(variables)
|
39 |
+
|
40 |
+
if "/" not in lora_model_name:
|
41 |
+
path_of_available_lora_model = get_path_of_available_lora_model(
|
42 |
+
lora_model_name)
|
43 |
+
if path_of_available_lora_model:
|
44 |
+
lora_model_name = path_of_available_lora_model
|
45 |
+
|
46 |
+
if Global.ui_dev_mode:
|
47 |
+
message = f"Currently in UI dev mode, not running actual inference.\n\nLoRA model: {lora_model_name}\n\nYour prompt is:\n\n{prompt}"
|
48 |
+
print(message)
|
49 |
+
time.sleep(1)
|
50 |
+
yield message
|
51 |
+
return
|
52 |
+
|
53 |
+
model = get_model_with_lora(lora_model_name)
|
54 |
+
tokenizer = get_tokenizer()
|
55 |
+
|
56 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
57 |
+
input_ids = inputs["input_ids"].to(device)
|
58 |
+
generation_config = GenerationConfig(
|
59 |
+
temperature=temperature,
|
60 |
+
top_p=top_p,
|
61 |
+
top_k=top_k,
|
62 |
+
repetition_penalty=repetition_penalty,
|
63 |
+
num_beams=num_beams,
|
64 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
generate_params = {
|
67 |
+
"input_ids": input_ids,
|
68 |
+
"generation_config": generation_config,
|
69 |
+
"return_dict_in_generate": True,
|
70 |
+
"output_scores": True,
|
71 |
+
"max_new_tokens": max_new_tokens,
|
72 |
+
}
|
73 |
+
|
74 |
+
if stream_output:
|
75 |
+
# Stream the reply 1 token at a time.
|
76 |
+
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
77 |
+
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
78 |
+
|
79 |
+
def generate_with_callback(callback=None, **kwargs):
|
80 |
+
kwargs.setdefault(
|
81 |
+
"stopping_criteria", transformers.StoppingCriteriaList()
|
82 |
+
)
|
83 |
+
kwargs["stopping_criteria"].append(
|
84 |
+
Stream(callback_func=callback)
|
85 |
+
)
|
86 |
+
with torch.no_grad():
|
87 |
+
model.generate(**kwargs)
|
88 |
+
|
89 |
+
def generate_with_streaming(**kwargs):
|
90 |
+
return Iteratorize(
|
91 |
+
generate_with_callback, kwargs, callback=None
|
92 |
+
)
|
93 |
+
|
94 |
+
with generate_with_streaming(**generate_params) as generator:
|
95 |
+
for output in generator:
|
96 |
+
# new_tokens = len(output) - len(input_ids[0])
|
97 |
+
decoded_output = tokenizer.decode(output)
|
98 |
+
|
99 |
+
if output[-1] in [tokenizer.eos_token_id]:
|
100 |
+
break
|
101 |
+
|
102 |
+
yield prompter.get_response(decoded_output)
|
103 |
+
return # early return for stream_output
|
104 |
+
|
105 |
+
# Without streaming
|
106 |
+
with torch.no_grad():
|
107 |
+
generation_output = model.generate(
|
108 |
+
input_ids=input_ids,
|
109 |
+
generation_config=generation_config,
|
110 |
+
return_dict_in_generate=True,
|
111 |
+
output_scores=True,
|
112 |
+
max_new_tokens=max_new_tokens,
|
113 |
)
|
114 |
+
s = generation_output.sequences[0]
|
115 |
+
output = tokenizer.decode(s)
|
116 |
+
yield prompter.get_response(output)
|
117 |
|
118 |
+
except Exception as e:
|
119 |
+
raise gr.Error(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
|
122 |
def reload_selections(current_lora_model, current_prompt_template):
|
|
|
130 |
iter(available_template_names_with_none), None)
|
131 |
|
132 |
default_lora_models = ["tloen/alpaca-lora-7b"]
|
133 |
+
available_lora_models = default_lora_models + get_available_lora_model_names()
|
134 |
|
135 |
current_lora_model = current_lora_model or next(
|
136 |
iter(available_lora_models), None)
|
|
|
274 |
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
275 |
|
276 |
generate_event = generate_btn.click(
|
277 |
+
fn=do_inference,
|
278 |
inputs=[
|
279 |
lora_model,
|
280 |
prompt_template,
|
llama_lora/utils/data.py
CHANGED
@@ -38,6 +38,20 @@ def get_available_dataset_names():
|
|
38 |
return [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def get_dataset_content(name):
|
42 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
43 |
if not os.path.exists(file_name):
|
|
|
38 |
return [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
|
39 |
|
40 |
|
41 |
+
def get_available_lora_model_names():
|
42 |
+
datasets_directory_path = os.path.join(Global.data_dir, "lora_models")
|
43 |
+
all_items = os.listdir(datasets_directory_path)
|
44 |
+
return [item for item in all_items if os.path.isdir(os.path.join(datasets_directory_path, item))]
|
45 |
+
|
46 |
+
|
47 |
+
def get_path_of_available_lora_model(name):
|
48 |
+
datasets_directory_path = os.path.join(Global.data_dir, "lora_models")
|
49 |
+
path = os.path.join(datasets_directory_path, name)
|
50 |
+
if os.path.isdir(path):
|
51 |
+
return path
|
52 |
+
return None
|
53 |
+
|
54 |
+
|
55 |
def get_dataset_content(name):
|
56 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
57 |
if not os.path.exists(file_name):
|