car_gan / app.py
norsu's picture
Update app.py
09b0913
raw
history blame
1.44 kB
import gradio as gr
import torch
from torch import nn
from torchvision.transforms.functional import to_pil_image
import torchvision.utils as vutils
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
nn.ConvTranspose2d(128, 512, 4,1,0,bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512,256,4,2,1,bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128,4,2,1,bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64,4,2,1,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4,2,1,bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
model = Generator(ngpu=0)
model.load_state_dict(torch.load('car_gen.pth',map_location='cpu'))
def generate(button):
model.eval()
noise = torch.randn(32,128,1,1)
with torch.inference_mode():
images = []
predictions = model(noise).detach().cpu()
generated_grid = vutils.make_grid(predictions, nrow=8, padding=2, normalize=True)
return to_pil_image(generated_grid)
Interface = gr.Interface(
title='CarGAN',
fn=generate,
inputs=gr.Button(label='Generate',size='lg'),
outputs=gr.Image(type='pil')
)
Interface.launch()