alex6095 commited on
Commit
b13a6c0
1 Parent(s): 245f241

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -60,9 +60,10 @@ if text:
60
  return_tensors='pt',
61
  return_length=True
62
  )
63
-
64
-
65
- outputs = model(encoded_dict['input_ids'], encoded_dict['length'].unsqueeze(0))
 
66
 
67
  _, preds = torch.max(outputs, 1)
68
 
 
60
  return_tensors='pt',
61
  return_length=True
62
  )
63
+ input_ids = encoded_dict['length'].unsqueeze(0)
64
+ attn_mask = torch.arange(input_ids.size(1)).to(device)
65
+ attn_mask = attn_mask[None, :] < input_ids_len[:, None]
66
+ outputs = model(encoded_dict['input_ids'], attn_mask)
67
 
68
  _, preds = torch.max(outputs, 1)
69