Ahsen Khaliq commited on
Commit
37f7a9f
1 Parent(s): 910b4e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ from pytorchvideo.data.encoded_video import EncodedVideo
9
+ from torchvision.transforms._transforms_video import NormalizeVideo
10
+
11
+ from pytorchvideo.transforms import (
12
+ ApplyTransformToKey,
13
+ ShortSideScale,
14
+ UniformTemporalSubsample,
15
+ )
16
+
17
+
18
+
19
+
20
+ # Device on which to run the model
21
+ # Set to cuda to load on GPU
22
+ device = "cpu"
23
+
24
+ # Pick a pretrained model
25
+ model_name = "omnivore_swinB"
26
+ model = torch.hub.load("facebookresearch/omnivore:main", model=model_name)
27
+
28
+ # Set to eval mode and move to desired device
29
+ model = model.to(device)
30
+ model = model.eval()
31
+
32
+ os.system("wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json")
33
+
34
+ with open("imagenet_class_index.json", "r") as f:
35
+ imagenet_classnames = json.load(f)
36
+
37
+ # Create an id to label name mapping
38
+ imagenet_id_to_classname = {}
39
+ for k, v in imagenet_classnames.items():
40
+ imagenet_id_to_classname[k] = v[1]
41
+
42
+ os.system("wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg")
43
+
44
+ def inference(img):
45
+ image = img
46
+ image_transform = T.Compose(
47
+ [
48
+ T.Resize(224),
49
+ T.CenterCrop(224),
50
+ T.ToTensor(),
51
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
52
+ ]
53
+ )
54
+ image = image_transform(image)
55
+
56
+ # The model expects inputs of shape: B x C x T x H x W
57
+ image = image[None, :, None, ...]
58
+
59
+ prediction = model(image, input_type="image")
60
+ prediction = F.softmax(prediction, dim=1)
61
+ pred_classes = prediction.topk(k=5).indices
62
+
63
+ pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]
64
+ return "Top 5 predicted labels: %s" % ", ".join(pred_class_names)
65
+
66
+ inputs = gr.inputs.Image(type='filepath')
67
+ outputs = gr.outputs.Textbox(label="Output")
68
+
69
+ title = "Omnivore"
70
+
71
+ description = "Gradio demo for Revisiting Weakly Supervised Pre-Training of Visual Perception Models. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
72
+
73
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.08371' target='_blank'>Revisiting Weakly Supervised Pre-Training of Visual Perception Models</a> | <a href='https://github.com/facebookresearch/SWAG' target='_blank'>Github Repo</a></p>"
74
+
75
+
76
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['dog.jpg']]).launch(enable_queue=True,cache_examples=True)