Dilyara commited on
Commit
8483e3e
1 Parent(s): 2fb2974

feat: readme

Browse files
Files changed (1) hide show
  1. README.md +383 -0
README.md CHANGED
@@ -1,3 +1,386 @@
1
  ---
2
  license: openrail
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: openrail
3
+ language:
4
+ - ru
5
+ pipeline_tag: text-generation
6
  ---
7
+ ---
8
+ language:
9
+ - ru
10
+ ---
11
+
12
+ # Model Card for Model ID
13
+
14
+ <!-- Provide a quick summary of what the model is/does. -->
15
+
16
+ # Model Details
17
+
18
+ ## Model Description
19
+
20
+ <!-- Provide a longer summary of what this model is. -->
21
+
22
+ - **Developed by:** Deeppavlov team
23
+ - **Model type:** seq2seq
24
+ - **Language(s) (NLP):** Russian
25
+ - **License:** MIT
26
+ - **Finetuned from model:** [facebook/mbart-large-50](facebook/mbart-large-50)
27
+
28
+
29
+ # Uses
30
+
31
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
32
+
33
+
34
+ ## Direct Use
35
+
36
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
37
+
38
+ ```python
39
+ from typing import List, TypedDict
40
+ from dataclasses import dataclass
41
+ from itertools import chain
42
+
43
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
44
+ import torch
45
+
46
+
47
+ @dataclass
48
+ class H2PersonaChatHyperparametersV1:
49
+ """
50
+ chat_history_pair_length: int - количество пар диалога с конца
51
+ """
52
+
53
+ model_name: str = "facebook/bart-base"
54
+ chat_history_pair_length: int = 7
55
+
56
+ persona_max_length: int = 14
57
+ chat_max_length: int = 25
58
+
59
+ debug_status: int = 0
60
+
61
+
62
+ class PersonaChatDatasetSampleV1(TypedDict):
63
+ """
64
+ persona: List[str] - набор предложений фактов персоны
65
+ history: List[str] - набор предложений истории переписки
66
+ """
67
+
68
+ persona: List[str]
69
+ history: List[str]
70
+ sample_id: str
71
+
72
+
73
+ class H2Seq2SeqInferenceSampleDictV1(TypedDict):
74
+ input_ids: List[int]
75
+ attention_mask: List[int]
76
+
77
+
78
+ class H2Seq2SeqInferenceSampleDictV2(TypedDict):
79
+ input_ids: torch.Tensor
80
+ attention_mask: torch.Tensor
81
+
82
+
83
+ def flat_list(list_of_lists: List[List]) -> List:
84
+ return list(chain.from_iterable(list_of_lists))
85
+
86
+
87
+ class H2Seq2SeqInferencePersonaSampleV1:
88
+ def __init__(
89
+ self,
90
+ dataset_sample: PersonaChatDatasetSampleV1,
91
+ tokenizer: AutoTokenizer,
92
+ hyperparameters: H2PersonaChatHyperparametersV1,
93
+ ) -> None:
94
+ self.dataset_sample = dataset_sample
95
+ self.tokenizer = tokenizer
96
+ self.hyperparameters = hyperparameters
97
+
98
+ def add_spaces_after(
99
+ self,
100
+ items: List[str],
101
+ ) -> List[str]:
102
+ items = [item + " " for item in items]
103
+ return items
104
+
105
+ @property
106
+ def bos_token_id(self):
107
+ if "t5" in self.hyperparameters.model_name:
108
+ return []
109
+
110
+ if self.tokenizer.bos_token_id is None:
111
+ return []
112
+
113
+ return [self.tokenizer.bos_token_id]
114
+
115
+ @property
116
+ def eos_token_id(self):
117
+ if self.tokenizer.eos_token_id is None:
118
+ return []
119
+
120
+ return [self.tokenizer.eos_token_id]
121
+
122
+ def add_sep_beetween(self, items: List[str], sep=" EOS ") -> List[str]:
123
+ for i in range(1, len(items)):
124
+ items[i] = sep + items[i]
125
+
126
+ return items
127
+
128
+ def add_spaces_between(self, items: List[str]) -> List[str]:
129
+ items = self.add_spaces_after(items)
130
+ items[-1] = items[-1].strip()
131
+ return items
132
+
133
+ def get_sample(self) -> H2Seq2SeqInferenceSampleDictV1:
134
+
135
+ dialog_history = self.dataset_sample["history"]
136
+ dialog_history = dialog_history[-self.hyperparameters.chat_history_pair_length * 2 - 1 :]
137
+ dialog_history = self.add_sep_beetween(dialog_history)
138
+
139
+ persona = self.dataset_sample["persona"]
140
+ persona = self.add_sep_beetween(
141
+ persona,
142
+ sep=" ",
143
+ )
144
+
145
+ KNOWLEDGE_IDS = self.tokenizer.encode(
146
+ " [KNOWLEDGE] ",
147
+ add_special_tokens=False,
148
+ )
149
+ CONTEXT_IDS = self.tokenizer.encode(
150
+ " [CONTEXT]",
151
+ add_special_tokens=False,
152
+ )
153
+
154
+ encoded_history = self.tokenizer.batch_encode_plus(
155
+ dialog_history,
156
+ add_special_tokens=False,
157
+ truncation=True,
158
+ max_length=self.hyperparameters.chat_max_length,
159
+ )
160
+ encoded_history = flat_list(encoded_history["input_ids"])
161
+
162
+ encoded_persona = self.tokenizer.batch_encode_plus(
163
+ persona,
164
+ add_special_tokens=False,
165
+ truncation=True,
166
+ max_length=self.hyperparameters.persona_max_length,
167
+ )
168
+
169
+ encoded_persona = flat_list(encoded_persona["input_ids"])
170
+
171
+ input_ids = [
172
+ *self.bos_token_id,
173
+ *CONTEXT_IDS,
174
+ *encoded_history,
175
+ *KNOWLEDGE_IDS,
176
+ *encoded_persona,
177
+ *self.eos_token_id,
178
+ ]
179
+
180
+ attention_mask = [1] * len(input_ids)
181
+
182
+ return H2Seq2SeqInferenceSampleDictV1(
183
+ input_ids=input_ids,
184
+ attention_mask=attention_mask,
185
+ )
186
+
187
+
188
+ class DialogBotV1:
189
+ def __init__(
190
+ self,
191
+ model: AutoModelForSeq2SeqLM,
192
+ tokenizer: AutoTokenizer,
193
+ hyperparameters: H2PersonaChatHyperparametersV1,
194
+ history: List[str] = None,
195
+ persona: List[str] = None,
196
+ device: str = "cuda",
197
+ shuffle_persona: bool = True,
198
+ ):
199
+ self.model = model
200
+
201
+ self.tokenizer = tokenizer
202
+ self.hyperparameters = hyperparameters
203
+ self.device = device
204
+ self.shuffle_persona = shuffle_persona
205
+
206
+ self.debug_status = hyperparameters.debug_status
207
+
208
+ if history is None:
209
+ self.history = []
210
+ self.history = history
211
+
212
+ if persona is None:
213
+ self.persona = []
214
+ self.persona = persona
215
+
216
+ def _get_sample(
217
+ self,
218
+ persona: List[str],
219
+ history: List[str],
220
+ ) -> H2Seq2SeqInferenceSampleDictV1:
221
+ dataset_sample = PersonaChatDatasetSampleV1(
222
+ persona=persona,
223
+ history=history,
224
+ )
225
+
226
+ sample = H2Seq2SeqInferencePersonaSampleV1(
227
+ tokenizer=self.tokenizer,
228
+ hyperparameters=self.hyperparameters,
229
+ dataset_sample=dataset_sample,
230
+ )
231
+ sample = sample.get_sample()
232
+ print(self.tokenizer.decode(sample['input_ids']))
233
+
234
+ for key in sample.keys():
235
+ sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(self.device)
236
+
237
+ return sample
238
+
239
+ def next_response(
240
+ self,
241
+ **generation_params,
242
+ ) -> str:
243
+ """
244
+ делает предсказание на основе текущей истории
245
+ и персоны
246
+ """
247
+
248
+ sample = self._get_sample(
249
+ persona=self.persona,
250
+ history=self.history,
251
+ )
252
+ answer = self.generate_response(
253
+ sample,
254
+ **generation_params,
255
+ )
256
+ answer = self.tokenizer.batch_decode(
257
+ answer,
258
+ skip_special_tokens=True,
259
+ )
260
+ self.history.append(answer[0])
261
+ return answer[0]
262
+
263
+ def generate_response(
264
+ self,
265
+ sample: H2Seq2SeqInferenceSampleDictV1,
266
+ **generation_params,
267
+ ):
268
+ """
269
+ generation_params - https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation
270
+ """
271
+ with torch.no_grad():
272
+ return self.model.generate(
273
+ **sample,
274
+ **generation_params,
275
+ )
276
+
277
+
278
+ # facebook/mbart-large-50
279
+ PRETRAINED_MODEL_NAME_OR_PATH = "DeepPavlov/mbart-large-50-ru-persona-chat"
280
+
281
+ PAIR_DIALOG_HISTORY_LENGTH = 2
282
+
283
+ # CHAT_MAX_LENGTH for single sentence
284
+ CHAT_MAX_LENGTH = 25
285
+ # PERSONA_MAX_LENGTH for single sentence
286
+ PERSONA_MAX_LENGTH = 19
287
+
288
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
289
+ model = AutoModelForSeq2SeqLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
290
+ model.to(device)
291
+ model.eval()
292
+
293
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
294
+
295
+ if torch.cuda.is_available():
296
+ model.half()
297
+
298
+ hyperparameters = H2PersonaChatHyperparametersV1(
299
+ chat_history_pair_length=PAIR_DIALOG_HISTORY_LENGTH,
300
+ persona_max_length=PERSONA_MAX_LENGTH,
301
+ chat_max_length=CHAT_MAX_LENGTH,
302
+ model_name=PRETRAINED_MODEL_NAME_OR_PATH,
303
+ )
304
+
305
+
306
+ persona = [
307
+ "Я люблю играть с милыми песиками",
308
+ "Я ненавижу лук и броколли"
309
+ ]
310
+
311
+ history = [
312
+ "Привет. Ты любишь лук?"
313
+ ]
314
+
315
+ persona_bot = DialogBotV1(
316
+ model=model,
317
+ tokenizer=tokenizer,
318
+ hyperparameters=hyperparameters,
319
+ history=history,
320
+ persona=persona,
321
+ device=device,
322
+ )
323
+
324
+ GENERATION_PARAMS = {
325
+ "max_new_tokens": 60,
326
+ "penalty_alpha": 0.15,
327
+ "top_k": 10
328
+ }
329
+ response = persona_bot.next_response(
330
+ **GENERATION_PARAMS,
331
+ )
332
+
333
+ print(response)
334
+
335
+ ```
336
+
337
+
338
+ ## Recommendations
339
+
340
+ # Training Details
341
+
342
+ ## Training Data
343
+
344
+ <!-- This should link to a Data Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
345
+ - [Data Source | RU Persona Chat](https://toloka.ai/ru/datasets/#nlp)
346
+
347
+ [More Information Needed]
348
+
349
+ ### Preprocessing
350
+
351
+ - Initial data was splitted by this script:
352
+ ```python
353
+ def ru_persona_chat_dataset_tranformer_v1(
354
+ initial_dataset_path: str,
355
+ output_folder: str,
356
+ ) -> None:
357
+ """
358
+ example
359
+ ru_persona_chat_dataset_tranformer_v1(
360
+ initial_dataset_path="./datasets/ru_persona_chat/dialogues.tsv",
361
+ output_folder="./datasets/ru_persona_chat",
362
+ )
363
+ """
364
+ assert initial_dataset_path is not None, "initial_dataset_path is None"
365
+ assert output_folder is not None, "output_folder is None"
366
+
367
+ dataset = pd.read_csv(initial_dataset_path, sep="\t")
368
+ split_ratio = int(len(dataset) * 0.95)
369
+ train_dataset = dataset[:split_ratio]
370
+ valid_dataset = dataset[split_ratio:]
371
+
372
+ print(f"Dataset lengths: train {len(train_dataset)}, valid {len(valid_dataset)}")
373
+ # save csv files
374
+ train_dataset.to_csv(output_folder + "/train.csv", index=False)
375
+ valid_dataset.to_csv(output_folder + "/valid.csv", index=False)
376
+ print("Datasets saved.")
377
+ ```
378
+
379
+ # Evaluation
380
+
381
+ ### Metrics
382
+
383
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
384
+ - BLUEL
385
+ - CharF
386
+ - RougeL