jinlinyi commited on
Commit
97b28a1
1 Parent(s): 0e491a3

update to new codebase

Browse files
Files changed (1) hide show
  1. app.py +25 -94
app.py CHANGED
@@ -1,26 +1,26 @@
1
 
2
  import os
3
- os.system(f"pip install -U openmim")
4
- os.system(f"mim install mmcv")
5
- os.system(f"pip install git+https://github.com/jinlinyi/PerspectiveFields.git@dev#egg=perspective2d")
 
6
 
7
 
8
  import gradio as gr
9
  import cv2
10
  import copy
11
- import torch
12
- from PIL import Image, ImageDraw
13
- from glob import glob
14
  import numpy as np
15
  import os.path as osp
16
- from detectron2.config import get_cfg
17
- from detectron2.data.detection_utils import read_image
18
- from perspective2d.utils.predictor import VisualizationDemo
19
- import perspective2d.modeling # noqa
20
- from perspective2d.config import get_perspective2d_cfg_defaults
21
- from perspective2d.utils import draw_from_r_p_f_cx_cy
22
  from datetime import datetime
23
 
 
 
 
 
 
 
 
 
24
 
25
 
26
 
@@ -51,24 +51,6 @@ article = """
51
 
52
 
53
 
54
-
55
- def setup_cfg(args):
56
- cfgs = []
57
- configs = args['config_file'].split('#')
58
- weights_id = args['opts'].index('MODEL.WEIGHTS') + 1
59
- weights = args['opts'][weights_id].split('#')
60
- for i, conf in enumerate(configs):
61
- if len(conf) != 0:
62
- tmp_opts = copy.deepcopy(args['opts'])
63
- tmp_opts[weights_id] = weights[i]
64
- cfg = get_cfg()
65
- get_perspective2d_cfg_defaults(cfg)
66
- cfg.merge_from_file(conf)
67
- cfg.merge_from_list(tmp_opts)
68
- cfg.freeze()
69
- cfgs.append(cfg)
70
- return cfgs
71
-
72
  def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None):
73
  height = img.shape[0]
74
  width = img.shape[1]
@@ -98,37 +80,26 @@ def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None):
98
  return img, field
99
 
100
 
101
- def inference(img, model_type):
102
- img_h = img.shape[0]
103
  if model_type is None:
104
  return None, ""
105
- perspective_cfg_list = setup_cfg(model_zoo[model_type])
106
- demo = VisualizationDemo(cfg_list=perspective_cfg_list)
107
-
108
- # img = read_image(image_path, format="BGR")
109
- img = img[..., ::-1] # rgb->bgr
110
- pred = demo.run_on_image(img)
111
  field = {
112
  'up': pred['pred_gravity_original'].cpu().detach(),
113
  'lati': pred['pred_latitude_original'].cpu().detach(),
114
  }
115
- img, field = resize_fix_aspect_ratio(img, field, 640)
116
  if not model_zoo[model_type]['param']:
117
- pred_vis = demo.draw(
118
- image=img,
119
- latimap=field['lati'],
120
- gravity=field['up'],
121
- latimap_format=pred['pred_latitude_original_mode'],
122
- ).get_image()
123
  param = "Not Implemented"
124
  else:
125
- if 'pred_general_vfov' not in pred.keys():
126
- pred['pred_general_vfov'] = pred['pred_vfov']
127
- if 'pred_rel_cx' not in pred.keys():
128
- pred['pred_rel_cx'] = torch.FloatTensor([0])
129
- if 'pred_rel_cy' not in pred.keys():
130
- pred['pred_rel_cy'] = torch.FloatTensor([0])
131
-
132
  r_p_f_rad = np.radians(
133
  [
134
  pred['pred_roll'].cpu().item(),
@@ -143,14 +114,14 @@ def inference(img, model_type):
143
  param = f"roll {pred['pred_roll'].cpu().item() :.2f}\npitch {pred['pred_pitch'].cpu().item() :.2f}\nvertical fov {pred['pred_general_vfov'].cpu().item() :.2f}\nfocal_length {pred['pred_rel_focal'].cpu().item()*img_h :.2f}\n"
144
  param += f"principal point {pred['pred_rel_cx'].cpu().item() :.2f} {pred['pred_rel_cy'].cpu().item() :.2f}"
145
  pred_vis = draw_from_r_p_f_cx_cy(
146
- img[:,:,::-1],
147
  *r_p_f_rad,
148
  *cx_cy,
149
  'rad',
150
  up_color=(0,1,0),
151
  )
152
  print(f"""time {datetime.now().strftime("%H:%M:%S")}
153
- img.shape {img.shape}
154
  model_type {model_type}
155
  param {param}
156
  """
@@ -163,46 +134,6 @@ for img_name in glob('assets/imgs/*.*g'):
163
  print(examples)
164
 
165
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
166
- model_zoo = {
167
-
168
- 'NEW:Paramnet-360Cities-edina-centered': {
169
- 'weights': ['https://www.dropbox.com/s/z2dja70bgy007su/paramnet_360cities_edina_rpf.pth'],
170
- 'opts': ['MODEL.WEIGHTS', 'models/paramnet_360cities_edina_rpf.pth', 'MODEL.DEVICE', device,],
171
- 'config_file': 'models/paramnet_360cities_edina_rpf.yaml',
172
- 'param': True,
173
- },
174
-
175
- 'NEW:Paramnet-360Cities-edina-uncentered': {
176
- 'weights': ['https://www.dropbox.com/s/nt29e1pi83mm1va/paramnet_360cities_edina_rpfpp.pth'],
177
- 'opts': ['MODEL.WEIGHTS', 'models/paramnet_360cities_edina_rpfpp.pth', 'MODEL.DEVICE', device,],
178
- 'config_file': 'models/paramnet_360cities_edina_rpfpp.yaml',
179
- 'param': True,
180
- },
181
-
182
- 'PersNet-360Cities': {
183
- 'weights': ['https://www.dropbox.com/s/czqrepqe7x70b7y/cvpr2023.pth'],
184
- 'opts': ['MODEL.WEIGHTS', 'models/cvpr2023.pth', 'MODEL.DEVICE', device,],
185
- 'config_file': 'models/cvpr2023.yaml',
186
- 'param': False,
187
- },
188
- 'PersNet_Paramnet-GSV-uncentered': {
189
- 'weights': ['https://www.dropbox.com/s/ufdadxigewakzlz/paramnet_gsv_rpfpp.pth'],
190
- 'opts': ['MODEL.WEIGHTS', 'models/paramnet_gsv_rpfpp.pth', 'MODEL.DEVICE', device,],
191
- 'config_file': 'models/paramnet_gsv_rpfpp.yaml',
192
- 'param': True,
193
- },
194
- # trained on GSV dataset, predicts Perspective Fields + camera parameters (roll, pitch, fov), assuming centered principal point
195
- 'PersNet_Paramnet-GSV-centered': {
196
- 'weights': ['https://www.dropbox.com/s/g6xwbgnkggapyeu/paramnet_gsv_rpf.pth'],
197
- 'opts': ['MODEL.WEIGHTS', 'models/paramnet_gsv_rpf.pth', 'MODEL.DEVICE', device,],
198
- 'config_file': 'models/paramnet_gsv_rpf.yaml',
199
- 'param': True,
200
- },
201
- }
202
- for model_id in model_zoo:
203
- html = model_zoo[model_id]['weights'][0]
204
- if not os.path.exists(os.path.join('models', html.split('/')[-1])):
205
- os.system(f"wget -P models/ {html}")
206
 
207
  info = """Select model\n"""
208
  gr.Interface(
 
1
 
2
  import os
3
+ try:
4
+ import perspective2d
5
+ except:
6
+ os.system(f"pip install git+https://github.com/jinlinyi/PerspectiveFields.git@v1.0.0")
7
 
8
 
9
  import gradio as gr
10
  import cv2
11
  import copy
 
 
 
12
  import numpy as np
13
  import os.path as osp
 
 
 
 
 
 
14
  from datetime import datetime
15
 
16
+ import torch
17
+ from PIL import Image, ImageDraw
18
+ from glob import glob
19
+
20
+ from perspective2d import PerspectiveFields
21
+ from perspective2d.utils import draw_perspective_fields, draw_from_r_p_f_cx_cy
22
+ from perspective2d.perspectivefields import model_zoo
23
+
24
 
25
 
26
 
 
51
 
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None):
55
  height = img.shape[0]
56
  width = img.shape[1]
 
80
  return img, field
81
 
82
 
83
+ def inference(img_rgb, model_type):
 
84
  if model_type is None:
85
  return None, ""
86
+ pf_model = PerspectiveFields(model_type).eval().to(device)
87
+ pred = pf_model.inference(img_bgr=img_rgb[...,::-1])
88
+ img_h = img_rgb.shape[0]
 
 
 
89
  field = {
90
  'up': pred['pred_gravity_original'].cpu().detach(),
91
  'lati': pred['pred_latitude_original'].cpu().detach(),
92
  }
93
+ img_rgb, field = resize_fix_aspect_ratio(img_rgb, field, 640)
94
  if not model_zoo[model_type]['param']:
95
+ pred_vis = draw_perspective_fields(
96
+ img_rgb,
97
+ field['up'],
98
+ torch.deg2rad(field['lati']),
99
+ color=(0,1,0),
100
+ )
101
  param = "Not Implemented"
102
  else:
 
 
 
 
 
 
 
103
  r_p_f_rad = np.radians(
104
  [
105
  pred['pred_roll'].cpu().item(),
 
114
  param = f"roll {pred['pred_roll'].cpu().item() :.2f}\npitch {pred['pred_pitch'].cpu().item() :.2f}\nvertical fov {pred['pred_general_vfov'].cpu().item() :.2f}\nfocal_length {pred['pred_rel_focal'].cpu().item()*img_h :.2f}\n"
115
  param += f"principal point {pred['pred_rel_cx'].cpu().item() :.2f} {pred['pred_rel_cy'].cpu().item() :.2f}"
116
  pred_vis = draw_from_r_p_f_cx_cy(
117
+ img_rgb,
118
  *r_p_f_rad,
119
  *cx_cy,
120
  'rad',
121
  up_color=(0,1,0),
122
  )
123
  print(f"""time {datetime.now().strftime("%H:%M:%S")}
124
+ img.shape {img_rgb.shape}
125
  model_type {model_type}
126
  param {param}
127
  """
 
134
  print(examples)
135
 
136
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  info = """Select model\n"""
139
  gr.Interface(