codeShare commited on
Commit
6eeabcf
1 Parent(s): cd1aec4

Upload sd_token_similarity_calculator.ipynb

Browse files
Files changed (1) hide show
  1. sd_token_similarity_calculator.ipynb +223 -172
sd_token_similarity_calculator.ipynb CHANGED
@@ -118,8 +118,7 @@
118
  ],
119
  "metadata": {
120
  "id": "Ch9puvwKH1s3",
121
- "collapsed": true,
122
- "cellView": "form"
123
  },
124
  "execution_count": null,
125
  "outputs": []
@@ -133,7 +132,7 @@
133
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
134
  "\n",
135
  "# @markdown Write name of token to match against\n",
136
- "token_name = \"banana\" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
137
  "\n",
138
  "prompt = token_name\n",
139
  "# @markdown (optional) Mix the token with something else\n",
@@ -298,8 +297,10 @@
298
  "source": [
299
  "# @title ⚡+🖼️ -> 📝 Token-Sampling Image interrogator\n",
300
  "#-----#\n",
 
301
  "import shelve\n",
302
  "db_vocab = shelve.open(VOCAB_FILENAME)\n",
 
303
  "# @markdown # What do you want to to mimic?\n",
304
  "use = '🖼️image_encoding from image' # @param ['📝text_encoding from prompt', '🖼️image_encoding from image']\n",
305
  "# @markdown --------------------------\n",
@@ -317,7 +318,7 @@
317
  " return list(uploaded.keys())\n",
318
  "#Get image\n",
319
  "# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
320
- "image_url = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
321
  "colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\": \"eval. as '/content/sd_tokens/' + **your input**\"}\n",
322
  "# @markdown --------------------------\n",
323
  "from PIL import Image\n",
@@ -360,13 +361,12 @@
360
  "#-----#\n",
361
  "# @markdown # The output...\n",
362
  "must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
363
- "must_contain = \"banana \" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
364
  "must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
365
- "token_B = must_contain\n",
366
  "# @markdown -----\n",
367
  "# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
368
  "start_search_at_index = 1700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
369
- "# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\"\n",
370
  "start_search_at_ID = start_search_at_index\n",
371
  "search_range = 100 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
372
  "restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
@@ -378,186 +378,238 @@
378
  "_enable = False # param {\"type\":\"boolean\"}\n",
379
  "prompt_items = \"\" # param {\"type\":\"string\",\"placeholder\":\"{item1|item2|...}\"}\n",
380
  "#-----#\n",
381
- "name_B = must_contain\n",
382
  "#-----#\n",
383
  "START = start_search_at_ID\n",
384
- "RANGE = min(search_range , 49407 - start_search_at_ID)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  "#-----#\n",
386
- "dots = torch.zeros(RANGE)\n",
387
- "is_BC = torch.zeros(RANGE)\n",
388
- "import re\n",
389
  "#-----#\n",
390
- "for index in range(RANGE):\n",
391
- " id_C = START + index\n",
392
- " name_C = db_vocab[f'{id_C}']\n",
393
- " is_Prefix = 0\n",
394
- " #Skip if non-AZ characters are found\n",
395
- " if re.search(\"\\W/g\" , name_C.replace('</w>', '')):\n",
396
- " continue\n",
397
  " #-----#\n",
398
- " # Decide if we should process prefix/suffix tokens\n",
399
- " if name_C.find('</w>')<=-1:\n",
400
- " is_Prefix = 1\n",
401
- " if restrictions != \"Prefix only\":\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  " continue\n",
403
- " else:\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  " if restrictions == \"Prefix only\":\n",
 
 
 
405
  " continue\n",
406
- " #-----#\n",
407
- " # Decide if char-size is within range\n",
408
- " if len(name_C) < min_char_size:\n",
409
- " continue\n",
410
- " if len(name_C) > min_char_size + char_range:\n",
411
- " continue\n",
412
- " #-----#\n",
413
- " name_CB = must_start_with + name_C + name_B + must_end_with\n",
414
- " if is_Prefix>0:\n",
415
- " name_CB = must_start_with + ' ' + name_C.strip() + '-' + name_B.strip() + ' ' + must_end_with\n",
416
- " #-----#\n",
417
- " if(use == '🖼️image_encoding from image'):\n",
418
- " ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
419
- " text_features = model.get_text_features(**ids_CB)\n",
420
- " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
421
- " logit_scale = model.logit_scale.exp()\n",
422
- " torch.matmul(text_features, image_features.t()) * logit_scale\n",
423
- " sim_CB = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
424
- " #-----#\n",
425
- " if(use == '📝text_encoding from prompt'):\n",
426
- " ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
427
- " text_features = model.get_text_features(**ids_CB)\n",
428
- " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
429
- " sim_CB = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
430
- " #-----#\n",
431
- " #-----#\n",
432
- " if restrictions == \"Prefix only\":\n",
433
  " result = sim_CB\n",
 
 
 
 
 
434
  " result = result.item()\n",
435
  " dots[index] = result\n",
436
- " continue\n",
437
- " #-----#\n",
438
- " if(use == '🖼️image_encoding from image'):\n",
439
- " name_BC = must_start_with + name_B + name_C + must_end_with\n",
440
- " ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
441
- " text_features = model.get_text_features(**ids_BC)\n",
442
- " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
443
- " logit_scale = model.logit_scale.exp()\n",
444
- " torch.matmul(text_features, image_features.t()) * logit_scale\n",
445
- " sim_BC = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
446
- " #-----#\n",
447
- " if(use == '📝text_encoding from prompt'):\n",
448
- " name_BC = must_start_with + name_B + name_C + must_end_with\n",
449
- " ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
450
- " text_features = model.get_text_features(**ids_BC)\n",
451
- " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
452
- " sim_BC = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
453
- " #-----#\n",
454
- " result = sim_CB\n",
455
- " if(sim_BC > sim_CB):\n",
456
- " is_BC[index] = 1\n",
457
- " result = sim_BC\n",
458
- " #-----#\n",
459
- " #result = absolute_value(result.item())\n",
460
- " result = result.item()\n",
461
- " dots[index] = result\n",
462
- "#----#\n",
463
- "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
464
- "# @markdown ----------\n",
465
- "# @markdown # Print options\n",
466
- "list_size = 100 # @param {type:'number'}\n",
467
- "print_ID = False # @param {type:\"boolean\"}\n",
468
- "print_Similarity = True # @param {type:\"boolean\"}\n",
469
- "print_Name = True # @param {type:\"boolean\"}\n",
470
- "print_Divider = True # @param {type:\"boolean\"}\n",
471
- "#----#\n",
472
- "if (print_Divider):\n",
473
- " print('//---//')\n",
474
- "#----#\n",
475
- "print('')\n",
476
- "print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for {prompt_A} : ')\n",
477
- "print('')\n",
478
- "#----#\n",
479
- "aheads = \"{\"\n",
480
- "trails = \"{\"\n",
481
- "tmp = \"\"\n",
482
- "#----#\n",
483
- "max_sim_ahead = 0\n",
484
- "max_sim_trail = 0\n",
485
- "sim = 0\n",
486
- "max_name_ahead = ''\n",
487
- "max_name_trail = ''\n",
488
- "#----#\n",
489
- "for index in range(min(list_size,RANGE)):\n",
490
- " id = START + indices[index].item()\n",
491
- " name = db_vocab[f'{id}']\n",
492
- " #-----#\n",
493
- " if (name.find('</w>')<=-1):\n",
494
- " name = name + '-'\n",
495
- " else:\n",
496
- " name = name.replace('</w>', ' ')\n",
497
- " if(is_BC[index]>0):\n",
498
- " trails = trails + name + \"|\"\n",
499
- " else:\n",
500
- " aheads = aheads + name + \"|\"\n",
501
  " #----#\n",
502
- " sim = sorted[index].item()\n",
 
 
 
 
 
 
 
503
  " #----#\n",
504
- " if(is_BC[index]>0):\n",
505
- " if sim>max_sim_ahead:\n",
506
- " max_sim_ahead = sim\n",
507
- " max_name_ahead = name\n",
508
- " else:\n",
509
- " if sim>max_sim_trail:\n",
510
- " max_sim_trail = sim\n",
511
- " max_name_trail = name\n",
512
- "#------#\n",
513
- "trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
514
- "aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
515
- "max_sim_ahead=max_sim_ahead\n",
516
- "max_sim_ahead=max_sim_trail\n",
517
- "#-----#\n",
518
- "print(f\"place these items ahead of prompt : {aheads}\")\n",
519
- "print(\"\")\n",
520
- "print(f\"place these items behind the prompt : {trails}\")\n",
521
- "print(\"\")\n",
522
- "print(f\"max_similarity = {max_sim_ahead} % when using '{max_name_ahead + must_contain}' \")\n",
523
- "print(\"\")\n",
524
- "print(f\"max_similarity = {max_sim_trail} % when using '{must_contain + max_name_trail}' \")\n",
525
- "#-----#\n",
526
- "#STEP 2\n",
527
- "import random\n",
528
- "names = {}\n",
529
- "NUM_PERMUTATIONS = 4\n",
530
- "#-----#\n",
531
- "dots = torch.zeros(NUM_PERMUTATIONS)\n",
532
- "for index in range(NUM_PERMUTATIONS):\n",
533
- " name = must_start_with\n",
534
- " if index == 0 : name = name + must_contain\n",
535
- " if index == 1 : name = name + max_name_ahead + must_contain\n",
536
- " if index == 2 : name = name + must_contain + max_name_trail\n",
537
- " if index == 3 : name = name + max_name_ahead + must_contain + max_name_trail\n",
538
- " name = name + must_end_with\n",
539
- " #----#\n",
540
- " ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
541
  " #----#\n",
 
 
 
542
  " if(use == '🖼️image_encoding from image'):\n",
543
- " text_features = model.get_text_features(**ids)\n",
544
- " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
545
- " logit_scale = model.logit_scale.exp()\n",
546
- " torch.matmul(text_features, image_features.t()) * logit_scale\n",
547
- " sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  " #-----#\n",
549
- " if(use == '📝text_encoding from prompt'):\n",
550
- " text_features = model.get_text_features(**ids)\n",
551
- " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
552
- " sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
 
 
 
 
 
 
 
 
553
  " #-----#\n",
554
- " dots[index] = sim\n",
555
- " names[index] = name\n",
556
- "#------#\n",
557
- "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
558
- "#------#\n",
559
- "for index in range(NUM_PERMUTATIONS):\n",
560
- " print(names[indices[index].item()])\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  " print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
562
  " print('------')\n",
563
  "#------#\n",
@@ -565,8 +617,7 @@
565
  ],
566
  "metadata": {
567
  "collapsed": true,
568
- "id": "fi0jRruI0-tu",
569
- "cellView": "form"
570
  },
571
  "execution_count": null,
572
  "outputs": []
 
118
  ],
119
  "metadata": {
120
  "id": "Ch9puvwKH1s3",
121
+ "collapsed": true
 
122
  },
123
  "execution_count": null,
124
  "outputs": []
 
132
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
133
  "\n",
134
  "# @markdown Write name of token to match against\n",
135
+ "token_name = \"dogs\" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
136
  "\n",
137
  "prompt = token_name\n",
138
  "# @markdown (optional) Mix the token with something else\n",
 
297
  "source": [
298
  "# @title ⚡+🖼️ -> 📝 Token-Sampling Image interrogator\n",
299
  "#-----#\n",
300
+ "NUM_TOKENS = 49407\n",
301
  "import shelve\n",
302
  "db_vocab = shelve.open(VOCAB_FILENAME)\n",
303
+ "print(f'using the tokens found in {VOCAB_FILENAME}.db as the vocab')\n",
304
  "# @markdown # What do you want to to mimic?\n",
305
  "use = '🖼️image_encoding from image' # @param ['📝text_encoding from prompt', '🖼️image_encoding from image']\n",
306
  "# @markdown --------------------------\n",
 
318
  " return list(uploaded.keys())\n",
319
  "#Get image\n",
320
  "# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
321
+ "image_url = \"http://images.cocodataset.org/val2017/000000039769.jpg\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
322
  "colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\": \"eval. as '/content/sd_tokens/' + **your input**\"}\n",
323
  "# @markdown --------------------------\n",
324
  "from PIL import Image\n",
 
361
  "#-----#\n",
362
  "# @markdown # The output...\n",
363
  "must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
364
+ "must_contain = \" pet \" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
365
  "must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
 
366
  "# @markdown -----\n",
367
  "# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
368
  "start_search_at_index = 1700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
369
+ "# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
370
  "start_search_at_ID = start_search_at_index\n",
371
  "search_range = 100 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
372
  "restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
 
378
  "_enable = False # param {\"type\":\"boolean\"}\n",
379
  "prompt_items = \"\" # param {\"type\":\"string\",\"placeholder\":\"{item1|item2|...}\"}\n",
380
  "#-----#\n",
 
381
  "#-----#\n",
382
  "START = start_search_at_ID\n",
383
+ "RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
384
+ "#-----#\n",
385
+ "import math, random\n",
386
+ "CHUNK = math.floor(NUM_TOKENS/(RANGE*100))\n",
387
+ "\n",
388
+ "ITERS = 3\n",
389
+ "#-----#\n",
390
+ "#LOOP START\n",
391
+ "#-----#\n",
392
+ "\n",
393
+ "results_sim = torch.zeros(ITERS+1)\n",
394
+ "results_name = {}\n",
395
+ "\n",
396
+ "# Check if original solution is best\n",
397
+ "best_sim = 0\n",
398
+ "name = must_start_with + must_contain + must_end_with\n",
399
+ "ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
400
+ "text_features = model.get_text_features(**ids)\n",
401
+ "text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
402
+ "#------#\n",
403
+ "if(use == '🖼️image_encoding from image'):\n",
404
+ " logit_scale = model.logit_scale.exp()\n",
405
+ " torch.matmul(text_features, image_features.t()) * logit_scale\n",
406
+ " sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
407
  "#-----#\n",
408
+ "if(use == '📝text_encoding from prompt'):\n",
409
+ " sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
 
410
  "#-----#\n",
411
+ "best_sim = sim\n",
412
+ "name_B = must_contain\n",
413
+ "#-----#\n",
414
+ "for iter in range(ITERS):\n",
415
+ " dots = torch.zeros(RANGE)\n",
416
+ " is_trail = torch.zeros(RANGE)\n",
417
+ " import re\n",
418
  " #-----#\n",
419
+ "\n",
420
+ " _start = START + iter*CHUNK + iter*random.randint(1,CHUNK)\n",
421
+ " results_name[iter] = name_B\n",
422
+ " results_sim[iter] = best_sim\n",
423
+ "\n",
424
+ " for index in range(RANGE):\n",
425
+ " id_C = min(_start + index, NUM_TOKENS)\n",
426
+ " name_C = db_vocab[f'{id_C}']\n",
427
+ " is_Prefix = 0\n",
428
+ " #Skip if non-AZ characters are found\n",
429
+ " #???\n",
430
+ " #-----#\n",
431
+ " # Decide if we should process prefix/suffix tokens\n",
432
+ " if name_C.find('</w>')<=-1:\n",
433
+ " is_Prefix = 1\n",
434
+ " if restrictions != \"Prefix only\":\n",
435
+ " continue\n",
436
+ " else:\n",
437
+ " if restrictions == \"Prefix only\":\n",
438
+ " continue\n",
439
+ " #-----#\n",
440
+ " # Decide if char-size is within range\n",
441
+ " if len(name_C) < min_char_size:\n",
442
  " continue\n",
443
+ " if len(name_C) > min_char_size + char_range:\n",
444
+ " continue\n",
445
+ " #-----#\n",
446
+ " name_CB = must_start_with + name_C + name_B + must_end_with\n",
447
+ " if is_Prefix>0:\n",
448
+ " name_CB = must_start_with + ' ' + name_C + '-' + name_B + ' ' + must_end_with\n",
449
+ " #-----#\n",
450
+ " if(use == '🖼️image_encoding from image'):\n",
451
+ " ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
452
+ " text_features = model.get_text_features(**ids_CB)\n",
453
+ " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
454
+ " logit_scale = model.logit_scale.exp()\n",
455
+ " torch.matmul(text_features, image_features.t()) * logit_scale\n",
456
+ " sim_CB = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
457
+ " #-----#\n",
458
+ " if(use == '📝text_encoding from prompt'):\n",
459
+ " ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
460
+ " text_features = model.get_text_features(**ids_CB)\n",
461
+ " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
462
+ " sim_CB = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
463
+ " #-----#\n",
464
+ " #-----#\n",
465
  " if restrictions == \"Prefix only\":\n",
466
+ " result = sim_CB\n",
467
+ " result = result.item()\n",
468
+ " dots[index] = result\n",
469
  " continue\n",
470
+ " #-----#\n",
471
+ " if(use == '🖼️image_encoding from image'):\n",
472
+ " name_BC = must_start_with + name_B + name_C + must_end_with\n",
473
+ " ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
474
+ " text_features = model.get_text_features(**ids_BC)\n",
475
+ " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
476
+ " logit_scale = model.logit_scale.exp()\n",
477
+ " torch.matmul(text_features, image_features.t()) * logit_scale\n",
478
+ " sim_BC = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
479
+ " #-----#\n",
480
+ " if(use == '📝text_encoding from prompt'):\n",
481
+ " name_BC = must_start_with + name_B + name_C + must_end_with\n",
482
+ " ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
483
+ " text_features = model.get_text_features(**ids_BC)\n",
484
+ " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
485
+ " sim_BC = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
486
+ " #-----#\n",
 
 
 
 
 
 
 
 
 
 
487
  " result = sim_CB\n",
488
+ " if(sim_BC > sim_CB):\n",
489
+ " is_trail[index] = 1\n",
490
+ " result = sim_BC\n",
491
+ " #-----#\n",
492
+ " #result = absolute_value(result.item())\n",
493
  " result = result.item()\n",
494
  " dots[index] = result\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  " #----#\n",
496
+ " sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
497
+ " # @markdown ----------\n",
498
+ " # @markdown # Print options\n",
499
+ " list_size = 100 # @param {type:'number'}\n",
500
+ " print_ID = False # @param {type:\"boolean\"}\n",
501
+ " print_Similarity = True # @param {type:\"boolean\"}\n",
502
+ " print_Name = True # @param {type:\"boolean\"}\n",
503
+ " print_Divider = True # @param {type:\"boolean\"}\n",
504
  " #----#\n",
505
+ " if (print_Divider):\n",
506
+ " print('//---//')\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  " #----#\n",
508
+ " print('')\n",
509
+ "\n",
510
+ " used_reference = f'the text_encoding for {prompt_A}'\n",
511
  " if(use == '🖼️image_encoding from image'):\n",
512
+ " used_reference = 'the image input'\n",
513
+ " print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match {used_reference}: ')\n",
514
+ " print('')\n",
515
+ " #----#\n",
516
+ " aheads = \"{\"\n",
517
+ " trails = \"{\"\n",
518
+ " tmp = \"\"\n",
519
+ " #----#\n",
520
+ " max_sim_ahead = 0\n",
521
+ " max_sim_trail = 0\n",
522
+ " sim = 0\n",
523
+ " max_name_ahead = ''\n",
524
+ " max_name_trail = ''\n",
525
+ " #----#\n",
526
+ " for index in range(min(list_size,RANGE)):\n",
527
+ " id = START + indices[index].item()\n",
528
+ " name = db_vocab[f'{id}']\n",
529
+ " #-----#\n",
530
+ " if (name.find('</w>')<=-1):\n",
531
+ " name = name + '-'\n",
532
+ " if(is_trail[index]>0):\n",
533
+ " trails = trails + name + \"|\"\n",
534
+ " else:\n",
535
+ " aheads = aheads + name + \"|\"\n",
536
+ " #----#\n",
537
+ " sim = sorted[index].item()\n",
538
+ " #----#\n",
539
+ " if(is_trail[index]>0):\n",
540
+ " if sim>max_sim_trail:\n",
541
+ " max_sim_trail = sim\n",
542
+ " max_name_trail = name\n",
543
+ " max_name_trail = max_name_trail.strip()\n",
544
+ "\n",
545
+ " else:\n",
546
+ " if sim>max_sim_ahead:\n",
547
+ " max_sim_ahead = sim\n",
548
+ " max_name_ahead = name\n",
549
+ " #------#\n",
550
+ " trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
551
+ " aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
552
  " #-----#\n",
553
+ " print(f\"place these items ahead of prompt : {aheads}\")\n",
554
+ " print(\"\")\n",
555
+ " print(f\"place these items behind the prompt : {trails}\")\n",
556
+ " print(\"\")\n",
557
+ "\n",
558
+ " tmp = must_start_with + ' ' + max_name_ahead + name_B + ' ' + must_end_with\n",
559
+ " tmp = tmp.strip()\n",
560
+ " print(f\"max_similarity_ahead = {round(max_sim_ahead,2)} % when using '{tmp}' \")\n",
561
+ " print(\"\")\n",
562
+ " tmp = must_start_with + ' ' + name_B + max_name_trail + ' ' + must_end_with\n",
563
+ " tmp = tmp.strip()\n",
564
+ " print(f\"max_similarity_trail = {round(max_sim_trail,2)} % when using '{tmp}' \")\n",
565
  " #-----#\n",
566
+ " #STEP 2\n",
567
+ " import random\n",
568
+ " names = {}\n",
569
+ " NUM_PERMUTATIONS = 4\n",
570
+ " #-----#\n",
571
+ " dots = torch.zeros(NUM_PERMUTATIONS)\n",
572
+ " for index in range(NUM_PERMUTATIONS):\n",
573
+ " name_inner = ''\n",
574
+ " if index == 0 : name_inner = name_B\n",
575
+ " if index == 1 : name_inner = max_name_ahead\n",
576
+ " if index == 2 : name_inner = name_B + max_name_trail\n",
577
+ " if index == 3 : name_inner = max_name_ahead + name_B + max_name_trail\n",
578
+ " name = must_start_with + name_inner + must_end_with\n",
579
+ " #----#\n",
580
+ " ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
581
+ " #----#\n",
582
+ " sim = 0\n",
583
+ " if(use == '🖼️image_encoding from image'):\n",
584
+ " text_features = model.get_text_features(**ids)\n",
585
+ " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
586
+ " logit_scale = model.logit_scale.exp()\n",
587
+ " torch.matmul(text_features, image_features.t()) * logit_scale\n",
588
+ " sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
589
+ " #-----#\n",
590
+ " if(use == '📝text_encoding from prompt'):\n",
591
+ " text_features = model.get_text_features(**ids)\n",
592
+ " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
593
+ " sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
594
+ " #-----#\n",
595
+ " dots[index] = sim\n",
596
+ " names[index] = name_inner\n",
597
+ " #------#\n",
598
+ " sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
599
+ " #------#\n",
600
+ " best_sim = dots[indices[0].item()]\n",
601
+ " name_B = names[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
602
+ "#--------#\n",
603
+ "#store the final value\n",
604
+ "results_name[iter] = name_B\n",
605
+ "results_sim[iter] = best_sim\n",
606
+ "\n",
607
+ "sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
608
+ "\n",
609
+ "print('')\n",
610
+ "for index in range(ITERS+1):\n",
611
+ " name_inner = results_name[indices[index].item()]\n",
612
+ " print(must_start_with + name_inner + must_end_with)\n",
613
  " print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
614
  " print('------')\n",
615
  "#------#\n",
 
617
  ],
618
  "metadata": {
619
  "collapsed": true,
620
+ "id": "fi0jRruI0-tu"
 
621
  },
622
  "execution_count": null,
623
  "outputs": []