|
from modal import Image, App, Secret, web_endpoint, Volume, enter, method, build |
|
from typing import Dict |
|
import sys |
|
|
|
model_image = (Image.debian_slim(python_version="3.12") |
|
.pip_install("chromadb", "sentence-transformers", "pysqlite3-binary") |
|
) |
|
|
|
|
|
with model_image.imports(): |
|
import os |
|
__import__("pysqlite3") |
|
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") |
|
|
|
|
|
app = App("mps-api", |
|
image=model_image) |
|
vol = Volume.from_name("mps", create_if_missing=False) |
|
data_path = "/data" |
|
|
|
|
|
|
|
|
|
@app.cls(timeout=30*60, |
|
volumes={data_path: vol}) |
|
class VECTORDB: |
|
@enter() |
|
@build() |
|
def init(self): |
|
|
|
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction |
|
model_name = "Lajavaness/sentence-camembert-large" |
|
self.embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name) |
|
print(f"Embedding model loaded: {model_name}") |
|
|
|
|
|
import chromadb |
|
DB_PATH = data_path + "/db" |
|
COLLECTION_NAME = "MPS" |
|
chroma_client = chromadb.PersistentClient(path=DB_PATH) |
|
self.chroma_collection = chroma_client.get_collection(name=COLLECTION_NAME, embedding_function=self.embedding_function) |
|
print(f"{self.chroma_collection.count()} documents loaded.") |
|
|
|
@method() |
|
def query(self, query, origins): |
|
results = self.chroma_collection.query( |
|
query_texts=[query], |
|
n_results=10, |
|
where={"origin": {"$in": origins}}, |
|
include=['documents', 'metadatas', 'distances']) |
|
|
|
documents = results['documents'][0] |
|
metadatas = results['metadatas'][0] |
|
distances = results['distances'][0] |
|
return documents, metadatas, distances |
|
|
|
|
|
|
|
|
|
@app.function(timeout=30*60) |
|
@web_endpoint(method="POST") |
|
def query(query: Dict): |
|
|
|
print(f"Incoming query: {query}...") |
|
|
|
|
|
vectordb = VECTORDB() |
|
|
|
|
|
documents, metadatas, distances = vectordb.query.remote(query['query'], query['origins']) |
|
return {"documents" : documents, "metadatas" : metadatas, "distances" : distances} |
|
|