Benjamin Bossan
Initial commit
31a1df6
raw
history blame
5.12 kB
from __future__ import annotations
from pathlib import Path
from uuid import uuid4
from skops import card
from skops.card._model_card import PlotSection, split_subsection_names
from streamlit.runtime.uploaded_file_manager import UploadedFile
class Task:
def __init__(self, model_card: card.Card) -> None:
self.model_card = model_card
def do(self) -> None:
raise NotImplementedError
def undo(self) -> None:
raise NotImplementedError
class TaskState:
def __init__(self) -> None:
self.done_list: list[Task] = []
self.undone_list: list[Task] = []
def undo(self) -> None:
if not self.done_list:
return
task = self.done_list.pop(-1)
task.undo()
self.undone_list.append(task)
def redo(self) -> None:
if not self.undone_list:
return
task = self.undone_list.pop(-1)
task.do()
self.done_list.append(task)
def add(self, task: Task) -> None:
task.do()
self.done_list.append(task)
self.undone_list.clear()
def reset(self) -> None:
self.done_list.clear()
self.undone_list.clear()
class AddSectionTask(Task):
def __init__(
self,
model_card: card.Card,
title: str,
content: str,
) -> None:
self.model_card = model_card
self.title = title
self.key = title + " " + str(uuid4())[:6]
self.content = content
def do(self) -> None:
self.model_card.add(**{self.key: self.content})
section = self.model_card.select(self.key)
section.title = split_subsection_names(self.title)[-1]
def undo(self) -> None:
self.model_card.delete(self.key)
class AddFigureTask(Task):
def __init__(
self,
model_card: card.Card,
title: str,
content: str,
) -> None:
self.model_card = model_card
self.title = title
self.key = title + " " + str(uuid4())[:6]
self.content = content
def do(self) -> None:
self.model_card.add_plot(**{self.key: self.content})
section = self.model_card.select(self.key)
section.title = split_subsection_names(self.title)[-1]
section.is_fig = True # type: ignore
def undo(self) -> None:
self.model_card.delete(self.key)
class DeleteSectionTask(Task):
def __init__(
self,
model_card: card.Card,
key: str,
) -> None:
self.model_card = model_card
self.key = key
def do(self) -> None:
self.model_card.select(self.key).visible = False
def undo(self) -> None:
self.model_card.select(self.key).visible = True
class UpdateSectionTask(Task):
def __init__(
self,
model_card: card.Card,
key: str,
old_name: str,
new_name: str,
old_content: str,
new_content: str,
) -> None:
self.model_card = model_card
self.key = key
self.old_name = old_name
self.new_name = new_name
self.old_content = old_content
self.new_content = new_content
def do(self) -> None:
section = self.model_card.select(self.key)
new_title = split_subsection_names(self.new_name)[-1]
section.title = new_title
section.content = self.new_content
def undo(self) -> None:
section = self.model_card.select(self.key)
old_title = split_subsection_names(self.old_name)[-1]
section.title = old_title
section.content = self.old_content
class UpdateFigureTask(Task):
def __init__(
self,
model_card: card.Card,
key: str,
old_name: str,
new_name: str,
data: UploadedFile | None,
path: Path | None,
) -> None:
self.model_card = model_card
self.key = key
self.old_name = old_name
self.new_name = new_name
self.old_data = self.model_card.select(self.key).content
self.path = path
if not data:
self.new_data = self.old_data
else:
self.new_data = data
def do(self) -> None:
section = self.model_card.select(self.key)
new_title = split_subsection_names(self.new_name)[-1]
section.title = self.title = new_title
if self.new_data == self.old_data: # image is same
return
# write figure
# note: this can still be the same image if the image is a file, there
# is no test to check, e.g., the hash of the image
with open(self.path, "wb") as f:
f.write(self.new_data.getvalue())
section.content = PlotSection(
alt_text=self.new_data.name,
path=self.path,
).format()
def undo(self) -> None:
section = self.model_card.select(self.key)
old_title = split_subsection_names(self.old_name)[-1]
section.title = old_title
if self.new_data == self.old_data: # image is same
return
self.path.unlink(missing_ok=True)
section.content = self.old_data