import gradio as gr import torch.nn as nn import torch from transformers import BertTokenizerFast as BertTokenizer, BertModel import pytorch_lightning as pl BERT_MODEL_NAME = 'bert-base-uncased' tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) LABEL_COLUMNS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] MAX_TOKEN_COUNT = 300 class ToxicCommentTagger(pl.LightningModule): def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None): super().__init__() self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True) self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes) self.n_training_steps = n_training_steps self.n_warmup_steps = n_warmup_steps self.criterion = nn.BCELoss() def forward(self, input_ids, attention_mask, labels=None): output = self.bert(input_ids, attention_mask=attention_mask) output = self.classifier(output.pooler_output) output = torch.sigmoid(output) loss = 0 if labels is not None: loss = self.criterion(output, labels) return loss, output def predict(model, tokenizer, sentence): encoding = tokenizer.encode_plus( sentence, add_special_tokens=False, max_length=MAX_TOKEN_COUNT, return_token_type_ids=False, padding="max_length", return_attention_mask=True, return_tensors='pt' ) # define target chunksize chunksize = MAX_TOKEN_COUNT # split into chunks of 510 tokens, we also convert to list (default is tuple which is immutable) input_id_chunks = list(encoding['input_ids'][0].split(chunksize - 2)) mask_chunks = list(encoding['attention_mask'][0].split(chunksize - 2)) # loop through each chunk for i in range(len(input_id_chunks)): # add CLS and SEP tokens to input IDs input_id_chunks[i] = torch.cat([ torch.tensor([101]), input_id_chunks[i], torch.tensor([102]) ]) # add attention tokens to attention mask mask_chunks[i] = torch.cat([ torch.tensor([1]), mask_chunks[i], torch.tensor([1]) ]) # get required padding length pad_len = chunksize - input_id_chunks[i].shape[0] # check if tensor length satisfies required chunk size if pad_len > 0: # if padding length is more than 0, we must add padding input_id_chunks[i] = torch.cat([ input_id_chunks[i], torch.Tensor([0] * pad_len) ]) mask_chunks[i] = torch.cat([ mask_chunks[i], torch.Tensor([0] * pad_len) ]) input_ids = torch.stack(input_id_chunks) attention_mask = torch.stack(mask_chunks) input_dict = { 'input_ids': input_ids.long(), 'attention_mask': attention_mask.int() } _, test_prediction = model(**input_dict) test_prediction = test_prediction.numpy() output = {} for chunk in test_prediction: for label, prediction in zip(LABEL_COLUMNS, chunk): if label in output: output[label] = max(prediction, output[label]) else: output[label] = prediction return output model = ToxicCommentTagger.load_from_checkpoint( './best-checkpoint.ckpt', n_classes=len(LABEL_COLUMNS) ) model.eval() model.freeze() all_categories = {'all_categories': [ 'toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit' ] } examples = [ ['Yes indeed. She sort of reminds me of the elder lady that played the part in the movie "Titanic" who was telling her story!!! And I wouldn\'t have wanted to cover who I really am!! I would be proud!!!! WE should be proud of our race no matter what it is!!', 50], ['The trans women reading this tweet right now is beautiful', 50], ["Question: These 4 broads who criticize America, what country did they flee to get here? And now they want to make OUR America like THEIR former HELL HOLE. I don't think so!!!!!!!!!! Let them explain their GRATITUDE for letting them in OUR country.", 50], ['Well being subtle you could not have been born because black people don\'t matter', 50], ["For starters bend over the one in pink and kick that ass and pussy to get a taste until she's begging for a dick inside her.", 50] ] def toxicity(sentence, threshold): predicts = predict(model, tokenizer, sentence) return [ x for x in predicts if predicts[x] > threshold/100 ], all_categories gr.Interface(fn=toxicity, inputs=[ gr.Textbox(placeholder="Enter sentence here..."), gr.Slider(0, 100) ], outputs=[ 'text', gr.JSON(all_categories) ], examples=examples).launch()