CindyBSydney commited on
Commit
ca1f33d
1 Parent(s): 247beb3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import torchvision.models as models
4
+ from joblib import load
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import io
9
+
10
+ # Transformation and device setup
11
+ device = torch.device("cpu")
12
+ data_transforms = transforms.Compose([
13
+ transforms.Resize(224),
14
+ transforms.CenterCrop(224),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
17
+ ])
18
+
19
+ # Load the Isolation Forest model
20
+ clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')
21
+
22
+ # Load feature extractor
23
+ feature_extractor_path = 'Models/feature_extractor.pth'
24
+ feature_extractor = models.resnet50(weights=None)
25
+ feature_extractor.fc = nn.Sequential()
26
+ feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
27
+ feature_extractor.to(device)
28
+ feature_extractor.eval()
29
+
30
+ # Load gastric classification model
31
+ GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth'
32
+ model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device)
33
+ model_ft.to(device)
34
+ model_ft.eval()
35
+
36
+ # Anomaly detection and classification function
37
+ def classify_image(uploaded_image):
38
+ image = Image.open(io.BytesIO(uploaded_image.read())).convert('RGB')
39
+ input_image = data_transforms(image).unsqueeze(0).to(device)
40
+
41
+ # Anomaly detection
42
+ if is_anomaly(clf, feature_extractor):
43
+ return "Anomaly detected. Image will not be classified.", None
44
+
45
+ # Classification
46
+ with torch.no_grad():
47
+ outputs = model_ft(input_image)
48
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
49
+ _, predicted = torch.max(outputs, 1)
50
+
51
+ predicted_class_index = predicted.item()
52
+ class_names = ['abnormal', 'normal']
53
+ predicted_class_name = class_names[predicted_class_index]
54
+ predicted_probability = probabilities[0][predicted_class_index].item() * 100
55
+
56
+ return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None
57
+
58
+ # Create Gradio interface
59
+ iface = gr.Interface(
60
+ fn=classify_image,
61
+ inputs=File(type="filepath"),
62
+ [gr.outputs.Text(), gr.outputs.Image(plot=True)],
63
+ title="Gastric Image Classifier",
64
+ description="Upload a gastric image to classify it as normal or abnormal."
65
+ )
66
+
67
+ # Run the Gradio app
68
+ iface.launch()