data-silence
commited on
Commit
•
d79d747
1
Parent(s):
9ba7b64
Update README.md
Browse files
README.md
CHANGED
@@ -61,11 +61,36 @@ classification of news categories politics, society and conflicts.
|
|
61 |
Example of how to use the model:
|
62 |
|
63 |
```python
|
|
|
|
|
64 |
import torch
|
65 |
from transformers import AutoTokenizer
|
66 |
from huggingface_hub import hf_hub_download
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
|
70 |
'politics', 'science', 'society', 'sports', 'travel']
|
71 |
|
|
|
61 |
Example of how to use the model:
|
62 |
|
63 |
```python
|
64 |
+
import torch.nn as nn
|
65 |
+
from transformers import BertModel
|
66 |
import torch
|
67 |
from transformers import AutoTokenizer
|
68 |
from huggingface_hub import hf_hub_download
|
69 |
|
70 |
|
71 |
+
class BiLSTMClassifier(nn.Module):
|
72 |
+
def __init__(self, hidden_dim, output_dim, n_layers, dropout):
|
73 |
+
super(BiLSTMClassifier, self).__init__()
|
74 |
+
self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
|
75 |
+
self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers,
|
76 |
+
bidirectional=True, dropout=dropout, batch_first=True)
|
77 |
+
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
78 |
+
self.dropout = nn.Dropout(dropout)
|
79 |
+
|
80 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
81 |
+
with torch.no_grad():
|
82 |
+
embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
|
83 |
+
lstm_out, _ = self.lstm(embedded)
|
84 |
+
pooled = torch.mean(lstm_out, dim=1)
|
85 |
+
logits = self.fc(self.dropout(pooled))
|
86 |
+
|
87 |
+
if labels is not None:
|
88 |
+
loss_fn = nn.CrossEntropyLoss()
|
89 |
+
loss = loss_fn(logits, labels)
|
90 |
+
return {"loss": loss, "logits": logits} # Возвращаем словарь
|
91 |
+
return logits # Возвращаем логиты, если метки не переданы
|
92 |
+
|
93 |
+
|
94 |
categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
|
95 |
'politics', 'science', 'society', 'sports', 'travel']
|
96 |
|