Benjamin Bossan
Improve layout
00bfdf9
raw
history blame
8.35 kB
from __future__ import annotations
import reprlib
from pathlib import Path
from tempfile import mkdtemp
import streamlit as st
from huggingface_hub import hf_hub_download
from skops import card
from skops.card._model_card import PlotSection, split_subsection_names
from utils import iterate_key_section_content, process_card_for_rendering
from tasks import AddSectionTask, AddFigureTask, DeleteSectionTask, TaskState, UpdateFigureTask, UpdateSectionTask
arepr = reprlib.Repr()
arepr.maxstring = 24
tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files
hf_path = Path(mkdtemp(prefix="skops-")) # hf repo
def load_model_card_from_repo(repo_id: str) -> card.Card:
print("downloading model card")
path = hf_hub_download(repo_id, "README.md")
model_card = card.parse_modelcard(path)
return model_card
def _update_model_card(
model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool,
) -> None:
# This is a very roundabout way to update the model card but it's necessary
# because of how streamlit handles session state. Basically, there have to
# be "key" arguments, which have to be retrieved from the session_state, as
# they are up-to-date. Just getting the Python variables is not enough, as
# they can be out of date.
# key names must match with those used in form
new_title = st.session_state[f"{key}.title"]
new_content = st.session_state[f"{key}.content"]
# determine if title is the same
old_title_split = split_subsection_names(section_name)
new_title_split = old_title_split[:-1] + [new_title]
is_title_same = old_title_split == new_title_split
# determine if content is the same
if is_fig:
if isinstance(new_content, PlotSection):
is_content_same = content == new_content
else:
is_content_same = not bool(new_content)
else:
is_content_same = content == new_content
if is_title_same and is_content_same:
return
if is_fig:
fpath = None
if new_content: # new figure uploaded
fname = new_content.name.replace(" ", "_")
fpath = tmp_path / fname
task = UpdateFigureTask(
model_card,
key=key,
old_name=section_name,
new_name=new_title,
data=new_content,
path=fpath,
)
else:
task = UpdateSectionTask(
model_card,
key=key,
old_name=section_name,
new_name=new_title,
old_content=content,
new_content=new_content,
)
st.session_state.task_state.add(task)
def _add_section(model_card: card.Card, key: str) -> None:
section_name = f"{key}/Untitled"
task = AddSectionTask(model_card, title=section_name, content="[More Information Needed]")
st.session_state.task_state.add(task)
def _add_figure(model_card: card.Card, key: str) -> None:
section_name = f"{key}/Untitled"
task = AddFigureTask(model_card, title=section_name, content="cat.png")
st.session_state.task_state.add(task)
def _delete_section(model_card: card.Card, key: str) -> None:
task = DeleteSectionTask(model_card, key=key)
st.session_state.task_state.add(task)
def _add_section_form(
model_card: card.Card, key: str, section_name: str, old_title: str, content: str
) -> None:
with st.form(key, clear_on_submit=False):
st.header(section_name)
# setting the 'key' argument below to update the session_state
st.text_input("Section name", value=old_title, key=f"{key}.title")
st.text_area("Content", value=content, key=f"{key}.content")
is_fig = False
st.form_submit_button(
"Update",
on_click=_update_model_card,
args=(model_card, key, section_name, content, is_fig),
)
def _add_fig_form(
model_card: card.Card, key: str, section_name: str, old_title: str, content: str
) -> None:
with st.form(key, clear_on_submit=False):
st.header(section_name)
# setting the 'key' argument below to update the session_state
st.text_input("Section name", value=old_title, key=f"{key}.title")
st.file_uploader("Upload image", key=f"{key}.content")
is_fig = True
st.form_submit_button(
"Update",
on_click=_update_model_card,
args=(model_card, key, section_name, content, is_fig),
)
def create_form_from_section(
model_card: card.Card, key: str, section_name: str, content: str, is_fig: bool = False
) -> None:
split_sections = split_subsection_names(section_name)
old_title = split_sections[-1]
if is_fig:
_add_fig_form(
model_card=model_card,
key=key,
section_name=section_name,
old_title=old_title,
content=content,
)
else:
_add_section_form(
model_card=model_card,
key=key,
section_name=section_name,
old_title=old_title,
content=content,
)
col_0, col_1, col_2 = st.columns([4, 2, 2])
with col_0:
st.button(
f"delete '{arepr.repr(old_title)}'",
on_click=_delete_section,
args=(model_card, key),
key=f"{key}.delete",
)
with col_1:
st.button(
"add section below",
on_click=_add_section,
args=(model_card, key),
key=f"{key}.add",
)
with col_2:
st.button(
"add figure below",
on_click=_add_figure,
args=(model_card, key),
key=f"{key}.fig",
)
def display_sections(model_card: card.Card) -> None:
for key, section_name, content, is_fig in iterate_key_section_content(model_card._data):
create_form_from_section(model_card, key, section_name, content, is_fig)
def display_model_card(model_card: card.Card) -> None:
rendered = model_card.render()
metadata, rendered = process_card_for_rendering(rendered)
# strip metadata
with st.expander("show metadata"):
st.text(metadata)
st.markdown(rendered, unsafe_allow_html=True)
def reset_model_card() -> None:
if "task_state" not in st.session_state:
return
if "model_card" not in st.session_state:
del st.session_state["model_card"]
while st.session_state.task_state.done_list:
st.session_state.task_state.undo()
def delete_model_card() -> None:
if "model_card" in st.session_state:
del st.session_state["model_card"]
if "task_state" in st.session_state:
st.session_state.task_state.reset()
def undo_last():
st.session_state.task_state.undo()
display_model_card(st.session_state.model_card)
def redo_last():
st.session_state.task_state.redo()
display_model_card(st.session_state.model_card)
def add_download_model_card_button():
model_card = st.session_state.get("model_card")
download_disabled = not bool(model_card)
data = model_card.render()
st.download_button(
"Save (md)", data=data, disabled=download_disabled
)
def edit_input_form():
if "task_state" not in st.session_state:
st.session_state.task_state = TaskState()
with st.sidebar:
col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2])
undo_disabled = not bool(st.session_state.task_state.done_list)
redo_disabled = not bool(st.session_state.task_state.undone_list)
with col_0:
name = f"UNDO ({len(st.session_state.task_state.done_list)})"
st.button(name, on_click=undo_last, disabled=undo_disabled)
with col_1:
name = f"REDO ({len(st.session_state.task_state.undone_list)})"
st.button(name, on_click=redo_last, disabled=redo_disabled)
with col_2:
st.button("Reset", on_click=reset_model_card)
col_0, col_1, *_ = st.columns([2, 2, 2, 2])
with col_0:
add_download_model_card_button()
with col_1:
st.button("Delete", on_click=delete_model_card)
if "model_card" in st.session_state:
display_sections(st.session_state.model_card)
if "model_card" in st.session_state:
display_model_card(st.session_state.model_card)