File size: 3,758 Bytes
fffc1e9
bd5687f
fffc1e9
 
 
 
 
 
 
 
 
10eaeda
 
fffc1e9
 
 
 
 
 
 
 
 
 
 
bd639d6
fffc1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4df546
6353e4e
bd5687f
d4df546
953205d
d4df546
 
 
 
 
 
 
 
 
 
 
fffc1e9
 
37b391d
fffc1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953205d
fffc1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4df546
 
fffc1e9
 
d4df546
fffc1e9
 
 
 
 
 
 
 
 
 
 
0bb0923
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from evaluate_model import compute_metrics
from datasets import load_from_disk, Dataset
from transformers import AutoTokenizer
import os
import pickle
from transformers import AutoModelForTokenClassification
# from transformers import DataCollatorForTokenClassification
from utils import tokenize_and_align_labels
from rich import print
import huggingface_hub
import torch
import json
from tqdm import tqdm

# _ = load_dotenv(find_dotenv()) # read local .env file
hf_token= os.environ['HF_TOKEN']
huggingface_hub.login(hf_token)

checkpoint = 'elshehawy/finer-ord-transformers'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

data_path = './data/merged_dataset/'

test = load_from_disk(data_path)['test']
test = Dataset.from_dict(test[:16])

feature_path = './data/ner_feature.pickle'

with open(feature_path, 'rb') as f:
    ner_feature = pickle.load(f)

    
# data_collator  = DataCollatorForTokenClassification(tokenizer=tokenizer)
    
ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)

# tokenized_dataset.set_format('torch')

def collate_fn(data):
    input_ids = [(element['input_ids']) for element in data]
    attention_mask = [element['attention_mask'] for element in data]
    token_type_ids = [element['token_type_ids'] for element in data]
    labels = [element['labels'] for element in data]
    
    return input_ids, token_type_ids, attention_mask, labels


ner_model = ner_model.eval()



def get_metrics_trf(data):
    # print(device)
    
    data = Dataset.from_dict(data)
        
    tokenized_data = data.map(
        tokenize_and_align_labels,
        batched=True,
        batch_size=None,
        remove_columns=data.column_names[2:],
        fn_kwargs={'tokenizer': tokenizer}
    )
    
    loader = torch.utils.data.DataLoader(tokenized_data, batch_size=16, collate_fn=collate_fn)
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    
    y_true, logits = [], []
    for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
        ner_model.to(device)
        with torch.no_grad():
            logits.extend(
                ner_model(
                    input_ids=torch.tensor(input_ids).to(device),
                    token_type_ids=torch.tensor(token_type_ids).to(device),
                    attention_mask=torch.tensor(attention_mask).to(device)
                ).logits.cpu().numpy()
            )

            y_true.extend(labels)


    all_metrics = compute_metrics((logits, y_true))
    return all_metrics

    # with open('./metrics/trf/metrics.json', 'w') as f:
    #     json.dump(all_metrics, f)
    

def find_orgs_in_data(tokens, labels):
    orgs = []
    prev_tok_id = 0
    for i, (token, label) in enumerate(zip(tokens, labels)):
        if label == 'B-ORG':
            org = []
            org.append(token)
            orgs.append(org)
            prev_tok_id = i
        
        if label == 'I-ORG' and (i-1) == prev_tok_id:
            org = orgs[-1]
            org.append(token)
            orgs[-1] = org
            prev_tok_id = i
            # print(i)
            
    return [tokenizer.convert_tokens_to_string(org) for org in orgs] 



def store_sample_data(data):
    data = Dataset.from_dict(data)
    test_data = []

    for sent in data:
        labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
        # print(labels)
        sent_orgs = find_orgs(sent['tokens'], labels)

        sent_text = tokenizer.convert_tokens_to_string(sent['tokens'])
        test_data.append({
            'id': sent['id'],
            'text': sent_text,
            'orgs': sent_orgs
        })

    return test_data
    # with open('./data/sample_data.json', 'w') as f:
    #     json.dump(test_data, f)