multimodalart HF staff commited on
Commit
1677fe8
1 Parent(s): 3d6f220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -152,6 +152,14 @@ def run_captioning(images, concept_sentence, *captions):
152
  if is_spaces:
153
  run_captioning = spaces.GPU()(run_captioning)
154
 
 
 
 
 
 
 
 
 
155
  def start_training(
156
  lora_name,
157
  concept_sentence,
@@ -167,6 +175,7 @@ def start_training(
167
  profile: Union[gr.OAuthProfile, None],
168
  oauth_token: Union[gr.OAuthToken, None],
169
  ):
 
170
  if not lora_name:
171
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
172
 
@@ -218,7 +227,9 @@ def start_training(
218
  config["config"]["process"][0]["train"]["disable_sampling"] = True
219
 
220
  if(use_more_advanced_options):
221
- config["config"]["process"] = more_advanced_options
 
 
222
 
223
  # Save the updated config
224
  # generate a random name for the config
 
152
  if is_spaces:
153
  run_captioning = spaces.GPU()(run_captioning)
154
 
155
+ def recursive_update(d, u):
156
+ for k, v in u.items():
157
+ if isinstance(v, dict) and v:
158
+ d[k] = recursive_update(d.get(k, {}), v)
159
+ else:
160
+ d[k] = v
161
+ return d
162
+
163
  def start_training(
164
  lora_name,
165
  concept_sentence,
 
175
  profile: Union[gr.OAuthProfile, None],
176
  oauth_token: Union[gr.OAuthToken, None],
177
  ):
178
+
179
  if not lora_name:
180
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
181
 
 
227
  config["config"]["process"][0]["train"]["disable_sampling"] = True
228
 
229
  if(use_more_advanced_options):
230
+ more_advanced_options_dict = yaml.safe_load(more_advanced_options)
231
+ config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
232
+ print(config)
233
 
234
  # Save the updated config
235
  # generate a random name for the config