blumenstiel commited on
Commit
f5ac567
Β·
1 Parent(s): ad034fe

Added demo code

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+
3
+
4
+ RUN apt-get update && apt-get install --no-install-recommends -y \
5
+ build-essential \
6
+ python3.9 \
7
+ python3-pip \
8
+ git \
9
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
10
+
11
+ WORKDIR /code
12
+
13
+ COPY ./requirements.txt /code/requirements.txt
14
+
15
+ # Set up a new user named "user" with user ID 1000
16
+ RUN useradd -m -u 1000 user
17
+ # Switch to the "user" user
18
+ USER user
19
+ # Set home to the user's home directory
20
+ ENV HOME=/home/user \
21
+ PATH=/home/user/.local/bin:$PATH \
22
+ PYTHONPATH=$HOME/app \
23
+ PYTHONUNBUFFERED=1 \
24
+ GRADIO_ALLOW_FLAGGING=never \
25
+ GRADIO_NUM_PORTS=1 \
26
+ GRADIO_SERVER_NAME=0.0.0.0 \
27
+ GRADIO_THEME=huggingface \
28
+ SYSTEM=spaces
29
+
30
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
31
+
32
+ # Set the working directory to the user's home directory
33
+ WORKDIR $HOME/app
34
+
35
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
36
+ COPY --chown=user . $HOME/app
37
+
38
+ CMD ["python3", "app.py"]
app.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ # pull files from hub
6
+ yaml_file_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
7
+ filename="Prithvi_EO_V2_300M_TL_config.yaml", token=os.environ.get("token"))
8
+ checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
9
+ filename='Prithvi_EO_V2_300M_TL.pt', token=os.environ.get("token"))
10
+ model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
11
+ filename='prithvi_mae.py', token=os.environ.get("token"))
12
+ model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
13
+ filename='inference.py', token=os.environ.get("token"))
14
+ os.system(f'cp {model_def} .')
15
+ os.system(f'cp {model_inference} .')
16
+
17
+ import os
18
+ import torch
19
+ import yaml
20
+ import numpy as np
21
+ import gradio as gr
22
+ from einops import rearrange
23
+ from functools import partial
24
+ from prithvi_mae import PrithviMAE
25
+ from inference import process_channel_group, read_geotiff, save_geotiff, _convert_np_uint8, load_example, run_model
26
+
27
+
28
+ NO_DATA = -9999
29
+ NO_DATA_FLOAT = 0.0001
30
+ PERCENTILES = (0.1, 99.9)
31
+
32
+
33
+ # def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
34
+ # """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
35
+ # original range using *data_mean* and *data_std* and then lowest and highest percentiles are
36
+ # removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
37
+ # Args:
38
+ # orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
39
+ # new_img: torch.Tensor representing image with shape = (bands, H, W).
40
+ # channels: list of indices representing RGB channels.
41
+ # data_mean: list of mean values for each band.
42
+ # data_std: list of std values for each band.
43
+ # Returns:
44
+ # torch.Tensor with shape (num_channels, height, width) for original image
45
+ # torch.Tensor with shape (num_channels, height, width) for the other image
46
+ # """
47
+ #
48
+ # stack_c = [], []
49
+ #
50
+ # for c in channels:
51
+ # orig_ch = orig_img[c, ...]
52
+ # valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
53
+ # valid_mask[orig_ch == NO_DATA_FLOAT] = False
54
+ #
55
+ # # Back to original data range
56
+ # orig_ch = (orig_ch * data_std[c]) + data_mean[c]
57
+ # new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
58
+ #
59
+ # # Rescale (enhancing contrast)
60
+ # min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
61
+ #
62
+ # orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
63
+ # new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
64
+ #
65
+ # # No data as zeros
66
+ # orig_ch[~valid_mask] = 0
67
+ # new_ch[~valid_mask] = 0
68
+ #
69
+ # stack_c[0].append(orig_ch)
70
+ # stack_c[1].append(new_ch)
71
+ #
72
+ # # Channels first
73
+ # stack_orig = torch.stack(stack_c[0], dim=0)
74
+ # stack_rec = torch.stack(stack_c[1], dim=0)
75
+ #
76
+ # return stack_orig, stack_rec
77
+ #
78
+ #
79
+ # def read_geotiff(file_path: str):
80
+ # """ Read all bands from *file_path* and returns image + meta info.
81
+ # Args:
82
+ # file_path: path to image file.
83
+ # Returns:
84
+ # np.ndarray with shape (bands, height, width)
85
+ # meta info dict
86
+ # """
87
+ #
88
+ # with rasterio.open(file_path) as src:
89
+ # img = src.read()
90
+ # meta = src.meta
91
+ # coords = src.lnglat()
92
+ #
93
+ # return img, meta, coords
94
+ #
95
+ #
96
+ # def save_geotiff(image, output_path: str, meta: dict):
97
+ # """ Save multi-band image in Geotiff file.
98
+ # Args:
99
+ # image: np.ndarray with shape (bands, height, width)
100
+ # output_path: path where to save the image
101
+ # meta: dict with meta info.
102
+ # """
103
+ #
104
+ # with rasterio.open(output_path, "w", **meta) as dest:
105
+ # for i in range(image.shape[0]):
106
+ # dest.write(image[i, :, :], i + 1)
107
+ #
108
+ # return
109
+ #
110
+ #
111
+ # def _convert_np_uint8(float_image: torch.Tensor):
112
+ #
113
+ # image = float_image.numpy() * 255.0
114
+ # image = image.astype(dtype=np.uint8)
115
+ # image = image.transpose((1, 2, 0))
116
+ #
117
+ # return image
118
+ #
119
+ #
120
+ # def load_example(file_paths: List[str], mean: List[float], std: List[float]):
121
+ # """ Build an input example by loading images in *file_paths*.
122
+ # Args:
123
+ # file_paths: list of file paths .
124
+ # mean: list containing mean values for each band in the images in *file_paths*.
125
+ # std: list containing std values for each band in the images in *file_paths*.
126
+ # Returns:
127
+ # np.array containing created example
128
+ # list of meta info for each image in *file_paths*
129
+ # """
130
+ #
131
+ # imgs = []
132
+ # metas = []
133
+ #
134
+ # for file in file_paths:
135
+ # img, meta = read_geotiff(file)
136
+ # img = img[:6]*10000 if img[:6].mean() <= 2 else img[:6]
137
+ #
138
+ # # Rescaling (don't normalize on nodata)
139
+ # img = np.moveaxis(img, 0, -1) # channels last for rescaling
140
+ # img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
141
+ #
142
+ # imgs.append(img)
143
+ # metas.append(meta)
144
+ #
145
+ # imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
146
+ # imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
147
+ # imgs = np.expand_dims(imgs, axis=0) # add batch dim
148
+ #
149
+ # return imgs, metas
150
+ #
151
+ #
152
+ # def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
153
+ # """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
154
+ # Args:
155
+ # model: MAE model to run.
156
+ # input_data: torch.Tensor with shape (B, C, T, H, W).
157
+ # mask_ratio: mask ratio to use.
158
+ # device: device where model should run.
159
+ # Returns:
160
+ # 3 torch.Tensor with shape (B, C, T, H, W).
161
+ # """
162
+ #
163
+ # with torch.no_grad():
164
+ # x = input_data.to(device)
165
+ #
166
+ # _, pred, mask = model(x, mask_ratio)
167
+ #
168
+ # # Create mask and prediction images (un-patchify)
169
+ # mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
170
+ # pred_img = model.unpatchify(pred).detach().cpu()
171
+ #
172
+ # # Mix visible and predicted patches
173
+ # rec_img = input_data.clone()
174
+ # rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
175
+ #
176
+ # # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
177
+ # mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
178
+ #
179
+ # return rec_img, mask_img
180
+ #
181
+ #
182
+ # def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
183
+ # """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
184
+ # Args:
185
+ # input_img: input torch.Tensor with shape (C, T, H, W).
186
+ # rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
187
+ # mask_img: mask torch.Tensor with shape (C, T, H, W).
188
+ # channels: list of indices representing RGB channels.
189
+ # mean: list of mean values for each band.
190
+ # std: list of std values for each band.
191
+ # output_dir: directory where to save outputs.
192
+ # meta_data: list of dicts with geotiff meta info.
193
+ # """
194
+ #
195
+ # for t in range(input_img.shape[1]):
196
+ # rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
197
+ # new_img=rec_img[:, t, :, :],
198
+ # channels=channels, data_mean=mean,
199
+ # data_std=std)
200
+ #
201
+ # rgb_mask = mask_img[channels, t, :, :] * rgb_orig
202
+ #
203
+ # # Saving images
204
+ #
205
+ # save_geotiff(image=_convert_np_uint8(rgb_orig),
206
+ # output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
207
+ # meta=meta_data[t])
208
+ #
209
+ # save_geotiff(image=_convert_np_uint8(rgb_pred),
210
+ # output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
211
+ # meta=meta_data[t])
212
+ #
213
+ # save_geotiff(image=_convert_np_uint8(rgb_mask),
214
+ # output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
215
+ # meta=meta_data[t])
216
+
217
+
218
+ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
219
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
220
+ Args:
221
+ input_img: input torch.Tensor with shape (C, T, H, W).
222
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
223
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
224
+ channels: list of indices representing RGB channels.
225
+ mean: list of mean values for each band.
226
+ std: list of std values for each band.
227
+ output_dir: directory where to save outputs.
228
+ meta_data: list of dicts with geotiff meta info.
229
+ """
230
+ rgb_orig_list = []
231
+ rgb_mask_list = []
232
+ rgb_pred_list = []
233
+
234
+ for t in range(input_img.shape[1]):
235
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
236
+ new_img=rec_img[:, t, :, :],
237
+ channels=channels,
238
+ mean=mean,
239
+ std=std)
240
+
241
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
242
+
243
+ # extract images
244
+ rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
245
+ rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
246
+ rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
247
+
248
+ outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
249
+
250
+ return outputs
251
+
252
+
253
+ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str):
254
+
255
+
256
+ try:
257
+ data_files = [x.name for x in data_files]
258
+ print('Path extracted from example')
259
+ except:
260
+ print('Files submitted through UI')
261
+
262
+ # Get parameters --------
263
+ print('This is the printout', data_files)
264
+
265
+ with open(yaml_file_path, 'r') as f:
266
+ config = yaml.safe_load(f)
267
+
268
+ batch_size = 8
269
+ bands = config['DATA']['BANDS']
270
+ num_frames = len(data_files)
271
+ mean = config['DATA']['MEAN']
272
+ std = config['DATA']['STD']
273
+ coords_encoding = config['MODEL']['COORDS_ENCODING']
274
+ img_size = config['DATA']['INPUT_SIZE'][-1]
275
+
276
+ mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
277
+
278
+ if num_frames > 4:
279
+ # TODO: Check if we can limit this via UI
280
+ logging.warning("Model was only trained with only four timestamps.")
281
+
282
+ if torch.cuda.is_available():
283
+ device = torch.device('cuda')
284
+ else:
285
+ device = torch.device('cpu')
286
+
287
+ print(f"Using {device} device.\n")
288
+
289
+ # Loading data ---------------------------------------------------------------------------------
290
+
291
+ input_data, temporal_coords, location_coords, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
292
+
293
+ if len(temporal_coords) != num_frames and 'time' in coords_encoding:
294
+ coords_encoding.pop('time')
295
+ if not len(location_coords) and 'location' in coords_encoding:
296
+ coords_encoding.pop('location')
297
+
298
+ # Create model and load checkpoint -------------------------------------------------------------
299
+
300
+ model = PrithviMAE(img_size=config['DATA']['INPUT_SIZE'][-2:],
301
+ patch_size=config['MODEL']['PATCH_SIZE'],
302
+ num_frames=num_frames,
303
+ in_chans=len(bands),
304
+ embed_dim=config['MODEL']['EMBED_DIM'],
305
+ depth=config['MODEL']['DEPTH'],
306
+ num_heads=config['MODEL']['NUM_HEADS'],
307
+ decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'],
308
+ decoder_depth=config['MODEL']['DECODER_DEPTH'],
309
+ decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'],
310
+ mlp_ratio=config['MODEL']['MLP_RATIO'],
311
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
312
+ norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'],
313
+ coords_encoding=coords_encoding,
314
+ coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN'])
315
+
316
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
317
+ print(f"\n--> Model has {total_params:,} parameters.\n")
318
+
319
+ model.to(device)
320
+
321
+ state_dict = torch.load(checkpoint, map_location=device, weights_only=False)
322
+ # discard fixed pos_embedding weight
323
+ for k in list(state_dict.keys()):
324
+ if 'pos_embed' in k:
325
+ del state_dict[k]
326
+ model.load_state_dict(state_dict, strict=False)
327
+ print(f"Loaded checkpoint from {checkpoint}")
328
+
329
+ # Running model --------------------------------------------------------------------------------
330
+
331
+ model.eval()
332
+ channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
333
+
334
+ # Reflect pad if not divisible by img_size
335
+ original_h, original_w = input_data.shape[-2:]
336
+ pad_h = img_size - (original_h % img_size)
337
+ pad_w = img_size - (original_w % img_size)
338
+ input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
339
+
340
+ # Build sliding window
341
+ batch = torch.tensor(input_data, device='cpu')
342
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
343
+ h1, w1 = windows.shape[3:5]
344
+ windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
345
+
346
+ # Split into batches if number of windows > batch_size
347
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
348
+ windows = torch.tensor_split(windows, num_batches, dim=0)
349
+
350
+ # Run model
351
+ rec_imgs = []
352
+ mask_imgs = []
353
+ for x in windows:
354
+ temp_coords = torch.Tensor([temporal_coords] * len(x))
355
+ loc_coords = torch.Tensor([location_coords[0]] * len(x))
356
+ rec_img, mask_img = run_model(model, x, temp_coords, loc_coords, mask_ratio, device)
357
+ rec_imgs.append(rec_img)
358
+ mask_imgs.append(mask_img)
359
+
360
+ rec_imgs = torch.concat(rec_imgs, dim=0)
361
+ mask_imgs = torch.concat(mask_imgs, dim=0)
362
+
363
+ # Build images from patches
364
+ rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
365
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
366
+ mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
367
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
368
+
369
+ # Cut padded images back to original size
370
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
371
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
372
+ batch_full = batch[..., :original_h, :original_w]
373
+
374
+ # Build RGB images
375
+ for d in meta_data:
376
+ d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
377
+
378
+ outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
379
+ channels, mean, std)
380
+
381
+ print("Done!")
382
+
383
+ return outputs
384
+
385
+
386
+ func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint)
387
+
388
+ with gr.Blocks() as demo:
389
+
390
+ gr.Markdown(value='# Prithvi-EO-2.0 image reconstruction demo')
391
+ gr.Markdown(value='''
392
+ Prithvi-EO-2.0 is the second generation EO foundation model developed by the IBM and NASA team.
393
+ The temporal ViT is train on 4.2M Harmonised Landsat Sentinel 2 (HLS) samples with four timestamps each, using the Masked AutoEncoder learning strategy.
394
+ The model includes spatial and temporal attention across multiple patches and timestamps.
395
+ Additionally, temporal and location information is added to the model input via embeddings.
396
+ More info about the model are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL).\n
397
+
398
+ This demo showcases the image reconstruction over one to four timestamps.
399
+ The model randomly masks out some proportion of the images and then reconstructing them based on the not masked portion of the images.\n
400
+ The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
401
+ Optionally, the location information is extracted from the tif files while the temporal information can be provided in the filename in the format `<date>T<time>` or `<year><julian day>T<time>` (HLS format).
402
+ We recommend submitting images of size 224 to 1000 pixels for faster processing time. Some example images are provided at the end of this page.
403
+ ''')
404
+ with gr.Row():
405
+ with gr.Column():
406
+ inp_files = gr.Files(elem_id='files')
407
+ # inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'),
408
+ btn = gr.Button("Submit")
409
+ with gr.Row():
410
+ gr.Markdown(value='## Original images')
411
+ with gr.Row():
412
+ gr.Markdown(value='T1')
413
+ gr.Markdown(value='T2')
414
+ gr.Markdown(value='T3')
415
+ with gr.Row():
416
+ out1_orig_t1 = gr.Image(image_mode='RGB')
417
+ out2_orig_t2 = gr.Image(image_mode='RGB')
418
+ out3_orig_t3 = gr.Image(image_mode='RGB')
419
+ with gr.Row():
420
+ gr.Markdown(value='## Masked images')
421
+ with gr.Row():
422
+ gr.Markdown(value='T1')
423
+ gr.Markdown(value='T2')
424
+ gr.Markdown(value='T3')
425
+ with gr.Row():
426
+ out4_masked_t1 = gr.Image(image_mode='RGB')
427
+ out5_masked_t2 = gr.Image(image_mode='RGB')
428
+ out6_masked_t3 = gr.Image(image_mode='RGB')
429
+ with gr.Row():
430
+ gr.Markdown(value='## Reonstructed images')
431
+ with gr.Row():
432
+ gr.Markdown(value='T1')
433
+ gr.Markdown(value='T2')
434
+ gr.Markdown(value='T3')
435
+ with gr.Row():
436
+ out7_pred_t1 = gr.Image(image_mode='RGB')
437
+ out8_pred_t2 = gr.Image(image_mode='RGB')
438
+ out9_pred_t3 = gr.Image(image_mode='RGB')
439
+
440
+
441
+ btn.click(fn=func,
442
+ # inputs=[inp_files, inp_slider],
443
+ inputs=inp_files,
444
+ outputs=[out1_orig_t1,
445
+ out2_orig_t2,
446
+ out3_orig_t3,
447
+ out4_masked_t1,
448
+ out5_masked_t2,
449
+ out6_masked_t3,
450
+ out7_pred_t1,
451
+ out8_pred_t2,
452
+ out9_pred_t3])
453
+
454
+ with gr.Row():
455
+ gr.Examples(examples=[[[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
456
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
457
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]],
458
+ [[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
459
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
460
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]],
461
+ [[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
462
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
463
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]]],
464
+ inputs=inp_files,
465
+ outputs=[out1_orig_t1,
466
+ out2_orig_t2,
467
+ out3_orig_t3,
468
+ out4_masked_t1,
469
+ out5_masked_t2,
470
+ out6_masked_t3,
471
+ out7_pred_t1,
472
+ out8_pred_t2,
473
+ out9_pred_t3],
474
+ fn=func,
475
+ cache_examples=True
476
+ )
477
+
478
+ demo.launch()
examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: a2e1f9d91fedf9b286aaeef5197f4715f3caf2851187356d598d9fe78beb7c6b
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 92b5e2072f9b72fee207b8aec2f91f5c42f42f60950c8ca10d9022192d2cfb1a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 24feb1904fc62268494c9c0d8628124a41621cb4ee705d82cbce7586121c91c5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 757da6d6cb8f34def8c7dc779181ed535521315b68396f5b7daa6e99cceae247
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: b8b4bfff07672d765d2350b324631a588578740d5582c92189ad5ada198a88c5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 1a22d7d5b06f069a62ff3a76f0719158eac10f9cc9c5623e1d4db6e511495097
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
examples/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 9b40ca063665849268b792e21a54069af900dade98c0e728f9586c639f221f0b
  • Pointer size: 132 Bytes
  • Size of remote file: 3.36 MB
examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: b79534c8002421598d00dde32fe69da18b8af7fe1a31fab04e4581406bdc442e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.36 MB
examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 7e276d768c6b7d461cc26020165641512391a59182db1bc50546e2f868b4ff17
  • Pointer size: 132 Bytes
  • Size of remote file: 3.36 MB
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ rasterio
5
+ einops
6
+ huggingface_hub
7
+ gradio