kittchy
[ADD] image_vector_search
30099ac unverified
raw
history blame contribute delete
No virus
2.16 kB
from dataclasses import dataclass
from typing import Any
import japanese_clip as ja_clip
from s3_session import Bucket
from PIL import Image
import uuid
from db_session import get_db
@dataclass
class MLModel:
tokenizer: Any = None
model: Any = None
preprocess: Any = None
bucket: Any = None
def __post_init__(self):
tokenizer = ja_clip.load_tokenizer()
model, preprocess = ja_clip.load(
"rinna/japanese-clip-vit-b-16", cache_dir="/tmp/japanese_clip", device="cpu"
)
self.tokenizer = tokenizer
self.model = model
self.preprocess = preprocess
self.bucket = Bucket()
def save(self, image_path: str):
pillow_iamge = Image.open(image_path)
image = self.preprocess(pillow_iamge).unsqueeze(0).to("cpu")
image_features = self.model.get_image_features(image)
image_uuid = str(uuid.uuid4())
# media upload
self.bucket.upload_file(pillow_iamge, image_uuid)
# db insert
db = get_db()
result = db["embedding"].insert_one(
{"uuid": image_uuid, "vectorField": image_features[0].tolist()}
)
return result.inserted_id
def search(self, prompt: str):
db = get_db()
encodings = ja_clip.tokenize(
texts=[prompt], max_seq_len=77, device="cpu", tokenizer=self.tokenizer
)
text_features = self.model.get_text_features(**encodings)
pipeline = [
{
"$vectorSearch": {
"index": "vector_index",
"path": "vectorField",
"queryVector": text_features[0].tolist(),
"numCandidates": 150,
"limit": 10,
}
},
{
"$project": {
"_id": {"$toString": "$_id"},
"uuid": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
]
result = db["embedding"].aggregate(pipeline)
urls = [self.bucket.get_presigned_url(x["uuid"]) for x in result]
return urls