radames commited on
Commit
c962acd
1 Parent(s): 0e674dd

add midas depth and env for MAX_KEYFRAME

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -18,6 +18,7 @@ from skimage import exposure
18
  import src.import_util # noqa: F401
19
  from ControlNet.annotator.canny import CannyDetector
20
  from ControlNet.annotator.hed import HEDdetector
 
21
  from ControlNet.annotator.util import HWC3
22
  from ControlNet.cldm.model import create_model, load_state_dict
23
  from gmflow_module.gmflow.gmflow import GMFlow
@@ -61,7 +62,7 @@ class ProcessingState(Enum):
61
  KEY_IMGS = 2
62
 
63
 
64
- MAX_KEYFRAME = 8
65
 
66
 
67
  class GlobalState:
@@ -111,6 +112,12 @@ class GlobalState:
111
  load_state_dict(huggingface_hub.hf_hub_download(
112
  'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
113
  location=device))
 
 
 
 
 
 
114
  model.to(device)
115
  sd_model_path = model_dict[sd_model]
116
  if len(sd_model_path) > 0:
@@ -162,6 +169,15 @@ class GlobalState:
162
 
163
  self.detector = apply_canny
164
 
 
 
 
 
 
 
 
 
 
165
 
166
  global_state = GlobalState()
167
  global_video_path = None
@@ -716,7 +732,7 @@ with block:
716
  value=0,
717
  step=1)
718
  with gr.Row():
719
- control_type = gr.Dropdown(['HED', 'canny'],
720
  label='Control type',
721
  value='HED')
722
  low_threshold = gr.Slider(label='Canny low threshold',
@@ -756,14 +772,14 @@ with block:
756
  interval = gr.Slider(
757
  label='Key frame frequency (K)',
758
  minimum=1,
759
- maximum=1,
760
  value=1,
761
  step=1,
762
  info='Uniformly sample the key frames every K frames')
763
  keyframe_count = gr.Slider(
764
  label='Number of key frames',
765
  minimum=1,
766
- maximum=1,
767
  value=1,
768
  step=1,
769
  info='To avoid overload, maximum 8 key frames')
 
18
  import src.import_util # noqa: F401
19
  from ControlNet.annotator.canny import CannyDetector
20
  from ControlNet.annotator.hed import HEDdetector
21
+ from ControlNet.annotator.midas import MidasDetector
22
  from ControlNet.annotator.util import HWC3
23
  from ControlNet.cldm.model import create_model, load_state_dict
24
  from gmflow_module.gmflow.gmflow import GMFlow
 
62
  KEY_IMGS = 2
63
 
64
 
65
+ MAX_KEYFRAME = float(os.environ.get('MAX_KEYFRAME', 8))
66
 
67
 
68
  class GlobalState:
 
112
  load_state_dict(huggingface_hub.hf_hub_download(
113
  'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
114
  location=device))
115
+ elif control_type == 'depth':
116
+ model.load_state_dict(
117
+ load_state_dict(huggingface_hub.hf_hub_download(
118
+ 'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'),
119
+ location=device))
120
+
121
  model.to(device)
122
  sd_model_path = model_dict[sd_model]
123
  if len(sd_model_path) > 0:
 
169
 
170
  self.detector = apply_canny
171
 
172
+ elif control_type == 'depth':
173
+ midas = MidasDetector()
174
+
175
+ def apply_midas(x):
176
+ detected_map, _ = midas(x)
177
+ return detected_map
178
+
179
+ self.detector = apply_midas
180
+
181
 
182
  global_state = GlobalState()
183
  global_video_path = None
 
732
  value=0,
733
  step=1)
734
  with gr.Row():
735
+ control_type = gr.Dropdown(['HED', 'canny', 'depth'],
736
  label='Control type',
737
  value='HED')
738
  low_threshold = gr.Slider(label='Canny low threshold',
 
772
  interval = gr.Slider(
773
  label='Key frame frequency (K)',
774
  minimum=1,
775
+ maximum=MAX_KEYFRAME,
776
  value=1,
777
  step=1,
778
  info='Uniformly sample the key frames every K frames')
779
  keyframe_count = gr.Slider(
780
  label='Number of key frames',
781
  minimum=1,
782
+ maximum=MAX_KEYFRAME,
783
  value=1,
784
  step=1,
785
  info='To avoid overload, maximum 8 key frames')