Model Description

The model classifies whether two appraisals aligned or not and is trained on ALOE dataset.

Input: two appraisals (see forward function in SNN class)

Output: cosine similarity

Model architecture: Siamese Network + all-mpnet-base-v2

Developed by: Jiamin Yang

Model Performance

F1 Recall Precision
0.46 0.45 0.46

Getting Started

import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer

class SNN(nn.Module): 
    def __init__(self, model_name):
        super(SNN,self).__init__()
        self.model = AutoModel.from_pretrained(model_name).to("cuda").train()
        self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4)
        
    def mean_pooling(self, token_embeddings, attention_mask): 
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    def forward(self, input_ids_a, attention_a, input_ids_b, attention_b): 
        #encode sentence and get mean pooled sentence representation 
        encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings
        encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0]
        
        meanPooled1 = self.mean_pooling(encoding1, attention_a)
        meanPooled2 = self.mean_pooling(encoding2, attention_b)
        
        pred = self.cos(meanPooled1, meanPooled2)
        return pred

checkpoint_path = 'your_path_to/empathy-appraisal-alignment.pt'

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = SNN('sentence-transformers/all-mpnet-base-v2').to('cuda')
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['model_state_dict']

# depend on the version of torch
del state_dict['model.embeddings.position_ids']

model.load_state_dict(state_dict)

# use the model
target = ["I'm so sad that my cat died yesterday."]
observer = ["It's ok to feel sad."]

target_encodings = tokenizer(target, padding=True, truncation=True)
target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda')
target_attention_mask = torch.LongTensor(target_encodings['attention_mask']).to('cuda')
observer_encodings = tokenizer(observer, padding=True, truncation=True)
observer_input_ids = torch.LongTensor(observer_encodings['input_ids']).to('cuda')
observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask']).to('cuda')

model.eval()
output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask)
print(output) # [0.5755]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train Blablablab/empathy-appraisal-alignment