person-thumbs-up / push_to_hf.py
Srimanth Agastyaraju
Initial commit
5372b88
import os
from pathlib import Path
from huggingface_hub import create_repo, upload_folder
from PIL import Image
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
if __name__ == "__main__":
REPOS = {
"tom_cruise_plain": {"hub_model_id": "person-thumbs-up-plain-lora", "output_dir": "/l/vision/v5/sragas/easel_ai/models_plain/"},
"tom_cruise": {"hub_model_id": "person-thumbs-up-lora", "output_dir": "/l/vision/v5/sragas/easel_ai/models/"},
"tom_cruise_no_cap": {"hub_model_id": "person-thumbs-up-lora-no-cap", "output_dir": "/l/vision/v5/sragas/easel_ai/models_no_cap/"}
}
current_repo_id = "tom_cruise"
current_repo = REPOS[current_repo_id]
print(f"{'-'*20} CURRENT REPO: {current_repo_id} {'-'*20}")
hub_model_id = current_repo["hub_model_id"]
output_dir = current_repo["output_dir"]
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
images = [Image.open(output_dir+file) for file in os.listdir(output_dir) if ".png" in file]
repo_id = create_repo(
repo_id=hub_model_id or Path(output_dir).name, exist_ok=True, token=None
).repo_id
save_model_card(
repo_id,
images=images,
base_model=pretrained_model_name_or_path,
dataset_name="Custom dataset",
repo_folder=output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)