LasRuinasCirculares
commited on
Commit
•
861bb01
1
Parent(s):
2a5821d
Upload 7 files
Browse files- .gitattributes +1 -0
- knowledge_conflict_entity_based/.DS_Store +0 -0
- knowledge_conflict_entity_based/entity_substitute.py +126 -0
- knowledge_conflict_entity_based/requirements.txt +5 -0
- knowledge_conflict_entity_based/result/.DS_Store +0 -0
- knowledge_conflict_entity_based/result/entity_info.json +3 -0
- knowledge_conflict_entity_based/run.sh +2 -0
- knowledge_conflict_entity_based/setup.sh +6 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
knowledge_conflict_entity_based/result/entity_info.json filter=lfs diff=lfs merge=lfs -text
|
knowledge_conflict_entity_based/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
knowledge_conflict_entity_based/entity_substitute.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spacy
|
2 |
+
import zstandard as zstd
|
3 |
+
import json
|
4 |
+
import typing
|
5 |
+
import os
|
6 |
+
from tqdm import tqdm
|
7 |
+
import multiprocessing
|
8 |
+
import random
|
9 |
+
from langdetect import detect
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--input_dir', type=str, help='Path to the input file')
|
14 |
+
args = parser.parse_args()
|
15 |
+
input_dir = args.input_dir
|
16 |
+
|
17 |
+
|
18 |
+
def is_english(text):
|
19 |
+
try:
|
20 |
+
lang = detect(text)
|
21 |
+
return lang == 'en'
|
22 |
+
except:
|
23 |
+
return False
|
24 |
+
|
25 |
+
def process_text(texts, model, out_f, lock):
|
26 |
+
for text in texts:
|
27 |
+
doc = model(text)
|
28 |
+
freq_cnt = {}
|
29 |
+
for e in doc.ents:
|
30 |
+
if e not in freq_cnt:
|
31 |
+
freq_cnt[e] = 0
|
32 |
+
freq_cnt[e] += 1
|
33 |
+
if len(freq_cnt) == 0:
|
34 |
+
continue
|
35 |
+
sorted_freq = sorted(freq_cnt.items(), key = lambda x:[1])
|
36 |
+
most_freq = sorted_freq[-1][0]
|
37 |
+
data = {'text':text, 'main_entity':most_freq.text, 'label': most_freq.label_, 'id': most_freq.kb_id_}
|
38 |
+
json_data = json.dumps(data)
|
39 |
+
with lock:
|
40 |
+
out_f.write(json_data + '\n')
|
41 |
+
out_f.flush()
|
42 |
+
|
43 |
+
def run_ner_linking(texts: typing.List[str], ner_model_path: str):
|
44 |
+
nlp = spacy.load(ner_model_path)
|
45 |
+
out_f = open('result/temp_store_data.json', 'w', encoding='utf-8')
|
46 |
+
lock = multiprocessing.Lock()
|
47 |
+
processes = []
|
48 |
+
|
49 |
+
for i in tqdm(range(0, len(texts), 1000)):
|
50 |
+
p = multiprocessing.Process(target=process_text, args=(texts[i:i+1000], nlp, out_f, lock))
|
51 |
+
processes.append(p)
|
52 |
+
p.start()
|
53 |
+
|
54 |
+
for p in processes:
|
55 |
+
p.join()
|
56 |
+
|
57 |
+
out_f.close()
|
58 |
+
return
|
59 |
+
|
60 |
+
wikipedia_out_path='result/wikipedia.json'
|
61 |
+
subdirectories = [f.path for f in os.scandir(input_dir) if f.is_dir()]
|
62 |
+
wikipedia_data = []
|
63 |
+
for sub_dir in subdirectories:
|
64 |
+
chunk_dir = sub_dir+'/'
|
65 |
+
zst_files = [f for f in os.listdir(chunk_dir) if f.endswith('.zst')]
|
66 |
+
for file in tqdm(zst_files):
|
67 |
+
with open(chunk_dir+file, 'rb') as compressed_file:
|
68 |
+
decompressor = zstd.ZstdDecompressor()
|
69 |
+
with decompressor.stream_reader(compressed_file) as reader:
|
70 |
+
decompressed_data = reader.read()
|
71 |
+
for line in decompressed_data.splitlines():
|
72 |
+
data = json.loads(line)
|
73 |
+
# print(data)
|
74 |
+
if data['meta']['redpajama_set_name']=='RedPajamaWikipedia':
|
75 |
+
if is_english(data['text']):
|
76 |
+
wikipedia_data.append(data)
|
77 |
+
|
78 |
+
with open(wikipedia_out_path, 'w', encoding='utf-8') as f:
|
79 |
+
for data in wikipedia_data:
|
80 |
+
json_data = json.dumps(data)
|
81 |
+
f.write(json_data+'\n')
|
82 |
+
|
83 |
+
wikipedia_data = []
|
84 |
+
ner_model_path = 'kc-ner-model'
|
85 |
+
with open(wikipedia_out_path, 'r', encoding='utf-8') as f:
|
86 |
+
for line in tqdm(f):
|
87 |
+
data = json.loads(line)
|
88 |
+
wikipedia_data.append(data['text'])
|
89 |
+
run_ner_linking(wikipedia_data, ner_model_path)
|
90 |
+
|
91 |
+
entity_info_path = 'result/entity_info.json'
|
92 |
+
with open(entity_info_path, 'r', encoding='utf-8') as f:
|
93 |
+
entity_info = json.load(f)
|
94 |
+
all_original_data = []
|
95 |
+
|
96 |
+
category = {}
|
97 |
+
all_data = []
|
98 |
+
with open('result/temp_store_data.json', 'r', encoding='utf-8') as f:
|
99 |
+
for line in f:
|
100 |
+
data = json.loads(line)
|
101 |
+
all_data.append(data)
|
102 |
+
if data['label'] not in category:
|
103 |
+
category[data['label']] = []
|
104 |
+
category[data['label']].append(data['main_entity'])
|
105 |
+
|
106 |
+
with open('result/processed_data.json', 'w', encoding='utf-8') as f:
|
107 |
+
for data in tqdm(all_data):
|
108 |
+
text = data['text']
|
109 |
+
main_entity = [data['main_entity']]
|
110 |
+
if data['id'] in entity_info:
|
111 |
+
main_entity.extend(entity_info[data['id']]['aliases'])
|
112 |
+
if len(category[data['label']]) == 1:
|
113 |
+
continue
|
114 |
+
replaced_eneity = random.sample(category[data['label']], 1)
|
115 |
+
while replaced_eneity[0] in main_entity:
|
116 |
+
replaced_eneity = random.sample(category[data['label']], 1)
|
117 |
+
for entity in main_entity:
|
118 |
+
text = text.replace(entity, replaced_eneity[0])
|
119 |
+
data = {
|
120 |
+
'text':text,
|
121 |
+
'original_main_entity':main_entity,
|
122 |
+
'replaced_entity':replaced_eneity[0]
|
123 |
+
}
|
124 |
+
json_data = json.dumps(data)
|
125 |
+
f.write(json_data+'\n')
|
126 |
+
|
knowledge_conflict_entity_based/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spacy==2.2.4
|
2 |
+
langdetect
|
3 |
+
zstandard
|
4 |
+
tqdm
|
5 |
+
wget
|
knowledge_conflict_entity_based/result/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
knowledge_conflict_entity_based/result/entity_info.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:423a217fa602456b961b6169b0bac15659ec85c90de2b261ca924c0ebe7d04a4
|
3 |
+
size 742977816
|
knowledge_conflict_entity_based/run.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
### the processed data will be stored to the path {result/processed_data.json}
|
2 |
+
python entity_substitute.py --input_dir /opt/data/private/szc/ml-knowledge-conflicts-main/test
|
knowledge_conflict_entity_based/setup.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pip install -r requirements.txt
|
2 |
+
|
3 |
+
# Download the SpaCy Named Entity Recognizer (NER) and Entity Linker (EL) model
|
4 |
+
# See https://spacy.io/usage/linguistic-features#named-entities and https://v2.spacy.io/usage/training#entity-linker
|
5 |
+
wget https://docs-assets.developer.apple.com/ml-research/models/kc-ner/model.gz -O kc-ner-model.gz
|
6 |
+
tar -xvzf kc-ner-model.gz -C kc-ner-model
|