Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Commit
·
b962858
1
Parent(s):
d8de5a4
add requirements
Browse files
model.py
CHANGED
@@ -59,7 +59,7 @@ class Model:
|
|
59 |
map_location="cpu")
|
60 |
vq_model.load_state_dict(checkpoint["model"])
|
61 |
del checkpoint
|
62 |
-
print(
|
63 |
return vq_model
|
64 |
|
65 |
def load_gpt(self, condition_type='canny'):
|
@@ -76,14 +76,14 @@ class Model:
|
|
76 |
model_weight = load_file(gpt_ckpt)
|
77 |
gpt_model.load_state_dict(model_weight, strict=False)
|
78 |
gpt_model.eval()
|
79 |
-
print(
|
80 |
return gpt_model
|
81 |
|
82 |
def load_t5(self):
|
83 |
precision = torch.bfloat16
|
84 |
t5_model = T5Embedder(
|
85 |
device=self.device,
|
86 |
-
local_cache=
|
87 |
cache_dir='checkpoints/flan-t5-xl',
|
88 |
dir_or_name='flan-t5-xl',
|
89 |
torch_dtype=precision,
|
|
|
59 |
map_location="cpu")
|
60 |
vq_model.load_state_dict(checkpoint["model"])
|
61 |
del checkpoint
|
62 |
+
print("image tokenizer is loaded")
|
63 |
return vq_model
|
64 |
|
65 |
def load_gpt(self, condition_type='canny'):
|
|
|
76 |
model_weight = load_file(gpt_ckpt)
|
77 |
gpt_model.load_state_dict(model_weight, strict=False)
|
78 |
gpt_model.eval()
|
79 |
+
print("gpt model is loaded")
|
80 |
return gpt_model
|
81 |
|
82 |
def load_t5(self):
|
83 |
precision = torch.bfloat16
|
84 |
t5_model = T5Embedder(
|
85 |
device=self.device,
|
86 |
+
local_cache=False,
|
87 |
cache_dir='checkpoints/flan-t5-xl',
|
88 |
dir_or_name='flan-t5-xl',
|
89 |
torch_dtype=precision,
|