masanorihirano commited on
Commit
4f91616
1 Parent(s): df97cfa

enable 8bit

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -59,31 +59,34 @@ except Exception:
59
  if device == "cuda":
60
  model = AutoModelForCausalLM.from_pretrained(
61
  BASE_MODEL,
62
- load_in_8bit=False,
63
- torch_dtype=torch.float16,
64
  device_map="auto",
65
  )
66
- model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
67
  elif device == "mps":
68
  model = AutoModelForCausalLM.from_pretrained(
69
  BASE_MODEL,
70
  device_map={"": device},
71
- torch_dtype=torch.float16,
72
  )
73
  model = PeftModel.from_pretrained(
74
  model,
75
  LORA_WEIGHTS,
76
  device_map={"": device},
77
- torch_dtype=torch.float16,
78
  )
79
  else:
80
  model = AutoModelForCausalLM.from_pretrained(
81
- BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
 
 
 
82
  )
83
  model = PeftModel.from_pretrained(
84
  model,
85
  LORA_WEIGHTS,
86
  device_map={"": device},
 
87
  )
88
 
89
 
 
59
  if device == "cuda":
60
  model = AutoModelForCausalLM.from_pretrained(
61
  BASE_MODEL,
62
+ load_in_8bit=True,
 
63
  device_map="auto",
64
  )
65
+ model = PeftModel.from_pretrained(model, LORA_WEIGHTS, load_in_8bit=True)
66
  elif device == "mps":
67
  model = AutoModelForCausalLM.from_pretrained(
68
  BASE_MODEL,
69
  device_map={"": device},
70
+ load_in_8bit=True,
71
  )
72
  model = PeftModel.from_pretrained(
73
  model,
74
  LORA_WEIGHTS,
75
  device_map={"": device},
76
+ load_in_8bit=True,
77
  )
78
  else:
79
  model = AutoModelForCausalLM.from_pretrained(
80
+ BASE_MODEL,
81
+ device_map={"": device},
82
+ low_cpu_mem_usage=True,
83
+ load_in_8bit=True,
84
  )
85
  model = PeftModel.from_pretrained(
86
  model,
87
  LORA_WEIGHTS,
88
  device_map={"": device},
89
+ load_in_8bit=True,
90
  )
91
 
92