Akash James
Default CUDA and Post Processing as True; Potential CPU fix
c9fd122
from __future__ import print_function
import torch
import process_stylization
from photo_wct import PhotoWCT
import gradio as gr
from datetime import datetime
# Load model
model_path = './models/photo_wct.pth'
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(model_path))
def run(content_img, style_img, cuda, post_processing, fast):
now = datetime.now()
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
print("[TimeStamp] {}".format(dt_string))
if fast == 0:
from photo_gif import GIFSmoothing
p_pro = GIFSmoothing(r=35, eps=0.001)
else:
from photo_smooth import Propagator
p_pro = Propagator()
if cuda:
p_wct.cuda(0)
else:
p_wct.to('cpu')
output_img = process_stylization.stylization_gradio(
stylization_module=p_wct,
smoothing_module=p_pro,
content_image=content_img,
style_image=style_img,
cuda=cuda,
post_processing=post_processing
)
return output_img
if __name__ == '__main__':
style = gr.Interface(
fn=run,
inputs=[
gr.Image(label='Content Image'),
gr.Image(label='Stylize Image'),
gr.Checkbox(value=True, label='Use CUDA'),
gr.Checkbox(value=True, label='Post Processing'),
gr.Radio(choices=["Guided Image Filtering (Fast)", "Photorealisitic Smoothing (Slow)"], value="Guided Image Filtering (Fast)", type="index", label="Algorithm", interactive=False),
],
outputs=[gr.Image(
type="pil",
label="Result"),
]
)
style.queue()
style.launch()