Spaces:
Running
Running
feat(log_inference_samples): cleanup
Browse files
tools/inference/log_inference_samples.ipynb
CHANGED
@@ -100,11 +100,12 @@
|
|
100 |
"outputs": [],
|
101 |
"source": [
|
102 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
103 |
-
"clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
104 |
-
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
105 |
-
"clip_params = replicate(clip.params)\n",
|
106 |
"vqgan_params = replicate(vqgan.params)\n",
|
107 |
"\n",
|
|
|
|
|
|
|
|
|
108 |
"if add_clip_32:\n",
|
109 |
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
110 |
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
@@ -123,8 +124,8 @@
|
|
123 |
" return vqgan.decode_code(indices, params=params)\n",
|
124 |
"\n",
|
125 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
126 |
-
"def
|
127 |
-
" logits =
|
128 |
" return logits\n",
|
129 |
"\n",
|
130 |
"if add_clip_32:\n",
|
@@ -229,7 +230,7 @@
|
|
229 |
"outputs": [],
|
230 |
"source": [
|
231 |
"run_id = run_ids[0]\n",
|
232 |
-
"# TODO: turn everything into a class"
|
233 |
]
|
234 |
},
|
235 |
{
|
@@ -248,10 +249,8 @@
|
|
248 |
"for artifact in artifact_versions:\n",
|
249 |
" print(f'Processing artifact: {artifact.name}')\n",
|
250 |
" version = int(artifact.version[1:])\n",
|
251 |
-
"
|
252 |
-
"
|
253 |
-
" results32 = []\n",
|
254 |
-
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
255 |
" \n",
|
256 |
" if latest_only:\n",
|
257 |
" assert last_inference_version is None or version > last_inference_version\n",
|
@@ -307,34 +306,13 @@
|
|
307 |
" for img in decoded_images:\n",
|
308 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
|
309 |
"\n",
|
310 |
-
"
|
311 |
-
"
|
312 |
-
" clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
313 |
-
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
314 |
-
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
315 |
-
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
316 |
-
" clip_inputs = shard(clip_inputs)\n",
|
317 |
-
" logits = p_clip(clip_inputs, clip_params)\n",
|
318 |
-
" logits = logits.reshape(-1, num_images)\n",
|
319 |
-
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
320 |
-
" logits = jax.device_get(logits)\n",
|
321 |
-
" # add to results table\n",
|
322 |
-
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
323 |
-
" if sample == padding_item: continue\n",
|
324 |
-
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
325 |
-
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
326 |
-
" top_scores = [scores[x] for x in idx]\n",
|
327 |
-
" results.append([sample] + top_images + top_scores)\n",
|
328 |
-
" \n",
|
329 |
-
" # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
|
330 |
-
" if add_clip_32:\n",
|
331 |
-
" print('Calculating CLIP 32 scores')\n",
|
332 |
-
" clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
333 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
334 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
335 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
336 |
" clip_inputs = shard(clip_inputs)\n",
|
337 |
-
" logits =
|
338 |
" logits = logits.reshape(-1, num_images)\n",
|
339 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
340 |
" logits = jax.device_get(logits)\n",
|
@@ -342,13 +320,24 @@
|
|
342 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
343 |
" if sample == padding_item: continue\n",
|
344 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
345 |
-
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
346 |
-
"
|
347 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
" pbar.close()\n",
|
349 |
"\n",
|
|
|
|
|
350 |
" # log results\n",
|
351 |
-
" table = wandb.Table(columns=columns, data=
|
352 |
" run.log({'Samples': table, 'version': version})\n",
|
353 |
" wandb.finish()\n",
|
354 |
" \n",
|
@@ -359,19 +348,6 @@
|
|
359 |
" wandb.finish()\n",
|
360 |
" run = None # ensure we don't log on this run"
|
361 |
]
|
362 |
-
},
|
363 |
-
{
|
364 |
-
"cell_type": "code",
|
365 |
-
"execution_count": null,
|
366 |
-
"id": "4e4c7d0c-2848-4f88-b967-82fd571534f1",
|
367 |
-
"metadata": {},
|
368 |
-
"outputs": [],
|
369 |
-
"source": [
|
370 |
-
"# TODO: not implemented\n",
|
371 |
-
"def log_runs(runs):\n",
|
372 |
-
" for run in tqdm(runs):\n",
|
373 |
-
" log_run(run)"
|
374 |
-
]
|
375 |
}
|
376 |
],
|
377 |
"metadata": {
|
|
|
100 |
"outputs": [],
|
101 |
"source": [
|
102 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
|
|
|
|
|
|
103 |
"vqgan_params = replicate(vqgan.params)\n",
|
104 |
"\n",
|
105 |
+
"clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
106 |
+
"processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
|
107 |
+
"clip16_params = replicate(clip16.params)\n",
|
108 |
+
"\n",
|
109 |
"if add_clip_32:\n",
|
110 |
" clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
111 |
" processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
|
|
124 |
" return vqgan.decode_code(indices, params=params)\n",
|
125 |
"\n",
|
126 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
127 |
+
"def p_clip16(inputs, params):\n",
|
128 |
+
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
129 |
" return logits\n",
|
130 |
"\n",
|
131 |
"if add_clip_32:\n",
|
|
|
230 |
"outputs": [],
|
231 |
"source": [
|
232 |
"run_id = run_ids[0]\n",
|
233 |
+
"# TODO: turn everything into a class or loop over runs"
|
234 |
]
|
235 |
},
|
236 |
{
|
|
|
249 |
"for artifact in artifact_versions:\n",
|
250 |
" print(f'Processing artifact: {artifact.name}')\n",
|
251 |
" version = int(artifact.version[1:])\n",
|
252 |
+
" results16, results32 = [], []\n",
|
253 |
+
" columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)]\n",
|
|
|
|
|
254 |
" \n",
|
255 |
" if latest_only:\n",
|
256 |
" assert last_inference_version is None or version > last_inference_version\n",
|
|
|
306 |
" for img in decoded_images:\n",
|
307 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
|
308 |
"\n",
|
309 |
+
" def add_clip_results(results, processor, p_clip, clip_params): \n",
|
310 |
+
" clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
312 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
313 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
314 |
" clip_inputs = shard(clip_inputs)\n",
|
315 |
+
" logits = p_clip(clip_inputs, clip32_params)\n",
|
316 |
" logits = logits.reshape(-1, num_images)\n",
|
317 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
318 |
" logits = jax.device_get(logits)\n",
|
|
|
320 |
" for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
|
321 |
" if sample == padding_item: continue\n",
|
322 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
323 |
+
" top_images = [wandb.Image(cur_images[x], caption=f'Score: {scores[x]:.2f}') for x in idx]\n",
|
324 |
+
" results.append([sample] + top_images)\n",
|
325 |
+
" \n",
|
326 |
+
" # get clip scores\n",
|
327 |
+
" pbar.set_description('Calculating CLIP 16 scores')\n",
|
328 |
+
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
329 |
+
" \n",
|
330 |
+
" # get clip 32 scores\n",
|
331 |
+
" if add_clip_32:\n",
|
332 |
+
" pbar.set_description('Calculating CLIP 32 scores')\n",
|
333 |
+
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
334 |
+
"\n",
|
335 |
" pbar.close()\n",
|
336 |
"\n",
|
337 |
+
" \n",
|
338 |
+
"\n",
|
339 |
" # log results\n",
|
340 |
+
" table = wandb.Table(columns=columns, data=results16)\n",
|
341 |
" run.log({'Samples': table, 'version': version})\n",
|
342 |
" wandb.finish()\n",
|
343 |
" \n",
|
|
|
348 |
" wandb.finish()\n",
|
349 |
" run = None # ensure we don't log on this run"
|
350 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
}
|
352 |
],
|
353 |
"metadata": {
|