mrcuddle's picture
Update app.py
c4a29da verified
import gradio as gr
import subprocess
import os
import logging
from pathlib import Path
import spaces
@spaces.GPU()
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message):
# Define a fixed output path
outpath = Path('/tmp')
# Construct the command to run hf_merge.py
command = [
"python3", "hf_merge.py",
base_model,
model_to_merge,
"-p", str(weight_drop_prob),
"-lambda", str(scaling_factor),
"--token", token,
"--repo", repo_name,
"--commit-message", commit_message,
"-U"
]
# Set up logging
logging.basicConfig(level=logging.INFO)
log_output = ""
# Run the command and capture the output
result = subprocess.run(command, capture_output=True, text=True)
# Log the output
log_output += result.stdout + "\n"
log_output += result.stderr + "\n"
logging.info(result.stdout)
logging.error(result.stderr)
# Check if the merge was successful
if result.returncode != 0:
return None, f"Error in merging models: {result.stderr}", log_output
# Assuming the script handles the upload and returns the repo URL
repo_url = f"https://huggingface.co/{repo_name}"
return repo_url, "Model merged and uploaded successfully!", log_output
# Define the Gradio interface
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
gr.Markdown("# SuperMario Safetensors Merger")
gr.Markdown("Combine any two models using a Super Mario merge(DARE)")
gr.Markdown("Based on: https://github.com/martyn/safetensors-merge-supermario")
gr.Markdown("Works with:")
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
gr.Markdown("* LLMs (Mistral, Llama, etc) (also works with Llava, Visison models) ")
gr.Markdown("* LoRas (must be same size)")
gr.Markdown("* Any two homologous models")
with gr.Column():
with gr.Row():
token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
with gr.Row():
base_model = gr.Textbox(label="Base Model", placeholder="meta-llama/Llama-3.2-11B-Vision-Instruct", info="Safetensors format")
with gr.Row():
model_to_merge = gr.Textbox(label="Merge Model", placeholder="Qwen/Qwen2.5-Coder-7B-Instruct", info="Safetensors or .bin")
with gr.Row():
repo_name = gr.Textbox(label="New Model", placeholder="Llama-Qwen-Vision_Instruct", info="your-username/new-model-name", value="", max_lines=1)
with gr.Row():
scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
with gr.Row():
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
with gr.Row():
commit_message = gr.Textbox(label="Commit Message", value="Upload merged model", max_lines=1)
progress = gr.Progress()
repo_url = gr.Markdown(label="Repository URL")
output = gr.Textbox(label="Output")
gr.Button("Merge").click(
merge_and_upload,
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message],
outputs=[repo_url, output]
)
demo.launch()