File size: 5,412 Bytes
0701931
 
 
 
 
 
17f0613
0701931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79ee030
0701931
 
 
 
79ee030
0701931
 
 
 
17f0613
0701931
 
17f0613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6bbbcb
17f0613
a6bbbcb
17f0613
a6bbbcb
17f0613
a6bbbcb
17f0613
a6bbbcb
17f0613
 
 
 
 
 
 
a6bbbcb
 
17f0613
 
a6bbbcb
17f0613
 
 
a6bbbcb
17f0613
 
 
 
 
 
 
 
0701931
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np 
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix

def presentation_modele(st,data, model,class_labels, y_test,encoder):
    
    st.write('Notre modèle prend les embeddings de Camembert pour les descriptions et designations (séparemment), les embeddings de FlauBert pour les descriptions, les embeddings VIT pour les images et les tailles des champs de texte.')
    
    st.image("model.png",  use_column_width=True)
    #afficher une image du modele
    #afficher les embeddings en extrait
    #ajouter un bouton qui declanche le training
    if st.button("Prédire"):
        X1_test = data["embeddings_desi"].values
        X1_test = np.stack(X1_test).astype(np.float32)
        X2_test = data["embeddings_desc"].values
        X2_test = np.stack(X2_test).astype(np.float32)
        X3_test = data["embedding_vit"].values
        X3_test = np.stack(X3_test).astype(np.float32)
        X4_test = data["designation_length_normalized"].values
        X5_test = data["description_length_normalized"].values
        X6_test = data["embeddings_desi_Flaubert"].values
        X6_test = np.stack(X6_test).astype(np.float32)
        y_pred = model.predict([X1_test, X2_test,X3_test,X4_test,X5_test,X6_test])
        y_pred_ids = np.argmax(y_pred, axis=-1)

        weighted_f1_score = f1_score(y_test, y_pred_ids, average='weighted')
        st.write("weighted F1 score:",weighted_f1_score)
    
 
        conf_matrix = confusion_matrix(y_test, y_pred_ids)

        row_sums = conf_matrix.sum(axis=0)
        normalized_conf_matrix = conf_matrix / row_sums[ np.newaxis,:]*100


        st.title("Matrice de Confusion Normalisée")
        plt.figure(figsize=(10, 10))
        
        sns.heatmap(normalized_conf_matrix, annot=True, cmap='Blues',fmt='.0f',
                    xticklabels=class_labels,
                    yticklabels=class_labels,
                    linewidths=1.5)
        plt.xticks(rotation=45)
        plt.xlabel('Prédictions')
        plt.ylabel('Réelles')
        plt.title('Matrice de Confusion')
        st.pyplot(plt)
    
    #afficher la matrice de conf.
    st.dataframe(data.head(10))
    ligne = st.number_input(label="Prédire la ligne:",min_value=0, max_value=1000, value=0)
    if st.button("Obtenir La prédiction"):
        X1_test = data["embeddings_desi"].iloc[[ligne]]       
        X1_test = np.stack(X1_test).astype(np.float32) 
        X2_test = data["embeddings_desc"].iloc[[ligne]] 
        X2_test = np.stack(X2_test).astype(np.float32)       
        X3_test = data["embedding_vit"].iloc[[ligne]]
        X3_test = np.stack(X3_test).astype(np.float32)
        X4_test = data["designation_length_normalized"].iloc[[ligne]]
        X5_test = data["description_length_normalized"].iloc[[ligne]]
        X6_test = data["embeddings_desi_Flaubert"].iloc[[ligne]]
        X6_test = np.stack(X6_test).astype(np.float32)

        pred = model.predict([X1_test, X2_test,X3_test,X4_test,X5_test,X6_test])
        pred_ids = np.argmax(pred, axis=-1)
        val_pred = encoder.inverse_transform(pred_ids)[0]
        val_true = encoder.inverse_transform([y_test[ligne]])[0]
        
        if(val_pred == val_true):
            col1,col2,_ = st.columns([1,5,18])
            with col1:
                st.image("check.png")
            with col2:
                st.write(f":green[Prédiction: {encoder.inverse_transform(pred_ids)} Réel: {encoder.inverse_transform([y_test[ligne]])}]")
        else:

            col1,col2,_ = st.columns([1,5,18])
            with col1:
                st.image("uncheck.png")
            with col2:
                st.write(f":red[Prédiction: {encoder.inverse_transform(pred_ids)} Réel: {encoder.inverse_transform([y_test[ligne]])}]")
        st.text("")    
        st.text(f"""
        
        designation: {data['designation'].values[ligne]}
        description: {data['tr_description'].values[ligne]}
        """)

        
        cat_dict = {
            '10':"Livres anciens",
            '40':"Jeux import",
            "50" : "accessoires jeux consoles",
            "60": "consoles rétro",
            "1140" :"figurines jeu et jeux de roles",
            "1160": "cartes à collectionner",
            "1180": "figurines miniatures",
            "1280": "jouet enfant",
            "1281": "jouets enfants",
            "1300": "Modèles réduits et accessoires",
            "1301": "chaussettes enfant",
            "1302": "jeux d'extérieur",
            "1320": "Accessoire puériculture",
            "1560": "Cuisine et accessoire maison",
            "1920": "literie",
            "1940": "ingrédients culinaires",
            "2060": "Déco Maison",
            "2220": "accessoires animalerie",
            "2280": "Magazines anciens",
            "2403": "Lots de livres anciens",
            "2462": "consoles et accessoires occasion",
            "2522": "papeterie",
            "2582": "mobilier de jardin",
            "2583": "piscine et accessoires",
            "2585": "Le Jardin",
            "2705": "livres",
            "2905": "jeux en téléchargement",
        }
        print(val_pred)
        print(val_true)
        
        st.write(f"{val_pred}: {cat_dict[f'{val_pred}']}")
        if(val_pred != val_true):
            st.write(f"{val_true}: {cat_dict[f'{val_true}']}")