Spaces:
Runtime error
Runtime error
PROJECT_PATH = 'cleaned_code' | |
import os | |
import sys | |
sys.path.append(PROJECT_PATH) | |
import numpy as np | |
import pickle | |
import h5py | |
from tqdm import tqdm | |
from transformers import AutoTokenizer | |
from scipy.special import expit | |
import torch | |
from typing import Optional | |
import json | |
from src import BertForSemanticEmbedding, getLabelModel | |
from src import DataTrainingArguments, ModelArguments, CustomTrainingArguments, read_yaml_config | |
from src import dataset_classification_type | |
from src import SemSupDataset | |
from transformers import AutoConfig, HfArgumentParser, AutoTokenizer | |
import torch | |
import json | |
from tqdm import tqdm | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def compute_tok_score_cart(doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask): | |
qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1 | |
doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD | |
exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD | |
exact_match = exact_match.float() | |
scores_no_masking = torch.matmul( | |
qry_reps.view(-1, 16), # (Q * LQ) * d | |
doc_reps.view(-1, 16).transpose(0, 1) # d * (D * LD) | |
) | |
scores_no_masking = scores_no_masking.view( | |
*qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD | |
scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D | |
tok_scores = (scores * qry_attention_mask.reshape(-1, qry_attention_mask.shape[-1]).unsqueeze(2))[:, 1:].sum(1) | |
return tok_scores | |
def coil_fast_eval_forward( | |
input_ids: Optional[torch.Tensor] = None, | |
doc_reps = None, | |
logits: Optional[torch.Tensor] = None, | |
desc_input_ids = None, | |
desc_attention_mask = None, | |
lab_reps = None, | |
label_embeddings = None | |
): | |
tok_scores = compute_tok_score_cart( | |
doc_reps, input_ids, | |
lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask | |
) | |
logits = (logits.unsqueeze(0) @ label_embeddings.T) | |
new_tok_scores = torch.zeros(logits.shape, device = logits.device) | |
for i in range(tok_scores.shape[1]): | |
stride = tok_scores.shape[0]//tok_scores.shape[1] | |
new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i] | |
return (logits + new_tok_scores).squeeze() | |
class DemoModel: | |
def __init__(self, ): | |
self.label_list = [x.strip() for x in open(f'{PROJECT_PATH}/datasets/Amzn13K/all_labels.txt')] | |
unseen_label_list = [x.strip() for x in open(f'{PROJECT_PATH}/datasets/Amzn13K/unseen_labels_split6500_2.txt')] | |
num_labels = len(self.label_list) | |
self.label_list.sort() # For consistency | |
l2i = {v: i for i, v in enumerate(self.label_list)} | |
unseen_label_indexes = [l2i[x] for x in unseen_label_list] | |
self.coil_cluster_map = json.load(open(f'{PROJECT_PATH}/bert_coil_map_dict_lemma255K_isotropic.json')) | |
all_lab_reps1, all_label_embeddings1, all_desc_input_ids_orig1, all_desc_input_ids1, all_desc_attention_mask1 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_1.pkl','rb')) | |
all_lab_reps2, all_label_embeddings2, all_desc_input_ids_orig2, all_desc_input_ids2, all_desc_attention_mask2 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_2.pkl','rb')) | |
all_lab_reps3, all_label_embeddings3, all_desc_input_ids_orig3, all_desc_input_ids3, all_desc_attention_mask3 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_3.pkl','rb')) | |
all_lab_reps4, all_label_embeddings4, all_desc_input_ids_orig4, all_desc_input_ids4, all_desc_attention_mask4 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_4.pkl','rb')) | |
all_lab_reps5, all_label_embeddings5, all_desc_input_ids_orig5, all_desc_input_ids5, all_desc_attention_mask5 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_5.pkl','rb')) | |
self.all_lab_reps = [all_lab_reps1.to(device), all_lab_reps2.to(device), all_lab_reps3.to(device), all_lab_reps4.to(device), all_lab_reps5.to(device)] | |
self.all_label_embeddings = [all_label_embeddings1.to(device), all_label_embeddings2.to(device), all_label_embeddings3.to(device), all_label_embeddings4.to(device), all_label_embeddings5.to(device)] | |
self.all_desc_input_ids_orig = [all_desc_input_ids_orig1.to(device), all_desc_input_ids_orig2.to(device), all_desc_input_ids_orig3.to(device), all_desc_input_ids_orig4.to(device), all_desc_input_ids_orig5.to(device)] | |
self.all_desc_input_ids = [all_desc_input_ids1.to(device), all_desc_input_ids2.to(device), all_desc_input_ids3.to(device), all_desc_input_ids4.to(device), all_desc_input_ids5.to(device)] | |
self.all_desc_attention_mask = [all_desc_attention_mask1.to(device), all_desc_attention_mask2.to(device), all_desc_attention_mask3.to(device), all_desc_attention_mask4.to(device), all_desc_attention_mask5.to(device)] | |
ARGS_FILE = f'{PROJECT_PATH}/configs/ablation_amzn_eda.yml' | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments)) | |
self.model_args, self.data_args, self.training_args = parser.parse_dict(read_yaml_config(ARGS_FILE, output_dir = 'demo_tmp', extra_args = {})) | |
config = AutoConfig.from_pretrained( | |
self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path, | |
finetuning_task=self.data_args.task_name, | |
cache_dir=self.model_args.cache_dir, | |
revision=self.model_args.model_revision, | |
use_auth_token=True if self.model_args.use_auth_token else None, | |
) | |
config.model_name_or_path = self.model_args.model_name_or_path | |
config.problem_type = dataset_classification_type[self.data_args.task_name] | |
config.negative_sampling = self.model_args.negative_sampling | |
config.semsup = self.model_args.semsup | |
config.encoder_model_type = self.model_args.encoder_model_type | |
config.arch_type = self.model_args.arch_type | |
config.coil = self.model_args.coil | |
config.token_dim = self.model_args.token_dim | |
config.colbert = self.model_args.colbert | |
label_model, label_tokenizer = getLabelModel(self.data_args, self.model_args) | |
config.label_hidden_size = label_model.config.hidden_size | |
model = BertForSemanticEmbedding(config) | |
model.label_model = label_model | |
model.label_tokenizer = label_tokenizer | |
model.config.label2id = {l: i for i, l in enumerate(self.label_list)} | |
model.config.id2label = {id: label for label, id in config.label2id.items()} | |
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
model.to(device) | |
model.eval() | |
torch.set_grad_enabled(False) | |
model.load_state_dict(torch.load(f'{PROJECT_PATH}/ckpt/Amzn13K/amzn_main_model.bin', map_location = device)) | |
self.model = model | |
self.extracted_descs = [self.extract_descriptions(adi) for adi in self.all_desc_input_ids_orig] | |
tot_len = len(self.all_desc_input_ids_orig) | |
for i in range(len(self.all_desc_input_ids_orig[0])): | |
for j in range(tot_len): | |
if self.extracted_descs[j][i] == "": | |
for k in range(tot_len): | |
if self.extracted_descs[k][i] != '': | |
self.extracted_descs[j][i] = self.extracted_descs[k][i] | |
break | |
def extract_descriptions(self, input_ids): | |
descs = self.tokenizer.batch_decode(input_ids, skip_special_tokens = True) | |
new_descs = [] | |
for desc in descs: | |
a = desc.find('description is') | |
if a == -1: | |
# There is no description to use, lets go with empty | |
new_descs.append("") | |
continue | |
b = min([desc.find(x, a) if desc.find(x, a) !=-1 else 99999999999 for x in ['label is','parents are','children are']]) | |
if b == 99999999999: | |
new_descs.append(desc[a:].strip()) | |
else: | |
new_descs.append(desc[a:b].strip()) | |
return new_descs | |
def classify(self, text, unseen_labels = None): | |
self.model.eval() | |
with torch.no_grad(): | |
item = self.tokenizer(text, padding='max_length', max_length=self.data_args.max_seq_length, truncation=True) | |
item = {k:torch.tensor(v, device = device).unsqueeze(0) for k,v in item.items()} | |
outputs_doc, logits = self.model.forward_input_encoder(**item) | |
doc_reps = self.model.tok_proj(outputs_doc.last_hidden_state) | |
input_ids = torch.tensor([self.coil_cluster_map[str(x.item())] for x in item['input_ids'][0]]).to(device).unsqueeze(0) | |
all_logits = [] | |
descriptions = [] | |
for adi, ada, alr, ale in zip(self.all_desc_input_ids, self.all_desc_attention_mask, self.all_lab_reps, self.all_label_embeddings): | |
all_logits.append(coil_fast_eval_forward(input_ids, doc_reps, logits, adi, ada, alr, ale)) | |
final_logits = sum([expit(x.cpu()) for x in all_logits]) / len(all_logits) | |
max_indices = torch.argmax(torch.stack(all_logits), dim=0).cpu().tolist() | |
# from pdb import set_trace as bp | |
# bp() | |
outs = torch.topk(final_logits, k = 50) | |
preds_dic = dict() | |
descs_dic = dict() | |
for i,v in zip(outs.indices, outs.values): | |
preds_dic[self.label_list[i]] = v.item() | |
print(self.extracted_descs[max_indices[i]][i]) | |
descs_dic[self.label_list[i]] = self.extracted_descs[max_indices[i]][i] | |
return preds_dic, descs_dic | |
if __name__ == '__main__': | |
model = DemoModel() | |
model.classify('Hello') |