import torch import spaces from transformers import ( AutoProcessor, BitsAndBytesConfig, LlavaForConditionalGeneration, ) from PIL import Image import gradio as gr from threading import Thread from transformers import TextIteratorStreamer, AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer from dotenv import load_dotenv import os # Import Supabase functions from db_client import get_user_history, update_user_history, delete_user_history # Add these imports from datetime import datetime import pytz load_dotenv() # Add TESTING variable TESTING = False # You can change this to False when not testing IS_LOGGED_IN = True USER_ID = "jeremie.feron@gmail.com" # Hugging Face model id model_id = "blanchon/pixtral-nutrition-2" # BitsAndBytesConfig int-4 config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) # Modify the model and processor initialization if TESTING: model_id = "vikhyatk/moondream1" model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) processor = Tokenizer.from_pretrained(model_id) else: model = LlavaForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) processor = AutoProcessor.from_pretrained(model_id) # Set the chat template for the tokenizer processor.chat_template = """ {%- for message in messages %} {%- if message.role == "user" %} [INST] {%- for item in message.content %} {%- if item.type == "text" %} {{ item.text }} {%- elif item.type == "image" %} \n[IMG] {%- endif %} {%- endfor %} [/INST] {%- elif message.role == "assistant" %} {%- for item in message.content %} {%- if item.type == "text" %} {{ item.text }} {%- endif %} {%- endfor %} {%- endif %} {%- endfor %} """.replace(' ', "") processor.tokenizer.pad_token = processor.tokenizer.eos_token @spaces.GPU def bot_streaming(chatbot, image_input, max_new_tokens=250): # Preprocess inputs messages = get_user_history(USER_ID) images = [] text_input = chatbot[-1][0] # Get current time in Paris timezone paris_tz = pytz.timezone('Europe/Paris') current_time = datetime.now(paris_tz).strftime("%I:%M%p") if text_input != "": text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" else: text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" # Add current message if image_input is not None: # Check if image_input is already a PIL Image if isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: image = Image.fromarray(image_input).convert("RGB") images.append(image) messages.append({ "role": "user", "content": [{"type": "text", "text": text_input}, {"type": "image"}] }) else: messages.append({ "role": "user", "content": [{"type": "text", "text": text_input}] }) # Apply chat template texts = processor.apply_chat_template(messages) # Process inputs if not images: inputs = processor(text=texts, return_tensors="pt").to("cuda") else: inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") streamer = TextIteratorStreamer( processor.tokenizer, skip_special_tokens=True, skip_prompt=True ) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() response = "" for new_text in streamer: response += new_text chatbot[-1][1] = response yield chatbot thread.join() # Debug output print('*'*60) print('*'*60) print('BOT_STREAMING_CONV_START') for i, (request, answer) in enumerate(chatbot[:-1], 1): print(f'Q{i}:\n {request}') print(f'A{i}:\n {answer}') print('New_Q:\n', text_input) print('New_A:\n', response) print('BOT_STREAMING_CONV_END') if IS_LOGGED_IN: new_history = messages + [{"role": "assistant", "content": [{"type": "text", "text": response}]}] update_user_history(USER_ID, new_history) # Define the HTML content for the header html = f"""

🍽️ PixDiet

PixDiet is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.
Alan AI Logo Mistral AI Logo
""" # Define LaTeX delimiters latex_delimiters_set = [ {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}, {"left": "\\begin{align}", "right": "\\end{align}", "display": True}, {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True}, {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True}, {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True}, {"left": "\\[", "right": "\\]", "display": True} ] # Create the Gradio interface with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo: gr.HTML(html) with gr.Row(): with gr.Column(scale=3): image_input = gr.Image(label="Upload your meal image", height=350, type="pil") gr.Examples( examples=[ ["./examples/mistral_breakfast.jpeg", ""], ["./examples/mistral_desert.jpeg", ""], ["./examples/mistral_snacks.jpeg", ""], ["./examples/mistral_pasta.jpeg", ""], ], inputs=[image_input, gr.Textbox(visible=False)] ) with gr.Column(scale=7): chatbot = gr.Chatbot(label="Chat with PixDiet", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set) text_input = gr.Textbox(label="Ask about your meal", placeholder="(Optional) Enter your message here...", lines=1, container=False) with gr.Row(): send_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Delete my historic", variant="huggingface") def submit_chat(chatbot, text_input): response = '' chatbot.append((text_input, response)) return chatbot, '' def clear_chat(): delete_user_history(USER_ID) return [], None, "" send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then( bot_streaming, [chatbot, image_input], chatbot ) submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then( bot_streaming, [chatbot, image_input], chatbot ) clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input]) if __name__ == "__main__": demo.launch(debug=False, share=False, show_api=False)