zetavg commited on
Commit
35fba55
β€’
1 Parent(s): 87a0e23
llama_lora/globals.py CHANGED
@@ -28,7 +28,7 @@ class Global:
28
  # UI related
29
  ui_title: str = "LLaMA-LoRA"
30
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
31
- ui_subtitle: str = "Toolkit for examining and fine-tuning LLaMA models using low-rank adaptation (LoRA)."
32
  ui_show_sys_info: bool = True
33
  ui_dev_mode: bool = False
34
 
 
28
  # UI related
29
  ui_title: str = "LLaMA-LoRA"
30
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
31
+ ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
32
  ui_show_sys_info: bool = True
33
  ui_dev_mode: bool = False
34
 
llama_lora/ui/inference_ui.py CHANGED
@@ -7,14 +7,17 @@ from transformers import GenerationConfig
7
 
8
  from ..globals import Global
9
  from ..models import get_model_with_lora, get_tokenizer, get_device
10
- from ..utils.data import get_available_template_names
 
 
 
11
  from ..utils.prompter import Prompter
12
  from ..utils.callbacks import Iteratorize, Stream
13
 
14
  device = get_device()
15
 
16
 
17
- def inference(
18
  lora_model_name,
19
  prompt_template,
20
  variable_0, variable_1, variable_2, variable_3,
@@ -27,85 +30,93 @@ def inference(
27
  max_new_tokens=128,
28
  stream_output=False,
29
  progress=gr.Progress(track_tqdm=True),
30
- **kwargs,
31
  ):
32
- variables = [variable_0, variable_1, variable_2, variable_3,
33
- variable_4, variable_5, variable_6, variable_7]
34
- prompter = Prompter(prompt_template)
35
- prompt = prompter.generate_prompt(variables)
36
-
37
- if Global.ui_dev_mode:
38
- message = f"Currently in UI dev mode, not running actual inference.\n\nYour prompt is:\n\n{prompt}"
39
- print(message)
40
- time.sleep(1)
41
- yield message
42
- return
43
-
44
- model = get_model_with_lora(lora_model_name)
45
- tokenizer = get_tokenizer()
46
-
47
- inputs = tokenizer(prompt, return_tensors="pt")
48
- input_ids = inputs["input_ids"].to(device)
49
- generation_config = GenerationConfig(
50
- temperature=temperature,
51
- top_p=top_p,
52
- top_k=top_k,
53
- repetition_penalty=repetition_penalty,
54
- num_beams=num_beams,
55
- **kwargs,
56
- )
57
-
58
- generate_params = {
59
- "input_ids": input_ids,
60
- "generation_config": generation_config,
61
- "return_dict_in_generate": True,
62
- "output_scores": True,
63
- "max_new_tokens": max_new_tokens,
64
- }
65
-
66
- if stream_output:
67
- # Stream the reply 1 token at a time.
68
- # This is based on the trick of using 'stopping_criteria' to create an iterator,
69
- # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
70
-
71
- def generate_with_callback(callback=None, **kwargs):
72
- kwargs.setdefault(
73
- "stopping_criteria", transformers.StoppingCriteriaList()
74
- )
75
- kwargs["stopping_criteria"].append(
76
- Stream(callback_func=callback)
77
- )
78
- with torch.no_grad():
79
- model.generate(**kwargs)
80
 
81
- def generate_with_streaming(**kwargs):
82
- return Iteratorize(
83
- generate_with_callback, kwargs, callback=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
 
 
 
85
 
86
- with generate_with_streaming(**generate_params) as generator:
87
- for output in generator:
88
- # new_tokens = len(output) - len(input_ids[0])
89
- decoded_output = tokenizer.decode(output)
90
-
91
- if output[-1] in [tokenizer.eos_token_id]:
92
- break
93
-
94
- yield prompter.get_response(decoded_output)
95
- return # early return for stream_output
96
-
97
- # Without streaming
98
- with torch.no_grad():
99
- generation_output = model.generate(
100
- input_ids=input_ids,
101
- generation_config=generation_config,
102
- return_dict_in_generate=True,
103
- output_scores=True,
104
- max_new_tokens=max_new_tokens,
105
- )
106
- s = generation_output.sequences[0]
107
- output = tokenizer.decode(s)
108
- yield prompter.get_response(output)
109
 
110
 
111
  def reload_selections(current_lora_model, current_prompt_template):
@@ -119,7 +130,7 @@ def reload_selections(current_lora_model, current_prompt_template):
119
  iter(available_template_names_with_none), None)
120
 
121
  default_lora_models = ["tloen/alpaca-lora-7b"]
122
- available_lora_models = default_lora_models
123
 
124
  current_lora_model = current_lora_model or next(
125
  iter(available_lora_models), None)
@@ -263,7 +274,7 @@ def inference_ui():
263
  variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
264
 
265
  generate_event = generate_btn.click(
266
- fn=inference,
267
  inputs=[
268
  lora_model,
269
  prompt_template,
 
7
 
8
  from ..globals import Global
9
  from ..models import get_model_with_lora, get_tokenizer, get_device
10
+ from ..utils.data import (
11
+ get_available_template_names,
12
+ get_available_lora_model_names,
13
+ get_path_of_available_lora_model)
14
  from ..utils.prompter import Prompter
15
  from ..utils.callbacks import Iteratorize, Stream
16
 
17
  device = get_device()
18
 
19
 
20
+ def do_inference(
21
  lora_model_name,
22
  prompt_template,
23
  variable_0, variable_1, variable_2, variable_3,
 
30
  max_new_tokens=128,
31
  stream_output=False,
32
  progress=gr.Progress(track_tqdm=True),
 
33
  ):
34
+ try:
35
+ variables = [variable_0, variable_1, variable_2, variable_3,
36
+ variable_4, variable_5, variable_6, variable_7]
37
+ prompter = Prompter(prompt_template)
38
+ prompt = prompter.generate_prompt(variables)
39
+
40
+ if "/" not in lora_model_name:
41
+ path_of_available_lora_model = get_path_of_available_lora_model(
42
+ lora_model_name)
43
+ if path_of_available_lora_model:
44
+ lora_model_name = path_of_available_lora_model
45
+
46
+ if Global.ui_dev_mode:
47
+ message = f"Currently in UI dev mode, not running actual inference.\n\nLoRA model: {lora_model_name}\n\nYour prompt is:\n\n{prompt}"
48
+ print(message)
49
+ time.sleep(1)
50
+ yield message
51
+ return
52
+
53
+ model = get_model_with_lora(lora_model_name)
54
+ tokenizer = get_tokenizer()
55
+
56
+ inputs = tokenizer(prompt, return_tensors="pt")
57
+ input_ids = inputs["input_ids"].to(device)
58
+ generation_config = GenerationConfig(
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ repetition_penalty=repetition_penalty,
63
+ num_beams=num_beams,
64
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ generate_params = {
67
+ "input_ids": input_ids,
68
+ "generation_config": generation_config,
69
+ "return_dict_in_generate": True,
70
+ "output_scores": True,
71
+ "max_new_tokens": max_new_tokens,
72
+ }
73
+
74
+ if stream_output:
75
+ # Stream the reply 1 token at a time.
76
+ # This is based on the trick of using 'stopping_criteria' to create an iterator,
77
+ # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
78
+
79
+ def generate_with_callback(callback=None, **kwargs):
80
+ kwargs.setdefault(
81
+ "stopping_criteria", transformers.StoppingCriteriaList()
82
+ )
83
+ kwargs["stopping_criteria"].append(
84
+ Stream(callback_func=callback)
85
+ )
86
+ with torch.no_grad():
87
+ model.generate(**kwargs)
88
+
89
+ def generate_with_streaming(**kwargs):
90
+ return Iteratorize(
91
+ generate_with_callback, kwargs, callback=None
92
+ )
93
+
94
+ with generate_with_streaming(**generate_params) as generator:
95
+ for output in generator:
96
+ # new_tokens = len(output) - len(input_ids[0])
97
+ decoded_output = tokenizer.decode(output)
98
+
99
+ if output[-1] in [tokenizer.eos_token_id]:
100
+ break
101
+
102
+ yield prompter.get_response(decoded_output)
103
+ return # early return for stream_output
104
+
105
+ # Without streaming
106
+ with torch.no_grad():
107
+ generation_output = model.generate(
108
+ input_ids=input_ids,
109
+ generation_config=generation_config,
110
+ return_dict_in_generate=True,
111
+ output_scores=True,
112
+ max_new_tokens=max_new_tokens,
113
  )
114
+ s = generation_output.sequences[0]
115
+ output = tokenizer.decode(s)
116
+ yield prompter.get_response(output)
117
 
118
+ except Exception as e:
119
+ raise gr.Error(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  def reload_selections(current_lora_model, current_prompt_template):
 
130
  iter(available_template_names_with_none), None)
131
 
132
  default_lora_models = ["tloen/alpaca-lora-7b"]
133
+ available_lora_models = default_lora_models + get_available_lora_model_names()
134
 
135
  current_lora_model = current_lora_model or next(
136
  iter(available_lora_models), None)
 
274
  variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
275
 
276
  generate_event = generate_btn.click(
277
+ fn=do_inference,
278
  inputs=[
279
  lora_model,
280
  prompt_template,
llama_lora/utils/data.py CHANGED
@@ -38,6 +38,20 @@ def get_available_dataset_names():
38
  return [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def get_dataset_content(name):
42
  file_name = os.path.join(Global.data_dir, "datasets", name)
43
  if not os.path.exists(file_name):
 
38
  return [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
39
 
40
 
41
+ def get_available_lora_model_names():
42
+ datasets_directory_path = os.path.join(Global.data_dir, "lora_models")
43
+ all_items = os.listdir(datasets_directory_path)
44
+ return [item for item in all_items if os.path.isdir(os.path.join(datasets_directory_path, item))]
45
+
46
+
47
+ def get_path_of_available_lora_model(name):
48
+ datasets_directory_path = os.path.join(Global.data_dir, "lora_models")
49
+ path = os.path.join(datasets_directory_path, name)
50
+ if os.path.isdir(path):
51
+ return path
52
+ return None
53
+
54
+
55
  def get_dataset_content(name):
56
  file_name = os.path.join(Global.data_dir, "datasets", name)
57
  if not os.path.exists(file_name):