Spaces:
Runtime error
Runtime error
import io | |
import os | |
import shutil | |
import zipfile | |
import gradio as gr | |
import requests | |
from huggingface_hub import create_repo, upload_folder, whoami | |
from convert import convert_full_checkpoint | |
MODELS_DIR = "models/" | |
CKPT_FILE = MODELS_DIR + "model.ckpt" | |
HF_MODEL_DIR = MODELS_DIR + "diffusers_model" | |
ZIP_FILE = MODELS_DIR + "model.zip" | |
def download_ckpt(url, out_path): | |
with open(out_path, "wb") as out_file: | |
with requests.get(url, stream=True) as r: | |
r.raise_for_status() | |
for chunk in r.iter_content(chunk_size=8192): | |
out_file.write(chunk) | |
def zip_model(model_path, zip_path): | |
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zip_file: | |
for root, dirs, files in os.walk(model_path): | |
for file in files: | |
zip_file.write( | |
os.path.join(root, file), | |
os.path.relpath( | |
os.path.join(root, file), os.path.join(model_path, "..") | |
), | |
) | |
def download_checkpoint_and_config(ckpt_url, config_url): | |
ckpt_url = ckpt_url.strip() | |
config_url = config_url.strip() | |
if not ckpt_url.startswith("http://") and not ckpt_url.startswith("https://"): | |
raise ValueError("Invalid checkpoint URL") | |
if config_url.startswith("http://") or config_url.startswith("https://"): | |
response = requests.get(config_url) | |
response.raise_for_status() | |
config_file = io.BytesIO(response.content) | |
elif config_url != "": | |
raise ValueError("Invalid config URL") | |
else: | |
config_file = open("original_config.yaml", "r") | |
download_ckpt(ckpt_url, CKPT_FILE) | |
return CKPT_FILE, config_file | |
def convert_and_download(ckpt_url, config_url, scheduler_type, extract_ema): | |
shutil.rmtree(MODELS_DIR, ignore_errors=True) | |
os.makedirs(HF_MODEL_DIR) | |
ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url) | |
convert_full_checkpoint( | |
ckpt_path, | |
config_file, | |
scheduler_type=scheduler_type, | |
extract_ema=(extract_ema == "EMA"), | |
output_path=HF_MODEL_DIR, | |
) | |
zip_model(HF_MODEL_DIR, ZIP_FILE) | |
return ZIP_FILE | |
def convert_and_upload( | |
ckpt_url, config_url, scheduler_type, extract_ema, token, model_name | |
): | |
shutil.rmtree(MODELS_DIR, ignore_errors=True) | |
os.makedirs(HF_MODEL_DIR) | |
try: | |
ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url) | |
username = whoami(token)["name"] | |
repo_name = f"{username}/{model_name}" | |
repo_url = create_repo(repo_name, token=token, exist_ok=True) | |
convert_full_checkpoint( | |
ckpt_path, | |
config_file, | |
scheduler_type=scheduler_type, | |
extract_ema=(extract_ema == "EMA"), | |
output_path=HF_MODEL_DIR, | |
) | |
upload_folder(repo_id=repo_name, folder_path=HF_MODEL_DIR, token=token, commit_message=f"Upload diffusers weights") | |
except Exception as e: | |
return f"#### Error: {e}" | |
return f"#### Success! Model uploaded to [{repo_url}]({repo_url})" | |
TTILE_IMAGE = """ | |
<div | |
style=" | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
width: 50%; | |
" | |
> | |
<img src="https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg"/> | |
</div> | |
""" | |
TITLE = """ | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
text-align: center; | |
max-width: 1400px; | |
gap: 0.8rem; | |
font-size: 2.2rem; | |
" | |
> | |
<h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 10px;"> | |
Convert Stable Diffusion `.ckpt` files to Hugging Face Diffusers 🔥 | |
</h1> | |
</div> | |
""" | |
with gr.Blocks() as interface: | |
gr.HTML(TTILE_IMAGE) | |
gr.HTML(TITLE) | |
gr.Markdown("We will perform all of the checkpoint surgery for you, and create a clean diffusers model!") | |
gr.Markdown("This converter will also remove any pickled code from third-party checkpoints.") | |
with gr.Row(): | |
with gr.Column(scale=50): | |
gr.Markdown("### 1. Paste a URL to your <model>.ckpt file") | |
ckpt_url = gr.Textbox( | |
max_lines=1, | |
label="URL to <model>.ckpt", | |
placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt", | |
) | |
with gr.Column(scale=50): | |
gr.Markdown("### (Optional) paste a URL to your <config>.yaml file") | |
config_url = gr.Textbox( | |
max_lines=1, | |
label="URL to <config>.yaml", | |
placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-inference.yaml", | |
) | |
gr.Markdown( | |
"**If you don't provide a config file, we'll try to use" | |
" [v1-inference.yaml](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-inference.yaml).*" | |
) | |
with gr.Accordion("Advanced Settings"): | |
scheduler_type = gr.Dropdown( | |
label="Choose a scheduler type (if not sure, keep the PNDM default)", | |
choices=["PNDM", "K-LMS", "Euler", "EulerAncestral", "DDIM"], | |
value="PNDM", | |
) | |
extract_ema = gr.Radio( | |
label=( | |
"EMA weights usually yield higher quality images for inference." | |
" Non-EMA weights are usually better to continue fine-tuning." | |
), | |
choices=["EMA", "Non-EMA"], | |
value="EMA", | |
interactive=True, | |
) | |
gr.Markdown("### 2. Choose what to do with the converted model") | |
model_choice = gr.Radio( | |
show_label=False, | |
choices=[ | |
"Download the model as an archive", | |
"Host the model on the Hugging Face Hub", | |
# "Submit a PR with the model for an existing Hub repository", | |
], | |
type="index", | |
value="Download the model as an archive", | |
interactive=True, | |
) | |
download_panel = gr.Column(visible=True) | |
upload_panel = gr.Column(visible=False) | |
# pr_panel = gr.Column(visible=False) | |
model_choice.change( | |
fn=lambda i: gr.update(visible=(i == 0)), | |
inputs=model_choice, | |
outputs=download_panel, | |
) | |
model_choice.change( | |
fn=lambda i: gr.update(visible=(i == 1)), | |
inputs=model_choice, | |
outputs=upload_panel, | |
) | |
# model_choice.change( | |
# fn=lambda i: gr.update(visible=(i == 2)), | |
# inputs=model_choice, | |
# outputs=pr_panel, | |
# ) | |
with download_panel: | |
gr.Markdown("### 3. Convert and download") | |
down_btn = gr.Button("Convert") | |
output_file = gr.File( | |
label="Download the converted model", | |
type="binary", | |
interactive=False, | |
visible=True, | |
) | |
down_btn.click( | |
fn=convert_and_download, | |
inputs=[ckpt_url, config_url, scheduler_type, extract_ema], | |
outputs=output_file, | |
) | |
with upload_panel: | |
gr.Markdown("### 3. Convert and host on the Hub") | |
gr.Markdown( | |
"This will create a new repository if it doesn't exist yet, and upload the model to the Hugging Face Hub.\n\n" | |
"Paste a WRITE token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)" | |
" and make up a model name." | |
) | |
up_token = gr.Textbox( | |
max_lines=1, | |
label="Hugging Face token", | |
) | |
up_model_name = gr.Textbox( | |
max_lines=1, | |
label="Hub model name (e.g. `artistic-diffusion-v1`)", | |
placeholder="my-awesome-model", | |
) | |
upload_btn = gr.Button("Convert and upload") | |
with gr.Box(): | |
output_text = gr.Markdown() | |
upload_btn.click( | |
fn=convert_and_upload, | |
inputs=[ | |
ckpt_url, | |
config_url, | |
scheduler_type, | |
extract_ema, | |
up_token, | |
up_model_name, | |
], | |
outputs=output_text, | |
) | |
# with pr_panel: | |
# gr.Markdown("### 3. Convert and submit as a PR") | |
# gr.Markdown( | |
# "This will open a Pull Request on the original model repository, if it already exists on the Hub.\n\n" | |
# "Paste a write-access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)" | |
# " and paste an existing model id from the Hub in the `username/model-name` form." | |
# ) | |
# pr_token = gr.Textbox( | |
# max_lines=1, | |
# label="Hugging Face token", | |
# ) | |
# pr_model_name = gr.Textbox( | |
# max_lines=1, | |
# label="Hub model name (e.g. `diffuser/artistic-diffusion-v1`)", | |
# placeholder="diffuser/my-awesome-model", | |
# ) | |
# | |
# btn = gr.Button("Convert and open a PR") | |
# output = gr.Markdown(label="Output") | |
interface.queue(concurrency_count=1) | |
interface.launch() | |