|
import spaces |
|
import os |
|
from typing import cast |
|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import torchvision |
|
from diffusers import DDIMScheduler |
|
from load_image import load_exr_image, load_ldr_image |
|
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline |
|
|
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
|
|
|
current_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
_pipe = StableDiffusionAOVMatEstPipeline.from_pretrained( |
|
"zheng95z/rgb-to-x", |
|
torch_dtype=torch.float16, |
|
cache_dir=os.path.join(current_directory, "model_cache"), |
|
).to("cuda") |
|
pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe) |
|
pipe.scheduler = DDIMScheduler.from_config( |
|
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" |
|
) |
|
pipe.set_progress_bar_config(disable=True) |
|
pipe.to("cuda") |
|
pipe = cast(StableDiffusionAOVMatEstPipeline, pipe) |
|
|
|
|
|
@spaces.GPU |
|
def generate( |
|
photo, |
|
seed: int, |
|
inference_step: int, |
|
num_samples: int, |
|
) -> list[Image.Image]: |
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
|
if photo.name.endswith(".exr"): |
|
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda") |
|
elif ( |
|
photo.name.endswith(".png") |
|
or photo.name.endswith(".jpg") |
|
or photo.name.endswith(".jpeg") |
|
): |
|
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda") |
|
|
|
|
|
old_height = photo.shape[1] |
|
old_width = photo.shape[2] |
|
new_height = old_height |
|
new_width = old_width |
|
radio = old_height / old_width |
|
max_side = 1000 |
|
if old_height > old_width: |
|
new_height = max_side |
|
new_width = int(new_height / radio) |
|
else: |
|
new_width = max_side |
|
new_height = int(new_width * radio) |
|
|
|
if new_width % 8 != 0 or new_height % 8 != 0: |
|
new_width = new_width // 8 * 8 |
|
new_height = new_height // 8 * 8 |
|
|
|
photo = torchvision.transforms.Resize((new_height, new_width))(photo) |
|
|
|
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] |
|
prompts = { |
|
"albedo": "Albedo (diffuse basecolor)", |
|
"normal": "Camera-space Normal", |
|
"roughness": "Roughness", |
|
"metallic": "Metallicness", |
|
"irradiance": "Irradiance (diffuse lighting)", |
|
} |
|
|
|
return_list = [] |
|
for i in range(num_samples): |
|
for aov_name in required_aovs: |
|
prompt = prompts[aov_name] |
|
generated_image = pipe( |
|
prompt=prompt, |
|
photo=photo, |
|
num_inference_steps=inference_step, |
|
height=new_height, |
|
width=new_width, |
|
generator=generator, |
|
required_aovs=[aov_name], |
|
).images[0][0] |
|
|
|
generated_image = torchvision.transforms.Resize((old_height, old_width))( |
|
generated_image |
|
) |
|
|
|
generated_image = (generated_image, f"Generated {aov_name} {i}") |
|
return_list.append(generated_image) |
|
|
|
return return_list |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)") |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Given Image") |
|
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"]) |
|
|
|
gr.Markdown("### Parameters") |
|
run_button = gr.Button(value="Run") |
|
with gr.Accordion("Advanced options", open=False): |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
randomize=True, |
|
) |
|
inference_step = gr.Slider( |
|
label="Inference Step", |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=50, |
|
) |
|
num_samples = gr.Slider( |
|
label="Samples", |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=1, |
|
) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("### Output Gallery") |
|
result_gallery = gr.Gallery( |
|
label="Output", |
|
show_label=False, |
|
elem_id="gallery", |
|
columns=2, |
|
) |
|
examples = gr.Examples( |
|
examples=[ |
|
[ |
|
"rgb2x/example/Castlereagh_corridor_photo.png", |
|
] |
|
], |
|
inputs=[photo], |
|
outputs=[result_gallery], |
|
fn=generate, |
|
cache_mode="eager", |
|
cache_examples=True, |
|
) |
|
|
|
run_button.click( |
|
fn=generate, |
|
inputs=[photo, seed, inference_step, num_samples], |
|
outputs=result_gallery, |
|
queue=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=False, share=False, show_api=False) |
|
|