Raymundo Gonzalez commited on
Commit
9325b9d
1 Parent(s): 4b50bb8

Agregando los archivos al repositorio

Browse files
Files changed (4) hide show
  1. app.py +43 -0
  2. assets/logo.png +0 -0
  3. requirements.txt +1 -0
  4. utils.py +36 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import carga_modelo, genera
3
+
4
+ ## Página principal
5
+ st.title("Generador de mariposas")
6
+ st.write(
7
+ "Este es un generador de mariposas que utiliza una red neuronal generativa (GAN)"
8
+ )
9
+
10
+ ## Barra lateral
11
+ st.sidebar.subheader("¡Esta mariposa no existe!, ¿Puedes creerlo?")
12
+ st.sidebar.image("assets/logo.png", width=200)
13
+ st.sidebar.caption("Demo creado en vivo.")
14
+
15
+ ## Cargamos el modelo
16
+ repo_id = "ceyda/butterfly_cropped_uniq1K_512"
17
+ modelo_gan = carga_modelo(repo_id)
18
+
19
+ # Numero de mariposas a generar
20
+ n_mariposas = 4
21
+
22
+
23
+ def corre():
24
+ with st.spinner("Generando mariposa..."):
25
+ ims = genera(modelo_gan, n_mariposas)
26
+ st.session_state["ims"] = ims
27
+
28
+
29
+ if "ims" not in st.session_state:
30
+ st.session_state["ims"] = None
31
+ corre()
32
+
33
+ ims = st.session_state["ims"]
34
+
35
+ corre_boton = st.button(
36
+ "Generar mariposas", on_click=corre, help="Genera mariposas aleatorias."
37
+ )
38
+
39
+ if ims is not None:
40
+ cols = st.columns(n_mariposas)
41
+ for j, im in enumerate(ims):
42
+ i = j % n_mariposas
43
+ cols[i].image(im, use_column_width=True)
assets/logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git+https://github.com/huggingface/community-events.git@3fea10c5d5a50c69f509e34cd580fe9139905d04#egg=huggan
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
4
+
5
+
6
+ def carga_modelo(model_name="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
7
+ """
8
+ Loads a pre-trained LightweightGAN model from Hugging Face Model Hub.
9
+
10
+ Args:
11
+ model_name (str): The name of the pre-trained model to load. Defaults to "ceyda/butterfly_cropped_uniq1K_512".
12
+ model_version (str): The version of the pre-trained model to load. Defaults to None.
13
+
14
+ Returns:
15
+ LightweightGAN: The loaded pre-trained model.
16
+ """
17
+ gan = LightweightGAN.from_pretrained(model_name, version=model_version)
18
+ gan.eval()
19
+ return gan
20
+
21
+
22
+ def genera(gan, batch_size=1):
23
+ """
24
+ Generates images using the given GAN model.
25
+
26
+ Args:
27
+ gan (nn.Module): The GAN model to use for generating images.
28
+ batch_size (int, optional): The number of images to generate in each batch. Defaults to 1.
29
+
30
+ Returns:
31
+ numpy.ndarray: A numpy array of generated images.
32
+ """
33
+ with torch.no_grad():
34
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255
35
+ ims = ims.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
36
+ return ims