|
import os |
|
from huggingface_hub import snapshot_download, delete_repo, metadata_update |
|
import uuid |
|
import json |
|
import yaml |
|
import subprocess |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
HF_DATASET = os.environ.get("DATA_PATH") |
|
|
|
|
|
def download_dataset(hf_dataset_path: str): |
|
random_id = str(uuid.uuid4()) |
|
snapshot_download( |
|
repo_id=hf_dataset_path, |
|
token=HF_TOKEN, |
|
local_dir=f"/tmp/{random_id}", |
|
repo_type="dataset", |
|
) |
|
return f"/tmp/{random_id}" |
|
|
|
|
|
def process_dataset(dataset_dir: str): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(os.path.join(dataset_dir, "config.yaml")): |
|
raise ValueError("config.yaml does not exist") |
|
|
|
|
|
if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")): |
|
metadata = [] |
|
with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f: |
|
for line in f: |
|
if len(line.strip()) > 0: |
|
metadata.append(json.loads(line)) |
|
for item in metadata: |
|
txt_path = os.path.join(dataset_dir, item["file_name"]) |
|
txt_path = txt_path.rsplit(".", 1)[0] + ".txt" |
|
with open(txt_path, "w") as f: |
|
f.write(item["prompt"]) |
|
|
|
|
|
os.remove(os.path.join(dataset_dir, "metadata.jsonl")) |
|
|
|
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: |
|
config = yaml.safe_load(f) |
|
|
|
|
|
config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir |
|
|
|
with open(os.path.join(dataset_dir, "config.yaml"), "w") as f: |
|
yaml.dump(config, f) |
|
|
|
return dataset_dir |
|
|
|
|
|
def run_training(hf_dataset_path: str): |
|
|
|
dataset_dir = download_dataset(hf_dataset_path) |
|
dataset_dir = process_dataset(dataset_dir) |
|
|
|
|
|
commands = "git clone https://github.com/ostris/ai-toolkit.git ai-toolkit && cd ai-toolkit && git submodule update --init --recursive" |
|
subprocess.run(commands, shell=True) |
|
|
|
commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}" |
|
process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ) |
|
|
|
return process, dataset_dir |
|
|
|
|
|
if __name__ == "__main__": |
|
process, dataset_dir = run_training(HF_DATASET) |
|
process.wait() |
|
|
|
with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: |
|
config = yaml.safe_load(f) |
|
repo_id = config["config"]["process"][0]["save"]["hf_repo_id"] |
|
|
|
metadata = { |
|
"tags": [ |
|
"autotrain", |
|
"spacerunner", |
|
"text-to-image", |
|
"flux", |
|
"lora", |
|
"diffusers", |
|
"template:sd-lora", |
|
] |
|
} |
|
metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True) |
|
delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True) |
|
|