YiJina / app.py
Tonic's picture
improve pass metadata
930288d
raw
history blame
9.27 kB
# main.py
import spaces
from torch.nn import DataParallel
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import InferenceClient
from openai import OpenAI
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_chroma import Chroma
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.config import Settings
from chromadb import HttpClient
import os
import re
import uuid
import gradio as gr
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from utils import load_env_variables, parse_and_route
from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt
load_dotenv()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_CACHE_DISABLE'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### Utils
hf_token, yi_token = load_env_variables()
def clear_cuda_cache():
torch.cuda.empty_cache()
client = OpenAI(api_key=yi_token, base_url=API_BASE)
class EmbeddingGenerator:
def __init__(self, model_name: str, token: str, intention_client):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device)
self.intention_client = intention_client
def clear_cuda_cache(self):
torch.cuda.empty_cache()
@spaces.GPU
def compute_embeddings(self, input_text: str):
# Get the intention
intention_completion = self.intention_client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": intention_prompt},
{"role": "user", "content": input_text}
]
)
intention_output = intention_completion.choices[0].message['content']
# Parse and route the intention
parsed_task = parse_and_route(intention_output)
selected_task = list(parsed_task.keys())[0]
# Construct the prompt
try:
task_description = tasks[selected_task]
except KeyError:
print(f"Selected task not found: {selected_task}")
return f"Error: Task '{selected_task}' not found. Please select a valid task."
query_prefix = f"Instruct: {task_description}\nQuery: "
queries = [input_text]
# Get the metadata
metadata_completion = self.intention_client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": metadata_prompt},
{"role": "user", "content": input_text}
]
)
metadata_output = metadata_completion.choices[0].message['content']
metadata = self.extract_metadata(metadata_output)
# Get the embeddings
with torch.no_grad():
inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
outputs = self.model(**inputs)
query_embeddings = outputs.last_hidden_state.mean(dim=1)
# Normalize embeddings
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
# Include metadata in the embeddings
embeddings_with_metadata = [{"embedding": emb, "metadata": metadata} for emb in embeddings_list]
self.clear_cuda_cache()
return embeddings_with_metadata
def extract_metadata(self, metadata_output: str):
# Regex pattern to extract key-value pairs
pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
matches = pattern.findall(metadata_output)
metadata = {key: value for key, value in matches}
return metadata
class MyEmbeddingFunction(EmbeddingFunction):
def __init__(self, embedding_generator: EmbeddingGenerator):
self.embedding_generator = embedding_generator
def __call__(self, input: Documents) -> (Embeddings, list):
embeddings_with_metadata = [self.embedding_generator.compute_embeddings(doc) for doc in input]
embeddings = [item[0] for item in embeddings_with_metadata]
metadata = [item[1] for item in embeddings_with_metadata]
embeddings_flattened = [emb for sublist in embeddings for emb in sublist]
metadata_flattened = [meta for sublist in metadata for meta in sublist]
return embeddings_flattened, metadata_flattened
def load_documents(file_path: str, mode: str = "elements"):
loader = UnstructuredFileLoader(file_path, mode=mode)
docs = loader.load()
return [doc.page_content for doc in docs]
def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
client = HttpClient(host='localhost', port=8000, settings = Settings(allow_reset=True, anonymized_telemetry=False))
client.reset() # resets the database
collection = client.create_collection(collection_name)
return client, collection
def add_documents_to_chroma(client, collection, documents: list, embedding_function: MyEmbeddingFunction):
for doc in documents:
embeddings, metadata = embedding_function.embedding_generator.compute_embeddings(doc)
for embedding, meta in zip(embeddings, metadata):
collection.add(
ids=[str(uuid.uuid1())],
documents=[doc],
embeddings=[embedding],
metadatas=[meta]
)
def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction):
db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function)
result_docs = db.similarity_search(query_text)
return result_docs
# Initialize clients
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator)
chroma_client, chroma_collection = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
retrieved_text = query_documents(message)
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": f"{retrieved_text}\n\n{message}"})
response = ""
for message in intention_client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
def upload_documents(files):
for file in files:
loader = UnstructuredFileLoader(file.name)
documents = loader.load_documents()
add_documents_to_chroma(chroma_client, chroma_collection, documents, embedding_function)
return "Documents uploaded and processed successfully!"
def query_documents(query):
results = query_chroma(query)
return "\n\n".join([result.content for result in results])
with gr.Blocks() as demo:
with gr.Tab("Upload Documents"):
document_upload = gr.File(file_count="multiple", file_types=["document"])
upload_button = gr.Button("Upload and Process")
upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
with gr.Tab("Ask Questions"):
with gr.Row():
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
query_input = gr.Textbox(label="Query")
query_button = gr.Button("Query")
query_output = gr.Textbox()
query_button.click(query_documents, inputs=query_input, outputs=query_output)
if __name__ == "__main__":
demo.launch()