{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "a44dd6024769456a8262a17b0ce6a2ed": { "model_module": "@jupyter-widgets/controls", "model_name": "ButtonModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ButtonModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ButtonView", "button_style": "success", "description": "✔ Done", "disabled": true, "icon": "", "layout": "IPY_MODEL_49441085d85a4f219a6ccbf2a197f527", "style": "IPY_MODEL_f084b7dfcae445a58d36a9c21971793c", "tooltip": "" } }, "49441085d85a4f219a6ccbf2a197f527": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": "50px", "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f084b7dfcae445a58d36a9c21971793c": { "model_module": "@jupyter-widgets/controls", "model_name": "ButtonStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ButtonStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "button_color": null, "font_weight": "" } } } } }, "cells": [ { "cell_type": "code", "source": [ "#@title Mount Google Drive\n", "from google.colab import drive\n", "from IPython.display import clear_output\n", "from IPython.display import display\n", "import ipywidgets as widgets\n", "import os\n", "\n", "def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n", "Shared_Drive = \"\" #@param {type:\"string\"}\n", "#@markdown - If you're not using a shared drive, leave this empty\n", "\n", "print(\"\u001b[0;33mConnecting...\")\n", "drive.mount('/content/gdrive')\n", "\n", "if Shared_Drive!=\"\" and os.path.exists(\"/content/gdrive/Shareddrives\"):\n", " mainpth=\"Shareddrives/\"+Shared_Drive\n", "else:\n", " mainpth=\"MyDrive\"\n", "\n", "clear_output()\n", "inf('\\u2714 Done','success', '50px')" ], "metadata": { "id": "fCR2boKCTx0z", "cellView": "form", "outputId": "baf6303f-9850-4dd2-a6d3-86871ac8aef5", "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "a44dd6024769456a8262a17b0ce6a2ed", "49441085d85a4f219a6ccbf2a197f527", "f084b7dfcae445a58d36a9c21971793c" ] } }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Button(button_style='success', description='✔ Done', disabled=True, layout=Layout(min_width='50px'), style=But…" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "a44dd6024769456a8262a17b0ce6a2ed" } }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "#@title Install Required Dependencies\n", "!pip install torch\n", "!pip install safetensors\n", "!pip install pytorch-lightning" ], "metadata": { "id": "5S88gkUJzeqG" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n", "file_path = \"\" #@param {type:\"string\"}\n", "#@markdown - Copy and paste the path to an embedding or VAE file that you are converting, or a directory containing several files\n", "#@markdown - For example: /content/gdrive/MyDrive/myembedding.pt or /content/gdrive/MyDrive/my_directory\n", "#@markdown - Pickle files must be in .pt format\n", "verbose=True" ], "metadata": { "id": "7aLFC6c4O5EW" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Define Converter Functions\n", "import os\n", "from typing import Any, Dict\n", "\n", "import torch\n", "from safetensors.torch import save_file\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "def process_pt_files(path: str, model_type: str, verbose=True) -> None:\n", " if os.path.isdir(path):\n", " # Path is a directory, process all .pt files in the directory\n", " for file_name in os.listdir(path):\n", " if file_name.endswith('.pt'):\n", " process_file(os.path.join(path, file_name), model_type, verbose)\n", " elif os.path.isfile(path) and path.endswith('.pt'):\n", " # Path is a .pt file, process this file\n", " process_file(path, model_type, verbose)\n", " else:\n", " print(f\"{path} is not a valid directory or .pt file.\")\n", "\n", "def process_file(file_path: str, model_type: str, verbose: bool) -> None:\n", " # Load the PyTorch model\n", " model = torch.load(file_path, map_location=device)\n", "\n", " if verbose:\n", " print(file_path)\n", "\n", " if model_type == 'embedding':\n", " s_model = process_embedding_file(model, verbose)\n", " elif model_type == 'vae':\n", " s_model = process_vae_file(model, verbose)\n", " else:\n", " raise Exception(f\"model_type `{model_type}` is not supported!\")\n", "\n", " # Save the model with the new extension\n", " if file_path.endswith('.pt'):\n", " new_file_path = file_path[:-3] + '.safetensors'\n", " else:\n", " new_file_path = file_path + '.safetensors'\n", " save_file(s_model, new_file_path)\n", "\n", "def process_embedding_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:\n", " # Extract the embedding tensors\n", " model_tensors = model.get('string_to_param').get('*')\n", " s_model = {\n", " 'emb_params': model_tensors\n", " }\n", "\n", " if verbose:\n", " # Print the requested training information, if it exists\n", " if ('sd_checkpoint_name' in model) and (model['sd_checkpoint_name'] is not None):\n", " print(f\"Trained on {model['sd_checkpoint_name']}.\")\n", " else:\n", " print(\"Checkpoint name not found in the model.\")\n", "\n", " if ('step' in model) and (model['step'] is not None):\n", " print(f\"Trained for {model['step']} steps.\")\n", " else:\n", " print(\"Step not found in the model.\")\n", " # Display the tensor's shape\n", " print(f\"Dimensions of embedding tensor: {model_tensors.shape}\")\n", " print()\n", "\n", " return s_model\n", "\n", "def process_vae_file(model: Dict[str, Any], verbose: bool) -> Dict[str, torch.Tensor]:\n", " # Extract the state dictionary\n", " s_model = model[\"state_dict\"]\n", " if verbose:\n", " # Print the requested training information, if it exists\n", " step = model.get('step', model.get('global_step'))\n", " if step is not None:\n", " print(f\"Trained for {step} steps.\")\n", " else:\n", " print(\"Step not found in the model.\")\n", " print()\n", " return s_model" ], "metadata": { "id": "UwH1lXmGw9XP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Convert the file(s)\n", "\n", "Run whichever of the two following code blocks corresponds to the type of file you are converting.\n", "\n", "The converted Safetensor file will be saved in the same directory as the original." ], "metadata": { "id": "LqEl4sM0sMPG" } }, { "cell_type": "code", "source": [ "#@title Convert the Embedding(s)\n", "process_pt_files(file_path, 'embedding', verbose=verbose)" ], "metadata": { "id": "4LEWGfjiUeG1", "cellView": "form" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Convert the VAE(s)\n", "process_pt_files(file_path, 'vae', verbose=verbose)" ], "metadata": { "id": "Jil7A1ckyiHA", "cellView": "form" }, "execution_count": null, "outputs": [] } ] }