NeverlandPeter commited on
Commit
43d3d85
1 Parent(s): cbf69e8
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -20,7 +20,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  ctx_limit = 2500
22
  gen_limit = 500
23
- ENABLE_VISUAL = True
24
 
25
  ########################## text rwkv ################################################################
26
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
@@ -35,13 +35,13 @@ args = model_v6.args
35
  eng_name = 'rwkv-x060-eng_single_round_qa-3B-20240430-ctx1024'
36
  chn_name = 'rwkv-x060-chn_single_round_qa-3B-20240505-ctx1024'
37
 
38
- # state_eng_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{eng_name}.pth')
39
- # state_chn_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{chn_name}.pth')
40
 
41
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
42
  chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
43
- state_eng_raw = torch.load(eng_file)
44
- state_chn_raw = torch.load(chn_file)
45
 
46
  state_eng = [None] * args.n_layer * 3
47
  state_chn = [None] * args.n_layer * 3
 
20
 
21
  ctx_limit = 2500
22
  gen_limit = 500
23
+ ENABLE_VISUAL = False
24
 
25
  ########################## text rwkv ################################################################
26
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
 
35
  eng_name = 'rwkv-x060-eng_single_round_qa-3B-20240430-ctx1024'
36
  chn_name = 'rwkv-x060-chn_single_round_qa-3B-20240505-ctx1024'
37
 
38
+ # state_eng_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{eng_name}.pth', map_location=torch.device('cpu'))
39
+ # state_chn_raw = torch.load(f'/mnt/e/RWKV-Runner/models/{chn_name}.pth', map_location=torch.device('cpu'))
40
 
41
  eng_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{eng_name}.pth")
42
  chn_file = hf_hub_download(repo_id="BlinkDL/temp-latest-training-models", filename=f"{chn_name}.pth")
43
+ state_eng_raw = torch.load(eng_file, map_location=torch.device('cpu'))
44
+ state_chn_raw = torch.load(chn_file, map_location=torch.device('cpu'))
45
 
46
  state_eng = [None] * args.n_layer * 3
47
  state_chn = [None] * args.n_layer * 3