Alexandr "MrSteyk" German commited on
Commit
caff12e
1 Parent(s): dfb402a
Files changed (1) hide show
  1. app.py +68 -0
app.py CHANGED
@@ -104,6 +104,64 @@ def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_in
104
  print(e)
105
  yield ("Error...", gr.Text.update(value=str(e), visible=True))
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def generator_wrap(l, fn):
108
  def wrap(*args):
109
  last_i = list([None] * l)
@@ -134,6 +192,14 @@ with gr.Blocks() as app:
134
  num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
135
  insert = gr.Button("Insert", variant="primary")
136
  i_stop = gr.Button("Stop", variant="stop", visible=False)
 
 
 
 
 
 
 
 
137
 
138
  with gr.Column():
139
  max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
@@ -147,5 +213,7 @@ with gr.Blocks() as app:
147
  i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
148
  i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
149
 
 
 
150
  app.queue(concurrency_count=2)
151
  app.launch()
 
104
  print(e)
105
  yield ("Error...", gr.Text.update(value=str(e), visible=True))
106
 
107
+ # def classify_fn_inner(inpt, clas):
108
+ # state = rwkv_rs.State(model)
109
+ # tokens = tokenizer.encode(f"This is an example of {clas} text: {inpt}").ids
110
+ # for i in tokens[:-2]:
111
+ # model.forward_token_preproc(i, state)
112
+ # # state_2 = state.copy()
113
+
114
+ # logit_x_1 = softmax(model.forward_token(tokens[-2], state))
115
+ # logit_y_1 = softmax(model.forward_token(tokens[-1], state))
116
+ # # shapep = logit_x_1.shape[0] * 0.9
117
+ # # s = np.sort(logit_y_1)[::-1]
118
+ # # c = s[np.argmax(np.cumsum(s) > 0.9)]
119
+ # # logit_y_1[logit_y_1 < c] = 0
120
+ # loss_1 = -np.sum(logit_y_1 * np.log(logit_x_1)) / logit_x_1.shape[0]
121
+
122
+ # # I forgor that I do not return the preproc shit...
123
+ # # logit_x_2 = model.forward_token_preproc(tokens[-2], state_2)
124
+ # # logit_y_2 = model.forward_token_preproc(tokens[-1], state_2)
125
+ # return (loss_1, None)
126
+
127
+ def classify_fn_inner2(inpt, clas):
128
+ state = rwkv_rs.State(model)
129
+ tokens = tokenizer.encode(f"This is an example of {clas} text:").ids
130
+ for i in tokens:
131
+ model.forward_token_preproc(i, state)
132
+
133
+ logits = []
134
+ tokens = tokenizer.encode(f" {inpt}\n").ids
135
+ for i in tokens[:-1]:
136
+ logits.append(model.forward_token(i, state))
137
+ logit_x = [softmax(i) for i in logits]
138
+ loss = -np.sum([ x[y] for x, y in zip(logit_x, tokens[1:]) ]) / len(logit_x)
139
+
140
+ return loss
141
+
142
+ def softmax(x):
143
+ e = np.exp(x - np.max(x))
144
+ return e / e.sum()
145
+
146
+ # TODO: maybe make a function with pos/neg inputs?
147
+ def classify_fn(inpt, clas, clasneg):
148
+ # loss_1, loss_2 = classify_fn_inner(inpt, clas)
149
+ # loss_1_neg, loss_2_neg = classify_fn_inner(inpt, clasneg)
150
+
151
+ # print(loss_1, loss_1_neg, end=' | ')
152
+ # # We negate the loss because we want to know who's closer to 0
153
+ # loss_1, loss_1_neg = softmax([-loss_1, -loss_1_neg])
154
+ # print(loss_1, loss_1_neg)
155
+
156
+ loss_3 = classify_fn_inner2(inpt, clas)
157
+ loss_3_neg = classify_fn_inner2(inpt, clasneg)
158
+ print(loss_3, loss_3_neg, end=' | ')
159
+ loss_3, loss_3_neg = softmax([-loss_3, -loss_3_neg])
160
+ print(loss_3, loss_3_neg)
161
+
162
+ # return ({"v1_pos": loss_1, "v1_neg": loss_1_neg, "v3_pos": loss_3, "v3_neg": loss_3_neg})
163
+ return ({"+": loss_3, "-": loss_3_neg})
164
+
165
  def generator_wrap(l, fn):
166
  def wrap(*args):
167
  last_i = list([None] * l)
 
192
  num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
193
  insert = gr.Button("Insert", variant="primary")
194
  i_stop = gr.Button("Stop", variant="stop", visible=False)
195
+ with gr.Tab("Classification W/O head"):
196
+ gr.Markdown("This is an experimental classification attempt based on [this Twitter post](https://twitter.com/aicrumb/status/1625239547268280321)\n\nSettings at the bottom do no affect this example.")
197
+ with gr.Row():
198
+ inpt_c = gr.TextArea(label="Input")
199
+ out_c = gr.Label(label="Output")
200
+ clas = gr.Textbox(label="+ NL class/example to check against.")
201
+ clasneg = gr.Textbox(label="- NL class/example to check against.")
202
+ classify = gr.Button("Classify", variant="primary")
203
 
204
  with gr.Column():
205
  max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
 
213
  i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
214
  i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
215
 
216
+ classify.click(classify_fn, [inpt_c, clas, clasneg], [out_c])
217
+
218
  app.queue(concurrency_count=2)
219
  app.launch()