Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
4b8c3a8
1
Parent(s):
eb912a1
* JIT outside the loop.
Browse filesMy tests yesterday were wrong: there is a noticeable performance
improvement doing it this way. Even so, JIT runs twice, we could cut
times in half (for this test) if we could make it run once.
encoding/vqgan-jax-encoding.ipynb
CHANGED
@@ -363,20 +363,21 @@
|
|
363 |
},
|
364 |
{
|
365 |
"cell_type": "code",
|
366 |
-
"execution_count":
|
367 |
-
"id": "
|
368 |
"metadata": {},
|
369 |
"outputs": [],
|
370 |
"source": [
|
371 |
"def encode(model, batch):\n",
|
|
|
372 |
" _, indices = model.encode(batch)\n",
|
373 |
" return indices"
|
374 |
]
|
375 |
},
|
376 |
{
|
377 |
"cell_type": "code",
|
378 |
-
"execution_count":
|
379 |
-
"id": "
|
380 |
"metadata": {},
|
381 |
"outputs": [],
|
382 |
"source": [
|
@@ -969,15 +970,15 @@
|
|
969 |
},
|
970 |
{
|
971 |
"cell_type": "markdown",
|
972 |
-
"id": "
|
973 |
"metadata": {},
|
974 |
"source": [
|
975 |
-
"It works! Let's wrap it and run the whole process on the 10k images subset."
|
976 |
]
|
977 |
},
|
978 |
{
|
979 |
"cell_type": "markdown",
|
980 |
-
"id": "
|
981 |
"metadata": {},
|
982 |
"source": [
|
983 |
"## 10k encoding"
|
@@ -993,8 +994,18 @@
|
|
993 |
},
|
994 |
{
|
995 |
"cell_type": "code",
|
996 |
-
"execution_count":
|
997 |
-
"id": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
998 |
"metadata": {},
|
999 |
"outputs": [],
|
1000 |
"source": [
|
@@ -1004,10 +1015,11 @@
|
|
1004 |
" superbatches = superbatch_generator(dataloader)\n",
|
1005 |
" \n",
|
1006 |
" # TODO: save to disk as we go, do not accumulate everything in RAM\n",
|
1007 |
-
"#
|
|
|
1008 |
" results = None\n",
|
1009 |
" for superbatch in tqdm(superbatches):\n",
|
1010 |
-
" encoded =
|
1011 |
" encoded = encoded.reshape(encoded.shape[0] * encoded.shape[1], -1)\n",
|
1012 |
" results = np.concatenate((results, encoded), axis=0) if results is not None else encoded\n",
|
1013 |
" return results"
|
@@ -1015,15 +1027,15 @@
|
|
1015 |
},
|
1016 |
{
|
1017 |
"cell_type": "code",
|
1018 |
-
"execution_count":
|
1019 |
-
"id": "
|
1020 |
"metadata": {},
|
1021 |
"outputs": [
|
1022 |
{
|
1023 |
"name": "stderr",
|
1024 |
"output_type": "stream",
|
1025 |
"text": [
|
1026 |
-
"16it [
|
1027 |
]
|
1028 |
}
|
1029 |
],
|
|
|
363 |
},
|
364 |
{
|
365 |
"cell_type": "code",
|
366 |
+
"execution_count": 76,
|
367 |
+
"id": "fd26cdce",
|
368 |
"metadata": {},
|
369 |
"outputs": [],
|
370 |
"source": [
|
371 |
"def encode(model, batch):\n",
|
372 |
+
"# print(\"jitting encode function\")\n",
|
373 |
" _, indices = model.encode(batch)\n",
|
374 |
" return indices"
|
375 |
]
|
376 |
},
|
377 |
{
|
378 |
"cell_type": "code",
|
379 |
+
"execution_count": 18,
|
380 |
+
"id": "c49181e1",
|
381 |
"metadata": {},
|
382 |
"outputs": [],
|
383 |
"source": [
|
|
|
970 |
},
|
971 |
{
|
972 |
"cell_type": "markdown",
|
973 |
+
"id": "48896d5f",
|
974 |
"metadata": {},
|
975 |
"source": [
|
976 |
+
"It works! Let's wrap it up and run the whole process on the 10k images subset."
|
977 |
]
|
978 |
},
|
979 |
{
|
980 |
"cell_type": "markdown",
|
981 |
+
"id": "029d35d9",
|
982 |
"metadata": {},
|
983 |
"source": [
|
984 |
"## 10k encoding"
|
|
|
994 |
},
|
995 |
{
|
996 |
"cell_type": "code",
|
997 |
+
"execution_count": 45,
|
998 |
+
"id": "04b1568b",
|
999 |
+
"metadata": {},
|
1000 |
+
"outputs": [],
|
1001 |
+
"source": [
|
1002 |
+
"from functools import partial"
|
1003 |
+
]
|
1004 |
+
},
|
1005 |
+
{
|
1006 |
+
"cell_type": "code",
|
1007 |
+
"execution_count": 78,
|
1008 |
+
"id": "bfa3073b",
|
1009 |
"metadata": {},
|
1010 |
"outputs": [],
|
1011 |
"source": [
|
|
|
1015 |
" superbatches = superbatch_generator(dataloader)\n",
|
1016 |
" \n",
|
1017 |
" # TODO: save to disk as we go, do not accumulate everything in RAM\n",
|
1018 |
+
"# p_encoder = pmap(partial(encode, model), in_axes=(0,), donate_argnums=(0))\n",
|
1019 |
+
" p_encoder = pmap(lambda batch: encode(model, batch))\n",
|
1020 |
" results = None\n",
|
1021 |
" for superbatch in tqdm(superbatches):\n",
|
1022 |
+
" encoded = p_encoder(superbatch.numpy())\n",
|
1023 |
" encoded = encoded.reshape(encoded.shape[0] * encoded.shape[1], -1)\n",
|
1024 |
" results = np.concatenate((results, encoded), axis=0) if results is not None else encoded\n",
|
1025 |
" return results"
|
|
|
1027 |
},
|
1028 |
{
|
1029 |
"cell_type": "code",
|
1030 |
+
"execution_count": 79,
|
1031 |
+
"id": "d8d4da18",
|
1032 |
"metadata": {},
|
1033 |
"outputs": [
|
1034 |
{
|
1035 |
"name": "stderr",
|
1036 |
"output_type": "stream",
|
1037 |
"text": [
|
1038 |
+
"16it [00:41, 2.61s/it]\n"
|
1039 |
]
|
1040 |
}
|
1041 |
],
|