PixDiet / app.py
blanchon's picture
Add again bnb
1a0ce51
raw
history blame
11.2 kB
import spaces
from transformers import (
TextIteratorStreamer,
)
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
LlavaForConditionalGeneration,
)
from PIL import Image
import gradio as gr
from threading import Thread
from dotenv import load_dotenv
# Import Supabase functions
from db_client import get_user_history, update_user_history, delete_user_history
# Add these imports
from datetime import datetime
import pytz
from gradio.components import LoginButton
from typing import Optional
from transformers import AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer
import torch
from theme import Seafoam
load_dotenv()
# Add TESTING variable
TESTING = False
IS_LOGGED_IN = False
USER_ID = None
# Hugging Face model id
# model_id = "mistral-community/pixtral-12b"
model_id = "blanchon/PixDiet-pixtral-nutrition-v2"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Modify the model and processor initialization
if TESTING:
model_id = "vikhyatk/moondream1"
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
processor = Tokenizer.from_pretrained(model_id)
else:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
)
processor = AutoProcessor.from_pretrained(model_id)
# Set the chat template for the tokenizer
processor.chat_template = """
{%- for message in messages %}
{%- if message.role == "user" %}
<s>[INST]
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- elif item.type == "image" %}
\n[IMG]
{%- endif %}
{%- endfor %}
[/INST]
{%- elif message.role == "assistant" %}
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- endif %}
{%- endfor %}
</s>
{%- endif %}
{%- endfor %}
""".replace(" ", "")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
@spaces.GPU
def bot_streaming(chatbot, image_input, max_new_tokens=250):
# Preprocess inputs
messages = get_user_history(USER_ID)
images = []
text_input = chatbot[-1][0]
# Get current time in Paris timezone
paris_tz = pytz.timezone("Europe/Paris")
current_time = datetime.now(paris_tz).strftime("%I:%M%p")
if text_input != "":
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
else:
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
# Add current message
if image_input is not None:
# Check if image_input is already a PIL Image
if isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
image = Image.fromarray(image_input).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text_input}, {"type": "image"}],
}
)
else:
messages.append(
{"role": "user", "content": [{"type": "text", "text": text_input}]}
)
# Apply chat template
texts = processor.apply_chat_template(messages)
# Process inputs
if not images:
inputs = processor(text=texts, return_tensors="pt").to("cuda")
else:
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
processor.tokenizer, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
chatbot[-1][1] = response
yield chatbot
thread.join()
# Debug output
print("*" * 60)
print("*" * 60)
print("BOT_STREAMING_CONV_START")
for i, (request, answer) in enumerate(chatbot[:-1], 1):
print(f"Q{i}:\n {request}")
print(f"A{i}:\n {answer}")
print("New_Q:\n", text_input)
print("New_A:\n", response)
print("BOT_STREAMING_CONV_END")
if IS_LOGGED_IN:
new_history = messages + [
{"role": "assistant", "content": [{"type": "text", "text": response}]}
]
update_user_history(USER_ID, new_history)
seafoam = Seafoam()
# Define the HTML content for the header
html = """
<!-- Foreground content -->
<p align="center" style="font-size: 2.5em; line-height: 1; ">
<span style="display: inline-block; vertical-align: middle;">🍽️</span>
<span style="display: inline-block; vertical-align: middle;">PixDiet</span>
</p>
<center>
<font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font>
</center>
<!-- Background image positioned behind everything -->
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;">
<div style="display: flex; justify-content: center; width: 100%;">
<img src="https://dropshare.blanchon.xyz/public/dropshare/alan.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;">
<img src="https://dropshare.blanchon.xyz/public/dropshare/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;">
</div>
</div>
"""
footer_html = """
<!-- Footer content -->
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;">
<div style="display: flex; justify-content: center; width: 100%;">
<img src="https://dropshare.blanchon.xyz/public/dropshare//VariantVariant6-Photoroom.png" alt="Background Image"
style="height: 100px; width: 100%; object-fit: scale-down;">
</div>
<div>
Made with ❤️ during the Mistral AI x Alan Hackathon.
</div>
</div>
"""
# Define LaTeX delimiters
latex_delimiters_set = [
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True},
{"left": "\\begin{align}", "right": "\\end{align}", "display": True},
{"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True},
{"left": "\\begin{gather}", "right": "\\end{gather}", "display": True},
{"left": "\\begin{CD}", "right": "\\end{CD}", "display": True},
{"left": "\\[", "right": "\\]", "display": True},
]
# Create the Gradio interface
with gr.Blocks(
title="PixDiet", theme=seafoam, css="footer{display:none !important}"
) as demo:
gr.HTML(html)
with gr.Row():
with gr.Column(scale=3):
about_you = gr.Textbox(
label="About you",
placeholder="Add information about you here...",
lines=3,
interactive=True,
)
image_input = gr.Image(
label="Upload your meal image", height=350, type="pil"
)
gr.Examples(
examples=[
[
"./examples/mistral_breakfast.jpeg",
"John, 45 years old, 80kg, lactose intolerant. Training for his first triathlon.",
],
[
"./examples/mistral_desert.jpeg",
"Emma, 26 years old, 55kg, iron deficiency. Training for her first Ironman competition.",
],
[
"./examples/mistral_snacks.jpeg",
"Paul, 34 years old, 62kg, no known pathologies. Focused on improving strength for weightlifting competitions.",
],
[
"./examples/mistral_pasta.jpeg",
"Carla, 52 years old, 58kg, no known pathologies. Currently training for her first marathon.",
],
],
inputs=[image_input, about_you],
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(
label="Chat with PixDiet",
layout="panel",
height=700,
show_copy_button=True,
latex_delimiters=latex_delimiters_set,
type=None,
)
text_input = gr.Textbox(
label="Ask about your meal",
placeholder="(Optional) Enter your message here...",
lines=1,
container=False,
interactive=True,
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary", visible=True)
login_button = LoginButton(visible=True, value="Login")
clear_btn = gr.Button(
"Delete my history",
variant="stop",
visible=True,
)
def submit_chat(chatbot, text_input):
response = ""
chatbot.append((text_input, response))
return chatbot, ""
def clear_chat():
if USER_ID:
delete_user_history(USER_ID)
return [], None, ""
def user_logged_in(data, user: Optional[gr.OAuthProfile]):
global IS_LOGGED_IN, USER_ID
print("login")
print(data)
profile = get_profile(user)
print(profile)
user = profile["username"]
print(f"User logged in: {USER_ID}")
if user is not None:
USER_ID = user
else:
USER_ID = "john doe"
IS_LOGGED_IN = True
def get_profile(profile) -> dict:
print(dir(profile))
return {
"username": profile.get("username"),
"profile": profile.get("profile"),
"name": profile.get("name"),
}
send_click_event = send_btn.click(
submit_chat, [chatbot, text_input], [chatbot, text_input]
).then(bot_streaming, [chatbot, image_input], chatbot)
submit_event = text_input.submit(
submit_chat, [chatbot, text_input], [chatbot, text_input]
).then(bot_streaming, [chatbot, image_input], chatbot)
clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])
# Add login event handler
login_button.click(
user_logged_in,
inputs=[login_button],
outputs=[login_button, send_btn, clear_btn],
)
gr.HTML(footer_html)
if __name__ == "__main__":
demo.launch(debug=False, share=False, show_api=False)