Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
@@ -25,24 +25,36 @@ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
|
25 |
|
26 |
args = model.args
|
27 |
eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
|
28 |
-
chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
|
29 |
eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
|
30 |
-
chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
|
31 |
state_eng_raw = torch.load(eng_file)
|
32 |
-
state_chn_raw = torch.load(chn_file)
|
33 |
state_eng = [None] * args.n_layer * 3
|
|
|
|
|
|
|
|
|
34 |
state_chn = [None] * args.n_layer * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
for i in range(args.n_layer):
|
36 |
dd = model.strategy[i]
|
37 |
dev = dd.device
|
38 |
atype = dd.atype
|
39 |
state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
40 |
-
state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
41 |
state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
42 |
-
state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
43 |
state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
|
|
|
|
|
|
44 |
state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
45 |
|
|
|
|
|
|
|
|
|
46 |
def generate_prompt(instruction, input=""):
|
47 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
48 |
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
|
@@ -208,6 +220,56 @@ def evaluate_chn(
|
|
208 |
torch.cuda.empty_cache()
|
209 |
yield out_str.strip()
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
examples = [
|
212 |
["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
213 |
["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
@@ -242,6 +304,14 @@ examples_chn = [
|
|
242 |
["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
243 |
]
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
##########################################################################
|
246 |
|
247 |
with gr.Blocks(title=title) as demo:
|
@@ -307,6 +377,26 @@ with gr.Blocks(title=title) as demo:
|
|
307 |
clear.click(lambda: None, [], [output])
|
308 |
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
demo.queue(concurrency_count=1, max_size=10)
|
312 |
demo.launch(share=False)
|
|
|
25 |
|
26 |
args = model.args
|
27 |
eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
|
|
|
28 |
eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
|
|
|
29 |
state_eng_raw = torch.load(eng_file)
|
|
|
30 |
state_eng = [None] * args.n_layer * 3
|
31 |
+
|
32 |
+
chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
|
33 |
+
chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
|
34 |
+
state_chn_raw = torch.load(chn_file)
|
35 |
state_chn = [None] * args.n_layer * 3
|
36 |
+
|
37 |
+
wyw_name = 'rwkv-x060-chn_文言文和古典名著_single_round_qa-7B-20240601-ctx2048'
|
38 |
+
wyw_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{wyw_name}.pth")
|
39 |
+
state_wyw_raw = torch.load(wyw_file)
|
40 |
+
state_wyw = [None] * args.n_layer * 3
|
41 |
+
|
42 |
for i in range(args.n_layer):
|
43 |
dd = model.strategy[i]
|
44 |
dev = dd.device
|
45 |
atype = dd.atype
|
46 |
state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
|
|
47 |
state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
|
|
48 |
state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
49 |
+
|
50 |
+
state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
51 |
+
state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
52 |
state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
53 |
|
54 |
+
state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
55 |
+
state_wyw[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
56 |
+
state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
57 |
+
|
58 |
def generate_prompt(instruction, input=""):
|
59 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
60 |
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
|
|
|
220 |
torch.cuda.empty_cache()
|
221 |
yield out_str.strip()
|
222 |
|
223 |
+
def evaluate_wyw(
|
224 |
+
ctx,
|
225 |
+
token_count=gen_limit,
|
226 |
+
temperature=1.0,
|
227 |
+
top_p=0.3,
|
228 |
+
presencePenalty=0.3,
|
229 |
+
countPenalty=0.3,
|
230 |
+
):
|
231 |
+
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
|
232 |
+
alpha_frequency = countPenalty,
|
233 |
+
alpha_presence = presencePenalty,
|
234 |
+
token_ban = [], # ban the generation of some tokens
|
235 |
+
token_stop = [0]) # stop generation whenever you see any token here
|
236 |
+
ctx = qa_prompt(ctx)
|
237 |
+
all_tokens = []
|
238 |
+
out_last = 0
|
239 |
+
out_str = ''
|
240 |
+
occurrence = {}
|
241 |
+
state = copy.deepcopy(state_wyw)
|
242 |
+
for i in range(int(token_count)):
|
243 |
+
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
|
244 |
+
for n in occurrence:
|
245 |
+
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
246 |
+
|
247 |
+
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
|
248 |
+
if token in args.token_stop:
|
249 |
+
break
|
250 |
+
all_tokens += [token]
|
251 |
+
for xxx in occurrence:
|
252 |
+
occurrence[xxx] *= penalty_decay
|
253 |
+
if token not in occurrence:
|
254 |
+
occurrence[token] = 1
|
255 |
+
else:
|
256 |
+
occurrence[token] += 1
|
257 |
+
|
258 |
+
tmp = pipeline.decode(all_tokens[out_last:])
|
259 |
+
if '\ufffd' not in tmp:
|
260 |
+
out_str += tmp
|
261 |
+
yield out_str.strip()
|
262 |
+
out_last = i + 1
|
263 |
+
|
264 |
+
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
265 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
266 |
+
print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
|
267 |
+
del out
|
268 |
+
del state
|
269 |
+
gc.collect()
|
270 |
+
torch.cuda.empty_cache()
|
271 |
+
yield out_str.strip()
|
272 |
+
|
273 |
examples = [
|
274 |
["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
275 |
["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
|
|
|
304 |
["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
305 |
]
|
306 |
|
307 |
+
examples_wyw = [
|
308 |
+
["我和前男友分手了", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
309 |
+
["量子计算机的原理", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
310 |
+
["李白和杜甫的结拜故事", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
311 |
+
["林黛玉和伏地魔的关系是什么?", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
312 |
+
["我被同事陷害了,帮我写一篇文言文骂他", gen_limit_long, 1, 0.2, 0.3, 0.3],
|
313 |
+
]
|
314 |
+
|
315 |
##########################################################################
|
316 |
|
317 |
with gr.Blocks(title=title) as demo:
|
|
|
377 |
clear.click(lambda: None, [], [output])
|
378 |
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
379 |
|
380 |
+
with gr.Tab("=== WenYanWen Q/A ==="):
|
381 |
+
gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [WenYanWen 文言文 Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{wyw_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
|
382 |
+
with gr.Row():
|
383 |
+
with gr.Column():
|
384 |
+
prompt = gr.Textbox(lines=2, label="Prompt", value="我和前男友分手了")
|
385 |
+
token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
|
386 |
+
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
387 |
+
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
|
388 |
+
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
|
389 |
+
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
|
390 |
+
with gr.Column():
|
391 |
+
with gr.Row():
|
392 |
+
submit = gr.Button("Submit", variant="primary")
|
393 |
+
clear = gr.Button("Clear", variant="secondary")
|
394 |
+
output = gr.Textbox(label="Output", lines=30)
|
395 |
+
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
|
396 |
+
submit.click(evaluate_wyw, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
|
397 |
+
clear.click(lambda: None, [], [output])
|
398 |
+
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
399 |
+
|
400 |
|
401 |
demo.queue(concurrency_count=1, max_size=10)
|
402 |
demo.launch(share=False)
|