CindyBSydney's picture
Create app.py
ca1f33d
raw
history blame
2.28 kB
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from joblib import load
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import io
# Transformation and device setup
device = torch.device("cpu")
data_transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the Isolation Forest model
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')
# Load feature extractor
feature_extractor_path = 'Models/feature_extractor.pth'
feature_extractor = models.resnet50(weights=None)
feature_extractor.fc = nn.Sequential()
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
feature_extractor.to(device)
feature_extractor.eval()
# Load gastric classification model
GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth'
model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device)
model_ft.to(device)
model_ft.eval()
# Anomaly detection and classification function
def classify_image(uploaded_image):
image = Image.open(io.BytesIO(uploaded_image.read())).convert('RGB')
input_image = data_transforms(image).unsqueeze(0).to(device)
# Anomaly detection
if is_anomaly(clf, feature_extractor):
return "Anomaly detected. Image will not be classified.", None
# Classification
with torch.no_grad():
outputs = model_ft(input_image)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
_, predicted = torch.max(outputs, 1)
predicted_class_index = predicted.item()
class_names = ['abnormal', 'normal']
predicted_class_name = class_names[predicted_class_index]
predicted_probability = probabilities[0][predicted_class_index].item() * 100
return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None
# Create Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=File(type="filepath"),
[gr.outputs.Text(), gr.outputs.Image(plot=True)],
title="Gastric Image Classifier",
description="Upload a gastric image to classify it as normal or abnormal."
)
# Run the Gradio app
iface.launch()