|
import argparse |
|
import logging |
|
from diffusers import AmusedPipeline |
|
import os |
|
from peft import PeftModel |
|
from diffusers import UVit2DModel |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--pretrained_model_name_or_path", |
|
type=str, |
|
default=None, |
|
required=True, |
|
help="Path to pretrained model or model identifier from huggingface.co/models.", |
|
) |
|
parser.add_argument( |
|
"--revision", |
|
type=str, |
|
default=None, |
|
required=False, |
|
help="Revision of pretrained model identifier from huggingface.co/models.", |
|
) |
|
parser.add_argument( |
|
"--variant", |
|
type=str, |
|
default=None, |
|
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", |
|
) |
|
parser.add_argument("--style_descriptor", type=str, default="[V]") |
|
parser.add_argument( |
|
"--load_transformer_from", |
|
type=str, |
|
required=False, |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--load_transformer_lora_from", |
|
type=str, |
|
required=False, |
|
default=None, |
|
) |
|
parser.add_argument("--device", type=str, default='cuda') |
|
parser.add_argument("--batch_size", type=int, default=1) |
|
parser.add_argument("--write_images_to", type=str, required=True) |
|
args = parser.parse_args() |
|
return args |
|
|
|
def main(args): |
|
prompts = [ |
|
f"A chihuahua in {args.style_descriptor} style", |
|
f"A tabby cat in {args.style_descriptor} style", |
|
f"A portrait of chihuahua in {args.style_descriptor} style", |
|
f"An apple on the table in {args.style_descriptor} style", |
|
f"A banana on the table in {args.style_descriptor} style", |
|
f"A church on the street in {args.style_descriptor} style", |
|
f"A church in the mountain in {args.style_descriptor} style", |
|
f"A church in the field in {args.style_descriptor} style", |
|
f"A church on the beach in {args.style_descriptor} style", |
|
f"A chihuahua walking on the street in {args.style_descriptor} style", |
|
f"A tabby cat walking on the street in {args.style_descriptor} style", |
|
f"A portrait of tabby cat in {args.style_descriptor} style", |
|
f"An apple on the dish in {args.style_descriptor} style", |
|
f"A banana on the dish in {args.style_descriptor} style", |
|
f"A human walking on the street in {args.style_descriptor} style", |
|
f"A temple on the street in {args.style_descriptor} style", |
|
f"A temple in the mountain in {args.style_descriptor} style", |
|
f"A temple in the field in {args.style_descriptor} style", |
|
f"A temple on the beach in {args.style_descriptor} style", |
|
f"A chihuahua walking in the forest in {args.style_descriptor} style", |
|
f"A tabby cat walking in the forest in {args.style_descriptor} style", |
|
f"A portrait of human face in {args.style_descriptor} style", |
|
f"An apple on the ground in {args.style_descriptor} style", |
|
f"A banana on the ground in {args.style_descriptor} style", |
|
f"A human walking in the forest in {args.style_descriptor} style", |
|
f"A cabin on the street in {args.style_descriptor} style", |
|
f"A cabin in the mountain in {args.style_descriptor} style", |
|
f"A cabin in the field in {args.style_descriptor} style", |
|
f"A cabin on the beach in {args.style_descriptor} style" |
|
] |
|
|
|
logger.warning(f"generating image for {prompts}") |
|
|
|
logger.warning(f"loading models") |
|
|
|
pipe_args = {} |
|
|
|
if args.load_transformer_from is not None: |
|
pipe_args["transformer"] = UVit2DModel.from_pretrained(args.load_transformer_from) |
|
|
|
pipe = AmusedPipeline.from_pretrained( |
|
pretrained_model_name_or_path=args.pretrained_model_name_or_path, |
|
revision=args.revision, |
|
variant=args.variant, |
|
**pipe_args |
|
) |
|
|
|
if args.load_transformer_lora_from is not None: |
|
pipe.transformer = PeftModel.from_pretrained( |
|
pipe.transformer, os.path.join(args.load_transformer_from), is_trainable=False |
|
) |
|
|
|
pipe.to(args.device) |
|
|
|
logger.warning(f"generating images") |
|
|
|
os.makedirs(args.write_images_to, exist_ok=True) |
|
|
|
for prompt_idx in range(0, len(prompts), args.batch_size): |
|
images = pipe(prompts[prompt_idx:prompt_idx+args.batch_size]).images |
|
|
|
for image_idx, image in enumerate(images): |
|
prompt = prompts[prompt_idx+image_idx] |
|
image.save(os.path.join(args.write_images_to, prompt + ".png")) |
|
|
|
if __name__ == "__main__": |
|
main(parse_args()) |