ner-demo-evaluate / evaluate_data.py
elshehawy's picture
use 100 examples of test set
bd5687f
raw
history blame
3.6 kB
from evaluate_model import compute_metrics
from datasets import load_from_disk, Dataset
from transformers import AutoTokenizer
import os
import pickle
from transformers import AutoModelForTokenClassification
# from transformers import DataCollatorForTokenClassification
from utils import tokenize_and_align_labels
from rich import print
import huggingface_hub
import torch
import json
from tqdm import tqdm
# _ = load_dotenv(find_dotenv()) # read local .env file
hf_token= os.environ['HF_TOKEN']
huggingface_hub.login(hf_token)
checkpoint = 'elshehawy/finer-ord-transformers'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_path = './data/merged_dataset/'
test = load_from_disk(data_path)['test']
test = Dataset.from_dict(test[:105])
feature_path = './data/ner_feature.pickle'
with open(feature_path, 'rb') as f:
ner_feature = pickle.load(f)
# data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
tokenized_test = test.map(
tokenize_and_align_labels,
batched=True,
batch_size=None,
remove_columns=test.column_names[2:],
fn_kwargs={'tokenizer': tokenizer}
)
# tokenized_dataset.set_format('torch')
def collate_fn(data):
input_ids = [(element['input_ids']) for element in data]
attention_mask = [element['attention_mask'] for element in data]
token_type_ids = [element['token_type_ids'] for element in data]
labels = [element['labels'] for element in data]
return input_ids, token_type_ids, attention_mask, labels
loader = torch.utils.data.DataLoader(tokenized_test, batch_size=16, collate_fn=collate_fn)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
ner_model = ner_model.eval()
def get_metrics_trf():
print(device)
y_true, logits = [], []
for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
ner_model.to(device)
with torch.no_grad():
logits.extend(
ner_model(
input_ids=torch.tensor(input_ids).to(device),
token_type_ids=torch.tensor(token_type_ids).to(device),
attention_mask=torch.tensor(attention_mask).to(device)
).logits.cpu().numpy()
)
y_true.extend(labels)
all_metrics = compute_metrics((logits, y_true))
return all_metrics
# with open('./metrics/trf/metrics.json', 'w') as f:
# json.dump(all_metrics, f)
def find_orgs(tokens, labels):
orgs = []
prev_tok_id = 0
for i, (token, label) in enumerate(zip(tokens, labels)):
if label == 'B-ORG':
org = []
org.append(token)
orgs.append(org)
prev_tok_id = i
if label == 'I-ORG' and (i-1) == prev_tok_id:
org = orgs[-1]
org.append(token)
orgs[-1] = org
prev_tok_id = i
# print(i)
return [tokenizer.convert_tokens_to_string(org) for org in orgs]
def store_sample_data():
test_data = []
for sent in test:
labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
# print(labels)
sent_orgs = find_orgs(sent['tokens'], labels)
sent_text = tokenizer.convert_tokens_to_string(sent['tokens'])
test_data.append({
'id': sent['id'],
'text': sent_text,
'orgs': sent_orgs
})
with open('data/sample_data.json', 'w') as f:
json.dump(test_data, f)