Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from safetensors.numpy import save_file, load_file | |
from omegaconf import OmegaConf | |
from transformers import AutoConfig | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import json | |
import os | |
# | |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInpaintPipeline, DDIMScheduler, AutoencoderKL | |
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DDIMScheduler | |
from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler | |
from diffusers.image_processor import VaeImageProcessor | |
# | |
from models.pipeline_mimicbrush import MimicBrushPipeline | |
from models.ReferenceNet import ReferenceNet | |
from models.depth_guider import DepthGuider | |
from mimicbrush import MimicBrush_RefNet | |
from data_utils import * | |
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download | |
from huggingface_hub import snapshot_download | |
snapshot_download(repo_id="xichenhku/cleansd", local_dir="./cleansd") | |
print('=== Pretrained SD weights downloaded ===') | |
snapshot_download(repo_id="xichenhku/MimicBrush", local_dir="./MimicBrush") | |
print('=== MimicBrush weights downloaded ===') | |
#sd_dir = ms_snapshot_download('xichen/cleansd', cache_dir='./modelscope') | |
#print('=== Pretrained SD weights downloaded ===') | |
#model_dir = ms_snapshot_download('xichen/MimicBrush', cache_dir='./modelscope') | |
#print('=== MimicBrush weights downloaded ===') | |
val_configs = OmegaConf.load('./configs/inference.yaml') | |
# === import Depth Anything === | |
import sys | |
sys.path.append("./depthanything") | |
from torchvision.transforms import Compose | |
from depthanything.fast_import import depth_anything_model | |
from depthanything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet | |
transform = Compose([ | |
Resize( | |
width=518, | |
height=518, | |
resize_target=False, | |
keep_aspect_ratio=True, | |
ensure_multiple_of=14, | |
resize_method='lower_bound', | |
image_interpolation_method=cv2.INTER_CUBIC, | |
), | |
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
PrepareForNet(), | |
]) | |
depth_anything_model.load_state_dict(torch.load(val_configs.model_path.depth_model)) | |
# === load the checkpoint === | |
base_model_path = val_configs.model_path.pretrained_imitativer_path | |
vae_model_path = val_configs.model_path.pretrained_vae_name_or_path | |
image_encoder_path = val_configs.model_path.image_encoder_path | |
ref_model_path = val_configs.model_path.pretrained_reference_path | |
mimicbrush_ckpt = val_configs.model_path.mimicbrush_ckpt_path | |
device = "cuda" | |
def pad_img_to_square(original_image, is_mask=False): | |
width, height = original_image.size | |
if height == width: | |
return original_image | |
if height > width: | |
padding = (height - width) // 2 | |
new_size = (height, height) | |
else: | |
padding = (width - height) // 2 | |
new_size = (width, width) | |
if is_mask: | |
new_image = Image.new("RGB", new_size, "black") | |
else: | |
new_image = Image.new("RGB", new_size, "white") | |
if height > width: | |
new_image.paste(original_image, (padding, 0)) | |
else: | |
new_image.paste(original_image, (0, padding)) | |
return new_image | |
def collage_region(low, high, mask): | |
mask = (np.array(mask) > 128).astype(np.uint8) | |
low = np.array(low).astype(np.uint8) | |
low = (low * 0).astype(np.uint8) | |
high = np.array(high).astype(np.uint8) | |
mask_3 = mask | |
collage = low * mask_3 + high * (1-mask_3) | |
collage = Image.fromarray(collage) | |
return collage | |
def resize_image_keep_aspect_ratio(image, target_size = 512): | |
height, width = image.shape[:2] | |
if height > width: | |
new_height = target_size | |
new_width = int(width * (target_size / height)) | |
else: | |
new_width = target_size | |
new_height = int(height * (target_size / width)) | |
resized_image = cv2.resize(image, (new_width, new_height)) | |
return resized_image | |
def crop_padding_and_resize(ori_image, square_image): | |
ori_height, ori_width, _ = ori_image.shape | |
scale = max(ori_height / square_image.shape[0], ori_width / square_image.shape[1]) | |
resized_square_image = cv2.resize(square_image, (int(square_image.shape[1] * scale), int(square_image.shape[0] * scale))) | |
padding_size = max(resized_square_image.shape[0] - ori_height, resized_square_image.shape[1] - ori_width) | |
if ori_height < ori_width: | |
top = padding_size // 2 | |
bottom = resized_square_image.shape[0] - (padding_size - top) | |
cropped_image = resized_square_image[top:bottom, :,:] | |
else: | |
left = padding_size // 2 | |
right = resized_square_image.shape[1] - (padding_size - left) | |
cropped_image = resized_square_image[:, left:right,:] | |
return cropped_image | |
def vis_mask(image, mask): | |
# mask 3 channle 255 | |
mask = mask[:,:,0] | |
mask_contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
# Draw outlines, using random colors | |
outline_opacity = 0.5 | |
outline_thickness = 5 | |
outline_color = np.concatenate([ [255,255,255], [outline_opacity] ]) | |
white_mask = np.ones_like(image) * 255 | |
mask_bin_3 = np.stack([mask,mask,mask],-1) > 128 | |
alpha = 0.5 | |
image = ( white_mask * alpha + image * (1-alpha) ) * mask_bin_3 + image * (1-mask_bin_3) | |
cv2.polylines(image, mask_contours, True, outline_color, outline_thickness, cv2.LINE_AA) | |
return image | |
noise_scheduler = DDIMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
) | |
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16) | |
unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet", in_channels=13, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(dtype=torch.float16) | |
pipe = MimicBrushPipeline.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.float16, | |
scheduler=noise_scheduler, | |
vae=vae, | |
unet=unet, | |
feature_extractor=None, | |
safety_checker=None, | |
) | |
depth_guider = DepthGuider() | |
referencenet = ReferenceNet.from_pretrained(ref_model_path, subfolder="unet").to(dtype=torch.float16) | |
mimicbrush_model = MimicBrush_RefNet(pipe, image_encoder_path, mimicbrush_ckpt, depth_anything_model, depth_guider, referencenet, device) | |
mask_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False, do_binarize=True, do_convert_grayscale=True) | |
def infer_single(ref_image, target_image, target_mask, seed = -1, num_inference_steps=50, guidance_scale = 5, enable_shape_control = False): | |
#return ref_image | |
""" | |
mask: 0/1 1-channel np.array | |
image: rgb np.array | |
""" | |
ref_image = ref_image.astype(np.uint8) | |
target_image = target_image.astype(np.uint8) | |
target_mask = target_mask .astype(np.uint8) | |
ref_image = Image.fromarray(ref_image.astype(np.uint8)) | |
ref_image = pad_img_to_square(ref_image) | |
target_image = pad_img_to_square(Image.fromarray(target_image)) | |
target_image_low = target_image | |
target_mask = np.stack([target_mask,target_mask,target_mask],-1).astype(np.uint8) * 255 | |
target_mask_np = target_mask.copy() | |
target_mask = Image.fromarray(target_mask) | |
target_mask = pad_img_to_square(target_mask, True) | |
target_image_ori = target_image.copy() | |
target_image = collage_region(target_image_low, target_image, target_mask) | |
depth_image = target_image_ori.copy() | |
depth_image = np.array(depth_image) | |
depth_image = transform({'image': depth_image})['image'] | |
depth_image = torch.from_numpy(depth_image).unsqueeze(0) / 255 | |
if not enable_shape_control: | |
depth_image = depth_image * 0 | |
mask_pt = mask_processor.preprocess(target_mask, height=512, width=512) | |
pred, depth_pred = mimicbrush_model.generate(pil_image=ref_image, depth_image = depth_image, num_samples=1, num_inference_steps=num_inference_steps, | |
seed=seed, image=target_image, mask_image=mask_pt, strength=1.0, guidance_scale=guidance_scale) | |
depth_pred = F.interpolate(depth_pred, size=(512,512), mode = 'bilinear', align_corners=True)[0][0] | |
depth_pred = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0 | |
depth_pred = depth_pred.detach().cpu().numpy().astype(np.uint8) | |
depth_pred = cv2.applyColorMap(depth_pred, cv2.COLORMAP_INFERNO)[:,:,::-1] | |
pred = pred[0] | |
pred = np.array(pred).astype(np.uint8) | |
return pred, depth_pred.astype(np.uint8) | |
def inference_single_image(ref_image, | |
tar_image, | |
tar_mask, | |
ddim_steps, | |
scale, | |
seed, | |
enable_shape_control, | |
): | |
if seed == -1: | |
seed = np.random.randint(10000) | |
pred, depth_pred = infer_single(ref_image, tar_image, tar_mask, seed, num_inference_steps=ddim_steps, guidance_scale = scale, enable_shape_control = enable_shape_control) | |
return pred, depth_pred | |
def run_local(base, | |
ref, | |
*args): | |
# print("base文件:",base) | |
# print("base文件背景图:", base["background"]) | |
# print("base文件层次图:", base["layers"]) | |
# print("参考图:", ref) | |
image = base["background"].convert("RGB") #base["image"].convert("RGB") | |
mask = base["layers"][0] #base["mask"].convert("L") | |
image = np.asarray(image) | |
mask = np.asarray(mask)[:,:,-1] | |
#print(image.shape, mask.shape, mask.max(), mask.min()) | |
mask = np.where(mask > 128, 1, 0).astype(np.uint8) | |
ref_image = ref.convert("RGB") | |
ref_image = np.asarray(ref_image) | |
if mask.sum() == 0: | |
raise gr.Error('No mask for the background image.') | |
mask_3 = np.stack([mask,mask,mask],-1).astype(np.uint8) * 255 | |
mask_alpha = mask_3.copy() | |
for i in range(10): | |
mask_alpha = cv2.GaussianBlur(mask_alpha, (3, 3), 0) | |
synthesis, depth_pred = inference_single_image(ref_image.copy(), image.copy(), mask.copy(), *args) | |
synthesis = crop_padding_and_resize(image, synthesis) | |
depth_pred = crop_padding_and_resize(image, depth_pred) | |
mask_3_bin = mask_alpha / 255 | |
synthesis = synthesis * mask_3_bin + image * (1-mask_3_bin) | |
vis_source = vis_mask(image, mask_3).astype(np.uint8) | |
return [synthesis.astype(np.uint8), depth_pred.astype(np.uint8), vis_source, mask_3] | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# MimicBrush: Zero-shot Image Editing with Reference Imitation ") | |
with gr.Row(): | |
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768) | |
with gr.Accordion("Advanced Option", open=True): | |
num_samples = 1 | |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1) | |
scale = gr.Slider(label="Guidance Scale", minimum=-30.0, maximum=30.0, value=5.0, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1) | |
enable_shape_control = gr.Checkbox(label='Keep the original shape', value=False, interactive = True) | |
gr.Markdown("### Tutorial") | |
gr.Markdown("1. Upload the source image and the reference image") | |
gr.Markdown("2. Select the \"draw button\" to mask the to-edit region on the source image ") | |
gr.Markdown("3. Click generate ") | |
gr.Markdown("#### You shoud click \"keep the original shape\" to conduct texture transfer ") | |
gr.Markdown("# Upload the source image and reference image") | |
gr.Markdown("### Tips: you could adjust the brush size") | |
with gr.Row(): | |
base = gr.ImageEditor( label="Source", | |
type="pil", | |
brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"), | |
layers = False, | |
interactive=True | |
) | |
ref = gr.Image(label="Reference", sources="upload", type="pil", height=512) | |
run_local_button = gr.Button(value="Run") | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
[ | |
'./demo_example/005_source.png', | |
'./demo_example/005_reference.png', | |
0 | |
], | |
[ | |
'./demo_example/004_source.png', | |
'./demo_example/004_reference.png', | |
0 | |
], | |
[ | |
'./demo_example/000_source.png', | |
'./demo_example/000_reference.png', | |
0 | |
], | |
[ | |
'./demo_example/003_source.png', | |
'./demo_example/003_reference.png', | |
0 | |
], | |
[ | |
'./demo_example/006_source.png', | |
'./demo_example/006_reference.png', | |
0 | |
], | |
[ | |
'./demo_example/001_source.png', | |
'./demo_example/001_reference.png', | |
1 | |
], | |
[ | |
'./demo_example/002_source.png', | |
'./demo_example/002_reference.png', | |
1 | |
], | |
[ | |
'./demo_example/007_source.png', | |
'./demo_example/007_reference.png', | |
1 | |
], | |
], | |
inputs=[ | |
base, | |
ref, | |
enable_shape_control | |
], | |
cache_examples=False, | |
examples_per_page=100) | |
run_local_button.click(fn=run_local, | |
inputs=[base, | |
ref, | |
ddim_steps, | |
scale, | |
seed, | |
enable_shape_control | |
], | |
outputs=[baseline_gallery] | |
) | |
demo.launch(server_name="0.0.0.0") |