norbench_SA / sentiment_wrapper.py
AnnaPalatkina's picture
m
9761446
raw
history blame
2.92 kB
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import classification_report, f1_score
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from config import params
from torch import nn
import pandas as pd
import numpy as np
import warnings
import random
import torch
import os
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Dataset(Dataset):
def __init__(self, texts, max_len):
self.texts = texts
self.tokenizer = BertTokenizer.from_pretrained(params['pretrained_model_name'])
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, item):
text = str(self.texts[item])
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
truncation=True,
return_tensors='pt',
)
return {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
}
class SentimentClassifier(nn.Module):
def __init__(self, n_classes):
super(SentimentClassifier, self).__init__()
self.bert = BertModel.from_pretrained(params['pretrained_model_name'])
self.drop = nn.Dropout(params['dropout'])
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask):
bert_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=False
)
last_hidden_state, pooled_output = bert_output
output = self.drop(pooled_output)
return self.out(output)
class PredictionModel:
def __init__(self):
self.model = SentimentClassifier(n_classes = 6)
self.loss_fn = nn.CrossEntropyLoss().to(device)
def create_data_loader(self, X_test, max_len, batch_size):
ds = Dataset(
texts= np.array(X_test),
max_len=max_len
)
return DataLoader(
ds,
batch_size=batch_size
)
def predict(self, X_test: list):
data_loader = self.create_data_loader(X_test, params['max_length'], params['batch_size'])
self.model.load_state_dict(torch.load(params['path_to_model_bin']))
self.model.eval()
losses = []
y_pred = []
with torch.no_grad():
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
y_pred += preds.tolist()
return y_pred