zetavg commited on
Commit
f79f0d5
1 Parent(s): 94ece44
Files changed (1) hide show
  1. llama_lora/ui/inference_ui.py +7 -3
llama_lora/ui/inference_ui.py CHANGED
@@ -6,7 +6,7 @@ import transformers
6
  from transformers import GenerationConfig
7
 
8
  from ..globals import Global
9
- from ..models import get_model_with_lora, get_tokenizer, get_device
10
  from ..utils.data import (
11
  get_available_template_names,
12
  get_available_lora_model_names,
@@ -37,7 +37,7 @@ def do_inference(
37
  prompter = Prompter(prompt_template)
38
  prompt = prompter.generate_prompt(variables)
39
 
40
- if "/" not in lora_model_name:
41
  path_of_available_lora_model = get_path_of_available_lora_model(
42
  lora_model_name)
43
  if path_of_available_lora_model:
@@ -50,7 +50,10 @@ def do_inference(
50
  yield message
51
  return
52
 
53
- model = get_model_with_lora(lora_model_name)
 
 
 
54
  tokenizer = get_tokenizer()
55
 
56
  inputs = tokenizer(prompt, return_tensors="pt")
@@ -131,6 +134,7 @@ def reload_selections(current_lora_model, current_prompt_template):
131
 
132
  default_lora_models = ["tloen/alpaca-lora-7b"]
133
  available_lora_models = default_lora_models + get_available_lora_model_names()
 
134
 
135
  current_lora_model = current_lora_model or next(
136
  iter(available_lora_models), None)
 
6
  from transformers import GenerationConfig
7
 
8
  from ..globals import Global
9
+ from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_device
10
  from ..utils.data import (
11
  get_available_template_names,
12
  get_available_lora_model_names,
 
37
  prompter = Prompter(prompt_template)
38
  prompt = prompter.generate_prompt(variables)
39
 
40
+ if "/" not in lora_model_name and lora_model_name != "None":
41
  path_of_available_lora_model = get_path_of_available_lora_model(
42
  lora_model_name)
43
  if path_of_available_lora_model:
 
50
  yield message
51
  return
52
 
53
+ if lora_model_name == "None":
54
+ model = get_base_model()
55
+ else:
56
+ model = get_model_with_lora(lora_model_name)
57
  tokenizer = get_tokenizer()
58
 
59
  inputs = tokenizer(prompt, return_tensors="pt")
 
134
 
135
  default_lora_models = ["tloen/alpaca-lora-7b"]
136
  available_lora_models = default_lora_models + get_available_lora_model_names()
137
+ available_lora_models = available_lora_models + ["None"]
138
 
139
  current_lora_model = current_lora_model or next(
140
  iter(available_lora_models), None)