elshehawy commited on
Commit
37b391d
β€’
1 Parent(s): 10eaeda

fix UnboundLocalError

Browse files
Files changed (1) hide show
  1. evaluate_data.py +2 -1
evaluate_data.py CHANGED
@@ -55,6 +55,7 @@ def collate_fn(data):
55
 
56
  loader = torch.utils.data.DataLoader(tokenized_test, batch_size=32, collate_fn=collate_fn)
57
  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
 
58
 
59
  ner_model = ner_model.eval()
60
 
@@ -64,7 +65,7 @@ def get_metrics_trf():
64
  y_true, logits = [], []
65
 
66
  for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
67
- ner_model = ner_model.to(device)
68
  with torch.no_grad():
69
  logits.extend(
70
  ner_model(
 
55
 
56
  loader = torch.utils.data.DataLoader(tokenized_test, batch_size=32, collate_fn=collate_fn)
57
  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
58
+ print(device)
59
 
60
  ner_model = ner_model.eval()
61
 
 
65
  y_true, logits = [], []
66
 
67
  for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
68
+ ner_model.to(device)
69
  with torch.no_grad():
70
  logits.extend(
71
  ner_model(