hysts's picture
hysts HF staff
gradio==5.0.0b10
023e539
#!/usr/bin/env python
import ast
import os
import datasets
import gradio as gr
import PIL.Image
DESCRIPTION = """\
# [MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset viewer
"""
SHOW_ANSWER = os.getenv("SHOW_ANSWER", "false").lower() == "true"
SHOW_QUESTION_DETAILS = os.getenv("SHOW_QUESTION_DETAILS", "false").lower() == "true"
SUBJECTS = [
"Accounting",
"Agriculture",
"Architecture_and_Engineering",
"Art",
"Art_Theory",
"Basic_Medical_Science",
"Biology",
"Chemistry",
"Clinical_Medicine",
"Computer_Science",
"Design",
"Diagnostics_and_Laboratory_Medicine",
"Economics",
"Electronics",
"Energy_and_Power",
"Finance",
"Geography",
"History",
"Literature",
"Manage",
"Marketing",
"Materials",
"Math",
"Mechanical_Engineering",
"Music",
"Pharmacy",
"Physics",
"Psychology",
"Public_Health",
"Sociology",
]
ds = {subject: datasets.load_dataset("MMMU/MMMU", name=subject, split="validation") for subject in SUBJECTS}
def set_default_subject() -> str:
return "Accounting"
def get_images(subject: str, question_index: int) -> list[PIL.Image.Image]:
images = []
for image_id in range(1, 8):
image = ds[subject][question_index][f"image_{image_id}"]
if image is None:
break
images.append(image)
return images
def update_subject(
subject: str,
) -> tuple[
gr.Textbox, # Number of Questions
gr.Slider, # Question Index
gr.Gallery, # Images
gr.Textbox, # Question
gr.Textbox, # Options
gr.Textbox, # Answer
gr.Textbox, # Explanation
gr.Textbox, # Topic Difficulty
gr.Textbox, # Question Type
gr.Textbox, # Subfield
]:
return (
gr.Textbox(value=len(ds[subject])), # Number of Questions
gr.Slider(label="Question Index", minimum=0, maximum=len(ds[subject]) - 1, step=1, value=0), # Question Index
) + update_question(subject, 0)
def update_question(subject: str, question_index: int) -> tuple[
gr.Gallery, # Images
gr.Textbox, # Question
gr.Textbox, # Options
gr.Textbox, # Answer
gr.Textbox, # Explanation
gr.Textbox, # Topic Difficulty
gr.Textbox, # Question Type
gr.Textbox, # Subfield
]:
question = ds[subject][question_index]
options = ast.literal_eval(question["options"])
options_str = "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)])
images = get_images(subject, question_index)
return (
gr.Gallery(value=images, columns=min(len(images), 2)), # Images
gr.Textbox(value=question["question"]), # Question
gr.Textbox(value=options_str), # Options
gr.Textbox(value=question["answer"]), # Answer
gr.Textbox(value=question["explanation"]), # Explanation
gr.Textbox(value=question["topic_difficulty"]), # Topic Difficulty
gr.Textbox(value=question["question_type"]), # Question Type
gr.Textbox(value=question["subfield"]), # Subfield
)
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
subject = gr.Dropdown(label="Subject", choices=SUBJECTS, value=SUBJECTS[-1])
question_count = gr.Textbox(label="Number of Questions")
with gr.Group():
question_index = gr.Slider(label="Question Index")
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Question")
options = gr.Textbox(label="Options")
with gr.Column():
images = gr.Gallery(label="Images", object_fit="scale-down")
with gr.Accordion("Answer and Explanation", open=SHOW_ANSWER):
with gr.Row():
answer = gr.Textbox(label="Answer")
explanation = gr.Textbox(label="Explanation")
with gr.Accordion("Question Details", open=SHOW_QUESTION_DETAILS):
with gr.Row():
topic_difficulty = gr.Textbox(label="Topic Difficulty")
question_type = gr.Textbox(label="Question Type")
subfield = gr.Textbox(label="Subfield")
subject.change(
fn=update_subject,
inputs=subject,
outputs=[
question_count,
question_index,
images,
question,
options,
answer,
explanation,
topic_difficulty,
question_type,
subfield,
],
queue=False,
api_name=False,
)
question_index.input(
fn=update_question,
inputs=[subject, question_index],
outputs=[images, question, options, answer, explanation, topic_difficulty, question_type, subfield],
queue=False,
api_name=False,
)
demo.load(fn=set_default_subject, outputs=subject, queue=False, api_name=False)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False)