rowankwang's picture
redis
b799b9a
import streamlit as st
import json
from streamlit_shortcuts import add_keyboard_shortcuts
import random
import requests
st.set_page_config(layout="wide")
# import streamlit_authenticator as stauth
# import yaml
# from yaml.loader import SafeLoader
# with open('auth.yaml') as file:
# config = yaml.load(file, Loader=SafeLoader)
# authenticator = stauth.Authenticate(
# config['credentials'],
# config['cookie']['name'],
# config['cookie']['key'],
# config['cookie']['expiry_days'],
# config['preauthorized']
# )
file_path = 'grid_eval_gpt4o.json'
@st.cache_data()
def fetch_data(fetch_url, num_samples=10):
payload = {
'num_samples': num_samples
}
headers = {
'Content-Type': 'application/json'
}
response = requests.post(fetch_url, json=payload, headers=headers)
if response.status_code == 200:
data = response.json()
return data
else:
st.error(f"Failed to fetch data: {response.status_code}")
return None
def update_data(query, new_grade_result):
payload = {
'query': query,
'newGradeResult': new_grade_result
}
headers = {
'Content-Type': 'application/json'
}
response = requests.put(URL + "update", json=payload, headers=headers)
if response.status_code != 200:
st.error(f"Failed to update data: {response.status_code}")
def get_new_grade_result_from_data(data):
new_grade_result = {
'status': data['status'],
'results': [{'url' : result['url'], 'agree': result['agree']} for result in data['results']]
}
return new_grade_result
##https://synthetic-data-framework.vercel.app/samples
# Load your data
@st.cache_data()
def load_data():
with open(file_path, 'r') as file:
data = json.load(file)
return data
def save_data(data):
print(file_path.split(".json")[0])
with open(f"{file_path.split('.json')[0]}_graded.json", 'w') as file:
json.dump(data, file, indent=4)
def download_json(data):
return json.dumps(data, indent=4)
NUM_SAMPLES = 5
URL = "https://synthetic-data-framework.vercel.app/"
data = fetch_data(URL + "samples", num_samples=NUM_SAMPLES)
def refresh_data():
fetch_data.clear(URL, NUM_SAMPLES)
st.session_state.data = fetch_data(URL + "samples", num_samples=NUM_SAMPLES)
st.session_state.current_query_index = 0
st.session_state.graded_queries = 0
for query in data:
for result in query['results']:
if 'agree' not in result:
result['agree'] = True
# State management for current query index
if 'current_query_index' not in st.session_state:
st.session_state.current_query_index = 0
if 'data' not in st.session_state:
st.session_state.data = data
if 'graded_queries' not in st.session_state:
st.session_state.graded_queries = 0
def truncate_text(text, length=250):
return text if len(text) <= length else text[:length] + '...'
result_box_style = """
<style>
.rounded-box {
border: 1px solid #ddd;
border-radius: 10px;
padding: 0.01px;
margin-bottom: 10px;
}
</style>
"""
# Navigation to next query
def next_query():
if st.session_state.current_query_index < len(data) - 1:
st.session_state.current_query_index += 1
st.rerun()
add_keyboard_shortcuts({
's': 'Skip',
})
# Display current query and its results
def display_query():
# Navigation bar
global current_query
st.session_state.graded_queries = sum(query.get('status', None) == 'graded' for query in st.session_state.data)
print(f"Current Query Index: {st.session_state.current_query_index} | Graded Queries: {st.session_state.graded_queries} | Total Queries: {len(st.session_state.data)} | Current Query Status {current_query.get('status', None)}")
col1, col2 = st.columns([4, 2], gap="small")
with col1:
if st.button('Previous'):
if st.session_state.current_query_index > 0:
st.session_state.current_query_index -= 1 % len(st.session_state.data)
st.rerun()
st.progress((st.session_state.current_query_index + 1) / len(st.session_state.data))
with col2:
col1, col2, col3, col4, col5 = st.columns([2, 2, 2, 3, 3], gap = "small")
with col1:
if st.button('Next'):
if st.session_state.current_query_index < len(st.session_state.data) - 1:
st.session_state.current_query_index += 1 % len(st.session_state.data)
st.rerun()
with col2:
if st.button('Skip'):
current_query['status'] = 'skipped'
update_data(current_query['query'], get_new_grade_result_from_data(current_query))
next_query()
with col3:
if st.button('Junk'):
current_query['status'] = 'nonsense'
update_data(current_query['query'], get_new_grade_result_from_data(current_query))
next_query()
with col4:
# Example button for downloading data
st.download_button(
label="Download",
data=download_json(st.session_state.data),
file_name="graded_data.json",
mime="application/json"
)
with col5:
if st.button('Renew'):
refresh_data()
st.rerun()
index = st.text_input(f"At index {st.session_state.current_query_index + 1}. Graded: {st.session_state.graded_queries}/{len(st.session_state.data)}", placeholder="Go to index:")
if index:
try:
index = int(index) - 1
if index < 0 or index >= len(data):
st.error("Invalid index.")
else:
st.session_state.current_query_index = index
st.rerun()
except ValueError:
st.error("Please enter a valid integer.")
if st.session_state.graded_queries >= len(data):
# save_data(st.session_state.data)
st.success(f"{len(data)} Queries graded!")
st.markdown(result_box_style, unsafe_allow_html=True)
st.header(f"Query: {current_query['query']}")
status_color = 'green' if current_query.get('status', None) == "graded" else 'red'
st.markdown(f"{current_query['metadata']['grid_pos_str']} | Query Grade: <b style='color: {status_color};'>{current_query.get('status', None)}</b>", unsafe_allow_html = True)
with st.expander("Model's Query Gen Reasoning Trace"):
st.markdown(f"{current_query['metadata']['reasoning_trace'][0]}")
st.subheader("Results:")
for index, result in enumerate(current_query['results']):
st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True)
col1, col2 = st.columns([3, 2], gap="small")
with col1:
st.markdown(f"<h5>{result['title']}</h5>", unsafe_allow_html=True)
st.markdown(f"[<span style='font-size: 0.8em;'>{truncate_text(result['url'], length = 50)}</span>]({result['url']}) | {result['published_date']}", unsafe_allow_html=True)
st.markdown(f"{truncate_text(result['text'], length = len(result['model_trace']))}")
with col2:
grade_color = 'green' if result['grade'].lower() == 'yes' else 'red'
st.markdown(f"Model Grade: <b style='color: {grade_color};'>{result['grade']}</b>", unsafe_allow_html=True)
st.write(result['model_trace'])
if st.checkbox("Reject", value= not result.get('agree'), key=f'verify-{index}'):
result['agree'] = False
st.markdown("</div>", unsafe_allow_html=True)
st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True)
# Show current query and its results
current_query = st.session_state.data[st.session_state.current_query_index]
display_query()
col1, col2 = st.columns([5, 1], gap="small")
with col2:
if st.button('Mark Done and Go to Next'):
current_query['status'] = 'graded'
update_data(current_query['query'], get_new_grade_result_from_data(current_query))
next_query()
add_keyboard_shortcuts({
'j': 'Junk',
})
add_keyboard_shortcuts({
'p': 'Previous',
})
add_keyboard_shortcuts({
'n': 'Next',
})
add_keyboard_shortcuts({
'd': 'Download',
})
add_keyboard_shortcuts({
'r': 'Renew',
})