codeShare commited on
Commit
1028385
·
verified ·
1 Parent(s): 8ad9fee

Upload sd_token_similarity_calculator.ipynb

Browse files
Files changed (1) hide show
  1. sd_token_similarity_calculator.ipynb +116 -19
sd_token_similarity_calculator.ipynb CHANGED
@@ -14,6 +14,15 @@
14
  }
15
  },
16
  "cells": [
 
 
 
 
 
 
 
 
 
17
  {
18
  "cell_type": "code",
19
  "source": [
@@ -23,7 +32,42 @@
23
  "from torch import linalg as LA\n",
24
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
25
  "%cd /content/sd_tokens\n",
26
- "token = torch.load('sd15_tensors.pt', map_location=device, weights_only=True)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ],
28
  "metadata": {
29
  "id": "Ch9puvwKH1s3"
@@ -34,7 +78,8 @@
34
  {
35
  "cell_type": "code",
36
  "source": [
37
- "print(token[100].shape) #dimension of the tokens"
 
38
  ],
39
  "metadata": {
40
  "id": "S_Yh9gH_OUA1"
@@ -42,36 +87,88 @@
42
  "execution_count": null,
43
  "outputs": []
44
  },
 
 
 
 
 
 
 
 
 
 
 
45
  {
46
  "cell_type": "code",
47
  "source": [
48
- "def absolute_value(x):\n",
49
- " return max(x, -x)\n",
50
  "\n",
51
- "def similarity(id_A , id_B):\n",
52
- " #Tensors\n",
53
- " A = token[id_A]\n",
54
- " B = token[id_B]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  "\n",
56
- " #Tensor vector length (2nd order, i.e (a^2 + b^2 + ....)^(1/2)\n",
57
- " _A = LA.vector_norm(A, ord=2)\n",
58
- " _B = LA.vector_norm(B, ord=2)\n",
 
59
  "\n",
 
 
 
 
60
  " result = torch.dot(A,B)/(_A*_B)\n",
61
- " similarity_pcnt = absolute_value(result.item()*100)\n",
62
- "\n",
63
- " similarity_pcnt_aprox = round(similarity_pcnt, 3)\n",
64
  "\n",
65
- " result = f'{similarity_pcnt_aprox} %'\n",
66
- "\n",
67
- " return result"
 
68
  ],
69
  "metadata": {
70
- "id": "fxquCxFaUxAZ"
71
  },
72
- "execution_count": 16,
73
  "outputs": []
74
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  {
76
  "cell_type": "markdown",
77
  "source": [
 
14
  }
15
  },
16
  "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "This Notebook is a Stable-diffusion tool which allows you to find similiar tokens from the SD 1.5 vocab.json that you can use for text-to-image generation"
21
+ ],
22
+ "metadata": {
23
+ "id": "L7JTcbOdBPfh"
24
+ }
25
+ },
26
  {
27
  "cell_type": "code",
28
  "source": [
 
32
  "from torch import linalg as LA\n",
33
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
34
  "%cd /content/sd_tokens\n",
35
+ "token = torch.load('sd15_tensors.pt', map_location=device, weights_only=True)\n",
36
+ "#-----#\n",
37
+ "\n",
38
+ "#Import the vocab.json\n",
39
+ "import json\n",
40
+ "import pandas as pd\n",
41
+ "with open('vocab.json', 'r') as f:\n",
42
+ " data = json.load(f)\n",
43
+ "\n",
44
+ "_df = pd.DataFrame({'count': data})['count']\n",
45
+ "\n",
46
+ "vocab = {\n",
47
+ " value: key for key, value in _df.items()\n",
48
+ "}\n",
49
+ "#-----#\n",
50
+ "\n",
51
+ "# Define functions/constants\n",
52
+ "NUM_TOKENS = 49407\n",
53
+ "\n",
54
+ "def absolute_value(x):\n",
55
+ " return max(x, -x)\n",
56
+ "\n",
57
+ "def similarity(id_A , id_B):\n",
58
+ " #Tensors\n",
59
+ " A = token[id_A]\n",
60
+ " B = token[id_B]\n",
61
+ " #Tensor vector length (2nd order, i.e (a^2 + b^2 + ....)^(1/2)\n",
62
+ " _A = LA.vector_norm(A, ord=2)\n",
63
+ " _B = LA.vector_norm(B, ord=2)\n",
64
+ " #----#\n",
65
+ " result = torch.dot(A,B)/(_A*_B)\n",
66
+ " similarity_pcnt = absolute_value(result.item()*100)\n",
67
+ " similarity_pcnt_aprox = round(similarity_pcnt, 3)\n",
68
+ " result = f'{similarity_pcnt_aprox} %'\n",
69
+ " return result\n",
70
+ "#----#"
71
  ],
72
  "metadata": {
73
  "id": "Ch9puvwKH1s3"
 
78
  {
79
  "cell_type": "code",
80
  "source": [
81
+ "print(vocab[12432]) #the vocab item for ID 12432\n",
82
+ "print(token[12432].shape) #dimension of the token"
83
  ],
84
  "metadata": {
85
  "id": "S_Yh9gH_OUA1"
 
87
  "execution_count": null,
88
  "outputs": []
89
  },
90
+ {
91
+ "cell_type": "markdown",
92
+ "source": [
93
+ "Get the IDs from a prompt text.\n",
94
+ "\n",
95
+ "The prompt will be enclosed with the <|start-of-text|> and <|end-of-text|> tokens"
96
+ ],
97
+ "metadata": {
98
+ "id": "f1-jS7YJApiO"
99
+ }
100
+ },
101
  {
102
  "cell_type": "code",
103
  "source": [
 
 
104
  "\n",
105
+ "from transformers import AutoTokenizer\n",
106
+ "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
107
+ "prompt= \"blah\" # @param {type:'string'}\n",
108
+ "tokenizer_output = tokenizer(text = prompt)\n",
109
+ "input_ids = tokenizer_output['input_ids']\n",
110
+ "print(input_ids)"
111
+ ],
112
+ "metadata": {
113
+ "id": "RPdkYzT2_X85"
114
+ },
115
+ "execution_count": null,
116
+ "outputs": []
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "source": [
121
+ "#Produce a list id IDs that are most similiar to the prompt ID at positiion 1\n",
122
  "\n",
123
+ "id_A = input_ids[1]\n",
124
+ "A = token[id_A]\n",
125
+ "_A = LA.vector_norm(A, ord=2)\n",
126
+ "dots = torch.zeros(NUM_TOKENS)\n",
127
  "\n",
128
+ "for index in range(NUM_TOKENS):\n",
129
+ " id_B = index\n",
130
+ " B = token[id_B]\n",
131
+ " _B = LA.vector_norm(B, ord=2)\n",
132
  " result = torch.dot(A,B)/(_A*_B)\n",
133
+ " result = absolute_value(result.item())\n",
134
+ " dots[index] = result\n",
 
135
  "\n",
136
+ "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
137
+ "#----#\n",
138
+ "print(f'Calculated all cosine-similarities between ID = {id_A} the rest of the IDs as a 1x{sorted.shape[0]} tensor')\n",
139
+ "print(f'Calculated indices as a 1x{indices.shape[0]} tensor')"
140
  ],
141
  "metadata": {
142
+ "id": "juxsvco9B0iV"
143
  },
144
+ "execution_count": null,
145
  "outputs": []
146
  },
147
+ {
148
+ "cell_type": "code",
149
+ "source": [
150
+ "list_size = 10 # @param {type:'number'}\n",
151
+ "for index in range(list_size):\n",
152
+ " print(f'{vocab[indices[index]]}') # vocab item\n",
153
+ " print(f'ID = {indices[index]}') # IDs\n",
154
+ " print(f'similiarity = {round(sorted[index].item()*100,2)} %') # % value\n",
155
+ " print('--------')\n"
156
+ ],
157
+ "metadata": {
158
+ "id": "YIEmLAzbHeuo"
159
+ },
160
+ "execution_count": null,
161
+ "outputs": []
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "source": [
166
+ "Find the most similiar Tokens for given input"
167
+ ],
168
+ "metadata": {
169
+ "id": "qqZ5DvfLBJnw"
170
+ }
171
+ },
172
  {
173
  "cell_type": "markdown",
174
  "source": [