wondervictor commited on
Commit
0e7e92c
·
1 Parent(s): b62a9c0
Files changed (3) hide show
  1. app.py +2 -2
  2. autoregressive/models/generate.py +1 -1
  3. model.py +9 -6
app.py CHANGED
@@ -54,8 +54,8 @@ hf_hub_download(repo_id="facebook/dinov2-small", filename="pytorch_model.bin", l
54
  DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
55
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
56
  model = Model()
57
- device = "cuda"
58
- model.to(device)
59
  with gr.Blocks(css="style.css") as demo:
60
  gr.Markdown(DESCRIPTION)
61
  gr.DuplicateButton(
 
54
  DESCRIPTION = "# [ControlAR: Controllable Image Generation with Autoregressive Models](https://arxiv.org/abs/2410.02705) \n ### The first row in outputs is the input image and condition. The second row is the images generated by ControlAR. \n ### You can run locally by following the instruction on our [Github Repo](https://github.com/hustvl/ControlAR)."
55
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
56
  model = Model()
57
+ # device = "cuda"
58
+ # model.to(device)
59
  with gr.Blocks(css="style.css") as demo:
60
  gr.Markdown(DESCRIPTION)
61
  gr.DuplicateButton(
autoregressive/models/generate.py CHANGED
@@ -145,7 +145,7 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
145
  print(condition)
146
  condition = torch.ones_like(condition)
147
  condition = model.adapter_mlp(condition)
148
- print(condition)
149
  if model.model_type == 'c2i':
150
  if cfg_scale > 1.0:
151
  cond_null = torch.ones_like(cond) * model.num_classes
 
145
  print(condition)
146
  condition = torch.ones_like(condition)
147
  condition = model.adapter_mlp(condition)
148
+ #print(condition)
149
  if model.model_type == 'c2i':
150
  if cfg_scale > 1.0:
151
  cond_null = torch.ones_like(cond) * model.num_classes
model.py CHANGED
@@ -44,7 +44,7 @@ class Model:
44
 
45
  def __init__(self):
46
  self.device = torch.device(
47
- "cuda:0")
48
  self.base_model_id = ""
49
  self.task_name = ""
50
  self.vq_model = self.load_vq()
@@ -63,7 +63,7 @@ class Model:
63
  def load_vq(self):
64
  vq_model = VQ_models["VQ-16"](codebook_size=16384,
65
  codebook_embed_dim=8)
66
- vq_model.to('cuda')
67
  vq_model.eval()
68
  checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
69
  map_location="cpu")
@@ -82,11 +82,13 @@ class Model:
82
  cls_token_num=120,
83
  model_type='t2i',
84
  condition_type=condition_type,
85
- ).to(device='cuda', dtype=precision)
86
 
87
  model_weight = load_file(gpt_ckpt)
88
- # gpt_model.load_state_dict(model_weight, strict=True)
 
89
  gpt_model.eval()
 
90
  print("gpt model is loaded")
91
  return gpt_model
92
 
@@ -121,8 +123,9 @@ class Model:
121
  image = resize_image_to_16_multiple(image, 'canny')
122
  W, H = image.size
123
  print(W, H)
124
- self.t5_model.model.to(self.device)
125
- self.gpt_model_canny.to(self.device)
 
126
 
127
  condition_img = self.get_control_canny(np.array(image), low_threshold,
128
  high_threshold)
 
44
 
45
  def __init__(self):
46
  self.device = torch.device(
47
+ "cuda")
48
  self.base_model_id = ""
49
  self.task_name = ""
50
  self.vq_model = self.load_vq()
 
63
  def load_vq(self):
64
  vq_model = VQ_models["VQ-16"](codebook_size=16384,
65
  codebook_embed_dim=8)
66
+ # vq_model.to('cuda')
67
  vq_model.eval()
68
  checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt",
69
  map_location="cpu")
 
82
  cls_token_num=120,
83
  model_type='t2i',
84
  condition_type=condition_type,
85
+ ).to(device='cpu', dtype=precision)
86
 
87
  model_weight = load_file(gpt_ckpt)
88
+ # print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight'])
89
+ gpt_model.load_state_dict(model_weight, strict=True)
90
  gpt_model.eval()
91
+ print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight)
92
  print("gpt model is loaded")
93
  return gpt_model
94
 
 
123
  image = resize_image_to_16_multiple(image, 'canny')
124
  W, H = image.size
125
  print(W, H)
126
+ print("before cuda", self.gpt_model_canny.adapter.model.embeddings.patch_embeddings.projection.weight)
127
+ self.t5_model.model.to('cuda')
128
+ self.gpt_model_canny.to('cuda')
129
 
130
  condition_img = self.get_control_canny(np.array(image), low_threshold,
131
  high_threshold)