x2rgb / rgb2x /gradio_demo_rgb2x.py
blanchon's picture
Update examples
e4ee2ca
raw
history blame
5.18 kB
import spaces
import os
from typing import cast
import gradio as gr
from PIL import Image
import torch
import torchvision
from diffusers import DDIMScheduler
from load_image import load_exr_image, load_ldr_image
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
current_directory = os.path.dirname(os.path.abspath(__file__))
_pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
"zheng95z/rgb-to-x",
torch_dtype=torch.float16,
cache_dir=os.path.join(current_directory, "model_cache"),
).to("cuda")
pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe)
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
pipe.set_progress_bar_config(disable=True)
pipe.to("cuda")
pipe = cast(StableDiffusionAOVMatEstPipeline, pipe)
@spaces.GPU
def generate(
photo,
seed: int,
inference_step: int,
num_samples: int,
) -> list[Image.Image]:
generator = torch.Generator(device="cuda").manual_seed(seed)
if photo.name.endswith(".exr"):
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
elif (
photo.name.endswith(".png")
or photo.name.endswith(".jpg")
or photo.name.endswith(".jpeg")
):
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
old_height = photo.shape[1]
old_width = photo.shape[2]
new_height = old_height
new_width = old_width
radio = old_height / old_width
max_side = 1000
if old_height > old_width:
new_height = max_side
new_width = int(new_height / radio)
else:
new_width = max_side
new_height = int(new_width * radio)
if new_width % 8 != 0 or new_height % 8 != 0:
new_width = new_width // 8 * 8
new_height = new_height // 8 * 8
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
prompts = {
"albedo": "Albedo (diffuse basecolor)",
"normal": "Camera-space Normal",
"roughness": "Roughness",
"metallic": "Metallicness",
"irradiance": "Irradiance (diffuse lighting)",
}
return_list = []
for i in range(num_samples):
for aov_name in required_aovs:
prompt = prompts[aov_name]
generated_image = pipe(
prompt=prompt,
photo=photo,
num_inference_steps=inference_step,
height=new_height,
width=new_width,
generator=generator,
required_aovs=[aov_name],
).images[0][0] # type: ignore
generated_image = torchvision.transforms.Resize((old_height, old_width))(
generated_image
)
generated_image = (generated_image, f"Generated {aov_name} {i}")
return_list.append(generated_image)
return return_list
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
with gr.Row():
# Input side
with gr.Column():
gr.Markdown("### Given Image")
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
gr.Markdown("### Parameters")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
randomize=True,
)
inference_step = gr.Slider(
label="Inference Step",
minimum=1,
maximum=100,
step=1,
value=50,
)
num_samples = gr.Slider(
label="Samples",
minimum=1,
maximum=100,
step=1,
value=1,
)
# Output side
with gr.Column():
gr.Markdown("### Output Gallery")
result_gallery = gr.Gallery(
label="Output",
show_label=False,
elem_id="gallery",
columns=2,
)
examples = gr.Examples(
examples=[
[
"rgb2x/example/Castlereagh_corridor_photo.png",
]
],
inputs=[photo],
outputs=[result_gallery],
fn=generate,
cache_mode="eager",
cache_examples=True,
)
run_button.click(
fn=generate,
inputs=[photo, seed, inference_step, num_samples],
outputs=result_gallery,
queue=True,
)
if __name__ == "__main__":
demo.launch(debug=False, share=False, show_api=False)