ban-cars / app.py
lsb's picture
no old size, yes inference steps, warm up on two sizes
be54e00
import gradio as gr
import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from tqdm import tqdm
import numpy as np
from datetime import datetime
# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
seg_preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
seg_preferred_dtype = torch.float32 # torch.float16 if seg_preferred_device == 'cuda' else torch.float32
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
inpaint_preferred_dtype = torch.float32 if inpaint_preferred_device == 'cpu' else torch.float16
torch.backends.cuda.matmul.allow_tf32 = True
print(f"backends: {torch._dynamo.list_backends()}")
preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else "inductor"
seg_model_img_size = 768
seg_model_size = 0
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(seg_preferred_device).to(seg_preferred_dtype)
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"SimianLuo/LCM_Dreamshaper_v7",
torch_dtype=inpaint_preferred_dtype,
safety_checker=None,
).to(inpaint_preferred_device)
inpainting_pipeline.text_encoder = torch.compile(inpainting_pipeline.text_encoder, backend=preferred_backend)
inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend)
inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend)
seg_model = torch.compile(seg_model, backend=preferred_backend)
seg_working_size = (seg_model_img_size, seg_model_img_size)
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")
ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)
def get_seg_mask(img):
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(seg_preferred_device).to(seg_preferred_dtype)
outputs = seg_model(**inputs)
logits = outputs.logits[0]
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000)
return blurred_widened_mask
def app(img, prompt, num_inference_steps, seed, inpaint_size):
start_time = datetime.now().timestamp()
img = np.array(Image.fromarray(img).resize(seg_working_size))
mask = get_seg_mask(img)
# mask.save("mask.jpg")
mask_time = datetime.now().timestamp()
#print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
overlay_img = inpainting_pipeline(
prompt=prompt,
image=Image.fromarray(img).resize((inpaint_size, inpaint_size)),
mask_image=(mask).resize((inpaint_size, inpaint_size)),
strength=1,
num_inference_steps=num_inference_steps,
height=inpaint_size,
width=inpaint_size,
generator=torch.manual_seed(int(seed)),
).images[0]
#overlay_img.save("overlay_raw.jpg")
end_time = datetime.now().timestamp()
draw = ImageDraw.Draw(overlay_img)
# replace spaces with newlines after many words to line break prompt
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])
draw.text((10, 50), "\n".join([
f"Total duration: {int(1000 * (end_time - start_time))}ms",
f"Inference steps: {num_inference_steps}",
f"Segmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))}",
f"<{prompt}>"
]), fill=(0, 255, 0))
#overlay_img.save("overlay_with_text.jpg")
return overlay_img
# warmup, for compiling and then for timing
for size in [384,512]:
for i in range(2):
for j in tqdm(range(3 ** i)):
app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42, size).save("zeros_inpainting_oneshot.jpg")
#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [
gr.Image(),
gr.Textbox(value=default_inpainting_prompt),
gr.Number(minimum=1, maximum=8, value=4),
gr.Number(value=42),
gr.Number(value=512, maximum=seg_model_img_size,)
],
"image")
iface.launch(share=True)