wondervictor commited on
Commit
b962858
·
1 Parent(s): d8de5a4

add requirements

Browse files
Files changed (1) hide show
  1. model.py +3 -3
model.py CHANGED
@@ -59,7 +59,7 @@ class Model:
59
  map_location="cpu")
60
  vq_model.load_state_dict(checkpoint["model"])
61
  del checkpoint
62
- print(f"image tokenizer is loaded")
63
  return vq_model
64
 
65
  def load_gpt(self, condition_type='canny'):
@@ -76,14 +76,14 @@ class Model:
76
  model_weight = load_file(gpt_ckpt)
77
  gpt_model.load_state_dict(model_weight, strict=False)
78
  gpt_model.eval()
79
- print(f"gpt model is loaded")
80
  return gpt_model
81
 
82
  def load_t5(self):
83
  precision = torch.bfloat16
84
  t5_model = T5Embedder(
85
  device=self.device,
86
- local_cache=True,
87
  cache_dir='checkpoints/flan-t5-xl',
88
  dir_or_name='flan-t5-xl',
89
  torch_dtype=precision,
 
59
  map_location="cpu")
60
  vq_model.load_state_dict(checkpoint["model"])
61
  del checkpoint
62
+ print("image tokenizer is loaded")
63
  return vq_model
64
 
65
  def load_gpt(self, condition_type='canny'):
 
76
  model_weight = load_file(gpt_ckpt)
77
  gpt_model.load_state_dict(model_weight, strict=False)
78
  gpt_model.eval()
79
+ print("gpt model is loaded")
80
  return gpt_model
81
 
82
  def load_t5(self):
83
  precision = torch.bfloat16
84
  t5_model = T5Embedder(
85
  device=self.device,
86
+ local_cache=False,
87
  cache_dir='checkpoints/flan-t5-xl',
88
  dir_or_name='flan-t5-xl',
89
  torch_dtype=precision,