Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Update model.py
Browse files
model.py
CHANGED
@@ -153,29 +153,29 @@ class Model:
|
|
153 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
154 |
t1 = time.time()
|
155 |
print(caption_embs.device)
|
156 |
-
index_sample = generate(
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
)
|
169 |
-
sampling_time = time.time() - t1
|
170 |
-
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
171 |
-
|
172 |
-
t2 = time.time()
|
173 |
-
print(index_sample.shape)
|
174 |
-
samples = self.vq_model.decode_code(
|
175 |
-
index_sample, qzshape) # output value is between [-1, 1]
|
176 |
-
decoder_time = time.time() - t2
|
177 |
-
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
180 |
samples = 255 * (samples * 0.5 + 0.5)
|
181 |
samples = [
|
@@ -247,29 +247,31 @@ class Model:
|
|
247 |
c_emb_masks = new_emb_masks
|
248 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
249 |
t1 = time.time()
|
250 |
-
index_sample = generate(
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
)
|
263 |
-
sampling_time = time.time() - t1
|
264 |
-
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
-
|
267 |
-
print(index_sample.shape)
|
268 |
-
samples = self.vq_model.decode_code(index_sample, qzshape)
|
269 |
-
decoder_time = time.time() - t2
|
270 |
-
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
271 |
-
condition_img = condition_img.cpu()
|
272 |
-
samples = samples.cpu()
|
273 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
274 |
samples = 255 * (samples * 0.5 + 0.5)
|
275 |
samples = [
|
|
|
153 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
154 |
t1 = time.time()
|
155 |
print(caption_embs.device)
|
156 |
+
# index_sample = generate(
|
157 |
+
# self.gpt_model,
|
158 |
+
# c_indices,
|
159 |
+
# (H // 16) * (W // 16),
|
160 |
+
# c_emb_masks,
|
161 |
+
# condition=condition_img,
|
162 |
+
# cfg_scale=cfg_scale,
|
163 |
+
# temperature=temperature,
|
164 |
+
# top_k=top_k,
|
165 |
+
# top_p=top_p,
|
166 |
+
# sample_logits=True,
|
167 |
+
# control_strength=control_strength,
|
168 |
+
# )
|
169 |
+
# sampling_time = time.time() - t1
|
170 |
+
# print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
+
# t2 = time.time()
|
173 |
+
# print(index_sample.shape)
|
174 |
+
# samples = self.vq_model.decode_code(
|
175 |
+
# index_sample, qzshape) # output value is between [-1, 1]
|
176 |
+
# decoder_time = time.time() - t2
|
177 |
+
# print(f"decoder takes about {decoder_time:.2f} seconds.")
|
178 |
+
samples = condition_img[0:1]
|
179 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
180 |
samples = 255 * (samples * 0.5 + 0.5)
|
181 |
samples = [
|
|
|
247 |
c_emb_masks = new_emb_masks
|
248 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
249 |
t1 = time.time()
|
250 |
+
# index_sample = generate(
|
251 |
+
# self.gpt_model,
|
252 |
+
# c_indices,
|
253 |
+
# (H // 16) * (W // 16),
|
254 |
+
# c_emb_masks,
|
255 |
+
# condition=condition_img,
|
256 |
+
# cfg_scale=cfg_scale,
|
257 |
+
# temperature=temperature,
|
258 |
+
# top_k=top_k,
|
259 |
+
# top_p=top_p,
|
260 |
+
# sample_logits=True,
|
261 |
+
# control_strength=control_strength,
|
262 |
+
# )
|
263 |
+
# sampling_time = time.time() - t1
|
264 |
+
# print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
265 |
+
|
266 |
+
# t2 = time.time()
|
267 |
+
# print(index_sample.shape)
|
268 |
+
# samples = self.vq_model.decode_code(index_sample, qzshape)
|
269 |
+
# decoder_time = time.time() - t2
|
270 |
+
# print(f"decoder takes about {decoder_time:.2f} seconds.")
|
271 |
+
# condition_img = condition_img.cpu()
|
272 |
+
# samples = samples.cpu()
|
273 |
|
274 |
+
samples = condition_img[0:1]
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
276 |
samples = 255 * (samples * 0.5 + 0.5)
|
277 |
samples = [
|