import io import safetensors import streamlit.file_util from safetensors.torch import serialize from streamlit.runtime.uploaded_file_manager import UploadedFile from tools import lora_tools, torch_tools # https://huggingface.co/docs/hub/spaces-config-reference streamlit.title("Lora and Embedding Tools") output_dtype = streamlit.radio("Save Precision", ["float16", "float32", "bfloat16"], index=0) streamlit.container() col1, col2 = streamlit.columns(2, gap="medium") # A helper method to wipe a download button once invoked def completed_download_callback(): ui_filedownload_rescale.empty() ui_filedownload_stripclip.empty() ui_filedownload_ckpt.empty() with col1: # - A tool for rescaling the strength of Lora weights streamlit.html("

Rescale Lora Strength

") ui_fileupload_rescale = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_rescale", type=[".safetensors"]) # type: UploadedFile new_scale_factor = streamlit.number_input("Scale Factor", value=1.0, step=0.01, max_value=100.0, min_value=0.01) # Let's preallocate the download button here so it's in the correct column, we can just add the button later. ui_filedownload_rescale = streamlit.empty() with col2: # - A tool for removing CLIP parameters from a Lora file streamlit.html("

Remove CLIP Parameters

") ui_fileupload_stripclip = streamlit.file_uploader("Upload a safetensors lora", key="fileupload_stripclip", type=[".safetensors"]) # type: UploadedFile # Preallocate download button ui_filedownload_stripclip = streamlit.empty() streamlit.html("
") # - A tool for converting a .ckpt file to a .safetensors file streamlit.html("

Convert CKPT to Safetensors (700MB max)

") ui_fileupload_ckpt = streamlit.file_uploader("Upload a .ckpt file", key="fileupload_convertckpt", type=[".ckpt", ".pt", ".pth"]) # type: UploadedFile # Preallocate download button ui_filedownload_ckpt = streamlit.empty() # ! Rescale Lora if ui_fileupload_rescale and ui_fileupload_rescale.name is not None: lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_rescale) new_weights = lora_tools.rescale_lora_alpha(ui_fileupload_rescale, output_dtype, new_scale_factor) new_lora_data = safetensors.torch.save(new_weights, lora_metadata) lora_file_buffer = io.BytesIO() lora_file_buffer.write(new_lora_data) lora_file_buffer.seek(0) file_name = ui_fileupload_rescale.name.rsplit(".", 1)[0] output_name = f"{file_name}_rescaled_{new_scale_factor:.2f}.safetensors" ui_fileupload_rescale.close() del ui_fileupload_rescale ui_filedownload_rescale.download_button("Download Rescaled Weights", lora_file_buffer, output_name, type="primary") # ! Remove CLIP Parameters if ui_fileupload_stripclip and ui_fileupload_stripclip.name is not None: lora_metadata = lora_tools.read_safetensors_metadata(ui_fileupload_stripclip) stripped_weights = lora_tools.remove_clip_weights(ui_fileupload_stripclip, output_dtype) stripped_lora_data = safetensors.torch.save(stripped_weights, lora_metadata) lora_file_buffer = io.BytesIO() lora_file_buffer.write(stripped_lora_data) lora_file_buffer.seek(0) file_name = ui_fileupload_stripclip.name.rsplit(".", 1)[0] output_name = f"{file_name}_noclip.safetensors" ui_fileupload_stripclip.close() del ui_fileupload_stripclip ui_filedownload_stripclip.download_button("Download Stripped Weights", lora_file_buffer, output_name, type="primary") # ! Convert Checkpoint to Safetensors if ui_fileupload_ckpt and ui_fileupload_ckpt.name is not None: converted_weights = torch_tools.convert_ckpt_to_safetensors(ui_fileupload_ckpt, output_dtype) converted_lora_data = safetensors.torch.save(converted_weights) lora_file_buffer = io.BytesIO() lora_file_buffer.write(converted_lora_data) lora_file_buffer.seek(0) file_name = ui_fileupload_ckpt.name.rsplit(".", 1)[0] output_name = f"{file_name}.safetensors" ui_fileupload_ckpt.close() del ui_fileupload_ckpt ui_filedownload_ckpt.download_button("Download Converted Weights", lora_file_buffer, output_name, type="primary")