Spaces:
Sleeping
Sleeping
File size: 4,221 Bytes
58c979f 0867d1f 58c979f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import torch
from Utilities.model import Net
from Utilities import config
from Utilities.utils import generate_confidences, generate_gradcam, generate_missclassified_imgs
inputs = [
gr.Image(shape=(32, 32), label="Input Image"),
gr.Slider(minimum=1, maximum=10, step=1, label="Number of Top Prediction to Display"),
gr.Checkbox(default=False, label="Show GradCAM"),
gr.Slider(minimum=-2, maximum=-1, step=1, value=-1, label="GradCAM Layer (from the end)"),
gr.Slider(minimum=0, maximum=1, value=0.5, label="GradCAM Heatmap Opacity"),
gr.Checkbox(label="Show Incorrect Predictions"),
gr.Slider(minimum=5, maximum=50, step=5, label="Number of Incorrect Predictions to Display"),
]
model = Net(
num_classes=config.NUM_CLASSES,
dropout_percentage = config.DROPOUT_PERCENTAGE,
norm = config.LAYER_NORM,
criterion = config.CRITERION,
learning_rate = config.LEARNING_RATE,
weight_decay = config.WEIGHT_DECAY
)
model.load_state_dict(
torch.load(
config.MODEL_SAVE_PATH,
map_location=torch.device(config.ACCELERATOR)
)
)
model.pred_store = torch.load(config.PRED_STORE_PATH, map_location=torch.device(config.ACCELERATOR))
def generate_gradio_output(
input_img,
num_top_preds,
show_gradcam,
gradcam_layer,
gradcam_opacity,
show_misclassified,
num_misclassified,
):
processed_img, confidences = generate_confidences(
model=model,
input_img=input_img,
num_top_preds=num_top_preds
)
visulization = generate_gradcam(
model=model,
org_img=input_img,
input_img=processed_img,
show_gradcam=show_gradcam,
gradcam_layer=gradcam_layer,
gradcam_opacity=gradcam_opacity,
)
plot = generate_missclassified_imgs(
model=model,
show_misclassified=show_misclassified,
num_misclassified=num_misclassified,
)
return confidences, visulization, plot
outputs = [
gr.Label(visible=True, scale=0.5, label="Classification Confidences"),
gr.Image(shape=(32, 32), label="GradCAM Visualization").style(
width=256, height=256, visible=True
),
gr.Plot(visible=True, label="Misclassified Images")
]
examples = [
[config.EXAMPLE_IMG_PATH + "cat.jpeg", 3, True, -2, 0.68, True, 40],
[config.EXAMPLE_IMG_PATH + "horse.jpg", 3, True, -2, 0.59, True, 25],
[config.EXAMPLE_IMG_PATH + "bird.webp", 10, True, -1, 0.55, True, 20],
[config.EXAMPLE_IMG_PATH + "dog1.jpg", 10, True, -1, 0.33, True, 45],
[config.EXAMPLE_IMG_PATH + "frog1.webp", 5, True, -1, 0.64, True, 40],
[config.EXAMPLE_IMG_PATH + "deer.webp", 1, True, -2, 0.45, True, 20],
[config.EXAMPLE_IMG_PATH + "airplane.png", 3, True, -2, 0.43, True, 40],
[config.EXAMPLE_IMG_PATH + "shipp.jpg", 7, True, -1, 0.6, True, 30],
[config.EXAMPLE_IMG_PATH + "car.jpg", 2, True, -1, 0.68, True, 30],
[config.EXAMPLE_IMG_PATH + "truck1.jpg", 5, True, -2, 0.51, True, 35],
]
title = "Image Classification (CIFAR10 - 10 Classes) with GradCAM"
description = """A simple Gradio interface to visualize the output of a CNN trained on CIFAR10 dataset with GradCAM and Misclassified images.
The architecture is inspired from David Page's (myrtle.ai) DAWNBench winning model archiecture.
Please input the image and select the number of top predictions to display - you will see the top predictions and their corresponding confidence scores.
You can also select whether to show GradCAM for the particular image (utilizes the gradients of the classification score with respect to the final convolutional feature map, to identify the parts of an input image that most impact the classification score).
You need to select the model layer where the gradients need to be plugged from - this affects how much of the image is used to compute the GradCAM.
You can also select whether to show misclassified images - these are the images that the model misclassified.
Some examples are provided in the examples tab.
"""
gr.Interface(
fn=generate_gradio_output,
inputs=inputs,
outputs=outputs,
title=title,
description=description,
examples=examples
).launch()
|