Spaces:
Sleeping
Sleeping
tokenizing data and extracting features from last hidden layer
Browse files- src/classifier.ipynb +209 -14
src/classifier.ipynb
CHANGED
@@ -10,12 +10,26 @@
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
-
"execution_count":
|
14 |
"metadata": {},
|
15 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
"source": [
|
|
|
|
|
|
|
17 |
"import matplotlib.pyplot as plt\n",
|
18 |
-
"import pandas as pd"
|
|
|
|
|
19 |
]
|
20 |
},
|
21 |
{
|
@@ -27,7 +41,7 @@
|
|
27 |
},
|
28 |
{
|
29 |
"cell_type": "code",
|
30 |
-
"execution_count":
|
31 |
"metadata": {},
|
32 |
"outputs": [],
|
33 |
"source": [
|
@@ -37,7 +51,7 @@
|
|
37 |
},
|
38 |
{
|
39 |
"cell_type": "code",
|
40 |
-
"execution_count":
|
41 |
"metadata": {},
|
42 |
"outputs": [
|
43 |
{
|
@@ -123,7 +137,7 @@
|
|
123 |
"4 Algeria dz_El-Oued "
|
124 |
]
|
125 |
},
|
126 |
-
"execution_count":
|
127 |
"metadata": {},
|
128 |
"output_type": "execute_result"
|
129 |
}
|
@@ -134,7 +148,7 @@
|
|
134 |
},
|
135 |
{
|
136 |
"cell_type": "code",
|
137 |
-
"execution_count":
|
138 |
"metadata": {},
|
139 |
"outputs": [
|
140 |
{
|
@@ -220,7 +234,7 @@
|
|
220 |
"4 Libya ly_Misrata "
|
221 |
]
|
222 |
},
|
223 |
-
"execution_count":
|
224 |
"metadata": {},
|
225 |
"output_type": "execute_result"
|
226 |
}
|
@@ -231,7 +245,7 @@
|
|
231 |
},
|
232 |
{
|
233 |
"cell_type": "code",
|
234 |
-
"execution_count":
|
235 |
"metadata": {},
|
236 |
"outputs": [
|
237 |
{
|
@@ -249,7 +263,7 @@
|
|
249 |
" dtype: int64)"
|
250 |
]
|
251 |
},
|
252 |
-
"execution_count":
|
253 |
"metadata": {},
|
254 |
"output_type": "execute_result"
|
255 |
}
|
@@ -267,7 +281,7 @@
|
|
267 |
},
|
268 |
{
|
269 |
"cell_type": "code",
|
270 |
-
"execution_count":
|
271 |
"metadata": {},
|
272 |
"outputs": [
|
273 |
{
|
@@ -276,7 +290,7 @@
|
|
276 |
"Text(0.5, 1.0, 'Value counts of country label in train data')"
|
277 |
]
|
278 |
},
|
279 |
-
"execution_count":
|
280 |
"metadata": {},
|
281 |
"output_type": "execute_result"
|
282 |
},
|
@@ -299,7 +313,7 @@
|
|
299 |
},
|
300 |
{
|
301 |
"cell_type": "code",
|
302 |
-
"execution_count":
|
303 |
"metadata": {},
|
304 |
"outputs": [
|
305 |
{
|
@@ -308,7 +322,7 @@
|
|
308 |
"Text(0.5, 1.0, 'Value counts of country label in test data')"
|
309 |
]
|
310 |
},
|
311 |
-
"execution_count":
|
312 |
"metadata": {},
|
313 |
"output_type": "execute_result"
|
314 |
},
|
@@ -343,6 +357,187 @@
|
|
343 |
"## 2. Training the Classifier"
|
344 |
]
|
345 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
{
|
347 |
"cell_type": "markdown",
|
348 |
"metadata": {},
|
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
+
"execution_count": 1,
|
14 |
"metadata": {},
|
15 |
+
"outputs": [
|
16 |
+
{
|
17 |
+
"name": "stderr",
|
18 |
+
"output_type": "stream",
|
19 |
+
"text": [
|
20 |
+
"/home/mehdi/miniconda3/envs/adc/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
21 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
"source": [
|
26 |
+
"import pickle\n",
|
27 |
+
"\n",
|
28 |
+
"from datasets import DatasetDict, Dataset\n",
|
29 |
"import matplotlib.pyplot as plt\n",
|
30 |
+
"import pandas as pd\n",
|
31 |
+
"import torch\n",
|
32 |
+
"from transformers import AutoModel, AutoTokenizer"
|
33 |
]
|
34 |
},
|
35 |
{
|
|
|
41 |
},
|
42 |
{
|
43 |
"cell_type": "code",
|
44 |
+
"execution_count": 2,
|
45 |
"metadata": {},
|
46 |
"outputs": [],
|
47 |
"source": [
|
|
|
51 |
},
|
52 |
{
|
53 |
"cell_type": "code",
|
54 |
+
"execution_count": 3,
|
55 |
"metadata": {},
|
56 |
"outputs": [
|
57 |
{
|
|
|
137 |
"4 Algeria dz_El-Oued "
|
138 |
]
|
139 |
},
|
140 |
+
"execution_count": 3,
|
141 |
"metadata": {},
|
142 |
"output_type": "execute_result"
|
143 |
}
|
|
|
148 |
},
|
149 |
{
|
150 |
"cell_type": "code",
|
151 |
+
"execution_count": 4,
|
152 |
"metadata": {},
|
153 |
"outputs": [
|
154 |
{
|
|
|
234 |
"4 Libya ly_Misrata "
|
235 |
]
|
236 |
},
|
237 |
+
"execution_count": 4,
|
238 |
"metadata": {},
|
239 |
"output_type": "execute_result"
|
240 |
}
|
|
|
245 |
},
|
246 |
{
|
247 |
"cell_type": "code",
|
248 |
+
"execution_count": 5,
|
249 |
"metadata": {},
|
250 |
"outputs": [
|
251 |
{
|
|
|
263 |
" dtype: int64)"
|
264 |
]
|
265 |
},
|
266 |
+
"execution_count": 5,
|
267 |
"metadata": {},
|
268 |
"output_type": "execute_result"
|
269 |
}
|
|
|
281 |
},
|
282 |
{
|
283 |
"cell_type": "code",
|
284 |
+
"execution_count": 6,
|
285 |
"metadata": {},
|
286 |
"outputs": [
|
287 |
{
|
|
|
290 |
"Text(0.5, 1.0, 'Value counts of country label in train data')"
|
291 |
]
|
292 |
},
|
293 |
+
"execution_count": 6,
|
294 |
"metadata": {},
|
295 |
"output_type": "execute_result"
|
296 |
},
|
|
|
313 |
},
|
314 |
{
|
315 |
"cell_type": "code",
|
316 |
+
"execution_count": 7,
|
317 |
"metadata": {},
|
318 |
"outputs": [
|
319 |
{
|
|
|
322 |
"Text(0.5, 1.0, 'Value counts of country label in test data')"
|
323 |
]
|
324 |
},
|
325 |
+
"execution_count": 7,
|
326 |
"metadata": {},
|
327 |
"output_type": "execute_result"
|
328 |
},
|
|
|
357 |
"## 2. Training the Classifier"
|
358 |
]
|
359 |
},
|
360 |
+
{
|
361 |
+
"cell_type": "markdown",
|
362 |
+
"metadata": {},
|
363 |
+
"source": [
|
364 |
+
"For this classifier, we will convert the tweets into vector embeddings using the AraBART model. We will use the last hidden layer of the model to extract the features"
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"cell_type": "markdown",
|
369 |
+
"metadata": {},
|
370 |
+
"source": [
|
371 |
+
"### 2.1 Data Preparation\n",
|
372 |
+
"The first step is to prepare our data by tokenizing it to use it with the model AraBART."
|
373 |
+
]
|
374 |
+
},
|
375 |
+
{
|
376 |
+
"cell_type": "markdown",
|
377 |
+
"metadata": {},
|
378 |
+
"source": [
|
379 |
+
"First, we load the model and its tokenizer."
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "code",
|
384 |
+
"execution_count": 8,
|
385 |
+
"metadata": {},
|
386 |
+
"outputs": [
|
387 |
+
{
|
388 |
+
"data": {
|
389 |
+
"text/plain": [
|
390 |
+
"torch.device"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
"execution_count": 8,
|
394 |
+
"metadata": {},
|
395 |
+
"output_type": "execute_result"
|
396 |
+
}
|
397 |
+
],
|
398 |
+
"source": [
|
399 |
+
"device = torch.device(\"cpu\")\n",
|
400 |
+
"type(device)"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"execution_count": 9,
|
406 |
+
"metadata": {},
|
407 |
+
"outputs": [],
|
408 |
+
"source": [
|
409 |
+
"model = AutoModel.from_pretrained(\"moussaKam/AraBART\").to(device)\n",
|
410 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"moussaKam/AraBART\")"
|
411 |
+
]
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"cell_type": "markdown",
|
415 |
+
"metadata": {},
|
416 |
+
"source": [
|
417 |
+
"Next, we convert the datasets into a DatasetDict object."
|
418 |
+
]
|
419 |
+
},
|
420 |
+
{
|
421 |
+
"cell_type": "code",
|
422 |
+
"execution_count": 10,
|
423 |
+
"metadata": {},
|
424 |
+
"outputs": [],
|
425 |
+
"source": [
|
426 |
+
"mapper = {\"#2_tweet\": \"tweet\", \"#3_country_label\": \"label\"}\n",
|
427 |
+
"columns_to_keep = [\"tweet\", \"label\"]\n",
|
428 |
+
"\n",
|
429 |
+
"df_train = df_train.rename(columns=mapper)[columns_to_keep]\n",
|
430 |
+
"df_test = df_test.rename(columns=mapper)[columns_to_keep]\n",
|
431 |
+
"\n",
|
432 |
+
"train_dataset = Dataset.from_pandas(df_train)\n",
|
433 |
+
"test_dataset = Dataset.from_pandas(df_test)\n",
|
434 |
+
"data = DatasetDict({'train': train_dataset, 'test': test_dataset})"
|
435 |
+
]
|
436 |
+
},
|
437 |
+
{
|
438 |
+
"cell_type": "markdown",
|
439 |
+
"metadata": {},
|
440 |
+
"source": [
|
441 |
+
"Then, we tokenkize the dataset."
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"cell_type": "code",
|
446 |
+
"execution_count": 11,
|
447 |
+
"metadata": {},
|
448 |
+
"outputs": [
|
449 |
+
{
|
450 |
+
"name": "stderr",
|
451 |
+
"output_type": "stream",
|
452 |
+
"text": [
|
453 |
+
" \r"
|
454 |
+
]
|
455 |
+
}
|
456 |
+
],
|
457 |
+
"source": [
|
458 |
+
"def tokenize(batch):\n",
|
459 |
+
" return tokenizer(batch[\"tweet\"], padding=True)\n",
|
460 |
+
"\n",
|
461 |
+
"data_encoded = data.map(tokenize, batched=True, batch_size=None)"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "code",
|
466 |
+
"execution_count": 12,
|
467 |
+
"metadata": {},
|
468 |
+
"outputs": [
|
469 |
+
{
|
470 |
+
"data": {
|
471 |
+
"text/plain": [
|
472 |
+
"['tweet', 'label', 'input_ids', 'attention_mask']"
|
473 |
+
]
|
474 |
+
},
|
475 |
+
"execution_count": 12,
|
476 |
+
"metadata": {},
|
477 |
+
"output_type": "execute_result"
|
478 |
+
}
|
479 |
+
],
|
480 |
+
"source": [
|
481 |
+
"data_encoded[\"train\"].column_names"
|
482 |
+
]
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"cell_type": "markdown",
|
486 |
+
"metadata": {},
|
487 |
+
"source": [
|
488 |
+
"### 2.2 Feature Extraction"
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": 13,
|
494 |
+
"metadata": {},
|
495 |
+
"outputs": [],
|
496 |
+
"source": [
|
497 |
+
"def extract_hidden_states(batch):\n",
|
498 |
+
" inputs = {k:v.to(device) for k,v in batch.items()\n",
|
499 |
+
" if k in tokenizer.model_input_names}\n",
|
500 |
+
" with torch.no_grad():\n",
|
501 |
+
" last_hidden_state = model(**inputs).last_hidden_state\n",
|
502 |
+
"\n",
|
503 |
+
" return{\"hidden_state\": last_hidden_state[:,0].cpu().numpy()}"
|
504 |
+
]
|
505 |
+
},
|
506 |
+
{
|
507 |
+
"cell_type": "code",
|
508 |
+
"execution_count": 14,
|
509 |
+
"metadata": {},
|
510 |
+
"outputs": [
|
511 |
+
{
|
512 |
+
"name": "stderr",
|
513 |
+
"output_type": "stream",
|
514 |
+
"text": [
|
515 |
+
" \r"
|
516 |
+
]
|
517 |
+
}
|
518 |
+
],
|
519 |
+
"source": [
|
520 |
+
"data_encoded.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])\n",
|
521 |
+
"data_hidden = data_encoded.map(extract_hidden_states, batched=True, batch_size=50)"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"cell_type": "code",
|
526 |
+
"execution_count": 16,
|
527 |
+
"metadata": {},
|
528 |
+
"outputs": [],
|
529 |
+
"source": [
|
530 |
+
"with open(\"../data/data_hidden.pkl\", \"wb\") as f:\n",
|
531 |
+
" pickle.dump(data_hidden, f)"
|
532 |
+
]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"cell_type": "markdown",
|
536 |
+
"metadata": {},
|
537 |
+
"source": [
|
538 |
+
"### 2.3 Model Training"
|
539 |
+
]
|
540 |
+
},
|
541 |
{
|
542 |
"cell_type": "markdown",
|
543 |
"metadata": {},
|