mps / mps-api.py
huynhdoo's picture
Upload folder using huggingface_hub
5125874 verified
raw
history blame
2.33 kB
from modal import Image, App, Secret, web_endpoint, Volume, enter, method, build
from typing import Dict
import sys
model_image = (Image.debian_slim(python_version="3.12")
.pip_install("chromadb", "sentence-transformers", "pysqlite3-binary")
)
# Utilities
with model_image.imports():
import os
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") # Hotswap SQLlite version
# Application initialization
app = App("mps-api",
image=model_image)
vol = Volume.from_name("mps", create_if_missing=False)
data_path = "/data"
############
# MAIN CLASS
############
@app.cls(timeout=30*60,
volumes={data_path: vol})
class VECTORDB:
@enter()
@build()
def init(self):
# Load encoder
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
model_name = "Lajavaness/sentence-camembert-large"
self.embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name)
print(f"Embedding model loaded: {model_name}")
# Load vector database
import chromadb
DB_PATH = data_path + "/db"
COLLECTION_NAME = "MPS"
chroma_client = chromadb.PersistentClient(path=DB_PATH)
self.chroma_collection = chroma_client.get_collection(name=COLLECTION_NAME, embedding_function=self.embedding_function)
print(f"{self.chroma_collection.count()} documents loaded.")
@method()
def query(self, query, origins):
results = self.chroma_collection.query(
query_texts=[query],
n_results=10,
where={"origin": {"$in": origins}},
include=['documents', 'metadatas', 'distances'])
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
return documents, metadatas, distances
###########
# ENDPOINTS
###########
@app.function(timeout=30*60)
@web_endpoint(method="POST")
def query(query: Dict):
# Log query
print(f"Incoming query: {query}...")
# Instantiate vectordb
vectordb = VECTORDB()
# Run query
documents, metadatas, distances = vectordb.query.remote(query['query'], query['origins'])
return {"documents" : documents, "metadatas" : metadatas, "distances" : distances}