yeq6x commited on
Commit
d81c6ca
1 Parent(s): 2c835e6

set_lora_device

Browse files
Files changed (1) hide show
  1. scripts/process_utils.py +4 -4
scripts/process_utils.py CHANGED
@@ -50,10 +50,10 @@ def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
50
  sotai_gen_pipe = initialize_sotai_model()
51
  refine_gen_pipe = initialize_refine_model()
52
 
53
- def load_lora(pipeline, lora_path, alpha=0.75):
54
- pipeline.load_lora_weights(lora_path)
55
- pipeline.fuse_lora(lora_scale=alpha)
56
- pipeline.set_lora_device(device)
57
 
58
  def initialize_sotai_model():
59
  global device, torch_dtype
 
50
  sotai_gen_pipe = initialize_sotai_model()
51
  refine_gen_pipe = initialize_refine_model()
52
 
53
+ def load_lora(pipeline, lora_path, adapter_name, alpha=0.75):
54
+ pipeline.load_lora_weights(lora_path, adapter_name)
55
+ pipeline.fuse_lora(lora_scale=alpha, adapter_names=[adapter_name])
56
+ pipeline.set_lora_device(adapter_names=[adapter_name], device=device)
57
 
58
  def initialize_sotai_model():
59
  global device, torch_dtype