|
|
|
import os |
|
import torch |
|
import yaml |
|
import numpy as np |
|
import gradio as gr |
|
from einops import rearrange |
|
from functools import partial |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
token = os.environ.get("HF_TOKEN", None) |
|
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename="config.json", token=token) |
|
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename='Prithvi_EO_V2_300M_TL.pt', token=token) |
|
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename='prithvi_mae.py', token=token) |
|
model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename='inference.py', token=token) |
|
os.system(f'cp {model_def} .') |
|
os.system(f'cp {model_inference} .') |
|
|
|
from prithvi_mae import PrithviMAE |
|
from inference import process_channel_group, _convert_np_uint8, load_example, run_model |
|
|
|
def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std): |
|
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp. |
|
Args: |
|
input_img: input torch.Tensor with shape (C, T, H, W). |
|
rec_img: reconstructed torch.Tensor with shape (C, T, H, W). |
|
mask_img: mask torch.Tensor with shape (C, T, H, W). |
|
channels: list of indices representing RGB channels. |
|
mean: list of mean values for each band. |
|
std: list of std values for each band. |
|
output_dir: directory where to save outputs. |
|
meta_data: list of dicts with geotiff meta info. |
|
""" |
|
rgb_orig_list = [] |
|
rgb_mask_list = [] |
|
rgb_pred_list = [] |
|
|
|
for t in range(input_img.shape[1]): |
|
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :], |
|
new_img=rec_img[:, t, :, :], |
|
channels=channels, |
|
mean=mean, |
|
std=std) |
|
|
|
rgb_mask = mask_img[channels, t, :, :] * rgb_orig |
|
|
|
|
|
rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0)) |
|
rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0)) |
|
rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0)) |
|
|
|
|
|
dummy = np.ones((20, 20), dtype=np.uint8) * 255 |
|
num_dummies = 4 - len(rgb_orig_list) |
|
if num_dummies: |
|
rgb_orig_list.extend([dummy] * num_dummies) |
|
rgb_mask_list.extend([dummy] * num_dummies) |
|
rgb_pred_list.extend([dummy] * num_dummies) |
|
|
|
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list |
|
|
|
return outputs |
|
|
|
|
|
def predict_on_images(data_files: list, config_path: str, checkpoint: str, mask_ratio: float = None): |
|
try: |
|
data_files = [x.name for x in data_files] |
|
print('Path extracted from example') |
|
except: |
|
print('Files submitted through UI') |
|
|
|
|
|
print('This is the printout', data_files) |
|
|
|
with open(config_path, 'r') as f: |
|
config = yaml.safe_load(f)['pretrained_cfg'] |
|
|
|
batch_size = 8 |
|
bands = config['bands'] |
|
num_frames = len(data_files) |
|
mean = config['mean'] |
|
std = config['std'] |
|
coords_encoding = config['coords_encoding'] |
|
img_size = config['img_size'] |
|
mask_ratio = mask_ratio or config['mask_ratio'] |
|
|
|
assert num_frames <= 4, "Demo only supports up to four timestamps" |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
else: |
|
device = torch.device('cpu') |
|
|
|
print(f"Using {device} device.\n") |
|
|
|
|
|
|
|
input_data, temporal_coords, location_coords, meta_data = load_example(file_paths=data_files, mean=mean, std=std) |
|
|
|
if len(temporal_coords) != num_frames and 'time' in coords_encoding: |
|
coords_encoding.pop('time') |
|
if not len(location_coords) and 'location' in coords_encoding: |
|
coords_encoding.pop('location') |
|
|
|
|
|
|
|
config.update( |
|
num_frames=num_frames, |
|
coords_encoding=coords_encoding, |
|
) |
|
|
|
model = PrithviMAE(**config) |
|
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print(f"\n--> Model has {total_params:,} parameters.\n") |
|
|
|
model.to(device) |
|
|
|
state_dict = torch.load(checkpoint, map_location=device, weights_only=False) |
|
|
|
for k in list(state_dict.keys()): |
|
if 'pos_embed' in k: |
|
del state_dict[k] |
|
model.load_state_dict(state_dict, strict=False) |
|
print(f"Loaded checkpoint from {checkpoint}") |
|
|
|
|
|
|
|
model.eval() |
|
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] |
|
|
|
|
|
original_h, original_w = input_data.shape[-2:] |
|
pad_h = img_size - (original_h % img_size) |
|
pad_w = img_size - (original_w % img_size) |
|
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect') |
|
|
|
|
|
batch = torch.tensor(input_data, device='cpu') |
|
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) |
|
h1, w1 = windows.shape[3:5] |
|
windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size) |
|
|
|
|
|
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 |
|
windows = torch.tensor_split(windows, num_batches, dim=0) |
|
|
|
|
|
rec_imgs = [] |
|
mask_imgs = [] |
|
for x in windows: |
|
temp_coords = torch.Tensor([temporal_coords] * len(x)) |
|
loc_coords = torch.Tensor([location_coords[0]] * len(x)) |
|
rec_img, mask_img = run_model(model, x, temp_coords, loc_coords, mask_ratio, device) |
|
rec_imgs.append(rec_img) |
|
mask_imgs.append(mask_img) |
|
|
|
rec_imgs = torch.concat(rec_imgs, dim=0) |
|
mask_imgs = torch.concat(mask_imgs, dim=0) |
|
|
|
|
|
rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', |
|
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) |
|
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', |
|
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) |
|
|
|
|
|
rec_imgs_full = rec_imgs[..., :original_h, :original_w] |
|
mask_imgs_full = mask_imgs[..., :original_h, :original_w] |
|
batch_full = batch[..., :original_h, :original_w] |
|
|
|
|
|
for d in meta_data: |
|
d.update(count=3, dtype='uint8', compress='lzw', nodata=0) |
|
|
|
outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...], |
|
channels, mean, std) |
|
|
|
print("Done!") |
|
|
|
return outputs |
|
|
|
|
|
run_inference = partial(predict_on_images, config_path=config_path,checkpoint=checkpoint) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(value='# Prithvi-EO-2.0 image reconstruction demo') |
|
gr.Markdown(value=''' |
|
Prithvi-EO-2.0 is the second generation EO foundation model developed by the IBM and NASA team. |
|
The temporal ViT is train on 4.2M Harmonised Landsat Sentinel 2 (HLS) samples with four timestamps each, using the Masked AutoEncoder learning strategy. |
|
The model includes spatial and temporal attention across multiple patches and timestamps. |
|
Additionally, temporal and location information is added to the model input via embeddings. |
|
More info about the model are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL).\n |
|
|
|
This demo showcases the image reconstruction over one to four timestamps. |
|
The model randomly masks out some proportion of the images and reconstructs them based on the not masked portion of the images. |
|
The reconstructed images are merged with the visible unmasked patches. |
|
We recommend submitting images of size 224 to ~1000 pixels for faster processing time. |
|
Images bigger than 224x224 are processed using a sliding window approach which can lead to artefacts between patches.\n |
|
|
|
The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2. |
|
Optionally, the location information is extracted from the tif files while the temporal information can be provided in the filename in the format `<date>T<time>` or `<year><julian day>T<time>` (HLS format). |
|
Some example images are provided at the end of this page. |
|
''') |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp_files = gr.Files(elem_id='files') |
|
|
|
btn = gr.Button("Submit") |
|
with gr.Row(): |
|
gr.Markdown(value='## Input time series') |
|
gr.Markdown(value='## Masked images') |
|
gr.Markdown(value='## Reconstructed images*') |
|
|
|
original = [] |
|
masked = [] |
|
predicted = [] |
|
timestamps = [] |
|
for t in range(4): |
|
timestamps.append(gr.Column(visible=t == 0)) |
|
with timestamps[t]: |
|
|
|
|
|
with gr.Row(): |
|
original.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) |
|
masked.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) |
|
predicted.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) |
|
|
|
gr.Markdown(value='\* The reconstructed images include the ground truth unmasked patches.') |
|
|
|
btn.click(fn=run_inference, |
|
inputs=inp_files, |
|
outputs=original + masked + predicted) |
|
|
|
with gr.Row(): |
|
gr.Examples(examples=[[[ |
|
os.path.join(os.path.dirname(__file__), "examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif") |
|
]], [[ |
|
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2020305T212629.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2021044T212601.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2021067T213531.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2021067T213531.v2.0_cropped.tif") |
|
]], [[ |
|
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2019119T155911.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2019249T155901.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2019349T160651.v2.0_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2020039T160419.v2.0_cropped.tif") |
|
]]], |
|
inputs=inp_files, |
|
outputs=original + masked + predicted, |
|
fn=run_inference, |
|
cache_examples=True |
|
) |
|
|
|
def update_visibility(files): |
|
timestamps = [gr.Column(visible=t < len(files)) for t in range(4)] |
|
|
|
return timestamps |
|
|
|
inp_files.change(update_visibility, inp_files, timestamps) |
|
|
|
demo.launch() |
|
|