Junlinh's picture
Update app.py
60efb1b
raw
history blame contribute delete
No virus
1.64 kB
import gradio as gr
import torchvision.transforms as transforms
from PIL import Image
import torch
from timm.models import create_model
import numpy as np
def predict(input_img):
input_img = Image.fromarray(np.uint8(input_img))
model1 = create_model(
'resnet50',
drop_rate=0.5,
num_classes=1,)
model2 = create_model(
'resnet50',
drop_rate=0.5,
num_classes=1,)
checkpoint1 = torch.load("./machine_full_best.tar",map_location=torch.device('cpu'))
model1.load_state_dict(checkpoint1['state_dict'])
checkpoint2 = torch.load("./human_full_best.tar",map_location=torch.device('cpu'))
model2.load_state_dict(checkpoint2['state_dict'])
my_transform = transforms.Compose([
transforms.RandomResizedCrop(224, (1, 1)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),])
input_img = my_transform(input_img).view(1,3,224,224)
model1.eval()
model2.eval()
result1 = round(model1(input_img).item(), 3)
result2 = round(model2(input_img).item(), 3)
result = 'MachineMem score = ' + str(result1) + ', HumanMem score = ' + str(result2) +'.'
return result
demo = gr.Interface(predict, gr.Image(), "text", examples=["1.jpg", "2.jpg", "3.jpg", "4.jpg", "5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg", "12.jpg", "13.jpg", "14.jpg", "15.jpg", "16.jpg", "18.jpg", "19.jpg", "20.jpg","21.jpg", "22.jpg", "24.jpg", "25.jpg", "26.jpg", "27.jpg", "28.jpg", "30.jpg","32.jpg", "35.jpg", "36.jpg", "37.jpg"])
demo.launch(debug = True)