|
import gradio as gr |
|
import torch |
|
from torch import nn |
|
from torchvision.transforms.functional import to_pil_image |
|
import torchvision.utils as vutils |
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, ngpu): |
|
super(Generator, self).__init__() |
|
self.ngpu = ngpu |
|
self.main = nn.Sequential( |
|
|
|
nn.ConvTranspose2d(128, 512, 4,1,0,bias=False), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(512,256,4,2,1,bias=False), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(256, 128,4,2,1,bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(128, 64,4,2,1,bias=False), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(64, 3, 4,2,1,bias=False), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, x): |
|
return self.main(x) |
|
|
|
model = Generator(ngpu=0) |
|
|
|
model.load_state_dict(torch.load('car_gen.pth',map_location='cpu')) |
|
|
|
|
|
def generate(button): |
|
|
|
model.eval() |
|
|
|
noise = torch.randn(32,128,1,1) |
|
|
|
|
|
with torch.inference_mode(): |
|
images = [] |
|
predictions = model(noise).detach().cpu() |
|
generated_grid = vutils.make_grid(predictions, nrow=8, padding=2, normalize=True) |
|
return to_pil_image(generated_grid) |
|
|
|
|
|
Interface = gr.Interface( |
|
title='CarGAN', |
|
fn=generate, |
|
inputs=gr.Button(label='Generate',size='lg'), |
|
outputs=gr.Image(type='pil') |
|
) |
|
|
|
|
|
|
|
|
|
Interface.launch() |
|
|