Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Make server/component frontend pair
Browse files- client/dist/exBERT.html +1 -183
- client/dist/main.js +0 -0
- client/src/exBERT.html +1 -183
- client/src/ts/api/mainApi.ts +1 -53
- client/src/ts/api/responses.ts +0 -4
- client/src/ts/data/FaissSearchWrapper.ts +0 -127
- client/src/ts/data/TokenWrapper.ts +0 -6
- client/src/ts/etc/SpacyInfo.ts +0 -61
- client/src/ts/etc/types.ts +10 -61
- client/src/ts/main.ts +7 -104
- client/src/ts/uiConfig.ts +7 -78
- client/src/ts/vis/AttentionConnector.ts +6 -6
- client/src/ts/vis/CorpusHistogram.ts +0 -253
- client/src/ts/vis/CorpusInspector.ts +0 -150
- client/src/ts/vis/CorpusMatManager.ts +0 -321
- client/src/ts/vis/{myMain.ts → attentionVis.ts} +251 -447
- server/main.py +3 -4
- server/swagger.yaml +1 -76
- server/transformer_formatter.py +13 -3
- server/utils/path_fixes.py +1 -1
client/dist/exBERT.html
CHANGED
@@ -34,189 +34,7 @@
|
|
34 |
|
35 |
</div>
|
36 |
|
37 |
-
<div
|
38 |
-
<div class="left-half vpartial-95 scrolling">
|
39 |
-
|
40 |
-
<div class="text-center">
|
41 |
-
|
42 |
-
<div id="sentence-input">
|
43 |
-
<form>
|
44 |
-
<div class="form-group">
|
45 |
-
<label for="form-sentence-a"> Input Sentence </label>
|
46 |
-
<input id="form-sentence-a" type="text" name="sent-a-input"> </p>
|
47 |
-
</div>
|
48 |
-
<div class="padding"></div>
|
49 |
-
<button class="btn btn-primary" id="update-sentence" type="button">Update</button>
|
50 |
-
</form>
|
51 |
-
</div>
|
52 |
-
<hr />
|
53 |
-
|
54 |
-
<div id="connector-container">
|
55 |
-
|
56 |
-
<div class="connector-controls">
|
57 |
-
<div class="left-control-half">
|
58 |
-
<div id="model-selection">
|
59 |
-
<label for="model-options">Select model: </label>
|
60 |
-
<select id="model-option-selector" name="model-options"></select>
|
61 |
-
</div>
|
62 |
-
<div class="slide-container">
|
63 |
-
<div>
|
64 |
-
<label for="my-range">
|
65 |
-
Display top <span id="my-range-value">…</span>% of attention
|
66 |
-
</label>
|
67 |
-
<input type="range" min="0" max="100" value="70" class="slider" id="my-range"> <br>
|
68 |
-
</div>
|
69 |
-
</div>
|
70 |
-
|
71 |
-
<div id="layer-selection">
|
72 |
-
<div class="input-description">
|
73 |
-
Layer:
|
74 |
-
</div>
|
75 |
-
|
76 |
-
<div class="layer-select btn-group btn-group-toggle" data-toggle="buttons"
|
77 |
-
id="layer-select"> </div>
|
78 |
-
</div>
|
79 |
-
<div id="cls-toggle">
|
80 |
-
<div class="input-description">
|
81 |
-
Hide Special Tokens
|
82 |
-
</div>
|
83 |
-
|
84 |
-
<label class="switch">
|
85 |
-
<input type="checkbox" checked='checked'>
|
86 |
-
<span class="short-slider round"></span>
|
87 |
-
</label>
|
88 |
-
</div>
|
89 |
-
</div>
|
90 |
-
|
91 |
-
<div class="head-control">
|
92 |
-
<div id="selected-head-display">
|
93 |
-
<div class="input-description">
|
94 |
-
Selected heads:
|
95 |
-
</div>
|
96 |
-
<div id="selected-heads"></div>
|
97 |
-
</div>
|
98 |
-
|
99 |
-
<div class="select-input" id="head-all-or-none">
|
100 |
-
<button id="select-all-heads">Select all heads</button>
|
101 |
-
<button id="select-no-heads">Unselect all heads</button>
|
102 |
-
</div>
|
103 |
-
|
104 |
-
<div id="usage-info">
|
105 |
-
<p> You focus on one token by <b>click</b>.<br />
|
106 |
-
You can mask any token by <b>double click</b>.</p>
|
107 |
-
<p>You can select and de-select a head by a <b>click</b> on the heatmap columns</p>
|
108 |
-
|
109 |
-
</div>
|
110 |
-
</div>
|
111 |
-
|
112 |
-
</div>
|
113 |
-
|
114 |
-
<div id=vis-break></div>
|
115 |
-
|
116 |
-
|
117 |
-
<div class="text-center" id="atn-container">
|
118 |
-
<div id="head-info-box"></div>
|
119 |
-
<svg id="left-att-heads"></svg>
|
120 |
-
<div id="left-tokens"></div>
|
121 |
-
<svg id="atn-display"></svg>
|
122 |
-
<div id="right-tokens"></div>
|
123 |
-
<svg id="right-att-heads"></svg>
|
124 |
-
</div>
|
125 |
-
|
126 |
-
</div>
|
127 |
-
|
128 |
-
<!-- Part II of HTML -->
|
129 |
-
<hr />
|
130 |
-
|
131 |
-
<div id="corpus-selection-description">
|
132 |
-
<header>
|
133 |
-
<!-- Search <span class="inline-select" id="corpus-select">corpus</span> -->
|
134 |
-
Search <select id="corpus-select"></select>
|
135 |
-
</header>
|
136 |
-
</div>
|
137 |
-
|
138 |
-
<div id="corpus-querying">
|
139 |
-
<form>
|
140 |
-
<button class="btn btn-primary" id="search-contexts" type="button">by Context</button>
|
141 |
-
<button class="btn btn-primary" id="search-embeddings" type="button">by Embedding</button>
|
142 |
-
</form>
|
143 |
-
</div>
|
144 |
-
|
145 |
-
<div id="histograms">
|
146 |
-
<div id="matched-histogram">
|
147 |
-
<svg class="histogram" id="matched-histogram-container"></svg>
|
148 |
-
<div class="pos-selector">
|
149 |
-
<span id="match-kind">Matched</span> Word Summary:
|
150 |
-
<div id="matched-meta-select" class="btn-group btn-group-toggle" data-toggle="buttons">
|
151 |
-
<label class="btn btn-secondary active" value="pos">
|
152 |
-
<input type="radio" name="options" id="option1" autocomplete="off" value="pos"> POS
|
153 |
-
</label>
|
154 |
-
<label class="btn btn-secondary" value="dep">
|
155 |
-
<input type="radio" name="options" id="option2" autocomplete="off"> DEP
|
156 |
-
</label>
|
157 |
-
<label class="btn btn-secondary" value="is_ent">
|
158 |
-
<input type="radio" name="options" id="option3" autocomplete="off"> ENT
|
159 |
-
</label>
|
160 |
-
</div>
|
161 |
-
</div>
|
162 |
-
</div>
|
163 |
-
|
164 |
-
<div id="max-att-histogram">
|
165 |
-
<svg class="histogram" id="max-att-histogram-container"></svg>
|
166 |
-
<div class="pos-selector">
|
167 |
-
Max Attention Summary:
|
168 |
-
|
169 |
-
<div id="max-att-meta-select" class="btn-group btn-group-toggle" data-toggle="buttons">
|
170 |
-
<label class="btn btn-secondary active" value="pos">
|
171 |
-
<input type="radio" name="options" autocomplete="off" value="pos"> POS
|
172 |
-
</label>
|
173 |
-
<label class="btn btn-secondary" value="dep">
|
174 |
-
<input type="radio" name="options" autocomplete="off"> DEP
|
175 |
-
</label>
|
176 |
-
<label class="btn btn-secondary" value="is_ent">
|
177 |
-
<input type="radio" name="options" autocomplete="off"> ENT
|
178 |
-
</label>
|
179 |
-
<label class="btn btn-secondary" value="offset">
|
180 |
-
<input type="radio" name="options" autocomplete="off"> OFFSET
|
181 |
-
</label>
|
182 |
-
</div>
|
183 |
-
<!-- <select name="position-meta-dropdown" id="position-meta-dropdown">
|
184 |
-
<option value="offset">OFFSET</option>
|
185 |
-
</select> -->
|
186 |
-
</div>
|
187 |
-
</div>
|
188 |
-
</div>
|
189 |
-
</div>
|
190 |
-
|
191 |
-
</div>
|
192 |
-
|
193 |
-
<div class="vertical-separator"></div>
|
194 |
-
|
195 |
-
|
196 |
-
<div class="right-half">
|
197 |
-
<div id="corpus-vis">
|
198 |
-
<div id="corpus-control-buttons">
|
199 |
-
<button class="btn btn-xs btn-secondary" id="minus-left" type="button">+</button>
|
200 |
-
<button class="btn btn-xs btn-danger" id="kill-left" type="button">-</button>
|
201 |
-
<span>←||→</span>
|
202 |
-
<button class="btn btn-xs btn-danger" id="kill-right" type="button">-</button>
|
203 |
-
<button class="btn btn-xs btn-secondary" id="plus-right" type="button">+</button>
|
204 |
-
<button class="btn btn-xs btn-info" id="mat-refresh" type="button">↻</button>
|
205 |
-
</div>
|
206 |
-
|
207 |
-
|
208 |
-
<div class="vpartial-90 scrolling">
|
209 |
-
<div class="whitespace"></div>
|
210 |
-
<div id="corpus-msg-box"></div>
|
211 |
-
<div id="main-corpus-vis">
|
212 |
-
<div id="corpus-mat-container"></div>
|
213 |
-
<div id="corpus-similar-sentences-div"></div>
|
214 |
-
</div>
|
215 |
-
</div>
|
216 |
-
</div>
|
217 |
-
|
218 |
-
</div>
|
219 |
-
</div>
|
220 |
|
221 |
<script src="vendor.js"></script>
|
222 |
<script src="main.js"></script>
|
|
|
34 |
|
35 |
</div>
|
36 |
|
37 |
+
<div id="attention-vis"></div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
<script src="vendor.js"></script>
|
40 |
<script src="main.js"></script>
|
client/dist/main.js
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
client/src/exBERT.html
CHANGED
@@ -34,189 +34,7 @@
|
|
34 |
|
35 |
</div>
|
36 |
|
37 |
-
<div
|
38 |
-
<div class="left-half vpartial-95 scrolling">
|
39 |
-
|
40 |
-
<div class="text-center">
|
41 |
-
|
42 |
-
<div id="sentence-input">
|
43 |
-
<form>
|
44 |
-
<div class="form-group">
|
45 |
-
<label for="form-sentence-a"> Input Sentence </label>
|
46 |
-
<input id="form-sentence-a" type="text" name="sent-a-input"> </p>
|
47 |
-
</div>
|
48 |
-
<div class="padding"></div>
|
49 |
-
<button class="btn btn-primary" id="update-sentence" type="button">Update</button>
|
50 |
-
</form>
|
51 |
-
</div>
|
52 |
-
<hr />
|
53 |
-
|
54 |
-
<div id="connector-container">
|
55 |
-
|
56 |
-
<div class="connector-controls">
|
57 |
-
<div class="left-control-half">
|
58 |
-
<div id="model-selection">
|
59 |
-
<label for="model-options">Select model: </label>
|
60 |
-
<select id="model-option-selector" name="model-options"></select>
|
61 |
-
</div>
|
62 |
-
<div class="slide-container">
|
63 |
-
<div>
|
64 |
-
<label for="my-range">
|
65 |
-
Display top <span id="my-range-value">…</span>% of attention
|
66 |
-
</label>
|
67 |
-
<input type="range" min="0" max="100" value="70" class="slider" id="my-range"> <br>
|
68 |
-
</div>
|
69 |
-
</div>
|
70 |
-
|
71 |
-
<div id="layer-selection">
|
72 |
-
<div class="input-description">
|
73 |
-
Layer:
|
74 |
-
</div>
|
75 |
-
|
76 |
-
<div class="layer-select btn-group btn-group-toggle" data-toggle="buttons"
|
77 |
-
id="layer-select"> </div>
|
78 |
-
</div>
|
79 |
-
<div id="cls-toggle">
|
80 |
-
<div class="input-description">
|
81 |
-
Hide Special Tokens
|
82 |
-
</div>
|
83 |
-
|
84 |
-
<label class="switch">
|
85 |
-
<input type="checkbox" checked='checked'>
|
86 |
-
<span class="short-slider round"></span>
|
87 |
-
</label>
|
88 |
-
</div>
|
89 |
-
</div>
|
90 |
-
|
91 |
-
<div class="head-control">
|
92 |
-
<div id="selected-head-display">
|
93 |
-
<div class="input-description">
|
94 |
-
Selected heads:
|
95 |
-
</div>
|
96 |
-
<div id="selected-heads"></div>
|
97 |
-
</div>
|
98 |
-
|
99 |
-
<div class="select-input" id="head-all-or-none">
|
100 |
-
<button id="select-all-heads">Select all heads</button>
|
101 |
-
<button id="select-no-heads">Unselect all heads</button>
|
102 |
-
</div>
|
103 |
-
|
104 |
-
<div id="usage-info">
|
105 |
-
<p> You focus on one token by <b>click</b>.<br />
|
106 |
-
You can mask any token by <b>double click</b>.</p>
|
107 |
-
<p>You can select and de-select a head by a <b>click</b> on the heatmap columns</p>
|
108 |
-
|
109 |
-
</div>
|
110 |
-
</div>
|
111 |
-
|
112 |
-
</div>
|
113 |
-
|
114 |
-
<div id=vis-break></div>
|
115 |
-
|
116 |
-
|
117 |
-
<div class="text-center" id="atn-container">
|
118 |
-
<div id="head-info-box"></div>
|
119 |
-
<svg id="left-att-heads"></svg>
|
120 |
-
<div id="left-tokens"></div>
|
121 |
-
<svg id="atn-display"></svg>
|
122 |
-
<div id="right-tokens"></div>
|
123 |
-
<svg id="right-att-heads"></svg>
|
124 |
-
</div>
|
125 |
-
|
126 |
-
</div>
|
127 |
-
|
128 |
-
<!-- Part II of HTML -->
|
129 |
-
<hr />
|
130 |
-
|
131 |
-
<div id="corpus-selection-description">
|
132 |
-
<header>
|
133 |
-
<!-- Search <span class="inline-select" id="corpus-select">corpus</span> -->
|
134 |
-
Search <select id="corpus-select"></select>
|
135 |
-
</header>
|
136 |
-
</div>
|
137 |
-
|
138 |
-
<div id="corpus-querying">
|
139 |
-
<form>
|
140 |
-
<button class="btn btn-primary" id="search-contexts" type="button">by Context</button>
|
141 |
-
<button class="btn btn-primary" id="search-embeddings" type="button">by Embedding</button>
|
142 |
-
</form>
|
143 |
-
</div>
|
144 |
-
|
145 |
-
<div id="histograms">
|
146 |
-
<div id="matched-histogram">
|
147 |
-
<svg class="histogram" id="matched-histogram-container"></svg>
|
148 |
-
<div class="pos-selector">
|
149 |
-
<span id="match-kind">Matched</span> Word Summary:
|
150 |
-
<div id="matched-meta-select" class="btn-group btn-group-toggle" data-toggle="buttons">
|
151 |
-
<label class="btn btn-secondary active" value="pos">
|
152 |
-
<input type="radio" name="options" id="option1" autocomplete="off" value="pos"> POS
|
153 |
-
</label>
|
154 |
-
<label class="btn btn-secondary" value="dep">
|
155 |
-
<input type="radio" name="options" id="option2" autocomplete="off"> DEP
|
156 |
-
</label>
|
157 |
-
<label class="btn btn-secondary" value="is_ent">
|
158 |
-
<input type="radio" name="options" id="option3" autocomplete="off"> ENT
|
159 |
-
</label>
|
160 |
-
</div>
|
161 |
-
</div>
|
162 |
-
</div>
|
163 |
-
|
164 |
-
<div id="max-att-histogram">
|
165 |
-
<svg class="histogram" id="max-att-histogram-container"></svg>
|
166 |
-
<div class="pos-selector">
|
167 |
-
Max Attention Summary:
|
168 |
-
|
169 |
-
<div id="max-att-meta-select" class="btn-group btn-group-toggle" data-toggle="buttons">
|
170 |
-
<label class="btn btn-secondary active" value="pos">
|
171 |
-
<input type="radio" name="options" autocomplete="off" value="pos"> POS
|
172 |
-
</label>
|
173 |
-
<label class="btn btn-secondary" value="dep">
|
174 |
-
<input type="radio" name="options" autocomplete="off"> DEP
|
175 |
-
</label>
|
176 |
-
<label class="btn btn-secondary" value="is_ent">
|
177 |
-
<input type="radio" name="options" autocomplete="off"> ENT
|
178 |
-
</label>
|
179 |
-
<label class="btn btn-secondary" value="offset">
|
180 |
-
<input type="radio" name="options" autocomplete="off"> OFFSET
|
181 |
-
</label>
|
182 |
-
</div>
|
183 |
-
<!-- <select name="position-meta-dropdown" id="position-meta-dropdown">
|
184 |
-
<option value="offset">OFFSET</option>
|
185 |
-
</select> -->
|
186 |
-
</div>
|
187 |
-
</div>
|
188 |
-
</div>
|
189 |
-
</div>
|
190 |
-
|
191 |
-
</div>
|
192 |
-
|
193 |
-
<div class="vertical-separator"></div>
|
194 |
-
|
195 |
-
|
196 |
-
<div class="right-half">
|
197 |
-
<div id="corpus-vis">
|
198 |
-
<div id="corpus-control-buttons">
|
199 |
-
<button class="btn btn-xs btn-secondary" id="minus-left" type="button">+</button>
|
200 |
-
<button class="btn btn-xs btn-danger" id="kill-left" type="button">-</button>
|
201 |
-
<span>←||→</span>
|
202 |
-
<button class="btn btn-xs btn-danger" id="kill-right" type="button">-</button>
|
203 |
-
<button class="btn btn-xs btn-secondary" id="plus-right" type="button">+</button>
|
204 |
-
<button class="btn btn-xs btn-info" id="mat-refresh" type="button">↻</button>
|
205 |
-
</div>
|
206 |
-
|
207 |
-
|
208 |
-
<div class="vpartial-90 scrolling">
|
209 |
-
<div class="whitespace"></div>
|
210 |
-
<div id="corpus-msg-box"></div>
|
211 |
-
<div id="main-corpus-vis">
|
212 |
-
<div id="corpus-mat-container"></div>
|
213 |
-
<div id="corpus-similar-sentences-div"></div>
|
214 |
-
</div>
|
215 |
-
</div>
|
216 |
-
</div>
|
217 |
-
|
218 |
-
</div>
|
219 |
-
</div>
|
220 |
|
221 |
<script src="vendor.js"></script>
|
222 |
<script src="main.js"></script>
|
|
|
34 |
|
35 |
</div>
|
36 |
|
37 |
+
<div id="attention-vis"></div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
<script src="vendor.js"></script>
|
40 |
<script src="main.js"></script>
|
client/src/ts/api/mainApi.ts
CHANGED
@@ -134,56 +134,4 @@ export class API {
|
|
134 |
|
135 |
return checkDemoAPI(toSend, url, payload)
|
136 |
}
|
137 |
-
|
138 |
-
/**
|
139 |
-
*
|
140 |
-
* @param embedding Embedding of the word
|
141 |
-
* @param layer In the l'th layer
|
142 |
-
* @param k how many results to retrieve
|
143 |
-
*/
|
144 |
-
getNearestEmbeddings(model: string, corpus: string, embedding: number[], layer: number, heads: number[], k = 10, hashObj: {} | null = null): Promise<rsp.NearestNeighborResponse> {
|
145 |
-
const toSend = {
|
146 |
-
model: model,
|
147 |
-
corpus: corpus,
|
148 |
-
embedding: embedding,
|
149 |
-
layer: layer,
|
150 |
-
heads: heads,
|
151 |
-
k: k,
|
152 |
-
}
|
153 |
-
|
154 |
-
const url = makeUrl(this.baseURL + '/k-nearest-embeddings', toSend);
|
155 |
-
console.log("--- GET " + url);
|
156 |
-
|
157 |
-
if (hashObj != null) {
|
158 |
-
const key = hash.sha1(toSend)
|
159 |
-
d3.json(url).then(r => {
|
160 |
-
hashObj[key] = r;
|
161 |
-
})
|
162 |
-
}
|
163 |
-
|
164 |
-
return checkDemoAPI(toSend, url)
|
165 |
-
}
|
166 |
-
|
167 |
-
getNearestContexts(model: string, corpus: string, context: number[], layer: number, heads: number[], k = 10, hashObj: {} | null = null): Promise<rsp.NearestNeighborResponse> {
|
168 |
-
const toSend = {
|
169 |
-
model: model,
|
170 |
-
corpus: corpus,
|
171 |
-
context: context,
|
172 |
-
layer: layer,
|
173 |
-
heads: heads,
|
174 |
-
k: k,
|
175 |
-
}
|
176 |
-
|
177 |
-
const url = makeUrl(this.baseURL + '/k-nearest-contexts', toSend);
|
178 |
-
console.log("--- GET " + url);
|
179 |
-
|
180 |
-
if (hashObj != null) {
|
181 |
-
const key = hash.sha1(toSend)
|
182 |
-
d3.json(url).then(r => {
|
183 |
-
hashObj[key] = r;
|
184 |
-
})
|
185 |
-
}
|
186 |
-
|
187 |
-
return checkDemoAPI(toSend, url)
|
188 |
-
}
|
189 |
-
};
|
|
|
134 |
|
135 |
return checkDemoAPI(toSend, url, payload)
|
136 |
}
|
137 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client/src/ts/api/responses.ts
CHANGED
@@ -16,8 +16,4 @@ export interface ModelDetailResponse extends BaseResponse {
|
|
16 |
|
17 |
export interface AttentionDetailsResponse extends BaseResponse {
|
18 |
payload: tp.AttentionResponse
|
19 |
-
}
|
20 |
-
|
21 |
-
export interface NearestNeighborResponse extends BaseResponse {
|
22 |
-
payload: tp.FaissSearchResults[]
|
23 |
}
|
|
|
16 |
|
17 |
export interface AttentionDetailsResponse extends BaseResponse {
|
18 |
payload: tp.AttentionResponse
|
|
|
|
|
|
|
|
|
19 |
}
|
client/src/ts/data/FaissSearchWrapper.ts
DELETED
@@ -1,127 +0,0 @@
|
|
1 |
-
import * as tp from '../etc/types'
|
2 |
-
import * as d3 from 'd3'
|
3 |
-
import 'd3-array'
|
4 |
-
import * as R from 'ramda'
|
5 |
-
import {SpacyInfo} from '../etc/SpacyInfo'
|
6 |
-
import {initZero} from '../etc/xramda'
|
7 |
-
|
8 |
-
// If value is not a string, don't try to make lowercase
|
9 |
-
const makeStringLower = R.ifElse(R.is(String), R.toLower, R.identity)
|
10 |
-
|
11 |
-
function argMax(array:number[]) {
|
12 |
-
return [].map.call(array, (x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1];
|
13 |
-
}
|
14 |
-
|
15 |
-
|
16 |
-
export class FaissSearchResultWrapper {
|
17 |
-
data: tp.FaissSearchResults[]
|
18 |
-
|
19 |
-
options = {
|
20 |
-
showNext: false
|
21 |
-
}
|
22 |
-
|
23 |
-
constructor(data: tp.FaissSearchResults[], showNext=false) {
|
24 |
-
this.data = data
|
25 |
-
this.options.showNext = showNext
|
26 |
-
}
|
27 |
-
|
28 |
-
get matchAtt() {
|
29 |
-
return this.showNext() ? "matched_att_plus_1" : "matched_att"
|
30 |
-
}
|
31 |
-
|
32 |
-
get matchIdx() {
|
33 |
-
return this.showNext() ? "next_index" : "index"
|
34 |
-
}
|
35 |
-
|
36 |
-
/**
|
37 |
-
* Add position info interpretable by the histogram
|
38 |
-
*
|
39 |
-
* @param countObj Represents the inforrmation to be displayed by the histogram
|
40 |
-
*/
|
41 |
-
countPosInfo() {
|
42 |
-
const attOffsets = this.data.map((d,i) => +d[this.matchAtt].out.offset_to_max)
|
43 |
-
|
44 |
-
const ctObj = {
|
45 |
-
offset: initZero(attOffsets)
|
46 |
-
}
|
47 |
-
|
48 |
-
attOffsets.forEach(v => {
|
49 |
-
Object.keys(ctObj).forEach((k) => {
|
50 |
-
ctObj[k][v] += 1
|
51 |
-
})
|
52 |
-
})
|
53 |
-
|
54 |
-
return ctObj
|
55 |
-
}
|
56 |
-
|
57 |
-
countMaxAttKeys(indexOffset=0) {
|
58 |
-
// The keys in the below object dictate what we count
|
59 |
-
const countObj = {
|
60 |
-
pos: initZero(SpacyInfo.TotalMetaOptions.pos),
|
61 |
-
dep: initZero(SpacyInfo.TotalMetaOptions.dep),
|
62 |
-
is_ent: initZero(SpacyInfo.TotalMetaOptions.is_ent),
|
63 |
-
}
|
64 |
-
|
65 |
-
// Confusing: Show MATCHED WORD attentions, but NEXT WORD distribution
|
66 |
-
const getMaxToken = (d: tp.FaissSearchResults) => d.tokens[argMax(d.matched_att.out.att)]
|
67 |
-
|
68 |
-
this.data.forEach((d, i) => {
|
69 |
-
const maxMatch = getMaxToken(d)
|
70 |
-
|
71 |
-
Object.keys(countObj).forEach(k => {
|
72 |
-
const val = makeStringLower(String(maxMatch[k]))
|
73 |
-
countObj[k][val] += 1;
|
74 |
-
})
|
75 |
-
})
|
76 |
-
|
77 |
-
const newCountObj = Object.assign(countObj, this.countPosInfo())
|
78 |
-
return newCountObj
|
79 |
-
}
|
80 |
-
|
81 |
-
countMatchedKeys(indexOffset=0) {
|
82 |
-
// The keys in the below object dictate what we count
|
83 |
-
const countObj = {
|
84 |
-
pos: initZero(SpacyInfo.TotalMetaOptions.pos),
|
85 |
-
dep: initZero(SpacyInfo.TotalMetaOptions.dep),
|
86 |
-
is_ent: initZero(SpacyInfo.TotalMetaOptions.is_ent),
|
87 |
-
}
|
88 |
-
|
89 |
-
this.data.forEach(d => {
|
90 |
-
// Confusing: Show MATCHED WORD attentions, but NEXT WORD distribution
|
91 |
-
const match = d.tokens[d[this.matchIdx] + indexOffset]
|
92 |
-
|
93 |
-
Object.keys(countObj).forEach(k => {
|
94 |
-
const val = makeStringLower(String(match[k]))
|
95 |
-
countObj[k][val] += 1;
|
96 |
-
})
|
97 |
-
})
|
98 |
-
|
99 |
-
return countObj
|
100 |
-
}
|
101 |
-
|
102 |
-
getMatchedHistogram(indexOffset=0) {
|
103 |
-
const totalHist = this.countMatchedKeys(indexOffset)
|
104 |
-
const filterZeros = (val, key) => val != 0;
|
105 |
-
const nonZero = R.map(R.pickBy(filterZeros), totalHist)
|
106 |
-
|
107 |
-
return nonZero
|
108 |
-
}
|
109 |
-
|
110 |
-
getMaxAttHistogram() {
|
111 |
-
// const totalHist = this.countPosInfo()
|
112 |
-
const newHist = this.countMaxAttKeys()
|
113 |
-
const filterZeros = (val, key) => val != 0;
|
114 |
-
const nonZero = R.map(R.pickBy(filterZeros), newHist)
|
115 |
-
|
116 |
-
return nonZero
|
117 |
-
}
|
118 |
-
|
119 |
-
showNext(): boolean
|
120 |
-
showNext(v:boolean): this
|
121 |
-
showNext(v?) {
|
122 |
-
if (v == null) return this.options.showNext
|
123 |
-
|
124 |
-
this.options.showNext = v
|
125 |
-
return this
|
126 |
-
}
|
127 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client/src/ts/data/TokenWrapper.ts
CHANGED
@@ -8,12 +8,6 @@ import * as R from 'ramda'
|
|
8 |
*/
|
9 |
const emptyFullResponse: tp.FullSingleTokenInfo[] = [{
|
10 |
text: '[SEP]',
|
11 |
-
embeddings: [],
|
12 |
-
contexts: [],
|
13 |
-
bpe_token: '',
|
14 |
-
bpe_pos: '',
|
15 |
-
bpe_dep: '',
|
16 |
-
bpe_is_ent: null,
|
17 |
topk_words: [],
|
18 |
topk_probs: []
|
19 |
}]
|
|
|
8 |
*/
|
9 |
const emptyFullResponse: tp.FullSingleTokenInfo[] = [{
|
10 |
text: '[SEP]',
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
topk_words: [],
|
12 |
topk_probs: []
|
13 |
}]
|
client/src/ts/etc/SpacyInfo.ts
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import * as tp from './types'
|
2 |
-
import * as d3 from 'd3'
|
3 |
-
import * as R from 'ramda'
|
4 |
-
import {COLORS200} from '../etc/colors'
|
5 |
-
|
6 |
-
export class SpacyInfo {
|
7 |
-
colorScale:tp.ColorMetaScale
|
8 |
-
|
9 |
-
constructor(){
|
10 |
-
this.colorScale = this.createColorScales();
|
11 |
-
}
|
12 |
-
|
13 |
-
static EnglishMetaOptions: tp.MetaOptions = {
|
14 |
-
pos: ['punct', 'sym', 'x', 'adj', 'verb', 'conj', 'num', 'et', 'adv', 'x', 'adp', 'noun', 'propn', 'part', 'pron', 'space', 'intj'],
|
15 |
-
dep: ['root', 'ROOT', 'acl', 'acomp', 'advcl', 'advmod', 'agent', 'amod', 'appos', 'attr', 'aux', 'auxpass', 'case', 'cc', 'ccomp', 'compound', 'conj', 'cop', 'csubj',
|
16 |
-
'csubjpass', 'dative', 'dep', 'det', 'dobj', 'expl', 'intj', 'mark', 'meta', 'neg', 'nn', 'nounmod', 'npmod', 'nsubj', 'nsubjpass', 'nummod', 'oprd',
|
17 |
-
'obj', 'obl', 'parataxis', 'pcomp', 'pobj', 'poss', 'preconj', 'predet', 'prep', 'prt', 'punct', 'quantmod', 'relcl', 'root', 'xcomp', 'npadvmod'],
|
18 |
-
is_ent: [true, false],
|
19 |
-
ents: ['person', 'norp', 'fac', 'org', 'gpe', 'loc', 'product', 'event', 'work_of_art', 'law', 'language', 'date', 'time', 'percent', 'money', 'quantity', 'ordinal',
|
20 |
-
'cardinal'],
|
21 |
-
}
|
22 |
-
|
23 |
-
/**
|
24 |
-
* Obsolete. Represents the information that is included when trained on the universal corpus
|
25 |
-
*/
|
26 |
-
static UniversalMetaOptions: tp.MetaOptions = {
|
27 |
-
pos: ['adj', 'adp', 'adv', 'aux', 'conj', 'cconj', 'det', 'intj', 'noun', 'num', 'part', 'pron', 'propn', 'punct', 'sconj', 'sym', 'verb', 'x', 'space'],
|
28 |
-
dep: ['acl', 'advcl', 'advmod', 'amod', 'appos', 'aux', 'case', 'cc', 'ccomp', 'clf', 'compound', 'conj', 'cop', 'csubj', 'dep', 'det', 'discourse',
|
29 |
-
'dislocated', 'expl', 'fixed', 'flat', 'goeswith', 'iobj', 'list', 'mark', 'nmod', 'nsubj', 'nummod', 'obj', 'obl', 'orphan', 'parataxis', 'punct', 'reparandum',
|
30 |
-
'root', 'vocative', 'xcomp'],
|
31 |
-
is_ent: [true, false],
|
32 |
-
ents: ['person', 'norp', 'fac', 'org', 'gpe', 'loc', 'product', 'event', 'work_of_art', 'law', 'language', 'date', 'time', 'percent', 'money', 'quantity', 'ordinal',
|
33 |
-
'cardinal'],
|
34 |
-
}
|
35 |
-
|
36 |
-
static TotalMetaOptions: tp.MetaOptions = {
|
37 |
-
pos: R.union(SpacyInfo.EnglishMetaOptions.pos, SpacyInfo.UniversalMetaOptions.pos),
|
38 |
-
dep: SpacyInfo.EnglishMetaOptions.dep,
|
39 |
-
is_ent: SpacyInfo.EnglishMetaOptions.is_ent,
|
40 |
-
ents: SpacyInfo.EnglishMetaOptions.ents,
|
41 |
-
}
|
42 |
-
|
43 |
-
createColorScales(): tp.ColorMetaScale{
|
44 |
-
const toScale = (keys: Array<number|string|boolean>) => {
|
45 |
-
const obj = R.zipObj(R.map(String, keys), COLORS200.slice(0, keys.length))
|
46 |
-
return k => R.propOr("black", k, obj)
|
47 |
-
}
|
48 |
-
|
49 |
-
const myColors = {
|
50 |
-
pos: toScale(SpacyInfo.TotalMetaOptions.pos),
|
51 |
-
dep: toScale(SpacyInfo.TotalMetaOptions.dep),
|
52 |
-
is_ent: toScale(SpacyInfo.TotalMetaOptions.is_ent),
|
53 |
-
ents: toScale(SpacyInfo.TotalMetaOptions.ents),
|
54 |
-
offset: d3.scaleOrdinal().range(['black'])
|
55 |
-
}
|
56 |
-
|
57 |
-
return <tp.ColorMetaScale><unknown>myColors
|
58 |
-
}
|
59 |
-
}
|
60 |
-
|
61 |
-
export const spacyColors = new SpacyInfo();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client/src/ts/etc/types.ts
CHANGED
@@ -11,26 +11,21 @@ export type ModelInfo = {
|
|
11 |
nheads: number
|
12 |
}
|
13 |
|
14 |
-
type AbstractAttentionResponse<T> = {
|
15 |
-
aa: T
|
16 |
-
}
|
17 |
|
18 |
/**
|
19 |
* ATTENTION RESULTS FROM BACKEND
|
20 |
*
|
21 |
* These are the results that are encased in the 'aa' and 'ab' keys returned
|
22 |
*/
|
|
|
|
|
|
|
|
|
23 |
export type AttentionResponse = AbstractAttentionResponse<AttentionMetaResult>
|
24 |
export type AttentionMetaResult = AbstractAttentionResult<FullSingleTokenInfo[]>
|
25 |
|
26 |
export type FullSingleTokenInfo = {
|
27 |
text: string,
|
28 |
-
embeddings: number[],
|
29 |
-
contexts: number[],
|
30 |
-
bpe_token: string,
|
31 |
-
bpe_pos: string,
|
32 |
-
bpe_dep: string,
|
33 |
-
bpe_is_ent: boolean,
|
34 |
topk_words: string[],
|
35 |
topk_probs: number[]
|
36 |
}
|
@@ -57,28 +52,6 @@ interface MatchedAttentions {
|
|
57 |
out: MatchedTokAtt,
|
58 |
}
|
59 |
|
60 |
-
export interface FaissSearchResults {
|
61 |
-
sentence: string
|
62 |
-
index: number
|
63 |
-
next_index: number
|
64 |
-
match: string
|
65 |
-
match_plus_1: string
|
66 |
-
matched_att: MatchedAttentions
|
67 |
-
matched_att_plus_1: MatchedAttentions
|
68 |
-
tokens: TokenFaissMatch[]
|
69 |
-
}
|
70 |
-
|
71 |
-
export interface TokenFaissMatch {
|
72 |
-
token: string
|
73 |
-
pos: string
|
74 |
-
dep: string
|
75 |
-
is_ent: string
|
76 |
-
is_match: boolean
|
77 |
-
is_next_word: boolean
|
78 |
-
inward: number[]
|
79 |
-
outward: number[]
|
80 |
-
}
|
81 |
-
|
82 |
/**
|
83 |
* EVENT TYPES
|
84 |
*/
|
@@ -102,10 +75,7 @@ export type HeadBoxEvent = {
|
|
102 |
* MISCELLANEOUS TYPES
|
103 |
*/
|
104 |
|
105 |
-
export type SentenceOptions = "ab" | "ba" | "aa" | "bb" | "all";
|
106 |
export type SideOptions = "left" | "right"
|
107 |
-
export type SimpleMeta = "pos" | "dep" | "is_ent"
|
108 |
-
export type TokenOptions = "a" | "b" | "all"
|
109 |
|
110 |
export enum Toggled {
|
111 |
ADDED = 0,
|
@@ -113,35 +83,14 @@ export enum Toggled {
|
|
113 |
}
|
114 |
|
115 |
export enum NormBy {
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
}
|
120 |
-
|
121 |
-
export interface AbstractMetaOptions {
|
122 |
-
pos: string[],
|
123 |
-
dep: string[],
|
124 |
-
is_ent: any,
|
125 |
-
ents: string[],
|
126 |
-
}
|
127 |
-
|
128 |
-
export interface MetaOptions extends AbstractMetaOptions {
|
129 |
-
is_ent: boolean[],
|
130 |
-
}
|
131 |
-
|
132 |
-
export interface ColorMetaOptions extends AbstractMetaOptions {
|
133 |
-
is_ent: string[] // Representing hex colors
|
134 |
-
}
|
135 |
-
|
136 |
-
export interface ColorMetaScale {
|
137 |
-
pos: (d: string) => string,
|
138 |
-
dep: (d: string) => string,
|
139 |
-
is_ent: (d: string) => string,
|
140 |
-
ents: (d: string) => string,
|
141 |
-
offset?: (d: string) => string,
|
142 |
}
|
143 |
|
144 |
export enum ModelKind {
|
145 |
Bidirectional = "bidirectional",
|
146 |
Autoregressive = "autoregressive"
|
147 |
-
}
|
|
|
|
|
|
11 |
nheads: number
|
12 |
}
|
13 |
|
|
|
|
|
|
|
14 |
|
15 |
/**
|
16 |
* ATTENTION RESULTS FROM BACKEND
|
17 |
*
|
18 |
* These are the results that are encased in the 'aa' and 'ab' keys returned
|
19 |
*/
|
20 |
+
type AbstractAttentionResponse<T> = {
|
21 |
+
aa: T
|
22 |
+
}
|
23 |
+
|
24 |
export type AttentionResponse = AbstractAttentionResponse<AttentionMetaResult>
|
25 |
export type AttentionMetaResult = AbstractAttentionResult<FullSingleTokenInfo[]>
|
26 |
|
27 |
export type FullSingleTokenInfo = {
|
28 |
text: string,
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
topk_words: string[],
|
30 |
topk_probs: number[]
|
31 |
}
|
|
|
52 |
out: MatchedTokAtt,
|
53 |
}
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
/**
|
56 |
* EVENT TYPES
|
57 |
*/
|
|
|
75 |
* MISCELLANEOUS TYPES
|
76 |
*/
|
77 |
|
|
|
78 |
export type SideOptions = "left" | "right"
|
|
|
|
|
79 |
|
80 |
export enum Toggled {
|
81 |
ADDED = 0,
|
|
|
83 |
}
|
84 |
|
85 |
export enum NormBy {
|
86 |
+
ROW = 0,
|
87 |
+
COL,
|
88 |
+
ALL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
}
|
90 |
|
91 |
export enum ModelKind {
|
92 |
Bidirectional = "bidirectional",
|
93 |
Autoregressive = "autoregressive"
|
94 |
+
}
|
95 |
+
export type TokenOptions = "a" | "b" | "all"
|
96 |
+
export type SentenceOptions = "ab" | "ba" | "aa" | "bb" | "all";
|
client/src/ts/main.ts
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
import { MainGraphic } from './vis/
|
|
|
2 |
import { API, emptyTokenDisplay } from './api/mainApi'
|
3 |
import * as _ from 'lodash'
|
4 |
import { TokenWrapper } from './data/TokenWrapper'
|
@@ -8,112 +9,14 @@ import "!file-loader?name=exBERT.html!../exBERT.html";
|
|
8 |
import "!file-loader?name=index.html!../index.html";
|
9 |
import "../css/main.scss"
|
10 |
|
11 |
-
|
12 |
function doMySvg() {
|
13 |
-
|
|
|
14 |
};
|
15 |
|
16 |
-
/**
|
17 |
-
* Create the static files needed for the demo. Save the keys and file paths to a json object that is then written to a file
|
18 |
-
*
|
19 |
-
* This will print the object after every call. When the key length is the expected length, right click in chrome and select "save as global variable"
|
20 |
-
*
|
21 |
-
* Then, in the console, type "copy(temp1)". Use sublime text (it is the best for handling large files) to paste this into the code and save it as ____.json
|
22 |
-
*
|
23 |
-
* @param sentence - The sentence to analyze
|
24 |
-
* @param maskInd - Which index to mask in the sentence. Atm, can only record one masking
|
25 |
-
* @param outDictPath - Where to save the object of hashkey: filepath
|
26 |
-
*/
|
27 |
-
function createDemos(sentence, maskInd: number, modelName: string, corpusName: string, outDictPath) {
|
28 |
-
const api = new API()
|
29 |
-
const layers = _.range(12)
|
30 |
-
|
31 |
-
const L = 0
|
32 |
-
|
33 |
-
const contentHash = {} // Map hash -> contents
|
34 |
-
|
35 |
-
// Get the base return for all page initializations
|
36 |
-
_.range(12).forEach(L => {
|
37 |
-
api.getMetaAttentions(modelName, sentence, L, contentHash).then(r0 => {
|
38 |
-
const tokCapsule = new TokenWrapper(r0.payload);
|
39 |
-
|
40 |
-
// Unmasked response:
|
41 |
-
api.updateMaskedAttentions(modelName, tokCapsule.a, sentence, L, contentHash).then(r1 => {
|
42 |
-
// Masked word and searching responses:
|
43 |
-
tokCapsule.a.mask(maskInd)
|
44 |
-
api.updateMaskedAttentions(modelName, tokCapsule.a, sentence, L, contentHash).then(r2 => {
|
45 |
-
// Get search results by embedding
|
46 |
-
const embedding = r2['aa']['left'][maskInd].embeddings
|
47 |
-
api.getNearestEmbeddings(modelName, corpusName, embedding, L, _.range(12), 50, contentHash).then(x => {
|
48 |
-
})
|
49 |
-
|
50 |
-
// Get search results by context
|
51 |
-
const context = r2['aa']['left'][maskInd].contexts
|
52 |
-
api.getNearestContexts(modelName, corpusName, context, L, _.range(12), 50, contentHash).then(x => {
|
53 |
-
console.log(Object.keys(contentHash).length);
|
54 |
-
console.log(contentHash);
|
55 |
-
})
|
56 |
-
})
|
57 |
-
})
|
58 |
-
})
|
59 |
-
})
|
60 |
-
}
|
61 |
-
|
62 |
-
/**
|
63 |
-
*
|
64 |
-
* Observe how the demo creation process works.
|
65 |
-
*
|
66 |
-
* If desired to mask multiple words in the input for demo purposes, try looping over the mask inds and masking each one individually
|
67 |
-
*
|
68 |
-
* @param sentence The demo sentence
|
69 |
-
* @param maskInd Desired index to mask (can currently only accept a single mask index)
|
70 |
-
* @param outDictPath
|
71 |
-
*/
|
72 |
-
function inspectDemos(sentence, maskInd: number, modelName: string, corpusName: string, outDictPath) {
|
73 |
-
const api = new API()
|
74 |
-
|
75 |
-
const contentHash = {}
|
76 |
-
|
77 |
-
// Get the base return for all page initializations
|
78 |
-
_.range(1).forEach(L => {
|
79 |
-
api.getMetaAttentions(modelName, sentence, L, "").then(r0 => {
|
80 |
-
const tokCapsule = new TokenWrapper(r0.payload);
|
81 |
-
|
82 |
-
// Unmasked response:
|
83 |
-
api.updateMaskedAttentions(modelName, tokCapsule.a, sentence, L, emptyTokenDisplay).then(r1 => {
|
84 |
-
// Masked word and searching responses:
|
85 |
-
tokCapsule.a.mask(maskInd)
|
86 |
-
api.updateMaskedAttentions(modelName, tokCapsule.a, sentence, L, emptyTokenDisplay).then(r2 => {
|
87 |
-
console.log(r2);
|
88 |
-
// Get search results by embedding
|
89 |
-
const embedding = r2['aa']['left'][maskInd].embeddings
|
90 |
-
api.getNearestEmbeddings(modelName, corpusName, embedding, L, _.range(12), 50, contentHash).then(x => {
|
91 |
-
})
|
92 |
-
|
93 |
-
// Get search results by context
|
94 |
-
const context = r2['aa']['left'][maskInd].contexts
|
95 |
-
api.getNearestContexts(modelName, corpusName, context, L, _.range(12), 50).then(x => {
|
96 |
-
})
|
97 |
-
})
|
98 |
-
})
|
99 |
-
})
|
100 |
-
})
|
101 |
-
}
|
102 |
-
|
103 |
-
function replTest() {
|
104 |
-
// Tester.testAttWrapperConstructor()
|
105 |
-
// Tester.testUpdateMaskedAttention()
|
106 |
-
// Tester.testNjAray();
|
107 |
-
// Tester.testRandomArrayCreation();
|
108 |
-
// Tester.testFaissWrapper();
|
109 |
-
// Tester.testD3Ordinal();
|
110 |
-
// Tester.testFaissSearchResultsHist();
|
111 |
-
// Tester.testReadingJSON();
|
112 |
-
}
|
113 |
-
|
114 |
window.onload = () => {
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
console.log("Done loading window");
|
119 |
}
|
|
|
1 |
+
import { MainGraphic } from './vis/attentionVis'
|
2 |
+
import * as d3 from 'd3'
|
3 |
import { API, emptyTokenDisplay } from './api/mainApi'
|
4 |
import * as _ from 'lodash'
|
5 |
import { TokenWrapper } from './data/TokenWrapper'
|
|
|
9 |
import "!file-loader?name=index.html!../index.html";
|
10 |
import "../css/main.scss"
|
11 |
|
|
|
12 |
function doMySvg() {
|
13 |
+
const base = document.getElementById('static-init')
|
14 |
+
return new MainGraphic(base)
|
15 |
};
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
window.onload = () => {
|
18 |
+
const base = document.getElementById('attention-vis')
|
19 |
+
//@ts-ignore
|
20 |
+
const mainVis = new MainGraphic(base)
|
21 |
console.log("Done loading window");
|
22 |
}
|
client/src/ts/uiConfig.ts
CHANGED
@@ -16,16 +16,11 @@ interface URLParameters {
|
|
16 |
sentence?: string
|
17 |
model?: string
|
18 |
modelKind?: string
|
19 |
-
corpus?: string
|
20 |
layer?: number
|
21 |
heads?: number[]
|
22 |
threshold?: number
|
23 |
tokenInd?: number | 'null'
|
24 |
tokenSide?: tp.SideOptions
|
25 |
-
metaMatch?: tp.SimpleMeta | null
|
26 |
-
metaMax?: tp.SimpleMeta | null
|
27 |
-
displayInspector?: InspectorOptions
|
28 |
-
offsetIdxs?: number[]
|
29 |
maskInds?: number[]
|
30 |
hideClsSep?: boolean
|
31 |
}
|
@@ -34,19 +29,23 @@ export class UIConfig {
|
|
34 |
|
35 |
private _conf: URLParameters = {}
|
36 |
private _headSet: Set<number>;
|
37 |
-
attType:
|
38 |
_nHeads: number | null;
|
39 |
_nLayers: number | null;
|
40 |
private _token: tp.TokenEvent;
|
41 |
|
42 |
constructor() {
|
43 |
-
this._nHeads = 12;
|
44 |
this._nLayers = null;
|
45 |
-
this.attType = 'aa'
|
46 |
this.fromURL()
|
47 |
this.toURL(false)
|
48 |
}
|
49 |
|
|
|
|
|
|
|
|
|
50 |
|
51 |
fromURL() {
|
52 |
const params = URLHandler.parameters
|
@@ -55,17 +54,12 @@ export class UIConfig {
|
|
55 |
model: params['model'] || 'bert-base-cased',
|
56 |
modelKind: params['modelKind'] || tp.ModelKind.Bidirectional,
|
57 |
sentence: params['sentence'] || "The girl ran to a local pub to escape the din of her city.",
|
58 |
-
corpus: params['corpus'] || 'woz',
|
59 |
layer: params['layer'] || 1,
|
60 |
heads: this._initHeads(params['heads']),
|
61 |
threshold: params['threshold'] || 0.7,
|
62 |
tokenInd: params['tokenInd'] || null,
|
63 |
tokenSide: params['tokenSide'] || null,
|
64 |
maskInds: params['maskInds'] || [9],
|
65 |
-
metaMatch: params['metaMatch'] || "pos",
|
66 |
-
metaMax: params['metaMax'] || "pos",
|
67 |
-
displayInspector: params['displayInspector'] || null,
|
68 |
-
offsetIdxs: this._initOffsetIdxs(params['offsetIdxs']),
|
69 |
hideClsSep: truthy(params['hideClsSep']) || true,
|
70 |
}
|
71 |
|
@@ -73,20 +67,6 @@ export class UIConfig {
|
|
73 |
|
74 |
}
|
75 |
|
76 |
-
toURL(updateHistory = false) {
|
77 |
-
URLHandler.updateUrl(this._conf, updateHistory)
|
78 |
-
}
|
79 |
-
|
80 |
-
private _initOffsetIdxs(v: (string | number)[] | null) {
|
81 |
-
if (v == null) {
|
82 |
-
return [-1, 0, 1]
|
83 |
-
}
|
84 |
-
else {
|
85 |
-
const numberArr = R.map(toNumber, v);
|
86 |
-
return numberArr;
|
87 |
-
}
|
88 |
-
}
|
89 |
-
|
90 |
private _initHeads(v: number[] | null) {
|
91 |
if (v == null || v.length < 1) {
|
92 |
this.selectAllHeads()
|
@@ -237,26 +217,6 @@ export class UIConfig {
|
|
237 |
return this
|
238 |
}
|
239 |
|
240 |
-
metaMatch(): tp.SimpleMeta;
|
241 |
-
metaMatch(val: tp.SimpleMeta): this;
|
242 |
-
metaMatch(val?) {
|
243 |
-
if (val == null) return this._conf.metaMax;
|
244 |
-
|
245 |
-
this._conf.metaMax = val;
|
246 |
-
this.toURL();
|
247 |
-
return this;
|
248 |
-
}
|
249 |
-
|
250 |
-
metaMax(): tp.SimpleMeta;
|
251 |
-
metaMax(val: tp.SimpleMeta): this;
|
252 |
-
metaMax(val?) {
|
253 |
-
if (val == null) return this._conf.metaMatch;
|
254 |
-
|
255 |
-
this._conf.metaMatch = val;
|
256 |
-
this.toURL();
|
257 |
-
return this;
|
258 |
-
}
|
259 |
-
|
260 |
maskInds(): number[];
|
261 |
maskInds(val: number[]): this;
|
262 |
maskInds(val?) {
|
@@ -267,28 +227,6 @@ export class UIConfig {
|
|
267 |
return this;
|
268 |
}
|
269 |
|
270 |
-
displayInspector(): InspectorOptions;
|
271 |
-
displayInspector(val: InspectorOptions): this;
|
272 |
-
displayInspector(val?) {
|
273 |
-
if (val == null) return this._conf.displayInspector;
|
274 |
-
|
275 |
-
this._conf.displayInspector = val;
|
276 |
-
this.toURL();
|
277 |
-
return this;
|
278 |
-
}
|
279 |
-
|
280 |
-
offsetIdxs(): number[];
|
281 |
-
offsetIdxs(val: number[]): this;
|
282 |
-
offsetIdxs(val?) {
|
283 |
-
if (val == null) return this._conf.offsetIdxs;
|
284 |
-
|
285 |
-
// convert to numbers
|
286 |
-
|
287 |
-
this._conf.offsetIdxs = R.map(toNumber, val);
|
288 |
-
this.toURL();
|
289 |
-
return this;
|
290 |
-
}
|
291 |
-
|
292 |
hideClsSep(): boolean;
|
293 |
hideClsSep(val: boolean): this;
|
294 |
hideClsSep(val?) {
|
@@ -341,13 +279,4 @@ export class UIConfig {
|
|
341 |
get matchHistogramDescription() {
|
342 |
return this.modelKind() == tp.ModelKind.Autoregressive ? "Next" : "Matched"
|
343 |
}
|
344 |
-
|
345 |
-
corpus(): string;
|
346 |
-
corpus(val: string): this;
|
347 |
-
corpus(val?) {
|
348 |
-
if (val == null) return this._conf.corpus
|
349 |
-
this._conf.corpus = val
|
350 |
-
this.toURL();
|
351 |
-
return this
|
352 |
-
}
|
353 |
}
|
|
|
16 |
sentence?: string
|
17 |
model?: string
|
18 |
modelKind?: string
|
|
|
19 |
layer?: number
|
20 |
heads?: number[]
|
21 |
threshold?: number
|
22 |
tokenInd?: number | 'null'
|
23 |
tokenSide?: tp.SideOptions
|
|
|
|
|
|
|
|
|
24 |
maskInds?: number[]
|
25 |
hideClsSep?: boolean
|
26 |
}
|
|
|
29 |
|
30 |
private _conf: URLParameters = {}
|
31 |
private _headSet: Set<number>;
|
32 |
+
attType: "aa"
|
33 |
_nHeads: number | null;
|
34 |
_nLayers: number | null;
|
35 |
private _token: tp.TokenEvent;
|
36 |
|
37 |
constructor() {
|
38 |
+
this._nHeads = 12;
|
39 |
this._nLayers = null;
|
40 |
+
this.attType = 'aa'
|
41 |
this.fromURL()
|
42 |
this.toURL(false)
|
43 |
}
|
44 |
|
45 |
+
toURL(updateHistory = false) {
|
46 |
+
URLHandler.updateUrl(this._conf, updateHistory)
|
47 |
+
}
|
48 |
+
|
49 |
|
50 |
fromURL() {
|
51 |
const params = URLHandler.parameters
|
|
|
54 |
model: params['model'] || 'bert-base-cased',
|
55 |
modelKind: params['modelKind'] || tp.ModelKind.Bidirectional,
|
56 |
sentence: params['sentence'] || "The girl ran to a local pub to escape the din of her city.",
|
|
|
57 |
layer: params['layer'] || 1,
|
58 |
heads: this._initHeads(params['heads']),
|
59 |
threshold: params['threshold'] || 0.7,
|
60 |
tokenInd: params['tokenInd'] || null,
|
61 |
tokenSide: params['tokenSide'] || null,
|
62 |
maskInds: params['maskInds'] || [9],
|
|
|
|
|
|
|
|
|
63 |
hideClsSep: truthy(params['hideClsSep']) || true,
|
64 |
}
|
65 |
|
|
|
67 |
|
68 |
}
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
private _initHeads(v: number[] | null) {
|
71 |
if (v == null || v.length < 1) {
|
72 |
this.selectAllHeads()
|
|
|
217 |
return this
|
218 |
}
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
maskInds(): number[];
|
221 |
maskInds(val: number[]): this;
|
222 |
maskInds(val?) {
|
|
|
227 |
return this;
|
228 |
}
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
hideClsSep(): boolean;
|
231 |
hideClsSep(val: boolean): this;
|
232 |
hideClsSep(val?) {
|
|
|
279 |
get matchHistogramDescription() {
|
280 |
return this.modelKind() == tp.ModelKind.Autoregressive ? "Next" : "Matched"
|
281 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
}
|
client/src/ts/vis/AttentionConnector.ts
CHANGED
@@ -61,11 +61,11 @@ export class AttentionGraph extends VComponent<AttentionData>{
|
|
61 |
// Define whether to use the 'j' or 'i' attribute to calculate opacities
|
62 |
private scaleIdx(): "i" | "j" {
|
63 |
switch (this.normBy) {
|
64 |
-
case tp.NormBy.
|
65 |
return 'j'
|
66 |
-
case tp.NormBy.
|
67 |
return 'i'
|
68 |
-
case tp.NormBy.
|
69 |
return 'i'
|
70 |
|
71 |
}
|
@@ -165,7 +165,7 @@ export class AttentionGraph extends VComponent<AttentionData>{
|
|
165 |
|
166 |
// Group normalization
|
167 |
switch (this.normBy){
|
168 |
-
case tp.NormBy.
|
169 |
arr = this.edgeData.extent(1);
|
170 |
this.opacityScales = [];
|
171 |
arr.forEach((v, i) => {
|
@@ -176,7 +176,7 @@ export class AttentionGraph extends VComponent<AttentionData>{
|
|
176 |
)
|
177 |
})
|
178 |
break;
|
179 |
-
case tp.NormBy.
|
180 |
arr = this.edgeData.extent(0);
|
181 |
this.opacityScales = [];
|
182 |
arr.forEach((v, i) => {
|
@@ -187,7 +187,7 @@ export class AttentionGraph extends VComponent<AttentionData>{
|
|
187 |
)
|
188 |
})
|
189 |
break;
|
190 |
-
case tp.NormBy.
|
191 |
const maxIn = d3.max(this.plotData.map((d) => d.v))
|
192 |
for (let i = 0; i < this._data.length; i++) {
|
193 |
this.opacityScales.push(d3.scaleLinear()
|
|
|
61 |
// Define whether to use the 'j' or 'i' attribute to calculate opacities
|
62 |
private scaleIdx(): "i" | "j" {
|
63 |
switch (this.normBy) {
|
64 |
+
case tp.NormBy.COL:
|
65 |
return 'j'
|
66 |
+
case tp.NormBy.ROW:
|
67 |
return 'i'
|
68 |
+
case tp.NormBy.ALL:
|
69 |
return 'i'
|
70 |
|
71 |
}
|
|
|
165 |
|
166 |
// Group normalization
|
167 |
switch (this.normBy){
|
168 |
+
case tp.NormBy.ROW:
|
169 |
arr = this.edgeData.extent(1);
|
170 |
this.opacityScales = [];
|
171 |
arr.forEach((v, i) => {
|
|
|
176 |
)
|
177 |
})
|
178 |
break;
|
179 |
+
case tp.NormBy.COL:
|
180 |
arr = this.edgeData.extent(0);
|
181 |
this.opacityScales = [];
|
182 |
arr.forEach((v, i) => {
|
|
|
187 |
)
|
188 |
})
|
189 |
break;
|
190 |
+
case tp.NormBy.ALL:
|
191 |
const maxIn = d3.max(this.plotData.map((d) => d.v))
|
192 |
for (let i = 0; i < this._data.length; i++) {
|
193 |
this.opacityScales.push(d3.scaleLinear()
|
client/src/ts/vis/CorpusHistogram.ts
DELETED
@@ -1,253 +0,0 @@
|
|
1 |
-
import {VComponent} from './VisComponent'
|
2 |
-
import {spacyColors} from '../etc/SpacyInfo'
|
3 |
-
import {SVG} from '../etc/SVGplus'
|
4 |
-
import * as d3 from 'd3'
|
5 |
-
import * as R from 'ramda'
|
6 |
-
import { D3Sel } from '../etc/Util';
|
7 |
-
import { SimpleEventHandler } from '../etc/SimpleEventHandler';
|
8 |
-
|
9 |
-
interface MarginInfo {
|
10 |
-
top: number,
|
11 |
-
bottom: number,
|
12 |
-
right: number,
|
13 |
-
left: number
|
14 |
-
}
|
15 |
-
|
16 |
-
// Dependent on the options in the response
|
17 |
-
type MatchedMetaSelections = "pos" | "dep" | "ent"
|
18 |
-
|
19 |
-
interface MatchedMetaCount {
|
20 |
-
pos: number
|
21 |
-
dep: number
|
22 |
-
is_ent: number
|
23 |
-
}
|
24 |
-
|
25 |
-
interface MaxAttMetaCount {
|
26 |
-
offset: number
|
27 |
-
}
|
28 |
-
|
29 |
-
type MatchedDataInterface = MatchedMetaCount
|
30 |
-
type MaxAttDataInterface = MaxAttMetaCount
|
31 |
-
type DataInterface = MatchedDataInterface | MaxAttDataInterface
|
32 |
-
|
33 |
-
interface CountedHist {
|
34 |
-
label: string,
|
35 |
-
count: number
|
36 |
-
}
|
37 |
-
|
38 |
-
type RenderDataInterface = CountedHist[]
|
39 |
-
|
40 |
-
|
41 |
-
/**
|
42 |
-
* Data formatting functions
|
43 |
-
*/
|
44 |
-
const toRenderData = (obj: {[s: string]: number}): RenderDataInterface => Object.keys(obj).map((k, i) => {
|
45 |
-
return {label: k, count: obj[k]}
|
46 |
-
})
|
47 |
-
|
48 |
-
const toStringOrNum = (a:string) => {
|
49 |
-
const na = +a
|
50 |
-
if (isNaN(na)) {
|
51 |
-
return a
|
52 |
-
}
|
53 |
-
return na
|
54 |
-
}
|
55 |
-
|
56 |
-
const sortByLabel = R.sortBy(R.compose(toStringOrNum, R.prop('label')))
|
57 |
-
const sortByCount = R.sortBy(R.prop('count'))
|
58 |
-
|
59 |
-
const toOrderedRender = R.compose(
|
60 |
-
R.reverse,
|
61 |
-
// @ts-ignore -- TODO: fix
|
62 |
-
sortByCount,
|
63 |
-
toRenderData
|
64 |
-
)
|
65 |
-
|
66 |
-
export class CorpusHistogram<T> extends VComponent<T> {
|
67 |
-
|
68 |
-
css_name = ''
|
69 |
-
|
70 |
-
static events = {}
|
71 |
-
|
72 |
-
_current = {
|
73 |
-
chart: {
|
74 |
-
height: null,
|
75 |
-
width: null
|
76 |
-
}
|
77 |
-
}
|
78 |
-
|
79 |
-
// D3 COMPONENTS
|
80 |
-
svg: D3Sel
|
81 |
-
|
82 |
-
options: {
|
83 |
-
margin: MarginInfo
|
84 |
-
barWidth: number
|
85 |
-
width: number
|
86 |
-
height: number
|
87 |
-
val: string
|
88 |
-
xLabelRot: number
|
89 |
-
xLabelOffset: number
|
90 |
-
yLabelOffset: number
|
91 |
-
}
|
92 |
-
|
93 |
-
axes = {
|
94 |
-
x: d3.scaleBand(),
|
95 |
-
y: d3.scaleLinear(),
|
96 |
-
}
|
97 |
-
|
98 |
-
|
99 |
-
constructor(d3parent: D3Sel, eventHandler?: SimpleEventHandler, options={}) {
|
100 |
-
super(d3parent, eventHandler)
|
101 |
-
this.options = {
|
102 |
-
margin: {
|
103 |
-
top: 10,
|
104 |
-
right: 30,
|
105 |
-
bottom: 50,
|
106 |
-
left: 40
|
107 |
-
},
|
108 |
-
barWidth: 25,
|
109 |
-
width: 185,
|
110 |
-
height: 230,
|
111 |
-
val: "pos", // Change Default, pass through constructor
|
112 |
-
xLabelRot: 45,
|
113 |
-
xLabelOffset: 15,
|
114 |
-
yLabelOffset: 5,
|
115 |
-
|
116 |
-
}
|
117 |
-
this.superInitSVG()
|
118 |
-
}
|
119 |
-
|
120 |
-
meta():MatchedMetaSelections
|
121 |
-
meta(val:MatchedMetaSelections): this
|
122 |
-
meta(val?) {
|
123 |
-
if (val == null) {
|
124 |
-
return this.options.val;
|
125 |
-
}
|
126 |
-
|
127 |
-
this.options.val = val;
|
128 |
-
this.update(this._data)
|
129 |
-
|
130 |
-
return this;
|
131 |
-
}
|
132 |
-
|
133 |
-
_init() {}
|
134 |
-
|
135 |
-
private createXAxis() {
|
136 |
-
const self = this;
|
137 |
-
const op = this.options;
|
138 |
-
const width = op.width - op.margin.left - op.margin.right
|
139 |
-
|
140 |
-
this.axes.x
|
141 |
-
.domain(R.map(R.prop('label'), self.renderData))
|
142 |
-
.rangeRound([0, width])
|
143 |
-
.padding(0.1)
|
144 |
-
|
145 |
-
this._current.chart.width = width;
|
146 |
-
}
|
147 |
-
|
148 |
-
private createYAxis() {
|
149 |
-
const self = this;
|
150 |
-
const op = this.options;
|
151 |
-
const height = op.height - op.margin.top - op.margin.bottom
|
152 |
-
|
153 |
-
this.axes.y
|
154 |
-
.domain([0, +d3.max(R.map(R.prop('count'), self.renderData))])
|
155 |
-
.rangeRound([height, 0])
|
156 |
-
|
157 |
-
this._current.chart.height = height;
|
158 |
-
}
|
159 |
-
|
160 |
-
private createAxes() {
|
161 |
-
this.createXAxis()
|
162 |
-
this.createYAxis()
|
163 |
-
}
|
164 |
-
|
165 |
-
_wrangle(data: DataInterface) {
|
166 |
-
const out = data[this.options.val]
|
167 |
-
return toOrderedRender(out)
|
168 |
-
}
|
169 |
-
|
170 |
-
width():number
|
171 |
-
width(val:number):this
|
172 |
-
width(val?) {
|
173 |
-
if (val == null) {
|
174 |
-
return this.options.width;
|
175 |
-
}
|
176 |
-
this.options.width = val;
|
177 |
-
this.updateWidth();
|
178 |
-
this.createXAxis();
|
179 |
-
return this;
|
180 |
-
}
|
181 |
-
|
182 |
-
height():number
|
183 |
-
height(val:number):this
|
184 |
-
height(val?) {
|
185 |
-
if (val == null) {
|
186 |
-
return this.options.height;
|
187 |
-
}
|
188 |
-
|
189 |
-
this.options.height = val;
|
190 |
-
this.updateHeight();
|
191 |
-
this.createYAxis();
|
192 |
-
return this;
|
193 |
-
}
|
194 |
-
|
195 |
-
private updateWidth() {
|
196 |
-
this.svg.attr('width', this.options.width)
|
197 |
-
}
|
198 |
-
|
199 |
-
private updateHeight() {
|
200 |
-
this.svg.attr('height', this.options.height)
|
201 |
-
}
|
202 |
-
|
203 |
-
private figWidth(data: RenderDataInterface) {
|
204 |
-
const op = this.options;
|
205 |
-
return (data.length * op.barWidth) + op.margin.left + op.margin.right
|
206 |
-
}
|
207 |
-
|
208 |
-
_render(data:RenderDataInterface) {
|
209 |
-
const self = this;
|
210 |
-
const op = this.options;
|
211 |
-
const curr = this._current;
|
212 |
-
|
213 |
-
this.parent.html('')
|
214 |
-
this.svg = this.parent
|
215 |
-
|
216 |
-
this.createAxes();
|
217 |
-
this.width(this.figWidth(data));
|
218 |
-
this.updateHeight();
|
219 |
-
|
220 |
-
// Initialize axes
|
221 |
-
const g = self.svg.append("g")
|
222 |
-
.attr("transform", SVG.translate({x: op.margin.left, y:op.margin.top}))
|
223 |
-
|
224 |
-
// Hack to allow clearing this histograms to work
|
225 |
-
self.base = g
|
226 |
-
|
227 |
-
// Fix below for positional changing
|
228 |
-
const axisBottom = g.append("g")
|
229 |
-
.attr("transform", SVG.translate({x: 0, y:curr.chart.height}))
|
230 |
-
.call(d3.axisBottom(self.axes.x))
|
231 |
-
|
232 |
-
if (op.val != "offset") {
|
233 |
-
axisBottom
|
234 |
-
.selectAll("text")
|
235 |
-
.attr("y", op.yLabelOffset) // Move below the axis
|
236 |
-
.attr("x", op.xLabelOffset) // Offset to the right a bit
|
237 |
-
.attr("transform", SVG.rotate(op.xLabelRot))
|
238 |
-
}
|
239 |
-
|
240 |
-
g.append("g")
|
241 |
-
.call(d3.axisLeft(self.axes.y))
|
242 |
-
|
243 |
-
g.selectAll(".bar")
|
244 |
-
.data(data)
|
245 |
-
.join('rect')
|
246 |
-
.attr("class", "bar")
|
247 |
-
.attr("x", function(d) { return self.axes.x(d.label); })
|
248 |
-
.attr("y", function(d) { return self.axes.y(d.count); })
|
249 |
-
.attr("width", self.axes.x.bandwidth())
|
250 |
-
.attr("height", function(d) { return curr.chart.height - self.axes.y(d.count); })
|
251 |
-
.style('fill', k => spacyColors.colorScale[op.val](k.label))
|
252 |
-
}
|
253 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client/src/ts/vis/CorpusInspector.ts
DELETED
@@ -1,150 +0,0 @@
|
|
1 |
-
import * as d3 from "d3";
|
2 |
-
import * as R from 'ramda'
|
3 |
-
import 'd3-selection-multi'
|
4 |
-
import {d3S, D3Sel} from "../etc/Util";
|
5 |
-
import { VComponent } from "./VisComponent";
|
6 |
-
import { SimpleEventHandler } from "../etc/SimpleEventHandler";
|
7 |
-
import * as tp from "../etc/types"
|
8 |
-
import '../etc/xd3'
|
9 |
-
|
10 |
-
// Helpers
|
11 |
-
const currMatchIdx = (elem) => +(<Element>elem.parentNode).getAttribute('matchidx')
|
12 |
-
const currRowNum = (elem) => +(<Element>elem.parentNode).getAttribute('rownum')
|
13 |
-
const backgroundColor = x => `rgba(128, 0, 150, ${0.6*x})`
|
14 |
-
|
15 |
-
export class CorpusInspector extends VComponent<tp.FaissSearchResults[]>{
|
16 |
-
css_name = 'corpus-inspector';
|
17 |
-
_current: {};
|
18 |
-
|
19 |
-
_data: tp.FaissSearchResults[]; // The passed data
|
20 |
-
|
21 |
-
static events = {
|
22 |
-
rowMouseOver: "CorpusInspector_rowMouseOver",
|
23 |
-
rowMouseOut: "CorpusInspector_rowMouseOut",
|
24 |
-
rowClick: "CorpusInspector_rowClick",
|
25 |
-
rowDblClick: "CorpusInspector_rowDblClick",
|
26 |
-
cellMouseOver: "CorpusInspector_cellMouseOver",
|
27 |
-
cellMouseOut: "CorpusInspector_cellMouseOut",
|
28 |
-
cellClick: "CorpusInspector_cellClick",
|
29 |
-
cellDblClick: "CorpusInspector_cellDblClick",
|
30 |
-
}
|
31 |
-
|
32 |
-
options = {
|
33 |
-
showNext: false
|
34 |
-
}
|
35 |
-
|
36 |
-
// COMPONENTS
|
37 |
-
inspectorRows: D3Sel
|
38 |
-
inspectorCells: D3Sel
|
39 |
-
scaler = d3.scalePow().range([0,0.9]).exponent(2)
|
40 |
-
|
41 |
-
constructor(d3Parent: D3Sel, eventHandler?:SimpleEventHandler, options: {} = {}) {
|
42 |
-
super(d3Parent, eventHandler)
|
43 |
-
this.superInitHTML(options)
|
44 |
-
this._init()
|
45 |
-
}
|
46 |
-
|
47 |
-
private createRows() {
|
48 |
-
const data = this._data
|
49 |
-
|
50 |
-
this.inspectorRows = this.base.selectAll(".inspector-row")
|
51 |
-
.data(data)
|
52 |
-
.join('div')
|
53 |
-
.classed('inspector-row', true)
|
54 |
-
.attrs({
|
55 |
-
matchIdx: d => d.index,
|
56 |
-
rowNum: (d, i) => i,
|
57 |
-
})
|
58 |
-
.on("mouseover", (d, i) => {
|
59 |
-
this.eventHandler.trigger(CorpusInspector.events.rowMouseOver, {})
|
60 |
-
})
|
61 |
-
}
|
62 |
-
|
63 |
-
private addTooltip() {
|
64 |
-
this.inspectorCells = this.inspectorCells
|
65 |
-
.classed('celltooltip', true)
|
66 |
-
.append('span')
|
67 |
-
.classed('tooltiptext', true)
|
68 |
-
.html((d, i, n) => {
|
69 |
-
const entityStr = d.is_ent ? "<br>Entity" : ""
|
70 |
-
const att = (<Element>n[i].parentNode).getAttribute('att').slice(0, 7)
|
71 |
-
const attStr = `<br>Attention: ${att}`
|
72 |
-
|
73 |
-
return `POS: ${d.pos.toLowerCase()}<br>DEP: ${d.dep.toLowerCase()}` + entityStr + attStr
|
74 |
-
})
|
75 |
-
}
|
76 |
-
|
77 |
-
private createCells() {
|
78 |
-
const self = this
|
79 |
-
|
80 |
-
this.inspectorCells = this.inspectorRows.selectAll('.inspector-cell')
|
81 |
-
.data((d:tp.FaissSearchResults) => d.tokens)
|
82 |
-
.join('div')
|
83 |
-
.classed('inspector-cell', true)
|
84 |
-
.attr('index-offset', (d, i, n:HTMLElement[]) => {
|
85 |
-
const matchIdx = currMatchIdx(n[i])
|
86 |
-
return i - matchIdx
|
87 |
-
})
|
88 |
-
.attrs({
|
89 |
-
pos: d => d.pos.toLowerCase(),
|
90 |
-
dep: d => d.dep.toLowerCase(),
|
91 |
-
is_ent: d => d.is_ent
|
92 |
-
})
|
93 |
-
.text(d => d.token.replace("\u0120", " "))
|
94 |
-
.classed('matched-cell', d => d.is_match)
|
95 |
-
.classed('next-cell', function(d) {
|
96 |
-
return self.showNext() && d.is_next_word
|
97 |
-
})
|
98 |
-
.classed('gray-cell', function(d, i) {
|
99 |
-
const idx = +currMatchIdx(this)
|
100 |
-
return self.showNext() && i > idx
|
101 |
-
})
|
102 |
-
|
103 |
-
// Highlight the cells appropriately
|
104 |
-
this.inspectorCells.each((d,i,n) => {
|
105 |
-
const idx = currMatchIdx(n[i])
|
106 |
-
if (i == idx) {
|
107 |
-
const att = d.inward
|
108 |
-
const maxAtt = +d3.max(att)
|
109 |
-
const currRow = currRowNum(n[i])
|
110 |
-
const scaler = self.scaler.domain([0, maxAtt])
|
111 |
-
|
112 |
-
d3.selectAll(`.inspector-row[rownum='${currRow}']`)
|
113 |
-
.selectAll(`.inspector-cell`)
|
114 |
-
.style('background', (d, i) => {
|
115 |
-
return backgroundColor(scaler(att[i]))
|
116 |
-
})
|
117 |
-
.attr('att', (d, i) => att[i])
|
118 |
-
}
|
119 |
-
})
|
120 |
-
|
121 |
-
self.addTooltip()
|
122 |
-
}
|
123 |
-
|
124 |
-
private updateData() {
|
125 |
-
this.createRows()
|
126 |
-
this.createCells()
|
127 |
-
}
|
128 |
-
|
129 |
-
_init() {}
|
130 |
-
|
131 |
-
_wrangle(data: tp.FaissSearchResults[]) {
|
132 |
-
this._data = data
|
133 |
-
return data;
|
134 |
-
}
|
135 |
-
|
136 |
-
_render(data: tp.FaissSearchResults[]) {
|
137 |
-
// Remember that this._data is defined in wrangle which should always be called before render
|
138 |
-
// as is defined in the update function
|
139 |
-
this.updateData()
|
140 |
-
}
|
141 |
-
|
142 |
-
showNext(): boolean
|
143 |
-
showNext(v:boolean): this
|
144 |
-
showNext(v?) {
|
145 |
-
if (v == null) return this.options.showNext
|
146 |
-
|
147 |
-
this.options.showNext = v
|
148 |
-
return this
|
149 |
-
}
|
150 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client/src/ts/vis/CorpusMatManager.ts
DELETED
@@ -1,321 +0,0 @@
|
|
1 |
-
import * as d3 from 'd3'
|
2 |
-
import * as R from 'ramda'
|
3 |
-
import * as tp from '../etc/types'
|
4 |
-
import { D3Sel } from '../etc/Util'
|
5 |
-
import { VComponent } from '../vis/VisComponent'
|
6 |
-
import { SimpleEventHandler } from "../etc/SimpleEventHandler";
|
7 |
-
import { SVG } from "../etc/SVGplus"
|
8 |
-
import { spacyColors } from "../etc/SpacyInfo"
|
9 |
-
import "../etc/xd3"
|
10 |
-
|
11 |
-
// Need additoinal height information to render boxes
|
12 |
-
interface BaseDataInterface extends tp.FaissSearchResults {
|
13 |
-
height: number
|
14 |
-
}
|
15 |
-
type DataInterface = BaseDataInterface[]
|
16 |
-
|
17 |
-
interface ColorMetaBaseData {
|
18 |
-
pos: string
|
19 |
-
dep: string
|
20 |
-
is_ent: boolean
|
21 |
-
token: string
|
22 |
-
}
|
23 |
-
|
24 |
-
type DisplayOptions = "pos" | "dep" | "ent"
|
25 |
-
|
26 |
-
function managerData2MatData(dataIn: DataInterface, indexOffset = 0, toPick = ['pos']) {
|
27 |
-
|
28 |
-
const outOfRangeObj: ColorMetaBaseData = {
|
29 |
-
pos: null,
|
30 |
-
dep: null,
|
31 |
-
is_ent: null,
|
32 |
-
token: null,
|
33 |
-
}
|
34 |
-
|
35 |
-
const chooseProps = R.pick(toPick)
|
36 |
-
|
37 |
-
const dataOut = dataIn.map(d => {
|
38 |
-
const wordIdx = d.index + indexOffset;
|
39 |
-
if ((wordIdx < 0) || (wordIdx >= d.tokens.length)) {
|
40 |
-
return R.assoc('height', d.height, outOfRangeObj)
|
41 |
-
}
|
42 |
-
|
43 |
-
const newObj = chooseProps(d.tokens[wordIdx])
|
44 |
-
|
45 |
-
return R.assoc('height', d.height, newObj)
|
46 |
-
})
|
47 |
-
|
48 |
-
return dataOut
|
49 |
-
}
|
50 |
-
|
51 |
-
|
52 |
-
export class CorpusMatManager extends VComponent<DataInterface>{
|
53 |
-
css_name = 'corpus-mat-container'
|
54 |
-
options = {
|
55 |
-
cellWidth: 10,
|
56 |
-
toPick: ['pos'],
|
57 |
-
idxs: [-1, 0, 1],
|
58 |
-
divHover: {
|
59 |
-
width: 60,
|
60 |
-
height: 40
|
61 |
-
}
|
62 |
-
}
|
63 |
-
|
64 |
-
static events = {
|
65 |
-
mouseOver: "CorpusMatManager_MouseOver",
|
66 |
-
mouseOut: "CorpusMatManager_MouseOut",
|
67 |
-
click: "CorpusMatManager_Click",
|
68 |
-
dblClick: "CorpusMatManager_DblClick",
|
69 |
-
rectMouseOver: "CorpusMatManager_RectMouseOver",
|
70 |
-
rectMouseOut: "CorpusMatManager_RectMouseOut",
|
71 |
-
rectClick: "CorpusMatManager_RectClick",
|
72 |
-
rectDblClick: "CorpusMatManager_RectDblClick",
|
73 |
-
}
|
74 |
-
|
75 |
-
// The d3 components that are saved to make rendering faster
|
76 |
-
corpusMats: D3Sel
|
77 |
-
rowGroups: D3Sel
|
78 |
-
divHover: D3Sel
|
79 |
-
|
80 |
-
_current = {}
|
81 |
-
rowCssName = 'index-match-results'
|
82 |
-
cellCssName = 'index-cell-result'
|
83 |
-
|
84 |
-
_data: DataInterface
|
85 |
-
|
86 |
-
static colorScale: tp.ColorMetaScale = spacyColors.colorScale;
|
87 |
-
|
88 |
-
// Selections
|
89 |
-
constructor(d3parent: D3Sel, eventHandler?: SimpleEventHandler, options = {}) {
|
90 |
-
super(d3parent, eventHandler)
|
91 |
-
this.idxs = [-1, 0, 1];
|
92 |
-
this.superInitHTML(options)
|
93 |
-
this._init()
|
94 |
-
}
|
95 |
-
|
96 |
-
get idxs() {
|
97 |
-
return this.options.idxs;
|
98 |
-
}
|
99 |
-
|
100 |
-
set idxs(val: number[]) {
|
101 |
-
this.options.idxs = val
|
102 |
-
}
|
103 |
-
|
104 |
-
// Create static dom elements
|
105 |
-
_init() {
|
106 |
-
const self = this;
|
107 |
-
this.corpusMats = this.base.selectAll('.corpus-mat')
|
108 |
-
this.rowGroups = this.corpusMats.selectAll(`.${this.rowCssName}`)
|
109 |
-
this.divHover = this.base.append('div')
|
110 |
-
.classed('mat-hover-display', true)
|
111 |
-
.classed('text-center', true)
|
112 |
-
.style('width', String(this.options.divHover.width) + 'px')
|
113 |
-
.style('height', String(this.options.divHover.height) + 'px')
|
114 |
-
|
115 |
-
this.divHover.append('p')
|
116 |
-
}
|
117 |
-
|
118 |
-
pick(val: DisplayOptions) {
|
119 |
-
this.options.toPick = [val]
|
120 |
-
this.redraw()
|
121 |
-
}
|
122 |
-
|
123 |
-
addRight() {
|
124 |
-
const addedIdx = R.last(this.idxs) + 1;
|
125 |
-
this.idxs.push(addedIdx)
|
126 |
-
this.addCorpusMat(addedIdx, "right")
|
127 |
-
}
|
128 |
-
|
129 |
-
addLeft() {
|
130 |
-
const addedIdx = this.idxs[0] - 1;
|
131 |
-
const addDecrementedHead: (x: number[]) => number[] = x => R.insert(0, R.head(x) - 1)(x)
|
132 |
-
this.idxs = addDecrementedHead(this.idxs)
|
133 |
-
this.addCorpusMat(addedIdx, "left")
|
134 |
-
}
|
135 |
-
|
136 |
-
killRight() {
|
137 |
-
this.kill(Math.max(...this.idxs))
|
138 |
-
}
|
139 |
-
|
140 |
-
killLeft() {
|
141 |
-
this.kill(Math.min(...this.idxs))
|
142 |
-
}
|
143 |
-
|
144 |
-
/**
|
145 |
-
* Remove edge value from contained indexes
|
146 |
-
*
|
147 |
-
* @param d Index to remove
|
148 |
-
*/
|
149 |
-
kill(d: number) {
|
150 |
-
if (d != 0) {
|
151 |
-
if (d == Math.min(...this.idxs) || d == Math.max(...this.idxs)) {
|
152 |
-
this.idxs = R.without([d], this.idxs)
|
153 |
-
this.base.selectAll(`.offset-${d}`).remove()
|
154 |
-
}
|
155 |
-
}
|
156 |
-
}
|
157 |
-
|
158 |
-
_wrangle(data: DataInterface) {
|
159 |
-
return data
|
160 |
-
}
|
161 |
-
|
162 |
-
data(val?: DataInterface) {
|
163 |
-
if (val == null) {
|
164 |
-
return this._data;
|
165 |
-
}
|
166 |
-
|
167 |
-
this._data = val;
|
168 |
-
this._updateData();
|
169 |
-
return this;
|
170 |
-
}
|
171 |
-
|
172 |
-
/**
|
173 |
-
* The main rendering code, called whenever the data changes.
|
174 |
-
*/
|
175 |
-
private _updateData() {
|
176 |
-
const self = this;
|
177 |
-
const op = this.options;
|
178 |
-
|
179 |
-
this.base.selectAll('.corpus-mat').remove()
|
180 |
-
|
181 |
-
this.idxs.forEach((idxOffset, i) => {
|
182 |
-
self.addCorpusMat(idxOffset)
|
183 |
-
})
|
184 |
-
}
|
185 |
-
|
186 |
-
/**
|
187 |
-
* Add another word's meta information matrix column to either side of the index
|
188 |
-
*
|
189 |
-
* @param idxOffset Distance of word from matched word in the sentence
|
190 |
-
* @param toThe Indicates adding to the "left" or to the "right" of the index
|
191 |
-
*/
|
192 |
-
addCorpusMat(idxOffset: number, toThe: "right" | "left" = "right") {
|
193 |
-
const self = this;
|
194 |
-
const op = this.options;
|
195 |
-
const boxWidth = op.cellWidth * op.toPick.length;
|
196 |
-
const boxHeight = R.sum(R.map(R.prop('height'), this._data))
|
197 |
-
|
198 |
-
let corpusMat;
|
199 |
-
|
200 |
-
if (toThe == "right") {
|
201 |
-
corpusMat = this.base.append('div')
|
202 |
-
}
|
203 |
-
else if (toThe == "left") {
|
204 |
-
corpusMat = this.base.insert('div', ":first-child")
|
205 |
-
}
|
206 |
-
else {
|
207 |
-
throw Error("toThe must have argument of 'left' or 'right'")
|
208 |
-
}
|
209 |
-
|
210 |
-
corpusMat = corpusMat
|
211 |
-
.data([idxOffset])
|
212 |
-
.attr('class', `corpus-mat offset-${idxOffset}`)
|
213 |
-
.attr('offset', idxOffset)
|
214 |
-
.append('svg')
|
215 |
-
.attrs({
|
216 |
-
width: boxWidth,
|
217 |
-
height: boxHeight,
|
218 |
-
})
|
219 |
-
.on('mouseover', function (d, i) {
|
220 |
-
self.eventHandler.trigger(CorpusMatManager.events.mouseOver, { idx: i, offset: d, val: self.options.toPick[0] })
|
221 |
-
})
|
222 |
-
.on('mouseout', (d, i) => {
|
223 |
-
this.eventHandler.trigger(CorpusMatManager.events.mouseOut, { idx: i, offset: d })
|
224 |
-
})
|
225 |
-
|
226 |
-
this.addRowGroup(corpusMat)
|
227 |
-
}
|
228 |
-
|
229 |
-
/**
|
230 |
-
*
|
231 |
-
* @param mat The base div on which to add matrices and rows
|
232 |
-
*/
|
233 |
-
addRowGroup(mat: D3Sel) {
|
234 |
-
const self = this;
|
235 |
-
const op = this.options;
|
236 |
-
|
237 |
-
const heights = R.map(R.prop('height'), this._data)
|
238 |
-
|
239 |
-
const [heightSum, rawHeightList] = R.mapAccum((x, y) => [R.add(x, y), R.add(x, y)], 0, heights)
|
240 |
-
const fixList: (x: number[]) => number[] = R.compose(R.dropLast(1),
|
241 |
-
// @ts-ignore
|
242 |
-
R.prepend(0)
|
243 |
-
)
|
244 |
-
const heightList = fixList(rawHeightList)
|
245 |
-
|
246 |
-
const rowGroup = mat.selectAll(`.${self.rowCssName}`)
|
247 |
-
.data(d => managerData2MatData(self._data, d, op.toPick))
|
248 |
-
.join("g")
|
249 |
-
.attr("class", (d, i) => {
|
250 |
-
return `${self.rowCssName} ${self.rowCssName}-${i}`
|
251 |
-
})
|
252 |
-
.attr("row-num", (d,i) => i)
|
253 |
-
.attr("height", d => d.height)
|
254 |
-
.attr("transform", (d, i) => {
|
255 |
-
const out = SVG.translate({
|
256 |
-
x: 0,
|
257 |
-
y: heightList[i],
|
258 |
-
})
|
259 |
-
return out
|
260 |
-
})
|
261 |
-
|
262 |
-
op.toPick.forEach(prop => {
|
263 |
-
self.addRect(rowGroup, 0, prop)
|
264 |
-
})
|
265 |
-
}
|
266 |
-
|
267 |
-
addRect(g: D3Sel, xShift: number, prop: string) {
|
268 |
-
const self = this
|
269 |
-
const op = this.options
|
270 |
-
|
271 |
-
const rects = g.append('rect')
|
272 |
-
.attrs({
|
273 |
-
width: op.cellWidth,
|
274 |
-
height: d => d.height - 3,
|
275 |
-
transform: (d, i) => {
|
276 |
-
return SVG.translate({
|
277 |
-
x: xShift,
|
278 |
-
y: 1.5,
|
279 |
-
})
|
280 |
-
},
|
281 |
-
})
|
282 |
-
.style('fill', d => CorpusMatManager.colorScale[prop](d[prop]))
|
283 |
-
|
284 |
-
|
285 |
-
const getBaseX = () => (<HTMLElement>self.base.node()).getBoundingClientRect().left
|
286 |
-
const getBaseY = () => (<HTMLElement>self.base.node()).getBoundingClientRect().top
|
287 |
-
|
288 |
-
g.on('mouseover', function (d, i) {
|
289 |
-
self.divHover.style('visibility', 'visible')
|
290 |
-
// Get offset
|
291 |
-
const col = d3.select(this.parentNode.parentNode) // Column
|
292 |
-
const offset = +col.attr('offset')
|
293 |
-
self.eventHandler.trigger(CorpusMatManager.events.rectMouseOver, {idx: i, offset: offset})
|
294 |
-
})
|
295 |
-
.on('mouseout', function (d, i) {
|
296 |
-
self.divHover.style('visibility', 'hidden')
|
297 |
-
const col = d3.select(this.parentNode.parentNode) // Column
|
298 |
-
const offset = +col.attr('offset')
|
299 |
-
self.eventHandler.trigger(CorpusMatManager.events.rectMouseOut, {idx: i, offset: offset})
|
300 |
-
})
|
301 |
-
.on('mousemove', function(d, i) {
|
302 |
-
const mouse = d3.mouse(self.base.node())
|
303 |
-
const divOffset = [3, 3]
|
304 |
-
const left = mouse[0] + getBaseX() - (op.divHover.width + divOffset[0])
|
305 |
-
const top = mouse[1] + getBaseY() - (op.divHover.height + divOffset[1])
|
306 |
-
self.divHover
|
307 |
-
.style('left', String(left) + 'px')
|
308 |
-
.style('top', String(top) + 'px')
|
309 |
-
.selectAll('p')
|
310 |
-
.text(d[prop])
|
311 |
-
})
|
312 |
-
}
|
313 |
-
|
314 |
-
/**
|
315 |
-
* @param data Data to display
|
316 |
-
*/
|
317 |
-
_render(data: DataInterface) {
|
318 |
-
this._updateData();
|
319 |
-
}
|
320 |
-
|
321 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client/src/ts/vis/{myMain.ts → attentionVis.ts}
RENAMED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import * as d3 from 'd3';
|
2 |
import * as _ from "lodash"
|
3 |
import * as R from 'ramda'
|
@@ -9,19 +13,13 @@ import { UIConfig } from '../uiConfig'
|
|
9 |
import { TextTokens, LeftTextToken, RightTextToken } from './TextToken'
|
10 |
import { AttentionHeadBox, getAttentionInfo } from './AttentionHeadBox'
|
11 |
import { AttentionGraph } from './AttentionConnector'
|
12 |
-
import { CorpusInspector } from './CorpusInspector'
|
13 |
import { TokenWrapper, sideToLetter } from '../data/TokenWrapper'
|
14 |
import { AttentionWrapper, makeFromMetaResponse } from '../data/AttentionCapsule'
|
15 |
import { SimpleEventHandler } from '../etc/SimpleEventHandler'
|
16 |
-
import { CorpusMatManager } from '../vis/CorpusMatManager'
|
17 |
-
import { CorpusHistogram } from '../vis/CorpusHistogram'
|
18 |
-
import { FaissSearchResultWrapper } from '../data/FaissSearchWrapper'
|
19 |
import { D3Sel, Sel } from '../etc/Util';
|
20 |
-
import { from, fromEvent
|
21 |
import { switchMap, map, tap } from 'rxjs/operators'
|
22 |
import { BaseType } from "d3";
|
23 |
-
import { SimpleMeta } from "../etc/types";
|
24 |
-
import ChangeEvent = JQuery.ChangeEvent;
|
25 |
|
26 |
|
27 |
function isNullToken(tok: tp.TokenEvent) {
|
@@ -69,8 +67,199 @@ function setSelDisabled(attr: boolean, sel: D3Sel) {
|
|
69 |
sel.attr('disabled', val)
|
70 |
}
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
export class MainGraphic {
|
|
|
74 |
api: API
|
75 |
uiConf: UIConfig
|
76 |
attCapsule: AttentionWrapper
|
@@ -79,75 +268,17 @@ export class MainGraphic {
|
|
79 |
vizs: any // Contains vis components wrapped around parent sel
|
80 |
eventHandler: SimpleEventHandler // Orchestrates events raised from components
|
81 |
|
82 |
-
constructor() {
|
83 |
-
this.api = new API()
|
84 |
-
this.uiConf = new UIConfig()
|
85 |
-
this.skeletonInit()
|
86 |
-
this.mainInit();
|
87 |
-
}
|
88 |
-
|
89 |
/**
|
90 |
-
* Functions that can be called without any information of a response.
|
91 |
*
|
92 |
-
*
|
93 |
*/
|
94 |
-
|
95 |
-
this.
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
modelSelector: d3.select("#model-option-selector"),
|
100 |
-
corpusSelector: d3.select("#corpus-select"),
|
101 |
-
atnHeads: {
|
102 |
-
left: d3.select("#left-att-heads"),
|
103 |
-
right: d3.select("#right-att-heads"),
|
104 |
-
headInfo: d3.select("#head-info-box")
|
105 |
-
.classed('mat-hover-display', true)
|
106 |
-
.classed('text-center', true)
|
107 |
-
.style('width', String(70) + 'px')
|
108 |
-
.style('height', String(30) + 'px')
|
109 |
-
.style('visibillity', 'hidden')
|
110 |
-
},
|
111 |
-
form: {
|
112 |
-
sentenceA: d3.select("#form-sentence-a"),
|
113 |
-
button: d3.select("#update-sentence"),
|
114 |
-
},
|
115 |
-
tokens: {
|
116 |
-
left: d3.select("#left-tokens"),
|
117 |
-
right: d3.select("#right-tokens"),
|
118 |
-
},
|
119 |
-
clsToggle: d3.select("#cls-toggle").select(".switch"),
|
120 |
-
layerCheckboxes: d3.select("#layer-select"),
|
121 |
-
headCheckboxes: d3.select("#head-select"),
|
122 |
-
contextQuery: d3.select("#search-contexts"),
|
123 |
-
embeddingQuery: d3.select("#search-embeddings"),
|
124 |
-
selectedHeads: d3.select("#selected-heads"),
|
125 |
-
headSelectAll: d3.select("#select-all-heads"),
|
126 |
-
headSelectNone: d3.select("#select-no-heads"),
|
127 |
-
testCheckbox: d3.select("#simple-embed-query"),
|
128 |
-
threshSlider: d3.select("#my-range"),
|
129 |
-
corpusInspector: d3.select("#corpus-similar-sentences-div"),
|
130 |
-
corpusMatManager: d3.select("#corpus-mat-container"),
|
131 |
-
corpusMsgBox: d3.select("#corpus-msg-box"),
|
132 |
-
histograms: {
|
133 |
-
matchedWordDescription: d3.select("#match-kind"),
|
134 |
-
matchedWord: d3.select("#matched-histogram-container"),
|
135 |
-
maxAtt: d3.select("#max-att-histogram-container"),
|
136 |
-
},
|
137 |
-
buttons: {
|
138 |
-
killLeft: d3.select("#kill-left"),
|
139 |
-
addLeft: d3.select("#minus-left"),
|
140 |
-
addRight: d3.select("#plus-right"),
|
141 |
-
killRight: d3.select("#kill-right"),
|
142 |
-
refresh: d3.select("#mat-refresh")
|
143 |
-
},
|
144 |
-
metaSelector: {
|
145 |
-
matchedWord: d3.select("#matched-meta-select"),
|
146 |
-
maxAtt: d3.select("#max-att-meta-select")
|
147 |
-
}
|
148 |
-
}
|
149 |
|
150 |
-
this.eventHandler = new SimpleEventHandler(<Element>this.
|
151 |
|
152 |
this.vizs = {
|
153 |
leftHeads: new AttentionHeadBox(this.sels.atnHeads.left, this.eventHandler, { side: "left", }),
|
@@ -157,15 +288,11 @@ export class MainGraphic {
|
|
157 |
right: new RightTextToken(this.sels.tokens.right, this.eventHandler),
|
158 |
},
|
159 |
attentionSvg: new AttentionGraph(this.sels.atnDisplay, this.eventHandler),
|
160 |
-
corpusInspector: new CorpusInspector(this.sels.corpusInspector, this.eventHandler),
|
161 |
-
corpusMatManager: new CorpusMatManager(this.sels.corpusMatManager, this.eventHandler, { idxs: this.uiConf.offsetIdxs() }),
|
162 |
-
histograms: {
|
163 |
-
matchedWord: new CorpusHistogram(this.sels.histograms.matchedWord, this.eventHandler),
|
164 |
-
maxAtt: new CorpusHistogram(this.sels.histograms.maxAtt, this.eventHandler),
|
165 |
-
},
|
166 |
}
|
167 |
|
168 |
this._bindEventHandler()
|
|
|
|
|
169 |
}
|
170 |
|
171 |
private mainInit() {
|
@@ -183,23 +310,14 @@ export class MainGraphic {
|
|
183 |
// Wrap postInit into function so asynchronous call does not mess with necessary inits
|
184 |
const postResponseDisplayCleanup = () => {
|
185 |
this._toggleTokenSel()
|
186 |
-
|
187 |
-
const toDisplay = this.uiConf.displayInspector()
|
188 |
-
this._searchDisabler()
|
189 |
-
|
190 |
-
if (toDisplay == 'context') {
|
191 |
-
this._queryContext()
|
192 |
-
} else if (toDisplay == 'embeddings') {
|
193 |
-
this._queryEmbeddings()
|
194 |
-
}
|
195 |
}
|
196 |
|
197 |
let normBy
|
198 |
if ((this.uiConf.modelKind() == tp.ModelKind.Autoregressive) && (!this.uiConf.hideClsSep())) {
|
199 |
-
normBy = tp.NormBy.
|
200 |
}
|
201 |
else {
|
202 |
-
normBy = tp.NormBy.
|
203 |
}
|
204 |
this.vizs.attentionSvg.normBy = normBy
|
205 |
|
@@ -363,35 +481,9 @@ export class MainGraphic {
|
|
363 |
unselectHead(e.head)
|
364 |
}
|
365 |
|
366 |
-
this._searchDisabler()
|
367 |
this._renderHeadSummary();
|
368 |
this.renderSvg();
|
369 |
})
|
370 |
-
|
371 |
-
this.eventHandler.bind(CorpusMatManager.events.mouseOver, (e: { val: "pos" | "dep" | "is_ent", offset: number }) => {
|
372 |
-
// Uncomment the below if you want to modify the whole column
|
373 |
-
// const selector = `.inspector-cell[index-offset='${e.offset}']`
|
374 |
-
// d3.selectAll(selector).classed("hovered-col", true)
|
375 |
-
})
|
376 |
-
|
377 |
-
this.eventHandler.bind(CorpusMatManager.events.mouseOut, (e: { offset: number, idx: number }) => {
|
378 |
-
// Uncomment the below if you want to modify the whole column
|
379 |
-
// const selector = `.inspector-cell[index-offset='${e.offset}']`
|
380 |
-
// d3.selectAll(selector).classed("hovered-col", false)
|
381 |
-
})
|
382 |
-
|
383 |
-
this.eventHandler.bind(CorpusMatManager.events.rectMouseOver, (e: { offset: number, idx: number }) => {
|
384 |
-
const row = d3.select(`.inspector-row[rownum='${e.idx}']`)
|
385 |
-
const word = row.select(`.inspector-cell[index-offset='${e.offset}']`)
|
386 |
-
word.classed("hovered-col", true)
|
387 |
-
})
|
388 |
-
|
389 |
-
this.eventHandler.bind(CorpusMatManager.events.rectMouseOut, (e: { offset: number, idx: number }) => {
|
390 |
-
const row = d3.select(`.inspector-row[rownum='${e.idx}']`)
|
391 |
-
const word = row.select(`.inspector-cell[index-offset='${e.offset}']`)
|
392 |
-
word.classed("hovered-col", false)
|
393 |
-
})
|
394 |
-
|
395 |
}
|
396 |
|
397 |
private _toggleTokenSel() {
|
@@ -421,8 +513,6 @@ export class MainGraphic {
|
|
421 |
this.grayToggle(+e.ind)
|
422 |
this.markNextToggle(+e.ind, this.tokCapsule.a.length())
|
423 |
}
|
424 |
-
|
425 |
-
this._searchDisabler()
|
426 |
}
|
427 |
|
428 |
/** Gray all tokens that have index greater than ind */
|
@@ -463,189 +553,15 @@ export class MainGraphic {
|
|
463 |
|
464 |
}
|
465 |
|
466 |
-
private _initModelSelection() {
|
467 |
-
const self = this
|
468 |
-
|
469 |
-
// Below are the available models. Will need to choose 3 to be available ONLY
|
470 |
-
const data = [
|
471 |
-
{ name: "bert-base-cased", kind: tp.ModelKind.Bidirectional },
|
472 |
-
{ name: "bert-base-uncased", kind: tp.ModelKind.Bidirectional },
|
473 |
-
{ name: "distilbert-base-uncased", kind: tp.ModelKind.Bidirectional },
|
474 |
-
{ name: "distilroberta-base", kind: tp.ModelKind.Bidirectional },
|
475 |
-
// { name: "roberta-base", kind: tp.ModelKind.Bidirectional },
|
476 |
-
{ name: "gpt2", kind: tp.ModelKind.Autoregressive },
|
477 |
-
// { name: "gpt2-medium", kind: tp.ModelKind.Autoregressive },
|
478 |
-
// { name: "distilgpt2", kind: tp.ModelKind.Autoregressive },
|
479 |
-
]
|
480 |
-
|
481 |
-
const names = R.map(R.prop('name'))(data)
|
482 |
-
const kinds = R.map(R.prop('kind'))(data)
|
483 |
-
const kindmap = R.zipObj(names, kinds)
|
484 |
-
|
485 |
-
this.sels.modelSelector.selectAll('.model-option')
|
486 |
-
.data(data)
|
487 |
-
.join('option')
|
488 |
-
.classed('model-option', true)
|
489 |
-
.property('value', d => d.name)
|
490 |
-
.attr("modelkind", d => d.kind)
|
491 |
-
.text(d => d.name)
|
492 |
-
|
493 |
-
this.sels.modelSelector.property('value', this.uiConf.model());
|
494 |
-
|
495 |
-
this.sels.modelSelector.on('change', function () {
|
496 |
-
const me = d3.select(this)
|
497 |
-
const mname = me.property('value')
|
498 |
-
self.uiConf.model(mname);
|
499 |
-
self.uiConf.modelKind(kindmap[mname]);
|
500 |
-
if (kindmap[mname] == tp.ModelKind.Autoregressive) {
|
501 |
-
console.log("RESETTING MASK INDS");
|
502 |
-
self.uiConf.maskInds([])
|
503 |
-
}
|
504 |
-
self.mainInit();
|
505 |
-
})
|
506 |
-
}
|
507 |
-
|
508 |
-
private _initCorpusSelection() {
|
509 |
-
const data = [
|
510 |
-
{ code: "woz", display: "Wizard of Oz" },
|
511 |
-
{ code: "wiki", display: "Wikipedia" },
|
512 |
-
]
|
513 |
-
|
514 |
-
const self = this
|
515 |
-
self.sels.corpusSelector.selectAll('option')
|
516 |
-
.data(data)
|
517 |
-
.join('option')
|
518 |
-
.property('value', d => d.code)
|
519 |
-
.text(d => d.display)
|
520 |
-
|
521 |
-
this.sels.corpusSelector.on('change', function () {
|
522 |
-
const me = d3.select(this)
|
523 |
-
self.uiConf.corpus(me.property('value'))
|
524 |
-
console.log(self.uiConf.corpus());
|
525 |
-
})
|
526 |
-
|
527 |
-
|
528 |
-
}
|
529 |
-
|
530 |
private _staticInits() {
|
531 |
this._initSentenceForm();
|
532 |
this._initModelSelection();
|
533 |
-
this._initCorpusSelection();
|
534 |
-
this._initQueryForm();
|
535 |
-
this._initAdder();
|
536 |
this._renderHeadSummary();
|
537 |
-
this._initMetaSelectors();
|
538 |
this._initToggle();
|
539 |
this.renderAttHead();
|
540 |
this.renderTokens();
|
541 |
}
|
542 |
|
543 |
-
private _initAdder() {
|
544 |
-
const updateUrlOffsetIdxs = () => {
|
545 |
-
this.uiConf.offsetIdxs(this.vizs.corpusMatManager.idxs)
|
546 |
-
}
|
547 |
-
|
548 |
-
const fixCorpusMatHeights = () => {
|
549 |
-
const newWrapped = this._wrapResults(this.vizs.corpusMatManager.data())
|
550 |
-
this.vizs.corpusMatManager.data(newWrapped.data)
|
551 |
-
updateUrlOffsetIdxs()
|
552 |
-
}
|
553 |
-
|
554 |
-
this.sels.buttons.addRight.on('click', () => {
|
555 |
-
this.vizs.corpusMatManager.addRight()
|
556 |
-
updateUrlOffsetIdxs()
|
557 |
-
})
|
558 |
-
|
559 |
-
this.sels.buttons.addLeft.on('click', () => {
|
560 |
-
this.vizs.corpusMatManager.addLeft()
|
561 |
-
updateUrlOffsetIdxs()
|
562 |
-
})
|
563 |
-
|
564 |
-
this.sels.buttons.killRight.on('click', () => {
|
565 |
-
this.vizs.corpusMatManager.killRight()
|
566 |
-
updateUrlOffsetIdxs()
|
567 |
-
})
|
568 |
-
|
569 |
-
this.sels.buttons.killLeft.on('click', () => {
|
570 |
-
this.vizs.corpusMatManager.killLeft()
|
571 |
-
updateUrlOffsetIdxs()
|
572 |
-
})
|
573 |
-
|
574 |
-
this.sels.buttons.refresh.on('click', () => {
|
575 |
-
fixCorpusMatHeights();
|
576 |
-
})
|
577 |
-
|
578 |
-
const onresize = () => {
|
579 |
-
if (this.sels.corpusInspector.text() != '') fixCorpusMatHeights();
|
580 |
-
}
|
581 |
-
|
582 |
-
window.onresize = onresize
|
583 |
-
}
|
584 |
-
|
585 |
-
private _initMetaSelectors() {
|
586 |
-
this._initMatchedWordSelector(this.sels.metaSelector.matchedWord)
|
587 |
-
this._initMaxAttSelector(this.sels.metaSelector.maxAtt)
|
588 |
-
}
|
589 |
-
|
590 |
-
private _initMaxAttSelector(sel: D3Sel) {
|
591 |
-
const self = this;
|
592 |
-
|
593 |
-
const chooseSelected = (value) => {
|
594 |
-
const ms = sel.selectAll('label')
|
595 |
-
ms.classed('active', false)
|
596 |
-
const el = sel.selectAll(`label[value=${value}]`)
|
597 |
-
el.classed('active', true)
|
598 |
-
}
|
599 |
-
|
600 |
-
chooseSelected(this.uiConf.metaMax())
|
601 |
-
|
602 |
-
const el = sel.selectAll('label')
|
603 |
-
el.on('click', function () {
|
604 |
-
const val = <SimpleMeta>d3.select(this).attr('value');
|
605 |
-
|
606 |
-
// Do toggle
|
607 |
-
sel.selectAll('.active').classed('active', false)
|
608 |
-
d3.select(this).classed('active', true)
|
609 |
-
self.uiConf.metaMax(val)
|
610 |
-
self.vizs.histograms.maxAtt.meta(val)
|
611 |
-
})
|
612 |
-
}
|
613 |
-
|
614 |
-
private _initMatchedWordSelector(sel: D3Sel) {
|
615 |
-
const self = this;
|
616 |
-
|
617 |
-
const chooseSelected = (value) => {
|
618 |
-
const ms = sel.selectAll('label')
|
619 |
-
ms.classed('active', false)
|
620 |
-
const el = sel.selectAll(`label[value=${value}]`)
|
621 |
-
el.classed('active', true)
|
622 |
-
}
|
623 |
-
|
624 |
-
chooseSelected(this.uiConf.metaMatch())
|
625 |
-
|
626 |
-
const el = sel.selectAll('label')
|
627 |
-
el.on('click', function () {
|
628 |
-
const val = <SimpleMeta>d3.select(this).attr('value')
|
629 |
-
|
630 |
-
// Do toggle
|
631 |
-
sel.selectAll('.active').classed('active', false)
|
632 |
-
d3.select(this).classed('active', true)
|
633 |
-
self.uiConf.metaMatch(val)
|
634 |
-
self._updateCorpusInspectorFromMeta(val)
|
635 |
-
})
|
636 |
-
}
|
637 |
-
|
638 |
-
private _disableSearching(attr: boolean) {
|
639 |
-
setSelDisabled(attr, this.sels.contextQuery)
|
640 |
-
setSelDisabled(attr, this.sels.embeddingQuery)
|
641 |
-
}
|
642 |
-
|
643 |
-
private _updateCorpusInspectorFromMeta(val: tp.SimpleMeta) {
|
644 |
-
this.vizs.corpusInspector.showNext(this.uiConf.showNext)
|
645 |
-
this.vizs.corpusMatManager.pick(val)
|
646 |
-
this.vizs.histograms.matchedWord.meta(val)
|
647 |
-
}
|
648 |
-
|
649 |
private _initSentenceForm() {
|
650 |
const self = this;
|
651 |
|
@@ -697,165 +613,11 @@ export class MainGraphic {
|
|
697 |
inputBox.on('keypress', onEnterSubmit)
|
698 |
}
|
699 |
|
700 |
-
private _getSearchEmbeds() {
|
701 |
-
const savedToken = this.uiConf.token();
|
702 |
-
const out = this.vizs.tokens[savedToken.side].getEmbedding(savedToken.ind)
|
703 |
-
return out.embeddings
|
704 |
-
}
|
705 |
-
|
706 |
-
private _getSearchContext() {
|
707 |
-
const savedToken = this.uiConf.token();
|
708 |
-
const out = this.vizs.tokens[savedToken.side].getEmbedding(savedToken.ind)
|
709 |
-
return out.contexts
|
710 |
-
}
|
711 |
-
|
712 |
-
private _searchEmbeddings() {
|
713 |
-
const self = this;
|
714 |
-
console.log("SEARCHING EMBEDDINGS");
|
715 |
-
const embed = this._getSearchEmbeds()
|
716 |
-
const layer = self.uiConf.layer()
|
717 |
-
const heads = self.uiConf.heads()
|
718 |
-
const k = 50
|
719 |
-
self.vizs.corpusInspector.showNext(self.uiConf.showNext)
|
720 |
-
|
721 |
-
this.sels.body.style("cursor", "progress")
|
722 |
-
self.api.getNearestEmbeddings(self.uiConf.model(), self.uiConf.corpus(), embed, layer, heads, k)
|
723 |
-
.then((val: rsp.NearestNeighborResponse) => {
|
724 |
-
if (val.status == 406) {
|
725 |
-
self.leaveCorpusMsg(`Embeddings are not available for model '${self.uiConf.model()}' and corpus '${self.uiConf.corpus()}' at this time.`)
|
726 |
-
}
|
727 |
-
else {
|
728 |
-
const v = val.payload
|
729 |
-
|
730 |
-
self.vizs.corpusInspector.unhideView()
|
731 |
-
self.vizs.corpusMatManager.unhideView()
|
732 |
-
|
733 |
-
// Get heights of corpus inspector rows.
|
734 |
-
self.vizs.corpusInspector.update(v)
|
735 |
-
const wrappedVals = self._wrapResults(v)
|
736 |
-
const countedVals = wrappedVals.getMatchedHistogram()
|
737 |
-
const offsetVals = wrappedVals.getMaxAttHistogram()
|
738 |
-
|
739 |
-
self.vizs.corpusMatManager.update(wrappedVals.data)
|
740 |
-
self.sels.histograms.matchedWordDescription.text(this.uiConf.matchHistogramDescription)
|
741 |
-
console.log("MATCHER: ", self.sels.histograms.matchedWord);
|
742 |
-
self.vizs.histograms.matchedWord.update(countedVals)
|
743 |
-
self.vizs.histograms.maxAtt.update(offsetVals)
|
744 |
-
self.uiConf.displayInspector('embeddings')
|
745 |
-
this._updateCorpusInspectorFromMeta(this.uiConf.metaMatch())
|
746 |
-
}
|
747 |
-
this.sels.body.style("cursor", "default")
|
748 |
-
})
|
749 |
-
}
|
750 |
-
|
751 |
-
private _searchContext() {
|
752 |
-
const self = this;
|
753 |
-
console.log("SEARCHING CONTEXTS");
|
754 |
-
const context = this._getSearchContext()
|
755 |
-
const layer = self.uiConf.layer()
|
756 |
-
const heads = self.uiConf.heads()
|
757 |
-
const k = 50
|
758 |
-
self.vizs.corpusInspector.showNext(self.uiConf.showNext)
|
759 |
-
|
760 |
-
this.sels.body.style("cursor", "progress")
|
761 |
-
|
762 |
-
self.api.getNearestContexts(self.uiConf.model(), self.uiConf.corpus(), context, layer, heads, k)
|
763 |
-
.then((val: rsp.NearestNeighborResponse) => {
|
764 |
-
// Get heights of corpus inspector rows.
|
765 |
-
if (val.status == 406) {
|
766 |
-
console.log("Contexts are not available!");
|
767 |
-
self.leaveCorpusMsg(`Contexts are not available for model '${self.uiConf.model()}' and corpus '${self.uiConf.corpus()}' at this time.`)
|
768 |
-
}
|
769 |
-
else {
|
770 |
-
const v = val.payload;
|
771 |
-
console.log("HIDING");
|
772 |
-
|
773 |
-
self.vizs.corpusInspector.update(v)
|
774 |
-
|
775 |
-
Sel.hideElement(self.sels.corpusMsgBox)
|
776 |
-
self.vizs.corpusInspector.unhideView()
|
777 |
-
self.vizs.corpusMatManager.unhideView()
|
778 |
-
|
779 |
-
const wrappedVals = self._wrapResults(v)
|
780 |
-
const countedVals = wrappedVals.getMatchedHistogram()
|
781 |
-
const offsetVals = wrappedVals.getMaxAttHistogram()
|
782 |
-
self.vizs.corpusMatManager.update(wrappedVals.data)
|
783 |
-
|
784 |
-
self.vizs.histograms.matchedWord.update(countedVals)
|
785 |
-
self.vizs.histograms.maxAtt.update(offsetVals)
|
786 |
-
|
787 |
-
self.uiConf.displayInspector('context')
|
788 |
-
this._updateCorpusInspectorFromMeta(this.uiConf.metaMatch())
|
789 |
-
self.vizs.histograms.maxAtt.meta(self.uiConf.metaMax())
|
790 |
-
}
|
791 |
-
this.sels.body.style("cursor", "default")
|
792 |
-
})
|
793 |
-
}
|
794 |
-
|
795 |
-
private _queryContext() {
|
796 |
-
const self = this;
|
797 |
-
|
798 |
-
if (this.uiConf.hasToken()) {
|
799 |
-
this._searchContext();
|
800 |
-
} else {
|
801 |
-
console.log("Was told to show inspector but was not given a selected token embedding")
|
802 |
-
}
|
803 |
-
}
|
804 |
-
|
805 |
-
private _queryEmbeddings() {
|
806 |
-
const self = this;
|
807 |
-
|
808 |
-
if (this.uiConf.hasToken()) {
|
809 |
-
console.log("token: ", this.uiConf.token());
|
810 |
-
this._searchEmbeddings();
|
811 |
-
} else {
|
812 |
-
console.log("Was told to show inspector but was not given a selected token embedding")
|
813 |
-
}
|
814 |
-
}
|
815 |
-
|
816 |
-
private _searchingDisabled() {
|
817 |
-
return (this.uiConf.heads().length == 0) || (!this.uiConf.hasToken())
|
818 |
-
}
|
819 |
-
|
820 |
-
private _searchDisabler() {
|
821 |
-
this._disableSearching(this._searchingDisabled())
|
822 |
-
}
|
823 |
-
|
824 |
-
private _initQueryForm() {
|
825 |
-
const self = this;
|
826 |
-
this._searchDisabler()
|
827 |
-
|
828 |
-
this.sels.contextQuery.on("click", () => {
|
829 |
-
self._queryContext()
|
830 |
-
})
|
831 |
-
|
832 |
-
this.sels.embeddingQuery.on("click", () => {
|
833 |
-
self._queryEmbeddings()
|
834 |
-
})
|
835 |
-
}
|
836 |
-
|
837 |
private _renderHeadSummary() {
|
838 |
this.sels.selectedHeads
|
839 |
.html(R.join(', ', this.uiConf.heads().map(h => h + 1)))
|
840 |
}
|
841 |
|
842 |
-
// Modify faiss results with corresponding heights
|
843 |
-
private _wrapResults(returnedFaissResults: tp.FaissSearchResults[]) {
|
844 |
-
|
845 |
-
const rows = d3.selectAll('.inspector-row')
|
846 |
-
|
847 |
-
// Don't just use offsetHeight since that rounds to the nearest integer
|
848 |
-
const heights = rows.nodes().map((n: HTMLElement) => n.getBoundingClientRect().height)
|
849 |
-
|
850 |
-
const newVals = returnedFaissResults.map((v, i) => {
|
851 |
-
return R.assoc('height', heights[i], v)
|
852 |
-
})
|
853 |
-
|
854 |
-
const wrappedVals = new FaissSearchResultWrapper(newVals, this.uiConf.showNext)
|
855 |
-
|
856 |
-
return wrappedVals
|
857 |
-
}
|
858 |
-
|
859 |
private initLayers(nLayers: number) {
|
860 |
const self = this;
|
861 |
let hasActive = false;
|
@@ -932,14 +694,12 @@ export class MainGraphic {
|
|
932 |
|
933 |
this.sels.headSelectAll.on("click", function () {
|
934 |
self.uiConf.selectAllHeads();
|
935 |
-
self._searchDisabler()
|
936 |
self.renderSvg()
|
937 |
self.renderAttHead()
|
938 |
})
|
939 |
|
940 |
this.sels.headSelectNone.on("click", function () {
|
941 |
self.uiConf.selectNoHeads();
|
942 |
-
self._searchDisabler();
|
943 |
self.renderSvg()
|
944 |
self.renderAttHead()
|
945 |
Sel.setHidden(".atn-curve")
|
@@ -961,6 +721,48 @@ export class MainGraphic {
|
|
961 |
})
|
962 |
}
|
963 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
964 |
renderAttHead() {
|
965 |
const heads = _.range(0, this.uiConf._nHeads)
|
966 |
const focusAtt = this.attCapsule.att
|
@@ -1020,3 +822,5 @@ export class MainGraphic {
|
|
1020 |
this.render();
|
1021 |
}
|
1022 |
}
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Showing the top left part of exBERT, no information from the embeddings or the contexts
|
3 |
+
*/
|
4 |
+
|
5 |
import * as d3 from 'd3';
|
6 |
import * as _ from "lodash"
|
7 |
import * as R from 'ramda'
|
|
|
13 |
import { TextTokens, LeftTextToken, RightTextToken } from './TextToken'
|
14 |
import { AttentionHeadBox, getAttentionInfo } from './AttentionHeadBox'
|
15 |
import { AttentionGraph } from './AttentionConnector'
|
|
|
16 |
import { TokenWrapper, sideToLetter } from '../data/TokenWrapper'
|
17 |
import { AttentionWrapper, makeFromMetaResponse } from '../data/AttentionCapsule'
|
18 |
import { SimpleEventHandler } from '../etc/SimpleEventHandler'
|
|
|
|
|
|
|
19 |
import { D3Sel, Sel } from '../etc/Util';
|
20 |
+
import { from, fromEvent } from 'rxjs'
|
21 |
import { switchMap, map, tap } from 'rxjs/operators'
|
22 |
import { BaseType } from "d3";
|
|
|
|
|
23 |
|
24 |
|
25 |
function isNullToken(tok: tp.TokenEvent) {
|
|
|
67 |
sel.attr('disabled', val)
|
68 |
}
|
69 |
|
70 |
+
function createStaticSkeleton(base: D3Sel) {
|
71 |
+
|
72 |
+
/**
|
73 |
+
* Top level sections
|
74 |
+
*/
|
75 |
+
const sentenceInput = base.append('div')
|
76 |
+
.attr("id", "sentence-input")
|
77 |
+
|
78 |
+
const connectorContainer = base.append('div')
|
79 |
+
.attr('id', 'connector-container')
|
80 |
+
|
81 |
+
const atnControls = connectorContainer.append('div')
|
82 |
+
.attr("id", "connector-controls")
|
83 |
+
|
84 |
+
const atnContainer = connectorContainer.append('div')
|
85 |
+
.attr("id", "atn-container")
|
86 |
+
.classed("text-center", true)
|
87 |
+
|
88 |
+
/**
|
89 |
+
* Sentence Input
|
90 |
+
*/
|
91 |
+
|
92 |
+
const formGroup = sentenceInput.append('form')
|
93 |
+
.append('div').classed('form-group', true)
|
94 |
+
|
95 |
+
formGroup.append('label')
|
96 |
+
.attr('for', "form-sentence-a")
|
97 |
+
.text(' Input Sentence ')
|
98 |
+
|
99 |
+
const sentenceA = formGroup.append('input')
|
100 |
+
.attr('id', 'form-sentence-a')
|
101 |
+
.attr('type', 'text')
|
102 |
+
.attr('name', 'sent-a-input')
|
103 |
+
|
104 |
+
sentenceInput.append('div')
|
105 |
+
.classed('padding', true)
|
106 |
+
|
107 |
+
const formButton = sentenceInput.append('button')
|
108 |
+
.attr('class', 'btn btn-primary')
|
109 |
+
.attr('id', "update-sentence")
|
110 |
+
.attr('type', 'button')
|
111 |
+
|
112 |
+
formButton.text("Update")
|
113 |
+
|
114 |
+
/**
|
115 |
+
* Connector Controls
|
116 |
+
*/
|
117 |
+
const leftControlHalf = atnControls.append('div')
|
118 |
+
.classed('left-control-half', true)
|
119 |
+
|
120 |
+
const rightControlHalf = atnControls.append('div')
|
121 |
+
.attr('class', 'right-control-half head-control')
|
122 |
+
|
123 |
+
const modelSelection = leftControlHalf.append('div')
|
124 |
+
.attr('id', 'model-selection')
|
125 |
+
|
126 |
+
modelSelection.append('label')
|
127 |
+
.attr('for', 'model-options').text('Select model')
|
128 |
+
|
129 |
+
const modelSelector = modelSelection.append('select')
|
130 |
+
.attr('id', 'model-option-selector')
|
131 |
+
.attr('name', 'model-options')
|
132 |
+
|
133 |
+
const slideContainer = leftControlHalf.append('div')
|
134 |
+
.classed('slide-container', true)
|
135 |
+
|
136 |
+
slideContainer.append('label')
|
137 |
+
.attr('for', 'my-range')
|
138 |
+
.html("Display top <span id=\"my-range-value\">...</span>% of attention")
|
139 |
+
|
140 |
+
const threshSlider = slideContainer.append('input')
|
141 |
+
.attr('type', 'range')
|
142 |
+
.attr('min', '0')
|
143 |
+
.attr('max', '100')
|
144 |
+
.attr('value', '70')
|
145 |
+
.classed('slider', true)
|
146 |
+
.attr('id', 'my-range')
|
147 |
+
|
148 |
+
const layerSelection = leftControlHalf.append('div')
|
149 |
+
.attr('id', 'layer-selection')
|
150 |
+
|
151 |
+
layerSelection.append('div')
|
152 |
+
.classed('input-description', true)
|
153 |
+
.text("Layer: ")
|
154 |
+
|
155 |
+
const layerCheckboxes = layerSelection.append('div')
|
156 |
+
.attr('class', 'layer-select btn-group btn-group-toggle')
|
157 |
+
.attr('data-toggle', 'buttons')
|
158 |
+
.attr('id', 'layer-select')
|
159 |
+
|
160 |
+
const clsToggle = leftControlHalf.append('div')
|
161 |
+
.attr('id', 'cls-toggle')
|
162 |
+
|
163 |
+
clsToggle.append('div')
|
164 |
+
.attr('class', 'input-description')
|
165 |
+
.text("Hide Special Tokens")
|
166 |
+
|
167 |
+
const clsSwitch = clsToggle.append('label')
|
168 |
+
.attr('class', 'switch')
|
169 |
+
|
170 |
+
clsSwitch.append('input').attr('type', 'checkbox')
|
171 |
+
.attr('checked', 'checked')
|
172 |
+
|
173 |
+
clsSwitch.append('span')
|
174 |
+
.attr('class', 'short-slider round')
|
175 |
+
|
176 |
+
const selectedHeads = rightControlHalf.append('div')
|
177 |
+
.attr('id', 'selected-head-display')
|
178 |
+
|
179 |
+
selectedHeads.append('div')
|
180 |
+
.classed('input-description', true)
|
181 |
+
.text('Selected heads:')
|
182 |
+
|
183 |
+
selectedHeads.append('div').attr('id', 'selected-heads')
|
184 |
+
|
185 |
+
const headButtons = rightControlHalf.append('div')
|
186 |
+
.classed('select-input', true)
|
187 |
+
.attr('id', 'head-all-or-none')
|
188 |
+
|
189 |
+
const headSelectAll = headButtons.append('button').attr('id', 'select-all-heads').text("Select all heads")
|
190 |
+
const headSelectNone = headButtons.append('button').attr('id', 'select-no-heads').text("Unselect all heads")
|
191 |
+
|
192 |
+
const infoContainer = rightControlHalf.append('div')
|
193 |
+
.attr('id', 'usage-info')
|
194 |
+
|
195 |
+
infoContainer.append('p').html("You focus on one token by <b>click</b>.<br /> You can mask any token by <b>double click</b>.")
|
196 |
+
infoContainer.append('p').html("You can select and de-select a head by a <b>click</b> on the heatmap columns")
|
197 |
+
|
198 |
+
connectorContainer.append('div').attr('id', 'vis-break')
|
199 |
+
|
200 |
+
/**
|
201 |
+
* For main attention vis
|
202 |
+
*/
|
203 |
+
|
204 |
+
const headInfoBox = atnContainer.append('div')
|
205 |
+
.attr('id', "head-info-box")
|
206 |
+
.classed('mat-hover-display', true)
|
207 |
+
.classed('text-center', true)
|
208 |
+
.style('width', String(70) + 'px')
|
209 |
+
.style('height', String(30) + 'px')
|
210 |
+
.style('visibillity', 'hidden')
|
211 |
+
|
212 |
+
const headBoxLeft = atnContainer.append('svg')
|
213 |
+
.attr('id', 'left-att-heads')
|
214 |
+
|
215 |
+
const tokensLeft = atnContainer.append('div')
|
216 |
+
.attr("id", "left-tokens")
|
217 |
+
|
218 |
+
const atnDisplay = atnContainer.append('svg')
|
219 |
+
.attr("id", "atn-display")
|
220 |
+
|
221 |
+
const tokensRight = atnContainer.append('div')
|
222 |
+
.attr("id", "right-tokens")
|
223 |
+
|
224 |
+
const headBoxRight = atnContainer.append('svg')
|
225 |
+
.attr('id', 'right-att-heads')
|
226 |
+
|
227 |
+
/**
|
228 |
+
* Return an object that provides handles to the important parts here
|
229 |
+
*/
|
230 |
+
|
231 |
+
const pctSpan = base.select("#my-range-value")
|
232 |
+
|
233 |
+
const sels = {
|
234 |
+
body: d3.select('body'),
|
235 |
+
atnContainer: atnContainer,
|
236 |
+
atnDisplay: atnDisplay,
|
237 |
+
atnHeads: {
|
238 |
+
left: headBoxLeft,
|
239 |
+
right: headBoxRight,
|
240 |
+
headInfo: headInfoBox
|
241 |
+
},
|
242 |
+
form: {
|
243 |
+
sentenceA: sentenceA,
|
244 |
+
button: formButton
|
245 |
+
},
|
246 |
+
tokens: {
|
247 |
+
left: tokensLeft,
|
248 |
+
right: tokensRight
|
249 |
+
},
|
250 |
+
modelSelector: modelSelector,
|
251 |
+
clsToggle: clsToggle,
|
252 |
+
layerCheckboxes: layerCheckboxes,
|
253 |
+
selectedHeads: selectedHeads,
|
254 |
+
headSelectAll: headSelectAll,
|
255 |
+
headSelectNone: headSelectNone,
|
256 |
+
threshSlider: threshSlider,
|
257 |
+
}
|
258 |
+
return sels
|
259 |
+
}
|
260 |
|
261 |
export class MainGraphic {
|
262 |
+
base: D3Sel
|
263 |
api: API
|
264 |
uiConf: UIConfig
|
265 |
attCapsule: AttentionWrapper
|
|
|
268 |
vizs: any // Contains vis components wrapped around parent sel
|
269 |
eventHandler: SimpleEventHandler // Orchestrates events raised from components
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
/**
|
|
|
272 |
*
|
273 |
+
* @param base 'div' html element into which everything below will be rendered
|
274 |
*/
|
275 |
+
constructor(baseDiv: Element) {
|
276 |
+
this.base = d3.select(baseDiv)
|
277 |
+
this.api = new API()
|
278 |
+
this.uiConf = new UIConfig()
|
279 |
+
this.sels = createStaticSkeleton(this.base)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
+
this.eventHandler = new SimpleEventHandler(<Element>this.base.node());
|
282 |
|
283 |
this.vizs = {
|
284 |
leftHeads: new AttentionHeadBox(this.sels.atnHeads.left, this.eventHandler, { side: "left", }),
|
|
|
288 |
right: new RightTextToken(this.sels.tokens.right, this.eventHandler),
|
289 |
},
|
290 |
attentionSvg: new AttentionGraph(this.sels.atnDisplay, this.eventHandler),
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
}
|
292 |
|
293 |
this._bindEventHandler()
|
294 |
+
|
295 |
+
this.mainInit()
|
296 |
}
|
297 |
|
298 |
private mainInit() {
|
|
|
310 |
// Wrap postInit into function so asynchronous call does not mess with necessary inits
|
311 |
const postResponseDisplayCleanup = () => {
|
312 |
this._toggleTokenSel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
}
|
314 |
|
315 |
let normBy
|
316 |
if ((this.uiConf.modelKind() == tp.ModelKind.Autoregressive) && (!this.uiConf.hideClsSep())) {
|
317 |
+
normBy = tp.NormBy.COL
|
318 |
}
|
319 |
else {
|
320 |
+
normBy = tp.NormBy.ALL
|
321 |
}
|
322 |
this.vizs.attentionSvg.normBy = normBy
|
323 |
|
|
|
481 |
unselectHead(e.head)
|
482 |
}
|
483 |
|
|
|
484 |
this._renderHeadSummary();
|
485 |
this.renderSvg();
|
486 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
}
|
488 |
|
489 |
private _toggleTokenSel() {
|
|
|
513 |
this.grayToggle(+e.ind)
|
514 |
this.markNextToggle(+e.ind, this.tokCapsule.a.length())
|
515 |
}
|
|
|
|
|
516 |
}
|
517 |
|
518 |
/** Gray all tokens that have index greater than ind */
|
|
|
553 |
|
554 |
}
|
555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
private _staticInits() {
|
557 |
this._initSentenceForm();
|
558 |
this._initModelSelection();
|
|
|
|
|
|
|
559 |
this._renderHeadSummary();
|
|
|
560 |
this._initToggle();
|
561 |
this.renderAttHead();
|
562 |
this.renderTokens();
|
563 |
}
|
564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
private _initSentenceForm() {
|
566 |
const self = this;
|
567 |
|
|
|
613 |
inputBox.on('keypress', onEnterSubmit)
|
614 |
}
|
615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
616 |
private _renderHeadSummary() {
|
617 |
this.sels.selectedHeads
|
618 |
.html(R.join(', ', this.uiConf.heads().map(h => h + 1)))
|
619 |
}
|
620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
private initLayers(nLayers: number) {
|
622 |
const self = this;
|
623 |
let hasActive = false;
|
|
|
694 |
|
695 |
this.sels.headSelectAll.on("click", function () {
|
696 |
self.uiConf.selectAllHeads();
|
|
|
697 |
self.renderSvg()
|
698 |
self.renderAttHead()
|
699 |
})
|
700 |
|
701 |
this.sels.headSelectNone.on("click", function () {
|
702 |
self.uiConf.selectNoHeads();
|
|
|
703 |
self.renderSvg()
|
704 |
self.renderAttHead()
|
705 |
Sel.setHidden(".atn-curve")
|
|
|
721 |
})
|
722 |
}
|
723 |
|
724 |
+
private _initModelSelection() {
|
725 |
+
const self = this
|
726 |
+
|
727 |
+
// Below are the available models. Will need to choose 3 to be available ONLY
|
728 |
+
const data = [
|
729 |
+
{ name: "bert-base-cased", kind: tp.ModelKind.Bidirectional },
|
730 |
+
{ name: "bert-base-uncased", kind: tp.ModelKind.Bidirectional },
|
731 |
+
{ name: "distilbert-base-uncased", kind: tp.ModelKind.Bidirectional },
|
732 |
+
{ name: "distilroberta-base", kind: tp.ModelKind.Bidirectional },
|
733 |
+
// { name: "roberta-base", kind: tp.ModelKind.Bidirectional },
|
734 |
+
{ name: "gpt2", kind: tp.ModelKind.Autoregressive },
|
735 |
+
// { name: "gpt2-medium", kind: tp.ModelKind.Autoregressive },
|
736 |
+
// { name: "distilgpt2", kind: tp.ModelKind.Autoregressive },
|
737 |
+
]
|
738 |
+
|
739 |
+
const names = R.map(R.prop('name'))(data)
|
740 |
+
const kinds = R.map(R.prop('kind'))(data)
|
741 |
+
const kindmap = R.zipObj(names, kinds)
|
742 |
+
|
743 |
+
this.sels.modelSelector.selectAll('.model-option')
|
744 |
+
.data(data)
|
745 |
+
.join('option')
|
746 |
+
.classed('model-option', true)
|
747 |
+
.property('value', d => d.name)
|
748 |
+
.attr("modelkind", d => d.kind)
|
749 |
+
.text(d => d.name)
|
750 |
+
|
751 |
+
this.sels.modelSelector.property('value', this.uiConf.model());
|
752 |
+
|
753 |
+
this.sels.modelSelector.on('change', function () {
|
754 |
+
const me = d3.select(this)
|
755 |
+
const mname = me.property('value')
|
756 |
+
self.uiConf.model(mname);
|
757 |
+
self.uiConf.modelKind(kindmap[mname]);
|
758 |
+
if (kindmap[mname] == tp.ModelKind.Autoregressive) {
|
759 |
+
console.log("RESETTING MASK INDS");
|
760 |
+
self.uiConf.maskInds([])
|
761 |
+
}
|
762 |
+
self.mainInit();
|
763 |
+
})
|
764 |
+
}
|
765 |
+
|
766 |
renderAttHead() {
|
767 |
const heads = _.range(0, this.uiConf._nHeads)
|
768 |
const focusAtt = this.attCapsule.att
|
|
|
822 |
this.render();
|
823 |
}
|
824 |
}
|
825 |
+
|
826 |
+
|
server/main.py
CHANGED
@@ -7,7 +7,6 @@ from flask import render_template, redirect, send_from_directory
|
|
7 |
import utils.path_fixes as pf
|
8 |
from utils.f import ifnone
|
9 |
|
10 |
-
from data_processing import from_model
|
11 |
from model_api import get_details
|
12 |
|
13 |
app = connexion.FlaskApp(__name__, static_folder="client/dist", specification_dir=".")
|
@@ -101,7 +100,7 @@ def get_attentions_and_preds(**request):
|
|
101 |
sentence = request["sentence"]
|
102 |
layer = int(request["layer"])
|
103 |
|
104 |
-
deets = details.
|
105 |
|
106 |
payload_out = deets.to_json(layer)
|
107 |
|
@@ -152,14 +151,14 @@ def update_masked_attention(**request):
|
|
152 |
mask = payload["mask"]
|
153 |
layer = int(payload["layer"])
|
154 |
|
155 |
-
MASK = details.
|
156 |
mask_tokens = lambda toks, maskinds: [
|
157 |
t if i not in maskinds else ifnone(MASK, t) for (i, t) in enumerate(toks)
|
158 |
]
|
159 |
|
160 |
token_inputs = mask_tokens(tokens, mask)
|
161 |
|
162 |
-
deets = details.
|
163 |
payload_out = deets.to_json(layer)
|
164 |
|
165 |
return {
|
|
|
7 |
import utils.path_fixes as pf
|
8 |
from utils.f import ifnone
|
9 |
|
|
|
10 |
from model_api import get_details
|
11 |
|
12 |
app = connexion.FlaskApp(__name__, static_folder="client/dist", specification_dir=".")
|
|
|
100 |
sentence = request["sentence"]
|
101 |
layer = int(request["layer"])
|
102 |
|
103 |
+
deets = details.from_sentence(sentence)
|
104 |
|
105 |
payload_out = deets.to_json(layer)
|
106 |
|
|
|
151 |
mask = payload["mask"]
|
152 |
layer = int(payload["layer"])
|
153 |
|
154 |
+
MASK = details.tok.mask_token
|
155 |
mask_tokens = lambda toks, maskinds: [
|
156 |
t if i not in maskinds else ifnone(MASK, t) for (i, t) in enumerate(toks)
|
157 |
]
|
158 |
|
159 |
token_inputs = mask_tokens(tokens, mask)
|
160 |
|
161 |
+
deets = details.from_tokens(token_inputs, sentence)
|
162 |
payload_out = deets.to_json(layer)
|
163 |
|
164 |
return {
|
server/swagger.yaml
CHANGED
@@ -32,7 +32,7 @@ paths:
|
|
32 |
/attend+meta:
|
33 |
get:
|
34 |
tags: [All]
|
35 |
-
operationId: main.
|
36 |
summary: Get the attention information, BERT Embeddings, and spacy meta info for an input sentence
|
37 |
parameters:
|
38 |
- name: model
|
@@ -66,81 +66,6 @@ paths:
|
|
66 |
200:
|
67 |
description: Update BERT's masked behavior for passed tokens
|
68 |
|
69 |
-
/k-nearest-embeddings:
|
70 |
-
get:
|
71 |
-
tags: [All]
|
72 |
-
operationId: main.nearest_embedding_search
|
73 |
-
summary: Search for the nearest embeddings to a token sent from the frontend by layer
|
74 |
-
parameters:
|
75 |
-
- name: model
|
76 |
-
description: Which model to get information from
|
77 |
-
in: query
|
78 |
-
type: string
|
79 |
-
- name: corpus
|
80 |
-
description: Which corpus to search
|
81 |
-
in: query
|
82 |
-
type: string
|
83 |
-
- name: embedding
|
84 |
-
description: Query vector on which to search the dataset
|
85 |
-
in: query
|
86 |
-
type: array
|
87 |
-
items:
|
88 |
-
type: number
|
89 |
-
- name: layer
|
90 |
-
description: Which layer to search the nearest for
|
91 |
-
in: query
|
92 |
-
type: number
|
93 |
-
- name: heads
|
94 |
-
description: List of heads to search for
|
95 |
-
in: query
|
96 |
-
type: array
|
97 |
-
items:
|
98 |
-
type: number
|
99 |
-
- name: k
|
100 |
-
description: How many nearest neighbors to grab
|
101 |
-
in: query
|
102 |
-
type: number
|
103 |
-
responses:
|
104 |
-
200:
|
105 |
-
description: Return related embeddings and associated metadata
|
106 |
-
|
107 |
-
/k-nearest-contexts:
|
108 |
-
get:
|
109 |
-
tags: [All]
|
110 |
-
operationId: main.nearest_context_search
|
111 |
-
summary: Search for the nearest embeddings BY SELECTED HEADS to a token sent from the frontend by layer
|
112 |
-
parameters:
|
113 |
-
- name: model
|
114 |
-
description: Which model to get information from
|
115 |
-
in: query
|
116 |
-
type: string
|
117 |
-
- name: corpus
|
118 |
-
description: Which corpus to search
|
119 |
-
in: query
|
120 |
-
type: string
|
121 |
-
- name: context
|
122 |
-
description: Query vector on which to search the dataset
|
123 |
-
in: query
|
124 |
-
type: array
|
125 |
-
items:
|
126 |
-
type: number
|
127 |
-
- name: layer
|
128 |
-
description: Which layer to search the nearest for
|
129 |
-
in: query
|
130 |
-
type: number
|
131 |
-
- name: heads
|
132 |
-
description: List of heads to search for
|
133 |
-
in: query
|
134 |
-
type: array
|
135 |
-
items:
|
136 |
-
type: number
|
137 |
-
- name: k
|
138 |
-
description: How many nearest neighbors to grab
|
139 |
-
in: query
|
140 |
-
type: number
|
141 |
-
responses:
|
142 |
-
200:
|
143 |
-
description: Return related embeddings by that head and the associated metadata
|
144 |
|
145 |
definitions:
|
146 |
maskPayload:
|
|
|
32 |
/attend+meta:
|
33 |
get:
|
34 |
tags: [All]
|
35 |
+
operationId: main.get_attentions_and_preds
|
36 |
summary: Get the attention information, BERT Embeddings, and spacy meta info for an input sentence
|
37 |
parameters:
|
38 |
- name: model
|
|
|
66 |
200:
|
67 |
description: Update BERT's masked behavior for passed tokens
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
definitions:
|
71 |
maskPayload:
|
server/transformer_formatter.py
CHANGED
@@ -65,9 +65,19 @@ class TransformerOutputFormatter:
|
|
65 |
self.topk_probs = topk_probs
|
66 |
self.model_config = model_config
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
self.__len = len(tokens)# Get the number of tokens in the input
|
73 |
assert self.__len == self.attentions[0].shape[-1], "Attentions don't represent the passed tokens!"
|
|
|
65 |
self.topk_probs = topk_probs
|
66 |
self.model_config = model_config
|
67 |
|
68 |
+
try:
|
69 |
+
# GPT vals
|
70 |
+
self.n_layer = self.model_config.n_layer
|
71 |
+
self.n_head = self.model_config.n_head
|
72 |
+
self.hidden_dim = self.model_config.n_embd
|
73 |
+
except AttributeError:
|
74 |
+
try:
|
75 |
+
# BERT vals
|
76 |
+
self.n_layer = self.model_config.num_hidden_layers
|
77 |
+
self.n_head = self.model_config.num_attention_heads
|
78 |
+
self.hidden_dim = self.model_config.hidden_size
|
79 |
+
except AttributeError: raise
|
80 |
+
|
81 |
|
82 |
self.__len = len(tokens)# Get the number of tokens in the input
|
83 |
assert self.__len == self.attentions[0].shape[-1], "Attentions don't represent the passed tokens!"
|
server/utils/path_fixes.py
CHANGED
@@ -5,7 +5,7 @@ FAISS_LAYER_PATTERN = 'layer_*.faiss'
|
|
5 |
LAYER_TEMPLATE = 'layer_{:02d}.faiss'
|
6 |
|
7 |
ROOT_DIR = Path(os.path.abspath(__file__)).parent.parent.parent
|
8 |
-
CORPORA =
|
9 |
DATA_DIR = ROOT_DIR / 'server' / 'data'
|
10 |
DATASET_DIR = Path.home() / 'Datasets'
|
11 |
ROOT_DIR = Path(os.path.abspath(__file__)).parent.parent.parent
|
|
|
5 |
LAYER_TEMPLATE = 'layer_{:02d}.faiss'
|
6 |
|
7 |
ROOT_DIR = Path(os.path.abspath(__file__)).parent.parent.parent
|
8 |
+
CORPORA = ROOT_DIR / "corpora"
|
9 |
DATA_DIR = ROOT_DIR / 'server' / 'data'
|
10 |
DATASET_DIR = Path.home() / 'Datasets'
|
11 |
ROOT_DIR = Path(os.path.abspath(__file__)).parent.parent.parent
|