Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -60,9 +60,10 @@ if text:
|
|
60 |
return_tensors='pt',
|
61 |
return_length=True
|
62 |
)
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
_, preds = torch.max(outputs, 1)
|
68 |
|
|
|
60 |
return_tensors='pt',
|
61 |
return_length=True
|
62 |
)
|
63 |
+
input_ids = encoded_dict['length'].unsqueeze(0)
|
64 |
+
attn_mask = torch.arange(input_ids.size(1)).to(device)
|
65 |
+
attn_mask = attn_mask[None, :] < input_ids_len[:, None]
|
66 |
+
outputs = model(encoded_dict['input_ids'], attn_mask)
|
67 |
|
68 |
_, preds = torch.max(outputs, 1)
|
69 |
|