File size: 2,283 Bytes
ca1f33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from joblib import load
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import io

# Transformation and device setup
device = torch.device("cpu")
data_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the Isolation Forest model
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')

# Load feature extractor
feature_extractor_path = 'Models/feature_extractor.pth'
feature_extractor = models.resnet50(weights=None)
feature_extractor.fc = nn.Sequential()
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
feature_extractor.to(device)
feature_extractor.eval()

# Load gastric classification model
GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth'
model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device)
model_ft.to(device)
model_ft.eval()

# Anomaly detection and classification function
def classify_image(uploaded_image):
    image = Image.open(io.BytesIO(uploaded_image.read())).convert('RGB')
    input_image = data_transforms(image).unsqueeze(0).to(device)

    # Anomaly detection
    if is_anomaly(clf, feature_extractor):
        return "Anomaly detected. Image will not be classified.", None

    # Classification
    with torch.no_grad():
        outputs = model_ft(input_image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)

    predicted_class_index = predicted.item()
    class_names = ['abnormal', 'normal']
    predicted_class_name = class_names[predicted_class_index]
    predicted_probability = probabilities[0][predicted_class_index].item() * 100

    return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None

# Create Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=File(type="filepath"),
    [gr.outputs.Text(), gr.outputs.Image(plot=True)],
    title="Gastric Image Classifier",
    description="Upload a gastric image to classify it as normal or abnormal."
)

# Run the Gradio app
iface.launch()