Upload sd_token_similarity_calculator.ipynb
Browse files- sd_token_similarity_calculator.ipynb +223 -172
sd_token_similarity_calculator.ipynb
CHANGED
@@ -118,8 +118,7 @@
|
|
118 |
],
|
119 |
"metadata": {
|
120 |
"id": "Ch9puvwKH1s3",
|
121 |
-
"collapsed": true
|
122 |
-
"cellView": "form"
|
123 |
},
|
124 |
"execution_count": null,
|
125 |
"outputs": []
|
@@ -133,7 +132,7 @@
|
|
133 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
134 |
"\n",
|
135 |
"# @markdown Write name of token to match against\n",
|
136 |
-
"token_name = \"
|
137 |
"\n",
|
138 |
"prompt = token_name\n",
|
139 |
"# @markdown (optional) Mix the token with something else\n",
|
@@ -298,8 +297,10 @@
|
|
298 |
"source": [
|
299 |
"# @title ⚡+🖼️ -> 📝 Token-Sampling Image interrogator\n",
|
300 |
"#-----#\n",
|
|
|
301 |
"import shelve\n",
|
302 |
"db_vocab = shelve.open(VOCAB_FILENAME)\n",
|
|
|
303 |
"# @markdown # What do you want to to mimic?\n",
|
304 |
"use = '🖼️image_encoding from image' # @param ['📝text_encoding from prompt', '🖼️image_encoding from image']\n",
|
305 |
"# @markdown --------------------------\n",
|
@@ -317,7 +318,7 @@
|
|
317 |
" return list(uploaded.keys())\n",
|
318 |
"#Get image\n",
|
319 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
320 |
-
"image_url = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
321 |
"colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\": \"eval. as '/content/sd_tokens/' + **your input**\"}\n",
|
322 |
"# @markdown --------------------------\n",
|
323 |
"from PIL import Image\n",
|
@@ -360,13 +361,12 @@
|
|
360 |
"#-----#\n",
|
361 |
"# @markdown # The output...\n",
|
362 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
363 |
-
"must_contain = \"
|
364 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
365 |
-
"token_B = must_contain\n",
|
366 |
"# @markdown -----\n",
|
367 |
"# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
|
368 |
"start_search_at_index = 1700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
369 |
-
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\"\n",
|
370 |
"start_search_at_ID = start_search_at_index\n",
|
371 |
"search_range = 100 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
|
372 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
@@ -378,186 +378,238 @@
|
|
378 |
"_enable = False # param {\"type\":\"boolean\"}\n",
|
379 |
"prompt_items = \"\" # param {\"type\":\"string\",\"placeholder\":\"{item1|item2|...}\"}\n",
|
380 |
"#-----#\n",
|
381 |
-
"name_B = must_contain\n",
|
382 |
"#-----#\n",
|
383 |
"START = start_search_at_ID\n",
|
384 |
-
"RANGE = min(search_range ,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
"#-----#\n",
|
386 |
-
"
|
387 |
-
"
|
388 |
-
"import re\n",
|
389 |
"#-----#\n",
|
390 |
-
"
|
391 |
-
"
|
392 |
-
"
|
393 |
-
"
|
394 |
-
"
|
395 |
-
"
|
396 |
-
"
|
397 |
" #-----#\n",
|
398 |
-
"
|
399 |
-
"
|
400 |
-
"
|
401 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
" continue\n",
|
403 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
" if restrictions == \"Prefix only\":\n",
|
|
|
|
|
|
|
405 |
" continue\n",
|
406 |
-
"
|
407 |
-
"
|
408 |
-
"
|
409 |
-
"
|
410 |
-
"
|
411 |
-
"
|
412 |
-
"
|
413 |
-
"
|
414 |
-
"
|
415 |
-
"
|
416 |
-
"
|
417 |
-
"
|
418 |
-
"
|
419 |
-
"
|
420 |
-
"
|
421 |
-
"
|
422 |
-
"
|
423 |
-
" sim_CB = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
424 |
-
" #-----#\n",
|
425 |
-
" if(use == '📝text_encoding from prompt'):\n",
|
426 |
-
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
427 |
-
" text_features = model.get_text_features(**ids_CB)\n",
|
428 |
-
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
429 |
-
" sim_CB = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
430 |
-
" #-----#\n",
|
431 |
-
" #-----#\n",
|
432 |
-
" if restrictions == \"Prefix only\":\n",
|
433 |
" result = sim_CB\n",
|
|
|
|
|
|
|
|
|
|
|
434 |
" result = result.item()\n",
|
435 |
" dots[index] = result\n",
|
436 |
-
" continue\n",
|
437 |
-
" #-----#\n",
|
438 |
-
" if(use == '🖼️image_encoding from image'):\n",
|
439 |
-
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
440 |
-
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
441 |
-
" text_features = model.get_text_features(**ids_BC)\n",
|
442 |
-
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
443 |
-
" logit_scale = model.logit_scale.exp()\n",
|
444 |
-
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
445 |
-
" sim_BC = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
446 |
-
" #-----#\n",
|
447 |
-
" if(use == '📝text_encoding from prompt'):\n",
|
448 |
-
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
449 |
-
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
450 |
-
" text_features = model.get_text_features(**ids_BC)\n",
|
451 |
-
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
452 |
-
" sim_BC = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
453 |
-
" #-----#\n",
|
454 |
-
" result = sim_CB\n",
|
455 |
-
" if(sim_BC > sim_CB):\n",
|
456 |
-
" is_BC[index] = 1\n",
|
457 |
-
" result = sim_BC\n",
|
458 |
-
" #-----#\n",
|
459 |
-
" #result = absolute_value(result.item())\n",
|
460 |
-
" result = result.item()\n",
|
461 |
-
" dots[index] = result\n",
|
462 |
-
"#----#\n",
|
463 |
-
"sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
464 |
-
"# @markdown ----------\n",
|
465 |
-
"# @markdown # Print options\n",
|
466 |
-
"list_size = 100 # @param {type:'number'}\n",
|
467 |
-
"print_ID = False # @param {type:\"boolean\"}\n",
|
468 |
-
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
469 |
-
"print_Name = True # @param {type:\"boolean\"}\n",
|
470 |
-
"print_Divider = True # @param {type:\"boolean\"}\n",
|
471 |
-
"#----#\n",
|
472 |
-
"if (print_Divider):\n",
|
473 |
-
" print('//---//')\n",
|
474 |
-
"#----#\n",
|
475 |
-
"print('')\n",
|
476 |
-
"print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for {prompt_A} : ')\n",
|
477 |
-
"print('')\n",
|
478 |
-
"#----#\n",
|
479 |
-
"aheads = \"{\"\n",
|
480 |
-
"trails = \"{\"\n",
|
481 |
-
"tmp = \"\"\n",
|
482 |
-
"#----#\n",
|
483 |
-
"max_sim_ahead = 0\n",
|
484 |
-
"max_sim_trail = 0\n",
|
485 |
-
"sim = 0\n",
|
486 |
-
"max_name_ahead = ''\n",
|
487 |
-
"max_name_trail = ''\n",
|
488 |
-
"#----#\n",
|
489 |
-
"for index in range(min(list_size,RANGE)):\n",
|
490 |
-
" id = START + indices[index].item()\n",
|
491 |
-
" name = db_vocab[f'{id}']\n",
|
492 |
-
" #-----#\n",
|
493 |
-
" if (name.find('</w>')<=-1):\n",
|
494 |
-
" name = name + '-'\n",
|
495 |
-
" else:\n",
|
496 |
-
" name = name.replace('</w>', ' ')\n",
|
497 |
-
" if(is_BC[index]>0):\n",
|
498 |
-
" trails = trails + name + \"|\"\n",
|
499 |
-
" else:\n",
|
500 |
-
" aheads = aheads + name + \"|\"\n",
|
501 |
" #----#\n",
|
502 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
" #----#\n",
|
504 |
-
" if(
|
505 |
-
"
|
506 |
-
" max_sim_ahead = sim\n",
|
507 |
-
" max_name_ahead = name\n",
|
508 |
-
" else:\n",
|
509 |
-
" if sim>max_sim_trail:\n",
|
510 |
-
" max_sim_trail = sim\n",
|
511 |
-
" max_name_trail = name\n",
|
512 |
-
"#------#\n",
|
513 |
-
"trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
514 |
-
"aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
515 |
-
"max_sim_ahead=max_sim_ahead\n",
|
516 |
-
"max_sim_ahead=max_sim_trail\n",
|
517 |
-
"#-----#\n",
|
518 |
-
"print(f\"place these items ahead of prompt : {aheads}\")\n",
|
519 |
-
"print(\"\")\n",
|
520 |
-
"print(f\"place these items behind the prompt : {trails}\")\n",
|
521 |
-
"print(\"\")\n",
|
522 |
-
"print(f\"max_similarity = {max_sim_ahead} % when using '{max_name_ahead + must_contain}' \")\n",
|
523 |
-
"print(\"\")\n",
|
524 |
-
"print(f\"max_similarity = {max_sim_trail} % when using '{must_contain + max_name_trail}' \")\n",
|
525 |
-
"#-----#\n",
|
526 |
-
"#STEP 2\n",
|
527 |
-
"import random\n",
|
528 |
-
"names = {}\n",
|
529 |
-
"NUM_PERMUTATIONS = 4\n",
|
530 |
-
"#-----#\n",
|
531 |
-
"dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
532 |
-
"for index in range(NUM_PERMUTATIONS):\n",
|
533 |
-
" name = must_start_with\n",
|
534 |
-
" if index == 0 : name = name + must_contain\n",
|
535 |
-
" if index == 1 : name = name + max_name_ahead + must_contain\n",
|
536 |
-
" if index == 2 : name = name + must_contain + max_name_trail\n",
|
537 |
-
" if index == 3 : name = name + max_name_ahead + must_contain + max_name_trail\n",
|
538 |
-
" name = name + must_end_with\n",
|
539 |
-
" #----#\n",
|
540 |
-
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
541 |
" #----#\n",
|
|
|
|
|
|
|
542 |
" if(use == '🖼️image_encoding from image'):\n",
|
543 |
-
"
|
544 |
-
"
|
545 |
-
"
|
546 |
-
"
|
547 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
" #-----#\n",
|
549 |
-
"
|
550 |
-
"
|
551 |
-
"
|
552 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
" #-----#\n",
|
554 |
-
"
|
555 |
-
"
|
556 |
-
"
|
557 |
-
"
|
558 |
-
"
|
559 |
-
"
|
560 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
561 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|
562 |
" print('------')\n",
|
563 |
"#------#\n",
|
@@ -565,8 +617,7 @@
|
|
565 |
],
|
566 |
"metadata": {
|
567 |
"collapsed": true,
|
568 |
-
"id": "fi0jRruI0-tu"
|
569 |
-
"cellView": "form"
|
570 |
},
|
571 |
"execution_count": null,
|
572 |
"outputs": []
|
|
|
118 |
],
|
119 |
"metadata": {
|
120 |
"id": "Ch9puvwKH1s3",
|
121 |
+
"collapsed": true
|
|
|
122 |
},
|
123 |
"execution_count": null,
|
124 |
"outputs": []
|
|
|
132 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
133 |
"\n",
|
134 |
"# @markdown Write name of token to match against\n",
|
135 |
+
"token_name = \"dogs\" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
|
136 |
"\n",
|
137 |
"prompt = token_name\n",
|
138 |
"# @markdown (optional) Mix the token with something else\n",
|
|
|
297 |
"source": [
|
298 |
"# @title ⚡+🖼️ -> 📝 Token-Sampling Image interrogator\n",
|
299 |
"#-----#\n",
|
300 |
+
"NUM_TOKENS = 49407\n",
|
301 |
"import shelve\n",
|
302 |
"db_vocab = shelve.open(VOCAB_FILENAME)\n",
|
303 |
+
"print(f'using the tokens found in {VOCAB_FILENAME}.db as the vocab')\n",
|
304 |
"# @markdown # What do you want to to mimic?\n",
|
305 |
"use = '🖼️image_encoding from image' # @param ['📝text_encoding from prompt', '🖼️image_encoding from image']\n",
|
306 |
"# @markdown --------------------------\n",
|
|
|
318 |
" return list(uploaded.keys())\n",
|
319 |
"#Get image\n",
|
320 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
321 |
+
"image_url = \"http://images.cocodataset.org/val2017/000000039769.jpg\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
322 |
"colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\": \"eval. as '/content/sd_tokens/' + **your input**\"}\n",
|
323 |
"# @markdown --------------------------\n",
|
324 |
"from PIL import Image\n",
|
|
|
361 |
"#-----#\n",
|
362 |
"# @markdown # The output...\n",
|
363 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
364 |
+
"must_contain = \" pet \" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
365 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
|
|
366 |
"# @markdown -----\n",
|
367 |
"# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
|
368 |
"start_search_at_index = 1700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
369 |
+
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
370 |
"start_search_at_ID = start_search_at_index\n",
|
371 |
"search_range = 100 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
|
372 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
|
|
378 |
"_enable = False # param {\"type\":\"boolean\"}\n",
|
379 |
"prompt_items = \"\" # param {\"type\":\"string\",\"placeholder\":\"{item1|item2|...}\"}\n",
|
380 |
"#-----#\n",
|
|
|
381 |
"#-----#\n",
|
382 |
"START = start_search_at_ID\n",
|
383 |
+
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
384 |
+
"#-----#\n",
|
385 |
+
"import math, random\n",
|
386 |
+
"CHUNK = math.floor(NUM_TOKENS/(RANGE*100))\n",
|
387 |
+
"\n",
|
388 |
+
"ITERS = 3\n",
|
389 |
+
"#-----#\n",
|
390 |
+
"#LOOP START\n",
|
391 |
+
"#-----#\n",
|
392 |
+
"\n",
|
393 |
+
"results_sim = torch.zeros(ITERS+1)\n",
|
394 |
+
"results_name = {}\n",
|
395 |
+
"\n",
|
396 |
+
"# Check if original solution is best\n",
|
397 |
+
"best_sim = 0\n",
|
398 |
+
"name = must_start_with + must_contain + must_end_with\n",
|
399 |
+
"ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
400 |
+
"text_features = model.get_text_features(**ids)\n",
|
401 |
+
"text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
402 |
+
"#------#\n",
|
403 |
+
"if(use == '🖼️image_encoding from image'):\n",
|
404 |
+
" logit_scale = model.logit_scale.exp()\n",
|
405 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
406 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
407 |
"#-----#\n",
|
408 |
+
"if(use == '📝text_encoding from prompt'):\n",
|
409 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
|
|
410 |
"#-----#\n",
|
411 |
+
"best_sim = sim\n",
|
412 |
+
"name_B = must_contain\n",
|
413 |
+
"#-----#\n",
|
414 |
+
"for iter in range(ITERS):\n",
|
415 |
+
" dots = torch.zeros(RANGE)\n",
|
416 |
+
" is_trail = torch.zeros(RANGE)\n",
|
417 |
+
" import re\n",
|
418 |
" #-----#\n",
|
419 |
+
"\n",
|
420 |
+
" _start = START + iter*CHUNK + iter*random.randint(1,CHUNK)\n",
|
421 |
+
" results_name[iter] = name_B\n",
|
422 |
+
" results_sim[iter] = best_sim\n",
|
423 |
+
"\n",
|
424 |
+
" for index in range(RANGE):\n",
|
425 |
+
" id_C = min(_start + index, NUM_TOKENS)\n",
|
426 |
+
" name_C = db_vocab[f'{id_C}']\n",
|
427 |
+
" is_Prefix = 0\n",
|
428 |
+
" #Skip if non-AZ characters are found\n",
|
429 |
+
" #???\n",
|
430 |
+
" #-----#\n",
|
431 |
+
" # Decide if we should process prefix/suffix tokens\n",
|
432 |
+
" if name_C.find('</w>')<=-1:\n",
|
433 |
+
" is_Prefix = 1\n",
|
434 |
+
" if restrictions != \"Prefix only\":\n",
|
435 |
+
" continue\n",
|
436 |
+
" else:\n",
|
437 |
+
" if restrictions == \"Prefix only\":\n",
|
438 |
+
" continue\n",
|
439 |
+
" #-----#\n",
|
440 |
+
" # Decide if char-size is within range\n",
|
441 |
+
" if len(name_C) < min_char_size:\n",
|
442 |
" continue\n",
|
443 |
+
" if len(name_C) > min_char_size + char_range:\n",
|
444 |
+
" continue\n",
|
445 |
+
" #-----#\n",
|
446 |
+
" name_CB = must_start_with + name_C + name_B + must_end_with\n",
|
447 |
+
" if is_Prefix>0:\n",
|
448 |
+
" name_CB = must_start_with + ' ' + name_C + '-' + name_B + ' ' + must_end_with\n",
|
449 |
+
" #-----#\n",
|
450 |
+
" if(use == '🖼️image_encoding from image'):\n",
|
451 |
+
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
452 |
+
" text_features = model.get_text_features(**ids_CB)\n",
|
453 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
454 |
+
" logit_scale = model.logit_scale.exp()\n",
|
455 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
456 |
+
" sim_CB = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
457 |
+
" #-----#\n",
|
458 |
+
" if(use == '📝text_encoding from prompt'):\n",
|
459 |
+
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
460 |
+
" text_features = model.get_text_features(**ids_CB)\n",
|
461 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
462 |
+
" sim_CB = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
463 |
+
" #-----#\n",
|
464 |
+
" #-----#\n",
|
465 |
" if restrictions == \"Prefix only\":\n",
|
466 |
+
" result = sim_CB\n",
|
467 |
+
" result = result.item()\n",
|
468 |
+
" dots[index] = result\n",
|
469 |
" continue\n",
|
470 |
+
" #-----#\n",
|
471 |
+
" if(use == '🖼️image_encoding from image'):\n",
|
472 |
+
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
473 |
+
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
474 |
+
" text_features = model.get_text_features(**ids_BC)\n",
|
475 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
476 |
+
" logit_scale = model.logit_scale.exp()\n",
|
477 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
478 |
+
" sim_BC = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
479 |
+
" #-----#\n",
|
480 |
+
" if(use == '📝text_encoding from prompt'):\n",
|
481 |
+
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
482 |
+
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
483 |
+
" text_features = model.get_text_features(**ids_BC)\n",
|
484 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
485 |
+
" sim_BC = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
486 |
+
" #-----#\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
" result = sim_CB\n",
|
488 |
+
" if(sim_BC > sim_CB):\n",
|
489 |
+
" is_trail[index] = 1\n",
|
490 |
+
" result = sim_BC\n",
|
491 |
+
" #-----#\n",
|
492 |
+
" #result = absolute_value(result.item())\n",
|
493 |
" result = result.item()\n",
|
494 |
" dots[index] = result\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
" #----#\n",
|
496 |
+
" sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
497 |
+
" # @markdown ----------\n",
|
498 |
+
" # @markdown # Print options\n",
|
499 |
+
" list_size = 100 # @param {type:'number'}\n",
|
500 |
+
" print_ID = False # @param {type:\"boolean\"}\n",
|
501 |
+
" print_Similarity = True # @param {type:\"boolean\"}\n",
|
502 |
+
" print_Name = True # @param {type:\"boolean\"}\n",
|
503 |
+
" print_Divider = True # @param {type:\"boolean\"}\n",
|
504 |
" #----#\n",
|
505 |
+
" if (print_Divider):\n",
|
506 |
+
" print('//---//')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
" #----#\n",
|
508 |
+
" print('')\n",
|
509 |
+
"\n",
|
510 |
+
" used_reference = f'the text_encoding for {prompt_A}'\n",
|
511 |
" if(use == '🖼️image_encoding from image'):\n",
|
512 |
+
" used_reference = 'the image input'\n",
|
513 |
+
" print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match {used_reference}: ')\n",
|
514 |
+
" print('')\n",
|
515 |
+
" #----#\n",
|
516 |
+
" aheads = \"{\"\n",
|
517 |
+
" trails = \"{\"\n",
|
518 |
+
" tmp = \"\"\n",
|
519 |
+
" #----#\n",
|
520 |
+
" max_sim_ahead = 0\n",
|
521 |
+
" max_sim_trail = 0\n",
|
522 |
+
" sim = 0\n",
|
523 |
+
" max_name_ahead = ''\n",
|
524 |
+
" max_name_trail = ''\n",
|
525 |
+
" #----#\n",
|
526 |
+
" for index in range(min(list_size,RANGE)):\n",
|
527 |
+
" id = START + indices[index].item()\n",
|
528 |
+
" name = db_vocab[f'{id}']\n",
|
529 |
+
" #-----#\n",
|
530 |
+
" if (name.find('</w>')<=-1):\n",
|
531 |
+
" name = name + '-'\n",
|
532 |
+
" if(is_trail[index]>0):\n",
|
533 |
+
" trails = trails + name + \"|\"\n",
|
534 |
+
" else:\n",
|
535 |
+
" aheads = aheads + name + \"|\"\n",
|
536 |
+
" #----#\n",
|
537 |
+
" sim = sorted[index].item()\n",
|
538 |
+
" #----#\n",
|
539 |
+
" if(is_trail[index]>0):\n",
|
540 |
+
" if sim>max_sim_trail:\n",
|
541 |
+
" max_sim_trail = sim\n",
|
542 |
+
" max_name_trail = name\n",
|
543 |
+
" max_name_trail = max_name_trail.strip()\n",
|
544 |
+
"\n",
|
545 |
+
" else:\n",
|
546 |
+
" if sim>max_sim_ahead:\n",
|
547 |
+
" max_sim_ahead = sim\n",
|
548 |
+
" max_name_ahead = name\n",
|
549 |
+
" #------#\n",
|
550 |
+
" trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
551 |
+
" aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
552 |
" #-----#\n",
|
553 |
+
" print(f\"place these items ahead of prompt : {aheads}\")\n",
|
554 |
+
" print(\"\")\n",
|
555 |
+
" print(f\"place these items behind the prompt : {trails}\")\n",
|
556 |
+
" print(\"\")\n",
|
557 |
+
"\n",
|
558 |
+
" tmp = must_start_with + ' ' + max_name_ahead + name_B + ' ' + must_end_with\n",
|
559 |
+
" tmp = tmp.strip()\n",
|
560 |
+
" print(f\"max_similarity_ahead = {round(max_sim_ahead,2)} % when using '{tmp}' \")\n",
|
561 |
+
" print(\"\")\n",
|
562 |
+
" tmp = must_start_with + ' ' + name_B + max_name_trail + ' ' + must_end_with\n",
|
563 |
+
" tmp = tmp.strip()\n",
|
564 |
+
" print(f\"max_similarity_trail = {round(max_sim_trail,2)} % when using '{tmp}' \")\n",
|
565 |
" #-----#\n",
|
566 |
+
" #STEP 2\n",
|
567 |
+
" import random\n",
|
568 |
+
" names = {}\n",
|
569 |
+
" NUM_PERMUTATIONS = 4\n",
|
570 |
+
" #-----#\n",
|
571 |
+
" dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
572 |
+
" for index in range(NUM_PERMUTATIONS):\n",
|
573 |
+
" name_inner = ''\n",
|
574 |
+
" if index == 0 : name_inner = name_B\n",
|
575 |
+
" if index == 1 : name_inner = max_name_ahead\n",
|
576 |
+
" if index == 2 : name_inner = name_B + max_name_trail\n",
|
577 |
+
" if index == 3 : name_inner = max_name_ahead + name_B + max_name_trail\n",
|
578 |
+
" name = must_start_with + name_inner + must_end_with\n",
|
579 |
+
" #----#\n",
|
580 |
+
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
581 |
+
" #----#\n",
|
582 |
+
" sim = 0\n",
|
583 |
+
" if(use == '🖼️image_encoding from image'):\n",
|
584 |
+
" text_features = model.get_text_features(**ids)\n",
|
585 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
586 |
+
" logit_scale = model.logit_scale.exp()\n",
|
587 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
588 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
589 |
+
" #-----#\n",
|
590 |
+
" if(use == '📝text_encoding from prompt'):\n",
|
591 |
+
" text_features = model.get_text_features(**ids)\n",
|
592 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
593 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
594 |
+
" #-----#\n",
|
595 |
+
" dots[index] = sim\n",
|
596 |
+
" names[index] = name_inner\n",
|
597 |
+
" #------#\n",
|
598 |
+
" sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
599 |
+
" #------#\n",
|
600 |
+
" best_sim = dots[indices[0].item()]\n",
|
601 |
+
" name_B = names[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
602 |
+
"#--------#\n",
|
603 |
+
"#store the final value\n",
|
604 |
+
"results_name[iter] = name_B\n",
|
605 |
+
"results_sim[iter] = best_sim\n",
|
606 |
+
"\n",
|
607 |
+
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
608 |
+
"\n",
|
609 |
+
"print('')\n",
|
610 |
+
"for index in range(ITERS+1):\n",
|
611 |
+
" name_inner = results_name[indices[index].item()]\n",
|
612 |
+
" print(must_start_with + name_inner + must_end_with)\n",
|
613 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|
614 |
" print('------')\n",
|
615 |
"#------#\n",
|
|
|
617 |
],
|
618 |
"metadata": {
|
619 |
"collapsed": true,
|
620 |
+
"id": "fi0jRruI0-tu"
|
|
|
621 |
},
|
622 |
"execution_count": null,
|
623 |
"outputs": []
|