import os import rwkv_rs import numpy as np import huggingface_hub import tokenizers import gradio as gr model_path = "./rnn.safetensors" if not os.path.exists(model_path): model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors") assert model_path is not None model = rwkv_rs.Rwkv(model_path) tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b") GT = [ gr.Button.update(visible=False), gr.Button.update(visible=True), ] GF = [ gr.Button.update(visible=True), gr.Button.update(visible=False), ] def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p): try: state = rwkv_rs.State(model) text = inpt counts = [0]*tokenizer.get_vocab_size() tokens = tokenizer.encode(inpt).ids yield (None, gr.Text.update(visible=False)) # yield ("Preproc...", gr.Text.update(visible=False)) # logits = model.forward(tokens, state) for i in range(len(tokens) - 1): model.forward_token_preproc(tokens[i], state) yield (tokenizer.decode(tokens[:i + 1]), None) logits = model.forward_token(tokens[-1], state) yield (text, None) max_tokens = int(max_tokens) for i in range(max_tokens): if i < min_tokens: logits[0] = -100 for i in range(len(counts)): logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p) token = np.argmax(logits) counts[token] += 1 if token == 0: break tokens += [token] text = tokenizer.decode(tokens) yield (text, None) if i == max_tokens - 1: break logits = model.forward_token(token, state) yield (text, None) except Exception as e: print(e) yield ("Error...", gr.Text.update(value=str(e), visible=True)) # finally: # return (None, None) def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert): try: if inpt.count("<|INSERT|>") != 1: yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True)) return state = rwkv_rs.State(model) text, end = inpt.split("<|INSERT|>") counts = [0]*tokenizer.get_vocab_size() tokens = tokenizer.encode(text).ids tokens_end = tokenizer.encode(end).ids tokens_i = tokens_end[:num_tokens_insert] ins = [0]*len(tokens_i) yield (None, gr.Text.update(visible=False)) for i in range(len(tokens) - 1): model.forward_token_preproc(tokens[i], state) yield (tokenizer.decode(tokens[:i + 1]), None) logits = model.forward_token(tokens[-1], state) yield (text, None) max_tokens = int(max_tokens) for i in range(max_tokens): if i < min_tokens: logits[0] = -100 for i in range(len(counts)): logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p) token = np.argmax(logits) counts[token] += 1 if token == 0: break tokens += [token] ins = ins[1:] + [token] if ins == tokens_i: tokens += tokens_end[num_tokens_insert:] i = max_tokens - 1 # to break earlier... text = tokenizer.decode(tokens) yield (text, None) if i == max_tokens - 1: break logits = model.forward_token(token, state) yield (text, None) except Exception as e: print(e) yield ("Error...", gr.Text.update(value=str(e), visible=True)) def classify_fn_inner2(inpt, clas): state = rwkv_rs.State(model) tokens = tokenizer.encode(f"This is an example of {clas} text:").ids for i in tokens: model.forward_token_preproc(i, state) tokens = tokenizer.encode(f" {inpt}\n").ids loss = 0 for i in range(len(tokens)-1): loss += np.log(softmax(model.forward_token(tokens[i], state)))[tokens[i+1]] loss = -loss / (len(tokens)-1) return loss def softmax(x): e = np.exp(x - np.max(x)) return e / e.sum() def classify_fn(inpt, clas, clasneg): loss_3 = classify_fn_inner2(inpt, clas) loss_3_neg = classify_fn_inner2(inpt, clasneg) # print(loss_3, loss_3_neg, end=' | ') loss_3, loss_3_neg = softmax([-loss_3, -loss_3_neg]) # print(loss_3, loss_3_neg) return ({"+": loss_3, "-": loss_3_neg}) def generator_wrap(l, fn): def wrap(*args): last_i = list([None] * l) try: for i in fn(*args): last_i = list(i) yield last_i + GT finally: yield last_i + GF return wrap with gr.Blocks() as app: gr.Markdown(f"Running on `{model_path}`") error_box = gr.Text(label="Error", visible=False) with gr.Tab("Complete"): with gr.Row(): inpt = gr.TextArea(label="Input") out = gr.TextArea(label="Output") complete = gr.Button("Complete", variant="primary") c_stop = gr.Button("Stop", variant="stop", visible=False) with gr.Tab("Insert"): gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated") with gr.Row(): inpt_i = gr.TextArea(label="Input") out_i = gr.TextArea(label="Output") num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1) insert = gr.Button("Insert", variant="primary") i_stop = gr.Button("Stop", variant="stop", visible=False) with gr.Tab("Classification W/O head"): gr.Markdown("This is an experimental classification attempt based on [this Twitter post](https://twitter.com/aicrumb/status/1625239547268280321)\n\nSettings at the bottom do no affect this example.") with gr.Row(): inpt_c = gr.TextArea(label="Input") out_c = gr.Label(label="Output") clas = gr.Textbox(label="+ NL class/example to check against.") clasneg = gr.Textbox(label="- NL class/example to check against.") classify = gr.Button("Classify", variant="primary") with gr.Column(): max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767) min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1) alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01) alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01) c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop]) c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False) i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop]) i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False) classify.click(classify_fn, [inpt_c, clas, clasneg], [out_c]) app.queue(concurrency_count=2) app.launch()