blanchon commited on
Commit
d16ff9a
1 Parent(s): 6e41a07

Add Seafoam theme to the project

Browse files
Files changed (2) hide show
  1. app.py +132 -59
  2. theme.py +56 -0
app.py CHANGED
@@ -1,36 +1,45 @@
1
  import torch
2
  import spaces
3
  from transformers import (
4
- AutoProcessor,
5
  BitsAndBytesConfig,
 
 
 
 
 
 
 
 
6
  LlavaForConditionalGeneration,
7
  )
8
  from PIL import Image
9
  import gradio as gr
10
  from threading import Thread
11
- from transformers import TextIteratorStreamer, AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer
12
  from dotenv import load_dotenv
13
- import os
14
  # Import Supabase functions
15
  from db_client import get_user_history, update_user_history, delete_user_history
 
16
  # Add these imports
17
  from datetime import datetime
18
  import pytz
19
  from gradio.components import LoginButton
20
  from typing import Optional
21
 
 
 
22
 
23
  load_dotenv()
24
 
25
  # Add TESTING variable
26
- TESTING = False
27
 
28
  IS_LOGGED_IN = False
29
  USER_ID = None
30
 
31
  # Hugging Face model id
32
- # model_id = "mistral-community/pixtral-12b"
33
- model_id = "blanchon/pixtral-nutrition-3e-05_r-3_epochs-constant_7a9a"
34
 
35
  # BitsAndBytesConfig int-4 config
36
  bnb_config = BitsAndBytesConfig(
@@ -76,10 +85,11 @@ processor.chat_template = """
76
  </s>
77
  {%- endif %}
78
  {%- endfor %}
79
- """.replace(' ', "")
80
 
81
  processor.tokenizer.pad_token = processor.tokenizer.eos_token
82
 
 
83
  @spaces.GPU
84
  def bot_streaming(chatbot, image_input, max_new_tokens=250):
85
  # Preprocess inputs
@@ -88,7 +98,7 @@ def bot_streaming(chatbot, image_input, max_new_tokens=250):
88
  text_input = chatbot[-1][0]
89
 
90
  # Get current time in Paris timezone
91
- paris_tz = pytz.timezone('Europe/Paris')
92
  current_time = datetime.now(paris_tz).strftime("%I:%M%p")
93
 
94
  if text_input != "":
@@ -104,15 +114,16 @@ def bot_streaming(chatbot, image_input, max_new_tokens=250):
104
  else:
105
  image = Image.fromarray(image_input).convert("RGB")
106
  images.append(image)
107
- messages.append({
108
- "role": "user",
109
- "content": [{"type": "text", "text": text_input}, {"type": "image"}]
110
- })
 
 
111
  else:
112
- messages.append({
113
- "role": "user",
114
- "content": [{"type": "text", "text": text_input}]
115
- })
116
 
117
  # Apply chat template
118
  texts = processor.apply_chat_template(messages)
@@ -131,7 +142,7 @@ def bot_streaming(chatbot, image_input, max_new_tokens=250):
131
 
132
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
133
  thread.start()
134
-
135
  response = ""
136
  for new_text in streamer:
137
  response += new_text
@@ -141,34 +152,59 @@ def bot_streaming(chatbot, image_input, max_new_tokens=250):
141
  thread.join()
142
 
143
  # Debug output
144
- print('*'*60)
145
- print('*'*60)
146
- print('BOT_STREAMING_CONV_START')
147
  for i, (request, answer) in enumerate(chatbot[:-1], 1):
148
- print(f'Q{i}:\n {request}')
149
- print(f'A{i}:\n {answer}')
150
- print('New_Q:\n', text_input)
151
- print('New_A:\n', response)
152
- print('BOT_STREAMING_CONV_END')
153
 
154
-
155
  if IS_LOGGED_IN:
156
- new_history = messages + [{"role": "assistant", "content": [{"type": "text", "text": response}]}]
 
 
157
  update_user_history(USER_ID, new_history)
158
 
 
 
 
159
  # Define the HTML content for the header
160
- html = f"""
161
- <p align="center" style="font-size: 2.5em; line-height: 1;">
 
162
  <span style="display: inline-block; vertical-align: middle;">🍽️</span>
163
  <span style="display: inline-block; vertical-align: middle;">PixDiet</span>
164
  </p>
165
- <center><font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font></center>
166
- <div style="display: flex; justify-content: center; align-items: center; margin-top: 20px;">
167
- <img src="https://zozs97eh0bkqexza.public.blob.vercel-storage.com/alan-VD7bRf1rKuEBL6EDAjw0eLGVodhoh8.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;">
168
- <img src="https://seeklogo.com/images/M/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;">
 
 
 
 
 
169
  </div>
170
  """
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  # Define LaTeX delimiters
173
  latex_delimiters_set = [
174
  {"left": "\\(", "right": "\\)", "display": False},
@@ -177,39 +213,71 @@ latex_delimiters_set = [
177
  {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True},
178
  {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True},
179
  {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True},
180
- {"left": "\\[", "right": "\\]", "display": True}
181
  ]
182
 
183
  # Create the Gradio interface
184
- with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo:
 
 
185
  gr.HTML(html)
186
-
187
-
188
  with gr.Row():
189
  with gr.Column(scale=3):
190
- image_input = gr.Image(label="Upload your meal image", height=350, type="pil")
 
 
 
 
 
191
  gr.Examples(
192
  examples=[
193
- ["./examples/mistral_breakfast.jpeg", ""],
194
- ["./examples/mistral_desert.jpeg", ""],
195
- ["./examples/mistral_snacks.jpeg", ""],
196
- ["./examples/mistral_pasta.jpeg", ""],
197
-
 
 
 
 
 
 
 
 
 
 
 
198
  ],
199
- inputs=[image_input, gr.Textbox(visible=False)]
200
  )
201
  with gr.Column(scale=7):
202
- chatbot = gr.Chatbot(label="Chat with PixDiet", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set)
203
- text_input = gr.Textbox(label="Ask about your meal", placeholder="(Optional) Enter your message here...", lines=1, container=False)
 
 
 
 
 
 
 
 
 
 
 
204
  with gr.Row():
205
  send_btn = gr.Button("Send", variant="primary", visible=True)
206
- login_button = LoginButton(visible=True)
207
- clear_btn = gr.Button("Delete my historic", variant="stop", visible=True)
 
 
 
 
208
 
209
  def submit_chat(chatbot, text_input):
210
- response = ''
211
  chatbot.append((text_input, response))
212
- return chatbot, ''
213
 
214
  def clear_chat():
215
  if USER_ID:
@@ -218,7 +286,7 @@ with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo:
218
 
219
  def user_logged_in(data, user: Optional[gr.OAuthProfile]):
220
  global IS_LOGGED_IN, USER_ID
221
-
222
  print("login")
223
  print(data)
224
 
@@ -233,7 +301,6 @@ with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo:
233
  USER_ID = "john doe"
234
 
235
  IS_LOGGED_IN = True
236
-
237
 
238
  def get_profile(profile) -> dict:
239
  print(dir(profile))
@@ -243,16 +310,22 @@ with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo:
243
  "name": profile.get("name"),
244
  }
245
 
246
- send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(
247
- bot_streaming, [chatbot, image_input], chatbot
248
- )
249
- submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(
250
- bot_streaming, [chatbot, image_input], chatbot
251
- )
252
  clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])
253
 
254
  # Add login event handler
255
- login_button.click(user_logged_in, inputs=[login_button], outputs=[login_button, send_btn, clear_btn])
 
 
 
 
 
 
256
 
257
  if __name__ == "__main__":
258
  demo.launch(debug=False, share=False, show_api=False)
 
1
  import torch
2
  import spaces
3
  from transformers import (
 
4
  BitsAndBytesConfig,
5
+ )
6
+ from transformers import (
7
+ TextIteratorStreamer,
8
+ AutoModelForCausalLM,
9
+ CodeGenTokenizerFast as Tokenizer,
10
+ )
11
+ from transformers import (
12
+ AutoProcessor,
13
  LlavaForConditionalGeneration,
14
  )
15
  from PIL import Image
16
  import gradio as gr
17
  from threading import Thread
 
18
  from dotenv import load_dotenv
19
+
20
  # Import Supabase functions
21
  from db_client import get_user_history, update_user_history, delete_user_history
22
+
23
  # Add these imports
24
  from datetime import datetime
25
  import pytz
26
  from gradio.components import LoginButton
27
  from typing import Optional
28
 
29
+ from theme import Seafoam
30
+
31
 
32
  load_dotenv()
33
 
34
  # Add TESTING variable
35
+ TESTING = False
36
 
37
  IS_LOGGED_IN = False
38
  USER_ID = None
39
 
40
  # Hugging Face model id
41
+ model_id = "mistral-community/pixtral-12b"
42
+ # model_id = "blanchon/pixtral-nutrition-3e-05_r-3_epochs-constant_7a9a"
43
 
44
  # BitsAndBytesConfig int-4 config
45
  bnb_config = BitsAndBytesConfig(
 
85
  </s>
86
  {%- endif %}
87
  {%- endfor %}
88
+ """.replace(" ", "")
89
 
90
  processor.tokenizer.pad_token = processor.tokenizer.eos_token
91
 
92
+
93
  @spaces.GPU
94
  def bot_streaming(chatbot, image_input, max_new_tokens=250):
95
  # Preprocess inputs
 
98
  text_input = chatbot[-1][0]
99
 
100
  # Get current time in Paris timezone
101
+ paris_tz = pytz.timezone("Europe/Paris")
102
  current_time = datetime.now(paris_tz).strftime("%I:%M%p")
103
 
104
  if text_input != "":
 
114
  else:
115
  image = Image.fromarray(image_input).convert("RGB")
116
  images.append(image)
117
+ messages.append(
118
+ {
119
+ "role": "user",
120
+ "content": [{"type": "text", "text": text_input}, {"type": "image"}],
121
+ }
122
+ )
123
  else:
124
+ messages.append(
125
+ {"role": "user", "content": [{"type": "text", "text": text_input}]}
126
+ )
 
127
 
128
  # Apply chat template
129
  texts = processor.apply_chat_template(messages)
 
142
 
143
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
144
  thread.start()
145
+
146
  response = ""
147
  for new_text in streamer:
148
  response += new_text
 
152
  thread.join()
153
 
154
  # Debug output
155
+ print("*" * 60)
156
+ print("*" * 60)
157
+ print("BOT_STREAMING_CONV_START")
158
  for i, (request, answer) in enumerate(chatbot[:-1], 1):
159
+ print(f"Q{i}:\n {request}")
160
+ print(f"A{i}:\n {answer}")
161
+ print("New_Q:\n", text_input)
162
+ print("New_A:\n", response)
163
+ print("BOT_STREAMING_CONV_END")
164
 
 
165
  if IS_LOGGED_IN:
166
+ new_history = messages + [
167
+ {"role": "assistant", "content": [{"type": "text", "text": response}]}
168
+ ]
169
  update_user_history(USER_ID, new_history)
170
 
171
+
172
+ seafoam = Seafoam()
173
+
174
  # Define the HTML content for the header
175
+ html = """
176
+ <!-- Foreground content -->
177
+ <p align="center" style="font-size: 2.5em; line-height: 1; ">
178
  <span style="display: inline-block; vertical-align: middle;">🍽️</span>
179
  <span style="display: inline-block; vertical-align: middle;">PixDiet</span>
180
  </p>
181
+ <center>
182
+ <font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font>
183
+ </center>
184
+ <!-- Background image positioned behind everything -->
185
+ <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;">
186
+ <div style="display: flex; justify-content: center; width: 100%;">
187
+ <img src="https://dropshare.blanchon.xyz/public/dropshare/alan.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;">
188
+ <img src="https://dropshare.blanchon.xyz/public/dropshare/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;">
189
+ </div>
190
  </div>
191
  """
192
 
193
+ footer_html = """
194
+ <!-- Footer content -->
195
+ <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;">
196
+ <div style="display: flex; justify-content: center; width: 100%;">
197
+ <img src="https://dropshare.blanchon.xyz/public/dropshare//VariantVariant6-Photoroom.png" alt="Background Image"
198
+ style="height: 100px; width: 100%; object-fit: scale-down;">
199
+ </div>
200
+ <div>
201
+ Made with ❤️ during the Mistral AI x Alan Hackathon.
202
+ </div>
203
+
204
+ </div>
205
+ """
206
+
207
+
208
  # Define LaTeX delimiters
209
  latex_delimiters_set = [
210
  {"left": "\\(", "right": "\\)", "display": False},
 
213
  {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True},
214
  {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True},
215
  {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True},
216
+ {"left": "\\[", "right": "\\]", "display": True},
217
  ]
218
 
219
  # Create the Gradio interface
220
+ with gr.Blocks(
221
+ title="PixDiet", theme=seafoam, css="footer{display:none !important}"
222
+ ) as demo:
223
  gr.HTML(html)
224
+
 
225
  with gr.Row():
226
  with gr.Column(scale=3):
227
+ about_you = gr.Textbox(
228
+ label="About you", placeholder="Add information about you here..."
229
+ )
230
+ image_input = gr.Image(
231
+ label="Upload your meal image", height=350, type="pil"
232
+ )
233
  gr.Examples(
234
  examples=[
235
+ [
236
+ "./examples/mistral_breakfast.jpeg",
237
+ "John, 45 years old, 80kg, lactose intolerant. Training for his first triathlon.",
238
+ ],
239
+ [
240
+ "./examples/mistral_desert.jpeg",
241
+ "Emma, 26 years old, 55kg, iron deficiency. Training for her first Ironman competition.",
242
+ ],
243
+ [
244
+ "./examples/mistral_snacks.jpeg",
245
+ "Paul, 34 years old, 62kg, no known pathologies. Focused on improving strength for weightlifting competitions.",
246
+ ],
247
+ [
248
+ "./examples/mistral_pasta.jpeg",
249
+ "Carla, 52 years old, 58kg, no known pathologies. Currently training for her first marathon.",
250
+ ],
251
  ],
252
+ inputs=[image_input, about_you],
253
  )
254
  with gr.Column(scale=7):
255
+ chatbot = gr.Chatbot(
256
+ label="Chat with PixDiet",
257
+ layout="panel",
258
+ height=700,
259
+ show_copy_button=True,
260
+ latex_delimiters=latex_delimiters_set,
261
+ )
262
+ text_input = gr.Textbox(
263
+ label="Ask about your meal",
264
+ placeholder="(Optional) Enter your message here...",
265
+ lines=1,
266
+ container=False,
267
+ )
268
  with gr.Row():
269
  send_btn = gr.Button("Send", variant="primary", visible=True)
270
+ login_button = LoginButton(visible=True, value="Login")
271
+ clear_btn = gr.Button(
272
+ "Delete my historic",
273
+ variant="stop",
274
+ visible=True,
275
+ )
276
 
277
  def submit_chat(chatbot, text_input):
278
+ response = ""
279
  chatbot.append((text_input, response))
280
+ return chatbot, ""
281
 
282
  def clear_chat():
283
  if USER_ID:
 
286
 
287
  def user_logged_in(data, user: Optional[gr.OAuthProfile]):
288
  global IS_LOGGED_IN, USER_ID
289
+
290
  print("login")
291
  print(data)
292
 
 
301
  USER_ID = "john doe"
302
 
303
  IS_LOGGED_IN = True
 
304
 
305
  def get_profile(profile) -> dict:
306
  print(dir(profile))
 
310
  "name": profile.get("name"),
311
  }
312
 
313
+ send_click_event = send_btn.click(
314
+ submit_chat, [chatbot, text_input], [chatbot, text_input]
315
+ ).then(bot_streaming, [chatbot, image_input], chatbot)
316
+ submit_event = text_input.submit(
317
+ submit_chat, [chatbot, text_input], [chatbot, text_input]
318
+ ).then(bot_streaming, [chatbot, image_input], chatbot)
319
  clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])
320
 
321
  # Add login event handler
322
+ login_button.click(
323
+ user_logged_in,
324
+ inputs=[login_button],
325
+ outputs=[login_button, send_btn, clear_btn],
326
+ )
327
+
328
+ gr.HTML(footer_html)
329
 
330
  if __name__ == "__main__":
331
  demo.launch(debug=False, share=False, show_api=False)
theme.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Iterable
3
+ from gradio.themes.base import Base
4
+ from gradio.themes.utils import colors, fonts, sizes
5
+
6
+
7
+ class Seafoam(Base):
8
+ def __init__(
9
+ self,
10
+ *,
11
+ primary_hue: colors.Color | str = colors.emerald,
12
+ secondary_hue: colors.Color | str = colors.blue,
13
+ neutral_hue: colors.Color | str = colors.blue,
14
+ spacing_size: sizes.Size | str = sizes.spacing_md,
15
+ radius_size: sizes.Size | str = sizes.radius_md,
16
+ text_size: sizes.Size | str = sizes.text_md,
17
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
18
+ fonts.GoogleFont("Quicksand"),
19
+ "ui-sans-serif",
20
+ "sans-serif",
21
+ ),
22
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
23
+ fonts.GoogleFont("IBM Plex Mono"),
24
+ "ui-monospace",
25
+ "monospace",
26
+ ),
27
+ ):
28
+ super().__init__(
29
+ primary_hue=primary_hue,
30
+ secondary_hue=secondary_hue,
31
+ neutral_hue=neutral_hue,
32
+ spacing_size=spacing_size,
33
+ radius_size=radius_size,
34
+ text_size=text_size,
35
+ font=font,
36
+ font_mono=font_mono,
37
+ )
38
+ super().set(
39
+ # Simpler, subtle background
40
+ body_background_fill="*primary_50",
41
+ body_background_fill_dark="*primary_900",
42
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
43
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
44
+ button_primary_text_color="white",
45
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
46
+ slider_color="*secondary_300",
47
+ slider_color_dark="*secondary_600",
48
+ block_title_text_weight="600",
49
+ block_border_width="3px",
50
+ block_shadow="*shadow_drop_lg",
51
+ button_primary_shadow="*shadow_drop_lg",
52
+ button_large_padding="32px",
53
+ )
54
+
55
+
56
+ seafoam = Seafoam()