Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
@@ -52,7 +52,7 @@ for i in range(args.n_layer):
|
|
52 |
state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
53 |
|
54 |
state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
55 |
-
state_wyw[i*3+1] =
|
56 |
state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
57 |
|
58 |
def generate_prompt(instruction, input=""):
|
|
|
52 |
state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
53 |
|
54 |
state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
55 |
+
state_wyw[i*3+1] = state_wyw_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
56 |
state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
57 |
|
58 |
def generate_prompt(instruction, input=""):
|