Wuvin commited on
Commit
e94e18f
1 Parent(s): 6d00af9
gradio_app/custom_models/mvimg_prediction.py CHANGED
@@ -15,6 +15,7 @@ checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
15
  trainer, pipeline = load_pipeline(training_config, checkpoint_path)
16
 
17
  def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
 
18
  pipeline = pipeline.to("cuda")
19
  if isinstance(img_list, Image.Image):
20
  img_list = [img_list]
 
15
  trainer, pipeline = load_pipeline(training_config, checkpoint_path)
16
 
17
  def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
18
+ global pipeline
19
  pipeline = pipeline.to("cuda")
20
  if isinstance(img_list, Image.Image):
21
  img_list = [img_list]
gradio_app/custom_models/normal_prediction.py CHANGED
@@ -10,6 +10,7 @@ checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
10
  trainer, pipeline = load_pipeline(training_config, checkpoint_path)
11
 
12
  def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
 
13
  pipeline = pipeline.to("cuda")
14
 
15
  img_list = image if isinstance(image, list) else [image]
 
10
  trainer, pipeline = load_pipeline(training_config, checkpoint_path)
11
 
12
  def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
13
+ global pipeline
14
  pipeline = pipeline.to("cuda")
15
 
16
  img_list = image if isinstance(image, list) else [image]