Spaces:
Sleeping
Sleeping
from datetime import datetime | |
import gradio as gr | |
import torch | |
import torchvision.transforms as T | |
import numpy as np | |
from PIL import Image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# DINOv2 | |
# Select checkpoint | |
dinov2_ckpt = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'][1] | |
dinov2 = torch.hub.load('facebookresearch/dinov2', dinov2_ckpt) | |
dinov2.to(device) | |
print() | |
transform_image = T.Compose([ | |
T.Resize((224, 224)), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
def predict(image): | |
""" | |
Predict the identity of an image. | |
Args: | |
image: A PIL Image object. | |
Returns: | |
A string representing the predicted identity of the image. | |
""" | |
# Convert the image to a tensor. | |
transformed_img = transform_image(image)[:3].unsqueeze(0).to(device) | |
# Get the embedding of the image. | |
with torch.no_grad(): | |
embedding = dinov2(transformed_img) | |
print(embedding.shape) | |
embedding = embedding[0].cpu().numpy().tolist() | |
# Get the current datetime for logging | |
current_datetime = datetime.now() | |
formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S") | |
print(formatted_datetime) | |
return { | |
"embedding": embedding | |
} | |
# Create a Gradio interface. | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[gr.Image(type='pil')], | |
outputs=[gr.JSON()], | |
title="DINOv2 Embedding", | |
description=dinov2_ckpt | |
) | |
# Start the Gradio server. | |
interface.launch() | |