Spaces:
Runtime error
Runtime error
import json | |
import logging | |
import os | |
from typing import Optional | |
from hbutils.system import TemporaryDirectory | |
from huggingface_hub import hf_hub_url | |
from tqdm.auto import tqdm | |
from .draw import _DEFAULT_INFER_MODEL, draw_with_workdir | |
from ..dataset import save_recommended_tags | |
from ..utils import get_hf_fs, download_file | |
def draw_to_directory(workdir: str, export_dir: str, step: int, n_repeats: int = 2, | |
pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, | |
image_width: int = 512, image_height: int = 768, infer_steps: int = 30, | |
lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', | |
model_hash: Optional[str] = None): | |
from ..publish.export import KNOWN_MODEL_HASHES | |
model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model) | |
os.makedirs(export_dir, exist_ok=True) | |
while True: | |
try: | |
drawings = draw_with_workdir( | |
workdir, model_steps=step, n_repeats=n_repeats, | |
pretrained_model=pretrained_model, | |
width=image_width, height=image_height, infer_steps=infer_steps, | |
lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method, | |
model_hash=model_hash, | |
) | |
except RuntimeError: | |
n_repeats += 1 | |
else: | |
break | |
all_image_files = [] | |
for draw in drawings: | |
img_file = os.path.join(export_dir, f'{draw.name}.png') | |
draw.image.save(img_file, pnginfo=draw.pnginfo) | |
all_image_files.append(img_file) | |
with open(os.path.join(export_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f: | |
print(draw.preview_info, file=f) | |
def draw_with_repo(repository: str, export_dir: str, step: Optional[int] = None, n_repeats: int = 2, | |
pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, | |
image_width: int = 512, image_height: int = 768, infer_steps: int = 30, | |
lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', | |
model_hash: Optional[str] = None): | |
from ..publish import find_steps_in_workdir | |
hf_fs = get_hf_fs() | |
if not hf_fs.exists(f'{repository}/meta.json'): | |
raise ValueError(f'Invalid repository or no model found - {repository!r}.') | |
logging.info(f'Model repository {repository!r} found.') | |
meta = json.loads(hf_fs.read_text(f'{repository}/meta.json')) | |
step = step or meta['best_step'] | |
logging.info(f'Using step {step} ...') | |
with TemporaryDirectory() as workdir: | |
logging.info('Downloading models ...') | |
for f in tqdm(hf_fs.glob(f'{repository}/{step}/raw/*')): | |
rel_file = os.path.relpath(f, repository) | |
local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file)) | |
if os.path.dirname(local_file): | |
os.makedirs(os.path.dirname(local_file), exist_ok=True) | |
download_file( | |
hf_hub_url(repository, filename=rel_file), | |
local_file | |
) | |
logging.info(f'Regenerating tags for {workdir!r} ...') | |
pt_name, _ = find_steps_in_workdir(workdir) | |
game_name = pt_name.split('_')[-1] | |
name = '_'.join(pt_name.split('_')[:-1]) | |
from gchar.games.dispatch.access import GAME_CHARS | |
if game_name in GAME_CHARS: | |
ch_cls = GAME_CHARS[game_name] | |
ch = ch_cls.get(name) | |
else: | |
ch = None | |
if ch is None: | |
source = repository | |
else: | |
source = ch | |
logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.') | |
save_recommended_tags(source, name=pt_name, workdir=workdir, ds_size=meta["dataset"]['type']) | |
logging.info('Drawing ...') | |
draw_to_directory( | |
workdir, export_dir, step, | |
n_repeats, pretrained_model, clip_skip, image_width, image_height, infer_steps, | |
lora_alpha, sample_method, model_hash | |
) | |