FredZhang7
commited on
Commit
•
a6e2920
1
Parent(s):
643560b
Update README.md
Browse files
README.md
CHANGED
@@ -65,6 +65,7 @@ tags:
|
|
65 |
---
|
66 |
|
67 |
Find the v1 (TensorFlow) model on [this page](https://github.com/FredZhang7/tfjs-node-tiny/releases/tag/text-classification).
|
|
|
68 |
|
69 |
<br>
|
70 |
|
@@ -90,4 +91,41 @@ Training on Toxi Text 3M alone results in a biased model that classifies short t
|
|
90 |
<br>
|
91 |
|
92 |
Models tested for v2: roberta, xlm-roberta, bert-small, bert-base-cased/uncased, bert-multilingual-cased/uncased, and alberta-large-v2.
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
---
|
66 |
|
67 |
Find the v1 (TensorFlow) model on [this page](https://github.com/FredZhang7/tfjs-node-tiny/releases/tag/text-classification).
|
68 |
+
The license for the v1 model is Apache 2.0
|
69 |
|
70 |
<br>
|
71 |
|
|
|
91 |
<br>
|
92 |
|
93 |
Models tested for v2: roberta, xlm-roberta, bert-small, bert-base-cased/uncased, bert-multilingual-cased/uncased, and alberta-large-v2.
|
94 |
+
Of these, I chose bert-multilingual-cased because it performs better with the same amount of resources as the others for this particular task.
|
95 |
+
|
96 |
+
<br>
|
97 |
+
|
98 |
+
## PyTorch
|
99 |
+
|
100 |
+
```python
|
101 |
+
text = "hello world!"
|
102 |
+
|
103 |
+
import torch
|
104 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
105 |
+
|
106 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained("FredZhang7/one-for-all-toxicity-v3")
|
108 |
+
model = AutoModelForSequenceClassification.from_pretrained("FredZhang7/one-for-all-toxicity-v3").to(device)
|
109 |
+
|
110 |
+
encoding = tokenizer.encode_plus(
|
111 |
+
text,
|
112 |
+
add_special_tokens=True,
|
113 |
+
max_length=208,
|
114 |
+
padding="max_length",
|
115 |
+
truncation=True,
|
116 |
+
return_tensors="pt"
|
117 |
+
)
|
118 |
+
print('device:', device)
|
119 |
+
input_ids = encoding["input_ids"].to(device)
|
120 |
+
attention_mask = encoding["attention_mask"].to(device)
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
outputs = model(input_ids, attention_mask=attention_mask)
|
124 |
+
logits = outputs.logits
|
125 |
+
predicted_labels = torch.argmax(logits, dim=1)
|
126 |
+
|
127 |
+
print(predicted_labels)
|
128 |
+
```
|
129 |
+
|
130 |
+
## Attribution
|
131 |
+
- If you distribute, remix, adapt, or build upon One-for-all Toxicity v3, please credit "AIstrova Technologies Inc." in your README.md, application description, research, or website.
|