Pedro Cuenca commited on
Commit
4b8c3a8
1 Parent(s): eb912a1

* JIT outside the loop.

Browse files

My tests yesterday were wrong: there is a noticeable performance
improvement doing it this way. Even so, JIT runs twice, we could cut
times in half (for this test) if we could make it run once.

Files changed (1) hide show
  1. encoding/vqgan-jax-encoding.ipynb +26 -14
encoding/vqgan-jax-encoding.ipynb CHANGED
@@ -363,20 +363,21 @@
363
  },
364
  {
365
  "cell_type": "code",
366
- "execution_count": 18,
367
- "id": "c8b1c229",
368
  "metadata": {},
369
  "outputs": [],
370
  "source": [
371
  "def encode(model, batch):\n",
 
372
  " _, indices = model.encode(batch)\n",
373
  " return indices"
374
  ]
375
  },
376
  {
377
  "cell_type": "code",
378
- "execution_count": 19,
379
- "id": "f2aafe7a",
380
  "metadata": {},
381
  "outputs": [],
382
  "source": [
@@ -969,15 +970,15 @@
969
  },
970
  {
971
  "cell_type": "markdown",
972
- "id": "03643ba1",
973
  "metadata": {},
974
  "source": [
975
- "It works! Let's wrap it and run the whole process on the 10k images subset."
976
  ]
977
  },
978
  {
979
  "cell_type": "markdown",
980
- "id": "1c65d943",
981
  "metadata": {},
982
  "source": [
983
  "## 10k encoding"
@@ -993,8 +994,18 @@
993
  },
994
  {
995
  "cell_type": "code",
996
- "execution_count": 195,
997
- "id": "f69e2073",
 
 
 
 
 
 
 
 
 
 
998
  "metadata": {},
999
  "outputs": [],
1000
  "source": [
@@ -1004,10 +1015,11 @@
1004
  " superbatches = superbatch_generator(dataloader)\n",
1005
  " \n",
1006
  " # TODO: save to disk as we go, do not accumulate everything in RAM\n",
1007
- "# encoder = pmap(lambda batch: encode(model, batch))\n",
 
1008
  " results = None\n",
1009
  " for superbatch in tqdm(superbatches):\n",
1010
- " encoded = pmap(lambda batch: encode(model, batch))(superbatch.numpy())\n",
1011
  " encoded = encoded.reshape(encoded.shape[0] * encoded.shape[1], -1)\n",
1012
  " results = np.concatenate((results, encoded), axis=0) if results is not None else encoded\n",
1013
  " return results"
@@ -1015,15 +1027,15 @@
1015
  },
1016
  {
1017
  "cell_type": "code",
1018
- "execution_count": 199,
1019
- "id": "e9a5565e",
1020
  "metadata": {},
1021
  "outputs": [
1022
  {
1023
  "name": "stderr",
1024
  "output_type": "stream",
1025
  "text": [
1026
- "16it [03:38, 13.64s/it]\n"
1027
  ]
1028
  }
1029
  ],
 
363
  },
364
  {
365
  "cell_type": "code",
366
+ "execution_count": 76,
367
+ "id": "fd26cdce",
368
  "metadata": {},
369
  "outputs": [],
370
  "source": [
371
  "def encode(model, batch):\n",
372
+ "# print(\"jitting encode function\")\n",
373
  " _, indices = model.encode(batch)\n",
374
  " return indices"
375
  ]
376
  },
377
  {
378
  "cell_type": "code",
379
+ "execution_count": 18,
380
+ "id": "c49181e1",
381
  "metadata": {},
382
  "outputs": [],
383
  "source": [
 
970
  },
971
  {
972
  "cell_type": "markdown",
973
+ "id": "48896d5f",
974
  "metadata": {},
975
  "source": [
976
+ "It works! Let's wrap it up and run the whole process on the 10k images subset."
977
  ]
978
  },
979
  {
980
  "cell_type": "markdown",
981
+ "id": "029d35d9",
982
  "metadata": {},
983
  "source": [
984
  "## 10k encoding"
 
994
  },
995
  {
996
  "cell_type": "code",
997
+ "execution_count": 45,
998
+ "id": "04b1568b",
999
+ "metadata": {},
1000
+ "outputs": [],
1001
+ "source": [
1002
+ "from functools import partial"
1003
+ ]
1004
+ },
1005
+ {
1006
+ "cell_type": "code",
1007
+ "execution_count": 78,
1008
+ "id": "bfa3073b",
1009
  "metadata": {},
1010
  "outputs": [],
1011
  "source": [
 
1015
  " superbatches = superbatch_generator(dataloader)\n",
1016
  " \n",
1017
  " # TODO: save to disk as we go, do not accumulate everything in RAM\n",
1018
+ "# p_encoder = pmap(partial(encode, model), in_axes=(0,), donate_argnums=(0))\n",
1019
+ " p_encoder = pmap(lambda batch: encode(model, batch))\n",
1020
  " results = None\n",
1021
  " for superbatch in tqdm(superbatches):\n",
1022
+ " encoded = p_encoder(superbatch.numpy())\n",
1023
  " encoded = encoded.reshape(encoded.shape[0] * encoded.shape[1], -1)\n",
1024
  " results = np.concatenate((results, encoded), axis=0) if results is not None else encoded\n",
1025
  " return results"
 
1027
  },
1028
  {
1029
  "cell_type": "code",
1030
+ "execution_count": 79,
1031
+ "id": "d8d4da18",
1032
  "metadata": {},
1033
  "outputs": [
1034
  {
1035
  "name": "stderr",
1036
  "output_type": "stream",
1037
  "text": [
1038
+ "16it [00:41, 2.61s/it]\n"
1039
  ]
1040
  }
1041
  ],