Spaces:
No application file
No application file
Tatiana
commited on
Commit
•
dd3dbad
1
Parent(s):
aa3e28c
files added
Browse files
task2.py
CHANGED
@@ -1,17 +1,45 @@
|
|
1 |
from transformers import BertTokenizer, BertForSequenceClassification
|
2 |
import torch
|
3 |
from sklearn.preprocessing import LabelEncoder
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
|
|
|
9 |
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
|
10 |
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
|
17 |
if not user_input:
|
|
|
1 |
from transformers import BertTokenizer, BertForSequenceClassification
|
2 |
import torch
|
3 |
from sklearn.preprocessing import LabelEncoder
|
4 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
5 |
+
import torch
|
6 |
+
from sklearn.preprocessing import LabelEncoder
|
7 |
|
8 |
+
labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта']
|
9 |
+
label_encoder = LabelEncoder()
|
10 |
+
label_encoder.fit(labels)
|
11 |
+
|
12 |
+
# Загрузка сохраненной модели и токенизатора в Streamlit
|
13 |
+
loaded_model_path = "rubert-base-cased"
|
14 |
+
loaded_tokenizer_path = BertForSequenceClassification.from_pretrained(loaded_model_path)
|
15 |
|
16 |
+
# Инициализация модели и токенизатора
|
17 |
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
|
18 |
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
|
19 |
|
20 |
+
# Создание модели с архитектурой BertForSequenceClassification
|
21 |
+
# Передайте в аргумент `num_labels` количество классов, для которых модель будет выполнять классификацию
|
22 |
+
model = BertForSequenceClassification(num_labels=len(labels))
|
23 |
+
|
24 |
+
# Загрузка весов из сохраненного файла
|
25 |
+
weights_path = "model_weights_epoch_8.pt"
|
26 |
+
state_dict = torch.load(weights_path, map_location='cpu') # Укажите 'cuda' вместо 'cpu', если используете GPU
|
27 |
+
model.load_state_dict(state_dict)
|
28 |
+
|
29 |
+
# Пример использования загруженной модели
|
30 |
+
user_input = "Ваш текст для классификации"
|
31 |
+
predicted_class = predict_class(user_input, model=model, tokenizer=loaded_tokenizer, label_encoder=label_encoder)
|
32 |
+
print(predicted_class)
|
33 |
+
|
34 |
+
|
35 |
+
# #Загрузка сохраненной модели и токенизатора в Streamlit
|
36 |
+
# loaded_model_path = "nlp_project/model"
|
37 |
+
# loaded_tokenizer_path = "nlp_project/tokenizer"
|
38 |
+
|
39 |
+
# loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
|
40 |
+
# loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
|
41 |
+
|
42 |
+
|
43 |
|
44 |
def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
|
45 |
if not user_input:
|