|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import pandas as pd |
|
from collections import Counter |
|
import numpy as np |
|
from sklearn.datasets import fetch_20newsgroups |
|
from collections import Counter, defaultdict |
|
from nltk.corpus import stopwords |
|
from sklearn.model_selection import train_test_split |
|
import re |
|
from sklearn.utils import shuffle |
|
|
|
|
|
|
|
def cos_dist(x, y): |
|
|
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6) |
|
batch_size = x.size(0) |
|
c = torch.clamp(1 - cos(x.view(batch_size, -1), y.view(batch_size, -1)), |
|
min=0) |
|
return c.mean() |
|
|
|
|
|
|
|
|
|
def tag_mapping(tags): |
|
""" |
|
Create a dictionary and a mapping of tags, sorted by frequency. |
|
""" |
|
|
|
dico = Counter(tags) |
|
tag_to_id, id_to_tag = create_mapping(dico) |
|
print("Found %i unique named entity tags" % len(dico)) |
|
return dico, tag_to_id, id_to_tag |
|
|
|
|
|
def create_mapping(dico): |
|
""" |
|
Create a mapping (item to ID / ID to item) from a dictionary. |
|
Items are ordered by decreasing frequency. |
|
""" |
|
sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0])) |
|
id_to_item = {i: v[0] for i, v in enumerate(sorted_items)} |
|
item_to_id = {v: k for k, v in id_to_item.items()} |
|
return item_to_id, id_to_item |
|
|
|
|
|
|
|
|
|
def clean_str(string): |
|
""" |
|
Tokenization/string cleaning for all datasets except for SST. |
|
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py |
|
""" |
|
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) |
|
string = re.sub(r"\'s", " \'s", string) |
|
string = re.sub(r"\'ve", " \'ve", string) |
|
string = re.sub(r"n\'t", " n\'t", string) |
|
string = re.sub(r"\'re", " \'re", string) |
|
string = re.sub(r"\'d", " \'d", string) |
|
string = re.sub(r"\'ll", " \'ll", string) |
|
string = re.sub(r",", " , ", string) |
|
string = re.sub(r"!", " ! ", string) |
|
string = re.sub(r"\(", " \( ", string) |
|
string = re.sub(r"\)", " \) ", string) |
|
string = re.sub(r"\?", " \? ", string) |
|
string = re.sub(r"\s{2,}", " ", string) |
|
return string.strip().lower() |
|
|
|
|
|
def clean_doc(x, word_freq): |
|
stop_words = set(stopwords.words('english')) |
|
clean_docs = [] |
|
most_commons = dict(word_freq.most_common(min(len(word_freq), 50000))) |
|
for doc_content in x: |
|
doc_words = [] |
|
cleaned = clean_str(doc_content.strip()) |
|
for word in cleaned.split(): |
|
if word not in stop_words and word_freq[word] >= 5: |
|
if word in most_commons: |
|
doc_words.append(word) |
|
else: |
|
doc_words.append("<UNK>") |
|
doc_str = ' '.join(doc_words).strip() |
|
clean_docs.append(doc_str) |
|
return clean_docs |
|
|
|
|
|
|
|
def load_dataset(dataset): |
|
|
|
if dataset == 'sst': |
|
df_train = pd.read_csv("./dataset/sst/SST-2/train.tsv", delimiter='\t', header=0) |
|
|
|
df_val = pd.read_csv("./dataset/sst/SST-2/dev.tsv", delimiter='\t', header=0) |
|
|
|
df_test = pd.read_csv("./dataset/sst/SST-2/sst-test.tsv", delimiter='\t', header=None, names=['sentence', 'label']) |
|
|
|
train_sentences = df_train.sentence.values |
|
val_sentences = df_val.sentence.values |
|
test_sentences = df_test.sentence.values |
|
train_labels = df_train.label.values |
|
val_labels = df_val.label.values |
|
test_labels = df_test.label.values |
|
|
|
|
|
if dataset == '20news': |
|
|
|
VALIDATION_SPLIT = 0.8 |
|
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, random_state=0) |
|
print(newsgroups_train.target_names) |
|
print(len(newsgroups_train.data)) |
|
|
|
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False) |
|
|
|
print(len(newsgroups_test.data)) |
|
|
|
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data)) |
|
|
|
train_sentences = newsgroups_train.data[:train_len] |
|
val_sentences = newsgroups_train.data[train_len:] |
|
test_sentences = newsgroups_test.data |
|
train_labels = newsgroups_train.target[:train_len] |
|
val_labels = newsgroups_train.target[train_len:] |
|
test_labels = newsgroups_test.target |
|
|
|
|
|
|
|
if dataset == '20news-15': |
|
VALIDATION_SPLIT = 0.8 |
|
cats = ['alt.atheism', |
|
'comp.graphics', |
|
'comp.os.ms-windows.misc', |
|
'comp.sys.ibm.pc.hardware', |
|
'comp.sys.mac.hardware', |
|
'comp.windows.x', |
|
'rec.autos', |
|
'rec.motorcycles', |
|
'rec.sport.baseball', |
|
'rec.sport.hockey', |
|
'misc.forsale', |
|
'sci.crypt', |
|
'sci.electronics', |
|
'sci.med', |
|
'sci.space'] |
|
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, categories=cats, random_state=0) |
|
print(newsgroups_train.target_names) |
|
print(len(newsgroups_train.data)) |
|
|
|
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats) |
|
|
|
print(len(newsgroups_test.data)) |
|
|
|
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data)) |
|
|
|
train_sentences = newsgroups_train.data[:train_len] |
|
val_sentences = newsgroups_train.data[train_len:] |
|
test_sentences = newsgroups_test.data |
|
train_labels = newsgroups_train.target[:train_len] |
|
val_labels = newsgroups_train.target[train_len:] |
|
test_labels = newsgroups_test.target |
|
|
|
|
|
if dataset == '20news-5': |
|
cats = [ |
|
'soc.religion.christian', |
|
'talk.politics.guns', |
|
'talk.politics.mideast', |
|
'talk.politics.misc', |
|
'talk.religion.misc'] |
|
|
|
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats) |
|
print(newsgroups_test.target_names) |
|
print(len(newsgroups_test.data)) |
|
|
|
train_sentences = None |
|
val_sentences = None |
|
test_sentences = newsgroups_test.data |
|
train_labels = None |
|
val_labels = None |
|
test_labels = newsgroups_test.target |
|
|
|
if dataset == 'wos': |
|
TESTING_SPLIT = 0.6 |
|
VALIDATION_SPLIT = 0.8 |
|
file_path = './dataset/WebOfScience/WOS46985/X.txt' |
|
with open(file_path, 'r') as read_file: |
|
x_temp = read_file.readlines() |
|
x_all = [] |
|
for x in x_temp: |
|
x_all.append(str(x)) |
|
|
|
print(len(x_all)) |
|
|
|
file_path = './dataset/WebOfScience/WOS46985/Y.txt' |
|
with open(file_path, 'r') as read_file: |
|
y_temp= read_file.readlines() |
|
y_all = [] |
|
for y in y_temp: |
|
y_all.append(int(y)) |
|
print(len(y_all)) |
|
print(max(y_all), min(y_all)) |
|
|
|
|
|
x_in = [] |
|
y_in = [] |
|
for i in range(len(x_all)): |
|
x_in.append(x_all[i]) |
|
y_in.append(y_all[i]) |
|
|
|
|
|
train_val_len = int(TESTING_SPLIT * len(x_in)) |
|
train_len = int(VALIDATION_SPLIT * train_val_len) |
|
|
|
train_sentences = x_in[:train_len] |
|
val_sentences = x_in[train_len:train_val_len] |
|
test_sentences = x_in[train_val_len:] |
|
|
|
train_labels = y_in[:train_len] |
|
val_labels = y_in[train_len:train_val_len] |
|
test_labels = y_in[train_val_len:] |
|
|
|
print(len(train_labels)) |
|
print(len(val_labels)) |
|
print(len(test_labels)) |
|
|
|
|
|
if dataset == 'wos-100': |
|
TESTING_SPLIT = 0.6 |
|
VALIDATION_SPLIT = 0.8 |
|
file_path = './dataset/WebOfScience/WOS46985/X.txt' |
|
with open(file_path, 'r') as read_file: |
|
x_temp = read_file.readlines() |
|
x_all = [] |
|
for x in x_temp: |
|
x_all.append(str(x)) |
|
|
|
print(len(x_all)) |
|
|
|
file_path = './dataset/WebOfScience/WOS46985/Y.txt' |
|
with open(file_path, 'r') as read_file: |
|
y_temp= read_file.readlines() |
|
y_all = [] |
|
for y in y_temp: |
|
y_all.append(int(y)) |
|
print(len(y_all)) |
|
print(max(y_all), min(y_all)) |
|
|
|
|
|
x_in = [] |
|
y_in = [] |
|
for i in range(len(x_all)): |
|
if y_all[i] in range(100): |
|
x_in.append(x_all[i]) |
|
y_in.append(y_all[i]) |
|
|
|
for i in range(133): |
|
num = 0 |
|
for y in y_in: |
|
if y == i: |
|
num = num + 1 |
|
|
|
|
|
train_val_len = int(TESTING_SPLIT * len(x_in)) |
|
train_len = int(VALIDATION_SPLIT * train_val_len) |
|
|
|
train_sentences = x_in[:train_len] |
|
val_sentences = x_in[train_len:train_val_len] |
|
test_sentences = x_in[train_val_len:] |
|
|
|
train_labels = y_in[:train_len] |
|
val_labels = y_in[train_len:train_val_len] |
|
test_labels = y_in[train_val_len:] |
|
|
|
print(len(train_labels)) |
|
print(len(val_labels)) |
|
print(len(test_labels)) |
|
|
|
if dataset == 'wos-34': |
|
TESTING_SPLIT = 0.6 |
|
VALIDATION_SPLIT = 0.8 |
|
file_path = './dataset/WebOfScience/WOS46985/X.txt' |
|
with open(file_path, 'r') as read_file: |
|
x_temp = read_file.readlines() |
|
x_all = [] |
|
for x in x_temp: |
|
x_all.append(str(x)) |
|
|
|
print(len(x_all)) |
|
|
|
file_path = './dataset/WebOfScience/WOS46985/Y.txt' |
|
with open(file_path, 'r') as read_file: |
|
y_temp= read_file.readlines() |
|
y_all = [] |
|
for y in y_temp: |
|
y_all.append(int(y)) |
|
print(len(y_all)) |
|
print(max(y_all), min(y_all)) |
|
|
|
x_in = [] |
|
y_in = [] |
|
for i in range(len(x_all)): |
|
if (y_all[i] in range(100)) != True: |
|
x_in.append(x_all[i]) |
|
y_in.append(y_all[i]) |
|
|
|
for i in range(133): |
|
num = 0 |
|
for y in y_in: |
|
if y == i: |
|
num = num + 1 |
|
|
|
|
|
train_val_len = int(TESTING_SPLIT * len(x_in)) |
|
train_len = int(VALIDATION_SPLIT * train_val_len) |
|
|
|
train_sentences = None |
|
val_sentences = None |
|
test_sentences = x_in[train_val_len:] |
|
|
|
train_labels = None |
|
val_labels = None |
|
test_labels = y_in[train_val_len:] |
|
|
|
print(len(test_labels)) |
|
|
|
if dataset == 'agnews': |
|
|
|
VALIDATION_SPLIT = 0.8 |
|
labels_in_domain = [1, 2] |
|
|
|
train_df = pd.read_csv('./dataset/agnews/train.csv', header=None) |
|
train_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True) |
|
|
|
print(train_df.dtypes) |
|
train_in_df_sentence = [] |
|
train_in_df_label = [] |
|
|
|
for i in range(len(train_df.sentence.values)): |
|
sentence_temp = ''.join(str(train_df.sentence.values[i])) |
|
train_in_df_sentence.append(sentence_temp) |
|
train_in_df_label.append(train_df.label.values[i]-1) |
|
|
|
test_df = pd.read_csv('./dataset/agnews/test.csv', header=None) |
|
test_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True) |
|
|
|
test_in_df_sentence = [] |
|
test_in_df_label = [] |
|
for i in range(len(test_df.sentence.values)): |
|
test_in_df_sentence.append(str(test_df.sentence.values[i])) |
|
test_in_df_label.append(test_df.label.values[i]-1) |
|
|
|
train_len = int(VALIDATION_SPLIT * len(train_in_df_sentence)) |
|
|
|
train_sentences = train_in_df_sentence[:train_len] |
|
val_sentences = train_in_df_sentence[train_len:] |
|
test_sentences = test_in_df_sentence |
|
train_labels = train_in_df_label[:train_len] |
|
val_labels = train_in_df_label[train_len:] |
|
test_labels = test_in_df_label |
|
print(len(train_sentences)) |
|
print(len(val_sentences)) |
|
print(len(test_sentences)) |
|
|
|
|
|
return train_sentences, val_sentences, test_sentences, train_labels, val_labels, test_labels |
|
|
|
|