Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
"""evaluate_gan_gradio.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1ckZU76dq3XWcpa5PpQF8a6qJwkTttg8v | |
# βοΈ Setup | |
""" | |
#!pip install gradio -q | |
#!pip install wget -q | |
#!pip install tensorflow_addons -q | |
"""## Fix random seeds""" | |
SEED = 11 | |
import os | |
os.environ['PYTHONHASHSEED']=str(SEED) | |
import random | |
import numpy as np | |
import tensorflow as tf | |
random.seed(SEED) | |
np.random.seed(SEED) | |
tf.random.set_seed(SEED) | |
"""## Imports""" | |
import gradio as gr | |
import wget | |
import pandas as pd | |
import shutil | |
"""## Download CelebA attributes | |
We'll use face images from the CelebA dataset, resized to 64x64. | |
""" | |
#Download labels from public github, they have been processed in a 0,1 csv file | |
os.makedirs("content/celeba_gan") | |
wget.download(url="https://github.com/buoi/conditional-face-GAN/blob/main/list_attr_celeba01.csv.zip?raw=true", out="content/celeba_gan/list_attr_celeba01.csv.zip") | |
shutil.unpack_archive(filename="content/celeba_gan/list_attr_celeba01.csv.zip", extract_dir="content/celeba_gan") | |
"""## Dataset preprocessing functions""" | |
# image utils functions | |
def conv_range(in_range=(-1,1), out_range=(0,255)): | |
""" Returns range conversion function""" | |
# compute means and spans once | |
in_mean, out_mean = np.mean(in_range), np.mean(out_range) | |
in_span, out_span = np.ptp(in_range), np.ptp(out_range) | |
# return function | |
def convert_img_range(in_img): | |
out_img = (in_img - in_mean) / in_span | |
out_img = out_img * out_span + out_mean | |
return out_img | |
return convert_img_range | |
def crop128(img): | |
#return img[:, 77:141, 57:121]# 64,64 center crop | |
return img[:, 45:173, 25:153] #Β 128,128 center crop | |
def resize64(img): | |
return tf.image.resize(img, (64,64), antialias=True, method='bilinear') | |
"""# π Evaluate model | |
## Load trained GAN | |
""" | |
#wandb artifacts | |
#sagan40 v18 | |
#keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3MzA4NTY=/5f09f68e9bb5b09efbc37ad76cdcdbb0" | |
#saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/2676cd88ef1866d6e572916e413a933e" | |
#variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/5cab1cb7351f0732ea137fb2d2e0d4ec" | |
#index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/480b55762c3358f868b8cce53984736b" | |
#sagan10 v16 | |
keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMDQwMDE=/392d036bf91d3648eb5a2fa74c1eb716" | |
saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a5f8608efcc5dafbe780babcffbc79a9" | |
variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a62bf0c4bf7047c0a31df7d2cfdb54f0" | |
index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/de6539a7f0909d1dafa89571c7df43d1" | |
#download model | |
gan_path = "content/gan_model/" | |
try: | |
os.remove(gan_path+"keras_metadata.pb") | |
os.remove(gan_path+"saved_model.pb") | |
os.remove(gan_path+"variables/variables.data-00000-of-00001") | |
os.remove(gan_path+"variables/variables.index") | |
except FileNotFoundError: | |
pass | |
os.makedirs(gan_path,exist_ok =True) | |
os.makedirs(gan_path+"/variables",exist_ok =True) | |
import wget | |
wget.download(keras_metadata_url, gan_path+"keras_metadata.pb",) | |
wget.download(saved_model_url, gan_path+"saved_model.pb") | |
wget.download(variables_url, gan_path+"variables/variables.data-00000-of-00001") | |
wget.download(index_url, gan_path+"variables/variables.index") | |
gan = tf.keras.models.load_model(gan_path) | |
IMAGE_RANGE='11' | |
IMAGE_SIZE = gan.discriminator.input_shape[1] | |
if IMAGE_SIZE == 64: | |
IMAGE_SHAPE = (64,64,3) | |
elif IMAGE_SIZE == 218: | |
IMAGE_SHAPE = (218,178,3) | |
try: | |
LATENT_DIM = gan.generator.input_shape[0][1] | |
N_ATTRIBUTES = gan.generator.input_shape[1][1] | |
except TypeError: | |
LATENT_DIM = gan.generator.input_shape[1] | |
N_ATTRIBUTES =0 | |
"""## πΎ Dataset""" | |
#@title Select Attributes {form-width: "50%", display-mode: "both" } | |
#NUMBER_OF_ATTRIBUTES = "10" #@param [0, 2, 10, 12, 40] | |
#N_ATTRIBUTES = int(NUMBER_OF_ATTRIBUTES) | |
IMAGE_RANGE = '11' | |
BATCH_SIZE = 64 #@param {type: "number"} | |
if N_ATTRIBUTES == 2: | |
LABELS = ["Male", "Smiling"] | |
elif N_ATTRIBUTES == 10: | |
LABELS = [ | |
"Mouth_Slightly_Open", "Wearing_Lipstick", "High_Cheekbones", "Male", "Smiling", | |
"Heavy_Makeup", "Wavy_Hair", "Oval_Face", "Pointy_Nose", "Arched_Eyebrows"] | |
elif N_ATTRIBUTES == 12: | |
LABELS = ['Wearing_Lipstick','Mouth_Slightly_Open','Male','Smiling', | |
'High_Cheekbones','Heavy_Makeup','Attractive','Young', | |
'No_Beard','Black_Hair','Arched_Eyebrows','Big_Nose'] | |
elif N_ATTRIBUTES == 40: | |
LABELS = [ | |
'5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', | |
'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', | |
'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', | |
'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', | |
'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', | |
'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', | |
'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', | |
'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', | |
'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', | |
'Wearing_Necktie', 'Young'] | |
else: | |
LABELS = ["Male", "Smiling"]# just for dataset creation | |
# Take labels and a list of image locations in memory | |
df = pd.read_csv(r"content/celeba_gan/list_attr_celeba01.csv") | |
attr_list = df[LABELS].values.tolist() | |
def gen_img(attributes): | |
attr = np.zeros((1,N_ATTRIBUTES)) | |
for a in attributes: | |
attr[0,int(a)] = 1 | |
num_img = 1 | |
random_latent_vectors = tf.random.normal(shape=(num_img, LATENT_DIM)) | |
generated_images = gan.generator((random_latent_vectors, attr)) | |
generated_images = (generated_images*0.5+0.5).numpy() | |
print(generated_images[0].shape) | |
return generated_images[0] | |
iface = gr.Interface( | |
gen_img, | |
gr.inputs.CheckboxGroup([LABELS[i] for i in range(N_ATTRIBUTES)], type='index'), | |
"image", | |
layout='unaligned' | |
) | |
iface.launch() |