rwkv-rs / app.py
Alexandr "MrSteyk" German
blergh
caff12e
raw
history blame
8.91 kB
import os
import rwkv_rs
import numpy as np
import huggingface_hub
import tokenizers
import gradio as gr
model_path = "./rnn.safetensors"
if not os.path.exists(model_path):
model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors")
assert model_path is not None
model = rwkv_rs.Rwkv(model_path)
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
GT = [
gr.Button.update(visible=False),
gr.Button.update(visible=True),
]
GF = [
gr.Button.update(visible=True),
gr.Button.update(visible=False),
]
def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
try:
state = rwkv_rs.State(model)
text = inpt
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(inpt).ids
yield (None, gr.Text.update(visible=False))
# yield ("Preproc...", gr.Text.update(visible=False))
# logits = model.forward(tokens, state)
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
# finally:
# return (None, None)
def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert):
try:
if inpt.count("<|INSERT|>") != 1:
yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True))
return
state = rwkv_rs.State(model)
text, end = inpt.split("<|INSERT|>")
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(text).ids
tokens_end = tokenizer.encode(end).ids
tokens_i = tokens_end[:num_tokens_insert]
ins = [0]*len(tokens_i)
yield (None, gr.Text.update(visible=False))
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
ins = ins[1:] + [token]
if ins == tokens_i:
tokens += tokens_end[num_tokens_insert:]
i = max_tokens - 1 # to break earlier...
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
# def classify_fn_inner(inpt, clas):
# state = rwkv_rs.State(model)
# tokens = tokenizer.encode(f"This is an example of {clas} text: {inpt}").ids
# for i in tokens[:-2]:
# model.forward_token_preproc(i, state)
# # state_2 = state.copy()
# logit_x_1 = softmax(model.forward_token(tokens[-2], state))
# logit_y_1 = softmax(model.forward_token(tokens[-1], state))
# # shapep = logit_x_1.shape[0] * 0.9
# # s = np.sort(logit_y_1)[::-1]
# # c = s[np.argmax(np.cumsum(s) > 0.9)]
# # logit_y_1[logit_y_1 < c] = 0
# loss_1 = -np.sum(logit_y_1 * np.log(logit_x_1)) / logit_x_1.shape[0]
# # I forgor that I do not return the preproc shit...
# # logit_x_2 = model.forward_token_preproc(tokens[-2], state_2)
# # logit_y_2 = model.forward_token_preproc(tokens[-1], state_2)
# return (loss_1, None)
def classify_fn_inner2(inpt, clas):
state = rwkv_rs.State(model)
tokens = tokenizer.encode(f"This is an example of {clas} text:").ids
for i in tokens:
model.forward_token_preproc(i, state)
logits = []
tokens = tokenizer.encode(f" {inpt}\n").ids
for i in tokens[:-1]:
logits.append(model.forward_token(i, state))
logit_x = [softmax(i) for i in logits]
loss = -np.sum([ x[y] for x, y in zip(logit_x, tokens[1:]) ]) / len(logit_x)
return loss
def softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum()
# TODO: maybe make a function with pos/neg inputs?
def classify_fn(inpt, clas, clasneg):
# loss_1, loss_2 = classify_fn_inner(inpt, clas)
# loss_1_neg, loss_2_neg = classify_fn_inner(inpt, clasneg)
# print(loss_1, loss_1_neg, end=' | ')
# # We negate the loss because we want to know who's closer to 0
# loss_1, loss_1_neg = softmax([-loss_1, -loss_1_neg])
# print(loss_1, loss_1_neg)
loss_3 = classify_fn_inner2(inpt, clas)
loss_3_neg = classify_fn_inner2(inpt, clasneg)
print(loss_3, loss_3_neg, end=' | ')
loss_3, loss_3_neg = softmax([-loss_3, -loss_3_neg])
print(loss_3, loss_3_neg)
# return ({"v1_pos": loss_1, "v1_neg": loss_1_neg, "v3_pos": loss_3, "v3_neg": loss_3_neg})
return ({"+": loss_3, "-": loss_3_neg})
def generator_wrap(l, fn):
def wrap(*args):
last_i = list([None] * l)
try:
for i in fn(*args):
last_i = list(i)
yield last_i + GT
finally:
yield last_i + GF
return wrap
with gr.Blocks() as app:
gr.Markdown(f"Running on `{model_path}`")
error_box = gr.Text(label="Error", visible=False)
with gr.Tab("Complete"):
with gr.Row():
inpt = gr.TextArea(label="Input")
out = gr.TextArea(label="Output")
complete = gr.Button("Complete", variant="primary")
c_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Insert"):
gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated")
with gr.Row():
inpt_i = gr.TextArea(label="Input")
out_i = gr.TextArea(label="Output")
num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
insert = gr.Button("Insert", variant="primary")
i_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Classification W/O head"):
gr.Markdown("This is an experimental classification attempt based on [this Twitter post](https://twitter.com/aicrumb/status/1625239547268280321)\n\nSettings at the bottom do no affect this example.")
with gr.Row():
inpt_c = gr.TextArea(label="Input")
out_c = gr.Label(label="Output")
clas = gr.Textbox(label="+ NL class/example to check against.")
clasneg = gr.Textbox(label="- NL class/example to check against.")
classify = gr.Button("Classify", variant="primary")
with gr.Column():
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1)
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop])
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
classify.click(classify_fn, [inpt_c, clas, clasneg], [out_c])
app.queue(concurrency_count=2)
app.launch()