boris commited on
Commit
335110d
2 Parent(s): fdbe19f 2ef2966

Merge pull request #91 from borisdayma/feat-inf

Browse files
dev/inference/samples.txt ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ white snow covered mountain under blue sky during daytime
2
+ aerial view of the beach at night
3
+ aerial view of the beach during daytime
4
+ a beautiful sunset at a beach with a shell on the shore
5
+ a farmhouse surrounded by beautiful flowers
6
+ a photo of a fantasy version of New York City
7
+ a picture of fantasy kingdoms
8
+ a volcano erupting in the middle of San Francisco
9
+ big wave destroying a city
10
+ Paris in a far future, futuristic Paris
11
+ sunset over green mountains
12
+ the last sunrise on earth
13
+ underwater cathedral
14
+ painting of an oniric forest glade surrounded by tall trees
15
+ real painting of an alien from Monet
16
+ a graphite sketch of a gothic cathedral
17
+ a graphite sketch of Elon Musk
18
+ still life in the style of Kandinsky
19
+ still life in the style of Picasso
20
+ a colorful stairway to heaven
21
+ a background consisting of colors blue, green, and red
22
+ the communist statue of liberty
23
+ robots taking control over humans
24
+ epic sword fight
25
+ an avocado armchair
26
+ an armchair in the shape of an avocado
27
+ logo of an avocado armchair
28
+ an avocado armchair flying into space
29
+ a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
30
+ an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
31
+ illustration of an avocado armchair
32
+ illustration of an avocado armchair getting married to a pineapple
33
+ a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
34
+ Mohammed Ali and Mike Tyson in a hypothetical match
35
+ Pele and Maradona in a hypothetical match
36
+ view of mars from space
37
+ illustration of an astronaut in a space suit playing guitar
38
+ a clown wearing a spacesuit floating in space
39
+ a picture of the eiffel tower on the moon
40
+ watercolor of the Eiffel tower on the moon
41
+ a photo of the French flag on the planet Saturn
42
+ the moon is a skull
43
+ a dog playing with a ball
44
+ a cat sits on top of an alligator
45
+ a rat holding a red lightsaber in a white background
46
+ A unicorn is passing by a rainbow in a field of flowers
47
+ a dog eating worthlessness
48
+ an elephant made of carrots
49
+ an elephant on a unicycle during a circus
50
+ photography of a penguin watching television
51
+ rat wearing a crown
52
+ a portrait of a nightmare creature watching at you
53
+ a white room full of a black substance
54
+ happy, happiness
55
+ sad, sadness
56
+ the representation of infinity
57
+ a cute pikachu teapot
58
+ a picture of a castle from minecraft
59
+ an illustration of pikachu sitting on a bench
60
+ mario eating an avocado while walking his baby koala
61
+ star wars concept art
62
+ a cartoon of a superhero bear
63
+ an illustration of a cute skeleton wearing a blue hoodie
64
+ illustration of a baby shark swimming around corals
65
+ Cartoon of a carrot with big eyes
66
+ logo of a robot wearing glasses and reading a book
67
+ a bottle of coca-cola on a table
68
+ a cactus lifting weights
69
+ a living room with two white armchairs and a painting of the collosseum. The painting is mounted above a modern fireplace.
70
+ a long line of alternating green and red blocks
71
+ a long line of green blocks on a beach at subset
72
+ a long line of peaches on a beach at sunset
73
+ a peanut
74
+ a photo of a camera from the future
75
+ a restaurant menu
76
+ a skeleton with the shape of a spider
77
+ looking into the sky, 10 airplanes are seen overhead
78
+ shelves filled with books and alchemy potion bottles
79
+ this is a detailed high-resolution scan of a human brain
80
+ a collection of glasses is sitting on a table
81
+ a cross-section view of a walnut
82
+ a painting of a capybara sitting on a mountain during fall in surrealist style
83
+ a pentagonal green clock
84
+ a photo of san francisco golden gate bridge
85
+ a pixel art illustration of an eagle sitting in a field in the afternoon
86
+ a professional high-quality emoji of a lovestruck cup of boba
87
+ a small red block sitting on a large green block
88
+ a storefront that has the word 'openai' written on it
89
+ a tatoo of a black broccoli
90
+ a variety of clocks is sitting on a table
91
+ an emoji of a baby fox wearing a blue hat, blue gloves, red shirt, and red pants
92
+ an emoji of a baby penguin wearing a blue hat, blue gloves, red shirt, and green pants
93
+ an extreme close-up view of a capybara sitting in a field
94
+ an illustration of a baby cucumber with a mustache playing chess
95
+ an illustration of a baby daikon radish in a tutu walking a dog
96
+ an illustration of a baby hedgehog in a cape staring at its reflection in a mirror
97
+ an illustration of a baby panda with headphones holding an umbrella in the rain
98
+ an illustration of an avocado in a beanie riding a motorcycle
99
+ urinals are lined up in a jungle
100
+ a human face
101
+ a person is holding a phone and a waterbottle, running a marathon
102
+ a photograph of Ellen G. White
103
+ Young woman riding her bike through the forest
dev/inference/wandb-backend.ipynb ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import tempfile\n",
11
+ "from functools import partial\n",
12
+ "import random\n",
13
+ "import numpy as np\n",
14
+ "from PIL import Image\n",
15
+ "from tqdm.notebook import tqdm\n",
16
+ "import jax\n",
17
+ "import jax.numpy as jnp\n",
18
+ "from flax.training.common_utils import shard, shard_prng_key\n",
19
+ "from flax.jax_utils import replicate\n",
20
+ "import wandb\n",
21
+ "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
22
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
23
+ "from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n",
24
+ "from dalle_mini.text import TextNormalizer"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "run_ids = ['3kaut6e8']\n",
35
+ "ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
36
+ "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
+ "normalize_text = False\n",
38
+ "latest_only = True # log only latest or all versions\n",
39
+ "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
40
+ "add_clip_32 = True"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "run_ids = ['k76r0v39']\n",
51
+ "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
52
+ "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
53
+ "normalize_text = True\n",
54
+ "latest_only = True # log only latest or all versions\n",
55
+ "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
56
+ "add_clip_32 = False"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "batch_size = 8\n",
67
+ "num_images = 128\n",
68
+ "top_k = 8\n",
69
+ "text_normalizer = TextNormalizer() if normalize_text else None\n",
70
+ "padding_item = 'NONE'\n",
71
+ "seed = random.randint(0, 2**32-1)\n",
72
+ "key = jax.random.PRNGKey(seed)\n",
73
+ "api = wandb.Api()"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
84
+ "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
85
+ "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
86
+ "clip_params = replicate(clip.params)\n",
87
+ "vqgan_params = replicate(vqgan.params)\n",
88
+ "\n",
89
+ "if add_clip_32:\n",
90
+ " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
91
+ " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
92
+ " clip32_params = replicate(clip32.params)"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
103
+ "def p_decode(indices, params):\n",
104
+ " return vqgan.decode_code(indices, params=params)\n",
105
+ "\n",
106
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
107
+ "def p_clip(inputs, params):\n",
108
+ " logits = clip(params=params, **inputs).logits_per_image\n",
109
+ " return logits\n",
110
+ "\n",
111
+ "if add_clip_32:\n",
112
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
113
+ " def p_clip32(inputs, params):\n",
114
+ " logits = clip32(params=params, **inputs).logits_per_image\n",
115
+ " return logits"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "ebf4f7bf-2efa-46cc-b3f4-2d7a54f7b2cb",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "clip_params['logit_scale']"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "with open('samples.txt', encoding='utf8') as f:\n",
136
+ " samples = [l.strip() for l in f.readlines()]\n",
137
+ " # make list multiple of batch_size by adding elements\n",
138
+ " samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
139
+ " samples.extend(samples_to_add)\n",
140
+ " # reshape\n",
141
+ " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "def get_artifact_versions(run_id, latest_only=False):\n",
152
+ " try:\n",
153
+ " if latest_only:\n",
154
+ " return [api.artifact(type='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}:latest')]\n",
155
+ " else:\n",
156
+ " return api.artifact_versions(type_name='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}', per_page=10000)\n",
157
+ " except:\n",
158
+ " return []"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "def get_training_config(run_id):\n",
169
+ " training_run = api.run(f'{ENTITY}/{PROJECT}/{run_id}')\n",
170
+ " config = training_run.config\n",
171
+ " return config"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "# retrieve inference run details\n",
182
+ "def get_last_inference_version(run_id):\n",
183
+ " try:\n",
184
+ " inference_run = api.run(f'dalle-mini/dalle-mini/{run_id}-clip16{suffix}')\n",
185
+ " return inference_run.summary.get('version', None)\n",
186
+ " except:\n",
187
+ " return None"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "# compile functions - needed only once per run\n",
198
+ "def pmap_model_function(model):\n",
199
+ " \n",
200
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
201
+ " def _generate(tokenized_prompt, key, params):\n",
202
+ " return model.generate(\n",
203
+ " **tokenized_prompt,\n",
204
+ " do_sample=True,\n",
205
+ " num_beams=1,\n",
206
+ " prng_key=key,\n",
207
+ " params=params\n",
208
+ " )\n",
209
+ " \n",
210
+ " return _generate"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "run_id = run_ids[0]\n",
221
+ "# TODO: turn everything into a class"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "artifact_versions = get_artifact_versions(run_id, latest_only)\n",
232
+ "last_inference_version = get_last_inference_version(run_id)\n",
233
+ "training_config = get_training_config(run_id)\n",
234
+ "run = None\n",
235
+ "p_generate = None\n",
236
+ "model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n",
237
+ "for artifact in artifact_versions:\n",
238
+ " print(f'Processing artifact: {artifact.name}')\n",
239
+ " version = int(artifact.version[1:])\n",
240
+ " results = []\n",
241
+ " if add_clip_32:\n",
242
+ " results32 = []\n",
243
+ " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
244
+ " \n",
245
+ " if latest_only:\n",
246
+ " assert last_inference_version is None or version > last_inference_version\n",
247
+ " else:\n",
248
+ " if last_inference_version is None:\n",
249
+ " # we should start from v0\n",
250
+ " assert version == 0\n",
251
+ " elif version <= last_inference_version:\n",
252
+ " print(f'v{version} has already been logged (versions logged up to v{last_inference_version}')\n",
253
+ " else:\n",
254
+ " # check we are logging the correct version\n",
255
+ " assert version == last_inference_version + 1\n",
256
+ "\n",
257
+ " # start/resume corresponding run\n",
258
+ " if run is None:\n",
259
+ " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip16{suffix}', resume='allow')\n",
260
+ "\n",
261
+ " # work in temporary directory\n",
262
+ " with tempfile.TemporaryDirectory() as tmp:\n",
263
+ "\n",
264
+ " # download model files\n",
265
+ " artifact = run.use_artifact(artifact)\n",
266
+ " for f in model_files:\n",
267
+ " artifact.get_path(f).download(tmp)\n",
268
+ "\n",
269
+ " # load tokenizer and model\n",
270
+ " tokenizer = BartTokenizer.from_pretrained(tmp)\n",
271
+ " model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n",
272
+ " model_params = replicate(model.params)\n",
273
+ "\n",
274
+ " # pmap model function needs to happen only once per model config\n",
275
+ " if p_generate is None:\n",
276
+ " p_generate = pmap_model_function(model)\n",
277
+ "\n",
278
+ " # process one batch of captions\n",
279
+ " for batch in tqdm(samples):\n",
280
+ " processed_prompts = [text_normalizer(x) for x in batch] if normalize_text else list(batch)\n",
281
+ "\n",
282
+ " # repeat the prompts to distribute over each device and tokenize\n",
283
+ " processed_prompts = processed_prompts * jax.device_count()\n",
284
+ " tokenized_prompt = tokenizer(processed_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
285
+ " tokenized_prompt = shard(tokenized_prompt)\n",
286
+ "\n",
287
+ " # generate images\n",
288
+ " images = []\n",
289
+ " for i in tqdm(range(num_images // jax.device_count()), desc='Generating Images'):\n",
290
+ " key, subkey = jax.random.split(key)\n",
291
+ " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
292
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
293
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
294
+ " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
295
+ " for img in decoded_images:\n",
296
+ " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
297
+ "\n",
298
+ " # get clip scores\n",
299
+ " print('Calculating CLIP scores')\n",
300
+ " clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
301
+ " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
302
+ " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
303
+ " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
304
+ " clip_inputs = shard(clip_inputs)\n",
305
+ " logits = p_clip(clip_inputs, clip_params)\n",
306
+ " logits = logits.reshape(-1, num_images)\n",
307
+ " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
308
+ " logits = jax.device_get(logits)\n",
309
+ " # add to results table\n",
310
+ " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
311
+ " if sample == padding_item: continue\n",
312
+ " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
313
+ " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
314
+ " top_scores = [scores[x] for x in idx]\n",
315
+ " results.append([sample] + top_images + top_scores)\n",
316
+ " \n",
317
+ " # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
318
+ " if add_clip_32:\n",
319
+ " print('Calculating CLIP 32 scores')\n",
320
+ " clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
321
+ " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
322
+ " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
323
+ " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
324
+ " clip_inputs = shard(clip_inputs)\n",
325
+ " logits = p_clip32(clip_inputs, clip32_params)\n",
326
+ " logits = logits.reshape(-1, num_images)\n",
327
+ " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
328
+ " logits = jax.device_get(logits)\n",
329
+ " # add to results table\n",
330
+ " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
331
+ " if sample == padding_item: continue\n",
332
+ " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
333
+ " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
334
+ " top_scores = [scores[x] for x in idx]\n",
335
+ " results32.append([sample] + top_images + top_scores)\n",
336
+ "\n",
337
+ " # log results\n",
338
+ " table = wandb.Table(columns=columns, data=results)\n",
339
+ " run.log({'Samples': table, 'version': version})\n",
340
+ " wandb.finish()\n",
341
+ " \n",
342
+ " if add_clip_32: \n",
343
+ " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip32{suffix}', resume='allow')\n",
344
+ " table = wandb.Table(columns=columns, data=results32)\n",
345
+ " run.log({'Samples': table, 'version': version})\n",
346
+ " wandb.finish()\n",
347
+ " run = None # ensure we don't log on this run"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": null,
353
+ "id": "4e4c7d0c-2848-4f88-b967-82fd571534f1",
354
+ "metadata": {},
355
+ "outputs": [],
356
+ "source": [
357
+ "# TODO: not implemented\n",
358
+ "def log_runs(runs):\n",
359
+ " for run in tqdm(runs):\n",
360
+ " log_run(run)"
361
+ ]
362
+ }
363
+ ],
364
+ "metadata": {
365
+ "kernelspec": {
366
+ "display_name": "Python 3 (ipykernel)",
367
+ "language": "python",
368
+ "name": "python3"
369
+ },
370
+ "language_info": {
371
+ "codemirror_mode": {
372
+ "name": "ipython",
373
+ "version": 3
374
+ },
375
+ "file_extension": ".py",
376
+ "mimetype": "text/x-python",
377
+ "name": "python",
378
+ "nbconvert_exporter": "python",
379
+ "pygments_lexer": "ipython3",
380
+ "version": "3.9.7"
381
+ }
382
+ },
383
+ "nbformat": 4,
384
+ "nbformat_minor": 5
385
+ }