wondervictor commited on
Commit
340cc7b
·
verified ·
1 Parent(s): 8b9ace1

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +46 -44
model.py CHANGED
@@ -153,29 +153,29 @@ class Model:
153
  qzshape = [len(c_indices), 8, H // 16, W // 16]
154
  t1 = time.time()
155
  print(caption_embs.device)
156
- index_sample = generate(
157
- self.gpt_model,
158
- c_indices,
159
- (H // 16) * (W // 16),
160
- c_emb_masks,
161
- condition=condition_img,
162
- cfg_scale=cfg_scale,
163
- temperature=temperature,
164
- top_k=top_k,
165
- top_p=top_p,
166
- sample_logits=True,
167
- control_strength=control_strength,
168
- )
169
- sampling_time = time.time() - t1
170
- print(f"Full sampling takes about {sampling_time:.2f} seconds.")
171
-
172
- t2 = time.time()
173
- print(index_sample.shape)
174
- samples = self.vq_model.decode_code(
175
- index_sample, qzshape) # output value is between [-1, 1]
176
- decoder_time = time.time() - t2
177
- print(f"decoder takes about {decoder_time:.2f} seconds.")
178
 
 
 
 
 
 
 
 
179
  samples = torch.cat((condition_img[0:1], samples), dim=0)
180
  samples = 255 * (samples * 0.5 + 0.5)
181
  samples = [
@@ -247,29 +247,31 @@ class Model:
247
  c_emb_masks = new_emb_masks
248
  qzshape = [len(c_indices), 8, H // 16, W // 16]
249
  t1 = time.time()
250
- index_sample = generate(
251
- self.gpt_model,
252
- c_indices,
253
- (H // 16) * (W // 16),
254
- c_emb_masks,
255
- condition=condition_img,
256
- cfg_scale=cfg_scale,
257
- temperature=temperature,
258
- top_k=top_k,
259
- top_p=top_p,
260
- sample_logits=True,
261
- control_strength=control_strength,
262
- )
263
- sampling_time = time.time() - t1
264
- print(f"Full sampling takes about {sampling_time:.2f} seconds.")
 
 
 
 
 
 
 
 
265
 
266
- t2 = time.time()
267
- print(index_sample.shape)
268
- samples = self.vq_model.decode_code(index_sample, qzshape)
269
- decoder_time = time.time() - t2
270
- print(f"decoder takes about {decoder_time:.2f} seconds.")
271
- condition_img = condition_img.cpu()
272
- samples = samples.cpu()
273
  samples = torch.cat((condition_img[0:1], samples), dim=0)
274
  samples = 255 * (samples * 0.5 + 0.5)
275
  samples = [
 
153
  qzshape = [len(c_indices), 8, H // 16, W // 16]
154
  t1 = time.time()
155
  print(caption_embs.device)
156
+ # index_sample = generate(
157
+ # self.gpt_model,
158
+ # c_indices,
159
+ # (H // 16) * (W // 16),
160
+ # c_emb_masks,
161
+ # condition=condition_img,
162
+ # cfg_scale=cfg_scale,
163
+ # temperature=temperature,
164
+ # top_k=top_k,
165
+ # top_p=top_p,
166
+ # sample_logits=True,
167
+ # control_strength=control_strength,
168
+ # )
169
+ # sampling_time = time.time() - t1
170
+ # print(f"Full sampling takes about {sampling_time:.2f} seconds.")
 
 
 
 
 
 
 
171
 
172
+ # t2 = time.time()
173
+ # print(index_sample.shape)
174
+ # samples = self.vq_model.decode_code(
175
+ # index_sample, qzshape) # output value is between [-1, 1]
176
+ # decoder_time = time.time() - t2
177
+ # print(f"decoder takes about {decoder_time:.2f} seconds.")
178
+ samples = condition_img[0:1]
179
  samples = torch.cat((condition_img[0:1], samples), dim=0)
180
  samples = 255 * (samples * 0.5 + 0.5)
181
  samples = [
 
247
  c_emb_masks = new_emb_masks
248
  qzshape = [len(c_indices), 8, H // 16, W // 16]
249
  t1 = time.time()
250
+ # index_sample = generate(
251
+ # self.gpt_model,
252
+ # c_indices,
253
+ # (H // 16) * (W // 16),
254
+ # c_emb_masks,
255
+ # condition=condition_img,
256
+ # cfg_scale=cfg_scale,
257
+ # temperature=temperature,
258
+ # top_k=top_k,
259
+ # top_p=top_p,
260
+ # sample_logits=True,
261
+ # control_strength=control_strength,
262
+ # )
263
+ # sampling_time = time.time() - t1
264
+ # print(f"Full sampling takes about {sampling_time:.2f} seconds.")
265
+
266
+ # t2 = time.time()
267
+ # print(index_sample.shape)
268
+ # samples = self.vq_model.decode_code(index_sample, qzshape)
269
+ # decoder_time = time.time() - t2
270
+ # print(f"decoder takes about {decoder_time:.2f} seconds.")
271
+ # condition_img = condition_img.cpu()
272
+ # samples = samples.cpu()
273
 
274
+ samples = condition_img[0:1]
 
 
 
 
 
 
275
  samples = torch.cat((condition_img[0:1], samples), dim=0)
276
  samples = 255 * (samples * 0.5 + 0.5)
277
  samples = [