Eugeoter commited on
Commit
32f1452
1 Parent(s): aa88621
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -6,8 +6,10 @@ from PIL import Image
6
  from huggingface_hub import hf_hub_download
7
  from utils import utils, tools, preprocess
8
 
 
 
9
  VAE_PATH = "madebyollin/sdxl-vae-fp16-fix"
10
- REPO_ID = "Pbihao/ControlNeXt"
11
  UNET_FILENAME = "ControlAny-SDXL/anime_canny/unet.safetensors"
12
  CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors"
13
  CACHE_DIR = None
@@ -19,17 +21,17 @@ DEFAULT_NEGATIVE_PROMPT = "worst quality, abstract, clumsy pose, deformed hand,
19
  def ui():
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model_file = hf_hub_download(
22
- repo_id='neta-art/neta-xl-2.0',
23
- filename='neta-xl-v2.fp16.safetensors',
24
  cache_dir=CACHE_DIR,
25
  )
26
  unet_file = hf_hub_download(
27
- repo_id=REPO_ID,
28
  filename=UNET_FILENAME,
29
  cache_dir=CACHE_DIR,
30
  )
31
  controlnet_file = hf_hub_download(
32
- repo_id=REPO_ID,
33
  filename=CONTROLNET_FILENAME,
34
  cache_dir=CACHE_DIR,
35
  )
@@ -45,7 +47,6 @@ def ui():
45
  use_safetensors=True,
46
  )
47
 
48
- preprocessors = ['canny']
49
  schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM']
50
 
51
  css = """
@@ -57,7 +58,7 @@ def ui():
57
 
58
  with gr.Blocks(css=css) as demo:
59
  gr.Markdown(f"""
60
- # [ControlNeXt-SDXL](https://github.com/dvlab-research/ControlNeXt) Demo
61
  Base model: [Neta-Art-XL-2.0](https://civitai.com/models/410737/neta-art-xl)
62
  """)
63
  with gr.Row():
@@ -80,13 +81,11 @@ def ui():
80
  show_download_button=True,
81
  show_share_button=True,
82
  )
83
- with gr.Row():
84
- processor = gr.Dropdown(
85
- label='Image Preprocessor',
86
- choices=preprocessors,
87
- value='canny',
88
- )
89
- process_button = gr.Button("Process", variant='primary', min_width=96, scale=0)
90
  with gr.Row():
91
  scheduler = gr.Dropdown(
92
  label='Scheduler',
@@ -142,7 +141,6 @@ def ui():
142
  seed,
143
  ):
144
  pipeline.scheduler = tools.get_scheduler(scheduler, pipeline.scheduler.config)
145
-
146
  generator = torch.Generator(device=device).manual_seed(max(0, min(seed, np.iinfo(np.int32).max))) if seed != -1 else None
147
 
148
  if control_image is None:
@@ -167,13 +165,12 @@ def ui():
167
 
168
  def process(
169
  image,
170
- processor,
 
171
  ):
172
- if image is None:
173
- raise gr.Error('Please upload an image.')
174
- processor = preprocess.get_extractor(processor)
175
- image = processor(image)
176
- return image
177
 
178
  generate_button.click(
179
  fn=generate,
@@ -183,7 +180,7 @@ def ui():
183
 
184
  process_button.click(
185
  fn=process,
186
- inputs=[control_image, processor],
187
  outputs=[control_image],
188
  )
189
 
 
6
  from huggingface_hub import hf_hub_download
7
  from utils import utils, tools, preprocess
8
 
9
+ BASE_MODEL_REPO_ID = "neta-art/neta-xl-2.0"
10
+ BASE_MODEL_FILENAME = "neta-xl-v2.fp16.safetensors"
11
  VAE_PATH = "madebyollin/sdxl-vae-fp16-fix"
12
+ CONTROLNEXT_REPO_ID = "Pbihao/ControlNeXt"
13
  UNET_FILENAME = "ControlAny-SDXL/anime_canny/unet.safetensors"
14
  CONTROLNET_FILENAME = "ControlAny-SDXL/anime_canny/controlnet.safetensors"
15
  CACHE_DIR = None
 
21
  def ui():
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model_file = hf_hub_download(
24
+ repo_id=BASE_MODEL_REPO_ID,
25
+ filename=BASE_MODEL_FILENAME,
26
  cache_dir=CACHE_DIR,
27
  )
28
  unet_file = hf_hub_download(
29
+ repo_id=CONTROLNEXT_REPO_ID,
30
  filename=UNET_FILENAME,
31
  cache_dir=CACHE_DIR,
32
  )
33
  controlnet_file = hf_hub_download(
34
+ repo_id=CONTROLNEXT_REPO_ID,
35
  filename=CONTROLNET_FILENAME,
36
  cache_dir=CACHE_DIR,
37
  )
 
47
  use_safetensors=True,
48
  )
49
 
 
50
  schedulers = ['Euler A', 'UniPC', 'Euler', 'DDIM', 'DDPM']
51
 
52
  css = """
 
58
 
59
  with gr.Blocks(css=css) as demo:
60
  gr.Markdown(f"""
61
+ # [ControlNeXt-SDXL](https://github.com/dvlab-research/ControlNeXt) Demo (Anime Canny)
62
  Base model: [Neta-Art-XL-2.0](https://civitai.com/models/410737/neta-art-xl)
63
  """)
64
  with gr.Row():
 
81
  show_download_button=True,
82
  show_share_button=True,
83
  )
84
+ with gr.Accordion(label='Preprocess', open=True):
85
+ with gr.Row():
86
+ threshold1 = gr.Slider(minimum=-1, maximum=255, step=1, value=100, label='Threshold 1', info='-1 for auto')
87
+ threshold2 = gr.Slider(minimum=-1, maximum=255, step=1, value=200, label='Threshold 2', info='-1 for auto')
88
+ process_button = gr.Button("Process", variant='primary', min_width=96, scale=0)
 
 
89
  with gr.Row():
90
  scheduler = gr.Dropdown(
91
  label='Scheduler',
 
141
  seed,
142
  ):
143
  pipeline.scheduler = tools.get_scheduler(scheduler, pipeline.scheduler.config)
 
144
  generator = torch.Generator(device=device).manual_seed(max(0, min(seed, np.iinfo(np.int32).max))) if seed != -1 else None
145
 
146
  if control_image is None:
 
165
 
166
  def process(
167
  image,
168
+ threshold1,
169
+ threshold2,
170
  ):
171
+ threshold1 = None if threshold1 == -1 else threshold1
172
+ threshold2 = None if threshold2 == -1 else threshold2
173
+ return preprocess.canny_extractor(image, threshold1, threshold2)
 
 
174
 
175
  generate_button.click(
176
  fn=generate,
 
180
 
181
  process_button.click(
182
  fn=process,
183
+ inputs=[control_image, threshold1, threshold2],
184
  outputs=[control_image],
185
  )
186