Gabriel Ichcanziho commited on
Commit
58ab09b
1 Parent(s): 1c3a9db
Files changed (4) hide show
  1. app.py +56 -0
  2. assets/logo.png +0 -0
  3. requirements.txt +2 -0
  4. utils.py +2 -2
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import carga_modelo, genera
3
+
4
+ # Página principal
5
+ st.title("Butterfly GAN (GAN de mariposas)")
6
+ st.write(
7
+ "Modelo Light-GAN entrenado con 1000 imágenes de mariposas tomadas de la colección del Museo Smithsonian."
8
+ )
9
+
10
+ # Barra lateral
11
+ st.sidebar.subheader("¡Estas mariposas no existen! 🤯.")
12
+ st.sidebar.image("assets/logo.png", width=200)
13
+ st.sidebar.caption(
14
+ f"[Modelo](https://huggingface.co/ceyda/butterfly_cropped_uniq1K_512) y [Dataset]("
15
+ f"https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) usados."
16
+ )
17
+ st.sidebar.caption(f"*Disclaimers:*")
18
+ st.sidebar.caption(
19
+ "* Este demo creada a partir del curso de Platzi: Curso de Experimentación en Machine Learning con Hugging Face."
20
+ )
21
+
22
+ # Cargamos modelo
23
+ repo_id = "ceyda/butterfly_cropped_uniq1K_512"
24
+ version_modelo = "57d36a15546909557d9f967f47713236c8288838"
25
+ modelo_gan = carga_modelo(repo_id, version_modelo)
26
+
27
+ # Generamos 4 mariposas
28
+ n_mariposas = 4
29
+
30
+ # Función que genera mariposas y lo guarda como un estado de la sesión
31
+ def corre():
32
+ with st.spinner("Generando, espera un poco..."):
33
+ ims = genera(modelo_gan, n_mariposas)
34
+ st.session_state["ims"] = ims
35
+
36
+
37
+ # Si no hay una imagen generada entonces generala
38
+ if "ims" not in st.session_state:
39
+ st.session_state["ims"] = None
40
+ corre()
41
+
42
+ # ims contiene las imágenes generadas
43
+ ims = st.session_state["ims"]
44
+
45
+ # Si la usuaria da click en el botón entonces corremos la función genera()
46
+ corre_boton = st.button(
47
+ "Genera mariposas, porfa.",
48
+ on_click=corre,
49
+ help="Estamos en pleno vuelo, puede tardar.",
50
+ )
51
+
52
+ if ims is not None:
53
+ cols = st.columns(n_mariposas)
54
+ for j, im in enumerate(ims):
55
+ i = j % n_mariposas
56
+ cols[i].image(im, use_column_width=True)
assets/logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan
2
+ transformers
utils.py CHANGED
@@ -3,14 +3,14 @@ import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
4
 
5
 
6
- ## Cargamos el modelo desde el Hub de Hugging Face
7
  def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
8
  gan = LightweightGAN.from_pretrained(model_name, version=model_version)
9
  gan.eval()
10
  return gan
11
 
12
 
13
- ## Usamos el modelo GAN para generar imágenes
14
  def genera(gan, batch_size=1):
15
  with torch.no_grad():
16
  ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
 
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
4
 
5
 
6
+ # Cargamos el modelo desde el Hub de Hugging Face
7
  def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
8
  gan = LightweightGAN.from_pretrained(model_name, version=model_version)
9
  gan.eval()
10
  return gan
11
 
12
 
13
+ # Usamos el modelo GAN para generar imágenes
14
  def genera(gan, batch_size=1):
15
  with torch.no_grad():
16
  ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255