codeShare commited on
Commit
540a0c2
·
verified ·
1 Parent(s): a78db43

Upload sd_token_similarity_calculator.ipynb

Browse files
Files changed (1) hide show
  1. sd_token_similarity_calculator.ipynb +179 -12
sd_token_similarity_calculator.ipynb CHANGED
@@ -116,10 +116,23 @@
116
  "metadata": {
117
  "id": "Ch9puvwKH1s3",
118
  "collapsed": true,
119
- "cellView": "form"
 
 
 
 
120
  },
121
- "execution_count": null,
122
- "outputs": []
 
 
 
 
 
 
 
 
 
123
  },
124
  {
125
  "cell_type": "code",
@@ -128,7 +141,8 @@
128
  "from transformers import AutoTokenizer\n",
129
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
130
  "\n",
131
- "prompt= \"banana\" # @param {type:'string'}\n",
 
132
  "\n",
133
  "tokenizer_output = tokenizer(text = prompt)\n",
134
  "input_ids = tokenizer_output['input_ids']\n",
@@ -152,11 +166,15 @@
152
  " A = R*(_A/_R)\n",
153
  " name_A = 'random_A'\n",
154
  "\n",
155
- "\n",
156
- "mix_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"(optional) write something else\"}\n",
157
  "mix_method = \"None\" # @param [\"None\" , \"Average\", \"Subtract\"] {allow-input: true}\n",
158
  "w = 0.5 # @param {type:\"slider\", min:0, max:1, step:0.01}\n",
159
  "\n",
 
 
 
 
160
  "tokenizer_output = tokenizer(text = mix_with)\n",
161
  "input_ids = tokenizer_output['input_ids']\n",
162
  "id_C = input_ids[1]\n",
@@ -205,7 +223,7 @@
205
  " A = (_A/_tmp)*tmp\n",
206
  " #//---//\n",
207
  " _A = LA.vector_norm(A, ord=2)\n",
208
- " print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = _A * norm(w*A - (1-w)*C) , where C is '{name_C}' token , for w = {w} \")\n",
209
  "\n",
210
  "#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor\n",
211
  "\n",
@@ -231,6 +249,7 @@
231
  "\n",
232
  "#Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result\n",
233
  "\n",
 
234
  "list_size = 100 # @param {type:'number'}\n",
235
  "print_ID = False # @param {type:\"boolean\"}\n",
236
  "print_Similarity = True # @param {type:\"boolean\"}\n",
@@ -259,8 +278,7 @@
259
  "#Print the sorted list from above result"
260
  ],
261
  "metadata": {
262
- "id": "iWeFnT1gAx6A",
263
- "cellView": "form"
264
  },
265
  "execution_count": null,
266
  "outputs": []
@@ -270,7 +288,7 @@
270
  "source": [
271
  "# @title 💫 Compare Text encodings\n",
272
  "\n",
273
- "prompt_A = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
274
  "prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
275
  "use_token_padding = True # @param {type:\"boolean\"}\n",
276
  "\n",
@@ -283,6 +301,7 @@
283
  "ids_A = processor.tokenizer(text=prompt_A, padding=use_token_padding, return_tensors=\"pt\")\n",
284
  "text_encoding_A = model.get_text_features(**ids_A)\n",
285
  "\n",
 
286
  "ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
287
  "text_encoding_B = model.get_text_features(**ids_B)\n",
288
  "\n",
@@ -296,8 +315,156 @@
296
  ],
297
  "metadata": {
298
  "id": "QQOjh5BvnG8M",
299
- "collapsed": true,
300
- "cellView": "form"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  },
302
  "execution_count": null,
303
  "outputs": []
 
116
  "metadata": {
117
  "id": "Ch9puvwKH1s3",
118
  "collapsed": true,
119
+ "cellView": "form",
120
+ "outputId": "9a9d4274-a633-464b-e1fb-06a33f3dd873",
121
+ "colab": {
122
+ "base_uri": "https://localhost:8080/"
123
+ }
124
  },
125
+ "execution_count": 59,
126
+ "outputs": [
127
+ {
128
+ "output_type": "stream",
129
+ "name": "stdout",
130
+ "text": [
131
+ "fatal: destination path 'sd_tokens' already exists and is not an empty directory.\n",
132
+ "/content/sd_tokens\n"
133
+ ]
134
+ }
135
+ ]
136
  },
137
  {
138
  "cell_type": "code",
 
141
  "from transformers import AutoTokenizer\n",
142
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
143
  "\n",
144
+ "# @markdown Write name of token to match against\n",
145
+ "prompt= \"banana\" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
146
  "\n",
147
  "tokenizer_output = tokenizer(text = prompt)\n",
148
  "input_ids = tokenizer_output['input_ids']\n",
 
166
  " A = R*(_A/_R)\n",
167
  " name_A = 'random_A'\n",
168
  "\n",
169
+ "# @markdown (optional) Mix the token with something else\n",
170
+ "mix_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for random value token\"}\n",
171
  "mix_method = \"None\" # @param [\"None\" , \"Average\", \"Subtract\"] {allow-input: true}\n",
172
  "w = 0.5 # @param {type:\"slider\", min:0, max:1, step:0.01}\n",
173
  "\n",
174
+ "# @markdown Limit char size of included token\n",
175
+ "min_char_size = 3 # @param {type:\"slider\", min:0, max: 50, step:1}\n",
176
+ "char_range = 5 # @param {type:\"slider\", min:0, max: 50, step:1}\n",
177
+ "\n",
178
  "tokenizer_output = tokenizer(text = mix_with)\n",
179
  "input_ids = tokenizer_output['input_ids']\n",
180
  "id_C = input_ids[1]\n",
 
223
  " A = (_A/_tmp)*tmp\n",
224
  " #//---//\n",
225
  " _A = LA.vector_norm(A, ord=2)\n",
226
+ " print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = _A*norm(w*A - (1-w)*C) , where C is '{name_C}' token , for w = {w} \")\n",
227
  "\n",
228
  "#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor\n",
229
  "\n",
 
249
  "\n",
250
  "#Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result\n",
251
  "\n",
252
+ "# @markdown Set print options\n",
253
  "list_size = 100 # @param {type:'number'}\n",
254
  "print_ID = False # @param {type:\"boolean\"}\n",
255
  "print_Similarity = True # @param {type:\"boolean\"}\n",
 
278
  "#Print the sorted list from above result"
279
  ],
280
  "metadata": {
281
+ "id": "iWeFnT1gAx6A"
 
282
  },
283
  "execution_count": null,
284
  "outputs": []
 
288
  "source": [
289
  "# @title 💫 Compare Text encodings\n",
290
  "\n",
291
+ "prompt_A = \"banana\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
292
  "prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
293
  "use_token_padding = True # @param {type:\"boolean\"}\n",
294
  "\n",
 
301
  "ids_A = processor.tokenizer(text=prompt_A, padding=use_token_padding, return_tensors=\"pt\")\n",
302
  "text_encoding_A = model.get_text_features(**ids_A)\n",
303
  "\n",
304
+ "\n",
305
  "ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
306
  "text_encoding_B = model.get_text_features(**ids_B)\n",
307
  "\n",
 
315
  ],
316
  "metadata": {
317
  "id": "QQOjh5BvnG8M",
318
+ "collapsed": true
319
+ },
320
+ "execution_count": null,
321
+ "outputs": []
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "source": [
326
+ "# @title 🪐 Find similiar prompt\n",
327
+ "# @markdown Prompt A to match against\n",
328
+ "prompt_A = \"photo of a banana\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
329
+ "# @markdown Set conditions for the output\n",
330
+ "must_start_with = \"bendy \" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
331
+ "must_contain = \"yellow\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
332
+ "must_end_with = \" on a table\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
333
+ "\n",
334
+ "token_B = must_contain\n",
335
+ "\n",
336
+ "# @markdown Limit the search\n",
337
+ "use_token_padding = True # @param {type:\"boolean\"}\n",
338
+ "start_search_at_ID = 12500 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
339
+ "search_range = 500 # @param {type:\"slider\", min:0, max: 2000, step:100}\n",
340
+ "restrictions = 'Suffix only' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
341
+ "\n",
342
+ "# @markdown Limit char size of included token\n",
343
+ "min_char_size = 3 # @param {type:\"slider\", min:0, max: 50, step:1}\n",
344
+ "char_range = 5 # @param {type:\"slider\", min:0, max: 50, step:1}\n",
345
+ "\n",
346
+ "\n",
347
+ "#Tokenize input B\n",
348
+ "from transformers import AutoTokenizer\n",
349
+ "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
350
+ "tokenizer_output = tokenizer(text = token_B)\n",
351
+ "input_ids = tokenizer_output['input_ids']\n",
352
+ "#-----#\n",
353
+ "name_B = must_contain\n",
354
+ "#-----#\n",
355
+ "\n",
356
+ "from transformers import CLIPProcessor, CLIPModel\n",
357
+ "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\" , clean_up_tokenization_spaces = True)\n",
358
+ "model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
359
+ "#-------#\n",
360
+ "ids_A = processor.tokenizer(text=prompt_A, padding=use_token_padding, return_tensors=\"pt\")\n",
361
+ "text_encoding_A = model.get_text_features(**ids_A)\n",
362
+ "A = text_encoding_A[0]\n",
363
+ "_A = LA.vector_norm(A, ord=2)\n",
364
+ "name_A = prompt_A\n",
365
+ "print(f'a text_encoding was created for the prompt \"{prompt_A}\" ')\n",
366
+ "print('')\n",
367
+ "#----#\n",
368
+ "\n",
369
+ "START = start_search_at_ID\n",
370
+ "RANGE = min(search_range , 49407 - start_search_at_ID)\n",
371
+ "\n",
372
+ "dots = torch.zeros(RANGE)\n",
373
+ "is_BC = torch.zeros(RANGE)\n",
374
+ "for index in range(RANGE):\n",
375
+ " id_C = START + index\n",
376
+ " C = token[id_C]\n",
377
+ " _C = LA.vector_norm(C, ord=2)\n",
378
+ " name_C = vocab[id_C]\n",
379
+ "\n",
380
+ " # Decide if we should process prefix/suffix tokens\n",
381
+ " if name_C.find('</w>')<=-1:\n",
382
+ " if restrictions != \"Prefix only\":\n",
383
+ " continue\n",
384
+ " else:\n",
385
+ " if restrictions == \"Prefix only\":\n",
386
+ " continue\n",
387
+ " #-----#\n",
388
+ "\n",
389
+ " # Decide if char-size is within range\n",
390
+ " if len(name_C) < min_char_size:\n",
391
+ " continue\n",
392
+ " if len(name_C) > min_char_size + char_range:\n",
393
+ " continue\n",
394
+ " #-----#\n",
395
+ "\n",
396
+ " name_CB = must_start_with + name_C + name_B + must_end_with\n",
397
+ " if restrictions == \"Prefix only\":\n",
398
+ " name_CB = must_start_with + name_C + '-' + name_B + must_end_with\n",
399
+ " #-----#\n",
400
+ " ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
401
+ " text_encoding_CB = model.get_text_features(**ids_CB)\n",
402
+ " CB = text_encoding_CB[0]\n",
403
+ " _CB = LA.vector_norm(CB, ord=2)\n",
404
+ " sim_CB = torch.dot(A,CB)/(_A*_CB)\n",
405
+ " #-----#\n",
406
+ " if restrictions == \"Prefix only\":\n",
407
+ " result = sim_CB\n",
408
+ " result = result.item()\n",
409
+ " dots[index] = result\n",
410
+ " continue\n",
411
+ " #-----#\n",
412
+ " name_BC = must_start_with + name_B + name_C + must_end_with\n",
413
+ " ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
414
+ " text_encoding_BC = model.get_text_features(**ids_BC)\n",
415
+ " BC = text_encoding_BC[0]\n",
416
+ " _BC = LA.vector_norm(BC, ord=2)\n",
417
+ " sim_BC = torch.dot(A,BC)/(_A*_BC)\n",
418
+ " #-----#\n",
419
+ "\n",
420
+ " result = sim_CB\n",
421
+ " if(sim_BC > sim_CB):\n",
422
+ " is_BC[index] = 1\n",
423
+ " result = sim_BC\n",
424
+ "\n",
425
+ " #result = absolute_value(result.item())\n",
426
+ " result = result.item()\n",
427
+ " dots[index] = result\n",
428
+ "#----#\n",
429
+ "\n",
430
+ "\n",
431
+ "\n",
432
+ "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
433
+ "\n",
434
+ "# @markdown Print options\n",
435
+ "list_size = 100 # @param {type:'number'}\n",
436
+ "print_ID = False # @param {type:\"boolean\"}\n",
437
+ "print_Similarity = True # @param {type:\"boolean\"}\n",
438
+ "print_Name = True # @param {type:\"boolean\"}\n",
439
+ "print_Divider = True # @param {type:\"boolean\"}\n",
440
+ "\n",
441
+ "\n",
442
+ "if (print_Divider):\n",
443
+ " print('//---//')\n",
444
+ "\n",
445
+ "print('')\n",
446
+ "print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for the prompt \"{prompt_A}\" : ')\n",
447
+ "print('')\n",
448
+ "\n",
449
+ "for index in range(min(list_size,RANGE)):\n",
450
+ " id = START + indices[index].item()\n",
451
+ " if (print_Name):\n",
452
+ " if(is_BC[index]>0):\n",
453
+ " print(must_start_with + name_B + vocab[id] + must_end_with)\n",
454
+ " else:\n",
455
+ " if restrictions == \"Prefix only\":\n",
456
+ " print(must_start_with + vocab[id] + '-' + name_B + must_end_with)\n",
457
+ " else:\n",
458
+ " print(must_start_with + vocab[id] + name_B + must_end_with)\n",
459
+ " if (print_ID):\n",
460
+ " print(f'ID = {id}') # IDs\n",
461
+ " if (print_Similarity):\n",
462
+ " print(f'similiarity = {round(sorted[index].item()*100,2)} %')\n",
463
+ " if (print_Divider):\n",
464
+ " print('--------')"
465
+ ],
466
+ "metadata": {
467
+ "id": "uDtcm-l8UCJk"
468
  },
469
  "execution_count": null,
470
  "outputs": []