astra / src /utils.py
suryadev1's picture
v1
6a34fd4
raw
history blame
12.1 kB
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
from collections import Counter
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from collections import Counter, defaultdict
from nltk.corpus import stopwords
from sklearn.model_selection import train_test_split
import re
from sklearn.utils import shuffle
def cos_dist(x, y):
## cosine distance function
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
batch_size = x.size(0)
c = torch.clamp(1 - cos(x.view(batch_size, -1), y.view(batch_size, -1)),
min=0)
return c.mean()
def tag_mapping(tags):
"""
Create a dictionary and a mapping of tags, sorted by frequency.
"""
#tags = [s[1] for s in dataset]
dico = Counter(tags)
tag_to_id, id_to_tag = create_mapping(dico)
print("Found %i unique named entity tags" % len(dico))
return dico, tag_to_id, id_to_tag
def create_mapping(dico):
"""
Create a mapping (item to ID / ID to item) from a dictionary.
Items are ordered by decreasing frequency.
"""
sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0]))
id_to_item = {i: v[0] for i, v in enumerate(sorted_items)}
item_to_id = {v: k for k, v in id_to_item.items()}
return item_to_id, id_to_item
def clean_str(string):
"""
Tokenization/string cleaning for all datasets except for SST.
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
"""
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip().lower()
def clean_doc(x, word_freq):
stop_words = set(stopwords.words('english'))
clean_docs = []
most_commons = dict(word_freq.most_common(min(len(word_freq), 50000)))
for doc_content in x:
doc_words = []
cleaned = clean_str(doc_content.strip())
for word in cleaned.split():
if word not in stop_words and word_freq[word] >= 5:
if word in most_commons:
doc_words.append(word)
else:
doc_words.append("<UNK>")
doc_str = ' '.join(doc_words).strip()
clean_docs.append(doc_str)
return clean_docs
def load_dataset(dataset):
if dataset == 'sst':
df_train = pd.read_csv("./dataset/sst/SST-2/train.tsv", delimiter='\t', header=0)
df_val = pd.read_csv("./dataset/sst/SST-2/dev.tsv", delimiter='\t', header=0)
df_test = pd.read_csv("./dataset/sst/SST-2/sst-test.tsv", delimiter='\t', header=None, names=['sentence', 'label'])
train_sentences = df_train.sentence.values
val_sentences = df_val.sentence.values
test_sentences = df_test.sentence.values
train_labels = df_train.label.values
val_labels = df_val.label.values
test_labels = df_test.label.values
if dataset == '20news':
VALIDATION_SPLIT = 0.8
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, random_state=0)
print(newsgroups_train.target_names)
print(len(newsgroups_train.data))
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False)
print(len(newsgroups_test.data))
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))
train_sentences = newsgroups_train.data[:train_len]
val_sentences = newsgroups_train.data[train_len:]
test_sentences = newsgroups_test.data
train_labels = newsgroups_train.target[:train_len]
val_labels = newsgroups_train.target[train_len:]
test_labels = newsgroups_test.target
if dataset == '20news-15':
VALIDATION_SPLIT = 0.8
cats = ['alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc',
'comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware',
'comp.windows.x',
'rec.autos',
'rec.motorcycles',
'rec.sport.baseball',
'rec.sport.hockey',
'misc.forsale',
'sci.crypt',
'sci.electronics',
'sci.med',
'sci.space']
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, categories=cats, random_state=0)
print(newsgroups_train.target_names)
print(len(newsgroups_train.data))
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats)
print(len(newsgroups_test.data))
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))
train_sentences = newsgroups_train.data[:train_len]
val_sentences = newsgroups_train.data[train_len:]
test_sentences = newsgroups_test.data
train_labels = newsgroups_train.target[:train_len]
val_labels = newsgroups_train.target[train_len:]
test_labels = newsgroups_test.target
if dataset == '20news-5':
cats = [
'soc.religion.christian',
'talk.politics.guns',
'talk.politics.mideast',
'talk.politics.misc',
'talk.religion.misc']
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats)
print(newsgroups_test.target_names)
print(len(newsgroups_test.data))
train_sentences = None
val_sentences = None
test_sentences = newsgroups_test.data
train_labels = None
val_labels = None
test_labels = newsgroups_test.target
if dataset == 'wos':
TESTING_SPLIT = 0.6
VALIDATION_SPLIT = 0.8
file_path = './dataset/WebOfScience/WOS46985/X.txt'
with open(file_path, 'r') as read_file:
x_temp = read_file.readlines()
x_all = []
for x in x_temp:
x_all.append(str(x))
print(len(x_all))
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
with open(file_path, 'r') as read_file:
y_temp= read_file.readlines()
y_all = []
for y in y_temp:
y_all.append(int(y))
print(len(y_all))
print(max(y_all), min(y_all))
x_in = []
y_in = []
for i in range(len(x_all)):
x_in.append(x_all[i])
y_in.append(y_all[i])
train_val_len = int(TESTING_SPLIT * len(x_in))
train_len = int(VALIDATION_SPLIT * train_val_len)
train_sentences = x_in[:train_len]
val_sentences = x_in[train_len:train_val_len]
test_sentences = x_in[train_val_len:]
train_labels = y_in[:train_len]
val_labels = y_in[train_len:train_val_len]
test_labels = y_in[train_val_len:]
print(len(train_labels))
print(len(val_labels))
print(len(test_labels))
if dataset == 'wos-100':
TESTING_SPLIT = 0.6
VALIDATION_SPLIT = 0.8
file_path = './dataset/WebOfScience/WOS46985/X.txt'
with open(file_path, 'r') as read_file:
x_temp = read_file.readlines()
x_all = []
for x in x_temp:
x_all.append(str(x))
print(len(x_all))
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
with open(file_path, 'r') as read_file:
y_temp= read_file.readlines()
y_all = []
for y in y_temp:
y_all.append(int(y))
print(len(y_all))
print(max(y_all), min(y_all))
x_in = []
y_in = []
for i in range(len(x_all)):
if y_all[i] in range(100):
x_in.append(x_all[i])
y_in.append(y_all[i])
for i in range(133):
num = 0
for y in y_in:
if y == i:
num = num + 1
# print(num)
train_val_len = int(TESTING_SPLIT * len(x_in))
train_len = int(VALIDATION_SPLIT * train_val_len)
train_sentences = x_in[:train_len]
val_sentences = x_in[train_len:train_val_len]
test_sentences = x_in[train_val_len:]
train_labels = y_in[:train_len]
val_labels = y_in[train_len:train_val_len]
test_labels = y_in[train_val_len:]
print(len(train_labels))
print(len(val_labels))
print(len(test_labels))
if dataset == 'wos-34':
TESTING_SPLIT = 0.6
VALIDATION_SPLIT = 0.8
file_path = './dataset/WebOfScience/WOS46985/X.txt'
with open(file_path, 'r') as read_file:
x_temp = read_file.readlines()
x_all = []
for x in x_temp:
x_all.append(str(x))
print(len(x_all))
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
with open(file_path, 'r') as read_file:
y_temp= read_file.readlines()
y_all = []
for y in y_temp:
y_all.append(int(y))
print(len(y_all))
print(max(y_all), min(y_all))
x_in = []
y_in = []
for i in range(len(x_all)):
if (y_all[i] in range(100)) != True:
x_in.append(x_all[i])
y_in.append(y_all[i])
for i in range(133):
num = 0
for y in y_in:
if y == i:
num = num + 1
# print(num)
train_val_len = int(TESTING_SPLIT * len(x_in))
train_len = int(VALIDATION_SPLIT * train_val_len)
train_sentences = None
val_sentences = None
test_sentences = x_in[train_val_len:]
train_labels = None
val_labels = None
test_labels = y_in[train_val_len:]
print(len(test_labels))
if dataset == 'agnews':
VALIDATION_SPLIT = 0.8
labels_in_domain = [1, 2]
train_df = pd.read_csv('./dataset/agnews/train.csv', header=None)
train_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
# train_df = pd.concat([train_df, pd.get_dummies(train_df['label'],prefix='label')], axis=1)
print(train_df.dtypes)
train_in_df_sentence = []
train_in_df_label = []
for i in range(len(train_df.sentence.values)):
sentence_temp = ''.join(str(train_df.sentence.values[i]))
train_in_df_sentence.append(sentence_temp)
train_in_df_label.append(train_df.label.values[i]-1)
test_df = pd.read_csv('./dataset/agnews/test.csv', header=None)
test_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
# test_df = pd.concat([test_df, pd.get_dummies(test_df['label'],prefix='label')], axis=1)
test_in_df_sentence = []
test_in_df_label = []
for i in range(len(test_df.sentence.values)):
test_in_df_sentence.append(str(test_df.sentence.values[i]))
test_in_df_label.append(test_df.label.values[i]-1)
train_len = int(VALIDATION_SPLIT * len(train_in_df_sentence))
train_sentences = train_in_df_sentence[:train_len]
val_sentences = train_in_df_sentence[train_len:]
test_sentences = test_in_df_sentence
train_labels = train_in_df_label[:train_len]
val_labels = train_in_df_label[train_len:]
test_labels = test_in_df_label
print(len(train_sentences))
print(len(val_sentences))
print(len(test_sentences))
return train_sentences, val_sentences, test_sentences, train_labels, val_labels, test_labels