File size: 3,544 Bytes
fffc1e9
 
 
 
 
 
 
 
 
 
 
10eaeda
 
fffc1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aee0575
fffc1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aee0575
fffc1e9
37b391d
fffc1e9
 
 
 
 
df5e7b3
fffc1e9
 
 
37b391d
fffc1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970df24
fffc1e9
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from evaluate_model import compute_metrics
from datasets import load_from_disk
from transformers import AutoTokenizer
import os
import pickle
from transformers import AutoModelForTokenClassification
# from transformers import DataCollatorForTokenClassification
from utils import tokenize_and_align_labels
from rich import print
import huggingface_hub
import torch
import json
from tqdm import tqdm

# _ = load_dotenv(find_dotenv()) # read local .env file
hf_token= os.environ['HF_TOKEN']
huggingface_hub.login(hf_token)

checkpoint = 'elshehawy/finer-ord-transformers'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

data_path = './data/merged_dataset/'

test = load_from_disk(data_path)['test']

feature_path = './data/ner_feature.pickle'

with open(feature_path, 'rb') as f:
    ner_feature = pickle.load(f)

    
# data_collator  = DataCollatorForTokenClassification(tokenizer=tokenizer)
    
ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)


tokenized_test = test.map(
    tokenize_and_align_labels,
    batched=True,
    batch_size=None,
    remove_columns=test.column_names[2:],
    fn_kwargs={'tokenizer': tokenizer}
)

# tokenized_dataset.set_format('torch')


def collate_fn(data):
    input_ids = [(element['input_ids']) for element in data]
    attention_mask = [element['attention_mask'] for element in data]
    token_type_ids = [element['token_type_ids'] for element in data]
    labels = [element['labels'] for element in data]
    
    return input_ids, token_type_ids, attention_mask, labels

loader = torch.utils.data.DataLoader(tokenized_test, batch_size=32, collate_fn=collate_fn)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print(device)

ner_model = ner_model.eval()



def get_metrics_trf():
    y_true, logits = [], []

    for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
        ner_model.to(device)
        with torch.no_grad():
            logits.extend(
                ner_model(
                    input_ids=torch.tensor(input_ids).to(device),
                    token_type_ids=torch.tensor(token_type_ids).to(device),
                    attention_mask=torch.tensor(attention_mask).to(device)
                ).logits.cpu().numpy()
            )

            y_true.extend(labels)


    all_metrics = compute_metrics((logits, y_true))
    return all_metrics

    # with open('./metrics/trf/metrics.json', 'w') as f:
    #     json.dump(all_metrics, f)
    

    

def find_orgs(tokens, labels):
    orgs = []
    prev_tok_id = 0
    for i, (token, label) in enumerate(zip(tokens, labels)):
        if label == 'B-ORG':
            org = []
            org.append(token)
            orgs.append(org)
            prev_tok_id = i
        
        if label == 'I-ORG' and (i-1) == prev_tok_id:
            org = orgs[-1]
            org.append(token)
            orgs[-1] = org
            prev_tok_id = i
            # print(i)
            
    return [tokenizer.convert_tokens_to_string(org) for org in orgs] 



def store_sample_data():
    test_data = []

    for sent in test:
        labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
        # print(labels)
        sent_orgs = find_orgs(sent['tokens'], labels)

        sent_text = tokenizer.convert_tokens_to_string(sent['tokens'])
        test_data.append({
            'id': sent['id'],
            'text': sent_text,
            'orgs': sent_orgs
        })

    with open('data/sample_data.json', 'w') as f:
        json.dump(test_data, f)