Tatiana commited on
Commit
dd3dbad
1 Parent(s): aa3e28c

files added

Browse files
Files changed (1) hide show
  1. task2.py +34 -6
task2.py CHANGED
@@ -1,17 +1,45 @@
1
  from transformers import BertTokenizer, BertForSequenceClassification
2
  import torch
3
  from sklearn.preprocessing import LabelEncoder
 
 
 
4
 
5
- #Загрузка сохраненной модели и токенизатора в Streamlit
6
- loaded_model_path = "nlp_project/model"
7
- loaded_tokenizer_path = "nlp_project/tokenizer"
 
 
 
 
8
 
 
9
  loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
10
  loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
11
 
12
- labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта']
13
- label_encoder = LabelEncoder()
14
- label_encoder.fit(labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
17
  if not user_input:
 
1
  from transformers import BertTokenizer, BertForSequenceClassification
2
  import torch
3
  from sklearn.preprocessing import LabelEncoder
4
+ from transformers import BertTokenizer, BertForSequenceClassification
5
+ import torch
6
+ from sklearn.preprocessing import LabelEncoder
7
 
8
+ labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта']
9
+ label_encoder = LabelEncoder()
10
+ label_encoder.fit(labels)
11
+
12
+ # Загрузка сохраненной модели и токенизатора в Streamlit
13
+ loaded_model_path = "rubert-base-cased"
14
+ loaded_tokenizer_path = BertForSequenceClassification.from_pretrained(loaded_model_path)
15
 
16
+ # Инициализация модели и токенизатора
17
  loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
18
  loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
19
 
20
+ # Создание модели с архитектурой BertForSequenceClassification
21
+ # Передайте в аргумент `num_labels` количество классов, для которых модель будет выполнять классификацию
22
+ model = BertForSequenceClassification(num_labels=len(labels))
23
+
24
+ # Загрузка весов из сохраненного файла
25
+ weights_path = "model_weights_epoch_8.pt"
26
+ state_dict = torch.load(weights_path, map_location='cpu') # Укажите 'cuda' вместо 'cpu', если используете GPU
27
+ model.load_state_dict(state_dict)
28
+
29
+ # Пример использования загруженной модели
30
+ user_input = "Ваш текст для классификации"
31
+ predicted_class = predict_class(user_input, model=model, tokenizer=loaded_tokenizer, label_encoder=label_encoder)
32
+ print(predicted_class)
33
+
34
+
35
+ # #Загрузка сохраненной модели и токенизатора в Streamlit
36
+ # loaded_model_path = "nlp_project/model"
37
+ # loaded_tokenizer_path = "nlp_project/tokenizer"
38
+
39
+ # loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
40
+ # loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
41
+
42
+
43
 
44
  def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
45
  if not user_input: