hysts HF staff commited on
Commit
faa4cbd
1 Parent(s): 4af6431

Stop using cpu_offload

Browse files
Files changed (1) hide show
  1. model.py +23 -13
model.py CHANGED
@@ -2,6 +2,7 @@
2
  # The original license file is LICENSE.ControlNet in this repo.
3
  from __future__ import annotations
4
 
 
5
  import pathlib
6
  import sys
7
 
@@ -24,7 +25,6 @@ from annotator.mlsd import apply_mlsd
24
  from annotator.openpose import apply_openpose
25
  from annotator.uniformer import apply_uniformer
26
  from annotator.util import HWC3, resize_image
27
- from share import *
28
 
29
  CONTROLNET_MODEL_IDS = {
30
  'canny': 'lllyasviel/sd-controlnet-canny',
@@ -47,6 +47,8 @@ class Model:
47
  def __init__(self,
48
  base_model_id: str = 'runwayml/stable-diffusion-v1-5',
49
  task_name: str = 'canny'):
 
 
50
  self.base_model_id = ''
51
  self.task_name = ''
52
  self.pipe = self.load_pipe(base_model_id, task_name)
@@ -55,33 +57,41 @@ class Model:
55
  if base_model_id == self.base_model_id and task_name == self.task_name:
56
  return self.pipe
57
  model_id = CONTROLNET_MODEL_IDS[task_name]
58
- controlnet = ControlNetModel.from_pretrained(model_id,
59
- torch_dtype=torch.float16)
60
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
61
- base_model_id,
62
- safety_checker=None,
63
- controlnet=controlnet,
64
- torch_dtype=torch.float16)
65
  pipe.scheduler = UniPCMultistepScheduler.from_config(
66
  pipe.scheduler.config)
67
  pipe.enable_xformers_memory_efficient_attention()
68
- pipe.enable_model_cpu_offload()
 
 
69
  self.base_model_id = base_model_id
70
  self.task_name = task_name
71
  return pipe
72
 
73
  def set_base_model(self, base_model_id: str) -> str:
74
- self.pipe = self.load_pipe(base_model_id, self.task_name)
 
 
 
 
 
 
 
 
75
  return self.base_model_id
76
 
77
  def load_controlnet_weight(self, task_name: str) -> None:
78
  if task_name == self.task_name:
79
  return
 
 
 
80
  model_id = CONTROLNET_MODEL_IDS[task_name]
81
- controlnet = ControlNetModel.from_pretrained(model_id,
82
- torch_dtype=torch.float16)
83
- from accelerate import cpu_offload_with_hook
84
- cpu_offload_with_hook(controlnet, torch.device('cuda:0'))
85
  self.pipe.controlnet = controlnet
86
  self.task_name = task_name
87
 
 
2
  # The original license file is LICENSE.ControlNet in this repo.
3
  from __future__ import annotations
4
 
5
+ import gc
6
  import pathlib
7
  import sys
8
 
 
25
  from annotator.openpose import apply_openpose
26
  from annotator.uniformer import apply_uniformer
27
  from annotator.util import HWC3, resize_image
 
28
 
29
  CONTROLNET_MODEL_IDS = {
30
  'canny': 'lllyasviel/sd-controlnet-canny',
 
47
  def __init__(self,
48
  base_model_id: str = 'runwayml/stable-diffusion-v1-5',
49
  task_name: str = 'canny'):
50
+ self.device = torch.device(
51
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
52
  self.base_model_id = ''
53
  self.task_name = ''
54
  self.pipe = self.load_pipe(base_model_id, task_name)
 
57
  if base_model_id == self.base_model_id and task_name == self.task_name:
58
  return self.pipe
59
  model_id = CONTROLNET_MODEL_IDS[task_name]
60
+ controlnet = ControlNetModel.from_pretrained(model_id)
 
61
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
62
+ base_model_id, safety_checker=None, controlnet=controlnet)
 
 
 
63
  pipe.scheduler = UniPCMultistepScheduler.from_config(
64
  pipe.scheduler.config)
65
  pipe.enable_xformers_memory_efficient_attention()
66
+ pipe.to(self.device)
67
+ torch.cuda.empty_cache()
68
+ gc.collect()
69
  self.base_model_id = base_model_id
70
  self.task_name = task_name
71
  return pipe
72
 
73
  def set_base_model(self, base_model_id: str) -> str:
74
+ if not base_model_id or base_model_id == self.base_model_id:
75
+ return self.base_model_id
76
+ del self.pipe
77
+ torch.cuda.empty_cache()
78
+ gc.collect()
79
+ try:
80
+ self.pipe = self.load_pipe(base_model_id, self.task_name)
81
+ except Exception:
82
+ self.pipe = self.load_pipe(self.base_model_id, self.task_name)
83
  return self.base_model_id
84
 
85
  def load_controlnet_weight(self, task_name: str) -> None:
86
  if task_name == self.task_name:
87
  return
88
+ del self.pipe.controlnet
89
+ torch.cuda.empty_cache()
90
+ gc.collect()
91
  model_id = CONTROLNET_MODEL_IDS[task_name]
92
+ controlnet = ControlNetModel.from_pretrained(model_id).to(self.device)
93
+ torch.cuda.empty_cache()
94
+ gc.collect()
 
95
  self.pipe.controlnet = controlnet
96
  self.task_name = task_name
97