codeShare commited on
Commit
606fac8
1 Parent(s): bfc742a

Upload sd_token_similarity_calculator.ipynb

Browse files
Files changed (1) hide show
  1. sd_token_similarity_calculator.ipynb +45 -42
sd_token_similarity_calculator.ipynb CHANGED
@@ -118,10 +118,29 @@
118
  ],
119
  "metadata": {
120
  "id": "Ch9puvwKH1s3",
121
- "collapsed": true
 
 
 
 
122
  },
123
- "execution_count": null,
124
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  },
126
  {
127
  "cell_type": "code",
@@ -132,7 +151,7 @@
132
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
133
  "\n",
134
  "# @markdown Write name of token to match against\n",
135
- "token_name = \"banana \" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
136
  "\n",
137
  "prompt = token_name\n",
138
  "# @markdown (optional) Mix the token with something else\n",
@@ -368,7 +387,10 @@
368
  "start_search_at_index = 0 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
369
  "# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
370
  "start_search_at_ID = start_search_at_index\n",
371
- "search_range = 100 # @param {type:\"slider\", min:10, max: 200, step:0}\n",
 
 
 
372
  "iterations = 5 # @param {type:\"slider\", min:1, max: 20, step:0}\n",
373
  "restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
374
  "#markdown Limit char size of included token <----- Disabled\n",
@@ -384,15 +406,11 @@
384
  "RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
385
  "#-----#\n",
386
  "import math, random\n",
387
- "CHUNK = math.floor(NUM_TOKENS/RANGE)\n",
388
- "\n",
389
  "ITERS = iterations\n",
390
  "#-----#\n",
391
  "#LOOP START\n",
392
  "#-----#\n",
393
- "\n",
394
- "\n",
395
- "\n",
396
  "# Check if original solution is best\n",
397
  "best_sim = 0\n",
398
  "name = must_start_with + must_contain + must_end_with\n",
@@ -400,6 +418,7 @@
400
  "text_features = model.get_text_features(**ids)\n",
401
  "text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
402
  "#------#\n",
 
403
  "if(use == '🖼️image_encoding from image'):\n",
404
  " logit_scale = model.logit_scale.exp()\n",
405
  " torch.matmul(text_features, image_features.t()) * logit_scale\n",
@@ -411,7 +430,8 @@
411
  "best_sim = sim\n",
412
  "best_name = name\n",
413
  "name_B = must_contain\n",
414
- "results_sim = torch.zeros(ITERS+1)\n",
 
415
  "results_name_B = {}\n",
416
  "results_name = {}\n",
417
  "#-----#\n",
@@ -420,17 +440,10 @@
420
  " is_trail = torch.zeros(RANGE)\n",
421
  " import re\n",
422
  " #-----#\n",
 
423
  "\n",
424
- " _start = START + iter*CHUNK + iter*random.randint(1,CHUNK)\n",
425
- " results_name[iter] = best_name\n",
426
- " results_sim[iter] = best_sim\n",
427
- " results_name_B[iter] = name_B\n",
428
- " #-----#\n",
429
- " sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
430
- " name_B = results_name_B[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
431
- "\n",
432
- " for index in range(RANGE):\n",
433
- " id_C = min(_start + index, NUM_TOKENS)\n",
434
  " name_C = db_vocab[f'{id_C}']\n",
435
  " is_Prefix = 0\n",
436
  " #Skip if non-AZ characters are found\n",
@@ -573,17 +586,15 @@
573
  " #-----#\n",
574
  " #STEP 2\n",
575
  " import random\n",
576
- " names = {}\n",
577
- " name_inners = {}\n",
578
- " NUM_PERMUTATIONS = 4\n",
579
  " #-----#\n",
580
- " dots = torch.zeros(NUM_PERMUTATIONS)\n",
581
  " for index in range(NUM_PERMUTATIONS):\n",
582
  " name_inner = ''\n",
583
  " if index == 0 : name_inner = name_B\n",
584
- " if index == 1 : name_inner = max_name_ahead\n",
585
- " if index == 2 : name_inner = name_B + max_name_trail\n",
586
- " if index == 3 : name_inner = max_name_ahead + name_B + max_name_trail\n",
 
 
587
  " name = must_start_with + name_inner + must_end_with\n",
588
  " #----#\n",
589
  " ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
@@ -601,25 +612,17 @@
601
  " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
602
  " sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
603
  " #-----#\n",
604
- " dots[index] = sim\n",
605
- " names[index] = name\n",
606
- " name_inners[index] = name_inner\n",
607
  " #------#\n",
608
- " sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
609
- " #------#\n",
610
- " best_sim = dots[indices[0].item()]\n",
611
- " best_name = names[indices[0].item()]\n",
612
- " name_B = name_inners[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
613
  "#--------#\n",
614
- "#store the final value\n",
615
- "results_name[iter+1] = best_name\n",
616
- "results_sim[iter+1] = best_sim\n",
617
- "results_name_B[iter+1] = name_B\n",
618
  "\n",
 
619
  "sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
620
  "\n",
621
- "print('')\n",
622
- "for index in range(ITERS+1):\n",
623
  " name_inner = results_name[indices[index].item()]\n",
624
  " print(must_start_with + name_inner + must_end_with)\n",
625
  " print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
 
118
  ],
119
  "metadata": {
120
  "id": "Ch9puvwKH1s3",
121
+ "collapsed": true,
122
+ "outputId": "033c251a-2043-40e7-9500-4da870ffa7fd",
123
+ "colab": {
124
+ "base_uri": "https://localhost:8080/"
125
+ }
126
  },
127
+ "execution_count": 1,
128
+ "outputs": [
129
+ {
130
+ "output_type": "stream",
131
+ "name": "stdout",
132
+ "text": [
133
+ "Cloning into 'sd_tokens'...\n",
134
+ "remote: Enumerating objects: 20, done.\u001b[K\n",
135
+ "remote: Counting objects: 100% (17/17), done.\u001b[K\n",
136
+ "remote: Compressing objects: 100% (17/17), done.\u001b[K\n",
137
+ "remote: Total 20 (delta 4), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
138
+ "Unpacking objects: 100% (20/20), 310.37 KiB | 2.10 MiB/s, done.\n",
139
+ "Filtering content: 100% (3/3), 160.82 MiB | 26.64 MiB/s, done.\n",
140
+ "/content/sd_tokens\n"
141
+ ]
142
+ }
143
+ ]
144
  },
145
  {
146
  "cell_type": "code",
 
151
  "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
152
  "\n",
153
  "# @markdown Write name of token to match against\n",
154
+ "token_name = \" blanket \" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
155
  "\n",
156
  "prompt = token_name\n",
157
  "# @markdown (optional) Mix the token with something else\n",
 
387
  "start_search_at_index = 0 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
388
  "# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
389
  "start_search_at_ID = start_search_at_index\n",
390
+ "search_range = 1000 # @param {type:\"slider\", min:10, max: 1000, step:10}\n",
391
+ "\n",
392
+ "samples_per_iter = 10 # @param {type:\"slider\", min:10, max: 100, step:10}\n",
393
+ "\n",
394
  "iterations = 5 # @param {type:\"slider\", min:1, max: 20, step:0}\n",
395
  "restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
396
  "#markdown Limit char size of included token <----- Disabled\n",
 
406
  "RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
407
  "#-----#\n",
408
  "import math, random\n",
409
+ "NUM_PERMUTATIONS = 4\n",
 
410
  "ITERS = iterations\n",
411
  "#-----#\n",
412
  "#LOOP START\n",
413
  "#-----#\n",
 
 
 
414
  "# Check if original solution is best\n",
415
  "best_sim = 0\n",
416
  "name = must_start_with + must_contain + must_end_with\n",
 
418
  "text_features = model.get_text_features(**ids)\n",
419
  "text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
420
  "#------#\n",
421
+ "sim = 0\n",
422
  "if(use == '🖼️image_encoding from image'):\n",
423
  " logit_scale = model.logit_scale.exp()\n",
424
  " torch.matmul(text_features, image_features.t()) * logit_scale\n",
 
430
  "best_sim = sim\n",
431
  "best_name = name\n",
432
  "name_B = must_contain\n",
433
+ "#------#\n",
434
+ "results_sim = torch.zeros(ITERS*NUM_PERMUTATIONS)\n",
435
  "results_name_B = {}\n",
436
  "results_name = {}\n",
437
  "#-----#\n",
 
440
  " is_trail = torch.zeros(RANGE)\n",
441
  " import re\n",
442
  " #-----#\n",
443
+ " _start = START + iter*RANGE\n",
444
  "\n",
445
+ " for index in range(samples_per_iter):\n",
446
+ " id_C = min(_start + index, NUM_TOKENS) + random.randint(0,RANGE)\n",
 
 
 
 
 
 
 
 
447
  " name_C = db_vocab[f'{id_C}']\n",
448
  " is_Prefix = 0\n",
449
  " #Skip if non-AZ characters are found\n",
 
586
  " #-----#\n",
587
  " #STEP 2\n",
588
  " import random\n",
 
 
 
589
  " #-----#\n",
 
590
  " for index in range(NUM_PERMUTATIONS):\n",
591
  " name_inner = ''\n",
592
  " if index == 0 : name_inner = name_B\n",
593
+ " if index == 1: name_inner = max_name_ahead\n",
594
+ " if index == 2: name_inner = name_B + max_name_trail\n",
595
+ " if index == 3: name_inner = max_name_ahead + name_B + max_name_trail\n",
596
+ " if name_inner == '': name_inner = max_name_ahead + name_B + max_name_trail\n",
597
+ "\n",
598
  " name = must_start_with + name_inner + must_end_with\n",
599
  " #----#\n",
600
  " ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
 
612
  " text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
613
  " sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
614
  " #-----#\n",
615
+ " results_name[iter*NUM_PERMUTATIONS + index] = name\n",
616
+ " results_sim[iter*NUM_PERMUTATIONS + index] = sim\n",
617
+ " results_name_B[iter*NUM_PERMUTATIONS + index] = name_inner.replace('</w>',' ')\n",
618
  " #------#\n",
619
+ " name_B = results_name_B[iter*NUM_PERMUTATIONS + random.randint(0,3)]\n",
 
 
 
 
620
  "#--------#\n",
 
 
 
 
621
  "\n",
622
+ "print('')\n",
623
  "sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
624
  "\n",
625
+ "for index in range(ITERS*NUM_PERMUTATIONS):\n",
 
626
  " name_inner = results_name[indices[index].item()]\n",
627
  " print(must_start_with + name_inner + must_end_with)\n",
628
  " print(f'similiarity = {round(sorted[index].item(),2)} %')\n",