Upload sd_token_similarity_calculator.ipynb
Browse files
sd_token_similarity_calculator.ipynb
CHANGED
@@ -118,10 +118,29 @@
|
|
118 |
],
|
119 |
"metadata": {
|
120 |
"id": "Ch9puvwKH1s3",
|
121 |
-
"collapsed": true
|
|
|
|
|
|
|
|
|
122 |
},
|
123 |
-
"execution_count":
|
124 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
},
|
126 |
{
|
127 |
"cell_type": "code",
|
@@ -132,7 +151,7 @@
|
|
132 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
133 |
"\n",
|
134 |
"# @markdown Write name of token to match against\n",
|
135 |
-
"token_name = \"
|
136 |
"\n",
|
137 |
"prompt = token_name\n",
|
138 |
"# @markdown (optional) Mix the token with something else\n",
|
@@ -368,7 +387,10 @@
|
|
368 |
"start_search_at_index = 0 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
369 |
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
370 |
"start_search_at_ID = start_search_at_index\n",
|
371 |
-
"search_range =
|
|
|
|
|
|
|
372 |
"iterations = 5 # @param {type:\"slider\", min:1, max: 20, step:0}\n",
|
373 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
374 |
"#markdown Limit char size of included token <----- Disabled\n",
|
@@ -384,15 +406,11 @@
|
|
384 |
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
385 |
"#-----#\n",
|
386 |
"import math, random\n",
|
387 |
-
"
|
388 |
-
"\n",
|
389 |
"ITERS = iterations\n",
|
390 |
"#-----#\n",
|
391 |
"#LOOP START\n",
|
392 |
"#-----#\n",
|
393 |
-
"\n",
|
394 |
-
"\n",
|
395 |
-
"\n",
|
396 |
"# Check if original solution is best\n",
|
397 |
"best_sim = 0\n",
|
398 |
"name = must_start_with + must_contain + must_end_with\n",
|
@@ -400,6 +418,7 @@
|
|
400 |
"text_features = model.get_text_features(**ids)\n",
|
401 |
"text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
402 |
"#------#\n",
|
|
|
403 |
"if(use == '🖼️image_encoding from image'):\n",
|
404 |
" logit_scale = model.logit_scale.exp()\n",
|
405 |
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
@@ -411,7 +430,8 @@
|
|
411 |
"best_sim = sim\n",
|
412 |
"best_name = name\n",
|
413 |
"name_B = must_contain\n",
|
414 |
-
"
|
|
|
415 |
"results_name_B = {}\n",
|
416 |
"results_name = {}\n",
|
417 |
"#-----#\n",
|
@@ -420,17 +440,10 @@
|
|
420 |
" is_trail = torch.zeros(RANGE)\n",
|
421 |
" import re\n",
|
422 |
" #-----#\n",
|
|
|
423 |
"\n",
|
424 |
-
"
|
425 |
-
"
|
426 |
-
" results_sim[iter] = best_sim\n",
|
427 |
-
" results_name_B[iter] = name_B\n",
|
428 |
-
" #-----#\n",
|
429 |
-
" sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
430 |
-
" name_B = results_name_B[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
431 |
-
"\n",
|
432 |
-
" for index in range(RANGE):\n",
|
433 |
-
" id_C = min(_start + index, NUM_TOKENS)\n",
|
434 |
" name_C = db_vocab[f'{id_C}']\n",
|
435 |
" is_Prefix = 0\n",
|
436 |
" #Skip if non-AZ characters are found\n",
|
@@ -573,17 +586,15 @@
|
|
573 |
" #-----#\n",
|
574 |
" #STEP 2\n",
|
575 |
" import random\n",
|
576 |
-
" names = {}\n",
|
577 |
-
" name_inners = {}\n",
|
578 |
-
" NUM_PERMUTATIONS = 4\n",
|
579 |
" #-----#\n",
|
580 |
-
" dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
581 |
" for index in range(NUM_PERMUTATIONS):\n",
|
582 |
" name_inner = ''\n",
|
583 |
" if index == 0 : name_inner = name_B\n",
|
584 |
-
" if index == 1
|
585 |
-
" if index == 2
|
586 |
-
" if index == 3
|
|
|
|
|
587 |
" name = must_start_with + name_inner + must_end_with\n",
|
588 |
" #----#\n",
|
589 |
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
@@ -601,25 +612,17 @@
|
|
601 |
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
602 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
603 |
" #-----#\n",
|
604 |
-
"
|
605 |
-
"
|
606 |
-
"
|
607 |
" #------#\n",
|
608 |
-
"
|
609 |
-
" #------#\n",
|
610 |
-
" best_sim = dots[indices[0].item()]\n",
|
611 |
-
" best_name = names[indices[0].item()]\n",
|
612 |
-
" name_B = name_inners[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
613 |
"#--------#\n",
|
614 |
-
"#store the final value\n",
|
615 |
-
"results_name[iter+1] = best_name\n",
|
616 |
-
"results_sim[iter+1] = best_sim\n",
|
617 |
-
"results_name_B[iter+1] = name_B\n",
|
618 |
"\n",
|
|
|
619 |
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
620 |
"\n",
|
621 |
-
"
|
622 |
-
"for index in range(ITERS+1):\n",
|
623 |
" name_inner = results_name[indices[index].item()]\n",
|
624 |
" print(must_start_with + name_inner + must_end_with)\n",
|
625 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|
|
|
118 |
],
|
119 |
"metadata": {
|
120 |
"id": "Ch9puvwKH1s3",
|
121 |
+
"collapsed": true,
|
122 |
+
"outputId": "033c251a-2043-40e7-9500-4da870ffa7fd",
|
123 |
+
"colab": {
|
124 |
+
"base_uri": "https://localhost:8080/"
|
125 |
+
}
|
126 |
},
|
127 |
+
"execution_count": 1,
|
128 |
+
"outputs": [
|
129 |
+
{
|
130 |
+
"output_type": "stream",
|
131 |
+
"name": "stdout",
|
132 |
+
"text": [
|
133 |
+
"Cloning into 'sd_tokens'...\n",
|
134 |
+
"remote: Enumerating objects: 20, done.\u001b[K\n",
|
135 |
+
"remote: Counting objects: 100% (17/17), done.\u001b[K\n",
|
136 |
+
"remote: Compressing objects: 100% (17/17), done.\u001b[K\n",
|
137 |
+
"remote: Total 20 (delta 4), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
|
138 |
+
"Unpacking objects: 100% (20/20), 310.37 KiB | 2.10 MiB/s, done.\n",
|
139 |
+
"Filtering content: 100% (3/3), 160.82 MiB | 26.64 MiB/s, done.\n",
|
140 |
+
"/content/sd_tokens\n"
|
141 |
+
]
|
142 |
+
}
|
143 |
+
]
|
144 |
},
|
145 |
{
|
146 |
"cell_type": "code",
|
|
|
151 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
152 |
"\n",
|
153 |
"# @markdown Write name of token to match against\n",
|
154 |
+
"token_name = \" blanket \" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
|
155 |
"\n",
|
156 |
"prompt = token_name\n",
|
157 |
"# @markdown (optional) Mix the token with something else\n",
|
|
|
387 |
"start_search_at_index = 0 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
388 |
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
389 |
"start_search_at_ID = start_search_at_index\n",
|
390 |
+
"search_range = 1000 # @param {type:\"slider\", min:10, max: 1000, step:10}\n",
|
391 |
+
"\n",
|
392 |
+
"samples_per_iter = 10 # @param {type:\"slider\", min:10, max: 100, step:10}\n",
|
393 |
+
"\n",
|
394 |
"iterations = 5 # @param {type:\"slider\", min:1, max: 20, step:0}\n",
|
395 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
396 |
"#markdown Limit char size of included token <----- Disabled\n",
|
|
|
406 |
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
407 |
"#-----#\n",
|
408 |
"import math, random\n",
|
409 |
+
"NUM_PERMUTATIONS = 4\n",
|
|
|
410 |
"ITERS = iterations\n",
|
411 |
"#-----#\n",
|
412 |
"#LOOP START\n",
|
413 |
"#-----#\n",
|
|
|
|
|
|
|
414 |
"# Check if original solution is best\n",
|
415 |
"best_sim = 0\n",
|
416 |
"name = must_start_with + must_contain + must_end_with\n",
|
|
|
418 |
"text_features = model.get_text_features(**ids)\n",
|
419 |
"text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
420 |
"#------#\n",
|
421 |
+
"sim = 0\n",
|
422 |
"if(use == '🖼️image_encoding from image'):\n",
|
423 |
" logit_scale = model.logit_scale.exp()\n",
|
424 |
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
|
|
430 |
"best_sim = sim\n",
|
431 |
"best_name = name\n",
|
432 |
"name_B = must_contain\n",
|
433 |
+
"#------#\n",
|
434 |
+
"results_sim = torch.zeros(ITERS*NUM_PERMUTATIONS)\n",
|
435 |
"results_name_B = {}\n",
|
436 |
"results_name = {}\n",
|
437 |
"#-----#\n",
|
|
|
440 |
" is_trail = torch.zeros(RANGE)\n",
|
441 |
" import re\n",
|
442 |
" #-----#\n",
|
443 |
+
" _start = START + iter*RANGE\n",
|
444 |
"\n",
|
445 |
+
" for index in range(samples_per_iter):\n",
|
446 |
+
" id_C = min(_start + index, NUM_TOKENS) + random.randint(0,RANGE)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
" name_C = db_vocab[f'{id_C}']\n",
|
448 |
" is_Prefix = 0\n",
|
449 |
" #Skip if non-AZ characters are found\n",
|
|
|
586 |
" #-----#\n",
|
587 |
" #STEP 2\n",
|
588 |
" import random\n",
|
|
|
|
|
|
|
589 |
" #-----#\n",
|
|
|
590 |
" for index in range(NUM_PERMUTATIONS):\n",
|
591 |
" name_inner = ''\n",
|
592 |
" if index == 0 : name_inner = name_B\n",
|
593 |
+
" if index == 1: name_inner = max_name_ahead\n",
|
594 |
+
" if index == 2: name_inner = name_B + max_name_trail\n",
|
595 |
+
" if index == 3: name_inner = max_name_ahead + name_B + max_name_trail\n",
|
596 |
+
" if name_inner == '': name_inner = max_name_ahead + name_B + max_name_trail\n",
|
597 |
+
"\n",
|
598 |
" name = must_start_with + name_inner + must_end_with\n",
|
599 |
" #----#\n",
|
600 |
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
|
|
612 |
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
613 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
614 |
" #-----#\n",
|
615 |
+
" results_name[iter*NUM_PERMUTATIONS + index] = name\n",
|
616 |
+
" results_sim[iter*NUM_PERMUTATIONS + index] = sim\n",
|
617 |
+
" results_name_B[iter*NUM_PERMUTATIONS + index] = name_inner.replace('</w>',' ')\n",
|
618 |
" #------#\n",
|
619 |
+
" name_B = results_name_B[iter*NUM_PERMUTATIONS + random.randint(0,3)]\n",
|
|
|
|
|
|
|
|
|
620 |
"#--------#\n",
|
|
|
|
|
|
|
|
|
621 |
"\n",
|
622 |
+
"print('')\n",
|
623 |
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
624 |
"\n",
|
625 |
+
"for index in range(ITERS*NUM_PERMUTATIONS):\n",
|
|
|
626 |
" name_inner = results_name[indices[index].item()]\n",
|
627 |
" print(must_start_with + name_inner + must_end_with)\n",
|
628 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|