blumenstiel commited on
Commit
540601d
Β·
1 Parent(s): 58a73c7

Switched to config.json

Browse files
Files changed (1) hide show
  1. app.py +18 -28
app.py CHANGED
@@ -10,8 +10,8 @@ from huggingface_hub import hf_hub_download
10
 
11
  # pull files from hub
12
  token = os.environ.get("HF_TOKEN", None)
13
- yaml_file_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
14
- filename="Prithvi_EO_V2_300M_TL_config.yaml", token=token)
15
  checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
16
  filename='Prithvi_EO_V2_300M_TL.pt', token=token)
17
  model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
@@ -67,7 +67,7 @@ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
67
  return outputs
68
 
69
 
70
- def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, mask_ratio: float = None):
71
  try:
72
  data_files = [x.name for x in data_files]
73
  print('Path extracted from example')
@@ -77,18 +77,17 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
77
  # Get parameters --------
78
  print('This is the printout', data_files)
79
 
80
- with open(yaml_file_path, 'r') as f:
81
- config = yaml.safe_load(f)
82
 
83
  batch_size = 8
84
- bands = config['DATA']['BANDS']
85
  num_frames = len(data_files)
86
- mean = config['DATA']['MEAN']
87
- std = config['DATA']['STD']
88
- coords_encoding = config['MODEL']['COORDS_ENCODING']
89
- img_size = config['DATA']['INPUT_SIZE'][-1]
90
-
91
- mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
92
 
93
  assert num_frames <= 4, "Demo only supports up to four timestamps"
94
 
@@ -110,21 +109,12 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
110
 
111
  # Create model and load checkpoint -------------------------------------------------------------
112
 
113
- model = PrithviMAE(img_size=config['DATA']['INPUT_SIZE'][-2:],
114
- patch_size=config['MODEL']['PATCH_SIZE'],
115
- num_frames=num_frames,
116
- in_chans=len(bands),
117
- embed_dim=config['MODEL']['EMBED_DIM'],
118
- depth=config['MODEL']['DEPTH'],
119
- num_heads=config['MODEL']['NUM_HEADS'],
120
- decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'],
121
- decoder_depth=config['MODEL']['DECODER_DEPTH'],
122
- decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'],
123
- mlp_ratio=config['MODEL']['MLP_RATIO'],
124
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
125
- norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'],
126
- coords_encoding=coords_encoding,
127
- coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN'])
128
 
129
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
130
  print(f"\n--> Model has {total_params:,} parameters.\n")
@@ -196,7 +186,7 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
196
  return outputs
197
 
198
 
199
- run_inference = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint)
200
 
201
  with gr.Blocks() as demo:
202
 
 
10
 
11
  # pull files from hub
12
  token = os.environ.get("HF_TOKEN", None)
13
+ config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
14
+ filename="config.json", token=token)
15
  checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
16
  filename='Prithvi_EO_V2_300M_TL.pt', token=token)
17
  model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
 
67
  return outputs
68
 
69
 
70
+ def predict_on_images(data_files: list, config_path: str, checkpoint: str, mask_ratio: float = None):
71
  try:
72
  data_files = [x.name for x in data_files]
73
  print('Path extracted from example')
 
77
  # Get parameters --------
78
  print('This is the printout', data_files)
79
 
80
+ with open(config_path, 'r') as f:
81
+ config = yaml.safe_load(f)['pretrained_cfg']
82
 
83
  batch_size = 8
84
+ bands = config['bands']
85
  num_frames = len(data_files)
86
+ mean = config['mean']
87
+ std = config['std']
88
+ coords_encoding = config['coords_encoding']
89
+ img_size = config['img_size']
90
+ mask_ratio = mask_ratio or config['mask_ratio']
 
91
 
92
  assert num_frames <= 4, "Demo only supports up to four timestamps"
93
 
 
109
 
110
  # Create model and load checkpoint -------------------------------------------------------------
111
 
112
+ config.update(
113
+ num_frames=num_frames,
114
+ coords_encoding=coords_encoding,
115
+ )
116
+
117
+ model = PrithviMAE(**config)
 
 
 
 
 
 
 
 
 
118
 
119
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
120
  print(f"\n--> Model has {total_params:,} parameters.\n")
 
186
  return outputs
187
 
188
 
189
+ run_inference = partial(predict_on_images, config_path=config_path,checkpoint=checkpoint)
190
 
191
  with gr.Blocks() as demo:
192