Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
f79f0d5
1
Parent(s):
94ece44
update
Browse files
llama_lora/ui/inference_ui.py
CHANGED
@@ -6,7 +6,7 @@ import transformers
|
|
6 |
from transformers import GenerationConfig
|
7 |
|
8 |
from ..globals import Global
|
9 |
-
from ..models import get_model_with_lora, get_tokenizer, get_device
|
10 |
from ..utils.data import (
|
11 |
get_available_template_names,
|
12 |
get_available_lora_model_names,
|
@@ -37,7 +37,7 @@ def do_inference(
|
|
37 |
prompter = Prompter(prompt_template)
|
38 |
prompt = prompter.generate_prompt(variables)
|
39 |
|
40 |
-
if "/" not in lora_model_name:
|
41 |
path_of_available_lora_model = get_path_of_available_lora_model(
|
42 |
lora_model_name)
|
43 |
if path_of_available_lora_model:
|
@@ -50,7 +50,10 @@ def do_inference(
|
|
50 |
yield message
|
51 |
return
|
52 |
|
53 |
-
|
|
|
|
|
|
|
54 |
tokenizer = get_tokenizer()
|
55 |
|
56 |
inputs = tokenizer(prompt, return_tensors="pt")
|
@@ -131,6 +134,7 @@ def reload_selections(current_lora_model, current_prompt_template):
|
|
131 |
|
132 |
default_lora_models = ["tloen/alpaca-lora-7b"]
|
133 |
available_lora_models = default_lora_models + get_available_lora_model_names()
|
|
|
134 |
|
135 |
current_lora_model = current_lora_model or next(
|
136 |
iter(available_lora_models), None)
|
|
|
6 |
from transformers import GenerationConfig
|
7 |
|
8 |
from ..globals import Global
|
9 |
+
from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_device
|
10 |
from ..utils.data import (
|
11 |
get_available_template_names,
|
12 |
get_available_lora_model_names,
|
|
|
37 |
prompter = Prompter(prompt_template)
|
38 |
prompt = prompter.generate_prompt(variables)
|
39 |
|
40 |
+
if "/" not in lora_model_name and lora_model_name != "None":
|
41 |
path_of_available_lora_model = get_path_of_available_lora_model(
|
42 |
lora_model_name)
|
43 |
if path_of_available_lora_model:
|
|
|
50 |
yield message
|
51 |
return
|
52 |
|
53 |
+
if lora_model_name == "None":
|
54 |
+
model = get_base_model()
|
55 |
+
else:
|
56 |
+
model = get_model_with_lora(lora_model_name)
|
57 |
tokenizer = get_tokenizer()
|
58 |
|
59 |
inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
134 |
|
135 |
default_lora_models = ["tloen/alpaca-lora-7b"]
|
136 |
available_lora_models = default_lora_models + get_available_lora_model_names()
|
137 |
+
available_lora_models = available_lora_models + ["None"]
|
138 |
|
139 |
current_lora_model = current_lora_model or next(
|
140 |
iter(available_lora_models), None)
|