|
import logging |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
yaml_file_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename="Prithvi_EO_V2_300M_TL_config.yaml", token=os.environ.get("token")) |
|
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename='Prithvi_EO_V2_300M_TL.pt', token=os.environ.get("token")) |
|
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename='prithvi_mae.py', token=os.environ.get("token")) |
|
model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", |
|
filename='inference.py', token=os.environ.get("token")) |
|
os.system(f'cp {model_def} .') |
|
os.system(f'cp {model_inference} .') |
|
|
|
import os |
|
import torch |
|
import yaml |
|
import numpy as np |
|
import gradio as gr |
|
from einops import rearrange |
|
from functools import partial |
|
from prithvi_mae import PrithviMAE |
|
from inference import process_channel_group, read_geotiff, save_geotiff, _convert_np_uint8, load_example, run_model |
|
|
|
|
|
NO_DATA = -9999 |
|
NO_DATA_FLOAT = 0.0001 |
|
PERCENTILES = (0.1, 99.9) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std): |
|
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp. |
|
Args: |
|
input_img: input torch.Tensor with shape (C, T, H, W). |
|
rec_img: reconstructed torch.Tensor with shape (C, T, H, W). |
|
mask_img: mask torch.Tensor with shape (C, T, H, W). |
|
channels: list of indices representing RGB channels. |
|
mean: list of mean values for each band. |
|
std: list of std values for each band. |
|
output_dir: directory where to save outputs. |
|
meta_data: list of dicts with geotiff meta info. |
|
""" |
|
rgb_orig_list = [] |
|
rgb_mask_list = [] |
|
rgb_pred_list = [] |
|
|
|
for t in range(input_img.shape[1]): |
|
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :], |
|
new_img=rec_img[:, t, :, :], |
|
channels=channels, |
|
mean=mean, |
|
std=std) |
|
|
|
rgb_mask = mask_img[channels, t, :, :] * rgb_orig |
|
|
|
|
|
rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0)) |
|
rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0)) |
|
rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0)) |
|
|
|
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list |
|
|
|
return outputs |
|
|
|
|
|
def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str): |
|
|
|
|
|
try: |
|
data_files = [x.name for x in data_files] |
|
print('Path extracted from example') |
|
except: |
|
print('Files submitted through UI') |
|
|
|
|
|
print('This is the printout', data_files) |
|
|
|
with open(yaml_file_path, 'r') as f: |
|
config = yaml.safe_load(f) |
|
|
|
batch_size = 8 |
|
bands = config['DATA']['BANDS'] |
|
num_frames = len(data_files) |
|
mean = config['DATA']['MEAN'] |
|
std = config['DATA']['STD'] |
|
coords_encoding = config['MODEL']['COORDS_ENCODING'] |
|
img_size = config['DATA']['INPUT_SIZE'][-1] |
|
|
|
mask_ratio = mask_ratio or config['DATA']['MASK_RATIO'] |
|
|
|
if num_frames > 4: |
|
|
|
logging.warning("Model was only trained with only four timestamps.") |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
else: |
|
device = torch.device('cpu') |
|
|
|
print(f"Using {device} device.\n") |
|
|
|
|
|
|
|
input_data, temporal_coords, location_coords, meta_data = load_example(file_paths=data_files, mean=mean, std=std) |
|
|
|
if len(temporal_coords) != num_frames and 'time' in coords_encoding: |
|
coords_encoding.pop('time') |
|
if not len(location_coords) and 'location' in coords_encoding: |
|
coords_encoding.pop('location') |
|
|
|
|
|
|
|
model = PrithviMAE(img_size=config['DATA']['INPUT_SIZE'][-2:], |
|
patch_size=config['MODEL']['PATCH_SIZE'], |
|
num_frames=num_frames, |
|
in_chans=len(bands), |
|
embed_dim=config['MODEL']['EMBED_DIM'], |
|
depth=config['MODEL']['DEPTH'], |
|
num_heads=config['MODEL']['NUM_HEADS'], |
|
decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'], |
|
decoder_depth=config['MODEL']['DECODER_DEPTH'], |
|
decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'], |
|
mlp_ratio=config['MODEL']['MLP_RATIO'], |
|
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), |
|
norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'], |
|
coords_encoding=coords_encoding, |
|
coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN']) |
|
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print(f"\n--> Model has {total_params:,} parameters.\n") |
|
|
|
model.to(device) |
|
|
|
state_dict = torch.load(checkpoint, map_location=device, weights_only=False) |
|
|
|
for k in list(state_dict.keys()): |
|
if 'pos_embed' in k: |
|
del state_dict[k] |
|
model.load_state_dict(state_dict, strict=False) |
|
print(f"Loaded checkpoint from {checkpoint}") |
|
|
|
|
|
|
|
model.eval() |
|
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] |
|
|
|
|
|
original_h, original_w = input_data.shape[-2:] |
|
pad_h = img_size - (original_h % img_size) |
|
pad_w = img_size - (original_w % img_size) |
|
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect') |
|
|
|
|
|
batch = torch.tensor(input_data, device='cpu') |
|
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) |
|
h1, w1 = windows.shape[3:5] |
|
windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size) |
|
|
|
|
|
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 |
|
windows = torch.tensor_split(windows, num_batches, dim=0) |
|
|
|
|
|
rec_imgs = [] |
|
mask_imgs = [] |
|
for x in windows: |
|
temp_coords = torch.Tensor([temporal_coords] * len(x)) |
|
loc_coords = torch.Tensor([location_coords[0]] * len(x)) |
|
rec_img, mask_img = run_model(model, x, temp_coords, loc_coords, mask_ratio, device) |
|
rec_imgs.append(rec_img) |
|
mask_imgs.append(mask_img) |
|
|
|
rec_imgs = torch.concat(rec_imgs, dim=0) |
|
mask_imgs = torch.concat(mask_imgs, dim=0) |
|
|
|
|
|
rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', |
|
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) |
|
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', |
|
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) |
|
|
|
|
|
rec_imgs_full = rec_imgs[..., :original_h, :original_w] |
|
mask_imgs_full = mask_imgs[..., :original_h, :original_w] |
|
batch_full = batch[..., :original_h, :original_w] |
|
|
|
|
|
for d in meta_data: |
|
d.update(count=3, dtype='uint8', compress='lzw', nodata=0) |
|
|
|
outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...], |
|
channels, mean, std) |
|
|
|
print("Done!") |
|
|
|
return outputs |
|
|
|
|
|
func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(value='# Prithvi-EO-2.0 image reconstruction demo') |
|
gr.Markdown(value=''' |
|
Prithvi-EO-2.0 is the second generation EO foundation model developed by the IBM and NASA team. |
|
The temporal ViT is train on 4.2M Harmonised Landsat Sentinel 2 (HLS) samples with four timestamps each, using the Masked AutoEncoder learning strategy. |
|
The model includes spatial and temporal attention across multiple patches and timestamps. |
|
Additionally, temporal and location information is added to the model input via embeddings. |
|
More info about the model are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL).\n |
|
|
|
This demo showcases the image reconstruction over one to four timestamps. |
|
The model randomly masks out some proportion of the images and then reconstructing them based on the not masked portion of the images.\n |
|
The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2. |
|
Optionally, the location information is extracted from the tif files while the temporal information can be provided in the filename in the format `<date>T<time>` or `<year><julian day>T<time>` (HLS format). |
|
We recommend submitting images of size 224 to 1000 pixels for faster processing time. Some example images are provided at the end of this page. |
|
''') |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp_files = gr.Files(elem_id='files') |
|
|
|
btn = gr.Button("Submit") |
|
with gr.Row(): |
|
gr.Markdown(value='## Original images') |
|
with gr.Row(): |
|
gr.Markdown(value='T1') |
|
gr.Markdown(value='T2') |
|
gr.Markdown(value='T3') |
|
with gr.Row(): |
|
out1_orig_t1 = gr.Image(image_mode='RGB') |
|
out2_orig_t2 = gr.Image(image_mode='RGB') |
|
out3_orig_t3 = gr.Image(image_mode='RGB') |
|
with gr.Row(): |
|
gr.Markdown(value='## Masked images') |
|
with gr.Row(): |
|
gr.Markdown(value='T1') |
|
gr.Markdown(value='T2') |
|
gr.Markdown(value='T3') |
|
with gr.Row(): |
|
out4_masked_t1 = gr.Image(image_mode='RGB') |
|
out5_masked_t2 = gr.Image(image_mode='RGB') |
|
out6_masked_t3 = gr.Image(image_mode='RGB') |
|
with gr.Row(): |
|
gr.Markdown(value='## Reonstructed images') |
|
with gr.Row(): |
|
gr.Markdown(value='T1') |
|
gr.Markdown(value='T2') |
|
gr.Markdown(value='T3') |
|
with gr.Row(): |
|
out7_pred_t1 = gr.Image(image_mode='RGB') |
|
out8_pred_t2 = gr.Image(image_mode='RGB') |
|
out9_pred_t3 = gr.Image(image_mode='RGB') |
|
|
|
|
|
btn.click(fn=func, |
|
|
|
inputs=inp_files, |
|
outputs=[out1_orig_t1, |
|
out2_orig_t2, |
|
out3_orig_t3, |
|
out4_masked_t1, |
|
out5_masked_t2, |
|
out6_masked_t3, |
|
out7_pred_t1, |
|
out8_pred_t2, |
|
out9_pred_t3]) |
|
|
|
with gr.Row(): |
|
gr.Examples(examples=[[[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]], |
|
[[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]], |
|
[[os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]]], |
|
inputs=inp_files, |
|
outputs=[out1_orig_t1, |
|
out2_orig_t2, |
|
out3_orig_t3, |
|
out4_masked_t1, |
|
out5_masked_t2, |
|
out6_masked_t3, |
|
out7_pred_t1, |
|
out8_pred_t2, |
|
out9_pred_t3], |
|
fn=func, |
|
cache_examples=True |
|
) |
|
|
|
demo.launch() |
|
|