from dataclasses import dataclass from typing import Any import japanese_clip as ja_clip from s3_session import Bucket from PIL import Image import uuid from db_session import get_db @dataclass class MLModel: tokenizer: Any = None model: Any = None preprocess: Any = None bucket: Any = None def __post_init__(self): tokenizer = ja_clip.load_tokenizer() model, preprocess = ja_clip.load( "rinna/japanese-clip-vit-b-16", cache_dir="/tmp/japanese_clip", device="cpu" ) self.tokenizer = tokenizer self.model = model self.preprocess = preprocess self.bucket = Bucket() def save(self, image_path: str): pillow_iamge = Image.open(image_path) image = self.preprocess(pillow_iamge).unsqueeze(0).to("cpu") image_features = self.model.get_image_features(image) image_uuid = str(uuid.uuid4()) # media upload self.bucket.upload_file(pillow_iamge, image_uuid) # db insert db = get_db() result = db["embedding"].insert_one( {"uuid": image_uuid, "vectorField": image_features[0].tolist()} ) return result.inserted_id def search(self, prompt: str): db = get_db() encodings = ja_clip.tokenize( texts=[prompt], max_seq_len=77, device="cpu", tokenizer=self.tokenizer ) text_features = self.model.get_text_features(**encodings) pipeline = [ { "$vectorSearch": { "index": "vector_index", "path": "vectorField", "queryVector": text_features[0].tolist(), "numCandidates": 150, "limit": 10, } }, { "$project": { "_id": {"$toString": "$_id"}, "uuid": 1, "score": {"$meta": "vectorSearchScore"}, } }, ] result = db["embedding"].aggregate(pipeline) urls = [self.bucket.get_presigned_url(x["uuid"]) for x in result] return urls