Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
import gradio as gr | |
from model import DistMult | |
from PIL import Image | |
from torchvision import transforms | |
import json | |
from tqdm import tqdm | |
# Default image tensor normalization | |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] | |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] | |
def generate_target_list(data, entity2id): | |
sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] | |
sub = list(sub['t']) | |
categories = [] | |
for item in tqdm(sub): | |
if entity2id[str(int(float(item)))] not in categories: | |
categories.append(entity2id[str(int(float(item)))]) | |
# print('categories = {}'.format(categories)) | |
# print("No. of target categories = {}".format(len(categories))) | |
return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) | |
# Load necessary data and initialize the model | |
entity2id = json.load(open('entity2id_subtree.json', 'r')) | |
id2entity = {v: k for k, v in entity2id.items()} | |
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False) | |
num_ent_id = len(entity2id) | |
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere | |
overall_id_to_name = json.load(open('overall_id_to_name.json')) | |
# Initialize your model here | |
model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary | |
model.eval() | |
ckpt = torch.load('species_class_model.pt', map_location=torch.device('cpu')) | |
model.load_state_dict(ckpt['model'], strict=False) | |
print('ckpt loaded...') | |
# Define your evaluation function | |
def evaluate(img): | |
transform_steps = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((448, 448)), | |
transforms.ToTensor(), | |
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD) | |
]) | |
h = transform_steps(img) | |
r = torch.tensor([3]) | |
# Assuming `move_to` is a function to move tensors to the desired device | |
h = h.unsqueeze(0) | |
r = r.unsqueeze(0) | |
outputs = F.softmax(model.forward_ce(h, r, triple_type=('image', 'id')), dim=-1) | |
# print('outputs = {}'.format(outputs.size())) | |
predictions = torch.topk(outputs, k=5, dim=-1).indices.squeeze(0).tolist() | |
# print('predictions', predictions) | |
result = {} | |
for i in predictions: | |
pred_label = target_list[i].item() | |
label = overall_id_to_name[str(id2entity[pred_label])] | |
prob = outputs[0, i].item() | |
result[label] = prob | |
# y_pred = outputs.argmax(-1).cpu() | |
# pred_label = target_list[y_pred].item() | |
# species_label = overall_id_to_name[str(id2entity[pred_label])] | |
# print('pred_label', pred_label) | |
# print('species_label', species_label) | |
# return species_label | |
return result | |
# Gradio interface | |
species_model = gr.Interface( | |
evaluate, | |
gr.inputs.Image(shape=(200, 200)), | |
outputs="label", | |
title='Camera Trap Species Classification demo', | |
# description='Species Classification', | |
# article='Species Classification' | |
) | |
species_model.launch(server_port=8977,share=True, debug=True) | |