ClownRat commited on
Commit
9b4dadd
•
1 Parent(s): ee906b7

update demo.

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -61,12 +61,13 @@ The service is a research preview intended for non-commercial use only, subject
61
 
62
 
63
  class Chat:
64
- def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False):
65
  # disable_torch_init()
66
  model_name = get_model_name_from_path(model_path)
67
  self.tokenizer, self.model, processor, context_len = load_pretrained_model(
68
  model_path, model_base, model_name,
69
  load_8bit, load_4bit,
 
70
  offload_folder="save_folder")
71
  self.processor = processor
72
  self.conv_mode = conv_mode
@@ -247,7 +248,7 @@ if __name__ == '__main__':
247
 
248
  handler = Chat(model_path, conv_mode=conv_mode, load_8bit=False, load_4bit=True)
249
  # handler.model.to(dtype=torch.float16)
250
- handler = handler.model.to(device)
251
 
252
  if not os.path.exists("temp"):
253
  os.makedirs("temp")
 
61
 
62
 
63
  class Chat:
64
+ def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda'):
65
  # disable_torch_init()
66
  model_name = get_model_name_from_path(model_path)
67
  self.tokenizer, self.model, processor, context_len = load_pretrained_model(
68
  model_path, model_base, model_name,
69
  load_8bit, load_4bit,
70
+ device=device,
71
  offload_folder="save_folder")
72
  self.processor = processor
73
  self.conv_mode = conv_mode
 
248
 
249
  handler = Chat(model_path, conv_mode=conv_mode, load_8bit=False, load_4bit=True)
250
  # handler.model.to(dtype=torch.float16)
251
+ # handler = handler.model.to(device)
252
 
253
  if not os.path.exists("temp"):
254
  os.makedirs("temp")