Remove 3 timestep restriction
#12
by
carlosgomes98
- opened
- Prithvi_run_inference.py +8 -4
- README.md +1 -1
Prithvi_run_inference.py
CHANGED
@@ -252,7 +252,7 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
252 |
params = yaml.safe_load(f)
|
253 |
|
254 |
# data related
|
255 |
-
num_frames =
|
256 |
img_size = params['img_size']
|
257 |
bands = params['bands']
|
258 |
mean = params['data_mean']
|
@@ -272,8 +272,9 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
272 |
|
273 |
mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
|
274 |
|
275 |
-
|
276 |
-
|
|
|
277 |
|
278 |
if torch.cuda.is_available():
|
279 |
device = torch.device('cuda')
|
@@ -310,7 +311,10 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
|
|
310 |
model.to(device)
|
311 |
|
312 |
state_dict = torch.load(checkpoint, map_location=device)
|
313 |
-
|
|
|
|
|
|
|
314 |
print(f"Loaded checkpoint from {checkpoint}")
|
315 |
|
316 |
# Running model --------------------------------------------------------------------------------
|
|
|
252 |
params = yaml.safe_load(f)
|
253 |
|
254 |
# data related
|
255 |
+
num_frames = len(data_files)
|
256 |
img_size = params['img_size']
|
257 |
bands = params['bands']
|
258 |
mean = params['data_mean']
|
|
|
272 |
|
273 |
mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
|
274 |
|
275 |
+
print(f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n")
|
276 |
+
if len(data_files) != 3:
|
277 |
+
print("The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary")
|
278 |
|
279 |
if torch.cuda.is_available():
|
280 |
device = torch.device('cuda')
|
|
|
311 |
model.to(device)
|
312 |
|
313 |
state_dict = torch.load(checkpoint, map_location=device)
|
314 |
+
# discard fixed pos_embedding weight
|
315 |
+
del state_dict['pos_embed']
|
316 |
+
del state_dict['decoder_pos_embed']
|
317 |
+
model.load_state_dict(state_dict, strict=False)
|
318 |
print(f"Loaded checkpoint from {checkpoint}")
|
319 |
|
320 |
# Running model --------------------------------------------------------------------------------
|
README.md
CHANGED
@@ -33,7 +33,7 @@ The model follows the [original MAE repo](https://github.com/facebookresearch/ma
|
|
33 |
4. adding infrared bands besides RGB
|
34 |
|
35 |
### Inference and demo
|
36 |
-
There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of
|
37 |
|
38 |
```
|
39 |
python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --mask_ratio 0.5
|
|
|
33 |
4. adding infrared bands besides RGB
|
34 |
|
35 |
### Inference and demo
|
36 |
+
There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
|
37 |
|
38 |
```
|
39 |
python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --mask_ratio 0.5
|