unixpickle commited on
Commit
120e140
1 Parent(s): ddd296c
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
 
7
- from constants import MAKES_MODELS, PRICE_BIN_LABELS
8
 
9
  model = torch.jit.load("mobilenetv2_432000_calib.pt")
10
  model.eval()
@@ -26,13 +26,24 @@ def classify(img: Image.Image):
26
  price_bins = dict(
27
  zip(PRICE_BIN_LABELS, F.softmax(outputs["price_bin"], dim=-1)[0].tolist())
28
  )
 
 
 
 
 
 
29
  make_models = dict(
30
  zip(
31
  ([f"{make} {model}" for make, model in MAKES_MODELS] + ["Unknown"]),
32
  F.softmax(outputs["make_model"], dim=-1)[0].tolist(),
33
  )
34
  )
35
- return f"${int(round(outputs['price_median'].item()))}", price_bins, make_models
 
 
 
 
 
36
 
37
 
38
  iface = gr.Interface(
@@ -40,8 +51,9 @@ iface = gr.Interface(
40
  inputs=gr.Image(shape=(224, 224), type="pil"),
41
  outputs=[
42
  gr.Text(label="Price Prediction"),
43
- gr.Label(label="Price Bin"),
44
- gr.Label(label="Make/Model"),
 
45
  ],
46
  )
47
  iface.launch()
 
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
 
7
+ from constants import MAKES_MODELS, PRICE_BIN_LABELS, YEARS
8
 
9
  model = torch.jit.load("mobilenetv2_432000_calib.pt")
10
  model.eval()
 
26
  price_bins = dict(
27
  zip(PRICE_BIN_LABELS, F.softmax(outputs["price_bin"], dim=-1)[0].tolist())
28
  )
29
+ years = dict(
30
+ zip(
31
+ [str(year) for year in YEARS] + ["Unknown"],
32
+ F.softmax(outputs["year"], dim=-1)[0].tolist(),
33
+ )
34
+ )
35
  make_models = dict(
36
  zip(
37
  ([f"{make} {model}" for make, model in MAKES_MODELS] + ["Unknown"]),
38
  F.softmax(outputs["make_model"], dim=-1)[0].tolist(),
39
  )
40
  )
41
+ return (
42
+ f"${int(round(outputs['price_median'].item()))}",
43
+ price_bins,
44
+ years,
45
+ make_models,
46
+ )
47
 
48
 
49
  iface = gr.Interface(
 
51
  inputs=gr.Image(shape=(224, 224), type="pil"),
52
  outputs=[
53
  gr.Text(label="Price Prediction"),
54
+ gr.Label(label="Price Bin", num_top_classes=5),
55
+ gr.Label(label="Year", num_top_classes=5),
56
+ gr.Label(label="Make/Model", num_top_classes=10),
57
  ],
58
  )
59
  iface.launch()