qiuhuachuan
commited on
Commit
•
b75fb56
1
Parent(s):
3038223
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
<div align="center">
|
2 |
<h1>
|
3 |
Facilitating NSFW Text Detection in Open-Domain Dialogue Systems via Knowledge Distillation
|
@@ -62,6 +70,109 @@ We report the classification results of the BERT model in the following table. W
|
|
62 |
|
63 |
**NOTICE:** You can directly use our trained checkpoint on the hub of Hugging Face.
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
## Citation
|
66 |
|
67 |
If our work is useful for your own, you can cite us with the following BibTex entry:
|
@@ -76,4 +187,4 @@ If our work is useful for your own, you can cite us with the following BibTex en
|
|
76 |
primaryClass={cs.CL},
|
77 |
url={https://arxiv.org/abs/2309.09749}
|
78 |
}
|
79 |
-
```
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
pipeline_tag: text-classification
|
6 |
+
tags:
|
7 |
+
- text-classification
|
8 |
+
---
|
9 |
<div align="center">
|
10 |
<h1>
|
11 |
Facilitating NSFW Text Detection in Open-Domain Dialogue Systems via Knowledge Distillation
|
|
|
70 |
|
71 |
**NOTICE:** You can directly use our trained checkpoint on the hub of Hugging Face.
|
72 |
|
73 |
+
For context-level detection, the input format should be `[user] {user utterance} [SEP] [bot] {bot response}`, where user utterance and bot response should be placed corresponding content.
|
74 |
+
|
75 |
+
1. Download the checkpoint
|
76 |
+
|
77 |
+
```Bash
|
78 |
+
git lfs install
|
79 |
+
git clone https://huggingface.co/qiuhuachuan/NSFW-detector
|
80 |
+
```
|
81 |
+
|
82 |
+
2. Modify the `text` parameter in local_use.py and execute it.
|
83 |
+
|
84 |
+
```Bash
|
85 |
+
from typing import Optional
|
86 |
+
|
87 |
+
import torch
|
88 |
+
from transformers import BertConfig, BertTokenizer, BertModel, BertPreTrainedModel
|
89 |
+
from torch import nn
|
90 |
+
|
91 |
+
label_mapping = {0: 'NSFW', 1: 'SFW'}
|
92 |
+
|
93 |
+
config = BertConfig.from_pretrained('qiuhuachuan/NSFW-detector',
|
94 |
+
num_labels=2,
|
95 |
+
finetuning_task='text classification')
|
96 |
+
tokenizer = BertTokenizer.from_pretrained('qiuhuachuan/NSFW-detector',
|
97 |
+
use_fast=False,
|
98 |
+
never_split=['[user]', '[bot]'])
|
99 |
+
tokenizer.vocab['[user]'] = tokenizer.vocab.pop('[unused1]')
|
100 |
+
tokenizer.vocab['[bot]'] = tokenizer.vocab.pop('[unused2]')
|
101 |
+
|
102 |
+
|
103 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
104 |
+
def __init__(self, config):
|
105 |
+
super().__init__(config)
|
106 |
+
self.num_labels = config.num_labels
|
107 |
+
self.config = config
|
108 |
+
|
109 |
+
self.bert = BertModel.from_pretrained('bert-base-cased')
|
110 |
+
classifier_dropout = (config.classifier_dropout
|
111 |
+
if config.classifier_dropout is not None else
|
112 |
+
config.hidden_dropout_prob)
|
113 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
114 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
115 |
+
|
116 |
+
# Initialize weights and apply final processing
|
117 |
+
self.post_init()
|
118 |
+
|
119 |
+
def forward(self,
|
120 |
+
input_ids: Optional[torch.Tensor] = None,
|
121 |
+
attention_mask: Optional[torch.Tensor] = None,
|
122 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
123 |
+
position_ids: Optional[torch.Tensor] = None,
|
124 |
+
head_mask: Optional[torch.Tensor] = None,
|
125 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
126 |
+
labels: Optional[torch.Tensor] = None,
|
127 |
+
output_attentions: Optional[bool] = None,
|
128 |
+
output_hidden_states: Optional[bool] = None,
|
129 |
+
return_dict: Optional[bool] = None):
|
130 |
+
|
131 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
132 |
+
|
133 |
+
outputs = self.bert(
|
134 |
+
input_ids,
|
135 |
+
attention_mask=attention_mask,
|
136 |
+
token_type_ids=token_type_ids,
|
137 |
+
position_ids=position_ids,
|
138 |
+
head_mask=head_mask,
|
139 |
+
inputs_embeds=inputs_embeds,
|
140 |
+
output_attentions=output_attentions,
|
141 |
+
output_hidden_states=output_hidden_states,
|
142 |
+
return_dict=return_dict,
|
143 |
+
)
|
144 |
+
|
145 |
+
# we use cls embedding
|
146 |
+
cls = outputs[0][:, 0, :]
|
147 |
+
cls = self.dropout(cls)
|
148 |
+
logits = self.classifier(cls)
|
149 |
+
|
150 |
+
return logits
|
151 |
+
|
152 |
+
|
153 |
+
model = BertForSequenceClassification(config=config)
|
154 |
+
model.load_state_dict(torch.load('./NSFW-detector/pytorch_model.bin'))
|
155 |
+
model.cuda()
|
156 |
+
model.eval()
|
157 |
+
|
158 |
+
text = '''I'm open to exploring a variety of toys, including vibrators, wands, and clamps. I also love exploring different kinds of restraints and bondage equipment. I'm open to trying out different kinds of toys and exploring different levels of intensity.'''
|
159 |
+
result = tokenizer.encode_plus(text=text,
|
160 |
+
padding='max_length',
|
161 |
+
max_length=512,
|
162 |
+
truncation=True,
|
163 |
+
add_special_tokens=True,
|
164 |
+
return_token_type_ids=True,
|
165 |
+
return_tensors='pt')
|
166 |
+
result = result.to('cuda')
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
logits = model(**result)
|
170 |
+
predictions = logits.argmax(dim=-1)
|
171 |
+
pred_label_idx = predictions.item()
|
172 |
+
pred_label = label_mapping[pred_label_idx]
|
173 |
+
print('predicted label is:', pred_label)
|
174 |
+
```
|
175 |
+
|
176 |
## Citation
|
177 |
|
178 |
If our work is useful for your own, you can cite us with the following BibTex entry:
|
|
|
187 |
primaryClass={cs.CL},
|
188 |
url={https://arxiv.org/abs/2309.09749}
|
189 |
}
|
190 |
+
```
|