zaidmehdi commited on
Commit
cf5faeb
1 Parent(s): 370a710

grid search for xgboost

Browse files
Files changed (1) hide show
  1. src/classifier.ipynb +83 -5
src/classifier.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 36,
14
  "metadata": {},
15
  "outputs": [],
16
  "source": [
@@ -23,8 +23,10 @@
23
  "from sklearn.ensemble import RandomForestClassifier\n",
24
  "from sklearn.linear_model import LogisticRegression\n",
25
  "from sklearn.model_selection import RandomizedSearchCV\n",
 
26
  "import torch\n",
27
- "from transformers import AutoModel, AutoTokenizer"
 
28
  ]
29
  },
30
  {
@@ -533,6 +535,16 @@
533
  " pickle.dump(data_hidden, f)"
534
  ]
535
  },
 
 
 
 
 
 
 
 
 
 
536
  {
537
  "cell_type": "markdown",
538
  "metadata": {},
@@ -549,7 +561,7 @@
549
  },
550
  {
551
  "cell_type": "code",
552
- "execution_count": 28,
553
  "metadata": {},
554
  "outputs": [
555
  {
@@ -558,7 +570,7 @@
558
  "((21000, 768), (21000,))"
559
  ]
560
  },
561
- "execution_count": 28,
562
  "metadata": {},
563
  "output_type": "execute_result"
564
  }
@@ -695,6 +707,72 @@
695
  "print(\"Best Score:\", rf_search.best_score_)"
696
  ]
697
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
  {
699
  "cell_type": "markdown",
700
  "metadata": {},
@@ -719,7 +797,7 @@
719
  "name": "python",
720
  "nbconvert_exporter": "python",
721
  "pygments_lexer": "ipython3",
722
- "version": "3.11.7"
723
  }
724
  },
725
  "nbformat": 4,
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 6,
14
  "metadata": {},
15
  "outputs": [],
16
  "source": [
 
23
  "from sklearn.ensemble import RandomForestClassifier\n",
24
  "from sklearn.linear_model import LogisticRegression\n",
25
  "from sklearn.model_selection import RandomizedSearchCV\n",
26
+ "from sklearn.preprocessing import LabelEncoder\n",
27
  "import torch\n",
28
+ "from transformers import AutoModel, AutoTokenizer\n",
29
+ "import xgboost as xgb"
30
  ]
31
  },
32
  {
 
535
  " pickle.dump(data_hidden, f)"
536
  ]
537
  },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": 2,
541
+ "metadata": {},
542
+ "outputs": [],
543
+ "source": [
544
+ "with open(\"../data/data_hidden.pkl\", \"rb\") as f:\n",
545
+ " data_hidden = pickle.load(f)"
546
+ ]
547
+ },
548
  {
549
  "cell_type": "markdown",
550
  "metadata": {},
 
561
  },
562
  {
563
  "cell_type": "code",
564
+ "execution_count": 3,
565
  "metadata": {},
566
  "outputs": [
567
  {
 
570
  "((21000, 768), (21000,))"
571
  ]
572
  },
573
+ "execution_count": 3,
574
  "metadata": {},
575
  "output_type": "execute_result"
576
  }
 
707
  "print(\"Best Score:\", rf_search.best_score_)"
708
  ]
709
  },
710
+ {
711
+ "cell_type": "markdown",
712
+ "metadata": {},
713
+ "source": [
714
+ "#### 2.3.3 XGBoost"
715
+ ]
716
+ },
717
+ {
718
+ "cell_type": "markdown",
719
+ "metadata": {},
720
+ "source": [
721
+ "For XGBoost, we first need to encode the target variable."
722
+ ]
723
+ },
724
+ {
725
+ "cell_type": "code",
726
+ "execution_count": 7,
727
+ "metadata": {},
728
+ "outputs": [],
729
+ "source": [
730
+ "label_encoder = LabelEncoder()\n",
731
+ "y_train_encoded = label_encoder.fit_transform(y_train)\n",
732
+ "y_test_encoded = label_encoder.transform(y_test)"
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "code",
737
+ "execution_count": null,
738
+ "metadata": {},
739
+ "outputs": [],
740
+ "source": [
741
+ "xgb_model = xgb.XGBClassifier(device=\"cuda\", seed=2024)\n",
742
+ "parameters = {\n",
743
+ " \"n_estimators\" : [100, 150, 200, 300, 400, 450, 500],\n",
744
+ " \"max_depth\" : [3, 4, 5, 6, 7, 8],\n",
745
+ " \"learning_rate\": [0.1, 0.05, 0.01, 0.005, 0.001]\n",
746
+ "}\n",
747
+ "xgb_search = RandomizedSearchCV(estimator=xgb_model, param_distributions=parameters,\n",
748
+ " scoring=\"f1_macro\", cv=5, n_iter=20)\n",
749
+ "xgb_search.fit(X_train, y_train_encoded)"
750
+ ]
751
+ },
752
+ {
753
+ "cell_type": "code",
754
+ "execution_count": null,
755
+ "metadata": {},
756
+ "outputs": [],
757
+ "source": [
758
+ "print(\"Best Parameters:\", xgb_search.best_params_)\n",
759
+ "print(\"Best Score (Macro Average F1):\", xgb_search.best_score_)"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "markdown",
764
+ "metadata": {},
765
+ "source": [
766
+ "#### 2.3.4 LightGBM"
767
+ ]
768
+ },
769
+ {
770
+ "cell_type": "code",
771
+ "execution_count": null,
772
+ "metadata": {},
773
+ "outputs": [],
774
+ "source": []
775
+ },
776
  {
777
  "cell_type": "markdown",
778
  "metadata": {},
 
797
  "name": "python",
798
  "nbconvert_exporter": "python",
799
  "pygments_lexer": "ipython3",
800
+ "version": "3.1.0"
801
  }
802
  },
803
  "nbformat": 4,