rwkv-rs / app.py
Alexandr "MrSteyk" German
remove array
73ae988
raw
history blame contribute delete
No virus
7.53 kB
import os
import rwkv_rs
import numpy as np
import huggingface_hub
import tokenizers
import gradio as gr
model_path = "./rnn.safetensors"
if not os.path.exists(model_path):
model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors")
assert model_path is not None
model = rwkv_rs.Rwkv(model_path)
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
GT = [
gr.Button.update(visible=False),
gr.Button.update(visible=True),
]
GF = [
gr.Button.update(visible=True),
gr.Button.update(visible=False),
]
def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
try:
state = rwkv_rs.State(model)
text = inpt
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(inpt).ids
yield (None, gr.Text.update(visible=False))
# yield ("Preproc...", gr.Text.update(visible=False))
# logits = model.forward(tokens, state)
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
# finally:
# return (None, None)
def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert):
try:
if inpt.count("<|INSERT|>") != 1:
yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True))
return
state = rwkv_rs.State(model)
text, end = inpt.split("<|INSERT|>")
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(text).ids
tokens_end = tokenizer.encode(end).ids
tokens_i = tokens_end[:num_tokens_insert]
ins = [0]*len(tokens_i)
yield (None, gr.Text.update(visible=False))
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
ins = ins[1:] + [token]
if ins == tokens_i:
tokens += tokens_end[num_tokens_insert:]
i = max_tokens - 1 # to break earlier...
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
def classify_fn_inner2(inpt, clas):
state = rwkv_rs.State(model)
tokens = tokenizer.encode(f"This is an example of {clas} text:").ids
for i in tokens:
model.forward_token_preproc(i, state)
tokens = tokenizer.encode(f" {inpt}\n").ids
loss = 0
for i in range(len(tokens)-1):
loss += np.log(softmax(model.forward_token(tokens[i], state)))[tokens[i+1]]
loss = -loss / (len(tokens)-1)
return loss
def softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum()
def classify_fn(inpt, clas, clasneg):
loss_3 = classify_fn_inner2(inpt, clas)
loss_3_neg = classify_fn_inner2(inpt, clasneg)
# print(loss_3, loss_3_neg, end=' | ')
loss_3, loss_3_neg = softmax([-loss_3, -loss_3_neg])
# print(loss_3, loss_3_neg)
return ({"+": loss_3, "-": loss_3_neg})
def generator_wrap(l, fn):
def wrap(*args):
last_i = list([None] * l)
try:
for i in fn(*args):
last_i = list(i)
yield last_i + GT
finally:
yield last_i + GF
return wrap
with gr.Blocks() as app:
gr.Markdown(f"Running on `{model_path}`")
error_box = gr.Text(label="Error", visible=False)
with gr.Tab("Complete"):
with gr.Row():
inpt = gr.TextArea(label="Input")
out = gr.TextArea(label="Output")
complete = gr.Button("Complete", variant="primary")
c_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Insert"):
gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated")
with gr.Row():
inpt_i = gr.TextArea(label="Input")
out_i = gr.TextArea(label="Output")
num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
insert = gr.Button("Insert", variant="primary")
i_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Classification W/O head"):
gr.Markdown("This is an experimental classification attempt based on [this Twitter post](https://twitter.com/aicrumb/status/1625239547268280321)\n\nSettings at the bottom do no affect this example.")
with gr.Row():
inpt_c = gr.TextArea(label="Input")
out_c = gr.Label(label="Output")
clas = gr.Textbox(label="+ NL class/example to check against.")
clasneg = gr.Textbox(label="- NL class/example to check against.")
classify = gr.Button("Classify", variant="primary")
with gr.Column():
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1)
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop])
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
classify.click(classify_fn, [inpt_c, clas, clasneg], [out_c])
app.queue(concurrency_count=2)
app.launch()