Andrew Smith
Update with instructions on running locally
e359b32
import os
from typing import Union
from PIL import Image
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from sentence_transformers import SentenceTransformer
import uvicorn
import vecs
DB_CONNECTION = os.environ.get(
'DB_URL', "postgresql://postgres:postgres@localhost:54322/postgres")
app = FastAPI()
@app.get("/seed")
def seed():
# create vector store client
vx = vecs.create_client(DB_CONNECTION)
iv = vx.get_collection(name="image_vectors")
if iv:
return {"message": "Collection already exists."}
# create a collection of vectors with 512 dimensions
images = vx.create_collection(name="image_vectors", dimension=512)
# Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
# Encode an image:
img_emb1 = model.encode(Image.open('./images/one.jpg'))
img_emb2 = model.encode(Image.open('./images/two.jpg'))
img_emb3 = model.encode(Image.open('./images/three.jpg'))
img_emb4 = model.encode(Image.open('./images/four.jpg'))
images.upsert(
vectors=[
(
"one.jpg",
img_emb1,
{"type": "jpg"}
), (
"two.jpg",
img_emb2,
{"type": "jpg"}
), (
"three.jpg",
img_emb3,
{"type": "jpg"}
), (
"four.jpg",
img_emb4,
{"type": "jpg"}
)
]
)
print("Inserted images")
# index the collection fro fast search performance
images.create_index()
return {"message": "Collection created and indexed."}
@app.get("/search")
def search(query: Union[str, None] = None):
# create vector store client
vx = vecs.create_client(DB_CONNECTION)
images = vx.get_collection(name="image_vectors")
# Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
# Encode text query
query_string = query
text_emb = model.encode(query_string)
# query the collection filtering metadata for "type" = "jpg"
results = images.query(
query_vector=text_emb,
limit=1,
filters={"type": {"$eq": "jpg"}},
)
result = results[0]
return {"result": result, "query": query}
app.mount("/images", StaticFiles(directory="images"), name="images")
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="static/index.html", media_type="text/html")
def start():
"""Launched with `poetry run start` at root level"""
uvicorn.run("image_search.main:app",
host="0.0.0.0", port=7860, reload=True)