Spaces:
Paused
Paused
import os | |
from typing import Union | |
from PIL import Image | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from sentence_transformers import SentenceTransformer | |
import uvicorn | |
import vecs | |
DB_CONNECTION = os.environ.get( | |
'DB_URL', "postgresql://postgres:postgres@localhost:54322/postgres") | |
app = FastAPI() | |
def seed(): | |
# create vector store client | |
vx = vecs.create_client(DB_CONNECTION) | |
iv = vx.get_collection(name="image_vectors") | |
if iv: | |
return {"message": "Collection already exists."} | |
# create a collection of vectors with 512 dimensions | |
images = vx.create_collection(name="image_vectors", dimension=512) | |
# Load CLIP model | |
model = SentenceTransformer('clip-ViT-B-32') | |
# Encode an image: | |
img_emb1 = model.encode(Image.open('./images/one.jpg')) | |
img_emb2 = model.encode(Image.open('./images/two.jpg')) | |
img_emb3 = model.encode(Image.open('./images/three.jpg')) | |
img_emb4 = model.encode(Image.open('./images/four.jpg')) | |
images.upsert( | |
vectors=[ | |
( | |
"one.jpg", | |
img_emb1, | |
{"type": "jpg"} | |
), ( | |
"two.jpg", | |
img_emb2, | |
{"type": "jpg"} | |
), ( | |
"three.jpg", | |
img_emb3, | |
{"type": "jpg"} | |
), ( | |
"four.jpg", | |
img_emb4, | |
{"type": "jpg"} | |
) | |
] | |
) | |
print("Inserted images") | |
# index the collection fro fast search performance | |
images.create_index() | |
return {"message": "Collection created and indexed."} | |
def search(query: Union[str, None] = None): | |
# create vector store client | |
vx = vecs.create_client(DB_CONNECTION) | |
images = vx.get_collection(name="image_vectors") | |
# Load CLIP model | |
model = SentenceTransformer('clip-ViT-B-32') | |
# Encode text query | |
query_string = query | |
text_emb = model.encode(query_string) | |
# query the collection filtering metadata for "type" = "jpg" | |
results = images.query( | |
query_vector=text_emb, | |
limit=1, | |
filters={"type": {"$eq": "jpg"}}, | |
) | |
result = results[0] | |
return {"result": result, "query": query} | |
app.mount("/images", StaticFiles(directory="images"), name="images") | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
def index() -> FileResponse: | |
return FileResponse(path="static/index.html", media_type="text/html") | |
def start(): | |
"""Launched with `poetry run start` at root level""" | |
uvicorn.run("image_search.main:app", | |
host="0.0.0.0", port=7860, reload=True) | |