Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Apr 26 21:02:31 2022 | |
@author: pc | |
""" | |
import pickle | |
import numpy as np | |
import torch | |
import gradio as gr | |
import sys | |
import subprocess | |
import os | |
from typing import Tuple | |
import PIL.Image | |
os.system("git clone https://github.com/NVlabs/stylegan3") | |
sys.path.append("stylegan3") | |
DESCRIPTION = f'''This model generates healthy MR Brain Images. | |
![Example](ex.png) | |
''' | |
def make_transform(translate: Tuple[float,float], angle: float): | |
m = np.eye(3) | |
s = np.sin(angle/360.0*np.pi*2) | |
c = np.cos(angle/360.0*np.pi*2) | |
m[0][0] = c | |
m[0][1] = s | |
m[0][2] = translate[0] | |
m[1][0] = -s | |
m[1][1] = c | |
m[1][2] = translate[1] | |
return m | |
network_pkl='braingan-400.pkl' | |
with open(network_pkl, 'rb') as f: | |
G = pickle.load(f)['G_ema'] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
G.eval() | |
G.to(device) | |
def predict(Seed,noise_mode,truncation_psi,trans_x,trans_y,angle): | |
# Generate images. | |
z = torch.from_numpy(np.random.RandomState(Seed).randn(1, G.z_dim)).to(device) | |
label = torch.zeros([1, G.c_dim], device=device) | |
# Construct an inverse rotation/translation matrix and pass to the generator. The | |
# generator expects this matrix as an inverse to avoid potentially failing numerical | |
# operations in the network. | |
if hasattr(G.synthesis, 'input'): | |
m = make_transform((trans_x,trans_y), angle) | |
m = np.linalg.inv(m) | |
G.synthesis.input.transform.copy_(torch.from_numpy(m)) | |
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
return (PIL.Image.fromarray(img[0].cpu().numpy()[:,:,0])).resize((512,512)) | |
noises=['const', 'random', 'none'] | |
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2", | |
description = DESCRIPTION, | |
article = "Author: S.Serdar Helli", | |
inputs=[gr.inputs.Slider( minimum=0, maximum=2**12,label='Seed'),gr.inputs.Radio( choices=noises, default='const',label='Noise Mods'), | |
gr.inputs.Slider(0, 2, step=0.05, default=1, label='Truncation psi'), | |
gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate X'), | |
gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate Y'), | |
gr.inputs.Slider(-180, 180, step=5, default=0, label='Angle'),], | |
outputs=gr.outputs.Image( type="numpy", label="Output")) | |
interface.launch(debug=True) |