Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Commit
·
0e7e92c
1
Parent(s):
b62a9c0
update
Browse files- app.py +2 -2
- autoregressive/models/generate.py +1 -1
- model.py +9 -6
app.py
CHANGED
@@ -54,8 +54,8 @@ hf_hub_download(repo_id="facebook/dinov2-small", filename="pytorch_model.bin", l
|
|
54 |
DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
|
55 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
56 |
model = Model()
|
57 |
-
device = "cuda"
|
58 |
-
model.to(device)
|
59 |
with gr.Blocks(css="style.css") as demo:
|
60 |
gr.Markdown(DESCRIPTION)
|
61 |
gr.DuplicateButton(
|
|
|
54 |
DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
|
55 |
SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
|
56 |
model = Model()
|
57 |
+
# device = "cuda"
|
58 |
+
# model.to(device)
|
59 |
with gr.Blocks(css="style.css") as demo:
|
60 |
gr.Markdown(DESCRIPTION)
|
61 |
gr.DuplicateButton(
|
autoregressive/models/generate.py
CHANGED
@@ -145,7 +145,7 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
|
|
145 |
print(condition)
|
146 |
condition = torch.ones_like(condition)
|
147 |
condition = model.adapter_mlp(condition)
|
148 |
-
print(condition)
|
149 |
if model.model_type == 'c2i':
|
150 |
if cfg_scale > 1.0:
|
151 |
cond_null = torch.ones_like(cond) * model.num_classes
|
|
|
145 |
print(condition)
|
146 |
condition = torch.ones_like(condition)
|
147 |
condition = model.adapter_mlp(condition)
|
148 |
+
#print(condition)
|
149 |
if model.model_type == 'c2i':
|
150 |
if cfg_scale > 1.0:
|
151 |
cond_null = torch.ones_like(cond) * model.num_classes
|
model.py
CHANGED
@@ -44,7 +44,7 @@ class Model:
|
|
44 |
|
45 |
def __init__(self):
|
46 |
self.device = torch.device(
|
47 |
-
"cuda
|
48 |
self.base_model_id = ""
|
49 |
self.task_name = ""
|
50 |
self.vq_model = self.load_vq()
|
@@ -63,7 +63,7 @@ class Model:
|
|
63 |
def load_vq(self):
|
64 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
65 |
codebook_embed_dim=8)
|
66 |
-
vq_model.to('cuda')
|
67 |
vq_model.eval()
|
68 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
69 |
map_location="cpu")
|
@@ -82,11 +82,13 @@ class Model:
|
|
82 |
cls_token_num=120,
|
83 |
model_type='t2i',
|
84 |
condition_type=condition_type,
|
85 |
-
).to(device='
|
86 |
|
87 |
model_weight = load_file(gpt_ckpt)
|
88 |
-
#
|
|
|
89 |
gpt_model.eval()
|
|
|
90 |
print("gpt model is loaded")
|
91 |
return gpt_model
|
92 |
|
@@ -121,8 +123,9 @@ class Model:
|
|
121 |
image = resize_image_to_16_multiple(image, 'canny')
|
122 |
W, H = image.size
|
123 |
print(W, H)
|
124 |
-
self.
|
125 |
-
self.
|
|
|
126 |
|
127 |
condition_img = self.get_control_canny(np.array(image), low_threshold,
|
128 |
high_threshold)
|
|
|
44 |
|
45 |
def __init__(self):
|
46 |
self.device = torch.device(
|
47 |
+
"cuda")
|
48 |
self.base_model_id = ""
|
49 |
self.task_name = ""
|
50 |
self.vq_model = self.load_vq()
|
|
|
63 |
def load_vq(self):
|
64 |
vq_model = VQ_models["VQ-16"](codebook_size=16384,
|
65 |
codebook_embed_dim=8)
|
66 |
+
# vq_model.to('cuda')
|
67 |
vq_model.eval()
|
68 |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
|
69 |
map_location="cpu")
|
|
|
82 |
cls_token_num=120,
|
83 |
model_type='t2i',
|
84 |
condition_type=condition_type,
|
85 |
+
).to(device='cpu', dtype=precision)
|
86 |
|
87 |
model_weight = load_file(gpt_ckpt)
|
88 |
+
# print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight'])
|
89 |
+
gpt_model.load_state_dict(model_weight, strict=True)
|
90 |
gpt_model.eval()
|
91 |
+
print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight)
|
92 |
print("gpt model is loaded")
|
93 |
return gpt_model
|
94 |
|
|
|
123 |
image = resize_image_to_16_multiple(image, 'canny')
|
124 |
W, H = image.size
|
125 |
print(W, H)
|
126 |
+
print("before cuda", self.gpt_model_canny.adapter.model.embeddings.patch_embeddings.projection.weight)
|
127 |
+
self.t5_model.model.to('cuda')
|
128 |
+
self.gpt_model_canny.to('cuda')
|
129 |
|
130 |
condition_img = self.get_control_canny(np.array(image), low_threshold,
|
131 |
high_threshold)
|