person-thumbs-up / inference.py
Srimanth Agastyaraju
Update README, Add result images, app.py changes
7c6ffc8
raw
history blame contribute delete
No virus
1.9 kB
import os
from huggingface_hub import model_info
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
def main():
REPOS = {
"tom_cruise_plain": {"hub_model_id": "asrimanth/person-thumbs-up-plain-lora", "model_dir": "/l/vision/v5/sragas/easel_ai/models_plain/"},
"tom_cruise": {"hub_model_id": "asrimanth/person-thumbs-up-lora", "model_dir": "/l/vision/v5/sragas/easel_ai/models/"},
"tom_cruise_no_cap": {"hub_model_id": "asrimanth/person-thumbs-up-lora-no-cap", "model_dir": "/l/vision/v5/sragas/easel_ai/models_no_cap/"},
"srimanth_plain": {"hub_model_id": "asrimanth/srimanth-thumbs-up-lora-plain", "model_dir": "/l/vision/v5/sragas/easel_ai/models_srimanth_plain/"}
}
N_IMAGES = 50
current_repo_id = "tom_cruise_no_cap"
SAVE_DIR = f"./results/{current_repo_id}/"
os.makedirs(SAVE_DIR, exist_ok=True)
current_repo = REPOS[current_repo_id]
print(f"{'-'*20} CURRENT REPO: {current_repo_id} {'-'*20}")
hub_model_id = current_repo["hub_model_id"]
model_dir = current_repo["model_dir"]
info = model_info(hub_model_id)
model_base = info.cardData["base_model"]
print(f"Base model is: {model_base}")
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16, cache_dir="/l/vision/v5/sragas/hf_models/")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.load_attn_procs(hub_model_id)
pipe.to("cuda")
generators = [torch.Generator("cuda").manual_seed(i) for i in range(N_IMAGES)]
prompt = "<tom_cruise> showing #thumbsup"
print(f"Inferencing '{prompt}' for {N_IMAGES} images.")
for i in range(N_IMAGES):
image = pipe(prompt, generator=generators[i], num_inference_steps=25).images[0]
image.save(f"{SAVE_DIR}out_{i}.png")
if __name__ == "__main__":
main()