Bangla-PoS-Taggers / helper /alignment_mappers.py
musfiqdehan's picture
Update alignment_mappers.py with new model name
0a0f809
"""
This module contains the helper functions to get the word alignment mapping between two sentences.
"""
import torch
import itertools
import transformers
from transformers import logging
# Set the verbosity to error, so that the warning messages are not printed
logging.set_verbosity_warning()
logging.set_verbosity_error()
def select_model(model_name):
"""
Select Model
"""
if model_name == "Google-mBERT (Base-Multilingual)":
model_name="bert-base-multilingual-cased"
elif model_name == "Neulab-AwesomeAlign (Bn-En-0.5M)":
model_name="musfiqdehan/bn-en-word-aligner"
elif model_name == "BUET-BanglaBERT (Large)":
model_name="csebuetnlp/banglabert_large"
elif model_name == "SagorSarker-BanglaBERT (Base)":
model_name="sagorsarker/bangla-bert-base"
elif model_name == "SentenceTransformers-LaBSE (Multilingual)":
model_name="sentence-transformers/LaBSE"
return model_name
def get_alignment_mapping(source="", target="", model_name=""):
"""
Get Aligned Words
"""
model_name = select_model(model_name)
model = transformers.BertModel.from_pretrained(model_name)
tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
# pre-processing
sent_src, sent_tgt = source.strip().split(), target.strip().split()
token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [
tokenizer.tokenize(word) for word in sent_tgt]
wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [
tokenizer.convert_tokens_to_ids(x) for x in token_tgt]
ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids']
sub2word_map_src = []
for i, word_list in enumerate(token_src):
sub2word_map_src += [i for x in word_list]
sub2word_map_tgt = []
for i, word_list in enumerate(token_tgt):
sub2word_map_tgt += [i for x in word_list]
# alignment
align_layer = 8
threshold = 1e-3
model.eval()
with torch.no_grad():
out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[
2][align_layer][0, 1:-1]
out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[
2][align_layer][0, 1:-1]
dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2))
softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod)
softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod)
softmax_inter = (softmax_srctgt > threshold) * \
(softmax_tgtsrc > threshold)
align_subwords = torch.nonzero(softmax_inter, as_tuple=False)
align_words = set()
for i, j in align_subwords:
align_words.add((sub2word_map_src[i], sub2word_map_tgt[j]))
return sent_src, sent_tgt, align_words
def get_word_mapping(source="", target="", model_name=""):
"""
Get Word Aligned Mapping Words
"""
sent_src, sent_tgt, align_words = get_alignment_mapping(
source=source, target=target, model_name=model_name)
result = []
for i, j in sorted(align_words):
result.append(f'bn:({sent_src[i]}) -> en:({sent_tgt[j]})')
return result
def get_word_index_mapping(source="", target="", model_name=""):
"""
Get Word Aligned Mapping Index
"""
sent_src, sent_tgt, align_words = get_alignment_mapping(
source=source, target=target, model_name=model_name)
result = []
for i, j in sorted(align_words):
result.append(f'bn:({i}) -> en:({j})')
return result