hate-speech / app.py
fnavales's picture
Update app.py
2a04af3
raw
history blame
5.17 kB
import gradio as gr
import torch.nn as nn
import torch
from transformers import BertTokenizerFast as BertTokenizer, BertModel
import pytorch_lightning as pl
BERT_MODEL_NAME = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
LABEL_COLUMNS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
MAX_TOKEN_COUNT = 300
class ToxicCommentTagger(pl.LightningModule):
def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
super().__init__()
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
self.n_training_steps = n_training_steps
self.n_warmup_steps = n_warmup_steps
self.criterion = nn.BCELoss()
def forward(self, input_ids, attention_mask, labels=None):
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.classifier(output.pooler_output)
output = torch.sigmoid(output)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output
def predict(model, tokenizer, sentence):
encoding = tokenizer.encode_plus(
sentence,
add_special_tokens=False,
max_length=MAX_TOKEN_COUNT,
return_token_type_ids=False,
padding="max_length",
return_attention_mask=True,
return_tensors='pt'
)
# define target chunksize
chunksize = MAX_TOKEN_COUNT
# split into chunks of 510 tokens, we also convert to list (default is tuple which is immutable)
input_id_chunks = list(encoding['input_ids'][0].split(chunksize - 2))
mask_chunks = list(encoding['attention_mask'][0].split(chunksize - 2))
# loop through each chunk
for i in range(len(input_id_chunks)):
# add CLS and SEP tokens to input IDs
input_id_chunks[i] = torch.cat([
torch.tensor([101]), input_id_chunks[i], torch.tensor([102])
])
# add attention tokens to attention mask
mask_chunks[i] = torch.cat([
torch.tensor([1]), mask_chunks[i], torch.tensor([1])
])
# get required padding length
pad_len = chunksize - input_id_chunks[i].shape[0]
# check if tensor length satisfies required chunk size
if pad_len > 0:
# if padding length is more than 0, we must add padding
input_id_chunks[i] = torch.cat([
input_id_chunks[i], torch.Tensor([0] * pad_len)
])
mask_chunks[i] = torch.cat([
mask_chunks[i], torch.Tensor([0] * pad_len)
])
input_ids = torch.stack(input_id_chunks)
attention_mask = torch.stack(mask_chunks)
input_dict = {
'input_ids': input_ids.long(),
'attention_mask': attention_mask.int()
}
_, test_prediction = model(**input_dict)
test_prediction = test_prediction.numpy()
output = {}
for chunk in test_prediction:
for label, prediction in zip(LABEL_COLUMNS, chunk):
if label in output:
output[label] = max(prediction, output[label])
else:
output[label] = prediction
return output
model = ToxicCommentTagger.load_from_checkpoint(
'./best-checkpoint.ckpt',
n_classes=len(LABEL_COLUMNS)
)
model.eval()
model.freeze()
all_categories = {'all_categories': [
'toxicity',
'obscene',
'threat',
'insult',
'identity_attack',
'sexual_explicit'
]
}
examples = [
['Yes indeed. She sort of reminds me of the elder lady that played the part in the movie "Titanic" who was telling her story!!! And I wouldn\'t have wanted to cover who I really am!! I would be proud!!!! WE should be proud of our race no matter what it is!!',
50],
['The trans women reading this tweet right now is beautiful', 50],
["Question: These 4 broads who criticize America, what country did they flee to get here? And now they want to make OUR America like THEIR former HELL HOLE. I don't think so!!!!!!!!!! Let them explain their GRATITUDE for letting them in OUR country.",
50],
['Well being subtle you could not have been born because black people don\'t matter',
50],
["For starters bend over the one in pink and kick that ass and pussy to get a taste until she's begging for a dick inside her.",
50]
]
def toxicity(sentence, threshold):
predicts = predict(model, tokenizer, sentence)
return [ x for x in predicts if predicts[x] > threshold/100 ], all_categories
gr.Interface(fn=toxicity,
inputs=[
gr.Textbox(placeholder="Enter sentence here..."),
gr.Slider(0, 100)
],
outputs=[
'text',
gr.JSON(all_categories)
],
examples=examples).launch()