masanorihirano commited on
Commit
0d4eedd
1 Parent(s): d91928f
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -11,8 +11,9 @@ from fastchat.serve.inference import compress_module
11
  from fastchat.serve.inference import raise_warning_for_old_weights
12
  from huggingface_hub import Repository
13
  from huggingface_hub import hf_hub_download
 
14
  from peft import LoraConfig
15
- from peft import PeftModel
16
  from peft import set_peft_model_state_dict
17
  from transformers import AutoModelForCausalLM
18
  from transformers import GenerationConfig
@@ -63,7 +64,12 @@ try:
63
  except Exception:
64
  pass
65
 
66
- checkpoint_name = hf_hub_download(repo_id=LORA_WEIGHTS, filename="adapter_model.bin", use_auth_token=HF_TOKEN)
 
 
 
 
 
67
  if device == "cuda":
68
  model = AutoModelForCausalLM.from_pretrained(
69
  BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
@@ -83,12 +89,15 @@ else:
83
  low_cpu_mem_usage=True,
84
  torch_dtype=torch.float16,
85
  )
 
 
 
86
  adapters_weights = torch.load(checkpoint_name)
87
  set_peft_model_state_dict(model, adapters_weights)
88
  raise_warning_for_old_weights(BASE_MODEL, model)
89
  compress_module(model, device)
90
- if device == "cuda" or device == "mps":
91
- model = model.to(device)
92
 
93
 
94
  def generate_prompt(instruction: str, input: Optional[str] = None):
@@ -308,5 +317,5 @@ with gr.Blocks(
308
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
309
 
310
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
311
- server_name="0.0.0.0", server_port=7860
312
  )
 
11
  from fastchat.serve.inference import raise_warning_for_old_weights
12
  from huggingface_hub import Repository
13
  from huggingface_hub import hf_hub_download
14
+ from huggingface_hub import snapshot_download
15
  from peft import LoraConfig
16
+ from peft import get_peft_model
17
  from peft import set_peft_model_state_dict
18
  from transformers import AutoModelForCausalLM
19
  from transformers import GenerationConfig
 
64
  except Exception:
65
  pass
66
 
67
+ resume_from_checkpoint = snapshot_download(
68
+ repo_id=LORA_WEIGHTS, use_auth_token=HF_TOKEN
69
+ )
70
+ checkpoint_name = hf_hub_download(
71
+ repo_id=LORA_WEIGHTS, filename="adapter_model.bin", use_auth_token=HF_TOKEN
72
+ )
73
  if device == "cuda":
74
  model = AutoModelForCausalLM.from_pretrained(
75
  BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
 
89
  low_cpu_mem_usage=True,
90
  torch_dtype=torch.float16,
91
  )
92
+
93
+ config = LoraConfig.from_pretrained(resume_from_checkpoint)
94
+ model = get_peft_model(model, config)
95
  adapters_weights = torch.load(checkpoint_name)
96
  set_peft_model_state_dict(model, adapters_weights)
97
  raise_warning_for_old_weights(BASE_MODEL, model)
98
  compress_module(model, device)
99
+ # if device == "cuda" or device == "mps":
100
+ # model = model.to(device)
101
 
102
 
103
  def generate_prompt(instruction: str, input: Optional[str] = None):
 
317
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
318
 
319
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
320
+ share=True, server_name="0.0.0.0", server_port=7860
321
  )