elshehawy commited on
Commit
894b24d
β€’
1 Parent(s): 6c47a7b

add application files

Browse files
Files changed (37) hide show
  1. .gitignore +5 -0
  2. app.py +100 -0
  3. data/merged_dataset/dataset_dict.json +1 -0
  4. data/merged_dataset/orig_test/cache-eeafde0b6770e328.arrow +3 -0
  5. data/merged_dataset/orig_test/data-00000-of-00001.arrow +3 -0
  6. data/merged_dataset/orig_test/dataset_info.json +34 -0
  7. data/merged_dataset/orig_test/state.json +13 -0
  8. data/merged_dataset/orig_train/cache-45d1543dc33c36be.arrow +3 -0
  9. data/merged_dataset/orig_train/data-00000-of-00001.arrow +3 -0
  10. data/merged_dataset/orig_train/dataset_info.json +34 -0
  11. data/merged_dataset/orig_train/state.json +13 -0
  12. data/merged_dataset/orig_validation/cache-afff9bbc07b5bee3.arrow +3 -0
  13. data/merged_dataset/orig_validation/data-00000-of-00001.arrow +3 -0
  14. data/merged_dataset/orig_validation/dataset_info.json +34 -0
  15. data/merged_dataset/orig_validation/state.json +13 -0
  16. data/merged_dataset/test/cache-3a6709085dd0f520.arrow +3 -0
  17. data/merged_dataset/test/cache-50fbc051d6b536f8.arrow +3 -0
  18. data/merged_dataset/test/cache-7344e423192cdf30.arrow +3 -0
  19. data/merged_dataset/test/cache-861a0fd50d74bfe1.arrow +3 -0
  20. data/merged_dataset/test/data-00000-of-00001.arrow +3 -0
  21. data/merged_dataset/test/dataset_info.json +34 -0
  22. data/merged_dataset/test/state.json +13 -0
  23. data/merged_dataset/train/cache-f8f6a910898e33f3.arrow +3 -0
  24. data/merged_dataset/train/data-00000-of-00001.arrow +3 -0
  25. data/merged_dataset/train/dataset_info.json +34 -0
  26. data/merged_dataset/train/state.json +13 -0
  27. data/merged_dataset/validation/cache-a70cdc1f600f2440.arrow +3 -0
  28. data/merged_dataset/validation/cache-c442280565074102.arrow +3 -0
  29. data/merged_dataset/validation/data-00000-of-00001.arrow +3 -0
  30. data/merged_dataset/validation/dataset_info.json +34 -0
  31. data/merged_dataset/validation/state.json +13 -0
  32. data/ner_feature.pickle +3 -0
  33. data/sample_data.json +0 -0
  34. evaluate_model.py +62 -0
  35. metrics.py +78 -0
  36. requirements.txt +7 -0
  37. utils.py +83 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .ipynb_checkpoints/
2
+ Untitled.ipynb
3
+ __pycache__/
4
+ evaluate_trf.ipynb
5
+ test.json
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from metrics import calc_metrics
2
+ import gradio as gr
3
+ from openai import OpenAI
4
+ import os
5
+
6
+ from transformers import pipeline
7
+ # from dotenv import load_dotenv, find_dotenv
8
+ import huggingface_hub
9
+ import json
10
+ # from simcse import SimCSE # use for gpt
11
+ from evaluate_data import store_sample_data, get_metrics_trf
12
+
13
+ store_sample_data()
14
+
15
+
16
+
17
+ with open('./data/sample_data.json', 'r') as f:
18
+ # sample_data = [
19
+ # {'id': "", 'text': "", 'orgs': ["", ""]}
20
+ # ]
21
+ sample_data = json.load(f)
22
+
23
+ # _ = load_dotenv(find_dotenv()) # read local .env file
24
+ hf_token= os.environ['HF_TOKEN']
25
+ huggingface_hub.login(hf_token)
26
+
27
+ pipe = pipeline("token-classification", model="elshehawy/finer-ord-transformers", aggregation_strategy="first")
28
+
29
+
30
+ llm_model = 'gpt-3.5-turbo-0125'
31
+ # openai.api_key = os.environ['OPENAI_API_KEY']
32
+
33
+ client = OpenAI(
34
+ api_key=os.environ.get("OPENAI_API_KEY"),
35
+ )
36
+
37
+
38
+ def get_completion(prompt, model=llm_model):
39
+ messages = [{"role": "user", "content": prompt}]
40
+ response = client.chat.completions.create(
41
+ messages=messages,
42
+ model=model,
43
+ temperature=0,
44
+ )
45
+ return response.choices[0].message.content
46
+
47
+
48
+
49
+ def find_orgs_gpt(sentence):
50
+ prompt = f"""
51
+ In context of named entity recognition (NER), find all organizations in the text delimited by triple backticks.
52
+
53
+ text:
54
+ ```
55
+ {sentence}
56
+ ```
57
+ You should output only a list of organizations and follow this output format exactly: ["org_1", "org_2", "org_3"]
58
+ """
59
+
60
+ sent_orgs_str = get_completion(prompt)
61
+ sent_orgs = json.loads(sent_orgs_str)
62
+
63
+ return sent_orgs
64
+
65
+
66
+
67
+ # def find_orgs_trf(sentence):
68
+ # org_list = []
69
+ # for ent in pipe(sentence):
70
+ # if ent['entity_group'] == 'ORG':
71
+ # # message += f'\n- {ent["word"]} \t- score: {ent["score"]}'
72
+ # # message += f'\n- {ent["word"]}'# \t- score: {ent["score"]}'
73
+ # org_list.append(ent['word'])
74
+ # return list(set(org_list))
75
+
76
+
77
+ true_orgs = [sent['orgs'] for sent in sample_data]
78
+
79
+ predicted_orgs_gpt = [find_orgs_gpt(sent['text']) for sent in sample_data]
80
+ # predicted_orgs_trf = [find_orgs_trf(sent['text']) for sent in sample_data]
81
+
82
+ all_metrics = {}
83
+
84
+ # sim_model = SimCSE('sentence-transformers/all-MiniLM-L6-v2')
85
+ # all_metrics['gpt'] = calc_metrics(true_orgs, predicted_orgs_gpt, sim_model)
86
+ all_metrics['trf'] = get_metrics_trf()
87
+
88
+
89
+
90
+ # example = """
91
+ # My latest exclusive for The Hill : Conservative frustration over Republican efforts to force a House vote on reauthorizing the Export - Import Bank boiled over Wednesday during a contentious GOP meeting.
92
+
93
+ # """
94
+ def find_orgs(sentence, choice):
95
+ return all_metrics
96
+ radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
97
+ textbox = gr.Textbox(label="Enter your text", placeholder=str(all_metrics), lines=8)
98
+
99
+ iface = gr.Interface(fn=find_orgs, inputs=[textbox, radio_btn], outputs="text", examples=[[example]])
100
+ iface.launch(share=True)
data/merged_dataset/dataset_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": ["train", "validation", "test", "orig_train", "orig_validation", "orig_test"]}
data/merged_dataset/orig_test/cache-eeafde0b6770e328.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b20ac141827d2e067e67afe6bb6efe6fdabf3d227c33b0764aff545c15ee6c
3
+ size 953224
data/merged_dataset/orig_test/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc88467edb4babd0a7fc480903eed43b359e3755b5eecc87780fb33864530237
3
+ size 437856
data/merged_dataset/orig_test/dataset_info.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "id": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "tokens": {
10
+ "feature": {
11
+ "dtype": "string",
12
+ "_type": "Value"
13
+ },
14
+ "_type": "Sequence"
15
+ },
16
+ "ner_tags": {
17
+ "feature": {
18
+ "names": [
19
+ "O",
20
+ "B-PER",
21
+ "I-PER",
22
+ "B-LOC",
23
+ "I-LOC",
24
+ "B-ORG",
25
+ "I-ORG"
26
+ ],
27
+ "_type": "ClassLabel"
28
+ },
29
+ "_type": "Sequence"
30
+ }
31
+ },
32
+ "homepage": "",
33
+ "license": ""
34
+ }
data/merged_dataset/orig_test/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "0b33bf3dd398a19a",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
data/merged_dataset/orig_train/cache-45d1543dc33c36be.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83332abe5b5e7d05e3ae4376018429896530b916ab3ff74eb8ca7aef94497961
3
+ size 3009552
data/merged_dataset/orig_train/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac320be565428a08ad7a3c43d03ad14810775cb0620b47659321228b17a22148
3
+ size 1371040
data/merged_dataset/orig_train/dataset_info.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "id": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "tokens": {
10
+ "feature": {
11
+ "dtype": "string",
12
+ "_type": "Value"
13
+ },
14
+ "_type": "Sequence"
15
+ },
16
+ "ner_tags": {
17
+ "feature": {
18
+ "names": [
19
+ "O",
20
+ "B-PER",
21
+ "I-PER",
22
+ "B-LOC",
23
+ "I-LOC",
24
+ "B-ORG",
25
+ "I-ORG"
26
+ ],
27
+ "_type": "ClassLabel"
28
+ },
29
+ "_type": "Sequence"
30
+ }
31
+ },
32
+ "homepage": "",
33
+ "license": ""
34
+ }
data/merged_dataset/orig_train/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "74ec65c2b682826d",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
data/merged_dataset/orig_validation/cache-afff9bbc07b5bee3.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94d972da8072255a5df65632859af45a7ce025dd587dc066106ea8e7224b0a1f
3
+ size 387592
data/merged_dataset/orig_validation/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:758e0d9bfceefd51b3cef856c8e15786ce0493da10bdf231f27e067b6b66caec
3
+ size 174712
data/merged_dataset/orig_validation/dataset_info.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "id": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "tokens": {
10
+ "feature": {
11
+ "dtype": "string",
12
+ "_type": "Value"
13
+ },
14
+ "_type": "Sequence"
15
+ },
16
+ "ner_tags": {
17
+ "feature": {
18
+ "names": [
19
+ "O",
20
+ "B-PER",
21
+ "I-PER",
22
+ "B-LOC",
23
+ "I-LOC",
24
+ "B-ORG",
25
+ "I-ORG"
26
+ ],
27
+ "_type": "ClassLabel"
28
+ },
29
+ "_type": "Sequence"
30
+ }
31
+ },
32
+ "homepage": "",
33
+ "license": ""
34
+ }
data/merged_dataset/orig_validation/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "2b90f959ed79ba44",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
data/merged_dataset/test/cache-3a6709085dd0f520.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28be4e971c5d73f16b208f4b15d5965cfc81fb8936ce1c711fedc6fff5b3479a
3
+ size 953224
data/merged_dataset/test/cache-50fbc051d6b536f8.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:141c98d5d03e91f39e198897f147e0c2c6fa2c7a4c55174993392ec512599b34
3
+ size 953224
data/merged_dataset/test/cache-7344e423192cdf30.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:855fbd5d5477353be5930ee9ed4435238d847ef0971abe8106056e8d93639cd8
3
+ size 953240
data/merged_dataset/test/cache-861a0fd50d74bfe1.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:004f623f011a702e9eff454113818978e90f497b8ad806a8f86fa011868a0831
3
+ size 12304024
data/merged_dataset/test/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc107fe614a3ee59fd5b302dc0a56896e63f2a3106fd88b5c52d4fd88b77a0fe
3
+ size 437856
data/merged_dataset/test/dataset_info.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "id": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "tokens": {
10
+ "feature": {
11
+ "dtype": "string",
12
+ "_type": "Value"
13
+ },
14
+ "_type": "Sequence"
15
+ },
16
+ "ner_tags": {
17
+ "feature": {
18
+ "names": [
19
+ "O",
20
+ "B-PER",
21
+ "I-PER",
22
+ "B-LOC",
23
+ "I-LOC",
24
+ "B-ORG",
25
+ "I-ORG"
26
+ ],
27
+ "_type": "ClassLabel"
28
+ },
29
+ "_type": "Sequence"
30
+ }
31
+ },
32
+ "homepage": "",
33
+ "license": ""
34
+ }
data/merged_dataset/test/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "538471187ad5b763",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
data/merged_dataset/train/cache-f8f6a910898e33f3.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1060a3d1000579dc92f65efb200632efdafa80f5d750f0c2298d82193e648f3e
3
+ size 3009552
data/merged_dataset/train/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1546be0dd9960d920988ce2bb6883fc567db03c2c80d0f8678d4bf95001a1a5f
3
+ size 1371040
data/merged_dataset/train/dataset_info.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "id": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "tokens": {
10
+ "feature": {
11
+ "dtype": "string",
12
+ "_type": "Value"
13
+ },
14
+ "_type": "Sequence"
15
+ },
16
+ "ner_tags": {
17
+ "feature": {
18
+ "names": [
19
+ "O",
20
+ "B-PER",
21
+ "I-PER",
22
+ "B-LOC",
23
+ "I-LOC",
24
+ "B-ORG",
25
+ "I-ORG"
26
+ ],
27
+ "_type": "ClassLabel"
28
+ },
29
+ "_type": "Sequence"
30
+ }
31
+ },
32
+ "homepage": "",
33
+ "license": ""
34
+ }
data/merged_dataset/train/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "13b20c4adf67dcf4",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
data/merged_dataset/validation/cache-a70cdc1f600f2440.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0f9fa1d8fcd428b779d15f9f386d8b463c9f542dadb0056a16e3eb6b817cb5a
3
+ size 387592
data/merged_dataset/validation/cache-c442280565074102.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca8b2d2e4610afe0cf6daf0ca37d02414e1df7e6a486c80e0ea2b25bf7808807
3
+ size 387592
data/merged_dataset/validation/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fb30ccc26dc3d3172ffd54077d11217436ed169738251bf51fdb82908497868
3
+ size 174712
data/merged_dataset/validation/dataset_info.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "id": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "tokens": {
10
+ "feature": {
11
+ "dtype": "string",
12
+ "_type": "Value"
13
+ },
14
+ "_type": "Sequence"
15
+ },
16
+ "ner_tags": {
17
+ "feature": {
18
+ "names": [
19
+ "O",
20
+ "B-PER",
21
+ "I-PER",
22
+ "B-LOC",
23
+ "I-LOC",
24
+ "B-ORG",
25
+ "I-ORG"
26
+ ],
27
+ "_type": "ClassLabel"
28
+ },
29
+ "_type": "Sequence"
30
+ }
31
+ },
32
+ "homepage": "",
33
+ "license": ""
34
+ }
data/merged_dataset/validation/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "f95fe8e7a800be97",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
data/ner_feature.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ecfd61f261845f22d0b83a72263f7326514d78d71d3c52534ede75671dacc70
3
+ size 286
data/sample_data.json ADDED
The diff for this file is too large to render. See raw diff
 
evaluate_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ import pickle
4
+
5
+ metric = evaluate.load("seqeval")
6
+ with open('./data/ner_feature.pickle', 'rb') as f:
7
+ ner_feature = pickle.load(f)
8
+
9
+ label_names = ner_feature.feature.names
10
+ # label2id = {label: ner_feature.feature.str2int(label) for label in label_names}
11
+ # id2label = {v: k for k, v in label2id.items()}
12
+
13
+ def compute_metrics(eval_preds):
14
+ """
15
+ This compute_metrics() function first takes the argmax of the logits to convert them to predictions
16
+ (as usual, the logits and the probabilities are in the same order,
17
+ so we don’t need to apply the softmax).
18
+ Then we have to convert both labels and predictions from integers to strings.
19
+ We remove all the values where the label is -100, then pass the results to the metric.compute() method:
20
+ """
21
+
22
+ logits, labels = eval_preds
23
+ predictions = np.argmax(logits, axis=-1)
24
+
25
+ # Remove ignored index (special tokens) and convert to labels
26
+ true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
27
+ true_predictions = [
28
+ [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
29
+ for prediction, label in zip(predictions, labels)
30
+ ]
31
+ all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
32
+
33
+ # return all_metrics
34
+ # return {
35
+ # "precision": all_metrics["overall_precision"],
36
+ # "recall": all_metrics["overall_recall"],
37
+ # "f1": all_metrics["overall_f1"],
38
+ # "accuracy": all_metrics["overall_accuracy"],
39
+ # }
40
+
41
+ return {
42
+ # organization metrics
43
+ 'org_precision': all_metrics['ORG']['precision'],
44
+ 'org_recall': all_metrics['ORG']['recall'],
45
+ 'org_f1': all_metrics['ORG']['f1'],
46
+
47
+ # person metrics
48
+ 'per_precision': all_metrics['PER']['precision'],
49
+ 'per_recall': all_metrics['PER']['recall'],
50
+ 'per_f1': all_metrics['PER']['f1'],
51
+
52
+ # location metrics
53
+ 'loc_precision': all_metrics['LOC']['precision'],
54
+ 'loc_recall': all_metrics['LOC']['recall'],
55
+ 'loc_f1': all_metrics['LOC']['f1'],
56
+
57
+ # over all metrics
58
+ 'precision': all_metrics['overall_precision'],
59
+ 'recall': all_metrics['overall_recall'],
60
+ 'f1': all_metrics['overall_f1'],
61
+ 'accuracy': all_metrics['overall_accuracy']
62
+ }
metrics.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def calc_recall(true_pos, false_neg, eps=1e-8):
2
+ return true_pos / (true_pos + false_neg + eps)
3
+
4
+
5
+
6
+ def calc_precision(true_pos, false_pos, eps=1e-8):
7
+ return true_pos / (true_pos + false_pos + eps)
8
+
9
+
10
+
11
+ def calc_f1_score(precision, recall, eps=1e-8):
12
+ return (2*precision*recall) / (precision + recall + eps)
13
+
14
+
15
+
16
+ def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8):
17
+ true_pos = 0
18
+ false_pos = 0
19
+ false_neg = 0
20
+
21
+ false_pos_ids = []
22
+ false_neg_ids = []
23
+
24
+ i = 0
25
+ total = len(true)
26
+ for j, (true_ents, pred_ents) in enumerate(zip(true, predicted)):
27
+ i += 1
28
+ # print(f'{i}/{total}')
29
+ # print('----------------------------')
30
+
31
+ if len(true_ents) == 0:
32
+ false_pos += len(pred_ents)
33
+
34
+ if len(pred_ents) > 0:
35
+ false_pos_ids.append(j)
36
+
37
+ continue
38
+
39
+ if len(pred_ents) == 0:
40
+ false_neg += len(true_ents)
41
+
42
+ if len(true_ents) > 0:
43
+ # print('False Negative')
44
+ false_neg_ids.append(j)
45
+
46
+ continue
47
+
48
+ similarities = model.similarity(true_ents, pred_ents, device='cuda')
49
+
50
+ for row in similarities:
51
+ if (row >= threshold).any():
52
+ true_pos += 1
53
+ else:
54
+ false_neg += 1
55
+ # print('False Negative 2222222')
56
+ false_neg_ids.append(j)
57
+
58
+ for row in similarities.T:
59
+ if (row >= threshold).any():
60
+ continue
61
+ else:
62
+ false_pos += 1
63
+ false_pos_ids.append(j)
64
+
65
+ recall = calc_recall(true_pos, false_neg)
66
+ precision = calc_precision(true_pos, false_pos)
67
+ f1_score = calc_f1_score(precision, recall, eps=eps)
68
+
69
+ return {
70
+ # 'true_pos': true_pos,
71
+ # 'false_pos': false_pos,
72
+ # 'false_neg': false_neg,
73
+ 'recall': recall,
74
+ 'precision': precision,
75
+ 'f1': f1_score,
76
+ # 'false_pos_ids': list(set(false_pos_ids)),
77
+ # 'false_neg_ids': list(set(false_neg_ids))
78
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ openai
2
+ transformers[torch]
3
+ tqdm==4.66.1
4
+ datasets==2.18.0
5
+ evaluate
6
+ seqeval
7
+ rich
utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def find_broken_examples(data):
2
+ splits = list(data.keys())
3
+ broken = []
4
+
5
+ for s in splits:
6
+ for i, tokens in enumerate(data[s]['tokens']):
7
+ for token in tokens:
8
+ if not token.isprintable():
9
+ broken.append(s + '-' + str(i))
10
+
11
+ return broken
12
+
13
+
14
+ def update_data(examples, split, broken_ids):
15
+ new_tags = []
16
+ new_tokens = []
17
+ for id_ in examples['id']:
18
+ sent_id = split + '-' + id_
19
+ if sent_id in broken_ids:
20
+ continue
21
+
22
+ new_tokens.append(examples['tokens'][int(id_)])
23
+ new_tags.append(examples['ner_tags'][int(id_)])
24
+
25
+ assert len(new_tokens) == len(new_tags)
26
+ assert len(new_tokens[-1]) == len(new_tags[-1])
27
+
28
+ return {
29
+ 'id': [str(i) for i in range(len(new_tokens))],
30
+ 'tokens': new_tokens,
31
+ 'ner_tags': new_tags
32
+ }
33
+
34
+
35
+ def align_labels_with_tokens(labels, word_ids):
36
+ new_labels = []
37
+ current_word = None
38
+ for word_id in word_ids:
39
+ if word_id != current_word:
40
+ # Start of a new word!
41
+ current_word = word_id
42
+ label = -100 if word_id is None else labels[word_id]
43
+ new_labels.append(label)
44
+ elif word_id is None:
45
+ # Special token
46
+ new_labels.append(-100)
47
+ else:
48
+ # Same word as previous token
49
+ # label = labels[word_id]
50
+ # If the label is B-XXX we change it to I-XXX
51
+ # if label % 2 == 1:
52
+ # label += 1
53
+ label = -100
54
+ new_labels.append(label)
55
+
56
+ return new_labels
57
+
58
+
59
+ def tokenize_and_align_labels(examples, tokenizer):
60
+ tokenized_inputs = tokenizer(
61
+ examples["tokens"], truncation=True, is_split_into_words=True, padding='max_length'
62
+ )
63
+ all_labels = examples["ner_tags"]
64
+ new_labels = []
65
+ word_ids = []
66
+ for i, labels in enumerate(all_labels):
67
+ word_ids.append(tokenized_inputs.word_ids(i))
68
+ new_labels.append(align_labels_with_tokens(labels, word_ids[i]))
69
+
70
+ tokenized_inputs["labels"] = new_labels
71
+ tokenized_inputs['word_ids'] = word_ids
72
+
73
+ return tokenized_inputs
74
+
75
+
76
+ # def model_init(checkpoint, id2label, label2id):
77
+ # model = AutoModelForTokenClassification.from_pretrained(
78
+ # checkpoint,
79
+ # id2label=id2label,
80
+ # label2id=label2id
81
+ # )
82
+
83
+ # return model