from modal import Image, App, Secret, web_endpoint, Volume, enter, method, build from typing import Dict import sys model_image = (Image.debian_slim(python_version="3.12") .pip_install("chromadb", "sentence-transformers", "pysqlite3-binary") ) # Utilities with model_image.imports(): import os __import__("pysqlite3") sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") # Hotswap SQLlite version # Application initialization app = App("mps-api", image=model_image) vol = Volume.from_name("mps", create_if_missing=False) data_path = "/data" ############ # MAIN CLASS ############ @app.cls(timeout=30*60, volumes={data_path: vol}) class VECTORDB: @enter() @build() def init(self): # Load encoder from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction model_name = "Lajavaness/sentence-camembert-large" self.embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name) print(f"Embedding model loaded: {model_name}") # Load vector database import chromadb DB_PATH = data_path + "/db" COLLECTION_NAME = "MPS" chroma_client = chromadb.PersistentClient(path=DB_PATH) self.chroma_collection = chroma_client.get_collection(name=COLLECTION_NAME, embedding_function=self.embedding_function) print(f"{self.chroma_collection.count()} documents loaded.") @method() def query(self, query, origins): results = self.chroma_collection.query( query_texts=[query], n_results=10, where={"origin": {"$in": origins}}, include=['documents', 'metadatas', 'distances']) documents = results['documents'][0] metadatas = results['metadatas'][0] distances = results['distances'][0] return documents, metadatas, distances ########### # ENDPOINTS ########### @app.function(timeout=30*60) @web_endpoint(method="POST") def query(query: Dict): # Log query print(f"Incoming query: {query}...") # Instantiate vectordb vectordb = VECTORDB() # Run query documents, metadatas, distances = vectordb.query.remote(query['query'], query['origins']) return {"documents" : documents, "metadatas" : metadatas, "distances" : distances}