chat2plot / main.py
RobertoHdez's picture
Create main.py
5ece554 verified
import streamlit as st
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import plotly.express
import os
from groq import Groq
from scipy.stats import *
import statsmodels
cache_dic = {"have high paying fare passengers better survived rate":
"""st.write("Comparing the survival rate of passengers with high paying fares to those with low paying fares.")
st.write("First, let's filter the dataframe to only include passengers with fares greater than the median fare.")
high_fare_passengers = df[df['Fare'] > df['Fare'].median()]
st.write("Next, let's calculate the survival rate for these high fare passengers.")
high_fare_survival_rate = (high_fare_passengers['Survived'].sum()) / len(high_fare_passengers)
st.write("Now, let's filter the dataframe to only include passengers with fares less than or equal to the median fare.")
low_fare_passengers = df[df['Fare'] <= df['Fare'].median()]
st.write("Next, let's calculate the survival rate for these low fare passengers.")
low_fare_survival_rate = (low_fare_passengers['Survived'].sum()) / len(low_fare_passengers)
st.write("Finally, let's create a bar chart to visualize the survival rates for high and low fare passengers.")
import plotly.express as px
fig = px.bar(x=['High Fare', 'Low Fare'], y=[high_fare_survival_rate, low_fare_survival_rate])
st.write(fig)
st.write("The survival rate for high fare passengers is {:.2f}% and for low fare passengers is {:.2f}%.".format(high_fare_survival_rate*100, low_fare_survival_rate*100))
st.write("It appears that passengers with high paying fares had a higher survival rate than those with low paying fares.")""",
"List of examples" : """st.write("Here is a few examples to show!")""",
"does sex impact survived rate": """
st.write("Let's analyze the impact of sex on the survival rate.")
st.write("First, let's calculate the survival rate for each sex group:")
survived_male = (df[df['Sex'] == 'male']['Survived'].sum()) / df[df['Sex'] == 'male']['Survived'].count()
survived_female = (df[df['Sex'] == 'female']['Survived'].sum()) / df[df['Sex'] == 'female']['Survived'].count()
st.write("The survival rate for males is: ", survived_male)
st.write("The survival rate for females is: ", survived_female)
st.write("Now, let's visualize the survival rate for each sex group:")
fig = px.bar(x=['male', 'female'], y=[survived_male, survived_female], title='Survival Rate by Sex')
st.plotly_chart(fig)
st.write("From the graph, we can see that the survival rate for females is higher than for males. This suggests that being a female passenger had a positive impsitive impact on the survival rate.")""",
"provide Histogram plot of fares splitted by sex":"""
st.write("Histogram plot of fares splitted by sex")
import plotly.express as px
fig = px.histogram(df, x="Fare", color="Sex", nbins=50)
st.write(fig)
""",
"distribution of the number of siblings and spouses on board":
"""
st.write("Distribution of the number of siblings and spouses on board:")
st.write(df['SibSp'].value_counts().sort_values(ascending=False))
""",
"Which is the correlation between age and survived?":
"""
st.write("Calculating the correlation between age and survived...")
correlation = df['Age'].corr(df['Survived'])
st.write("The correlation between age and survived is: ", correlation)
""",
"represent the distribution of fares":
"""
import plotly.express as px
fig = px.histogram(df, x="Fare", nbins=50)
st.write(fig)
""",
"provide violin plot of the ages splitted by sex":
"""
st.write("Violin plot of the ages splitted by sex:")
import plotly.express as px
fig = px.violin(df, x="Sex", y="Age", color="Sex", box=True, points="all")
st.write(fig)
""",
"does Embarked column infuence survived?":
"""
st.write("Let's analyze the relationship between the Embarked column and the Survived column.")
st.write("First, let's count the number of survived and not survived passengers for each embarkment site:")
embarked_survived = df.groupby(['Embarked', 'Survived']).size().reset_index(name='count')
st.write(embarked_survived)
st.write("Now, let's create a bar chart to visualize the results:")
import plotly.express as px
fig = px.bar(embarked_survived, x='Embarked', y='count', color='Survived', barmode='group')
st.write(fig)
st.write("From the chart, we can see that the survival rate is higher for passengers who embarked at Cherbourg (C) and Queenstown (Q), and lower for passengers who embarked at Southampton (S).")
"""
}
df = pd.read_csv(r"train.csv")
st.title("Welcome to chat2plot!")
content6 = """
Estás trabajando con un dataframe de pandas `df` que se carga desde un archivo csv. Este es el resultado de `print(df.head())`:
{df.head}
Las columnas de df son :
{df.columns}
Descripción de las columnas:
Sex es el sexo del pasajero o pasajera, los valores son 'male' o 'female', así que debes tener esto en cuenta en cualquier pregunta relacionada con esta columna.
SibSp se refiere a los hermanos o esposos a bordo del pasajero o pasajera
Fare hace referencia al precio del billete. Es lo que el pasajero ha pagado.
Embarked se refiere al sitio desde donde ha subido el pasajero.
Instrucciones:
Ten en cuenta los nombres de las columnas y las letras mayúsculas, corrígelo o toma la columna más parecida. Por ejemplo: si el usuario escribe 'fare' en lugar de 'Fare', toma la columna 'Fare'.
Eres desarrollador senior de Python con conocimientos de Streamlit. En base a la pregunta del usuario, has de desarrollar la respuesta paso a paso. Cualquier variable de python tiene que ir con st.write(), recuerda no añadir texto, sólo el codigo de python. Recuerda ordenar los resultados de los cálculos en orden descedente del valor
Ten en cuenta que este output se va a ejecutar mediante exec(), por lo que necesito que el texto que proporciones con el desarrollo vayan con st.write(), también en el codigo de python para imprimir los resultados, excepto en las gráficas
Y recuerda no dejar texto sin st.write() excepto las gráficas (px....). Si añades alguna gráfica, no lo hagas con pyplot sino con plotly.express y recuerda no añadir st.write() y declarar la gráfica como fig para poder representarla
Además, necesito que incluyas alguna conclusión interesante en base al input, mediante st.write(), no lo hagas mediante # porque no se leerá. No describas gráficas, solo los cálculos y recuerda no inventarte ninguna conclusión, añade conclusiones en base a los datos.
Nota: evita símbolos especiales como "$" que pueden alterar el texto.
Nota: para comparar columnas, limítate a realizar la gráfica para la comparación, no realices contrastes de hipótesis.
Te resumo los pasos a seguir:
1. Analiza la pregunta, qué está pidiendo el usuario.
2. Analiza cómo puedes responder lo que pide el usuario con pandas o plotly.express. Ten en cuenta el formato de las columnas, si son variables numéricas, categóricas, etc. Corrige el input si es necesario para que se corresponda los valores con los valores del dataframe. Cualquier gráfica has de hacerla con plotly.express.
3. Una vez tengas código y compruebes que esto responde la pregunta del usuario. Luego, ejecuta ese código y obtén los resultados. Además, necesito que incluyas alguna conclusión interesante en base a la respuesta mediante st.write() IMPORTANTE CIERRA SIEMPRE LOS PARENTESIS. Describe los resultados de los cálculos realizados y nunca de las gráficas
4. Comprueba que todo el texto es coherente y no se presenta ninguna incongruencia.
5. Finalmente, cuando tengas la respuesta y los cálculos hechos, devuelve la respuesta desarrollada paso a paso para que pueda ser ejecutada en Streamlit. Todo el texto y el código tiene que ir con st.write() para que pueda verlo el usuario. Comprueba que todo el texto es coherente y no se presenta ninguna incongruencia y corrige lo que sea necesario. Además, asegúrate de que el código no tenga errores de sintaxis o de lógica
Recuerda que sólo debes devolver el paso 5 como respuesta, los pasos previos los debes procesar internamente. Es importante que las gráficas se realicen con plotly.express y que se están tomando los datos necesarios para responder la pregunta
"""
def prompting(input):
if ((input ==None) or (input =="")):
return """st.write("waiting for question")"""
client = Groq(
api_key=os.getenv('API_KEY') ,
)
chat_completion = client.chat.completions.create(
messages=[{
"role": "system",
"content": content6,
},
{
"role": "user",
"content": input
}
],
model="llama3-8b-8192", temperature=0
)
mode = chat_completion.choices[0].message.content
return mode
import streamlit as st
st.empty()
examples = ["List of examples","have high paying fare passengers better survived rate","does sex impact survived rate","provide Histogram plot of fares splitted by sex","distribution of the number of siblings and spouses on board", "Which is the correlation between age and survived?","represent the distribution of fares", "provide violin plot of the ages splitted by sex","does Embarked column infuence survived?"]
seleccion = st.selectbox("Pick an example", examples)
if seleccion:
exec(cache_dic[seleccion])
entrada_codigo = st.text_input("Or just try by yourself!", placeholder="Write your code right here!")
st.empty()
try:
try:
if not (entrada_codigo in st.session_state):
exec(prompting(entrada_codigo))
st.session_state[entrada_codigo]= prompting(entrada_codigo)
else:
exec(st.session_state[entrada_codigo])
except Exception as e2:
print(e2)
except Exception as e:
st.text("Oooooooooooooooooooooooops, try again!")
st.text(f"Error al ejecutar el código: {e}")
st.title("Next steps!")
st.markdown("- Render frontend components")
st.markdown("- Extend plot capabilities so user can request more specific plots and details")