Spaces:
Build error
Build error
import gradio as gr | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
from detect import detect | |
from huggingface_hub import hf_hub_download | |
from torchvision.transforms import Compose, Normalize, Resize, ToTensor | |
from transformers.models.auto.modeling_auto import \ | |
AutoModelForImageClassification | |
def run(image, auto_crop): | |
if auto_crop: | |
image = detect(image) | |
# Preprocess image | |
transforms = Compose( | |
[ | |
Resize((224, 224)), | |
ToTensor(), | |
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
] | |
) | |
image = transforms(image).unsqueeze(0) | |
# Pass through model | |
prediction = F.softmax(model(pixel_values=image).logits[0], dim=0) | |
confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} | |
# Denormalize image | |
image.clamp_(min=float(image.min()), max=float(image.max())) | |
image.add_(-float(image.min())).div_(float(image.max()) - float(image.min()) + 1e-5) | |
image = image.squeeze(0).permute(1, 2, 0).numpy() | |
return confidences, image | |
# Load model | |
ckpt_path = hf_hub_download( | |
"bwconrad/beit-base-patch16-224-pt22k-ft22k-dafre", | |
"beit-base-patch16-224-pt22k-ft22k-dafre.ckpt", | |
) | |
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["state_dict"] | |
model = AutoModelForImageClassification.from_pretrained( | |
"microsoft/beit-base-patch16-224-pt22k-ft22k", | |
num_labels=3263, | |
ignore_mismatched_sizes=True, | |
image_size=224, | |
) | |
# Remove prefix from key names | |
new_state_dict = {} | |
for k, v in ckpt.items(): | |
if k.startswith("net"): | |
k = k.replace("net" + ".", "") | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict, strict=True) | |
# Load label names | |
labels = pd.read_csv("classid_classname.csv", names=["id", "name"])["name"].tolist() | |
labels = [l.replace("_", " ").title() for l in labels] # Remove _ and capitalize | |
# Run app | |
description = """ | |
A character classification model trained on the DAF:re dataset which consists of 3263 characters from anime, manga and video game series. | |
A list of all characters can be found [here](https://github.com/bwconrad/dafre/blob/main/app/classid_classname.csv). | |
Model training code can be found [here](https://github.com/bwconrad/dafre). | |
The model is trained and performs best on head and shoulder portrait images. | |
Users can manually crop images through the UI or check the `auto_crop` box to let a face detection model do the cropping. | |
""" | |
app = gr.Interface( | |
title="Anime Character Classification", | |
description=description, | |
fn=run, | |
inputs=[gr.Image(type="pil", tool="select"), gr.Checkbox(label="auto_crop")], | |
outputs=[gr.Label(num_top_classes=5), gr.Image().style(height=224, width=224)], | |
allow_flagging="never", | |
examples=[ | |
["rei.jpg", False], | |
["futaba.jpg", False], | |
["yotsuba.jpg", True], | |
], | |
) | |
app.launch() |