Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import reprlib | |
from pathlib import Path | |
from tempfile import mkdtemp | |
import streamlit as st | |
from huggingface_hub import hf_hub_download | |
from skops import card | |
from skops.card._model_card import PlotSection, split_subsection_names | |
from utils import iterate_key_section_content, process_card_for_rendering | |
from tasks import AddSectionTask, AddFigureTask, DeleteSectionTask, TaskState, UpdateFigureTask, UpdateSectionTask | |
arepr = reprlib.Repr() | |
arepr.maxstring = 24 | |
tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files | |
hf_path = Path(mkdtemp(prefix="skops-")) # hf repo | |
def load_model_card_from_repo(repo_id: str) -> card.Card: | |
print("downloading model card") | |
path = hf_hub_download(repo_id, "README.md") | |
model_card = card.parse_modelcard(path) | |
return model_card | |
def _update_model_card( | |
model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool, | |
) -> None: | |
# This is a very roundabout way to update the model card but it's necessary | |
# because of how streamlit handles session state. Basically, there have to | |
# be "key" arguments, which have to be retrieved from the session_state, as | |
# they are up-to-date. Just getting the Python variables is not enough, as | |
# they can be out of date. | |
# key names must match with those used in form | |
new_title = st.session_state[f"{key}.title"] | |
new_content = st.session_state[f"{key}.content"] | |
# determine if title is the same | |
old_title_split = split_subsection_names(section_name) | |
new_title_split = old_title_split[:-1] + [new_title] | |
is_title_same = old_title_split == new_title_split | |
# determine if content is the same | |
if is_fig: | |
if isinstance(new_content, PlotSection): | |
is_content_same = content == new_content | |
else: | |
is_content_same = not bool(new_content) | |
else: | |
is_content_same = content == new_content | |
if is_title_same and is_content_same: | |
return | |
if is_fig: | |
fpath = None | |
if new_content: # new figure uploaded | |
fname = new_content.name.replace(" ", "_") | |
fpath = tmp_path / fname | |
task = UpdateFigureTask( | |
model_card, | |
key=key, | |
old_name=section_name, | |
new_name=new_title, | |
data=new_content, | |
path=fpath, | |
) | |
else: | |
task = UpdateSectionTask( | |
model_card, | |
key=key, | |
old_name=section_name, | |
new_name=new_title, | |
old_content=content, | |
new_content=new_content, | |
) | |
st.session_state.task_state.add(task) | |
def _add_section(model_card: card.Card, key: str) -> None: | |
section_name = f"{key}/Untitled" | |
task = AddSectionTask(model_card, title=section_name, content="[More Information Needed]") | |
st.session_state.task_state.add(task) | |
def _add_figure(model_card: card.Card, key: str) -> None: | |
section_name = f"{key}/Untitled" | |
task = AddFigureTask(model_card, title=section_name, content="cat.png") | |
st.session_state.task_state.add(task) | |
def _delete_section(model_card: card.Card, key: str) -> None: | |
task = DeleteSectionTask(model_card, key=key) | |
st.session_state.task_state.add(task) | |
def _add_section_form( | |
model_card: card.Card, key: str, section_name: str, old_title: str, content: str | |
) -> None: | |
with st.form(key, clear_on_submit=False): | |
st.header(section_name) | |
# setting the 'key' argument below to update the session_state | |
st.text_input("Section name", value=old_title, key=f"{key}.title") | |
st.text_area("Content", value=content, key=f"{key}.content") | |
is_fig = False | |
st.form_submit_button( | |
"Update", | |
on_click=_update_model_card, | |
args=(model_card, key, section_name, content, is_fig), | |
) | |
def _add_fig_form( | |
model_card: card.Card, key: str, section_name: str, old_title: str, content: str | |
) -> None: | |
with st.form(key, clear_on_submit=False): | |
st.header(section_name) | |
# setting the 'key' argument below to update the session_state | |
st.text_input("Section name", value=old_title, key=f"{key}.title") | |
st.file_uploader("Upload image", key=f"{key}.content") | |
is_fig = True | |
st.form_submit_button( | |
"Update", | |
on_click=_update_model_card, | |
args=(model_card, key, section_name, content, is_fig), | |
) | |
def create_form_from_section( | |
model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool = False | |
) -> None: | |
split_sections = split_subsection_names(section_name) | |
old_title = split_sections[-1] | |
if is_fig: | |
_add_fig_form( | |
model_card=model_card, | |
key=key, | |
section_name=section_name, | |
old_title=old_title, | |
content=content, | |
) | |
else: | |
_add_section_form( | |
model_card=model_card, | |
key=key, | |
section_name=section_name, | |
old_title=old_title, | |
content=content, | |
) | |
col_0, col_1, col_2 = st.columns([4, 2, 2]) | |
with col_0: | |
st.button( | |
f"delete '{arepr.repr(old_title)}'", | |
on_click=_delete_section, | |
args=(model_card, key), | |
key=f"{key}.delete", | |
) | |
with col_1: | |
st.button( | |
"add section below", | |
on_click=_add_section, | |
args=(model_card, key), | |
key=f"{key}.add", | |
) | |
with col_2: | |
st.button( | |
"add figure below", | |
on_click=_add_figure, | |
args=(model_card, key), | |
key=f"{key}.fig", | |
) | |
def display_sections(model_card: card.Card) -> None: | |
for key, section_name, content, is_fig in iterate_key_section_content(model_card._data): | |
create_form_from_section(model_card, key, section_name, content, is_fig) | |
def display_model_card(model_card: card.Card) -> None: | |
rendered = model_card.render() | |
metadata, rendered = process_card_for_rendering(rendered) | |
# strip metadata | |
with st.expander("show metadata"): | |
st.text(metadata) | |
st.markdown(rendered, unsafe_allow_html=True) | |
def reset_model_card() -> None: | |
if "task_state" not in st.session_state: | |
return | |
if "model_card" not in st.session_state: | |
del st.session_state["model_card"] | |
while st.session_state.task_state.done_list: | |
st.session_state.task_state.undo() | |
def delete_model_card() -> None: | |
if "model_card" in st.session_state: | |
del st.session_state["model_card"] | |
if "task_state" in st.session_state: | |
st.session_state.task_state.reset() | |
def undo_last(): | |
st.session_state.task_state.undo() | |
display_model_card(st.session_state.model_card) | |
def redo_last(): | |
st.session_state.task_state.redo() | |
display_model_card(st.session_state.model_card) | |
def add_download_model_card_button(): | |
model_card = st.session_state.get("model_card") | |
download_disabled = not bool(model_card) | |
data = model_card.render() | |
st.download_button( | |
"Save (md)", data=data, disabled=download_disabled | |
) | |
def edit_input_form(): | |
if "task_state" not in st.session_state: | |
st.session_state.task_state = TaskState() | |
with st.sidebar: | |
col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2]) | |
undo_disabled = not bool(st.session_state.task_state.done_list) | |
redo_disabled = not bool(st.session_state.task_state.undone_list) | |
with col_0: | |
name = f"UNDO ({len(st.session_state.task_state.done_list)})" | |
st.button(name, on_click=undo_last, disabled=undo_disabled) | |
with col_1: | |
name = f"REDO ({len(st.session_state.task_state.undone_list)})" | |
st.button(name, on_click=redo_last, disabled=redo_disabled) | |
with col_2: | |
st.button("Reset", on_click=reset_model_card) | |
col_0, col_1, *_ = st.columns([2, 2, 2, 2]) | |
with col_0: | |
add_download_model_card_button() | |
with col_1: | |
st.button("Delete", on_click=delete_model_card) | |
if "model_card" in st.session_state: | |
display_sections(st.session_state.model_card) | |
if "model_card" in st.session_state: | |
display_model_card(st.session_state.model_card) | |