Spaces:
Runtime error
Runtime error
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from sklearn.metrics import f1_score | |
from sklearn.metrics import confusion_matrix | |
def presentation_modele(st,data, model,class_labels, y_test,encoder): | |
st.write('Notre modèle prend les embeddings de Camembert pour les descriptions et designations (séparemment), les embeddings de FlauBert pour les descriptions, les embeddings VIT pour les images et les tailles des champs de texte.') | |
st.image("model.png", use_column_width=True) | |
#afficher une image du modele | |
#afficher les embeddings en extrait | |
#ajouter un bouton qui declanche le training | |
if st.button("Prédire"): | |
X1_test = data["embeddings_desi"].values | |
X1_test = np.stack(X1_test).astype(np.float32) | |
X2_test = data["embeddings_desc"].values | |
X2_test = np.stack(X2_test).astype(np.float32) | |
X3_test = data["embedding_vit"].values | |
X3_test = np.stack(X3_test).astype(np.float32) | |
X4_test = data["designation_length_normalized"].values | |
X5_test = data["description_length_normalized"].values | |
X6_test = data["embeddings_desi_Flaubert"].values | |
X6_test = np.stack(X6_test).astype(np.float32) | |
y_pred = model.predict([X1_test, X2_test,X3_test,X4_test,X5_test,X6_test]) | |
y_pred_ids = np.argmax(y_pred, axis=-1) | |
weighted_f1_score = f1_score(y_test, y_pred_ids, average='weighted') | |
st.write("weighted F1 score:",weighted_f1_score) | |
conf_matrix = confusion_matrix(y_test, y_pred_ids) | |
row_sums = conf_matrix.sum(axis=0) | |
normalized_conf_matrix = conf_matrix / row_sums[ np.newaxis,:]*100 | |
st.title("Matrice de Confusion Normalisée") | |
plt.figure(figsize=(10, 10)) | |
sns.heatmap(normalized_conf_matrix, annot=True, cmap='Blues',fmt='.0f', | |
xticklabels=class_labels, | |
yticklabels=class_labels, | |
linewidths=1.5) | |
plt.xticks(rotation=45) | |
plt.xlabel('Prédictions') | |
plt.ylabel('Réelles') | |
plt.title('Matrice de Confusion') | |
st.pyplot(plt) | |
#afficher la matrice de conf. | |
st.dataframe(data.head(10)) | |
ligne = st.number_input(label="Prédire la ligne:",min_value=0, max_value=1000, value=0) | |
if st.button("Obtenir La prédiction"): | |
X1_test = data["embeddings_desi"].iloc[[ligne]] | |
X1_test = np.stack(X1_test).astype(np.float32) | |
X2_test = data["embeddings_desc"].iloc[[ligne]] | |
X2_test = np.stack(X2_test).astype(np.float32) | |
X3_test = data["embedding_vit"].iloc[[ligne]] | |
X3_test = np.stack(X3_test).astype(np.float32) | |
X4_test = data["designation_length_normalized"].iloc[[ligne]] | |
X5_test = data["description_length_normalized"].iloc[[ligne]] | |
X6_test = data["embeddings_desi_Flaubert"].iloc[[ligne]] | |
X6_test = np.stack(X6_test).astype(np.float32) | |
pred = model.predict([X1_test, X2_test,X3_test,X4_test,X5_test,X6_test]) | |
pred_ids = np.argmax(pred, axis=-1) | |
val_pred = encoder.inverse_transform(pred_ids)[0] | |
val_true = encoder.inverse_transform([y_test[ligne]])[0] | |
if(val_pred == val_true): | |
col1,col2,_ = st.columns([1,5,18]) | |
with col1: | |
st.image("check.png") | |
with col2: | |
st.write(f":green[Prédiction: {encoder.inverse_transform(pred_ids)} Réel: {encoder.inverse_transform([y_test[ligne]])}]") | |
else: | |
col1,col2,_ = st.columns([1,5,18]) | |
with col1: | |
st.image("uncheck.png") | |
with col2: | |
st.write(f":red[Prédiction: {encoder.inverse_transform(pred_ids)} Réel: {encoder.inverse_transform([y_test[ligne]])}]") | |
st.text("") | |
st.text(f""" | |
designation: {data['designation'].values[ligne]} | |
description: {data['tr_description'].values[ligne]} | |
""") | |
cat_dict = { | |
'10':"Livres anciens", | |
'40':"Jeux import", | |
"50" : "accessoires jeux consoles", | |
"60": "consoles rétro", | |
"1140" :"figurines jeu et jeux de roles", | |
"1160": "cartes à collectionner", | |
"1180": "figurines miniatures", | |
"1280": "jouet enfant", | |
"1281": "jouets enfants", | |
"1300": "Modèles réduits et accessoires", | |
"1301": "chaussettes enfant", | |
"1302": "jeux d'extérieur", | |
"1320": "Accessoire puériculture", | |
"1560": "Cuisine et accessoire maison", | |
"1920": "literie", | |
"1940": "ingrédients culinaires", | |
"2060": "Déco Maison", | |
"2220": "accessoires animalerie", | |
"2280": "Magazines anciens", | |
"2403": "Lots de livres anciens", | |
"2462": "consoles et accessoires occasion", | |
"2522": "papeterie", | |
"2582": "mobilier de jardin", | |
"2583": "piscine et accessoires", | |
"2585": "Le Jardin", | |
"2705": "livres", | |
"2905": "jeux en téléchargement", | |
} | |
print(val_pred) | |
print(val_true) | |
st.write(f"{val_pred}: {cat_dict[f'{val_pred}']}") | |
if(val_pred != val_true): | |
st.write(f"{val_true}: {cat_dict[f'{val_true}']}") | |