sauc-abadal-lloret commited on
Commit
b2babb6
1 Parent(s): 95416b8

First commit

Browse files
Files changed (4) hide show
  1. app.py +50 -0
  2. class_names.txt +100 -0
  3. pytorch_model.bin +3 -0
  4. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import gradio as gr
5
+ from torch import nn
6
+
7
+
8
+ LABELS = Path('class_names.txt').read_text().splitlines()
9
+
10
+ model = nn.Sequential(
11
+ nn.Conv2d(1, 32, 3, padding='same'),
12
+ nn.ReLU(),
13
+ nn.MaxPool2d(2),
14
+ nn.Conv2d(32, 64, 3, padding='same'),
15
+ nn.ReLU(),
16
+ nn.MaxPool2d(2),
17
+ nn.Conv2d(64, 128, 3, padding='same'),
18
+ nn.ReLU(),
19
+ nn.MaxPool2d(2),
20
+ nn.Flatten(),
21
+ nn.Linear(1152, 256),
22
+ nn.ReLU(),
23
+ nn.Linear(256, len(LABELS)),
24
+ )
25
+ state_dict = torch.load('pytorch_model.bin', map_location='cpu')
26
+ model.load_state_dict(state_dict, strict=False)
27
+ model.eval()
28
+
29
+ def predict(im):
30
+ x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
31
+
32
+ with torch.no_grad():
33
+ out = model(x)
34
+
35
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
36
+
37
+ values, indices = torch.topk(probabilities, 5)
38
+
39
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}
40
+
41
+ interface = gr.Interface(
42
+ predict,
43
+ inputs="sketchpad",
44
+ outputs='label',
45
+ theme="huggingface",
46
+ title="Sketch Recognition",
47
+ description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
48
+ article = "<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
49
+ live=True)
50
+ interface.launch(debug=True)
class_names.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ airplane
2
+ alarm_clock
3
+ anvil
4
+ apple
5
+ axe
6
+ baseball
7
+ baseball_bat
8
+ basketball
9
+ beard
10
+ bed
11
+ bench
12
+ bicycle
13
+ bird
14
+ book
15
+ bread
16
+ bridge
17
+ broom
18
+ butterfly
19
+ camera
20
+ candle
21
+ car
22
+ cat
23
+ ceiling_fan
24
+ cell_phone
25
+ chair
26
+ circle
27
+ clock
28
+ cloud
29
+ coffee_cup
30
+ cookie
31
+ cup
32
+ diving_board
33
+ donut
34
+ door
35
+ drums
36
+ dumbbell
37
+ envelope
38
+ eye
39
+ eyeglasses
40
+ face
41
+ fan
42
+ flower
43
+ frying_pan
44
+ grapes
45
+ hammer
46
+ hat
47
+ headphones
48
+ helmet
49
+ hot_dog
50
+ ice_cream
51
+ key
52
+ knife
53
+ ladder
54
+ laptop
55
+ light_bulb
56
+ lightning
57
+ line
58
+ lollipop
59
+ microphone
60
+ moon
61
+ mountain
62
+ moustache
63
+ mushroom
64
+ pants
65
+ paper_clip
66
+ pencil
67
+ pillow
68
+ pizza
69
+ power_outlet
70
+ radio
71
+ rainbow
72
+ rifle
73
+ saw
74
+ scissors
75
+ screwdriver
76
+ shorts
77
+ shovel
78
+ smiley_face
79
+ snake
80
+ sock
81
+ spider
82
+ spoon
83
+ square
84
+ star
85
+ stop_sign
86
+ suitcase
87
+ sun
88
+ sword
89
+ syringe
90
+ t-shirt
91
+ table
92
+ tennis_racquet
93
+ tent
94
+ tooth
95
+ traffic_light
96
+ tree
97
+ triangle
98
+ umbrella
99
+ wheel
100
+ wristwatch
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:effb6ea6f1593c09e8247944028ed9c309b5ff1cef82ba38b822bee2ca4d0f3c
3
+ size 1656903
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ gradio