RamAnanth1 commited on
Commit
a976ca4
1 Parent(s): 2bb71d8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -9
model.py CHANGED
@@ -90,15 +90,15 @@ class Model:
90
  @torch.inference_mode()
91
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
92
  global current_base
93
- if current_base != base_model:
94
- ckpt = os.path.join("models", base_model)
95
- pl_sd = torch.load(ckpt, map_location="cpu")
96
- if "state_dict" in pl_sd:
97
- sd = pl_sd["state_dict"]
98
- else:
99
- sd = pl_sd
100
- model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
101
- current_base = base_model
102
  con_strength = int((1-con_strength)*50)
103
  if fix_sample == 'True':
104
  seed_everything(42)
 
90
  @torch.inference_mode()
91
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
92
  global current_base
93
+ # if current_base != base_model:
94
+ # ckpt = os.path.join("models", base_model)
95
+ # pl_sd = torch.load(ckpt, map_location="cpu")
96
+ # if "state_dict" in pl_sd:
97
+ # sd = pl_sd["state_dict"]
98
+ # else:
99
+ # sd = pl_sd
100
+ # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
101
+ # current_base = base_model
102
  con_strength = int((1-con_strength)*50)
103
  if fix_sample == 'True':
104
  seed_everything(42)