vodkaslime's picture
branding and renaming
ff68c49 unverified
raw
history blame contribute delete
No virus
3.94 kB
import os
import shutil
import subprocess
import streamlit as st
import uuid
from git import Repo
import huggingface_hub
BACKEND_REPO_URL = "https://github.com/vodkaslime/ctranslate2-converter-backend"
HOME_DIR = os.path.expanduser("~")
BACKEND_DIR = os.path.join(HOME_DIR, "backend")
BACKEND_SCRIPT = os.path.join(BACKEND_DIR, "main.py")
MODEL_ROOT_DIR = os.path.join(HOME_DIR, "models")
st.title(":wave: Tabby Model Converter")
@st.cache_resource
def init():
if os.path.exists(BACKEND_DIR):
return
try:
Repo.clone_from(BACKEND_REPO_URL, BACKEND_DIR)
subprocess.check_call(
[
"pip",
"install",
"-r",
os.path.join(BACKEND_DIR, "requirements.txt"),
]
)
except Exception as e:
shutil.rmtree(BACKEND_DIR)
st.error(f"error initializing backend: {e}")
def convert_and_upload_model(
model,
output_dir,
inference_mode,
prompt_template,
huggingface_token,
upload_mode,
new_model,
):
# Verify parameters
if not model:
st.error("Must provide a model name")
return
if not new_model:
st.error("Must provide a new model name where the conversion will upload to")
return
if not huggingface_token:
st.error("Must provide a huggingface token")
return
command = ["python", BACKEND_SCRIPT]
command += ["--model", model]
command += ["--output_dir", output_dir]
command += ["--inference_mode", inference_mode]
if prompt_template:
command += ["--prompt_template", prompt_template]
# Login on behalf of user
huggingface_hub.login(huggingface_token)
# Handle model conversion
try:
with st.spinner("Converting model"):
subprocess.check_call(command)
except subprocess.CalledProcessError as e:
st.error(f"Error converting model to Tabby compatible format: {e}")
link = f"https://huggingface.co/{model}"
st.warning(
f"Note: do you have access to the model? If not, visit the model page at {link} and request for access"
)
return
st.success("Model successfully converted")
# Handle model upload
try:
with st.spinner("Uploading converted model"):
api = huggingface_hub.HfApi()
if upload_mode == "new repo":
api.create_repo(new_model)
api.upload_folder(folder_path=output_dir, repo_id=new_model)
except Exception as e:
st.error(f"Error uploading model: {e}")
return
st.success("Model successfully uploaded.")
def clean_up(output_dir):
try:
with st.spinner("Cleaning up"):
shutil.rmtree(output_dir)
except Exception as e:
st.error(f"Error removing work dir: {e}")
st.success("Cleaning up finished")
init()
model = st.text_input("Model name", placeholder="Salesforce/codet5p-220m")
inference_mode = st.radio(
"Inference mode",
("causallm", "seq2seq"),
)
prompt_template = st.text_input("Prompt template")
huggingface_token = st.text_input(
"Hugging face token (must be writable token)", type="password"
)
upload_mode = st.radio(
"Choose if you want to create a new model repo or push a commit to existing repo",
("new repo", "existing repo"),
)
new_model = st.text_input(
"The new model name that the model is going to be converted to",
placeholder="TabbyML/T5P-220M",
)
convert_button = st.button("Convert model", use_container_width=True)
if convert_button:
id = uuid.uuid4()
output_dir = os.path.join(MODEL_ROOT_DIR, str(id))
# Try converting and uploading model
convert_and_upload_model(
model,
output_dir,
inference_mode,
prompt_template,
huggingface_token,
upload_mode,
new_model,
)
# Clean up the conversion
clean_up(output_dir)