pechaut commited on
Commit
17f0613
1 Parent(s): 79ee030

Update streamlit_presentation/modele.py

Browse files
Files changed (1) hide show
  1. streamlit_presentation/modele.py +77 -2
streamlit_presentation/modele.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from sklearn.metrics import f1_score
5
  from sklearn.metrics import confusion_matrix
6
 
7
- def presentation_modele(st,data, model,class_labels, y_test):
8
 
9
  st.write('Notre modèle prend les embeddings de Camembert pour les descriptions et designations (séparemment), les embeddings de FlauBert pour les descriptions, les embeddings VIT pour les images et les tailles des champs de texte.')
10
 
@@ -48,8 +48,83 @@ def presentation_modele(st,data, model,class_labels, y_test):
48
  plt.ylabel('Réelles')
49
  plt.title('Matrice de Confusion')
50
  st.pyplot(plt)
51
-
52
  #afficher la matrice de conf.
53
  st.dataframe(data.head(10))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
 
4
  from sklearn.metrics import f1_score
5
  from sklearn.metrics import confusion_matrix
6
 
7
+ def presentation_modele(st,data, model,class_labels, y_test,encoder):
8
 
9
  st.write('Notre modèle prend les embeddings de Camembert pour les descriptions et designations (séparemment), les embeddings de FlauBert pour les descriptions, les embeddings VIT pour les images et les tailles des champs de texte.')
10
 
 
48
  plt.ylabel('Réelles')
49
  plt.title('Matrice de Confusion')
50
  st.pyplot(plt)
51
+
52
  #afficher la matrice de conf.
53
  st.dataframe(data.head(10))
54
+ ligne = st.number_input(label="Prédire la ligne:",min_value=0, max_value=1000, value=0)
55
+ if st.button("Obtenir La prédiction"):
56
+ X1_test = data["embeddings_desi"].iloc[[ligne]]
57
+ X1_test = np.stack(X1_test).astype(np.float32)
58
+ X2_test = data["embeddings_desc"].iloc[[ligne]]
59
+ X2_test = np.stack(X2_test).astype(np.float32)
60
+ X3_test = data["embedding_vit"].iloc[[ligne]]
61
+ X3_test = np.stack(X3_test).astype(np.float32)
62
+ X4_test = data["designation_length_normalized"].iloc[[ligne]]
63
+ X5_test = data["description_length_normalized"].iloc[[ligne]]
64
+ X6_test = data["embeddings_desi_Flaubert"].iloc[[ligne]]
65
+ X6_test = np.stack(X6_test).astype(np.float32)
66
+
67
+ pred = model.predict([X1_test, X2_test,X3_test,X4_test,X5_test,X6_test])
68
+ pred_ids = np.argmax(pred, axis=-1)
69
+ val_pred = encoder.inverse_transform(pred_ids)[0]
70
+ val_true = encoder.inverse_transform([y_test[ligne]])[0]
71
+
72
+ if(val_pred == val_true):
73
+ col1,col2,_ = st.columns([1,5,18])
74
+ with col1:
75
+ st.image("check.png")
76
+ with col2:
77
+ st.write(f":green[Prédiction: {encoder.inverse_transform(pred_ids)} Réel: {encoder.inverse_transform([y_test[ligne]])}]")
78
+ else:
79
+
80
+ col1,col2,_ = st.columns([1,5,18])
81
+ with col1:
82
+ st.image("uncheck.png")
83
+ with col2:
84
+ st.write(f":red[Prédiction: {encoder.inverse_transform(pred_ids)} Réel: {encoder.inverse_transform([y_test[ligne]])}]")
85
+ st.text("")
86
+ st.text(f"""
87
+
88
+ designation: {data['designation'].values[ligne]}
89
+ description: {data['tr_description'].values[ligne]}
90
+ """)
91
+
92
+
93
+ cat_dict = {
94
+ '10':"Livres anciens",
95
+ '40':"Jeux import",
96
+ "50" : "accessoires jeux consoles ?",
97
+ "60": "consoles rétro",
98
+ "1140" :"figurines",
99
+ "1160": "cartes à collectionner",
100
+ "1180": "figurine miniatures",
101
+ "1280": "jouet enfant",
102
+ "1281": "jouet enfants",
103
+ "1300": "Modèles réduits et accessoires",
104
+ "1301": "vêtements enfant",
105
+ "1302": "jeux d'extérieur",
106
+ "1320": "Accessoire puériculture",
107
+ "1560": "Cuisine et accessoire maison",
108
+ "1920": "literie",
109
+ "1940": "ingrédients culinaires",
110
+ "2060": "Déco Maison",
111
+ "2220": "accessoires animalerie",
112
+ "2280": "Magazines",
113
+ "2403": "livres anciens",
114
+ "2462": "consoles et accessoires occasion",
115
+ "2522": "papeterie",
116
+ "2582": "?? La maison",
117
+ "2583": "piscine et accessoires",
118
+ "2585": "Le Jardin",
119
+ "2705": "livres",
120
+ "2905": "jeux en téléchargement (cf désignation) ?",
121
+ }
122
+ print(val_pred)
123
+ print(val_true)
124
+
125
+ st.write(f"{val_pred}: {cat_dict[f'{val_pred}']}")
126
+ if(val_pred != val_true):
127
+ st.write(f"{val_true}: {cat_dict[f'{val_true}']}")
128
+
129
 
130