more layers
Browse files
app.py
CHANGED
@@ -16,9 +16,21 @@ def predict(img):
|
|
16 |
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
|
17 |
return [str(k) for k in topk_indices[0].tolist()]
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
inputs=sp,
|
24 |
-
outputs=['label','label']).launch()
|
|
|
16 |
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
|
17 |
return [str(k) for k in topk_indices[0].tolist()]
|
18 |
|
19 |
+
with gr.Blocks() as iface:
|
20 |
+
gr.Markdown("# MNIST + Gradio End to End")
|
21 |
+
gr.HTML("Shows end to end MNIST training with Gradio interface")
|
22 |
+
with gr.Row():
|
23 |
+
with gr.Column():
|
24 |
+
sp = gr.Sketchpad(shape=(28, 28))
|
25 |
+
with gr.Row():
|
26 |
+
with gr.Column():
|
27 |
+
pred_button = gr.Button("Predict")
|
28 |
+
with gr.Column():
|
29 |
+
clear = gr.Button("Clear")
|
30 |
+
with gr.Column():
|
31 |
+
label1 = gr.Label(label='1st Pred')
|
32 |
+
label2 = gr.Label(label='2nd Pred')
|
33 |
|
34 |
+
pred_button.click(predict, inputs=sp, outputs=[label1,label2])
|
35 |
+
clear.click(lambda: None, None, sp, queue=False)
|
36 |
+
iface.launch()
|
|
|
|
mnist.pth
CHANGED
Binary files a/mnist.pth and b/mnist.pth differ
|
|
model.py
CHANGED
@@ -5,11 +5,13 @@ class Net(nn.Module):
|
|
5 |
def __init__(self):
|
6 |
super(Net, self).__init__()
|
7 |
self.fc1 = nn.Linear(28*28, 128) # MNIST images are 28x28
|
8 |
-
self.fc2 = nn.Linear(128,
|
9 |
-
self.fc3 = nn.Linear(
|
|
|
10 |
|
11 |
def forward(self, x):
|
12 |
x = x.view(x.shape[0], -1) # Flatten the input
|
13 |
x = torch.relu(self.fc1(x))
|
14 |
x = torch.relu(self.fc2(x))
|
15 |
-
|
|
|
|
5 |
def __init__(self):
|
6 |
super(Net, self).__init__()
|
7 |
self.fc1 = nn.Linear(28*28, 128) # MNIST images are 28x28
|
8 |
+
self.fc2 = nn.Linear(128, 128)
|
9 |
+
self.fc3 = nn.Linear(128, 64)
|
10 |
+
self.fc4 = nn.Linear(64, 10) # There are 10 classes (0 through 9)
|
11 |
|
12 |
def forward(self, x):
|
13 |
x = x.view(x.shape[0], -1) # Flatten the input
|
14 |
x = torch.relu(self.fc1(x))
|
15 |
x = torch.relu(self.fc2(x))
|
16 |
+
x = torch.relu(self.fc3(x))
|
17 |
+
return self.fc4(x)
|
requirements.txt
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
gradio==3.29.0
|
2 |
-
numpy==1.23.5
|
3 |
Pillow==9.1.0
|
4 |
torch==2.0.1
|
5 |
torchvision==0.15.2
|
|
|
|
|
|
|
1 |
Pillow==9.1.0
|
2 |
torch==2.0.1
|
3 |
torchvision==0.15.2
|