Spaces:
Sleeping
Sleeping
import streamlit as st | |
import json | |
from streamlit_shortcuts import add_keyboard_shortcuts | |
import random | |
import requests | |
st.set_page_config(layout="wide") | |
# import streamlit_authenticator as stauth | |
# import yaml | |
# from yaml.loader import SafeLoader | |
# with open('auth.yaml') as file: | |
# config = yaml.load(file, Loader=SafeLoader) | |
# authenticator = stauth.Authenticate( | |
# config['credentials'], | |
# config['cookie']['name'], | |
# config['cookie']['key'], | |
# config['cookie']['expiry_days'], | |
# config['preauthorized'] | |
# ) | |
file_path = 'grid_eval_gpt4o.json' | |
def fetch_data(fetch_url, num_samples=10): | |
payload = { | |
'num_samples': num_samples | |
} | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
response = requests.post(fetch_url, json=payload, headers=headers) | |
if response.status_code == 200: | |
data = response.json() | |
return data | |
else: | |
st.error(f"Failed to fetch data: {response.status_code}") | |
return None | |
def update_data(query, new_grade_result): | |
payload = { | |
'query': query, | |
'newGradeResult': new_grade_result | |
} | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
response = requests.put(URL + "update", json=payload, headers=headers) | |
if response.status_code != 200: | |
st.error(f"Failed to update data: {response.status_code}") | |
def get_new_grade_result_from_data(data): | |
new_grade_result = { | |
'status': data['status'], | |
'results': [{'url' : result['url'], 'agree': result['agree']} for result in data['results']] | |
} | |
return new_grade_result | |
##https://synthetic-data-framework.vercel.app/samples | |
# Load your data | |
def load_data(): | |
with open(file_path, 'r') as file: | |
data = json.load(file) | |
return data | |
def save_data(data): | |
print(file_path.split(".json")[0]) | |
with open(f"{file_path.split('.json')[0]}_graded.json", 'w') as file: | |
json.dump(data, file, indent=4) | |
def download_json(data): | |
return json.dumps(data, indent=4) | |
NUM_SAMPLES = 5 | |
URL = "https://synthetic-data-framework.vercel.app/" | |
data = fetch_data(URL + "samples", num_samples=NUM_SAMPLES) | |
def refresh_data(): | |
fetch_data.clear(URL, NUM_SAMPLES) | |
st.session_state.data = fetch_data(URL + "samples", num_samples=NUM_SAMPLES) | |
st.session_state.current_query_index = 0 | |
st.session_state.graded_queries = 0 | |
for query in data: | |
for result in query['results']: | |
if 'agree' not in result: | |
result['agree'] = True | |
# State management for current query index | |
if 'current_query_index' not in st.session_state: | |
st.session_state.current_query_index = 0 | |
if 'data' not in st.session_state: | |
st.session_state.data = data | |
if 'graded_queries' not in st.session_state: | |
st.session_state.graded_queries = 0 | |
def truncate_text(text, length=250): | |
return text if len(text) <= length else text[:length] + '...' | |
result_box_style = """ | |
<style> | |
.rounded-box { | |
border: 1px solid #ddd; | |
border-radius: 10px; | |
padding: 0.01px; | |
margin-bottom: 10px; | |
} | |
</style> | |
""" | |
# Navigation to next query | |
def next_query(): | |
if st.session_state.current_query_index < len(data) - 1: | |
st.session_state.current_query_index += 1 | |
st.rerun() | |
add_keyboard_shortcuts({ | |
's': 'Skip', | |
}) | |
# Display current query and its results | |
def display_query(): | |
# Navigation bar | |
global current_query | |
st.session_state.graded_queries = sum(query.get('status', None) == 'graded' for query in st.session_state.data) | |
print(f"Current Query Index: {st.session_state.current_query_index} | Graded Queries: {st.session_state.graded_queries} | Total Queries: {len(st.session_state.data)} | Current Query Status {current_query.get('status', None)}") | |
col1, col2 = st.columns([4, 2], gap="small") | |
with col1: | |
if st.button('Previous'): | |
if st.session_state.current_query_index > 0: | |
st.session_state.current_query_index -= 1 % len(st.session_state.data) | |
st.rerun() | |
st.progress((st.session_state.current_query_index + 1) / len(st.session_state.data)) | |
with col2: | |
col1, col2, col3, col4, col5 = st.columns([2, 2, 2, 3, 3], gap = "small") | |
with col1: | |
if st.button('Next'): | |
if st.session_state.current_query_index < len(st.session_state.data) - 1: | |
st.session_state.current_query_index += 1 % len(st.session_state.data) | |
st.rerun() | |
with col2: | |
if st.button('Skip'): | |
current_query['status'] = 'skipped' | |
update_data(current_query['query'], get_new_grade_result_from_data(current_query)) | |
next_query() | |
with col3: | |
if st.button('Junk'): | |
current_query['status'] = 'nonsense' | |
update_data(current_query['query'], get_new_grade_result_from_data(current_query)) | |
next_query() | |
with col4: | |
# Example button for downloading data | |
st.download_button( | |
label="Download", | |
data=download_json(st.session_state.data), | |
file_name="graded_data.json", | |
mime="application/json" | |
) | |
with col5: | |
if st.button('Renew'): | |
refresh_data() | |
st.rerun() | |
index = st.text_input(f"At index {st.session_state.current_query_index + 1}. Graded: {st.session_state.graded_queries}/{len(st.session_state.data)}", placeholder="Go to index:") | |
if index: | |
try: | |
index = int(index) - 1 | |
if index < 0 or index >= len(data): | |
st.error("Invalid index.") | |
else: | |
st.session_state.current_query_index = index | |
st.rerun() | |
except ValueError: | |
st.error("Please enter a valid integer.") | |
if st.session_state.graded_queries >= len(data): | |
# save_data(st.session_state.data) | |
st.success(f"{len(data)} Queries graded!") | |
st.markdown(result_box_style, unsafe_allow_html=True) | |
st.header(f"Query: {current_query['query']}") | |
status_color = 'green' if current_query.get('status', None) == "graded" else 'red' | |
st.markdown(f"{current_query['metadata']['grid_pos_str']} | Query Grade: <b style='color: {status_color};'>{current_query.get('status', None)}</b>", unsafe_allow_html = True) | |
with st.expander("Model's Query Gen Reasoning Trace"): | |
st.markdown(f"{current_query['metadata']['reasoning_trace'][0]}") | |
st.subheader("Results:") | |
for index, result in enumerate(current_query['results']): | |
st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True) | |
col1, col2 = st.columns([3, 2], gap="small") | |
with col1: | |
st.markdown(f"<h5>{result['title']}</h5>", unsafe_allow_html=True) | |
st.markdown(f"[<span style='font-size: 0.8em;'>{truncate_text(result['url'], length = 50)}</span>]({result['url']}) | {result['published_date']}", unsafe_allow_html=True) | |
st.markdown(f"{truncate_text(result['text'], length = len(result['model_trace']))}") | |
with col2: | |
grade_color = 'green' if result['grade'].lower() == 'yes' else 'red' | |
st.markdown(f"Model Grade: <b style='color: {grade_color};'>{result['grade']}</b>", unsafe_allow_html=True) | |
st.write(result['model_trace']) | |
if st.checkbox("Reject", value= not result.get('agree'), key=f'verify-{index}'): | |
result['agree'] = False | |
st.markdown("</div>", unsafe_allow_html=True) | |
st.markdown(f"<div class='rounded-box'>", unsafe_allow_html=True) | |
# Show current query and its results | |
current_query = st.session_state.data[st.session_state.current_query_index] | |
display_query() | |
col1, col2 = st.columns([5, 1], gap="small") | |
with col2: | |
if st.button('Mark Done and Go to Next'): | |
current_query['status'] = 'graded' | |
update_data(current_query['query'], get_new_grade_result_from_data(current_query)) | |
next_query() | |
add_keyboard_shortcuts({ | |
'j': 'Junk', | |
}) | |
add_keyboard_shortcuts({ | |
'p': 'Previous', | |
}) | |
add_keyboard_shortcuts({ | |
'n': 'Next', | |
}) | |
add_keyboard_shortcuts({ | |
'd': 'Download', | |
}) | |
add_keyboard_shortcuts({ | |
'r': 'Renew', | |
}) | |