blumenstiel commited on
Commit
1e019a9
Β·
1 Parent(s): 58ac12b

Made timestamps flexible and updated description

Browse files
Files changed (1) hide show
  1. app.py +54 -261
app.py CHANGED
@@ -1,6 +1,14 @@
1
- import logging
2
  import os
 
 
 
 
 
 
3
  from huggingface_hub import hf_hub_download
 
 
4
 
5
  # pull files from hub
6
  token = os.environ.get("HF_TOKEN", None)
@@ -15,206 +23,6 @@ model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-30
15
  os.system(f'cp {model_def} .')
16
  os.system(f'cp {model_inference} .')
17
 
18
- import os
19
- import torch
20
- import yaml
21
- import numpy as np
22
- import gradio as gr
23
- from einops import rearrange
24
- from functools import partial
25
- from prithvi_mae import PrithviMAE
26
- from inference import process_channel_group, read_geotiff, save_geotiff, _convert_np_uint8, load_example, run_model
27
-
28
-
29
- NO_DATA = -9999
30
- NO_DATA_FLOAT = 0.0001
31
- PERCENTILES = (0.1, 99.9)
32
-
33
-
34
- # def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
35
- # """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
36
- # original range using *data_mean* and *data_std* and then lowest and highest percentiles are
37
- # removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
38
- # Args:
39
- # orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
40
- # new_img: torch.Tensor representing image with shape = (bands, H, W).
41
- # channels: list of indices representing RGB channels.
42
- # data_mean: list of mean values for each band.
43
- # data_std: list of std values for each band.
44
- # Returns:
45
- # torch.Tensor with shape (num_channels, height, width) for original image
46
- # torch.Tensor with shape (num_channels, height, width) for the other image
47
- # """
48
- #
49
- # stack_c = [], []
50
- #
51
- # for c in channels:
52
- # orig_ch = orig_img[c, ...]
53
- # valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
54
- # valid_mask[orig_ch == NO_DATA_FLOAT] = False
55
- #
56
- # # Back to original data range
57
- # orig_ch = (orig_ch * data_std[c]) + data_mean[c]
58
- # new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
59
- #
60
- # # Rescale (enhancing contrast)
61
- # min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
62
- #
63
- # orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
64
- # new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
65
- #
66
- # # No data as zeros
67
- # orig_ch[~valid_mask] = 0
68
- # new_ch[~valid_mask] = 0
69
- #
70
- # stack_c[0].append(orig_ch)
71
- # stack_c[1].append(new_ch)
72
- #
73
- # # Channels first
74
- # stack_orig = torch.stack(stack_c[0], dim=0)
75
- # stack_rec = torch.stack(stack_c[1], dim=0)
76
- #
77
- # return stack_orig, stack_rec
78
- #
79
- #
80
- # def read_geotiff(file_path: str):
81
- # """ Read all bands from *file_path* and returns image + meta info.
82
- # Args:
83
- # file_path: path to image file.
84
- # Returns:
85
- # np.ndarray with shape (bands, height, width)
86
- # meta info dict
87
- # """
88
- #
89
- # with rasterio.open(file_path) as src:
90
- # img = src.read()
91
- # meta = src.meta
92
- # coords = src.lnglat()
93
- #
94
- # return img, meta, coords
95
- #
96
- #
97
- # def save_geotiff(image, output_path: str, meta: dict):
98
- # """ Save multi-band image in Geotiff file.
99
- # Args:
100
- # image: np.ndarray with shape (bands, height, width)
101
- # output_path: path where to save the image
102
- # meta: dict with meta info.
103
- # """
104
- #
105
- # with rasterio.open(output_path, "w", **meta) as dest:
106
- # for i in range(image.shape[0]):
107
- # dest.write(image[i, :, :], i + 1)
108
- #
109
- # return
110
- #
111
- #
112
- # def _convert_np_uint8(float_image: torch.Tensor):
113
- #
114
- # image = float_image.numpy() * 255.0
115
- # image = image.astype(dtype=np.uint8)
116
- # image = image.transpose((1, 2, 0))
117
- #
118
- # return image
119
- #
120
- #
121
- # def load_example(file_paths: List[str], mean: List[float], std: List[float]):
122
- # """ Build an input example by loading images in *file_paths*.
123
- # Args:
124
- # file_paths: list of file paths .
125
- # mean: list containing mean values for each band in the images in *file_paths*.
126
- # std: list containing std values for each band in the images in *file_paths*.
127
- # Returns:
128
- # np.array containing created example
129
- # list of meta info for each image in *file_paths*
130
- # """
131
- #
132
- # imgs = []
133
- # metas = []
134
- #
135
- # for file in file_paths:
136
- # img, meta = read_geotiff(file)
137
- # img = img[:6]*10000 if img[:6].mean() <= 2 else img[:6]
138
- #
139
- # # Rescaling (don't normalize on nodata)
140
- # img = np.moveaxis(img, 0, -1) # channels last for rescaling
141
- # img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
142
- #
143
- # imgs.append(img)
144
- # metas.append(meta)
145
- #
146
- # imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
147
- # imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
148
- # imgs = np.expand_dims(imgs, axis=0) # add batch dim
149
- #
150
- # return imgs, metas
151
- #
152
- #
153
- # def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
154
- # """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
155
- # Args:
156
- # model: MAE model to run.
157
- # input_data: torch.Tensor with shape (B, C, T, H, W).
158
- # mask_ratio: mask ratio to use.
159
- # device: device where model should run.
160
- # Returns:
161
- # 3 torch.Tensor with shape (B, C, T, H, W).
162
- # """
163
- #
164
- # with torch.no_grad():
165
- # x = input_data.to(device)
166
- #
167
- # _, pred, mask = model(x, mask_ratio)
168
- #
169
- # # Create mask and prediction images (un-patchify)
170
- # mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
171
- # pred_img = model.unpatchify(pred).detach().cpu()
172
- #
173
- # # Mix visible and predicted patches
174
- # rec_img = input_data.clone()
175
- # rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
176
- #
177
- # # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
178
- # mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
179
- #
180
- # return rec_img, mask_img
181
- #
182
- #
183
- # def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
184
- # """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
185
- # Args:
186
- # input_img: input torch.Tensor with shape (C, T, H, W).
187
- # rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
188
- # mask_img: mask torch.Tensor with shape (C, T, H, W).
189
- # channels: list of indices representing RGB channels.
190
- # mean: list of mean values for each band.
191
- # std: list of std values for each band.
192
- # output_dir: directory where to save outputs.
193
- # meta_data: list of dicts with geotiff meta info.
194
- # """
195
- #
196
- # for t in range(input_img.shape[1]):
197
- # rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
198
- # new_img=rec_img[:, t, :, :],
199
- # channels=channels, data_mean=mean,
200
- # data_std=std)
201
- #
202
- # rgb_mask = mask_img[channels, t, :, :] * rgb_orig
203
- #
204
- # # Saving images
205
- #
206
- # save_geotiff(image=_convert_np_uint8(rgb_orig),
207
- # output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
208
- # meta=meta_data[t])
209
- #
210
- # save_geotiff(image=_convert_np_uint8(rgb_pred),
211
- # output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
212
- # meta=meta_data[t])
213
- #
214
- # save_geotiff(image=_convert_np_uint8(rgb_mask),
215
- # output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
216
- # meta=meta_data[t])
217
-
218
 
219
  def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
220
  """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
@@ -245,15 +53,21 @@ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
245
  rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
246
  rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
247
  rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
248
-
 
 
 
 
 
 
 
 
249
  outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
250
 
251
  return outputs
252
 
253
 
254
- def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str):
255
-
256
-
257
  try:
258
  data_files = [x.name for x in data_files]
259
  print('Path extracted from example')
@@ -276,9 +90,7 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
276
 
277
  mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
278
 
279
- if num_frames > 4:
280
- # TODO: Check if we can limit this via UI
281
- logging.warning("Model was only trained with only four timestamps.")
282
 
283
  if torch.cuda.is_available():
284
  device = torch.device('cuda')
@@ -384,7 +196,7 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
384
  return outputs
385
 
386
 
387
- func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint)
388
 
389
  with gr.Blocks() as demo:
390
 
@@ -397,10 +209,14 @@ Additionally, temporal and location information is added to the model input via
397
  More info about the model are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL).\n
398
 
399
  This demo showcases the image reconstruction over one to four timestamps.
400
- The model randomly masks out some proportion of the images and then reconstructing them based on the not masked portion of the images.\n
 
 
 
 
401
  The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
402
  Optionally, the location information is extracted from the tif files while the temporal information can be provided in the filename in the format `<date>T<time>` or `<year><julian day>T<time>` (HLS format).
403
- We recommend submitting images of size 224 to 1000 pixels for faster processing time. Some example images are provided at the end of this page.
404
  ''')
405
  with gr.Row():
406
  with gr.Column():
@@ -409,48 +225,26 @@ We recommend submitting images of size 224 to 1000 pixels for faster processing
409
  btn = gr.Button("Submit")
410
  with gr.Row():
411
  gr.Markdown(value='## Original images')
412
- with gr.Row():
413
- gr.Markdown(value='T1')
414
- gr.Markdown(value='T2')
415
- gr.Markdown(value='T3')
416
- with gr.Row():
417
- out1_orig_t1 = gr.Image(image_mode='RGB')
418
- out2_orig_t2 = gr.Image(image_mode='RGB')
419
- out3_orig_t3 = gr.Image(image_mode='RGB')
420
- with gr.Row():
421
  gr.Markdown(value='## Masked images')
422
- with gr.Row():
423
- gr.Markdown(value='T1')
424
- gr.Markdown(value='T2')
425
- gr.Markdown(value='T3')
426
- with gr.Row():
427
- out4_masked_t1 = gr.Image(image_mode='RGB')
428
- out5_masked_t2 = gr.Image(image_mode='RGB')
429
- out6_masked_t3 = gr.Image(image_mode='RGB')
430
- with gr.Row():
431
- gr.Markdown(value='## Reonstructed images')
432
- with gr.Row():
433
- gr.Markdown(value='T1')
434
- gr.Markdown(value='T2')
435
- gr.Markdown(value='T3')
436
- with gr.Row():
437
- out7_pred_t1 = gr.Image(image_mode='RGB')
438
- out8_pred_t2 = gr.Image(image_mode='RGB')
439
- out9_pred_t3 = gr.Image(image_mode='RGB')
440
-
441
-
442
- btn.click(fn=func,
443
- # inputs=[inp_files, inp_slider],
444
  inputs=inp_files,
445
- outputs=[out1_orig_t1,
446
- out2_orig_t2,
447
- out3_orig_t3,
448
- out4_masked_t1,
449
- out5_masked_t2,
450
- out6_masked_t3,
451
- out7_pred_t1,
452
- out8_pred_t2,
453
- out9_pred_t3])
454
 
455
  with gr.Row():
456
  gr.Examples(examples=[[[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
@@ -463,17 +257,16 @@ We recommend submitting images of size 224 to 1000 pixels for faster processing
463
  os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
464
  os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]]],
465
  inputs=inp_files,
466
- outputs=[out1_orig_t1,
467
- out2_orig_t2,
468
- out3_orig_t3,
469
- out4_masked_t1,
470
- out5_masked_t2,
471
- out6_masked_t3,
472
- out7_pred_t1,
473
- out8_pred_t2,
474
- out9_pred_t3],
475
- fn=func,
476
  cache_examples=True
477
  )
478
 
 
 
 
 
 
 
 
479
  demo.launch()
 
1
+
2
  import os
3
+ import torch
4
+ import yaml
5
+ import numpy as np
6
+ import gradio as gr
7
+ from einops import rearrange
8
+ from functools import partial
9
  from huggingface_hub import hf_hub_download
10
+ from prithvi_mae import PrithviMAE
11
+ from inference import process_channel_group, _convert_np_uint8, load_example, run_model
12
 
13
  # pull files from hub
14
  token = os.environ.get("HF_TOKEN", None)
 
23
  os.system(f'cp {model_def} .')
24
  os.system(f'cp {model_inference} .')
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
28
  """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
 
53
  rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
54
  rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
55
  rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
56
+
57
+ # Add white dummy image values for missing timestamps
58
+ dummy = np.ones((20, 20), dtype=np.uint8) * 255
59
+ num_dummies = 4 - len(rgb_orig_list)
60
+ if num_dummies:
61
+ rgb_orig_list.extend([dummy] * num_dummies)
62
+ rgb_mask_list.extend([dummy] * num_dummies)
63
+ rgb_pred_list.extend([dummy] * num_dummies)
64
+
65
  outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
66
 
67
  return outputs
68
 
69
 
70
+ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, mask_ratio: float = None):
 
 
71
  try:
72
  data_files = [x.name for x in data_files]
73
  print('Path extracted from example')
 
90
 
91
  mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
92
 
93
+ assert num_frames <= 4, "Demo only supports up to four timestamps"
 
 
94
 
95
  if torch.cuda.is_available():
96
  device = torch.device('cuda')
 
196
  return outputs
197
 
198
 
199
+ run_inference = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint)
200
 
201
  with gr.Blocks() as demo:
202
 
 
209
  More info about the model are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL).\n
210
 
211
  This demo showcases the image reconstruction over one to four timestamps.
212
+ The model randomly masks out some proportion of the images and reconstructs them based on the not masked portion of the images.
213
+ The reconstructed images are merged with the visible unmasked patches.
214
+ We recommend submitting images of size 224 to ~1000 pixels for faster processing time.
215
+ Images bigger than 224x224 are processed using a sliding window approach which can lead to artefacts between patches.\n
216
+
217
  The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
218
  Optionally, the location information is extracted from the tif files while the temporal information can be provided in the filename in the format `<date>T<time>` or `<year><julian day>T<time>` (HLS format).
219
+ Some example images are provided at the end of this page.
220
  ''')
221
  with gr.Row():
222
  with gr.Column():
 
225
  btn = gr.Button("Submit")
226
  with gr.Row():
227
  gr.Markdown(value='## Original images')
 
 
 
 
 
 
 
 
 
228
  gr.Markdown(value='## Masked images')
229
+ gr.Markdown(value='## Visible and reconstructed images')
230
+
231
+ original = []
232
+ masked = []
233
+ predicted = []
234
+ timestamps = []
235
+ for t in range(4):
236
+ timestamps.append(gr.Column(visible=t == 0))
237
+ with timestamps[t]:
238
+ with gr.Row():
239
+ gr.Markdown(value=f"Timestamp {t+1}")
240
+ with gr.Row():
241
+ original.append(gr.Image(image_mode='RGB'))
242
+ masked.append(gr.Image(image_mode='RGB'))
243
+ predicted.append(gr.Image(image_mode='RGB'))
244
+
245
+ btn.click(fn=run_inference,
 
 
 
 
 
246
  inputs=inp_files,
247
+ outputs=original + masked + predicted)
 
 
 
 
 
 
 
 
248
 
249
  with gr.Row():
250
  gr.Examples(examples=[[[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
 
257
  os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
258
  os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]]],
259
  inputs=inp_files,
260
+ outputs=original + masked + predicted,
261
+ fn=run_inference,
 
 
 
 
 
 
 
 
262
  cache_examples=True
263
  )
264
 
265
+ def update_visibility(files):
266
+ timestamps = [gr.Column(visible=t < len(files)) for t in range(4)]
267
+
268
+ return timestamps
269
+
270
+ inp_files.change(update_visibility, inp_files, timestamps)
271
+
272
  demo.launch()