attr-cond-gan / app.py
buio's picture
fixed /content
177c91e
# -*- coding: utf-8 -*-
"""evaluate_gan_gradio.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1ckZU76dq3XWcpa5PpQF8a6qJwkTttg8v
# βš™οΈ Setup
"""
#!pip install gradio -q
#!pip install wget -q
#!pip install tensorflow_addons -q
"""## Fix random seeds"""
SEED = 11
import os
os.environ['PYTHONHASHSEED']=str(SEED)
import random
import numpy as np
import tensorflow as tf
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)
"""## Imports"""
import gradio as gr
import wget
import pandas as pd
import shutil
"""## Download CelebA attributes
We'll use face images from the CelebA dataset, resized to 64x64.
"""
#Download labels from public github, they have been processed in a 0,1 csv file
os.makedirs("content/celeba_gan")
wget.download(url="https://github.com/buoi/conditional-face-GAN/blob/main/list_attr_celeba01.csv.zip?raw=true", out="content/celeba_gan/list_attr_celeba01.csv.zip")
shutil.unpack_archive(filename="content/celeba_gan/list_attr_celeba01.csv.zip", extract_dir="content/celeba_gan")
"""## Dataset preprocessing functions"""
# image utils functions
def conv_range(in_range=(-1,1), out_range=(0,255)):
""" Returns range conversion function"""
# compute means and spans once
in_mean, out_mean = np.mean(in_range), np.mean(out_range)
in_span, out_span = np.ptp(in_range), np.ptp(out_range)
# return function
def convert_img_range(in_img):
out_img = (in_img - in_mean) / in_span
out_img = out_img * out_span + out_mean
return out_img
return convert_img_range
def crop128(img):
#return img[:, 77:141, 57:121]# 64,64 center crop
return img[:, 45:173, 25:153] #Β 128,128 center crop
def resize64(img):
return tf.image.resize(img, (64,64), antialias=True, method='bilinear')
"""# πŸ“‰ Evaluate model
## Load trained GAN
"""
#wandb artifacts
#sagan40 v18
#keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3MzA4NTY=/5f09f68e9bb5b09efbc37ad76cdcdbb0"
#saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/2676cd88ef1866d6e572916e413a933e"
#variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/5cab1cb7351f0732ea137fb2d2e0d4ec"
#index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6Mjg3NDY1OTU=/480b55762c3358f868b8cce53984736b"
#sagan10 v16
keras_metadata_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMDQwMDE=/392d036bf91d3648eb5a2fa74c1eb716"
saved_model_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a5f8608efcc5dafbe780babcffbc79a9"
variables_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/a62bf0c4bf7047c0a31df7d2cfdb54f0"
index_url = "https://api.wandb.ai/artifactsV2/gcp-us/buianifolli/QXJ0aWZhY3Q6MjYxMzQ0Mjg=/de6539a7f0909d1dafa89571c7df43d1"
#download model
gan_path = "content/gan_model/"
try:
os.remove(gan_path+"keras_metadata.pb")
os.remove(gan_path+"saved_model.pb")
os.remove(gan_path+"variables/variables.data-00000-of-00001")
os.remove(gan_path+"variables/variables.index")
except FileNotFoundError:
pass
os.makedirs(gan_path,exist_ok =True)
os.makedirs(gan_path+"/variables",exist_ok =True)
import wget
wget.download(keras_metadata_url, gan_path+"keras_metadata.pb",)
wget.download(saved_model_url, gan_path+"saved_model.pb")
wget.download(variables_url, gan_path+"variables/variables.data-00000-of-00001")
wget.download(index_url, gan_path+"variables/variables.index")
gan = tf.keras.models.load_model(gan_path)
IMAGE_RANGE='11'
IMAGE_SIZE = gan.discriminator.input_shape[1]
if IMAGE_SIZE == 64:
IMAGE_SHAPE = (64,64,3)
elif IMAGE_SIZE == 218:
IMAGE_SHAPE = (218,178,3)
try:
LATENT_DIM = gan.generator.input_shape[0][1]
N_ATTRIBUTES = gan.generator.input_shape[1][1]
except TypeError:
LATENT_DIM = gan.generator.input_shape[1]
N_ATTRIBUTES =0
"""## πŸ’Ύ Dataset"""
#@title Select Attributes {form-width: "50%", display-mode: "both" }
#NUMBER_OF_ATTRIBUTES = "10" #@param [0, 2, 10, 12, 40]
#N_ATTRIBUTES = int(NUMBER_OF_ATTRIBUTES)
IMAGE_RANGE = '11'
BATCH_SIZE = 64 #@param {type: "number"}
if N_ATTRIBUTES == 2:
LABELS = ["Male", "Smiling"]
elif N_ATTRIBUTES == 10:
LABELS = [
"Mouth_Slightly_Open", "Wearing_Lipstick", "High_Cheekbones", "Male", "Smiling",
"Heavy_Makeup", "Wavy_Hair", "Oval_Face", "Pointy_Nose", "Arched_Eyebrows"]
elif N_ATTRIBUTES == 12:
LABELS = ['Wearing_Lipstick','Mouth_Slightly_Open','Male','Smiling',
'High_Cheekbones','Heavy_Makeup','Attractive','Young',
'No_Beard','Black_Hair','Arched_Eyebrows','Big_Nose']
elif N_ATTRIBUTES == 40:
LABELS = [
'5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',
'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows',
'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair',
'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open',
'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin',
'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns',
'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace',
'Wearing_Necktie', 'Young']
else:
LABELS = ["Male", "Smiling"]# just for dataset creation
# Take labels and a list of image locations in memory
df = pd.read_csv(r"content/celeba_gan/list_attr_celeba01.csv")
attr_list = df[LABELS].values.tolist()
def gen_img(attributes):
attr = np.zeros((1,N_ATTRIBUTES))
for a in attributes:
attr[0,int(a)] = 1
num_img = 1
random_latent_vectors = tf.random.normal(shape=(num_img, LATENT_DIM))
generated_images = gan.generator((random_latent_vectors, attr))
generated_images = (generated_images*0.5+0.5).numpy()
print(generated_images[0].shape)
return generated_images[0]
iface = gr.Interface(
gen_img,
gr.inputs.CheckboxGroup([LABELS[i] for i in range(N_ATTRIBUTES)], type='index'),
"image",
layout='unaligned'
)
iface.launch()