jmourad commited on
Commit
ece0fb3
β€’
1 Parent(s): c8c7ed1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -39
app.py CHANGED
@@ -1,15 +1,15 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
- from clip_interrogator import Config, Interrogator
5
- import random
6
  import re
 
7
  import requests
8
  import shutil
 
 
9
  from PIL import Image
 
10
 
11
-
12
- # Definir la funciΓ³n para generar prompt desde imagen
13
  config = Config()
14
  config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  config.blip_offload = False if torch.cuda.is_available() else True
@@ -19,6 +19,7 @@ config.blip_num_beams = 64
19
  config.clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
20
  ci = Interrogator(config)
21
 
 
22
  def get_prompt_from_image(image, mode):
23
  image = image.convert('RGB')
24
  if mode == 'best':
@@ -31,29 +32,18 @@ def get_prompt_from_image(image, mode):
31
  prompt = ci.interrogate_negative(image)
32
  return prompt
33
 
34
-
35
- # Definir la funciΓ³n para generar prompt desde texto
36
- model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
37
- tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
38
-
39
- def translate(text):
40
- with torch.no_grad():
41
- encoded = tokenizer([text], return_tensors='pt')
42
- sequences = model.generate(**encoded)
43
- return tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
44
-
45
  text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
46
 
47
  def text_generate(input):
48
  seed = random.randint(100, 1000000)
49
  set_seed(seed)
50
- text_in_english = translate(input)
51
  for count in range(6):
52
- sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
53
  list = []
54
  for sequence in sequences:
55
  line = sequence['generated_text'].strip()
56
- if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith((':', '-', 'β€”')) is False:
57
  list.append(line)
58
 
59
  result = "\n".join(list)
@@ -64,25 +54,24 @@ def text_generate(input):
64
  if count == 5:
65
  return result
66
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Definir la funciΓ³n que permite al usuario cargar una imagen desde una URL
69
- def load_image_from_url(url):
70
- response = requests.get(url, stream=True)
71
- if response.status_code == 200:
72
- with open('./image.jpg', 'wb') as f:
73
- response.raw.decode_content = True
74
- shutil.copyfileobj(response.raw, f)
75
- return Image.open('./image.jpg')
76
- else:
77
- raise ValueError("No se pudo cargar la imagen")
78
 
 
 
79
 
80
- # Crear la interfaz de usuario de Gradio
81
- with gr.Interface(
82
- [get_prompt_from_image, text_generate],
83
- [
84
- gr.inputs.Image(type='pil', label='Imagen'),
85
- gr.inputs.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Modo'),
86
- gr.inputs.Textbox(lines=6, label='Texto de entrada'),
87
- ],
88
- [
 
1
+ # Importar bibliotecas
 
2
  import torch
 
 
3
  import re
4
+ import random
5
  import requests
6
  import shutil
7
+ from clip_interrogator import Config, Interrogator
8
+ from transformers import pipeline, set_seed, AutoTokenizer, AutoModelForSeq2SeqLM
9
  from PIL import Image
10
+ import gradio as gr
11
 
12
+ # Configurar CLIP
 
13
  config = Config()
14
  config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  config.blip_offload = False if torch.cuda.is_available() else True
 
19
  config.clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
20
  ci = Interrogator(config)
21
 
22
+ # FunciΓ³n para generar prompt desde imagen
23
  def get_prompt_from_image(image, mode):
24
  image = image.convert('RGB')
25
  if mode == 'best':
 
32
  prompt = ci.interrogate_negative(image)
33
  return prompt
34
 
35
+ # FunciΓ³n para generar texto
 
 
 
 
 
 
 
 
 
 
36
  text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
37
 
38
  def text_generate(input):
39
  seed = random.randint(100, 1000000)
40
  set_seed(seed)
 
41
  for count in range(6):
42
+ sequences = text_pipe(input, max_length=random.randint(60, 90), num_return_sequences=8)
43
  list = []
44
  for sequence in sequences:
45
  line = sequence['generated_text'].strip()
46
+ if line != input and len(line) > (len(input) + 4) and line.endswith((':', '-', 'β€”')) is False:
47
  list.append(line)
48
 
49
  result = "\n".join(list)
 
54
  if count == 5:
55
  return result
56
 
57
+ # Crear interfaz gradio
58
+ with gr.Blocks() as block:
59
+ with gr.Column():
60
+ gr.HTML('<h1>MidJourney / SD2 Helper Tool</h1>')
61
+ with gr.Tab('Generate from Image'):
62
+ with gr.Row():
63
+ input_image = gr.Image(type='pil')
64
+ with gr.Column():
65
+ input_mode = gr.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Mode')
66
+ img_btn = gr.Button('Discover Image Prompt')
67
+ output_image = gr.Textbox(lines=6, label='Generated Prompt')
68
 
69
+ with gr.Tab('Generate from Text'):
70
+ input_text = gr.Textbox(lines=6, label='Your Idea', placeholder='Enter your content here...')
71
+ output_text = gr.Textbox(lines=6, label='Generated Prompt')
72
+ text_btn = gr.Button('Generate Prompt')
 
 
 
 
 
 
73
 
74
+ img_btn.click(fn=get_prompt_from_image, inputs=[input_image, input_mode], outputs=output_image)
75
+ text_btn.click(fn=text_generate, inputs=input_text, outputs=output_text)
76
 
77
+ block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=True, server_name='0.0.0.0')