tttoaster commited on
Commit
06429f7
1 Parent(s): cd7d79e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -159,19 +159,27 @@ class LLMService:
159
  vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype)
160
 
161
 
162
- unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(dtype=self.dtype)
163
 
164
  sd_adapter_cfg = OmegaConf.load(args.sd_adapter)
165
 
166
- self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(dtype=self.dtype)
 
 
 
 
 
 
 
 
167
 
168
  self.sd_adapter.init_pipe(vae=vae,
169
  scheduler=noise_scheduler,
170
- visual_encoder=self.visual_encoder.cpu(),
171
  image_transform=self.image_transform,
172
  discrete_model=None,
173
  dtype=self.dtype,
174
- device="cpu")
175
 
176
  print('Init sd adapter pipe done.')
177
 
@@ -336,11 +344,11 @@ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
336
  generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '')
337
 
338
  if output['has_img_output']:
339
- print('loading visual encoder and llm to CPU, and sd to GPU')
340
- a = time.time()
341
- service.agent = service.agent.cpu()
342
- service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
343
- print("Loading finished: ", time.time() - a)
344
 
345
  img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype)
346
 
@@ -350,12 +358,12 @@ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
350
  image_base64 = encode_image(generated_image)
351
  gen_imgs_base64_list.append(image_base64)
352
 
353
- print('loading visual encoder and llm to GPU, and sd to CPU')
354
- a = time.time()
355
- service.sd_adapter = service.sd_adapter.cpu()
356
- service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype)
357
- service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype)
358
- print("Loading finished: ", time.time() - a)
359
 
360
  if args.has_bbox:
361
  bboxes = extract_box(generated_text)
 
159
  vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype)
160
 
161
 
162
+ unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(self.vit_sd_device, dtype=self.dtype)
163
 
164
  sd_adapter_cfg = OmegaConf.load(args.sd_adapter)
165
 
166
+ self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(self.vit_sd_device, dtype=self.dtype)
167
+
168
+ # self.sd_adapter.init_pipe(vae=vae,
169
+ # scheduler=noise_scheduler,
170
+ # visual_encoder=self.visual_encoder.cpu(),
171
+ # image_transform=self.image_transform,
172
+ # discrete_model=None,
173
+ # dtype=self.dtype,
174
+ # device="cpu")
175
 
176
  self.sd_adapter.init_pipe(vae=vae,
177
  scheduler=noise_scheduler,
178
+ visual_encoder=self.visual_encoder,
179
  image_transform=self.image_transform,
180
  discrete_model=None,
181
  dtype=self.dtype,
182
+ device=self.vit_sd_device)
183
 
184
  print('Init sd adapter pipe done.')
185
 
 
344
  generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '')
345
 
346
  if output['has_img_output']:
347
+ # print('loading visual encoder and llm to CPU, and sd to GPU')
348
+ # a = time.time()
349
+ # service.agent = service.agent.cpu()
350
+ # service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
351
+ # print("Loading finished: ", time.time() - a)
352
 
353
  img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype)
354
 
 
358
  image_base64 = encode_image(generated_image)
359
  gen_imgs_base64_list.append(image_base64)
360
 
361
+ # print('loading visual encoder and llm to GPU, and sd to CPU')
362
+ # a = time.time()
363
+ # service.sd_adapter = service.sd_adapter.cpu()
364
+ # service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype)
365
+ # service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype)
366
+ # print("Loading finished: ", time.time() - a)
367
 
368
  if args.has_bbox:
369
  bboxes = extract_box(generated_text)