File size: 5,963 Bytes
9cc6120 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import os
import uuid
from base64 import b64encode
from datetime import datetime
from mimetypes import guess_type
from pathlib import Path
import gradio as gr
from huggingface_hub import InferenceClient
from pandas import DataFrame
from feedback import save_feedback
client = InferenceClient(
token=os.getenv("HF_TOKEN"),
model=(
os.getenv("MODEL", "meta-llama/Llama-3.2-11B-Vision-Instruct")
if not os.getenv("BASE_URL")
else None
),
base_url=os.getenv("BASE_URL"),
)
def add_user_message(history, message):
for x in message["files"]:
history.append({"role": "user", "content": {"path": x}})
if message["text"] is not None:
history.append({"role": "user", "content": message["text"]})
return history, gr.MultimodalTextbox(value=None, interactive=False)
def _format_history_as_messages(history: list):
messages = []
current_role = None
current_message_content = []
for entry in history:
content = entry["content"]
if entry["role"] != current_role:
if current_role is not None:
messages.append(
{"role": current_role, "content": current_message_content}
)
current_role = entry["role"]
current_message_content = []
if isinstance(content, tuple): # Handle file paths
for path in content:
data_uri = _convert_path_to_data_uri(path)
current_message_content.append(
{"type": "image_url", "image_url": {"url": data_uri}}
)
elif isinstance(content, str): # Handle text
current_message_content.append({"type": "text", "text": content})
if current_role is not None:
messages.append({"role": current_role, "content": current_message_content})
return messages
def _convert_path_to_data_uri(path) -> str:
mime_type, _ = guess_type(path)
with open(path, "rb") as image_file:
data = image_file.read()
data_uri = f"data:{mime_type};base64," + b64encode(data).decode("utf-8")
return data_uri
def _is_file_safe(path) -> bool:
try:
return Path(path).is_file()
except Exception:
return False
def _process_content(content) -> str | list[str]:
if isinstance(content, str) and _is_file_safe(content):
return _convert_path_to_data_uri(content)
elif isinstance(content, list):
return _convert_path_to_data_uri(content[0])
return content
def respond_system_message(history: list) -> list: # -> list:
"""Respond to the user message with a system message"""
messages = _format_history_as_messages(history)
response = client.chat.completions.create(
messages=messages,
max_tokens=2000,
stream=False,
)
content = response.choices[0].message.content
# TODO: Add a response to the user message
message = gr.ChatMessage(role="assistant", content=content)
history.append(message)
return history
def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
"""Wrangle conversations and liked data into a DataFrame"""
liked_index = x.index[0]
output_data = []
for idx, message in enumerate(history):
if idx == liked_index:
message["metadata"] = {"title": "liked" if x.liked else "disliked"}
rating = message["metadata"].get("title")
if rating == "liked":
message["rating"] = 1
elif rating == "disliked":
message["rating"] = -1
else:
message["rating"] = None
output_data.append(
dict([(k, v) for k, v in message.items() if k != "metadata"])
)
return history, DataFrame(data=output_data)
def submit_conversation(dataframe, session_id):
""" "Submit the conversation to dataset repo"""
if dataframe.empty:
gr.Info("No messages to submit because the conversation was empty")
return (gr.Dataframe(value=None, interactive=False), [])
dataframe["content"] = dataframe["content"].apply(_process_content)
conversation_data = {
"conversation": dataframe.to_dict(orient="records"),
"timestamp": datetime.now().isoformat(),
"session_id": session_id,
"conversation_id": str(uuid.uuid4()),
}
save_feedback(input_object=conversation_data)
gr.Info(f"Submitted {len(dataframe)} messages to the dataset")
return (gr.Dataframe(value=None, interactive=False), [])
with gr.Blocks() as demo:
##############################
# Chatbot
##############################
session_id = gr.Textbox(
interactive=False,
value=str(uuid.uuid4()),
visible=False,
)
chatbot = gr.Chatbot(
elem_id="chatbot",
bubble_full_width=False,
type="messages",
)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_count="multiple",
placeholder="Enter message or upload file...",
show_label=False,
submit_btn=True,
)
chat_msg = chat_input.submit(
fn=add_user_message, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input]
)
bot_msg = chat_msg.then(
respond_system_message, chatbot, chatbot, api_name="bot_response"
)
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [chat_input])
##############################
# Deal with feedback
##############################
dataframe = gr.DataFrame()
chatbot.like(
fn=wrangle_like_data,
inputs=[chatbot],
outputs=[chatbot, dataframe],
like_user_message=False,
)
gr.Button(
value="Submit conversation",
).click(
fn=submit_conversation,
inputs=[dataframe, session_id],
outputs=[dataframe, chatbot],
)
demo.load(
lambda: str(uuid.uuid4()),
inputs=[],
outputs=[session_id],
)
demo.launch()
|