import os import sys import torch import gradio as gr import numpy as np import torchvision.transforms as transforms from torch.autograd import Variable from network.Transformer import Transformer import logging logger = logging.getLogger(__name__) MAX_DIMENSION = 1280 MODEL_PATH = "models" COLOUR_MODEL = "RGB" STYLE_SHINKAI = "Makoto Shinkai" STYLE_HOSODA = "Mamoru Hosoda" STYLE_MIYAZAKI = "Hayao Miyazaki" STYLE_KON = "Satoshi Kon" DEFAULT_STYLE = STYLE_SHINKAI STYLE_CHOICE_LIST = [STYLE_SHINKAI, STYLE_HOSODA, STYLE_MIYAZAKI, STYLE_KON] shinkai_model = Transformer() hosoda_model = Transformer() miyazaki_model = Transformer() kon_model = Transformer() shinkai_model.load_state_dict( torch.load(os.path.join(MODEL_PATH, "shinkai_makoto.pth")) ) hosoda_model.load_state_dict( torch.load(os.path.join(MODEL_PATH, "hosoda_mamoru.pth")) ) miyazaki_model.load_state_dict( torch.load(os.path.join(MODEL_PATH, "miyazaki_hayao.pth")) ) kon_model.load_state_dict( torch.load(os.path.join(MODEL_PATH, "kon_satoshi.pth")) ) shinkai_model.eval() hosoda_model.eval() miyazaki_model.eval() kon_model.eval() disable_gpu = True def get_model(style): if style == STYLE_SHINKAI: return shinkai_model elif style == STYLE_HOSODA: return hosoda_model elif style == STYLE_MIYAZAKI: return miyazaki_model elif style == STYLE_KON: return kon_model else: logger.warning( f"Style {style} not found. Defaulting to Makoto Shinkai" ) return shinkai_model def validate_image_size(img): logger.info(f"Image Height: {img.height}, Image Width: {img.width}") if img.height > MAX_DIMENSION or img.width > MAX_DIMENSION: raise RuntimeError( "Image size is too large. Please use an image less than {MAX_DIMENSION}px on both width and height" ) def inference(img, style): validate_image_size(img) # load image input_image = img.convert(COLOUR_MODEL) input_image = np.asarray(input_image) # RGB -> BGR input_image = input_image[:, :, [2, 1, 0]] input_image = transforms.ToTensor()(input_image).unsqueeze(0) # preprocess, (-1, 1) input_image = -1 + 2 * input_image if disable_gpu: input_image = Variable(input_image).float() else: input_image = Variable(input_image).cuda() # forward model = get_model(style) output_image = model(input_image) output_image = output_image[0] # BGR -> RGB output_image = output_image[[2, 1, 0], :, :] output_image = output_image.data.cpu().float() * 0.5 + 0.5 return transforms.ToPILImage()(output_image) title = "Anime Background GAN" description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao." article = "
CartoonGAN Whitepaper from Chen et.al
Original Implementation from Yijunmaverick