BlinkDL commited on
Commit
b761794
1 Parent(s): 8919796

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -5
app.py CHANGED
@@ -25,24 +25,36 @@ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
25
 
26
  args = model.args
27
  eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
28
- chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
29
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
30
- chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
31
  state_eng_raw = torch.load(eng_file)
32
- state_chn_raw = torch.load(chn_file)
33
  state_eng = [None] * args.n_layer * 3
 
 
 
 
34
  state_chn = [None] * args.n_layer * 3
 
 
 
 
 
 
35
  for i in range(args.n_layer):
36
  dd = model.strategy[i]
37
  dev = dd.device
38
  atype = dd.atype
39
  state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
40
- state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
41
  state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
42
- state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
43
  state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
 
 
44
  state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
45
 
 
 
 
 
46
  def generate_prompt(instruction, input=""):
47
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
48
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -208,6 +220,56 @@ def evaluate_chn(
208
  torch.cuda.empty_cache()
209
  yield out_str.strip()
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  examples = [
212
  ["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
213
  ["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
@@ -242,6 +304,14 @@ examples_chn = [
242
  ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
243
  ]
244
 
 
 
 
 
 
 
 
 
245
  ##########################################################################
246
 
247
  with gr.Blocks(title=title) as demo:
@@ -307,6 +377,26 @@ with gr.Blocks(title=title) as demo:
307
  clear.click(lambda: None, [], [output])
308
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  demo.queue(concurrency_count=1, max_size=10)
312
  demo.launch(share=False)
 
25
 
26
  args = model.args
27
  eng_name = 'rwkv-x060-eng_single_round_qa-7B-20240516-ctx2048'
 
28
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
 
29
  state_eng_raw = torch.load(eng_file)
 
30
  state_eng = [None] * args.n_layer * 3
31
+
32
+ chn_name = 'rwkv-x060-chn_single_round_qa-7B-20240516-ctx2048'
33
+ chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
34
+ state_chn_raw = torch.load(chn_file)
35
  state_chn = [None] * args.n_layer * 3
36
+
37
+ wyw_name = 'rwkv-x060-chn_文言文和古典名著_single_round_qa-7B-20240601-ctx2048'
38
+ wyw_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{wyw_name}.pth")
39
+ state_wyw_raw = torch.load(wyw_file)
40
+ state_wyw = [None] * args.n_layer * 3
41
+
42
  for i in range(args.n_layer):
43
  dd = model.strategy[i]
44
  dev = dd.device
45
  atype = dd.atype
46
  state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
 
47
  state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
 
48
  state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
49
+
50
+ state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
51
+ state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
52
  state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
53
 
54
+ state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
55
+ state_wyw[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
+ state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
57
+
58
  def generate_prompt(instruction, input=""):
59
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
60
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
 
220
  torch.cuda.empty_cache()
221
  yield out_str.strip()
222
 
223
+ def evaluate_wyw(
224
+ ctx,
225
+ token_count=gen_limit,
226
+ temperature=1.0,
227
+ top_p=0.3,
228
+ presencePenalty=0.3,
229
+ countPenalty=0.3,
230
+ ):
231
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
232
+ alpha_frequency = countPenalty,
233
+ alpha_presence = presencePenalty,
234
+ token_ban = [], # ban the generation of some tokens
235
+ token_stop = [0]) # stop generation whenever you see any token here
236
+ ctx = qa_prompt(ctx)
237
+ all_tokens = []
238
+ out_last = 0
239
+ out_str = ''
240
+ occurrence = {}
241
+ state = copy.deepcopy(state_wyw)
242
+ for i in range(int(token_count)):
243
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
244
+ for n in occurrence:
245
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
246
+
247
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
248
+ if token in args.token_stop:
249
+ break
250
+ all_tokens += [token]
251
+ for xxx in occurrence:
252
+ occurrence[xxx] *= penalty_decay
253
+ if token not in occurrence:
254
+ occurrence[token] = 1
255
+ else:
256
+ occurrence[token] += 1
257
+
258
+ tmp = pipeline.decode(all_tokens[out_last:])
259
+ if '\ufffd' not in tmp:
260
+ out_str += tmp
261
+ yield out_str.strip()
262
+ out_last = i + 1
263
+
264
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
265
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
266
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
267
+ del out
268
+ del state
269
+ gc.collect()
270
+ torch.cuda.empty_cache()
271
+ yield out_str.strip()
272
+
273
  examples = [
274
  ["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
275
  ["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
 
304
  ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
305
  ]
306
 
307
+ examples_wyw = [
308
+ ["我和前男友分手了", gen_limit_long, 1, 0.2, 0.3, 0.3],
309
+ ["量子计算机的原理", gen_limit_long, 1, 0.2, 0.3, 0.3],
310
+ ["李白和杜甫的结拜故事", gen_limit_long, 1, 0.2, 0.3, 0.3],
311
+ ["林黛玉和伏地魔的关系是什么?", gen_limit_long, 1, 0.2, 0.3, 0.3],
312
+ ["我被同事陷害了,帮我写一篇文言文骂他", gen_limit_long, 1, 0.2, 0.3, 0.3],
313
+ ]
314
+
315
  ##########################################################################
316
 
317
  with gr.Blocks(title=title) as demo:
 
377
  clear.click(lambda: None, [], [output])
378
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
379
 
380
+ with gr.Tab("=== WenYanWen Q/A ==="):
381
+ gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [WenYanWen 文言文 Q/A](https://huggingface.co/BlinkDL/temp-latest-training-models/blob/main/{wyw_name}.pth). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
382
+ with gr.Row():
383
+ with gr.Column():
384
+ prompt = gr.Textbox(lines=2, label="Prompt", value="我和前男友分手了")
385
+ token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
386
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
387
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
388
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
389
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
390
+ with gr.Column():
391
+ with gr.Row():
392
+ submit = gr.Button("Submit", variant="primary")
393
+ clear = gr.Button("Clear", variant="secondary")
394
+ output = gr.Textbox(label="Output", lines=30)
395
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
396
+ submit.click(evaluate_wyw, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
397
+ clear.click(lambda: None, [], [output])
398
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
399
+
400
 
401
  demo.queue(concurrency_count=1, max_size=10)
402
  demo.launch(share=False)