zaidmehdi commited on
Commit
76f50ce
1 Parent(s): 07f51a8

tokenizing data and extracting features from last hidden layer

Browse files
Files changed (1) hide show
  1. src/classifier.ipynb +209 -14
src/classifier.ipynb CHANGED
@@ -10,12 +10,26 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 14,
14
  "metadata": {},
15
- "outputs": [],
 
 
 
 
 
 
 
 
 
16
  "source": [
 
 
 
17
  "import matplotlib.pyplot as plt\n",
18
- "import pandas as pd"
 
 
19
  ]
20
  },
21
  {
@@ -27,7 +41,7 @@
27
  },
28
  {
29
  "cell_type": "code",
30
- "execution_count": 5,
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
@@ -37,7 +51,7 @@
37
  },
38
  {
39
  "cell_type": "code",
40
- "execution_count": 6,
41
  "metadata": {},
42
  "outputs": [
43
  {
@@ -123,7 +137,7 @@
123
  "4 Algeria dz_El-Oued "
124
  ]
125
  },
126
- "execution_count": 6,
127
  "metadata": {},
128
  "output_type": "execute_result"
129
  }
@@ -134,7 +148,7 @@
134
  },
135
  {
136
  "cell_type": "code",
137
- "execution_count": 7,
138
  "metadata": {},
139
  "outputs": [
140
  {
@@ -220,7 +234,7 @@
220
  "4 Libya ly_Misrata "
221
  ]
222
  },
223
- "execution_count": 7,
224
  "metadata": {},
225
  "output_type": "execute_result"
226
  }
@@ -231,7 +245,7 @@
231
  },
232
  {
233
  "cell_type": "code",
234
- "execution_count": 12,
235
  "metadata": {},
236
  "outputs": [
237
  {
@@ -249,7 +263,7 @@
249
  " dtype: int64)"
250
  ]
251
  },
252
- "execution_count": 12,
253
  "metadata": {},
254
  "output_type": "execute_result"
255
  }
@@ -267,7 +281,7 @@
267
  },
268
  {
269
  "cell_type": "code",
270
- "execution_count": 32,
271
  "metadata": {},
272
  "outputs": [
273
  {
@@ -276,7 +290,7 @@
276
  "Text(0.5, 1.0, 'Value counts of country label in train data')"
277
  ]
278
  },
279
- "execution_count": 32,
280
  "metadata": {},
281
  "output_type": "execute_result"
282
  },
@@ -299,7 +313,7 @@
299
  },
300
  {
301
  "cell_type": "code",
302
- "execution_count": 33,
303
  "metadata": {},
304
  "outputs": [
305
  {
@@ -308,7 +322,7 @@
308
  "Text(0.5, 1.0, 'Value counts of country label in test data')"
309
  ]
310
  },
311
- "execution_count": 33,
312
  "metadata": {},
313
  "output_type": "execute_result"
314
  },
@@ -343,6 +357,187 @@
343
  "## 2. Training the Classifier"
344
  ]
345
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  {
347
  "cell_type": "markdown",
348
  "metadata": {},
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 1,
14
  "metadata": {},
15
+ "outputs": [
16
+ {
17
+ "name": "stderr",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "/home/mehdi/miniconda3/envs/adc/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
21
+ " from .autonotebook import tqdm as notebook_tqdm\n"
22
+ ]
23
+ }
24
+ ],
25
  "source": [
26
+ "import pickle\n",
27
+ "\n",
28
+ "from datasets import DatasetDict, Dataset\n",
29
  "import matplotlib.pyplot as plt\n",
30
+ "import pandas as pd\n",
31
+ "import torch\n",
32
+ "from transformers import AutoModel, AutoTokenizer"
33
  ]
34
  },
35
  {
 
41
  },
42
  {
43
  "cell_type": "code",
44
+ "execution_count": 2,
45
  "metadata": {},
46
  "outputs": [],
47
  "source": [
 
51
  },
52
  {
53
  "cell_type": "code",
54
+ "execution_count": 3,
55
  "metadata": {},
56
  "outputs": [
57
  {
 
137
  "4 Algeria dz_El-Oued "
138
  ]
139
  },
140
+ "execution_count": 3,
141
  "metadata": {},
142
  "output_type": "execute_result"
143
  }
 
148
  },
149
  {
150
  "cell_type": "code",
151
+ "execution_count": 4,
152
  "metadata": {},
153
  "outputs": [
154
  {
 
234
  "4 Libya ly_Misrata "
235
  ]
236
  },
237
+ "execution_count": 4,
238
  "metadata": {},
239
  "output_type": "execute_result"
240
  }
 
245
  },
246
  {
247
  "cell_type": "code",
248
+ "execution_count": 5,
249
  "metadata": {},
250
  "outputs": [
251
  {
 
263
  " dtype: int64)"
264
  ]
265
  },
266
+ "execution_count": 5,
267
  "metadata": {},
268
  "output_type": "execute_result"
269
  }
 
281
  },
282
  {
283
  "cell_type": "code",
284
+ "execution_count": 6,
285
  "metadata": {},
286
  "outputs": [
287
  {
 
290
  "Text(0.5, 1.0, 'Value counts of country label in train data')"
291
  ]
292
  },
293
+ "execution_count": 6,
294
  "metadata": {},
295
  "output_type": "execute_result"
296
  },
 
313
  },
314
  {
315
  "cell_type": "code",
316
+ "execution_count": 7,
317
  "metadata": {},
318
  "outputs": [
319
  {
 
322
  "Text(0.5, 1.0, 'Value counts of country label in test data')"
323
  ]
324
  },
325
+ "execution_count": 7,
326
  "metadata": {},
327
  "output_type": "execute_result"
328
  },
 
357
  "## 2. Training the Classifier"
358
  ]
359
  },
360
+ {
361
+ "cell_type": "markdown",
362
+ "metadata": {},
363
+ "source": [
364
+ "For this classifier, we will convert the tweets into vector embeddings using the AraBART model. We will use the last hidden layer of the model to extract the features"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "markdown",
369
+ "metadata": {},
370
+ "source": [
371
+ "### 2.1 Data Preparation\n",
372
+ "The first step is to prepare our data by tokenizing it to use it with the model AraBART."
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "markdown",
377
+ "metadata": {},
378
+ "source": [
379
+ "First, we load the model and its tokenizer."
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": 8,
385
+ "metadata": {},
386
+ "outputs": [
387
+ {
388
+ "data": {
389
+ "text/plain": [
390
+ "torch.device"
391
+ ]
392
+ },
393
+ "execution_count": 8,
394
+ "metadata": {},
395
+ "output_type": "execute_result"
396
+ }
397
+ ],
398
+ "source": [
399
+ "device = torch.device(\"cpu\")\n",
400
+ "type(device)"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 9,
406
+ "metadata": {},
407
+ "outputs": [],
408
+ "source": [
409
+ "model = AutoModel.from_pretrained(\"moussaKam/AraBART\").to(device)\n",
410
+ "tokenizer = AutoTokenizer.from_pretrained(\"moussaKam/AraBART\")"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "markdown",
415
+ "metadata": {},
416
+ "source": [
417
+ "Next, we convert the datasets into a DatasetDict object."
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": 10,
423
+ "metadata": {},
424
+ "outputs": [],
425
+ "source": [
426
+ "mapper = {\"#2_tweet\": \"tweet\", \"#3_country_label\": \"label\"}\n",
427
+ "columns_to_keep = [\"tweet\", \"label\"]\n",
428
+ "\n",
429
+ "df_train = df_train.rename(columns=mapper)[columns_to_keep]\n",
430
+ "df_test = df_test.rename(columns=mapper)[columns_to_keep]\n",
431
+ "\n",
432
+ "train_dataset = Dataset.from_pandas(df_train)\n",
433
+ "test_dataset = Dataset.from_pandas(df_test)\n",
434
+ "data = DatasetDict({'train': train_dataset, 'test': test_dataset})"
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "markdown",
439
+ "metadata": {},
440
+ "source": [
441
+ "Then, we tokenkize the dataset."
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": 11,
447
+ "metadata": {},
448
+ "outputs": [
449
+ {
450
+ "name": "stderr",
451
+ "output_type": "stream",
452
+ "text": [
453
+ " \r"
454
+ ]
455
+ }
456
+ ],
457
+ "source": [
458
+ "def tokenize(batch):\n",
459
+ " return tokenizer(batch[\"tweet\"], padding=True)\n",
460
+ "\n",
461
+ "data_encoded = data.map(tokenize, batched=True, batch_size=None)"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": 12,
467
+ "metadata": {},
468
+ "outputs": [
469
+ {
470
+ "data": {
471
+ "text/plain": [
472
+ "['tweet', 'label', 'input_ids', 'attention_mask']"
473
+ ]
474
+ },
475
+ "execution_count": 12,
476
+ "metadata": {},
477
+ "output_type": "execute_result"
478
+ }
479
+ ],
480
+ "source": [
481
+ "data_encoded[\"train\"].column_names"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "markdown",
486
+ "metadata": {},
487
+ "source": [
488
+ "### 2.2 Feature Extraction"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": 13,
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": [
497
+ "def extract_hidden_states(batch):\n",
498
+ " inputs = {k:v.to(device) for k,v in batch.items()\n",
499
+ " if k in tokenizer.model_input_names}\n",
500
+ " with torch.no_grad():\n",
501
+ " last_hidden_state = model(**inputs).last_hidden_state\n",
502
+ "\n",
503
+ " return{\"hidden_state\": last_hidden_state[:,0].cpu().numpy()}"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": 14,
509
+ "metadata": {},
510
+ "outputs": [
511
+ {
512
+ "name": "stderr",
513
+ "output_type": "stream",
514
+ "text": [
515
+ " \r"
516
+ ]
517
+ }
518
+ ],
519
+ "source": [
520
+ "data_encoded.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])\n",
521
+ "data_hidden = data_encoded.map(extract_hidden_states, batched=True, batch_size=50)"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": 16,
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "with open(\"../data/data_hidden.pkl\", \"wb\") as f:\n",
531
+ " pickle.dump(data_hidden, f)"
532
+ ]
533
+ },
534
+ {
535
+ "cell_type": "markdown",
536
+ "metadata": {},
537
+ "source": [
538
+ "### 2.3 Model Training"
539
+ ]
540
+ },
541
  {
542
  "cell_type": "markdown",
543
  "metadata": {},