Remove 3 timestep restriction

#12
Files changed (2) hide show
  1. Prithvi_run_inference.py +8 -4
  2. README.md +1 -1
Prithvi_run_inference.py CHANGED
@@ -252,7 +252,7 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
252
  params = yaml.safe_load(f)
253
 
254
  # data related
255
- num_frames = params['num_frames']
256
  img_size = params['img_size']
257
  bands = params['bands']
258
  mean = params['data_mean']
@@ -272,8 +272,9 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
272
 
273
  mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
274
 
275
- # We must have *num_frames* files to build one example!
276
- assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
 
277
 
278
  if torch.cuda.is_available():
279
  device = torch.device('cuda')
@@ -310,7 +311,10 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
310
  model.to(device)
311
 
312
  state_dict = torch.load(checkpoint, map_location=device)
313
- model.load_state_dict(state_dict)
 
 
 
314
  print(f"Loaded checkpoint from {checkpoint}")
315
 
316
  # Running model --------------------------------------------------------------------------------
 
252
  params = yaml.safe_load(f)
253
 
254
  # data related
255
+ num_frames = len(data_files)
256
  img_size = params['img_size']
257
  bands = params['bands']
258
  mean = params['data_mean']
 
272
 
273
  mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
274
 
275
+ print(f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n")
276
+ if len(data_files) != 3:
277
+ print("The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary")
278
 
279
  if torch.cuda.is_available():
280
  device = torch.device('cuda')
 
311
  model.to(device)
312
 
313
  state_dict = torch.load(checkpoint, map_location=device)
314
+ # discard fixed pos_embedding weight
315
+ del state_dict['pos_embed']
316
+ del state_dict['decoder_pos_embed']
317
+ model.load_state_dict(state_dict, strict=False)
318
  print(f"Loaded checkpoint from {checkpoint}")
319
 
320
  # Running model --------------------------------------------------------------------------------
README.md CHANGED
@@ -33,7 +33,7 @@ The model follows the [original MAE repo](https://github.com/facebookresearch/ma
33
  4. adding infrared bands besides RGB
34
 
35
  ### Inference and demo
36
- There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of three HLS images (see example below). These images have to be geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
37
 
38
  ```
39
  python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --mask_ratio 0.5
 
33
  4. adding infrared bands besides RGB
34
 
35
  ### Inference and demo
36
+ There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
37
 
38
  ```
39
  python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --mask_ratio 0.5