zaidmehdi commited on
Commit
46afa74
1 Parent(s): 10f12ce

retraining models for evaluation

Browse files
Files changed (1) hide show
  1. src/classifier.ipynb +137 -3
src/classifier.ipynb CHANGED
@@ -10,9 +10,18 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 6,
14
  "metadata": {},
15
- "outputs": [],
 
 
 
 
 
 
 
 
 
16
  "source": [
17
  "import pickle\n",
18
  "\n",
@@ -723,7 +732,7 @@
723
  },
724
  {
725
  "cell_type": "code",
726
- "execution_count": 7,
727
  "metadata": {},
728
  "outputs": [],
729
  "source": [
@@ -776,6 +785,131 @@
776
  "source": [
777
  "## 3. Evaluating the Performance"
778
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  }
780
  ],
781
  "metadata": {
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 1,
14
  "metadata": {},
15
+ "outputs": [
16
+ {
17
+ "name": "stderr",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "/home/mehdi/miniconda3/envs/adc/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
21
+ " from .autonotebook import tqdm as notebook_tqdm\n"
22
+ ]
23
+ }
24
+ ],
25
  "source": [
26
  "import pickle\n",
27
  "\n",
 
732
  },
733
  {
734
  "cell_type": "code",
735
+ "execution_count": 4,
736
  "metadata": {},
737
  "outputs": [],
738
  "source": [
 
785
  "source": [
786
  "## 3. Evaluating the Performance"
787
  ]
788
+ },
789
+ {
790
+ "cell_type": "markdown",
791
+ "metadata": {},
792
+ "source": [
793
+ "First, let's retrain the models with the best parameters we obtained."
794
+ ]
795
+ },
796
+ {
797
+ "cell_type": "code",
798
+ "execution_count": 5,
799
+ "metadata": {},
800
+ "outputs": [
801
+ {
802
+ "data": {
803
+ "text/html": [
804
+ "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression(class_weight=&#x27;balanced&#x27;, max_iter=1000,\n",
805
+ " multi_class=&#x27;multinomial&#x27;, random_state=2024)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(class_weight=&#x27;balanced&#x27;, max_iter=1000,\n",
806
+ " multi_class=&#x27;multinomial&#x27;, random_state=2024)</pre></div></div></div></div></div>"
807
+ ],
808
+ "text/plain": [
809
+ "LogisticRegression(class_weight='balanced', max_iter=1000,\n",
810
+ " multi_class='multinomial', random_state=2024)"
811
+ ]
812
+ },
813
+ "execution_count": 5,
814
+ "metadata": {},
815
+ "output_type": "execute_result"
816
+ }
817
+ ],
818
+ "source": [
819
+ "lr_model = LogisticRegression(multi_class='multinomial', \n",
820
+ " class_weight=\"balanced\", \n",
821
+ " max_iter=1000, \n",
822
+ " random_state=2024)\n",
823
+ "lr_model.fit(X_train, y_train)"
824
+ ]
825
+ },
826
+ {
827
+ "cell_type": "code",
828
+ "execution_count": 6,
829
+ "metadata": {},
830
+ "outputs": [
831
+ {
832
+ "data": {
833
+ "text/html": [
834
+ "<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier(class_weight=&#x27;balanced&#x27;, max_depth=8, n_estimators=400,\n",
835
+ " random_state=2024)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier(class_weight=&#x27;balanced&#x27;, max_depth=8, n_estimators=400,\n",
836
+ " random_state=2024)</pre></div></div></div></div></div>"
837
+ ],
838
+ "text/plain": [
839
+ "RandomForestClassifier(class_weight='balanced', max_depth=8, n_estimators=400,\n",
840
+ " random_state=2024)"
841
+ ]
842
+ },
843
+ "execution_count": 6,
844
+ "metadata": {},
845
+ "output_type": "execute_result"
846
+ }
847
+ ],
848
+ "source": [
849
+ "rf_model = RandomForestClassifier(class_weight=\"balanced\", \n",
850
+ " random_state=2024,\n",
851
+ " n_estimators=400,\n",
852
+ " max_depth=8)\n",
853
+ "rf_model.fit(X_train, y_train)"
854
+ ]
855
+ },
856
+ {
857
+ "cell_type": "code",
858
+ "execution_count": 7,
859
+ "metadata": {},
860
+ "outputs": [
861
+ {
862
+ "data": {
863
+ "text/html": [
864
+ "<style>#sk-container-id-3 {color: black;}#sk-container-id-3 pre{padding: 0;}#sk-container-id-3 div.sk-toggleable {background-color: white;}#sk-container-id-3 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-3 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-3 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-3 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-3 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-3 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-3 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-3 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-3 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-3 div.sk-item {position: relative;z-index: 1;}#sk-container-id-3 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-3 div.sk-item::before, #sk-container-id-3 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-3 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-3 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-3 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-3 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-3 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-3 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-3 div.sk-label-container {text-align: center;}#sk-container-id-3 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-3 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
865
+ " colsample_bylevel=None, colsample_bynode=None,\n",
866
+ " colsample_bytree=None, device=&#x27;cuda&#x27;, early_stopping_rounds=None,\n",
867
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
868
+ " gamma=None, grow_policy=None, importance_type=None,\n",
869
+ " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
870
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
871
+ " max_delta_step=None, max_depth=7, max_leaves=None,\n",
872
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
873
+ " multi_strategy=None, n_estimators=450, n_jobs=None,\n",
874
+ " num_parallel_tree=None, objective=&#x27;multi:softprob&#x27;, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" checked><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
875
+ " colsample_bylevel=None, colsample_bynode=None,\n",
876
+ " colsample_bytree=None, device=&#x27;cuda&#x27;, early_stopping_rounds=None,\n",
877
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
878
+ " gamma=None, grow_policy=None, importance_type=None,\n",
879
+ " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
880
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
881
+ " max_delta_step=None, max_depth=7, max_leaves=None,\n",
882
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
883
+ " multi_strategy=None, n_estimators=450, n_jobs=None,\n",
884
+ " num_parallel_tree=None, objective=&#x27;multi:softprob&#x27;, ...)</pre></div></div></div></div></div>"
885
+ ],
886
+ "text/plain": [
887
+ "XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
888
+ " colsample_bylevel=None, colsample_bynode=None,\n",
889
+ " colsample_bytree=None, device='cuda', early_stopping_rounds=None,\n",
890
+ " enable_categorical=False, eval_metric=None, feature_types=None,\n",
891
+ " gamma=None, grow_policy=None, importance_type=None,\n",
892
+ " interaction_constraints=None, learning_rate=0.1, max_bin=None,\n",
893
+ " max_cat_threshold=None, max_cat_to_onehot=None,\n",
894
+ " max_delta_step=None, max_depth=7, max_leaves=None,\n",
895
+ " min_child_weight=None, missing=nan, monotone_constraints=None,\n",
896
+ " multi_strategy=None, n_estimators=450, n_jobs=None,\n",
897
+ " num_parallel_tree=None, objective='multi:softprob', ...)"
898
+ ]
899
+ },
900
+ "execution_count": 7,
901
+ "metadata": {},
902
+ "output_type": "execute_result"
903
+ }
904
+ ],
905
+ "source": [
906
+ "xgb_model = xgb.XGBClassifier(device=\"cuda\", \n",
907
+ " seed=2024,\n",
908
+ " n_estimators=450,\n",
909
+ " max_depth=7,\n",
910
+ " learning_rate=0.1)\n",
911
+ "xgb_model.fit(X_train, y_train_encoded)"
912
+ ]
913
  }
914
  ],
915
  "metadata": {