Upload sd_token_similarity_calculator.ipynb
Browse files
sd_token_similarity_calculator.ipynb
CHANGED
@@ -59,9 +59,11 @@
|
|
59 |
"\n",
|
60 |
"\n",
|
61 |
"def token_similarity(A, B):\n",
|
62 |
-
"
|
|
|
63 |
" _A = LA.vector_norm(A, ord=2)\n",
|
64 |
" _B = LA.vector_norm(B, ord=2)\n",
|
|
|
65 |
" #----#\n",
|
66 |
" result = torch.dot(A,B)/(_A*_B)\n",
|
67 |
" #similarity_pcnt = absolute_value(result.item()*100)\n",
|
@@ -70,6 +72,7 @@
|
|
70 |
" result = f'{similarity_pcnt_aprox} %'\n",
|
71 |
" return result\n",
|
72 |
"\n",
|
|
|
73 |
"def similarity(id_A , id_B):\n",
|
74 |
" #Tensors\n",
|
75 |
" A = token[id_A]\n",
|
@@ -81,11 +84,39 @@
|
|
81 |
"#print(token[8922].shape) #dimension of the token\n",
|
82 |
"\n",
|
83 |
"mix_with = \"\"\n",
|
84 |
-
"mix_method = \"None\""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
],
|
86 |
"metadata": {
|
87 |
"id": "Ch9puvwKH1s3",
|
88 |
-
"collapsed": true
|
|
|
89 |
},
|
90 |
"execution_count": null,
|
91 |
"outputs": []
|
@@ -119,6 +150,7 @@
|
|
119 |
" R = torch.rand(768)\n",
|
120 |
" _R = LA.vector_norm(R, ord=2)\n",
|
121 |
" A = R*(_A/_R)\n",
|
|
|
122 |
"\n",
|
123 |
"\n",
|
124 |
"mix_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"(optional) write something else\"}\n",
|
@@ -138,6 +170,26 @@
|
|
138 |
" R = torch.rand(768)\n",
|
139 |
" _R = LA.vector_norm(R, ord=2)\n",
|
140 |
" C = R*(_C/_R)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
"\n",
|
142 |
"if (mix_method == \"None\"):\n",
|
143 |
" print(\"No operation\")\n",
|
@@ -145,14 +197,15 @@
|
|
145 |
"if (mix_method == \"Average\"):\n",
|
146 |
" A = w*A + (1-w)*C\n",
|
147 |
" _A = LA.vector_norm(A, ord=2)\n",
|
148 |
-
" print(\"Tokenized prompt tensor A has been recalculated as A = w*A + (1-w)*C , where C is
|
149 |
"\n",
|
150 |
"if (mix_method == \"Subtract\"):\n",
|
151 |
-
" tmp =
|
152 |
-
" _tmp =
|
153 |
-
" A =
|
|
|
154 |
" _A = LA.vector_norm(A, ord=2)\n",
|
155 |
-
" print(\"Tokenized prompt tensor A has been recalculated as A =
|
156 |
"\n",
|
157 |
"#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor\n",
|
158 |
"\n",
|
@@ -166,14 +219,6 @@
|
|
166 |
" result = result.item()\n",
|
167 |
" dots[index] = result\n",
|
168 |
"\n",
|
169 |
-
"name_A = \"A of random type\"\n",
|
170 |
-
"if (id_A>-1):\n",
|
171 |
-
" name_A = vocab[id_A]\n",
|
172 |
-
"\n",
|
173 |
-
"name_C = \"token C of random type\"\n",
|
174 |
-
"if (id_C>-1):\n",
|
175 |
-
" name_C = vocab[id_C]\n",
|
176 |
-
"\n",
|
177 |
"\n",
|
178 |
"sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
179 |
"#----#\n",
|
@@ -194,11 +239,11 @@
|
|
194 |
"\n",
|
195 |
"\n",
|
196 |
"if (print_Divider):\n",
|
197 |
-
" print('//---//')
|
198 |
"\n",
|
199 |
-
"print('')
|
200 |
-
"print('Here is the result : ')
|
201 |
-
"print('')
|
202 |
"\n",
|
203 |
"for index in range(list_size):\n",
|
204 |
" id = indices[index].item()\n",
|
@@ -207,14 +252,15 @@
|
|
207 |
" if (print_ID):\n",
|
208 |
" print(f'ID = {id}') # IDs\n",
|
209 |
" if (print_Similarity):\n",
|
210 |
-
" print(f'similiarity = {round(sorted[index].item()*100,2)} %')
|
211 |
" if (print_Divider):\n",
|
212 |
" print('--------')\n",
|
213 |
"\n",
|
214 |
"#Print the sorted list from above result"
|
215 |
],
|
216 |
"metadata": {
|
217 |
-
"id": "iWeFnT1gAx6A"
|
|
|
218 |
},
|
219 |
"execution_count": null,
|
220 |
"outputs": []
|
@@ -250,7 +296,8 @@
|
|
250 |
],
|
251 |
"metadata": {
|
252 |
"id": "QQOjh5BvnG8M",
|
253 |
-
"collapsed": true
|
|
|
254 |
},
|
255 |
"execution_count": null,
|
256 |
"outputs": []
|
|
|
59 |
"\n",
|
60 |
"\n",
|
61 |
"def token_similarity(A, B):\n",
|
62 |
+
"\n",
|
63 |
+
" #Vector length#\n",
|
64 |
" _A = LA.vector_norm(A, ord=2)\n",
|
65 |
" _B = LA.vector_norm(B, ord=2)\n",
|
66 |
+
"\n",
|
67 |
" #----#\n",
|
68 |
" result = torch.dot(A,B)/(_A*_B)\n",
|
69 |
" #similarity_pcnt = absolute_value(result.item()*100)\n",
|
|
|
72 |
" result = f'{similarity_pcnt_aprox} %'\n",
|
73 |
" return result\n",
|
74 |
"\n",
|
75 |
+
"\n",
|
76 |
"def similarity(id_A , id_B):\n",
|
77 |
" #Tensors\n",
|
78 |
" A = token[id_A]\n",
|
|
|
84 |
"#print(token[8922].shape) #dimension of the token\n",
|
85 |
"\n",
|
86 |
"mix_with = \"\"\n",
|
87 |
+
"mix_method = \"None\"\n",
|
88 |
+
"\n",
|
89 |
+
"#-------------#\n",
|
90 |
+
"# UNUSED\n",
|
91 |
+
"\n",
|
92 |
+
"# Get the 10 lowest values from a tensor as a string\n",
|
93 |
+
"def get_valleys (A):\n",
|
94 |
+
" sorted, indices = torch.sort(A,dim=0 , descending=False)\n",
|
95 |
+
" result = \"{\"\n",
|
96 |
+
" for index in range(10):\n",
|
97 |
+
" id = indices[index].item()\n",
|
98 |
+
" result = result + f\"{id}\"\n",
|
99 |
+
" if(index<9):\n",
|
100 |
+
" result = result + \",\"\n",
|
101 |
+
" result = result + \"}\"\n",
|
102 |
+
" return result\n",
|
103 |
+
"\n",
|
104 |
+
"# Get the 10 highest values from a tensor as a string\n",
|
105 |
+
"def get_peaks (A):\n",
|
106 |
+
" sorted, indices = torch.sort(A,dim=0 , descending=True)\n",
|
107 |
+
" result = \"{\"\n",
|
108 |
+
" for index in range(10):\n",
|
109 |
+
" id = indices[index].item()\n",
|
110 |
+
" result = result + f\"{id}\"\n",
|
111 |
+
" if(index<9):\n",
|
112 |
+
" result = result + \",\"\n",
|
113 |
+
" result = result + \"}\"\n",
|
114 |
+
" return result"
|
115 |
],
|
116 |
"metadata": {
|
117 |
"id": "Ch9puvwKH1s3",
|
118 |
+
"collapsed": true,
|
119 |
+
"cellView": "form"
|
120 |
},
|
121 |
"execution_count": null,
|
122 |
"outputs": []
|
|
|
150 |
" R = torch.rand(768)\n",
|
151 |
" _R = LA.vector_norm(R, ord=2)\n",
|
152 |
" A = R*(_A/_R)\n",
|
153 |
+
" name_A = 'random_A'\n",
|
154 |
"\n",
|
155 |
"\n",
|
156 |
"mix_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"(optional) write something else\"}\n",
|
|
|
170 |
" R = torch.rand(768)\n",
|
171 |
" _R = LA.vector_norm(R, ord=2)\n",
|
172 |
" C = R*(_C/_R)\n",
|
173 |
+
" name_C = 'random_C'\n",
|
174 |
+
"\n",
|
175 |
+
"name_A = \"A of random type\"\n",
|
176 |
+
"if (id_A>-1):\n",
|
177 |
+
" name_A = vocab[id_A]\n",
|
178 |
+
"\n",
|
179 |
+
"name_C = \"token C of random type\"\n",
|
180 |
+
"if (id_C>-1):\n",
|
181 |
+
" name_C = vocab[id_C]\n",
|
182 |
+
"\n",
|
183 |
+
"# Peaks feature\n",
|
184 |
+
"#peaks_A = get_valleys(A)\n",
|
185 |
+
"#peaks_C = get_valleys(C)\n",
|
186 |
+
"#print(f\"The elementwise top 10 highest values for A is at indices {peaks_A}\")\n",
|
187 |
+
"#print(\"-------\")\n",
|
188 |
+
"#print(f\"The elementwise top 10 highest values for C is at indices {peaks_C}\")\n",
|
189 |
+
"#print(\"-------\")\n",
|
190 |
+
"#//------//\n",
|
191 |
+
"\n",
|
192 |
+
"print(f\"The similarity between A '{name_A}' and C '{name_C}' is {token_similarity(A, C)}\")\n",
|
193 |
"\n",
|
194 |
"if (mix_method == \"None\"):\n",
|
195 |
" print(\"No operation\")\n",
|
|
|
197 |
"if (mix_method == \"Average\"):\n",
|
198 |
" A = w*A + (1-w)*C\n",
|
199 |
" _A = LA.vector_norm(A, ord=2)\n",
|
200 |
+
" print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = w*A + (1-w)*C , where C is '{name_C}' token , for w = {w} \")\n",
|
201 |
"\n",
|
202 |
"if (mix_method == \"Subtract\"):\n",
|
203 |
+
" tmp = w*A - (1-w)*C\n",
|
204 |
+
" _tmp = LA.vector_norm(tmp, ord=2)\n",
|
205 |
+
" A = (_A/_tmp)*tmp\n",
|
206 |
+
" #//---//\n",
|
207 |
" _A = LA.vector_norm(A, ord=2)\n",
|
208 |
+
" print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = _A * norm(w*A - (1-w)*C) , where C is '{name_C}' token , for w = {w} \")\n",
|
209 |
"\n",
|
210 |
"#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor\n",
|
211 |
"\n",
|
|
|
219 |
" result = result.item()\n",
|
220 |
" dots[index] = result\n",
|
221 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
"\n",
|
223 |
"sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
224 |
"#----#\n",
|
|
|
239 |
"\n",
|
240 |
"\n",
|
241 |
"if (print_Divider):\n",
|
242 |
+
" print('//---//')\n",
|
243 |
"\n",
|
244 |
+
"print('')\n",
|
245 |
+
"print('Here is the result : ')\n",
|
246 |
+
"print('')\n",
|
247 |
"\n",
|
248 |
"for index in range(list_size):\n",
|
249 |
" id = indices[index].item()\n",
|
|
|
252 |
" if (print_ID):\n",
|
253 |
" print(f'ID = {id}') # IDs\n",
|
254 |
" if (print_Similarity):\n",
|
255 |
+
" print(f'similiarity = {round(sorted[index].item()*100,2)} %')\n",
|
256 |
" if (print_Divider):\n",
|
257 |
" print('--------')\n",
|
258 |
"\n",
|
259 |
"#Print the sorted list from above result"
|
260 |
],
|
261 |
"metadata": {
|
262 |
+
"id": "iWeFnT1gAx6A",
|
263 |
+
"cellView": "form"
|
264 |
},
|
265 |
"execution_count": null,
|
266 |
"outputs": []
|
|
|
296 |
],
|
297 |
"metadata": {
|
298 |
"id": "QQOjh5BvnG8M",
|
299 |
+
"collapsed": true,
|
300 |
+
"cellView": "form"
|
301 |
},
|
302 |
"execution_count": null,
|
303 |
"outputs": []
|