qiuhuachuan commited on
Commit
b75fb56
1 Parent(s): 3038223

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +112 -1
README.md CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  <div align="center">
2
  <h1>
3
  Facilitating NSFW Text Detection in Open-Domain Dialogue Systems via Knowledge Distillation
@@ -62,6 +70,109 @@ We report the classification results of the BERT model in the following table. W
62
 
63
  **NOTICE:** You can directly use our trained checkpoint on the hub of Hugging Face.
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  ## Citation
66
 
67
  If our work is useful for your own, you can cite us with the following BibTex entry:
@@ -76,4 +187,4 @@ If our work is useful for your own, you can cite us with the following BibTex en
76
  primaryClass={cs.CL},
77
  url={https://arxiv.org/abs/2309.09749}
78
  }
79
- ```
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: text-classification
6
+ tags:
7
+ - text-classification
8
+ ---
9
  <div align="center">
10
  <h1>
11
  Facilitating NSFW Text Detection in Open-Domain Dialogue Systems via Knowledge Distillation
 
70
 
71
  **NOTICE:** You can directly use our trained checkpoint on the hub of Hugging Face.
72
 
73
+ For context-level detection, the input format should be `[user] {user utterance} [SEP] [bot] {bot response}`, where user utterance and bot response should be placed corresponding content.
74
+
75
+ 1. Download the checkpoint
76
+
77
+ ```Bash
78
+ git lfs install
79
+ git clone https://huggingface.co/qiuhuachuan/NSFW-detector
80
+ ```
81
+
82
+ 2. Modify the `text` parameter in local_use.py and execute it.
83
+
84
+ ```Bash
85
+ from typing import Optional
86
+
87
+ import torch
88
+ from transformers import BertConfig, BertTokenizer, BertModel, BertPreTrainedModel
89
+ from torch import nn
90
+
91
+ label_mapping = {0: 'NSFW', 1: 'SFW'}
92
+
93
+ config = BertConfig.from_pretrained('qiuhuachuan/NSFW-detector',
94
+ num_labels=2,
95
+ finetuning_task='text classification')
96
+ tokenizer = BertTokenizer.from_pretrained('qiuhuachuan/NSFW-detector',
97
+ use_fast=False,
98
+ never_split=['[user]', '[bot]'])
99
+ tokenizer.vocab['[user]'] = tokenizer.vocab.pop('[unused1]')
100
+ tokenizer.vocab['[bot]'] = tokenizer.vocab.pop('[unused2]')
101
+
102
+
103
+ class BertForSequenceClassification(BertPreTrainedModel):
104
+ def __init__(self, config):
105
+ super().__init__(config)
106
+ self.num_labels = config.num_labels
107
+ self.config = config
108
+
109
+ self.bert = BertModel.from_pretrained('bert-base-cased')
110
+ classifier_dropout = (config.classifier_dropout
111
+ if config.classifier_dropout is not None else
112
+ config.hidden_dropout_prob)
113
+ self.dropout = nn.Dropout(classifier_dropout)
114
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
115
+
116
+ # Initialize weights and apply final processing
117
+ self.post_init()
118
+
119
+ def forward(self,
120
+ input_ids: Optional[torch.Tensor] = None,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ token_type_ids: Optional[torch.Tensor] = None,
123
+ position_ids: Optional[torch.Tensor] = None,
124
+ head_mask: Optional[torch.Tensor] = None,
125
+ inputs_embeds: Optional[torch.Tensor] = None,
126
+ labels: Optional[torch.Tensor] = None,
127
+ output_attentions: Optional[bool] = None,
128
+ output_hidden_states: Optional[bool] = None,
129
+ return_dict: Optional[bool] = None):
130
+
131
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
132
+
133
+ outputs = self.bert(
134
+ input_ids,
135
+ attention_mask=attention_mask,
136
+ token_type_ids=token_type_ids,
137
+ position_ids=position_ids,
138
+ head_mask=head_mask,
139
+ inputs_embeds=inputs_embeds,
140
+ output_attentions=output_attentions,
141
+ output_hidden_states=output_hidden_states,
142
+ return_dict=return_dict,
143
+ )
144
+
145
+ # we use cls embedding
146
+ cls = outputs[0][:, 0, :]
147
+ cls = self.dropout(cls)
148
+ logits = self.classifier(cls)
149
+
150
+ return logits
151
+
152
+
153
+ model = BertForSequenceClassification(config=config)
154
+ model.load_state_dict(torch.load('./NSFW-detector/pytorch_model.bin'))
155
+ model.cuda()
156
+ model.eval()
157
+
158
+ text = '''I'm open to exploring a variety of toys, including vibrators, wands, and clamps. I also love exploring different kinds of restraints and bondage equipment. I'm open to trying out different kinds of toys and exploring different levels of intensity.'''
159
+ result = tokenizer.encode_plus(text=text,
160
+ padding='max_length',
161
+ max_length=512,
162
+ truncation=True,
163
+ add_special_tokens=True,
164
+ return_token_type_ids=True,
165
+ return_tensors='pt')
166
+ result = result.to('cuda')
167
+
168
+ with torch.no_grad():
169
+ logits = model(**result)
170
+ predictions = logits.argmax(dim=-1)
171
+ pred_label_idx = predictions.item()
172
+ pred_label = label_mapping[pred_label_idx]
173
+ print('predicted label is:', pred_label)
174
+ ```
175
+
176
  ## Citation
177
 
178
  If our work is useful for your own, you can cite us with the following BibTex entry:
 
187
  primaryClass={cs.CL},
188
  url={https://arxiv.org/abs/2309.09749}
189
  }
190
+ ```