Spaces:
Running
Running
feat: improve inference demo
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -41,10 +41,10 @@
|
|
41 |
"outputs": [],
|
42 |
"source": [
|
43 |
"# Install required libraries\n",
|
44 |
-
"
|
45 |
-
"
|
46 |
-
"
|
47 |
-
"
|
48 |
]
|
49 |
},
|
50 |
{
|
@@ -70,8 +70,8 @@
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
-
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-
|
74 |
-
"DALLE_COMMIT_ID = None
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
77 |
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
@@ -91,13 +91,20 @@
|
|
91 |
"import jax\n",
|
92 |
"import jax.numpy as jnp\n",
|
93 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
"# type used for computation - use bfloat16 on TPU's\n",
|
95 |
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
|
96 |
"\n",
|
97 |
-
"# TODO
|
98 |
-
"# - we currently have an issue with model.generate() in bfloat16\n",
|
99 |
-
"# - https://github.com/google/jax/pull/9089 should fix it\n",
|
100 |
-
"# - remove below line and test on TPU with next release of JAX\n",
|
101 |
"dtype = jnp.float32"
|
102 |
]
|
103 |
},
|
@@ -115,35 +122,18 @@
|
|
115 |
"outputs": [],
|
116 |
"source": [
|
117 |
"# Load models & tokenizer\n",
|
118 |
-
"from dalle_mini.model import DalleBart\n",
|
119 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
120 |
-
"from transformers import
|
121 |
"import wandb\n",
|
122 |
"\n",
|
123 |
"# Load dalle-mini\n",
|
124 |
-
"
|
125 |
-
"
|
126 |
-
"
|
127 |
-
"
|
128 |
-
"
|
129 |
-
"
|
130 |
-
" \"flax_model.msgpack\",\n",
|
131 |
-
" \"merges.txt\",\n",
|
132 |
-
" \"special_tokens_map.json\",\n",
|
133 |
-
" \"tokenizer.json\",\n",
|
134 |
-
" \"tokenizer_config.json\",\n",
|
135 |
-
" \"vocab.json\",\n",
|
136 |
-
" ]\n",
|
137 |
-
" for f in model_files:\n",
|
138 |
-
" artifact.get_path(f).download(\"model\")\n",
|
139 |
-
" model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
|
140 |
-
" tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
|
141 |
-
"else:\n",
|
142 |
-
" # local folder or 🤗 Hub\n",
|
143 |
-
" model = DalleBart.from_pretrained(\n",
|
144 |
-
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
145 |
-
" )\n",
|
146 |
-
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
147 |
"\n",
|
148 |
"# Load VQGAN\n",
|
149 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
@@ -210,7 +200,8 @@
|
|
210 |
" prng_key=key,\n",
|
211 |
" params=params,\n",
|
212 |
" top_k=top_k,\n",
|
213 |
-
" top_p=top_p
|
|
|
214 |
" )\n",
|
215 |
"\n",
|
216 |
"\n",
|
@@ -233,7 +224,7 @@
|
|
233 |
"id": "HmVN6IBwapBA"
|
234 |
},
|
235 |
"source": [
|
236 |
-
"Keys are passed to the model on each device to generate unique
|
237 |
]
|
238 |
},
|
239 |
{
|
@@ -247,7 +238,7 @@
|
|
247 |
"import random\n",
|
248 |
"\n",
|
249 |
"# create a random key\n",
|
250 |
-
"seed = random.randint(0, 2
|
251 |
"key = jax.random.PRNGKey(seed)"
|
252 |
]
|
253 |
},
|
@@ -299,7 +290,7 @@
|
|
299 |
},
|
300 |
"outputs": [],
|
301 |
"source": [
|
302 |
-
"prompt = \"a
|
303 |
]
|
304 |
},
|
305 |
{
|
@@ -316,27 +307,19 @@
|
|
316 |
},
|
317 |
{
|
318 |
"cell_type": "markdown",
|
319 |
-
"metadata": {
|
320 |
-
"id": "iFVOyYboP0L-"
|
321 |
-
},
|
322 |
"source": [
|
323 |
-
"We
|
324 |
]
|
325 |
},
|
326 |
{
|
327 |
"cell_type": "code",
|
328 |
"execution_count": null,
|
329 |
-
"metadata": {
|
330 |
-
"id": "Rii_FJ7POw1y"
|
331 |
-
},
|
332 |
"outputs": [],
|
333 |
"source": [
|
334 |
-
"# repeat the prompt on each device\n",
|
335 |
-
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
336 |
-
"\n",
|
337 |
-
"# tokenize\n",
|
338 |
"tokenized_prompt = tokenizer(\n",
|
339 |
-
"
|
340 |
" return_tensors=\"jax\",\n",
|
341 |
" padding=\"max_length\",\n",
|
342 |
" truncation=True,\n",
|
@@ -360,24 +343,18 @@
|
|
360 |
},
|
361 |
{
|
362 |
"cell_type": "markdown",
|
363 |
-
"metadata": {
|
364 |
-
"id": "2wiDtG3_SH2u"
|
365 |
-
},
|
366 |
"source": [
|
367 |
-
"Finally we
|
368 |
]
|
369 |
},
|
370 |
{
|
371 |
"cell_type": "code",
|
372 |
"execution_count": null,
|
373 |
-
"metadata": {
|
374 |
-
"id": "AImyrxHtR9TG"
|
375 |
-
},
|
376 |
"outputs": [],
|
377 |
"source": [
|
378 |
-
"
|
379 |
-
"\n",
|
380 |
-
"tokenized_prompt = shard(tokenized_prompt)"
|
381 |
]
|
382 |
},
|
383 |
{
|
@@ -455,6 +432,8 @@
|
|
455 |
},
|
456 |
"outputs": [],
|
457 |
"source": [
|
|
|
|
|
458 |
"# get clip scores\n",
|
459 |
"clip_inputs = processor(\n",
|
460 |
" text=[prompt] * jax.device_count(),\n",
|
|
|
41 |
"outputs": [],
|
42 |
"source": [
|
43 |
"# Install required libraries\n",
|
44 |
+
"#!pip install -q transformers\n",
|
45 |
+
"#!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
46 |
+
"#!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n",
|
47 |
+
"#!pip install -q wandb"
|
48 |
]
|
49 |
},
|
50 |
{
|
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-mehdx7dg:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
|
74 |
+
"DALLE_COMMIT_ID = None\n",
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
77 |
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
|
|
91 |
"import jax\n",
|
92 |
"import jax.numpy as jnp\n",
|
93 |
"\n",
|
94 |
+
"# check how many devices are available\n",
|
95 |
+
"jax.local_device_count()"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": null,
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
"# type used for computation - use bfloat16 on TPU's\n",
|
105 |
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
|
106 |
"\n",
|
107 |
+
"# TODO: fix issue with bfloat16\n",
|
|
|
|
|
|
|
108 |
"dtype = jnp.float32"
|
109 |
]
|
110 |
},
|
|
|
122 |
"outputs": [],
|
123 |
"source": [
|
124 |
"# Load models & tokenizer\n",
|
125 |
+
"from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
|
126 |
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
127 |
+
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
128 |
"import wandb\n",
|
129 |
"\n",
|
130 |
"# Load dalle-mini\n",
|
131 |
+
"model = DalleBart.from_pretrained(\n",
|
132 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
133 |
+
")\n",
|
134 |
+
"tokenizer = DalleBartTokenizer.from_pretrained(\n",
|
135 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID\n",
|
136 |
+
")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
"\n",
|
138 |
"# Load VQGAN\n",
|
139 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
|
|
200 |
" prng_key=key,\n",
|
201 |
" params=params,\n",
|
202 |
" top_k=top_k,\n",
|
203 |
+
" top_p=top_p,\n",
|
204 |
+
" max_length=257\n",
|
205 |
" )\n",
|
206 |
"\n",
|
207 |
"\n",
|
|
|
224 |
"id": "HmVN6IBwapBA"
|
225 |
},
|
226 |
"source": [
|
227 |
+
"Keys are passed to the model on each device to generate unique inference per device."
|
228 |
]
|
229 |
},
|
230 |
{
|
|
|
238 |
"import random\n",
|
239 |
"\n",
|
240 |
"# create a random key\n",
|
241 |
+
"seed = random.randint(0, 2**32 - 1)\n",
|
242 |
"key = jax.random.PRNGKey(seed)"
|
243 |
]
|
244 |
},
|
|
|
290 |
},
|
291 |
"outputs": [],
|
292 |
"source": [
|
293 |
+
"prompt = \"a waterfall under the sunset\""
|
294 |
]
|
295 |
},
|
296 |
{
|
|
|
307 |
},
|
308 |
{
|
309 |
"cell_type": "markdown",
|
310 |
+
"metadata": {},
|
|
|
|
|
311 |
"source": [
|
312 |
+
"We tokenize the prompt."
|
313 |
]
|
314 |
},
|
315 |
{
|
316 |
"cell_type": "code",
|
317 |
"execution_count": null,
|
318 |
+
"metadata": {},
|
|
|
|
|
319 |
"outputs": [],
|
320 |
"source": [
|
|
|
|
|
|
|
|
|
321 |
"tokenized_prompt = tokenizer(\n",
|
322 |
+
" processed_prompt,\n",
|
323 |
" return_tensors=\"jax\",\n",
|
324 |
" padding=\"max_length\",\n",
|
325 |
" truncation=True,\n",
|
|
|
343 |
},
|
344 |
{
|
345 |
"cell_type": "markdown",
|
346 |
+
"metadata": {},
|
|
|
|
|
347 |
"source": [
|
348 |
+
"Finally we replicate it onto each device."
|
349 |
]
|
350 |
},
|
351 |
{
|
352 |
"cell_type": "code",
|
353 |
"execution_count": null,
|
354 |
+
"metadata": {},
|
|
|
|
|
355 |
"outputs": [],
|
356 |
"source": [
|
357 |
+
"tokenized_prompt = replicate(tokenized_prompt)"
|
|
|
|
|
358 |
]
|
359 |
},
|
360 |
{
|
|
|
432 |
},
|
433 |
"outputs": [],
|
434 |
"source": [
|
435 |
+
"from flax.training.common_utils import shard\n",
|
436 |
+
"\n",
|
437 |
"# get clip scores\n",
|
438 |
"clip_inputs = processor(\n",
|
439 |
" text=[prompt] * jax.device_count(),\n",
|