Exched commited on
Commit
71d4c29
β€’
1 Parent(s): 8fd40bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -107
app.py CHANGED
@@ -1,59 +1,51 @@
 
 
 
1
  import gradio as gr
 
 
 
 
2
  from random import randint
3
- from all_models import models
4
- from externalmod import gr_Interface_load
5
  import asyncio
6
  from threading import RLock
7
- lock = RLock()
8
-
9
- def load_fn(models):
10
- global models_load
11
- models_load = {}
12
-
13
- for model in models:
14
- if model not in models_load.keys():
15
- try:
16
- m = gr_Interface_load(f'models/{model}')
17
- except Exception as error:
18
- print(error)
19
- m = gr.Interface(lambda: None, ['text'], ['image'])
20
- models_load.update({model: m})
21
-
22
-
23
- load_fn(models)
24
-
25
-
26
- num_models = 6
27
- default_models = models[:num_models]
28
- timeout = 300
29
-
30
- def extend_choices(choices):
31
- return choices[:num_models] + (num_models - len(choices[:num_models])) * ['NA']
32
 
 
 
 
 
33
 
34
- def update_imgbox(choices):
35
- choices_plus = extend_choices(choices[:num_models])
36
- return [gr.Image(None, label = m, visible = (m != 'NA')) for m in choices_plus]
37
 
 
 
 
38
 
39
- def update_imgbox_gallery(choices):
40
- choices_plus = extend_choices(choices[:num_models])
41
- return [gr.Gallery(None, label = m, visible = (m != 'NA')) for m in choices_plus]
 
 
 
 
42
 
 
43
 
44
- async def infer(model_str, prompt, timeout):
45
- from PIL import Image
46
  noise = ""
47
  rand = randint(1, 500)
48
  for i in range(rand):
49
  noise += " "
50
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str], f'{prompt} {noise}'))
51
  await asyncio.sleep(0)
52
  try:
53
  result = await asyncio.wait_for(task, timeout=timeout)
54
  except (Exception, asyncio.TimeoutError) as e:
55
  print(e)
56
- print(f"Task timed out: {model_str}")
57
  if not task.done(): task.cancel()
58
  result = None
59
  if task.done() and result is not None:
@@ -62,97 +54,107 @@ async def infer(model_str, prompt, timeout):
62
  return image
63
  return None
64
 
65
- def gen_fn(model_str, prompt):
66
- if model_str == 'NA':
67
- return None
68
  try:
69
  loop = asyncio.new_event_loop()
70
- result = loop.run_until_complete(infer(model_str, prompt, timeout))
71
  except (Exception, asyncio.CancelledError) as e:
72
  print(e)
73
- print(f"Task aborted: {model_str}")
74
  result = None
75
  finally:
76
  loop.close()
77
  return result
78
 
79
-
80
- def add_gallery(image, model_str, gallery):
81
  if gallery is None: gallery = []
82
  with lock:
83
- if image is not None: gallery.insert(0, (image, model_str))
84
  return gallery
85
 
86
-
87
- def gen_fn_gallery(model_str, prompt, gallery):
88
  if gallery is None: gallery = []
89
- if model_str == 'NA':
90
- yield gallery
91
  try:
92
  loop = asyncio.new_event_loop()
93
- result = loop.run_until_complete(infer(model_str, prompt, timeout))
94
  with lock:
95
  if result: gallery.insert(0, result)
96
  except (Exception, asyncio.CancelledError) as e:
97
  print(e)
98
- print(f"Task aborted: {model_str}")
99
  finally:
100
  loop.close()
101
  yield gallery
102
 
103
-
104
- CSS="""
105
- #container { max-width: 1200px; margin: 0 auto; !important; }
106
- .output { width=112px; height=112px; !important; }
107
- .gallery { width=100%; min_height=768px; !important; }
108
- .guide { text-align: center; !important; }
109
- """
110
-
111
- with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=CSS) as demo:
112
- gr.HTML(
113
- """
114
- <div>
115
- <p> <center>For simultaneous generations without hidden queue check out <a href="https://huggingface.co/spaces/Yntec/ToyWorld">Toy World</a>! For more options like single model x6 check out <a href="https://huggingface.co/spaces/John6666/Diffusion80XX4sg">Diffusion80XX4sg</a> by John6666!</center>
116
- </p></div>
117
- """
118
- )
119
- with gr.Tab('Huggingface Diffusion'):
120
- with gr.Column(scale=2):
121
- txt_input = gr.Textbox(label='Your prompt:', lines=4)
122
- with gr.Row():
123
- gen_button = gr.Button(f'Generate up to {int(num_models)} images from 1 to {int(num_models)*3} minutes total', scale=2)
124
- stop_button = gr.Button('Stop', variant='secondary', interactive=False, scale=1)
125
- gen_button.click(lambda: gr.update(interactive = True), None, stop_button)
126
- gr.Markdown("Scroll down to see more images and select models.", elem_classes="guide")
127
-
128
- with gr.Column(scale=1):
129
- with gr.Group():
130
- with gr.Row():
131
- output = [gr.Image(label=m, show_download_button=True, elem_classes="output", interactive=False, min_width=80, show_share_button=False, visible=True) for m in default_models]
132
- #output = [gr.Image(label=m, show_download_button=True, elem_classes="output", interactive=False, show_share_button=True) for m in default_models]
133
- #output = [gr.Gallery(label=m, show_download_button=True, elem_classes="output", interactive=False, show_share_button=True, container=True, format="png", object_fit="cover") for m in default_models]
134
- current_models = [gr.Textbox(m, visible=False) for m in default_models]
135
-
136
- with gr.Column(scale=2):
137
- gallery = gr.Gallery(label="Output", show_download_button=True, elem_classes="gallery",
138
- interactive=False, show_share_button=True, container=True, format="png",
139
- preview=True, object_fit="cover", columns=2, rows=2)
140
-
141
- for m, o in zip(current_models, output):
142
- #gen_event = gen_button.click(gen_fn, [m, txt_input], o)
143
- #gen_event = gen_button.click(gen_fn_gallery, [m, txt_input, o], o)
144
- gen_event = gr.on(triggers=[gen_button.click, txt_input.submit], fn=gen_fn, inputs=[m, txt_input], outputs=[o])
145
- o.change(add_gallery, [o, m, gallery], [gallery])
146
- stop_button.click(lambda: gr.update(interactive = False), None, stop_button, cancels = [gen_event])
147
-
148
- with gr.Column(scale=4):
149
- with gr.Accordion('Model selection'):
150
- model_choice = gr.CheckboxGroup(models, label = f'Choose up to {int(num_models)} different models from the {len(models)} available!', value=default_models, interactive=True)
151
- model_choice.change(update_imgbox, model_choice, output)
152
- #model_choice.change(update_imgbox_gallery, model_choice, output)
153
- model_choice.change(extend_choices, model_choice, current_models)
154
-
155
- gr.Markdown("Based on the [TestGen](https://huggingface.co/spaces/derwahnsinn/TestGen) Space by derwahnsinn, the [SpacIO](https://huggingface.co/spaces/RdnUser77/SpacIO_v1) Space by RdnUser77 and Omnibus's Maximum Multiplier!")
156
-
157
- demo.queue()
158
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
  import gradio as gr
5
+ from transformers import BlipProcessor, BlipForConditionalGeneration
6
+ from langchain_huggingface import HuggingFaceEndpoint
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
  from random import randint
 
 
10
  import asyncio
11
  from threading import RLock
12
+ from externalmod import gr_Interface_load
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Define model IDs
15
+ llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3"
16
+ blip_model_id = "Salesforce/blip-image-captioning-large"
17
+ model_name = "John6666/wai-ani-hentai-pony-v3-sdxl"
18
 
19
+ # Initialize BLIP processor and model
20
+ processor = BlipProcessor.from_pretrained(blip_model_id)
21
+ model = BlipForConditionalGeneration.from_pretrained(blip_model_id)
22
 
23
+ # Initialize the model loading function
24
+ lock = RLock()
25
+ model_load = None
26
 
27
+ def load_fn(model):
28
+ global model_load
29
+ try:
30
+ model_load = gr_Interface_load(f'models/{model}')
31
+ except Exception as error:
32
+ print(error)
33
+ model_load = gr.Interface(lambda: None, ['text'], ['image'])
34
 
35
+ load_fn(model_name)
36
 
37
+ async def infer(prompt, timeout):
 
38
  noise = ""
39
  rand = randint(1, 500)
40
  for i in range(rand):
41
  noise += " "
42
+ task = asyncio.create_task(asyncio.to_thread(model_load, f'{prompt} {noise}'))
43
  await asyncio.sleep(0)
44
  try:
45
  result = await asyncio.wait_for(task, timeout=timeout)
46
  except (Exception, asyncio.TimeoutError) as e:
47
  print(e)
48
+ print(f"Task timed out: {model_name}")
49
  if not task.done(): task.cancel()
50
  result = None
51
  if task.done() and result is not None:
 
54
  return image
55
  return None
56
 
57
+ def gen_fn(prompt):
 
 
58
  try:
59
  loop = asyncio.new_event_loop()
60
+ result = loop.run_until_complete(infer(prompt, timeout=300))
61
  except (Exception, asyncio.CancelledError) as e:
62
  print(e)
63
+ print(f"Task aborted: {model_name}")
64
  result = None
65
  finally:
66
  loop.close()
67
  return result
68
 
69
+ def add_gallery(image, gallery):
 
70
  if gallery is None: gallery = []
71
  with lock:
72
+ if image is not None: gallery.insert(0, (image, model_name))
73
  return gallery
74
 
75
+ def gen_fn_gallery(prompt, gallery):
 
76
  if gallery is None: gallery = []
 
 
77
  try:
78
  loop = asyncio.new_event_loop()
79
+ result = loop.run_until_complete(infer(prompt, timeout=300))
80
  with lock:
81
  if result: gallery.insert(0, result)
82
  except (Exception, asyncio.CancelledError) as e:
83
  print(e)
84
+ print(f"Task aborted: {model_name}")
85
  finally:
86
  loop.close()
87
  yield gallery
88
 
89
+ def generate_caption(image, min_len=30, max_len=100):
90
+ try:
91
+ inputs = processor(image, return_tensors="pt")
92
+ out = model.generate(**inputs, min_length=min_len, max_length=max_len)
93
+ caption = processor.decode(out[0], skip_special_tokens=True)
94
+ return caption
95
+ except Exception as e:
96
+ return 'Unable to generate caption.'
97
+
98
+ def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1):
99
+ try:
100
+ llm = HuggingFaceEndpoint(
101
+ repo_id=model_id,
102
+ max_new_tokens=max_new_tokens,
103
+ temperature=temperature,
104
+ token=os.getenv("HF_TOKEN")
105
+ )
106
+ except Exception as e:
107
+ print(f"Error loading model: {e}")
108
+ llm = None
109
+ return llm
110
+
111
+ def get_response(system_message, chat_history, user_text, max_new_tokens=256):
112
+ hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
113
+ if hf is None:
114
+ return "Error with model inference.", chat_history
115
+
116
+ prompt = PromptTemplate.from_template(
117
+ "[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
118
+ )
119
+ chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
120
+ response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
121
+ response = response.split("AI:")[-1]
122
+
123
+ chat_history.append({'role': 'user', 'content': user_text})
124
+ chat_history.append({'role': 'assistant', 'content': response})
125
+ return response, chat_history
126
+
127
+ def chat_function(user_text, uploaded_image, system_message, chat_history):
128
+ # If an image is uploaded, generate a caption for it
129
+ if uploaded_image:
130
+ caption = generate_caption(uploaded_image)
131
+ chat_history.append({'role': 'user', 'content': f'![uploaded image](data:image/png;base64,{uploaded_image})'})
132
+ chat_history.append({'role': 'assistant', 'content': caption})
133
+ # Return the updated chat history
134
+ return chat_history, chat_history
135
+
136
+ # If no image is uploaded, generate a response from the chat model
137
+ response, updated_history = get_response(system_message, chat_history, user_text)
138
+ return response, updated_history
139
+
140
+ def gradio_interface():
141
+ with gr.Blocks() as demo:
142
+ gr.Markdown("# Personal HuggingFace ChatBot")
143
+
144
+ with gr.Row():
145
+ with gr.Column():
146
+ txt_input = gr.Textbox(label='Enter your text here', lines=4)
147
+ img_input = gr.Image(label='Upload an image', type='pil')
148
+ system_message = gr.Textbox(label='System Message', value="You are a friendly AI conversing with a human user.")
149
+ chat_history = gr.State(value=[{'role': 'assistant', 'content': 'Hello, there! How can I help you today?'}])
150
+
151
+ submit_btn = gr.Button('Submit')
152
+ response_output = gr.Markdown()
153
+ gallery_output = gr.Gallery(label="Generated Images", show_download_button=True, elem_classes="gallery", interactive=False, show_share_button=True, container=True)
154
+
155
+ submit_btn.click(chat_function, inputs=[txt_input, img_input, system_message, chat_history], outputs=[response_output, chat_history])
156
+ img_input.change(lambda img: add_gallery(gen_fn("Generate image of a fantasy scene"), gallery_output), inputs=[img_input], outputs=[gallery_output])
157
+
158
+ demo.launch()
159
+
160
+ gradio_interface()