diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
\ No newline at end of file
diff --git a/ControlNet/annotator/canny/__init__.py b/ControlNet/annotator/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b
--- /dev/null
+++ b/ControlNet/annotator/canny/__init__.py
@@ -0,0 +1,6 @@
+import cv2
+class CannyDetector:
+ def __call__(self, img, low_threshold, high_threshold):
+ return cv2.Canny(img, low_threshold, high_threshold)
diff --git a/ControlNet/annotator/ckpts/body_pose_model.pth b/ControlNet/annotator/ckpts/body_pose_model.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9acb77e68f31906a8875f1daef2f3f7ef94acb1e
--- /dev/null
+++ b/ControlNet/annotator/ckpts/body_pose_model.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25a948c16078b0f08e236bda51a385d855ef4c153598947c28c0d47ed94bb746
+size 209267595
diff --git a/ControlNet/annotator/ckpts/ckpts.txt b/ControlNet/annotator/ckpts/ckpts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b
--- /dev/null
+++ b/ControlNet/annotator/ckpts/ckpts.txt
@@ -0,0 +1 @@
+Weights here.
\ No newline at end of file
diff --git a/ControlNet/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt b/ControlNet/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt
new file mode 100644
index 0000000000000000000000000000000000000000..a54fd8ca8d59181d9343d79eb3f6deb6c5319eba
--- /dev/null
+++ b/ControlNet/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:501f0c75b3bca7daec6b3682c5054c09b366765aef6fa3a09d03a5cb4b230853
+size 492757791
diff --git a/ControlNet/annotator/ckpts/hand_pose_model.pth b/ControlNet/annotator/ckpts/hand_pose_model.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f23ccf3413cc8ac8581a82338a3037bc10d573f0
--- /dev/null
+++ b/ControlNet/annotator/ckpts/hand_pose_model.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b76b00d1750901abd07b9f9d8c98cc3385b8fe834a26d4b4f0aad439e75fc600
+size 147341049
diff --git a/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth b/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7e00f54f47838ca7697555699c50dfa3e99880b5
--- /dev/null
+++ b/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5696f168eb2c30d4374bbfd45436f7415bb4d88da29bea97eea0101520fba082
+size 6341481
diff --git a/ControlNet/annotator/ckpts/network-bsds500.pth b/ControlNet/annotator/ckpts/network-bsds500.pth
new file mode 100644
index 0000000000000000000000000000000000000000..36cff8560c17530f48cbb9a43c6e9a0d6f704af3
--- /dev/null
+++ b/ControlNet/annotator/ckpts/network-bsds500.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58a858782f5fa3e0ca3dc92e7a1a609add93987d77be3dfa54f8f8419d881a94
+size 58871680
diff --git a/ControlNet/annotator/ckpts/upernet_global_small.pth b/ControlNet/annotator/ckpts/upernet_global_small.pth
new file mode 100644
index 0000000000000000000000000000000000000000..88e019bbe64cbca0662ca839794e9dffb60f2ac5
--- /dev/null
+++ b/ControlNet/annotator/ckpts/upernet_global_small.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bebfa1264c10381e389d8065056baaadbdadee8ddc6e36770d1ec339dc84d970
+size 206313115
diff --git a/ControlNet/annotator/hed/__init__.py b/ControlNet/annotator/hed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56532c374df5c26f9ec53e2ac0dd924f4534bbdd
--- /dev/null
+++ b/ControlNet/annotator/hed/__init__.py
@@ -0,0 +1,132 @@
+import numpy as np
+import cv2
+import os
+import torch
+from einops import rearrange
+from annotator.util import annotator_ckpts_path
+class Network(torch.nn.Module):
+ def __init__(self, model_path):
+ super().__init__()
+ self.netVggOne = torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+ self.netVggTwo = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+ self.netVggThr = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+ self.netVggFou = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+ self.netVggFiv = torch.nn.Sequential(
+ torch.nn.MaxPool2d(kernel_size=2, stride=2),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+ self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+ self.netCombine = torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
+ torch.nn.Sigmoid()
+ )
+ self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
+ def forward(self, tenInput):
+ tenInput = tenInput * 255.0
+ tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
+ tenVggOne = self.netVggOne(tenInput)
+ tenVggTwo = self.netVggTwo(tenVggOne)
+ tenVggThr = self.netVggThr(tenVggTwo)
+ tenVggFou = self.netVggFou(tenVggThr)
+ tenVggFiv = self.netVggFiv(tenVggFou)
+ tenScoreOne = self.netScoreOne(tenVggOne)
+ tenScoreTwo = self.netScoreTwo(tenVggTwo)
+ tenScoreThr = self.netScoreThr(tenVggThr)
+ tenScoreFou = self.netScoreFou(tenVggFou)
+ tenScoreFiv = self.netScoreFiv(tenVggFiv)
+ tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+ return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
+class HEDdetector:
+ def __init__(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
+ modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+ self.netNetwork = Network(modelpath).cuda().eval()
+ def __call__(self, input_image):
+ assert input_image.ndim == 3
+ input_image = input_image[:, :, ::-1].copy()
+ with torch.no_grad():
+ image_hed = torch.from_numpy(input_image).float().cuda()
+ image_hed = image_hed / 255.0
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ edge = self.netNetwork(image_hed)[0]
+ edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
+ return edge[0]
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+ y = np.zeros_like(x)
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
diff --git a/ControlNet/annotator/midas/__init__.py b/ControlNet/annotator/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc5ac03eea6f5ba7968706f1863c8bc4f8aaaf6a
--- /dev/null
+++ b/ControlNet/annotator/midas/__init__.py
@@ -0,0 +1,38 @@
+import cv2
+import numpy as np
+import torch
+from einops import rearrange
+from .api import MiDaSInference
+class MidasDetector:
+ def __init__(self):
+ self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
+ assert input_image.ndim == 3
+ image_depth = input_image
+ with torch.no_grad():
+ image_depth = torch.from_numpy(image_depth).float().cuda()
+ image_depth = image_depth / 127.5 - 1.0
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
+ depth = self.model(image_depth)[0]
+ depth_pt = depth.clone()
+ depth_pt -= torch.min(depth_pt)
+ depth_pt /= torch.max(depth_pt)
+ depth_pt = depth_pt.cpu().numpy()
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+ depth_np = depth.cpu().numpy()
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
+ z = np.ones_like(x) * a
+ x[depth_pt < bg_th] = 0
+ y[depth_pt < bg_th] = 0
+ normal = np.stack([x, y, z], axis=2)
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
+ return depth_image, normal_image
diff --git a/ControlNet/annotator/midas/api.py b/ControlNet/annotator/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ab9f15bf96bbaffcee0e3e29fc9d3979d6c32e8
--- /dev/null
+++ b/ControlNet/annotator/midas/api.py
@@ -0,0 +1,169 @@
+# based on https://github.com/isl-org/MiDaS
+import cv2
+import os
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+from .midas.dpt_depth import DPTDepthModel
+from .midas.midas_net import MidasNet
+from .midas.midas_net_custom import MidasNet_small
+from .midas.transforms import Resize, NormalizeImage, PrepareForNet
+from annotator.util import annotator_ckpts_path
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
+ "midas_v21": "",
+ "midas_v21_small": "",
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+ return transform
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ if not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+ return model.eval(), transform
+class MiDaSInference(nn.Module):
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+ def forward(self, x):
+ with torch.no_grad():
+ prediction = self.model(x)
+ return prediction
diff --git a/ControlNet/annotator/midas/midas/__init__.py b/ControlNet/annotator/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/annotator/midas/midas/base_model.py b/ControlNet/annotator/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/ControlNet/annotator/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+ self.load_state_dict(parameters)
diff --git a/ControlNet/annotator/midas/midas/blocks.py b/ControlNet/annotator/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/ControlNet/annotator/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+ return pretrained, scratch
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+ return pretrained
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+ return pretrained
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: interpolated data
+ """
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+ return x
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+ def __init__(self, features):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.relu = nn.ReLU(inplace=True)
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ return out + x
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+ def __init__(self, features):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+ output = self.resConfUnit2(output)
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+ return output
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+ def __init__(self, features, activation, bn):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+ self.bn = bn
+ self.groups=1
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+ if self.groups > 1:
+ out = self.conv_merge(out)
+ return self.skip_add.add(out, x)
+ # return out + x
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups=1
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+ self.skip_add = nn.quantized.FloatFunctional()
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+ output = self.resConfUnit2(output)
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+ output = self.out_conv(output)
+ return output
diff --git a/ControlNet/annotator/midas/midas/dpt_depth.py b/ControlNet/annotator/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/ControlNet/annotator/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+ super(DPT, self).__init__()
+ self.channels_last = channels_last
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+ self.scratch.output_conv = head
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+ out = self.scratch.output_conv(path_1)
+ return out
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+ super().__init__(head, **kwargs)
+ if path is not None:
+ self.load(path)
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
diff --git a/ControlNet/annotator/midas/midas/midas_net.py b/ControlNet/annotator/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/ControlNet/annotator/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+import torch
+import torch.nn as nn
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+ super(MidasNet, self).__init__()
+ use_pretrained = False if path is None else True
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+ if path:
+ self.load(path)
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input data (image)
+ Returns:
+ tensor: depth
+ """
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+ out = self.scratch.output_conv(path_1)
+ return torch.squeeze(out, dim=1)
diff --git a/ControlNet/annotator/midas/midas/midas_net_custom.py b/ControlNet/annotator/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/ControlNet/annotator/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+import torch
+import torch.nn as nn
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+ super(MidasNet_small, self).__init__()
+ use_pretrained = False if path else True
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+ self.groups = 1
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+ self.scratch.activation = nn.ReLU(False)
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+ if path:
+ self.load(path)
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input data (image)
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+ out = self.scratch.output_conv(path_1)
+ return torch.squeeze(out, dim=1)
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/ControlNet/annotator/midas/midas/transforms.py b/ControlNet/annotator/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/ControlNet/annotator/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+ scale = max(scale)
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+ return tuple(shape)
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+ return y
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+ return (new_width, new_height)
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+ return sample
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+ return sample
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+ def __init__(self):
+ pass
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+ return sample
diff --git a/ControlNet/annotator/midas/midas/vit.py b/ControlNet/annotator/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/ControlNet/annotator/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+ def forward(self, x):
+ return x[:, self.start_index :]
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+ return self.project(features)
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+ glob = pretrained.model.forward_flex(x)
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+ return layer_1, layer_2, layer_3, layer_4
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+ B = x.shape[0]
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x
+activations = {}
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+ return hook
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+ return readout_oper
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ pretrained = nn.Module()
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+ pretrained.activations = activations
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+ return pretrained
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+ pretrained = nn.Module()
+ pretrained.model = model
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+ pretrained.activations = activations
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+ return pretrained
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/ControlNet/annotator/midas/utils.py b/ControlNet/annotator/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/ControlNet/annotator/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+def read_pfm(path):
+ """Read pfm file.
+ Args:
+ path (str): path to file
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data, scale
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+ with open(path, "wb") as file:
+ color = None
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+ image = np.flipud(image)
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+ endian = image.dtype.byteorder
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+ file.write("%f\n".encode() % scale)
+ image.tofile(file)
+def read_image(path):
+ """Read image and output RGB image (0-1).
+ Args:
+ path (str): path to file
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+ return img
+def resize_image(img):
+ """Resize image and make it fit for network.
+ Args:
+ img (array): image
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+ return img_resized
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+ return depth_resized
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+ depth_min = depth.min()
+ depth_max = depth.max()
+ max_val = (2**(8*bits))-1
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+ return
diff --git a/ControlNet/annotator/mlsd/__init__.py b/ControlNet/annotator/mlsd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..42af28c682e781b30f691f65a475b53c9f3adc8b
--- /dev/null
+++ b/ControlNet/annotator/mlsd/__init__.py
@@ -0,0 +1,39 @@
+import cv2
+import numpy as np
+import torch
+import os
+from einops import rearrange
+from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
+from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
+from .utils import pred_lines
+from annotator.util import annotator_ckpts_path
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth"
+class MLSDdetector:
+ def __init__(self):
+ model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
+ if not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+ model = MobileV2_MLSD_Large()
+ model.load_state_dict(torch.load(model_path), strict=True)
+ self.model = model.cuda().eval()
+ def __call__(self, input_image, thr_v, thr_d):
+ assert input_image.ndim == 3
+ img = input_image
+ img_output = np.zeros_like(img)
+ try:
+ with torch.no_grad():
+ lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
+ for line in lines:
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
+ except Exception as e:
+ pass
+ return img_output[:, :, 0]
diff --git a/ControlNet/annotator/mlsd/models/mbv2_mlsd_large.py b/ControlNet/annotator/mlsd/models/mbv2_mlsd_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603
--- /dev/null
+++ b/ControlNet/annotator/mlsd/models/mbv2_mlsd_large.py
@@ -0,0 +1,292 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ if self.upscale:
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+ self.features = nn.Sequential(*features)
+ self.fpn_selected = [1, 3, 6, 10, 13]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+ if pretrained:
+ self._load_pretrained_model()
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+ c1, c2, c3, c4, c5 = fpn_features
+ return c1, c2, c3, c4, c5
+ def forward(self, x):
+ return self._forward_impl(x)
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+class MobileV2_MLSD_Large(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Large, self).__init__()
+ self.backbone = MobileNetV2(pretrained=False)
+ ## A, B
+ self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
+ out_c1= 64, out_c2=64,
+ upscale=False)
+ self.block16 = BlockTypeB(128, 64)
+ ## A, B
+ self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
+ out_c1= 64, out_c2= 64)
+ self.block18 = BlockTypeB(128, 64)
+ ## A, B
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block20 = BlockTypeB(128, 64)
+ ## A, B, C
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block22 = BlockTypeB(128, 64)
+ self.block23 = BlockTypeC(64, 16)
+ def forward(self, x):
+ c1, c2, c3, c4, c5 = self.backbone(x)
+ x = self.block15(c4, c5)
+ x = self.block16(x)
+ x = self.block17(c3, x)
+ x = self.block18(x)
+ x = self.block19(c2, x)
+ x = self.block20(x)
+ x = self.block21(c1, x)
+ x = self.block22(x)
+ x = self.block23(x)
+ x = x[:, 7:, :, :]
+ return x
\ No newline at end of file
diff --git a/ControlNet/annotator/mlsd/models/mbv2_mlsd_tiny.py b/ControlNet/annotator/mlsd/models/mbv2_mlsd_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83
--- /dev/null
+++ b/ControlNet/annotator/mlsd/models/mbv2_mlsd_tiny.py
@@ -0,0 +1,275 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ #[6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+ self.features = nn.Sequential(*features)
+ self.fpn_selected = [3, 6, 10]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+ #if pretrained:
+ # self._load_pretrained_model()
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+ c2, c3, c4 = fpn_features
+ return c2, c3, c4
+ def forward(self, x):
+ return self._forward_impl(x)
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+class MobileV2_MLSD_Tiny(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Tiny, self).__init__()
+ self.backbone = MobileNetV2(pretrained=True)
+ self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
+ out_c1= 64, out_c2=64)
+ self.block13 = BlockTypeB(128, 64)
+ self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
+ out_c1= 32, out_c2= 32)
+ self.block15 = BlockTypeB(64, 64)
+ self.block16 = BlockTypeC(64, 16)
+ def forward(self, x):
+ c2, c3, c4 = self.backbone(x)
+ x = self.block12(c3, c4)
+ x = self.block13(x)
+ x = self.block14(c2, x)
+ x = self.block15(x)
+ x = self.block16(x)
+ x = x[:, 7:, :, :]
+ #print(x.shape)
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return x
\ No newline at end of file
diff --git a/ControlNet/annotator/mlsd/utils.py b/ControlNet/annotator/mlsd/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae3cf9420a33a4abae27c48ac4b90938c7d63cc3
--- /dev/null
+++ b/ControlNet/annotator/mlsd/utils.py
@@ -0,0 +1,580 @@
+modified by lihaoweicv
+pytorch version
+Copyright 2021-present NAVER Corp.
+Apache License v2.0
+import os
+import numpy as np
+import cv2
+import torch
+from torch.nn import functional as F
+def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
+ '''
+ tpMap:
+ center: tpMap[1, 0, :, :]
+ displacement: tpMap[1, 1:5, :, :]
+ '''
+ b, c, h, w = tpMap.shape
+ assert b==1, 'only support bsize==1'
+ displacement = tpMap[:, 1:5, :, :][0]
+ center = tpMap[:, 0, :, :]
+ heat = torch.sigmoid(center)
+ hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
+ keep = (hmax == heat).float()
+ heat = heat * keep
+ heat = heat.reshape(-1, )
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ ptss = torch.cat((yy, xx),dim=-1)
+ ptss = ptss.detach().cpu().numpy()
+ scores = scores.detach().cpu().numpy()
+ displacement = displacement.detach().cpu().numpy()
+ displacement = displacement.transpose((1,2,0))
+ return ptss, scores, displacement
+def pred_lines(image, model,
+ input_shape=[512, 512],
+ score_thr=0.10,
+ dist_thr=20.0):
+ h, w, _ = image.shape
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+ resized_image = resized_image.transpose((2,0,1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+ batch_image = torch.from_numpy(batch_image).float().cuda()
+ outputs = model(batch_image)
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2]
+ end = vmap[:, :, 2:]
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+ segments_list = []
+ for center, score in zip(pts, pts_score):
+ y, x = center
+ distance = dist_map[y, x]
+ if score > score_thr and distance > dist_thr:
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ x_start = x + disp_x_start
+ y_start = y + disp_y_start
+ x_end = x + disp_x_end
+ y_end = y + disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+ lines = 2 * np.array(segments_list) # 256 > 512
+ lines[:, 0] = lines[:, 0] * w_ratio
+ lines[:, 1] = lines[:, 1] * h_ratio
+ lines[:, 2] = lines[:, 2] * w_ratio
+ lines[:, 3] = lines[:, 3] * h_ratio
+ return lines
+def pred_squares(image,
+ model,
+ input_shape=[512, 512],
+ params={'score': 0.06,
+ 'outside_ratio': 0.28,
+ 'inside_ratio': 0.45,
+ 'w_overlap': 0.0,
+ 'w_degree': 1.95,
+ 'w_length': 0.0,
+ 'w_area': 1.86,
+ 'w_center': 0.14}):
+ '''
+ shape = [height, width]
+ '''
+ h, w, _ = image.shape
+ original_shape = [h, w]
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+ resized_image = resized_image.transpose((2, 0, 1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+ batch_image = torch.from_numpy(batch_image).float().cuda()
+ outputs = model(batch_image)
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2] # (x, y)
+ end = vmap[:, :, 2:] # (x, y)
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+ junc_list = []
+ segments_list = []
+ for junc, score in zip(pts, pts_score):
+ y, x = junc
+ distance = dist_map[y, x]
+ if score > params['score'] and distance > 20.0:
+ junc_list.append([x, y])
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ d_arrow = 1.0
+ x_start = x + d_arrow * disp_x_start
+ y_start = y + d_arrow * disp_y_start
+ x_end = x + d_arrow * disp_x_end
+ y_end = y + d_arrow * disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+ segments = np.array(segments_list)
+ ####### post processing for squares
+ # 1. get unique lines
+ point = np.array([[0, 0]])
+ point = point[0]
+ start = segments[:, :2]
+ end = segments[:, 2:]
+ diff = start - end
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
+ theta[theta < 0.0] += 180
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
+ d_quant = 1
+ theta_quant = 2
+ hough[:, 0] //= d_quant
+ hough[:, 1] //= theta_quant
+ _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
+ yx_indices = hough[indices, :].astype('int32')
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
+ acc_map_np = acc_map
+ # acc_map = acc_map[None, :, :, None]
+ #
+ # ### fast suppression using tensorflow op
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
+ # _, h, w, _ = acc_map.shape
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
+ # yx = tf.concat([y, x], axis=-1)
+ ### fast suppression using pytorch op
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
+ _,_, h, w = acc_map.shape
+ max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
+ acc_map = acc_map * ( (acc_map == max_acc_map).float() )
+ flatten_acc_map = acc_map.reshape([-1, ])
+ scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ yx = torch.cat((yy, xx), dim=-1)
+ yx = yx.detach().cpu().numpy()
+ topk_values = scores.detach().cpu().numpy()
+ indices = idx_map[yx[:, 0], yx[:, 1]]
+ basis = 5 // 2
+ merged_segments = []
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
+ y, x = yx_pt
+ if max_indice == -1 or value == 0:
+ continue
+ segment_list = []
+ for y_offset in range(-basis, basis + 1):
+ for x_offset in range(-basis, basis + 1):
+ indice = idx_map[y + y_offset, x + x_offset]
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
+ if indice != -1:
+ segment_list.append(segments[indice])
+ if cnt > 1:
+ check_cnt = 1
+ current_hough = hough[indice]
+ for new_indice, new_hough in enumerate(hough):
+ if (current_hough == new_hough).all() and indice != new_indice:
+ segment_list.append(segments[new_indice])
+ check_cnt += 1
+ if check_cnt == cnt:
+ break
+ group_segments = np.array(segment_list).reshape([-1, 2])
+ sorted_group_segments = np.sort(group_segments, axis=0)
+ x_min, y_min = sorted_group_segments[0, :]
+ x_max, y_max = sorted_group_segments[-1, :]
+ deg = theta[max_indice]
+ if deg >= 90:
+ merged_segments.append([x_min, y_max, x_max, y_min])
+ else:
+ merged_segments.append([x_min, y_min, x_max, y_max])
+ # 2. get intersections
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
+ start = new_segments[:, :2] # (x1, y1)
+ end = new_segments[:, 2:] # (x2, y2)
+ new_centers = (start + end) / 2.0
+ diff = start - end
+ dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
+ # ax + by = c
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+ pre_det = a[:, None] * b[None, :]
+ det = pre_det - np.transpose(pre_det)
+ pre_inter_y = a[:, None] * c[None, :]
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
+ pre_inter_x = c[:, None] * b[None, :]
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
+ # 3. get corner information
+ # 3.1 get distance
+ '''
+ dist_segments:
+ | dist(0), dist(1), dist(2), ...|
+ dist_inter_to_segment1:
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
+ ...
+ dist_inter_to_semgnet2:
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ ...
+ '''
+ dist_inter_to_segment1_start = np.sqrt(
+ np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment1_end = np.sqrt(
+ np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_start = np.sqrt(
+ np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_end = np.sqrt(
+ np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ # sort ascending
+ dist_inter_to_segment1 = np.sort(
+ np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+ dist_inter_to_segment2 = np.sort(
+ np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+ # 3.2 get degree
+ inter_to_start = new_centers[:, None, :] - inter_pts
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
+ inter_to_end = new_centers[None, :, :] - inter_pts
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
+ '''
+ B -- G
+ | |
+ C -- R
+ B : blue / G: green / C: cyan / R: red
+ 0 -- 1
+ | |
+ 3 -- 2
+ '''
+ # rename variables
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
+ # sort deg ascending
+ deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
+ deg_diff_map = np.abs(deg1_map - deg2_map)
+ # we only consider the smallest degree of intersect
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
+ # define available degree range
+ deg_range = [60, 120]
+ corner_dict = {corner_info: [] for corner_info in range(4)}
+ inter_points = []
+ for i in range(inter_pts.shape[0]):
+ for j in range(i + 1, inter_pts.shape[1]):
+ # i, j > line index, always i < j
+ x, y = inter_pts[i, j, :]
+ deg1, deg2 = deg_sort[i, j, :]
+ deg_diff = deg_diff_map[i, j]
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
+ if check_degree and check_distance:
+ corner_info = None
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
+ corner_info, color_info = 0, 'blue'
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
+ corner_info, color_info = 1, 'green'
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
+ corner_info, color_info = 2, 'black'
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
+ corner_info, color_info = 3, 'cyan'
+ else:
+ corner_info, color_info = 4, 'red' # we don't use it
+ continue
+ corner_dict[corner_info].append([x, y, i, j])
+ inter_points.append([x, y])
+ square_list = []
+ connect_list = []
+ segments_list = []
+ for corner0 in corner_dict[0]:
+ for corner1 in corner_dict[1]:
+ connect01 = False
+ for corner0_line in corner0[2:]:
+ if corner0_line in corner1[2:]:
+ connect01 = True
+ break
+ if connect01:
+ for corner2 in corner_dict[2]:
+ connect12 = False
+ for corner1_line in corner1[2:]:
+ if corner1_line in corner2[2:]:
+ connect12 = True
+ break
+ if connect12:
+ for corner3 in corner_dict[3]:
+ connect23 = False
+ for corner2_line in corner2[2:]:
+ if corner2_line in corner3[2:]:
+ connect23 = True
+ break
+ if connect23:
+ for corner3_line in corner3[2:]:
+ if corner3_line in corner0[2:]:
+ # SQUARE!!!
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+ square_list:
+ order: 0 > 1 > 2 > 3
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ ...
+ connect_list:
+ order: 01 > 12 > 23 > 30
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ ...
+ segments_list:
+ order: 0 > 1 > 2 > 3
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ ...
+ '''
+ square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
+ connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
+ segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
+ def check_outside_inside(segments_info, connect_idx):
+ # return 'outside or inside', min distance, cover_param, peri_param
+ if connect_idx == segments_info[0]:
+ check_dist_mat = dist_inter_to_segment1
+ else:
+ check_dist_mat = dist_inter_to_segment2
+ i, j = segments_info
+ min_dist, max_dist = check_dist_mat[i, j, :]
+ connect_dist = dist_segments[connect_idx]
+ if max_dist > connect_dist:
+ return 'outside', min_dist, 0, 1
+ else:
+ return 'inside', min_dist, -1, -1
+ top_square = None
+ try:
+ map_size = input_shape[0] / 2
+ squares = np.array(square_list).reshape([-1, 4, 2])
+ score_array = []
+ connect_array = np.array(connect_list)
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
+ # get degree of corners:
+ squares_rollup = np.roll(squares, 1, axis=1)
+ squares_rolldown = np.roll(squares, -1, axis=1)
+ vec1 = squares_rollup - squares
+ normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
+ vec2 = squares_rolldown - squares
+ normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
+ inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
+ squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
+ # get square score
+ overlap_scores = []
+ degree_scores = []
+ length_scores = []
+ for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+ # segments: [4, 2]
+ # connects: [4]
+ '''
+ ###################################### OVERLAP SCORES
+ cover = 0
+ perimeter = 0
+ # check 0 > 1 > 2 > 3
+ square_length = []
+ for start_idx in range(4):
+ end_idx = (start_idx + 1) % 4
+ connect_idx = connects[start_idx] # segment idx of segment01
+ start_segments = segments[start_idx]
+ end_segments = segments[end_idx]
+ start_point = square[start_idx]
+ end_point = square[end_idx]
+ # check whether outside or inside
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
+ connect_idx)
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
+ cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
+ perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
+ square_length.append(
+ dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
+ overlap_scores.append(cover / perimeter)
+ ######################################
+ ###################################### DEGREE SCORES
+ '''
+ deg0 vs deg2
+ deg1 vs deg3
+ '''
+ deg0, deg1, deg2, deg3 = degree
+ deg_ratio1 = deg0 / deg2
+ if deg_ratio1 > 1.0:
+ deg_ratio1 = 1 / deg_ratio1
+ deg_ratio2 = deg1 / deg3
+ if deg_ratio2 > 1.0:
+ deg_ratio2 = 1 / deg_ratio2
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
+ ######################################
+ ###################################### LENGTH SCORES
+ '''
+ len0 vs len2
+ len1 vs len3
+ '''
+ len0, len1, len2, len3 = square_length
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
+ ######################################
+ overlap_scores = np.array(overlap_scores)
+ overlap_scores /= np.max(overlap_scores)
+ degree_scores = np.array(degree_scores)
+ # degree_scores /= np.max(degree_scores)
+ length_scores = np.array(length_scores)
+ ###################################### AREA SCORES
+ area_scores = np.reshape(squares, [-1, 4, 2])
+ area_x = area_scores[:, :, 0]
+ area_y = area_scores[:, :, 1]
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
+ area_scores = 0.5 * np.abs(area_scores + correction)
+ area_scores /= (map_size * map_size) # np.max(area_scores)
+ ######################################
+ ###################################### CENTER SCORES
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
+ # squares: [n, 4, 2]
+ square_centers = np.mean(squares, axis=1) # [n, 2]
+ center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
+ center_scores = center2center / (map_size / np.sqrt(2.0))
+ '''
+ score_w = [overlap, degree, area, center, length]
+ '''
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
+ score_array = params['w_overlap'] * overlap_scores \
+ + params['w_degree'] * degree_scores \
+ + params['w_area'] * area_scores \
+ - params['w_center'] * center_scores \
+ + params['w_length'] * length_scores
+ best_square = []
+ sorted_idx = np.argsort(score_array)[::-1]
+ score_array = score_array[sorted_idx]
+ squares = squares[sorted_idx]
+ except Exception as e:
+ pass
+ '''return list
+ merged_lines, squares, scores
+ '''
+ try:
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
+ except:
+ new_segments = []
+ try:
+ squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
+ squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ squares = []
+ score_array = []
+ try:
+ inter_points = np.array(inter_points)
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ inter_points = []
+ return new_segments, squares, score_array, inter_points
diff --git a/ControlNet/annotator/openpose/__init__.py b/ControlNet/annotator/openpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c26f1b37dae854f51da938da2fa67a8ef48ce5a
--- /dev/null
+++ b/ControlNet/annotator/openpose/__init__.py
@@ -0,0 +1,44 @@
+import os
+import torch
+import numpy as np
+from . import util
+from .body import Body
+from .hand import Hand
+from annotator.util import annotator_ckpts_path
+body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
+hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
+class OpenposeDetector:
+ def __init__(self):
+ body_modelpath = os.path.join(annotator_ckpts_path, "body_pose_model.pth")
+ hand_modelpath = os.path.join(annotator_ckpts_path, "hand_pose_model.pth")
+ if not os.path.exists(hand_modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(body_model_path, model_dir=annotator_ckpts_path)
+ load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path)
+ self.body_estimation = Body(body_modelpath)
+ self.hand_estimation = Hand(hand_modelpath)
+ def __call__(self, oriImg, hand=False):
+ oriImg = oriImg[:, :, ::-1].copy()
+ with torch.no_grad():
+ candidate, subset = self.body_estimation(oriImg)
+ canvas = np.zeros_like(oriImg)
+ canvas = util.draw_bodypose(canvas, candidate, subset)
+ if hand:
+ hands_list = util.handDetect(candidate, subset, oriImg)
+ all_hand_peaks = []
+ for x, y, w, is_left in hands_list:
+ peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :])
+ peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x)
+ peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y)
+ all_hand_peaks.append(peaks)
+ canvas = util.draw_handpose(canvas, all_hand_peaks)
+ return canvas, dict(candidate=candidate.tolist(), subset=subset.tolist())
diff --git a/ControlNet/annotator/openpose/body.py b/ControlNet/annotator/openpose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c3cf7a388b4ac81004524e64125e383bdd455bd
--- /dev/null
+++ b/ControlNet/annotator/openpose/body.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from torchvision import transforms
+from . import util
+from .model import bodypose_model
+class Body(object):
+ def __init__(self, model_path):
+ self.model = bodypose_model()
+ if torch.cuda.is_available():
+ self.model = self.model.cuda()
+ print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+ def __call__(self, oriImg):
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
+ scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre1 = 0.1
+ thre2 = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+ # extract outputs, resize, and remove padding
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
+ paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+ paf_avg += + paf / len(multiplier)
+ all_peaks = []
+ peak_counter = 0
+ for part in range(18):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+ map_left = np.zeros(one_heatmap.shape)
+ map_left[1:, :] = one_heatmap[:-1, :]
+ map_right = np.zeros(one_heatmap.shape)
+ map_right[:-1, :] = one_heatmap[1:, :]
+ map_up = np.zeros(one_heatmap.shape)
+ map_up[:, 1:] = one_heatmap[:, :-1]
+ map_down = np.zeros(one_heatmap.shape)
+ map_down[:, :-1] = one_heatmap[:, 1:]
+ peaks_binary = np.logical_and.reduce(
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+ peak_id = range(peak_counter, peak_counter + len(peaks))
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+ all_peaks.append(peaks_with_score_and_id)
+ peak_counter += len(peaks)
+ # find connection in the specified sequence, center 29 is in the position 15
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ # the middle joints heatmap correpondence
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+ [55, 56], [37, 38], [45, 46]]
+ connection_all = []
+ special_k = []
+ mid_num = 10
+ for k in range(len(mapIdx)):
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+ candA = all_peaks[limbSeq[k][0] - 1]
+ candB = all_peaks[limbSeq[k][1] - 1]
+ nA = len(candA)
+ nB = len(candB)
+ indexA, indexB = limbSeq[k]
+ if (nA != 0 and nB != 0):
+ connection_candidate = []
+ for i in range(nA):
+ for j in range(nB):
+ vec = np.subtract(candB[j][:2], candA[i][:2])
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+ norm = max(0.001, norm)
+ vec = np.divide(vec, norm)
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+ for I in range(len(startend))])
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+ for I in range(len(startend))])
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+ criterion2 = score_with_dist_prior > 0
+ if criterion1 and criterion2:
+ connection_candidate.append(
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+ connection = np.zeros((0, 5))
+ for c in range(len(connection_candidate)):
+ i, j, s = connection_candidate[c][0:3]
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+ if (len(connection) >= min(nA, nB)):
+ break
+ connection_all.append(connection)
+ else:
+ special_k.append(k)
+ connection_all.append([])
+ # last number in each row is the total parts number of that person
+ # the second last number in each row is the score of the overall configuration
+ subset = -1 * np.ones((0, 20))
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
+ for k in range(len(mapIdx)):
+ if k not in special_k:
+ partAs = connection_all[k][:, 0]
+ partBs = connection_all[k][:, 1]
+ indexA, indexB = np.array(limbSeq[k]) - 1
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
+ found = 0
+ subset_idx = [-1, -1]
+ for j in range(len(subset)): # 1:size(subset,1):
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+ subset_idx[found] = j
+ found += 1
+ if found == 1:
+ j = subset_idx[0]
+ if subset[j][indexB] != partBs[i]:
+ subset[j][indexB] = partBs[i]
+ subset[j][-1] += 1
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ elif found == 2: # if found 2 and disjoint, merge them
+ j1, j2 = subset_idx
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
+ subset[j1][-2:] += subset[j2][-2:]
+ subset[j1][-2] += connection_all[k][i][2]
+ subset = np.delete(subset, j2, 0)
+ else: # as like found == 1
+ subset[j1][indexB] = partBs[i]
+ subset[j1][-1] += 1
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ # if find no partA in the subset, create a new subset
+ elif not found and k < 17:
+ row = -1 * np.ones(20)
+ row[indexA] = partAs[i]
+ row[indexB] = partBs[i]
+ row[-1] = 2
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+ subset = np.vstack([subset, row])
+ # delete some rows of subset which has few parts occur
+ deleteIdx = []
+ for i in range(len(subset)):
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+ deleteIdx.append(i)
+ subset = np.delete(subset, deleteIdx, axis=0)
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+ # candidate: x, y, score, id
+ return candidate, subset
+if __name__ == "__main__":
+ body_estimation = Body('../model/body_pose_model.pth')
+ test_image = '../images/ski.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ candidate, subset = body_estimation(oriImg)
+ canvas = util.draw_bodypose(oriImg, candidate, subset)
+ plt.imshow(canvas[:, :, [2, 1, 0]])
+ plt.show()
diff --git a/ControlNet/annotator/openpose/hand.py b/ControlNet/annotator/openpose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0bf17165ad7eb225332b51f4a2aa16718664b2
--- /dev/null
+++ b/ControlNet/annotator/openpose/hand.py
@@ -0,0 +1,86 @@
+import cv2
+import json
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from skimage.measure import label
+from .model import handpose_model
+from . import util
+class Hand(object):
+ def __init__(self, model_path):
+ self.model = handpose_model()
+ if torch.cuda.is_available():
+ self.model = self.model.cuda()
+ print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+ def __call__(self, oriImg):
+ scale_search = [0.5, 1.0, 1.5, 2.0]
+ # scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
+ # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ output = self.model(data).cpu().numpy()
+ # output = self.model(data).numpy()q
+ # extract outputs, resize, and remove padding
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+ heatmap_avg += heatmap / len(multiplier)
+ all_peaks = []
+ for part in range(21):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+ # 全部小于阈值
+ if np.sum(binary) == 0:
+ all_peaks.append([0, 0])
+ continue
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+ label_img[label_img != max_index] = 0
+ map_ori[label_img == 0] = 0
+ y, x = util.npmax(map_ori)
+ all_peaks.append([x, y])
+ return np.array(all_peaks)
+if __name__ == "__main__":
+ hand_estimation = Hand('../model/hand_pose_model.pth')
+ # test_image = '../images/hand.jpg'
+ test_image = '../images/hand.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ peaks = hand_estimation(oriImg)
+ canvas = util.draw_handpose(oriImg, peaks, True)
+ cv2.imshow('', canvas)
+ cv2.waitKey(0)
\ No newline at end of file
diff --git a/ControlNet/annotator/openpose/model.py b/ControlNet/annotator/openpose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dfc80de827a17beccb9b0f3f7588545be78c9de
--- /dev/null
+++ b/ControlNet/annotator/openpose/model.py
@@ -0,0 +1,219 @@
+import torch
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+def make_layers(block, no_relu_layers):
+ layers = []
+ for layer_name, v in block.items():
+ if 'pool' in layer_name:
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+ padding=v[2])
+ layers.append((layer_name, layer))
+ else:
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+ kernel_size=v[2], stride=v[3],
+ padding=v[4])
+ layers.append((layer_name, conv2d))
+ if layer_name not in no_relu_layers:
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+ return nn.Sequential(OrderedDict(layers))
+class bodypose_model(nn.Module):
+ def __init__(self):
+ super(bodypose_model, self).__init__()
+ # these layers have no relu layer
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+ blocks = {}
+ block0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
+ ])
+ # Stage 1
+ block1_1 = OrderedDict([
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+ ])
+ block1_2 = OrderedDict([
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+ ])
+ blocks['block1_1'] = block1_1
+ blocks['block1_2'] = block1_2
+ self.model0 = make_layers(block0, no_relu_layers)
+ # Stages 2 - 6
+ for i in range(2, 7):
+ blocks['block%d_1' % i] = OrderedDict([
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+ ])
+ blocks['block%d_2' % i] = OrderedDict([
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+ ])
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+ self.model1_1 = blocks['block1_1']
+ self.model2_1 = blocks['block2_1']
+ self.model3_1 = blocks['block3_1']
+ self.model4_1 = blocks['block4_1']
+ self.model5_1 = blocks['block5_1']
+ self.model6_1 = blocks['block6_1']
+ self.model1_2 = blocks['block1_2']
+ self.model2_2 = blocks['block2_2']
+ self.model3_2 = blocks['block3_2']
+ self.model4_2 = blocks['block4_2']
+ self.model5_2 = blocks['block5_2']
+ self.model6_2 = blocks['block6_2']
+ def forward(self, x):
+ out1 = self.model0(x)
+ out1_1 = self.model1_1(out1)
+ out1_2 = self.model1_2(out1)
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
+ out2_1 = self.model2_1(out2)
+ out2_2 = self.model2_2(out2)
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
+ out3_1 = self.model3_1(out3)
+ out3_2 = self.model3_2(out3)
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
+ out4_1 = self.model4_1(out4)
+ out4_2 = self.model4_2(out4)
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
+ out5_1 = self.model5_1(out5)
+ out5_2 = self.model5_2(out5)
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
+ out6_1 = self.model6_1(out6)
+ out6_2 = self.model6_2(out6)
+ return out6_1, out6_2
+class handpose_model(nn.Module):
+ def __init__(self):
+ super(handpose_model, self).__init__()
+ # these layers have no relu layer
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+ # stage 1
+ block1_0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3', [512, 512, 3, 1, 1]),
+ ('conv4_4', [512, 512, 3, 1, 1]),
+ ('conv5_1', [512, 512, 3, 1, 1]),
+ ('conv5_2', [512, 512, 3, 1, 1]),
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
+ ])
+ block1_1 = OrderedDict([
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
+ ])
+ blocks = {}
+ blocks['block1_0'] = block1_0
+ blocks['block1_1'] = block1_1
+ # stage 2-6
+ for i in range(2, 7):
+ blocks['block%d' % i] = OrderedDict([
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+ ])
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+ self.model1_0 = blocks['block1_0']
+ self.model1_1 = blocks['block1_1']
+ self.model2 = blocks['block2']
+ self.model3 = blocks['block3']
+ self.model4 = blocks['block4']
+ self.model5 = blocks['block5']
+ self.model6 = blocks['block6']
+ def forward(self, x):
+ out1_0 = self.model1_0(x)
+ out1_1 = self.model1_1(out1_0)
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
+ out_stage2 = self.model2(concat_stage2)
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+ out_stage3 = self.model3(concat_stage3)
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+ out_stage4 = self.model4(concat_stage4)
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+ out_stage5 = self.model5(concat_stage5)
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+ out_stage6 = self.model6(concat_stage6)
+ return out_stage6
diff --git a/ControlNet/annotator/openpose/util.py b/ControlNet/annotator/openpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f91ae0e65abaf0cbd62d803f56498991141e61b
--- /dev/null
+++ b/ControlNet/annotator/openpose/util.py
@@ -0,0 +1,164 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+ return img_padded, pad
+# transfer caffe model to pytorch which will match the layer name
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+# draw the body keypoint and lims
+def draw_bodypose(canvas, candidate, subset):
+ stickwidth = 4
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+ for i in range(18):
+ for n in range(len(subset)):
+ index = int(subset[n][i])
+ if index == -1:
+ continue
+ x, y = candidate[index][0:2]
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
+ for i in range(17):
+ for n in range(len(subset)):
+ index = subset[n][np.array(limbSeq[i]) - 1]
+ if -1 in index:
+ continue
+ cur_canvas = canvas.copy()
+ Y = candidate[index.astype(int), 0]
+ X = candidate[index.astype(int), 1]
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+ # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
+ # plt.imshow(canvas[:, :, [2, 1, 0]])
+ return canvas
+# image drawed by opencv is not good.
+def draw_handpose(canvas, all_hand_peaks, show_number=False):
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+ for peaks in all_hand_peaks:
+ for ie, e in enumerate(edges):
+ if np.sum(np.all(peaks[e], axis=1)==0)==0:
+ x1, y1 = peaks[e[0]]
+ x2, y2 = peaks[e[1]]
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie/float(len(edges)), 1.0, 1.0])*255, thickness=2)
+ for i, keyponit in enumerate(peaks):
+ x, y = keyponit
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ if show_number:
+ cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
+ return canvas
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(candidate, subset, oriImg):
+ # right hand: wrist 4, elbow 3, shoulder 2
+ # left hand: wrist 7, elbow 6, shoulder 5
+ ratioWristElbow = 0.33
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+ for person in subset.astype(int):
+ # if any of three not detected
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
+ if not (has_left or has_right):
+ continue
+ hands = []
+ #left hand
+ if has_left:
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
+ x1, y1 = candidate[left_shoulder_index][:2]
+ x2, y2 = candidate[left_elbow_index][:2]
+ x3, y3 = candidate[left_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, True])
+ # right hand
+ if has_right:
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
+ x1, y1 = candidate[right_shoulder_index][:2]
+ x2, y2 = candidate[right_elbow_index][:2]
+ x3, y3 = candidate[right_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, False])
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+ x = x3 + ratioWristElbow * (x3 - x2)
+ y = y3 + ratioWristElbow * (y3 - y2)
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+ # x-y refers to the center --> offset to topLeft point
+ # handRectangle.x -= handRectangle.width / 2.f;
+ # handRectangle.y -= handRectangle.height / 2.f;
+ x -= width / 2
+ y -= width / 2 # width = height
+ # overflow the image
+ if x < 0: x = 0
+ if y < 0: y = 0
+ width1 = width
+ width2 = width
+ if x + width > image_width: width1 = image_width - x
+ if y + width > image_height: width2 = image_height - y
+ width = min(width1, width2)
+ # the max hand box value is 20 pixels
+ if width >= 20:
+ detect_result.append([int(x), int(y), int(width), is_left])
+ '''
+ return value: [[x, y, w, True if left hand else False]].
+ width=height since the network require squared input.
+ x, y is the coordinate of top left
+ '''
+ return detect_result
+# get max index of 2d array
+def npmax(array):
+ arrayindex = array.argmax(1)
+ arrayvalue = array.max(1)
+ i = arrayvalue.argmax()
+ j = arrayindex[i]
+ return i, j
diff --git a/ControlNet/annotator/uniformer/__init__.py b/ControlNet/annotator/uniformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6be429542e4908c2b7648e7ee7c9c5f8253e7c94
--- /dev/null
+++ b/ControlNet/annotator/uniformer/__init__.py
@@ -0,0 +1,23 @@
+import os
+from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot
+from annotator.uniformer.mmseg.core.evaluation import get_palette
+from annotator.util import annotator_ckpts_path
+checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth"
+class UniformerDetector:
+ def __init__(self):
+ modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path)
+ config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py")
+ self.model = init_segmentor(config_file, modelpath).cuda()
+ def __call__(self, img):
+ result = inference_segmentor(self.model, img)
+ res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1)
+ return res_img
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/ade20k.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc8b4bb20c981f3db6df7eb52b3dc0744c94cc0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/ade20k.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/chase_db1.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..298594ea925f87f22b37094a2ec50e370aec96a0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/chase_db1.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'ChaseDB1Dataset'
+data_root = 'data/CHASE_DB1'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (960, 999)
+crop_size = (128, 128)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/cityscapes.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f21867c63e1835f6fceb61f066e802fd8fd2a735
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/cityscapes.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'CityscapesDataset'
+data_root = 'data/cityscapes/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 1024)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 1024),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/train',
+ ann_dir='gtFine/train',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py
new file mode 100644
index 0000000000000000000000000000000000000000..336c7b254fe392b4703039fec86a83acdbd2e1a5
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py
@@ -0,0 +1,35 @@
+_base_ = './cityscapes.py'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (769, 769)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2049, 1025),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+data = dict(
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/drive.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e8ff606e0d2a4514ec8b7d2c6c436a32efcbf4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/drive.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'DRIVEDataset'
+data_root = 'data/DRIVE'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (584, 565)
+crop_size = (64, 64)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/hrf.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..242d790eb1b83e75cf6b7eaa7a35c674099311ad
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/hrf.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'HRFDataset'
+data_root = 'data/HRF'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (2336, 3504)
+crop_size = (256, 256)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_context.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff65bad1b86d7e3a5980bb5b9fc55798dc8df5f4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_context.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (520, 520)
+crop_size = (480, 480)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py
new file mode 100644
index 0000000000000000000000000000000000000000..37585abab89834b95cd5bdd993b994fca1db65f6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset59'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (520, 520)
+crop_size = (480, 480)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1d42d0c5781f56dc177d860d856bb34adce555
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py
@@ -0,0 +1,57 @@
+# dataset settings
+dataset_type = 'PascalVOCDataset'
+data_root = 'data/VOCdevkit/VOC2012'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/val.txt',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f23b6717d53ad29f02dd15046802a2631a5076b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
@@ -0,0 +1,9 @@
+_base_ = './pascal_voc12.py'
+# dataset settings
+data = dict(
+ train=dict(
+ ann_dir=['SegmentationClass', 'SegmentationClassAug'],
+ split=[
+ 'ImageSets/Segmentation/train.txt',
+ 'ImageSets/Segmentation/aug.txt'
+ ]))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/datasets/stare.py b/ControlNet/annotator/uniformer/configs/_base_/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f71b25488cc11a6b4d582ac52b5a24e1ad1cf8e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/datasets/stare.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'STAREDataset'
+data_root = 'data/STARE'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (605, 700)
+crop_size = (128, 128)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/default_runtime.py b/ControlNet/annotator/uniformer/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..b564cc4e7e7d9a67dacaaddecb100e4d8f5c005b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/default_runtime.py
@@ -0,0 +1,14 @@
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook', by_epoch=False),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+cudnn_benchmark = True
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/ann_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/ann_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2cb653827e44e6015b3b83bc578003e614a6aa1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/ann_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='ANNHead',
+ in_channels=[1024, 2048],
+ in_index=[2, 3],
+ channels=512,
+ project_channels=256,
+ query_scales=(1, ),
+ key_pool_scales=(1, 3, 6, 8),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f5316cbcf3896ba9de7ca2c801eba512f01d5e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='APCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..794148f576b9e215c3c6963e73dffe98204b7717
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='CCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ recurrence=2,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/cgnet.py b/ControlNet/annotator/uniformer/configs/_base_/models/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff8d9458c877c5db894957e0b1b4597e40da6ab
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/cgnet.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='CGNet',
+ norm_cfg=norm_cfg,
+ in_channels=3,
+ num_channels=(32, 64, 128),
+ num_blocks=(3, 21),
+ dilations=(2, 4),
+ reductions=(8, 16)),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=256,
+ in_index=2,
+ channels=256,
+ num_convs=0,
+ concat_input=False,
+ dropout_ratio=0,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=[
+ 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
+ 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
+ 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
+ 10.396974, 10.055647
+ ])),
+ # model training and testing settings
+ train_cfg=dict(sampler=None),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/danet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/danet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c934939fac48525f22ad86f489a041dd7db7d09
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/danet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pam_channels=64,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7a43bee01422ad4795dd27874e0cd4bb6cbfecf
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='ASPPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dilations=(1, 12, 24, 36),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py b/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cd262999d8b2cb8e14a5c32190ae73f479d8e81
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='ASPPHead',
+ in_channels=64,
+ in_index=4,
+ channels=16,
+ dilations=(1, 12, 24, 36),
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..050e39e091d816df9028d23aa3ecf9db74e441e1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DepthwiseSeparableASPPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dilations=(1, 12, 24, 36),
+ c1_in_channels=256,
+ c1_channels=48,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22ba52640bebd805b3b8d07025e276dfb023759
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DMHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ filter_sizes=(1, 3, 5, 7),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..edb4c174c51e34c103737ba39bfc48bf831e561d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DNLHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dropout_ratio=0.1,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..26adcd430926de0862204a71d345f2543167f27b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='EMAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=256,
+ ema_channels=512,
+ num_bases=64,
+ num_stages=3,
+ momentum=0.1,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..be777123a886503172a95fe0719e956a147bbd68
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py
@@ -0,0 +1,48 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='EncHead',
+ in_channels=[512, 1024, 2048],
+ in_index=(1, 2, 3),
+ channels=512,
+ num_codes=32,
+ use_se_loss=True,
+ add_lateral=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_se_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/fast_scnn.py b/ControlNet/annotator/uniformer/configs/_base_/models/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..32fdeb659355a5ce5ef2cc7c2f30742703811cdf
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/fast_scnn.py
@@ -0,0 +1,57 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='FastSCNN',
+ downsample_dw_channels=(32, 48),
+ global_in_channels=64,
+ global_block_channels=(64, 96, 128),
+ global_block_strides=(2, 2, 1),
+ global_out_channels=128,
+ higher_in_channels=64,
+ lower_in_channels=128,
+ fusion_out_channels=128,
+ out_indices=(0, 1, 2),
+ norm_cfg=norm_cfg,
+ align_corners=False),
+ decode_head=dict(
+ type='DepthwiseSeparableFCNHead',
+ in_channels=128,
+ channels=128,
+ concat_input=False,
+ num_classes=19,
+ in_index=-1,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ auxiliary_head=[
+ dict(
+ type='FCNHead',
+ in_channels=128,
+ channels=32,
+ num_convs=1,
+ num_classes=19,
+ in_index=-2,
+ norm_cfg=norm_cfg,
+ concat_input=False,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ dict(
+ type='FCNHead',
+ in_channels=64,
+ channels=32,
+ num_convs=1,
+ num_classes=19,
+ in_index=-3,
+ norm_cfg=norm_cfg,
+ concat_input=False,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/fcn_hr18.py b/ControlNet/annotator/uniformer/configs/_base_/models/fcn_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e299bc89ada56ca14bbffcbdb08a586b8ed9e9
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/fcn_hr18.py
@@ -0,0 +1,52 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://msra/hrnetv2_w18',
+ backbone=dict(
+ type='HRNet',
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(18, 36)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(18, 36, 72)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(18, 36, 72, 144)))),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=[18, 36, 72, 144],
+ in_index=(0, 1, 2, 3),
+ channels=sum([18, 36, 72, 144]),
+ input_transform='resize_concat',
+ kernel_size=1,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e98f6cc918b6146fc6d613c6918e825ef1355c3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py
@@ -0,0 +1,45 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ num_convs=2,
+ concat_input=True,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py b/ControlNet/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33e7972877f902d0e7d18401ca675e3e4e60a18
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
@@ -0,0 +1,51 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=64,
+ in_index=4,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/fpn_r50.py b/ControlNet/annotator/uniformer/configs/_base_/models/fpn_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..86ab327db92e44c14822d65f1c9277cb007f17c1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/fpn_r50.py
@@ -0,0 +1,36 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ decode_head=dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/fpn_uniformer.py b/ControlNet/annotator/uniformer/configs/_base_/models/fpn_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aae98c5991055bfcc08e82ccdc09f8b1d9f8a8d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/fpn_uniformer.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1),
+ neck=dict(
+ type='FPN',
+ in_channels=[64, 128, 320, 512],
+ out_channels=256,
+ num_outs=4),
+ decode_head=dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole')
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d2ad69f5c22adfe79d5fdabf920217628987166
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='GCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ ratio=1 / 4.,
+ pooling_type='att',
+ fusion_types=('channel_add', ),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..93258242a90695cc94a7c6bd41562d6a75988771
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
@@ -0,0 +1,25 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='MobileNetV3',
+ arch='large',
+ out_indices=(1, 3, 16),
+ norm_cfg=norm_cfg),
+ decode_head=dict(
+ type='LRASPPHead',
+ in_channels=(16, 24, 960),
+ in_index=(0, 1, 2),
+ channels=128,
+ input_transform='multiple_select',
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5674a39854cafd1f2e363bac99c58ccae62f24da
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='NLHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dropout_ratio=0.1,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py b/ControlNet/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c60f62a7cdf3f5c5096a7a7e725e8268fddcb057
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py
@@ -0,0 +1,68 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://msra/hrnetv2_w18',
+ backbone=dict(
+ type='HRNet',
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(18, 36)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(18, 36, 72)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(18, 36, 72, 144)))),
+ decode_head=[
+ dict(
+ type='FCNHead',
+ in_channels=[18, 36, 72, 144],
+ channels=sum([18, 36, 72, 144]),
+ in_index=(0, 1, 2, 3),
+ input_transform='resize_concat',
+ kernel_size=1,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ dict(
+ type='OCRHead',
+ in_channels=[18, 36, 72, 144],
+ in_index=(0, 1, 2, 3),
+ input_transform='resize_concat',
+ channels=512,
+ ocr_channels=256,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..615aa3ff703942b6c22b2d6e9642504dd3e41ebd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=[
+ dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ dict(
+ type='OCRHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ ocr_channels=256,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/pointrend_r50.py b/ControlNet/annotator/uniformer/configs/_base_/models/pointrend_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d323dbf9466d41e0800aa57ef84045f3d874bdf
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/pointrend_r50.py
@@ -0,0 +1,56 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ decode_head=[
+ dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ dict(
+ type='PointHead',
+ in_channels=[256],
+ in_index=[0],
+ channels=256,
+ num_fcs=3,
+ coarse_pred_each_layer=True,
+ dropout_ratio=-1,
+ num_classes=19,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+ ],
+ # model training and testing settings
+ train_cfg=dict(
+ num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
+ test_cfg=dict(
+ mode='whole',
+ subdivision_steps=2,
+ subdivision_num_points=8196,
+ scale_factor=2))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..689513fa9d2a40f14bf0ae4ae61f38f0dcc1b3da
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py
@@ -0,0 +1,49 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='PSAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ mask_size=(97, 97),
+ psa_type='bi-direction',
+ compact=False,
+ shrink_factor=2,
+ normalization_factor=1.0,
+ psa_softmax=True,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py b/ControlNet/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..f451e08ad2eb0732dcb806b1851eb978d4acf136
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='PSPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py b/ControlNet/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcff9ec4f41fad158344ecd77313dc14564f3682
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='PSPHead',
+ in_channels=64,
+ in_index=4,
+ channels=16,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/upernet_r50.py b/ControlNet/annotator/uniformer/configs/_base_/models/upernet_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..10974962fdd7136031fd06de1700f497d355ceaa
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/upernet_r50.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[256, 512, 1024, 2048],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/ControlNet/annotator/uniformer/configs/_base_/models/upernet_uniformer.py b/ControlNet/annotator/uniformer/configs/_base_/models/upernet_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..41aa4db809dc6e2c508e98051f61807d07477903
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/models/upernet_uniformer.py
@@ -0,0 +1,43 @@
+# model settings
+norm_cfg = dict(type='BN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[64, 128, 320, 512],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=320,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
\ No newline at end of file
diff --git a/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_160k.py b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_160k.py
new file mode 100644
index 0000000000000000000000000000000000000000..52603890b10f25faf8eec9f9e5a4468fae09b811
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_160k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=160000)
+checkpoint_config = dict(by_epoch=False, interval=16000)
+evaluation = dict(interval=16000, metric='mIoU')
diff --git a/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_20k.py b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf780a1b6f6521833c6a5859675147824efa599d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_20k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=20000)
+checkpoint_config = dict(by_epoch=False, interval=2000)
+evaluation = dict(interval=2000, metric='mIoU')
diff --git a/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_40k.py b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_40k.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdbf841abcb26eed87bf76ab816aff4bae0630ee
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_40k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=40000)
+checkpoint_config = dict(by_epoch=False, interval=4000)
+evaluation = dict(interval=4000, metric='mIoU')
diff --git a/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_80k.py b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_80k.py
new file mode 100644
index 0000000000000000000000000000000000000000..c190cee6bdc7922b688ea75dc8f152fa15c24617
--- /dev/null
+++ b/ControlNet/annotator/uniformer/configs/_base_/schedules/schedule_80k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=80000)
+checkpoint_config = dict(by_epoch=False, interval=8000)
+evaluation = dict(interval=8000, metric='mIoU')
diff --git a/ControlNet/annotator/uniformer/exp/upernet_global_small/config.py b/ControlNet/annotator/uniformer/exp/upernet_global_small/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..01db96bf9b0be531aa0eaf62fee51543712f8670
--- /dev/null
+++ b/ControlNet/annotator/uniformer/exp/upernet_global_small/config.py
@@ -0,0 +1,38 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=False,
+ hybrid=False
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
\ No newline at end of file
diff --git a/ControlNet/annotator/uniformer/exp/upernet_global_small/run.sh b/ControlNet/annotator/uniformer/exp/upernet_global_small/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9fb22edfa7a32624ea08a63fe7d720c40db3b696
--- /dev/null
+++ b/ControlNet/annotator/uniformer/exp/upernet_global_small/run.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+work_path=$(dirname $0)
+PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 \
+ tools/train.py ${work_path}/config.py \
+ --launcher pytorch \
+ --options model.backbone.pretrained_path='your_model_path/uniformer_small_in1k.pth' \
+ --work-dir ${work_path}/ckpt \
+ 2>&1 | tee -a ${work_path}/log.txt
diff --git a/ControlNet/annotator/uniformer/exp/upernet_global_small/test.sh b/ControlNet/annotator/uniformer/exp/upernet_global_small/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d9a85e7a0d3b7c96b060f473d41254b37a382fcb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/exp/upernet_global_small/test.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+work_path=$(dirname $0)
+PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 \
+ tools/test.py ${work_path}/test_config_h32.py \
+ ${work_path}/ckpt/latest.pth \
+ --launcher pytorch \
+ --eval mIoU \
+ 2>&1 | tee -a ${work_path}/log.txt
diff --git a/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_g.py b/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_g.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43737a98a3b174a9f2fe059c06d511144686459
--- /dev/null
+++ b/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_g.py
@@ -0,0 +1,38 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=False,
+ hybrid=False,
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
\ No newline at end of file
diff --git a/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_h32.py b/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_h32.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31e3874f76f9f7b089ac8834d85df2441af9b0e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_h32.py
@@ -0,0 +1,39 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=False,
+ hybrid=True,
+ window_size=32
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
\ No newline at end of file
diff --git a/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_w32.py b/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_w32.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9e06f029e46c14cb9ddb39319cabe86fef9b44
--- /dev/null
+++ b/ControlNet/annotator/uniformer/exp/upernet_global_small/test_config_w32.py
@@ -0,0 +1,39 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=True,
+ hybrid=False,
+ window_size=32
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
\ No newline at end of file
diff --git a/ControlNet/annotator/uniformer/mmcv/__init__.py b/ControlNet/annotator/uniformer/mmcv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..210a2989138380559f23045b568d0fbbeb918c03
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+from .arraymisc import *
+from .fileio import *
+from .image import *
+from .utils import *
+from .version import *
+from .video import *
+from .visualization import *
+# The following modules are not imported to this level, so mmcv may be used
+# without PyTorch.
+# - runner
+# - parallel
+# - op
diff --git a/ControlNet/annotator/uniformer/mmcv/arraymisc/__init__.py b/ControlNet/annotator/uniformer/mmcv/arraymisc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/arraymisc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .quantization import dequantize, quantize
+__all__ = ['quantize', 'dequantize']
diff --git a/ControlNet/annotator/uniformer/mmcv/arraymisc/quantization.py b/ControlNet/annotator/uniformer/mmcv/arraymisc/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e47a3545780cf071a1ef8195efb0b7b662c8186
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/arraymisc/quantization.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+ return quantized_arr
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
+ min_val) / levels + min_val
+ return dequantized_arr
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/__init__.py b/ControlNet/annotator/uniformer/mmcv/cnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/__init__.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+# yapf: disable
+ ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
+ DepthwiseSeparableConvModule, GeneralizedAttention,
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
+ build_activation_layer, build_conv_layer,
+ build_norm_layer, build_padding_layer, build_plugin_layer,
+ build_upsample_layer, conv_ws_2d, is_norm)
+from .builder import MODELS, build_model_from_cfg
+# yapf: enable
+from .resnet import ResNet, make_res_layer
+from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ constant_init, fuse_conv_bn, get_model_complexity_info,
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
+from .vgg import VGG, make_vgg_layer
+__all__ = [
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/alexnet.py b/ControlNet/annotator/uniformer/mmcv/cnn/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e36b8c7851f895d9ae7f07149f0e707456aab0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/alexnet.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import torch.nn as nn
+class AlexNet(nn.Module):
+ """AlexNet backbone.
+ Args:
+ num_classes (int): number of classes for classification.
+ """
+ def __init__(self, num_classes=-1):
+ super(AlexNet, self).__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ # use default initializer
+ pass
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def forward(self, x):
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+ return x
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/__init__.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .activation import build_activation_layer
+from .context_block import ContextBlock
+from .conv import build_conv_layer
+from .conv2d_adaptive_padding import Conv2dAdaptivePadding
+from .conv_module import ConvModule
+from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
+from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
+from .drop import Dropout, DropPath
+from .generalized_attention import GeneralizedAttention
+from .hsigmoid import HSigmoid
+from .hswish import HSwish
+from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
+from .norm import build_norm_layer, is_norm
+from .padding import build_padding_layer
+from .plugin import build_plugin_layer
+from .scale import Scale
+from .swish import Swish
+from .upsample import build_upsample_layer
+from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
+ Linear, MaxPool2d, MaxPool3d)
+__all__ = [
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
+ 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
+ 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
+ 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/activation.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab2712287d5ef7be2f079dcb54a94b96394eab5
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/activation.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
+from .registry import ACTIVATION_LAYERS
+for module in [
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
+ nn.Sigmoid, nn.Tanh
+ ACTIVATION_LAYERS.register_module(module=module)
+class Clamp(nn.Module):
+ """Clamp activation layer.
+ This activation function is to clamp the feature map value within
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
+ Args:
+ min (Number | optional): Lower-bound of the range to be clamped to.
+ Default to -1.
+ max (Number | optional): Upper-bound of the range to be clamped to.
+ Default to 1.
+ """
+ def __init__(self, min=-1., max=1.):
+ super(Clamp, self).__init__()
+ self.min = min
+ self.max = max
+ def forward(self, x):
+ """Forward function.
+ Args:
+ x (torch.Tensor): The input tensor.
+ Returns:
+ torch.Tensor: Clamped tensor.
+ """
+ return torch.clamp(x, min=self.min, max=self.max)
+class GELU(nn.Module):
+ r"""Applies the Gaussian Error Linear Units function:
+ .. math::
+ \text{GELU}(x) = x * \Phi(x)
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
+ Gaussian Distribution.
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+ .. image:: scripts/activation_images/GELU.png
+ Examples::
+ >>> m = nn.GELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ def forward(self, input):
+ return F.gelu(input)
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
+ ACTIVATION_LAYERS.register_module(module=GELU)
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
+def build_activation_layer(cfg):
+ """Build activation layer.
+ Args:
+ cfg (dict): The activation layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an activation layer.
+ Returns:
+ nn.Module: Created activation layer.
+ """
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/context_block.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/context_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60fdb904c749ce3b251510dff3cc63cea70d42e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/context_block.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from ..utils import constant_init, kaiming_init
+from .registry import PLUGIN_LAYERS
+def last_zero_init(m):
+ if isinstance(m, nn.Sequential):
+ constant_init(m[-1], val=0)
+ else:
+ constant_init(m, val=0)
+class ContextBlock(nn.Module):
+ """ContextBlock module in GCNet.
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
+ (https://arxiv.org/abs/1904.11492) for details.
+ Args:
+ in_channels (int): Channels of the input feature map.
+ ratio (float): Ratio of channels of transform bottleneck
+ pooling_type (str): Pooling method for context modeling.
+ Options are 'att' and 'avg', stand for attention pooling and
+ average pooling respectively. Default: 'att'.
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
+ Options are 'channels_add', 'channel_mul', stand for channelwise
+ addition and multiplication respectively. Default: ('channel_add',)
+ """
+ _abbr_ = 'context_block'
+ def __init__(self,
+ in_channels,
+ ratio,
+ pooling_type='att',
+ fusion_types=('channel_add', )):
+ super(ContextBlock, self).__init__()
+ assert pooling_type in ['avg', 'att']
+ assert isinstance(fusion_types, (list, tuple))
+ valid_fusion_types = ['channel_add', 'channel_mul']
+ assert all([f in valid_fusion_types for f in fusion_types])
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
+ self.in_channels = in_channels
+ self.ratio = ratio
+ self.planes = int(in_channels * ratio)
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ if pooling_type == 'att':
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
+ self.softmax = nn.Softmax(dim=2)
+ else:
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ if 'channel_add' in fusion_types:
+ self.channel_add_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_add_conv = None
+ if 'channel_mul' in fusion_types:
+ self.channel_mul_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_mul_conv = None
+ self.reset_parameters()
+ def reset_parameters(self):
+ if self.pooling_type == 'att':
+ kaiming_init(self.conv_mask, mode='fan_in')
+ self.conv_mask.inited = True
+ if self.channel_add_conv is not None:
+ last_zero_init(self.channel_add_conv)
+ if self.channel_mul_conv is not None:
+ last_zero_init(self.channel_mul_conv)
+ def spatial_pool(self, x):
+ batch, channel, height, width = x.size()
+ if self.pooling_type == 'att':
+ input_x = x
+ # [N, C, H * W]
+ input_x = input_x.view(batch, channel, height * width)
+ # [N, 1, C, H * W]
+ input_x = input_x.unsqueeze(1)
+ # [N, 1, H, W]
+ context_mask = self.conv_mask(x)
+ # [N, 1, H * W]
+ context_mask = context_mask.view(batch, 1, height * width)
+ # [N, 1, H * W]
+ context_mask = self.softmax(context_mask)
+ # [N, 1, H * W, 1]
+ context_mask = context_mask.unsqueeze(-1)
+ # [N, 1, C, 1]
+ context = torch.matmul(input_x, context_mask)
+ # [N, C, 1, 1]
+ context = context.view(batch, channel, 1, 1)
+ else:
+ # [N, C, 1, 1]
+ context = self.avg_pool(x)
+ return context
+ def forward(self, x):
+ # [N, C, 1, 1]
+ context = self.spatial_pool(x)
+ out = x
+ if self.channel_mul_conv is not None:
+ # [N, C, 1, 1]
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+ out = out * channel_mul_term
+ if self.channel_add_conv is not None:
+ # [N, C, 1, 1]
+ channel_add_term = self.channel_add_conv(context)
+ out = out + channel_add_term
+ return out
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf54491997a48ac3e7fadc4183ab7bf3e831024c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+from .registry import CONV_LAYERS
+CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
+CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
+CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
+CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
+def build_conv_layer(cfg, *args, **kwargs):
+ """Build convolution layer.
+ Args:
+ cfg (None or dict): The conv layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an conv layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ Returns:
+ nn.Module: Created conv layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+ layer_type = cfg_.pop('type')
+ if layer_type not in CONV_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+ else:
+ conv_layer = CONV_LAYERS.get(layer_type)
+ layer = conv_layer(*args, **kwargs, **cfg_)
+ return layer
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from torch import nn
+from torch.nn import functional as F
+from .registry import CONV_LAYERS
+class Conv2dAdaptivePadding(nn.Conv2d):
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
+ which applies padding to input (if needed) so that input image gets fully
+ covered by filter and stride you specified. For stride 1, this will ensure
+ that output image size is same as input. For stride of 2, output dimensions
+ will be half, for example.
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
+ dilation, groups, bias)
+ def forward(self, x):
+ img_h, img_w = x.size()[-2:]
+ kernel_h, kernel_w = self.weight.size()[-2:]
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(img_h / stride_h)
+ output_w = math.ceil(img_w / stride_w)
+ pad_h = (
+ max((output_h - 1) * self.stride[0] +
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
+ pad_w = (
+ max((output_w - 1) * self.stride[1] +
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
+ ])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv_module.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..e60e7e62245071c77b652093fddebff3948d7c3e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv_module.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+import torch.nn as nn
+from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm
+from ..utils import constant_init, kaiming_init
+from .activation import build_activation_layer
+from .conv import build_conv_layer
+from .norm import build_norm_layer
+from .padding import build_padding_layer
+from .registry import PLUGIN_LAYERS
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+ _abbr_ = 'conv_block'
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias='auto',
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode='zeros',
+ order=('conv', 'norm', 'act')):
+ super(ConvModule, self).__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert act_cfg is None or isinstance(act_cfg, dict)
+ official_padding_mode = ['zeros', 'circular']
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(['conv', 'norm', 'act'])
+ self.with_norm = norm_cfg is not None
+ self.with_activation = act_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+ if self.with_explicit_padding:
+ pad_cfg = dict(type=padding_mode)
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+ self.add_module(self.norm_name, norm)
+ if self.with_bias:
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn(
+ 'Unnecessary conv bias before batch/instance norm')
+ else:
+ self.norm_name = None
+ # build activation layer
+ if self.with_activation:
+ act_cfg_ = act_cfg.copy()
+ # nn.Tanh has no 'inplace' argument
+ if act_cfg_['type'] not in [
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
+ ]:
+ act_cfg_.setdefault('inplace', inplace)
+ self.activate = build_activation_layer(act_cfg_)
+ # Use msra init by default
+ self.init_weights()
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, 'init_weights'):
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ a = self.act_cfg.get('negative_slope', 0.01)
+ else:
+ nonlinearity = 'relu'
+ a = 0
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_activation:
+ x = self.activate(x)
+ return x
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3941e27874993418b3b5708d5a7485f175ff9c8
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .registry import CONV_LAYERS
+def conv_ws_2d(input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ eps=1e-5):
+ c_in = weight.size(0)
+ weight_flat = weight.view(c_in, -1)
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ weight = (weight - mean) / (std + eps)
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+class ConvWS2d(nn.Conv2d):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ eps=1e-5):
+ super(ConvWS2d, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.eps = eps
+ def forward(self, x):
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.eps)
+class ConvAWS2d(nn.Conv2d):
+ """AWS (Adaptive Weight Standardization)
+ This is a variant of Weight Standardization
+ (https://arxiv.org/pdf/1903.10520.pdf)
+ It is used in DetectoRS to avoid NaN
+ (https://arxiv.org/pdf/2006.02334.pdf)
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the conv kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If set True, adds a learnable bias to the
+ output. Default: True
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.register_buffer('weight_gamma',
+ torch.ones(self.out_channels, 1, 1, 1))
+ self.register_buffer('weight_beta',
+ torch.zeros(self.out_channels, 1, 1, 1))
+ def _get_weight(self, weight):
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ weight = (weight - mean) / std
+ weight = self.weight_gamma * weight + self.weight_beta
+ return weight
+ def forward(self, x):
+ weight = self._get_weight(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Override default load function.
+ AWS overrides the function _load_from_state_dict to recover
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
+ weight_beta are found in the checkpoint, this function will return
+ after super()._load_from_state_dict. Otherwise, it will compute the
+ mean and std of the pretrained weights and store them in weight_beta
+ and weight_gamma.
+ """
+ self.weight_gamma.data.fill_(-1)
+ local_missing_keys = []
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, local_missing_keys,
+ unexpected_keys, error_msgs)
+ if self.weight_gamma.data.mean() > 0:
+ for k in local_missing_keys:
+ missing_keys.append(k)
+ return
+ weight = self.weight.data
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ self.weight_beta.data.copy_(mean)
+ self.weight_gamma.data.copy_(std)
+ missing_gamma_beta = [
+ k for k in local_missing_keys
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
+ ]
+ for k in missing_gamma_beta:
+ local_missing_keys.remove(k)
+ for k in local_missing_keys:
+ missing_keys.append(k)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..722d5d8d71f75486e2db3008907c4eadfca41d63
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from .conv_module import ConvModule
+class DepthwiseSeparableConvModule(nn.Module):
+ """Depthwise separable convolution module.
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
+ This module can replace a ConvModule with the conv block replaced by two
+ conv block: depthwise conv block and pointwise conv block. The depthwise
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
+ conv block contains pointwise-conv/norm/activation layers. It should be
+ noted that there will be norm/activation layer in the depthwise conv block
+ if `norm_cfg` and `act_cfg` are specified.
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
+ pointwise ConvModule. Default: None.
+ act_cfg (dict): Default activation config for both depthwise ConvModule
+ and pointwise ConvModule. Default: dict(type='ReLU').
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ kwargs (optional): Other shared arguments for depthwise and pointwise
+ ConvModule. See ConvModule for ref.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ dw_norm_cfg='default',
+ dw_act_cfg='default',
+ pw_norm_cfg='default',
+ pw_act_cfg='default',
+ **kwargs):
+ super(DepthwiseSeparableConvModule, self).__init__()
+ assert 'groups' not in kwargs, 'groups should not be specified'
+ # if norm/activation config of depthwise/pointwise ConvModule is not
+ # specified, use default config.
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
+ # depthwise convolution
+ self.depthwise_conv = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ norm_cfg=dw_norm_cfg,
+ act_cfg=dw_act_cfg,
+ **kwargs)
+ self.pointwise_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 1,
+ norm_cfg=pw_norm_cfg,
+ act_cfg=pw_act_cfg,
+ **kwargs)
+ def forward(self, x):
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+ return x
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/drop.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b4fccd457a0d51fb10c789df3c8537fe7b67c1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/drop.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ # handle tensors with different dimensions, not just 4D tensors.
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ output = x.div(keep_prob) * random_tensor.floor()
+ return output
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ Args:
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+ """
+ def __init__(self, drop_prob=0.1):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+class Dropout(nn.Dropout):
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+ ``DropPath``
+ Args:
+ drop_prob (float): Probability of the elements to be
+ zeroed. Default: 0.5.
+ inplace (bool): Do the operation inplace or not. Default: False.
+ """
+ def __init__(self, drop_prob=0.5, inplace=False):
+ super().__init__(p=drop_prob, inplace=inplace)
+def build_dropout(cfg, default_args=None):
+ """Builder for drop out layers."""
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..988d9adf2f289ef223bd1c680a5ae1d3387f0269
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..utils import kaiming_init
+from .registry import PLUGIN_LAYERS
+class GeneralizedAttention(nn.Module):
+ """GeneralizedAttention module.
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
+ (https://arxiv.org/abs/1711.07971) for details.
+ Args:
+ in_channels (int): Channels of the input feature map.
+ spatial_range (int): The spatial range. -1 indicates no spatial range
+ constraint. Default: -1.
+ num_heads (int): The head number of empirical_attention module.
+ Default: 9.
+ position_embedding_dim (int): The position embedding dimension.
+ Default: -1.
+ position_magnitude (int): A multiplier acting on coord difference.
+ Default: 1.
+ kv_stride (int): The feature stride acting on key/value feature map.
+ Default: 2.
+ q_stride (int): The feature stride acting on query feature map.
+ Default: 1.
+ attention_type (str): A binary indicator string for indicating which
+ items in generalized empirical_attention module are used.
+ Default: '1111'.
+ - '1000' indicates 'query and key content' (appr - appr) item,
+ - '0100' indicates 'query content and relative position'
+ (appr - position) item,
+ - '0010' indicates 'key content only' (bias - appr) item,
+ - '0001' indicates 'relative position only' (bias - position) item.
+ """
+ _abbr_ = 'gen_attention_block'
+ def __init__(self,
+ in_channels,
+ spatial_range=-1,
+ num_heads=9,
+ position_embedding_dim=-1,
+ position_magnitude=1,
+ kv_stride=2,
+ q_stride=1,
+ attention_type='1111'):
+ super(GeneralizedAttention, self).__init__()
+ # hard range means local range for non-local operation
+ self.position_embedding_dim = (
+ position_embedding_dim
+ if position_embedding_dim > 0 else in_channels)
+ self.position_magnitude = position_magnitude
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.spatial_range = spatial_range
+ self.kv_stride = kv_stride
+ self.q_stride = q_stride
+ self.attention_type = [bool(int(_)) for _ in attention_type]
+ self.qk_embed_dim = in_channels // num_heads
+ out_c = self.qk_embed_dim * num_heads
+ if self.attention_type[0] or self.attention_type[1]:
+ self.query_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.query_conv.kaiming_init = True
+ if self.attention_type[0] or self.attention_type[2]:
+ self.key_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.key_conv.kaiming_init = True
+ self.v_dim = in_channels // num_heads
+ self.value_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.v_dim * num_heads,
+ kernel_size=1,
+ bias=False)
+ self.value_conv.kaiming_init = True
+ if self.attention_type[1] or self.attention_type[3]:
+ self.appr_geom_fc_x = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_x.kaiming_init = True
+ self.appr_geom_fc_y = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_y.kaiming_init = True
+ if self.attention_type[2]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.appr_bias = nn.Parameter(appr_bias_value)
+ if self.attention_type[3]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.geom_bias = nn.Parameter(geom_bias_value)
+ self.proj_conv = nn.Conv2d(
+ in_channels=self.v_dim * num_heads,
+ out_channels=in_channels,
+ kernel_size=1,
+ bias=True)
+ self.proj_conv.kaiming_init = True
+ self.gamma = nn.Parameter(torch.zeros(1))
+ if self.spatial_range >= 0:
+ # only works when non local is after 3*3 conv
+ if in_channels == 256:
+ max_len = 84
+ elif in_channels == 512:
+ max_len = 42
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
+ local_constraint_map = np.ones(
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
+ for iy in range(max_len):
+ for ix in range(max_len):
+ local_constraint_map[
+ iy, ix,
+ max((iy - self.spatial_range) //
+ self.kv_stride, 0):min((iy + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len),
+ max((ix - self.spatial_range) //
+ self.kv_stride, 0):min((ix + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len)] = 0
+ self.local_constraint_map = nn.Parameter(
+ torch.from_numpy(local_constraint_map).byte(),
+ requires_grad=False)
+ if self.q_stride > 1:
+ self.q_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.q_stride)
+ else:
+ self.q_downsample = None
+ if self.kv_stride > 1:
+ self.kv_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.kv_stride)
+ else:
+ self.kv_downsample = None
+ self.init_weights()
+ def get_position_embedding(self,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ q_stride,
+ kv_stride,
+ device,
+ dtype,
+ feat_dim,
+ wave_length=1000):
+ # the default type of Tensor is float32, leading to type mismatch
+ # in fp16 mode. Cast it to support fp16 mode.
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
+ h_idxs = h_idxs.view((h, 1)) * q_stride
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
+ w_idxs = w_idxs.view((w, 1)) * q_stride
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
+ device=device, dtype=dtype)
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
+ device=device, dtype=dtype)
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
+ # (h, h_kv, 1)
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
+ h_diff *= self.position_magnitude
+ # (w, w_kv, 1)
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
+ w_diff *= self.position_magnitude
+ feat_range = torch.arange(0, feat_dim / 4).to(
+ device=device, dtype=dtype)
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
+ dim_mat = dim_mat.view((1, 1, -1))
+ embedding_x = torch.cat(
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
+ embedding_y = torch.cat(
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
+ return embedding_x, embedding_y
+ def forward(self, x_input):
+ num_heads = self.num_heads
+ # use empirical_attention
+ if self.q_downsample is not None:
+ x_q = self.q_downsample(x_input)
+ else:
+ x_q = x_input
+ n, _, h, w = x_q.shape
+ if self.kv_downsample is not None:
+ x_kv = self.kv_downsample(x_input)
+ else:
+ x_kv = x_input
+ _, _, h_kv, w_kv = x_kv.shape
+ if self.attention_type[0] or self.attention_type[1]:
+ proj_query = self.query_conv(x_q).view(
+ (n, num_heads, self.qk_embed_dim, h * w))
+ proj_query = proj_query.permute(0, 1, 3, 2)
+ if self.attention_type[0] or self.attention_type[2]:
+ proj_key = self.key_conv(x_kv).view(
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
+ if self.attention_type[1] or self.attention_type[3]:
+ position_embed_x, position_embed_y = self.get_position_embedding(
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
+ x_input.device, x_input.dtype, self.position_embedding_dim)
+ # (n, num_heads, w, w_kv, dim)
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+ # (n, num_heads, h, h_kv, dim)
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+ position_feat_x /= math.sqrt(2)
+ position_feat_y /= math.sqrt(2)
+ # accelerate for saliency only
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+ energy = torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, h_kv * w_kv)
+ h = 1
+ w = 1
+ else:
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
+ if not self.attention_type[0]:
+ energy = torch.zeros(
+ n,
+ num_heads,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ dtype=x_input.dtype,
+ device=x_input.device)
+ # attention_type[0]: appr - appr
+ # attention_type[1]: appr - position
+ # attention_type[2]: bias - appr
+ # attention_type[3]: bias - position
+ if self.attention_type[0] or self.attention_type[2]:
+ if self.attention_type[0] and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+ elif self.attention_type[0]:
+ energy = torch.matmul(proj_query, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+ elif self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+ energy += torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, 1, h_kv, w_kv)
+ if self.attention_type[1] or self.attention_type[3]:
+ if self.attention_type[1] and self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+ proj_query_reshape = (proj_query + geom_bias).\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+ energy_x = torch.matmul(
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
+ position_feat_x.permute(0, 1, 2, 4, 3))
+ energy_x = energy_x.\
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
+ energy_y = torch.matmul(
+ proj_query_reshape,
+ position_feat_y.permute(0, 1, 2, 4, 3))
+ energy_y = energy_y.unsqueeze(5)
+ energy += energy_x + energy_y
+ elif self.attention_type[1]:
+ proj_query_reshape = proj_query.\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+ proj_query_reshape = proj_query_reshape.\
+ permute(0, 1, 3, 2, 4)
+ position_feat_x_reshape = position_feat_x.\
+ permute(0, 1, 2, 4, 3)
+ position_feat_y_reshape = position_feat_y.\
+ permute(0, 1, 2, 4, 3)
+ energy_x = torch.matmul(proj_query_reshape,
+ position_feat_x_reshape)
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
+ energy_y = torch.matmul(proj_query_reshape,
+ position_feat_y_reshape)
+ energy_y = energy_y.unsqueeze(5)
+ energy += energy_x + energy_y
+ elif self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, self.qk_embed_dim, 1).\
+ repeat(n, 1, 1, 1)
+ position_feat_x_reshape = position_feat_x.\
+ view(n, num_heads, w*w_kv, self.qk_embed_dim)
+ position_feat_y_reshape = position_feat_y.\
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
+ energy += energy_x + energy_y
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
+ if self.spatial_range >= 0:
+ cur_local_constraint_map = \
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
+ contiguous().\
+ view(1, 1, h*w, h_kv*w_kv)
+ energy = energy.masked_fill_(cur_local_constraint_map,
+ float('-inf'))
+ attention = F.softmax(energy, 3)
+ proj_value = self.value_conv(x_kv)
+ proj_value_reshape = proj_value.\
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
+ permute(0, 1, 3, 2)
+ out = torch.matmul(attention, proj_value_reshape).\
+ permute(0, 1, 3, 2).\
+ contiguous().\
+ view(n, self.v_dim * self.num_heads, h, w)
+ out = self.proj_conv(out)
+ # output is downsampled, upsample back to input size
+ if self.q_downsample is not None:
+ out = F.interpolate(
+ out,
+ size=x_input.shape[2:],
+ mode='bilinear',
+ align_corners=False)
+ out = self.gamma * out + x_input
+ return out
+ def init_weights(self):
+ for m in self.modules():
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
+ kaiming_init(
+ m,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=0,
+ distribution='uniform',
+ a=1)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b1a3d6580cf0360710426fbea1f05acdf07b4b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from .registry import ACTIVATION_LAYERS
+class HSigmoid(nn.Module):
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
+ Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
+ Args:
+ bias (float): Bias of the input feature map. Default: 1.0.
+ divisor (float): Divisor of the input feature map. Default: 2.0.
+ min_value (float): Lower bound value. Default: 0.0.
+ max_value (float): Upper bound value. Default: 1.0.
+ Returns:
+ Tensor: The output tensor.
+ """
+ def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
+ super(HSigmoid, self).__init__()
+ self.bias = bias
+ self.divisor = divisor
+ assert self.divisor != 0
+ self.min_value = min_value
+ self.max_value = max_value
+ def forward(self, x):
+ x = (x + self.bias) / self.divisor
+ return x.clamp_(self.min_value, self.max_value)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/hswish.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/hswish.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0c090ff037c99ee6c5c84c4592e87beae02208
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/hswish.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from .registry import ACTIVATION_LAYERS
+class HSwish(nn.Module):
+ """Hard Swish Module.
+ This module applies the hard swish function:
+ .. math::
+ Hswish(x) = x * ReLU6(x + 3) / 6
+ Args:
+ inplace (bool): can optionally do the operation in-place.
+ Default: False.
+ Returns:
+ Tensor: The output tensor.
+ """
+ def __init__(self, inplace=False):
+ super(HSwish, self).__init__()
+ self.act = nn.ReLU6(inplace)
+ def forward(self, x):
+ return x * self.act(x + 3) / 6
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/non_local.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d00155ef275c1201ea66bba30470a1785cc5d7
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/non_local.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+import torch
+import torch.nn as nn
+from ..utils import constant_init, normal_init
+from .conv_module import ConvModule
+from .registry import PLUGIN_LAYERS
+class _NonLocalNd(nn.Module, metaclass=ABCMeta):
+ """Basic Non-local module.
+ This module is proposed in
+ "Non-local Neural Networks"
+ Paper reference: https://arxiv.org/abs/1711.07971
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
+ Args:
+ in_channels (int): Channels of the input feature map.
+ reduction (int): Channel reduction ratio. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
+ Default: True.
+ conv_cfg (None | dict): The config dict for convolution layers.
+ If not specified, it will use `nn.Conv2d` for convolution layers.
+ Default: None.
+ norm_cfg (None | dict): The config dict for normalization layers.
+ Default: None. (This parameter is only applicable to conv_out.)
+ mode (str): Options are `gaussian`, `concatenation`,
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
+ """
+ def __init__(self,
+ in_channels,
+ reduction=2,
+ use_scale=True,
+ conv_cfg=None,
+ norm_cfg=None,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(_NonLocalNd, self).__init__()
+ self.in_channels = in_channels
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.inter_channels = max(in_channels // reduction, 1)
+ self.mode = mode
+ if mode not in [
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
+ ]:
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
+ f"'embedded_gaussian' or 'dot_product', but got "
+ f'{mode} instead.')
+ # g, theta, phi are defaulted as `nn.ConvNd`.
+ # Here we use ConvModule for potential usage.
+ self.g = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.conv_out = ConvModule(
+ self.inter_channels,
+ self.in_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ if self.mode != 'gaussian':
+ self.theta = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.phi = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ if self.mode == 'concatenation':
+ self.concat_project = ConvModule(
+ self.inter_channels * 2,
+ 1,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ act_cfg=dict(type='ReLU'))
+ self.init_weights(**kwargs)
+ def init_weights(self, std=0.01, zeros_init=True):
+ if self.mode != 'gaussian':
+ for m in [self.g, self.theta, self.phi]:
+ normal_init(m.conv, std=std)
+ else:
+ normal_init(self.g.conv, std=std)
+ if zeros_init:
+ if self.conv_out.norm_cfg is None:
+ constant_init(self.conv_out.conv, 0)
+ else:
+ constant_init(self.conv_out.norm, 0)
+ else:
+ if self.conv_out.norm_cfg is None:
+ normal_init(self.conv_out.conv, std=std)
+ else:
+ normal_init(self.conv_out.norm, std=std)
+ def gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+ def embedded_gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+ def dot_product(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight /= pairwise_weight.shape[-1]
+ return pairwise_weight
+ def concatenation(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ h = theta_x.size(2)
+ w = phi_x.size(3)
+ theta_x = theta_x.repeat(1, 1, 1, w)
+ phi_x = phi_x.repeat(1, 1, h, 1)
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
+ pairwise_weight = self.concat_project(concat_feature)
+ n, _, h, w = pairwise_weight.size()
+ pairwise_weight = pairwise_weight.view(n, h, w)
+ pairwise_weight /= pairwise_weight.shape[-1]
+ return pairwise_weight
+ def forward(self, x):
+ # Assume `reduction = 1`, then `inter_channels = C`
+ # or `inter_channels = C` when `mode="gaussian"`
+ # NonLocal1d x: [N, C, H]
+ # NonLocal2d x: [N, C, H, W]
+ # NonLocal3d x: [N, C, T, H, W]
+ n = x.size(0)
+ # NonLocal1d g_x: [N, H, C]
+ # NonLocal2d g_x: [N, HxW, C]
+ # NonLocal3d g_x: [N, TxHxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+ pairwise_func = getattr(self, self.mode)
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+ # NonLocal1d y: [N, H, C]
+ # NonLocal2d y: [N, HxW, C]
+ # NonLocal3d y: [N, TxHxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # NonLocal1d y: [N, C, H]
+ # NonLocal2d y: [N, C, H, W]
+ # NonLocal3d y: [N, C, T, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+ output = x + self.conv_out(y)
+ return output
+class NonLocal1d(_NonLocalNd):
+ """1D Non-local module.
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv1d').
+ """
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv1d'),
+ **kwargs):
+ super(NonLocal1d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+ self.sub_sample = sub_sample
+ if sub_sample:
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+class NonLocal2d(_NonLocalNd):
+ """2D Non-local module.
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv2d').
+ """
+ _abbr_ = 'nonlocal_block'
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv2d'),
+ **kwargs):
+ super(NonLocal2d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+ self.sub_sample = sub_sample
+ if sub_sample:
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+class NonLocal3d(_NonLocalNd):
+ """3D Non-local module.
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv3d').
+ """
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv3d'),
+ **kwargs):
+ super(NonLocal3d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+ self.sub_sample = sub_sample
+ if sub_sample:
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/norm.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..408f4b42731b19a3beeef68b6a5e610d0bbc18b3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/norm.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import torch.nn as nn
+from annotator.uniformer.mmcv.utils import is_tuple_of
+from annotator.uniformer.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
+from .registry import NORM_LAYERS
+NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
+NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
+NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
+NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
+NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
+NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
+NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
+ infer the abbreviation to map class types to abbreviations.
+ Rule 1: If the class has the property "_abbr_", return the property.
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
+ "in" respectively.
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
+ respectively.
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
+ Args:
+ class_type (type): The norm layer type.
+ Returns:
+ str: The inferred abbreviation.
+ """
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
+ return 'in'
+ elif issubclass(class_type, _BatchNorm):
+ return 'bn'
+ elif issubclass(class_type, nn.GroupNorm):
+ return 'gn'
+ elif issubclass(class_type, nn.LayerNorm):
+ return 'ln'
+ else:
+ class_name = class_type.__name__.lower()
+ if 'batch' in class_name:
+ return 'bn'
+ elif 'group' in class_name:
+ return 'gn'
+ elif 'layer' in class_name:
+ return 'ln'
+ elif 'instance' in class_name:
+ return 'in'
+ else:
+ return 'norm_layer'
+def build_norm_layer(cfg, num_features, postfix=''):
+ """Build normalization layer.
+ Args:
+ cfg (dict): The norm layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a norm layer.
+ - requires_grad (bool, optional): Whether stop gradient updates.
+ num_features (int): Number of input channels.
+ postfix (int | str): The postfix to be appended into norm abbreviation
+ to create named layer.
+ Returns:
+ (str, nn.Module): The first element is the layer name consisting of
+ abbreviation and postfix, e.g., bn1, gn. The second element is the
+ created norm layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+ layer_type = cfg_.pop('type')
+ if layer_type not in NORM_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+ norm_layer = NORM_LAYERS.get(layer_type)
+ abbr = infer_abbr(norm_layer)
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+ requires_grad = cfg_.pop('requires_grad', True)
+ cfg_.setdefault('eps', 1e-5)
+ if layer_type != 'GN':
+ layer = norm_layer(num_features, **cfg_)
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
+ layer._specify_ddp_gpu_num(1)
+ else:
+ assert 'num_groups' in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+ return name, layer
+def is_norm(layer, exclude=None):
+ """Check if a layer is a normalization layer.
+ Args:
+ layer (nn.Module): The layer to be checked.
+ exclude (type | tuple[type]): Types to be excluded.
+ Returns:
+ bool: Whether the layer is a norm layer.
+ """
+ if exclude is not None:
+ if not isinstance(exclude, tuple):
+ exclude = (exclude, )
+ if not is_tuple_of(exclude, type):
+ raise TypeError(
+ f'"exclude" must be either None or type or a tuple of types, '
+ f'but got {type(exclude)}: {exclude}')
+ if exclude and isinstance(layer, exclude):
+ return False
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
+ return isinstance(layer, all_norm_bases)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/padding.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/padding.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from .registry import PADDING_LAYERS
+PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
+PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
+PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
+def build_padding_layer(cfg, *args, **kwargs):
+ """Build padding layer.
+ Args:
+ cfg (None or dict): The padding layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a padding layer.
+ Returns:
+ nn.Module: Created padding layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+ padding_type = cfg_.pop('type')
+ if padding_type not in PADDING_LAYERS:
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
+ else:
+ padding_layer = PADDING_LAYERS.get(padding_type)
+ layer = padding_layer(*args, **kwargs, **cfg_)
+ return layer
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/plugin.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c010d4053174dd41107aa654ea67e82b46a25c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/plugin.py
@@ -0,0 +1,88 @@
+import inspect
+import platform
+from .registry import PLUGIN_LAYERS
+if platform.system() == 'Windows':
+ import regex as re
+ import re
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+ This method will infer the abbreviation to map class types to
+ abbreviations.
+ Rule 1: If the class has the property "abbr", return the property.
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
+ Args:
+ class_type (type): The norm layer type.
+ Returns:
+ str: The inferred abbreviation.
+ """
+ def camel2snack(word):
+ """Convert camel case word into snack case.
+ Modified from `inflection lib
+ `_.
+ Example::
+ >>> camel2snack("FancyBlock")
+ 'fancy_block'
+ """
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
+ word = word.replace('-', '_')
+ return word.lower()
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ else:
+ return camel2snack(class_type.__name__)
+def build_plugin_layer(cfg, postfix='', **kwargs):
+ """Build plugin layer.
+ Args:
+ cfg (None or dict): cfg should contain:
+ type (str): identify plugin layer type.
+ layer args: args needed to instantiate a plugin layer.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer. Default: ''.
+ Returns:
+ tuple[str, nn.Module]:
+ name (str): abbreviation + postfix
+ layer (nn.Module): created plugin layer
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+ layer_type = cfg_.pop('type')
+ if layer_type not in PLUGIN_LAYERS:
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
+ abbr = infer_abbr(plugin_layer)
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+ layer = plugin_layer(**kwargs, **cfg_)
+ return name, layer
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/registry.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..39eabc58db4b5954478a2ac1ab91cea5e45ab055
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/registry.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.uniformer.mmcv.utils import Registry
+CONV_LAYERS = Registry('conv layer')
+NORM_LAYERS = Registry('norm layer')
+ACTIVATION_LAYERS = Registry('activation layer')
+PADDING_LAYERS = Registry('padding layer')
+UPSAMPLE_LAYERS = Registry('upsample layer')
+PLUGIN_LAYERS = Registry('plugin layer')
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/scale.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..c905fffcc8bf998d18d94f927591963c428025e2
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/scale.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+class Scale(nn.Module):
+ """A learnable scale parameter.
+ This layer scales the input by a learnable factor. It multiplies a
+ learnable scale parameter of shape (1,) with input of any shape.
+ Args:
+ scale (float): Initial value of scale factor. Default: 1.0
+ """
+ def __init__(self, scale=1.0):
+ super(Scale, self).__init__()
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
+ def forward(self, x):
+ return x * self.scale
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/swish.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ca8ed7b749413f011ae54aac0cab27e6f0b51f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/swish.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from .registry import ACTIVATION_LAYERS
+class Swish(nn.Module):
+ """Swish Module.
+ This module applies the swish function:
+ .. math::
+ Swish(x) = x * Sigmoid(x)
+ Returns:
+ Tensor: The output tensor.
+ """
+ def __init__(self):
+ super(Swish, self).__init__()
+ def forward(self, x):
+ return x * torch.sigmoid(x)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/transformer.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61ae0dd941a7be00b3e41a3de833ec50470a45f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/transformer.py
@@ -0,0 +1,595 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv import ConfigDict, deprecated_api_warning
+from annotator.uniformer.mmcv.cnn import Linear, build_activation_layer, build_norm_layer
+from annotator.uniformer.mmcv.runner.base_module import BaseModule, ModuleList, Sequential
+from annotator.uniformer.mmcv.utils import build_from_cfg
+from .drop import build_dropout
+# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
+ from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
+ warnings.warn(
+ ImportWarning(
+ '``MultiScaleDeformableAttention`` has been moved to '
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
+ '``from annotator.uniformer.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
+ 'to ``from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
+ ))
+except ImportError:
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
+ '``mmcv.ops.multi_scale_deform_attn``, '
+ 'You should install ``mmcv-full`` if you need this module. ')
+def build_positional_encoding(cfg, default_args=None):
+ """Builder for Position Encoding."""
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
+def build_attention(cfg, default_args=None):
+ """Builder for attention."""
+ return build_from_cfg(cfg, ATTENTION, default_args)
+def build_feedforward_network(cfg, default_args=None):
+ """Builder for feed-forward network (FFN)."""
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
+def build_transformer_layer(cfg, default_args=None):
+ """Builder for transformer layer."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
+def build_transformer_layer_sequence(cfg, default_args=None):
+ """Builder for transformer encoder and transformer decoder."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
+class MultiheadAttention(BaseModule):
+ """A wrapper for ``torch.nn.MultiheadAttention``.
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): When it is True, Key, Query and Value are shape of
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+ Default to False.
+ """
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+ super(MultiheadAttention, self).__init__(init_cfg)
+ if 'dropout' in kwargs:
+ warnings.warn('The arguments `dropout` in MultiheadAttention '
+ 'has been deprecated, now you can separately '
+ 'set `attn_drop`(float), proj_drop(float), '
+ 'and `dropout_layer`(dict) ')
+ attn_drop = kwargs['dropout']
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.batch_first = batch_first
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
+ **kwargs)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiheadAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_pos=None,
+ attn_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `MultiheadAttention`.
+ **kwargs allow passing a more general data flow when combining
+ with other operations in `transformerlayer`.
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ If None, the `key` will be used.
+ identity (Tensor): This tensor, with the same shape as x,
+ will be used for the identity link.
+ If None, `x` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `x`. If not None, it will
+ be added to `x` before forward function. Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+ Returns:
+ Tensor: forwarded results with shape
+ [num_queries, bs, embed_dims]
+ if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ """
+ if key is None:
+ key = query
+ if value is None:
+ value = key
+ if identity is None:
+ identity = query
+ if key_pos is None:
+ if query_pos is not None:
+ # use query_pos if key_pos is not available
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ else:
+ warnings.warn(f'position encoding of key is'
+ f'missing in {self.__class__.__name__}.')
+ if query_pos is not None:
+ query = query + query_pos
+ if key_pos is not None:
+ key = key + key_pos
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+ out = self.attn(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)[0]
+ if self.batch_first:
+ out = out.transpose(0, 1)
+ return identity + self.dropout_layer(self.proj_drop(out))
+class FFN(BaseModule):
+ """Implements feed-forward networks (FFNs) with identity connection.
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Default: 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='ReLU')
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+ @deprecated_api_warning(
+ {
+ 'dropout': 'ffn_drop',
+ 'add_residual': 'add_identity'
+ },
+ cls_name='FFN')
+ def __init__(self,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ **kwargs):
+ super(FFN, self).__init__(init_cfg)
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
+ f'than 2. got {num_fcs}.'
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+ layers = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ Sequential(
+ Linear(in_channels, feedforward_channels), self.activate,
+ nn.Dropout(ffn_drop)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+ def forward(self, x, identity=None):
+ """Forward function for `FFN`.
+ The function would add x to the output tensor if residue is None.
+ """
+ out = self.layers(x)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+class BaseTransformerLayer(BaseModule):
+ """Base `TransformerLayer` for vision transformer.
+ It can be built from `mmcv.ConfigDict` and support more flexible
+ customization, for example, using any number of `FFN or LN ` and
+ use different kinds of `attention` by specifying a list of `ConfigDict`
+ named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
+ when you specifying `norm` as the first element of `operation_order`.
+ More details about the `prenorm`: `On Layer Normalization in the
+ Transformer Architecture `_ .
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for `self_attention` or `cross_attention` modules,
+ The order of the configs in the list should be consistent with
+ corresponding attentions in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config. Default: None.
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for FFN, The order of the configs in the list should be
+ consistent with corresponding ffn in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Support `prenorm` when you specifying first element as `norm`.
+ Default:None.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): Key, Query and Value are shape
+ of (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ """
+ def __init__(self,
+ attn_cfgs=None,
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
+ operation_order=None,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+ deprecated_args = dict(
+ feedforward_channels='feedforward_channels',
+ ffn_dropout='ffn_drop',
+ ffn_num_fcs='num_fcs')
+ for ori_name, new_name in deprecated_args.items():
+ if ori_name in kwargs:
+ warnings.warn(
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
+ f'has been deprecated, now you should set `{new_name}` '
+ f'and other FFN related arguments '
+ f'to a dict named `ffn_cfgs`. ')
+ ffn_cfgs[new_name] = kwargs[ori_name]
+ super(BaseTransformerLayer, self).__init__(init_cfg)
+ self.batch_first = batch_first
+ assert set(operation_order) & set(
+ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
+ set(operation_order), f'The operation_order of' \
+ f' {self.__class__.__name__} should ' \
+ f'contains all four operation type ' \
+ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
+ num_attn = operation_order.count('self_attn') + operation_order.count(
+ 'cross_attn')
+ if isinstance(attn_cfgs, dict):
+ attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
+ else:
+ assert num_attn == len(attn_cfgs), f'The length ' \
+ f'of attn_cfg {num_attn} is ' \
+ f'not consistent with the number of attention' \
+ f'in operation_order {operation_order}.'
+ self.num_attn = num_attn
+ self.operation_order = operation_order
+ self.norm_cfg = norm_cfg
+ self.pre_norm = operation_order[0] == 'norm'
+ self.attentions = ModuleList()
+ index = 0
+ for operation_name in operation_order:
+ if operation_name in ['self_attn', 'cross_attn']:
+ if 'batch_first' in attn_cfgs[index]:
+ assert self.batch_first == attn_cfgs[index]['batch_first']
+ else:
+ attn_cfgs[index]['batch_first'] = self.batch_first
+ attention = build_attention(attn_cfgs[index])
+ # Some custom attentions used as `self_attn`
+ # or `cross_attn` can have different behavior.
+ attention.operation_name = operation_name
+ self.attentions.append(attention)
+ index += 1
+ self.embed_dims = self.attentions[0].embed_dims
+ self.ffns = ModuleList()
+ num_ffns = operation_order.count('ffn')
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = ConfigDict(ffn_cfgs)
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
+ assert len(ffn_cfgs) == num_ffns
+ for ffn_index in range(num_ffns):
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
+ ffn_cfgs['embed_dims'] = self.embed_dims
+ else:
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
+ self.ffns.append(
+ build_feedforward_network(ffn_cfgs[ffn_index],
+ dict(type='FFN')))
+ self.norms = ModuleList()
+ num_norms = operation_order.count('norm')
+ for _ in range(num_norms):
+ self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerDecoderLayer`.
+ **kwargs contains some specific arguments of attentions.
+ Args:
+ query (Tensor): The input query with shape
+ [num_queries, bs, embed_dims] if
+ self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ value (Tensor): The value tensor with same shape as `key`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor] | None): 2D Tensor used in
+ calculation of corresponding attention. The length of
+ it should equal to the number of `attention` in
+ `operation_order`. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in `self_attn` layer.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+ Returns:
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ """
+ norm_index = 0
+ attn_index = 0
+ ffn_index = 0
+ identity = query
+ if attn_masks is None:
+ attn_masks = [None for _ in range(self.num_attn)]
+ elif isinstance(attn_masks, torch.Tensor):
+ attn_masks = [
+ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
+ ]
+ warnings.warn(f'Use same attn_mask in all attentions in '
+ f'{self.__class__.__name__} ')
+ else:
+ assert len(attn_masks) == self.num_attn, f'The length of ' \
+ f'attn_masks {len(attn_masks)} must be equal ' \
+ f'to the number of attention in ' \
+ f'operation_order {self.num_attn}'
+ for layer in self.operation_order:
+ if layer == 'self_attn':
+ temp_key = temp_value = query
+ query = self.attentions[attn_index](
+ query,
+ temp_key,
+ temp_value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=query_key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+ elif layer == 'norm':
+ query = self.norms[norm_index](query)
+ norm_index += 1
+ elif layer == 'cross_attn':
+ query = self.attentions[attn_index](
+ query,
+ key,
+ value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+ elif layer == 'ffn':
+ query = self.ffns[ffn_index](
+ query, identity if self.pre_norm else None)
+ ffn_index += 1
+ return query
+class TransformerLayerSequence(BaseModule):
+ """Base class for TransformerEncoder and TransformerDecoder in vision
+ transformer.
+ As base-class of Encoder and Decoder in vision transformer.
+ Support customization such as specifying different kind
+ of `transformer_layer` in `transformer_coder`.
+ Args:
+ transformerlayer (list[obj:`mmcv.ConfigDict`] |
+ obj:`mmcv.ConfigDict`): Config of transformerlayer
+ in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
+ it would be repeated `num_layer` times to a
+ list[`mmcv.ConfigDict`]. Default: None.
+ num_layers (int): The number of `TransformerLayer`. Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
+ super(TransformerLayerSequence, self).__init__(init_cfg)
+ if isinstance(transformerlayers, dict):
+ transformerlayers = [
+ copy.deepcopy(transformerlayers) for _ in range(num_layers)
+ ]
+ else:
+ assert isinstance(transformerlayers, list) and \
+ len(transformerlayers) == num_layers
+ self.num_layers = num_layers
+ self.layers = ModuleList()
+ for i in range(num_layers):
+ self.layers.append(build_transformer_layer(transformerlayers[i]))
+ self.embed_dims = self.layers[0].embed_dims
+ self.pre_norm = self.layers[0].pre_norm
+ def forward(self,
+ query,
+ key,
+ value,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerCoder`.
+ Args:
+ query (Tensor): Input query with shape
+ `(num_queries, bs, embed_dims)`.
+ key (Tensor): The key tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor], optional): Each element is 2D Tensor
+ which is used in calculation of corresponding attention in
+ operation_order. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in self-attention
+ Default: None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+ Returns:
+ Tensor: results with shape [num_queries, bs, embed_dims].
+ """
+ for layer in self.layers:
+ query = layer(
+ query,
+ key,
+ value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_masks=attn_masks,
+ query_key_padding_mask=query_key_padding_mask,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ return query
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/upsample.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a353767d0ce8518f0d7289bed10dba0178ed12
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/upsample.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from ..utils import xavier_init
+from .registry import UPSAMPLE_LAYERS
+UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
+UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
+class PixelShufflePack(nn.Module):
+ """Pixel Shuffle upsample layer.
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
+ achieve a simple upsampling with pixel shuffle.
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Upsample ratio.
+ upsample_kernel (int): Kernel size of the conv layer to expand the
+ channels.
+ """
+ def __init__(self, in_channels, out_channels, scale_factor,
+ upsample_kernel):
+ super(PixelShufflePack, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+ self.upsample_kernel = upsample_kernel
+ self.upsample_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels * scale_factor * scale_factor,
+ self.upsample_kernel,
+ padding=(self.upsample_kernel - 1) // 2)
+ self.init_weights()
+ def init_weights(self):
+ xavier_init(self.upsample_conv, distribution='uniform')
+ def forward(self, x):
+ x = self.upsample_conv(x)
+ x = F.pixel_shuffle(x, self.scale_factor)
+ return x
+def build_upsample_layer(cfg, *args, **kwargs):
+ """Build upsample layer.
+ Args:
+ cfg (dict): The upsample layer config, which should contain:
+ - type (str): Layer type.
+ - scale_factor (int): Upsample ratio, which is not applicable to
+ deconv.
+ - layer args: Args needed to instantiate a upsample layer.
+ args (argument list): Arguments passed to the ``__init__``
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the
+ ``__init__`` method of the corresponding conv layer.
+ Returns:
+ nn.Module: Created upsample layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+ layer_type = cfg_.pop('type')
+ if layer_type not in UPSAMPLE_LAYERS:
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/bricks/wrappers.py b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/bricks/wrappers.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
+Wrap some nn modules to support empty tensor input. Currently, these wrappers
+are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
+heads are trained on only positive RoIs.
+import math
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair, _triple
+from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
+if torch.__version__ == 'parrots':
+ TORCH_VERSION = torch.__version__
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
+ # for comparison
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+def obsolete_torch_version(torch_version, version_threshold):
+ return torch_version == 'parrots' or torch_version <= version_threshold
+class NewEmptyTensorOp(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, new_shape):
+ ctx.shape = x.shape
+ return x.new_empty(new_shape)
+ @staticmethod
+ def backward(ctx, grad):
+ shape = ctx.shape
+ return NewEmptyTensorOp.apply(grad, shape), None
+@CONV_LAYERS.register_module('Conv', force=True)
+class Conv2d(nn.Conv2d):
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+ return super().forward(x)
+@CONV_LAYERS.register_module('Conv3d', force=True)
+class Conv3d(nn.Conv3d):
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+ return super().forward(x)
+@UPSAMPLE_LAYERS.register_module('deconv', force=True)
+class ConvTranspose2d(nn.ConvTranspose2d):
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+ return super().forward(x)
+@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
+class ConvTranspose3d(nn.ConvTranspose3d):
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+ return super().forward(x)
+class MaxPool2d(nn.MaxPool2d):
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
+ _pair(self.padding), _pair(self.stride),
+ _pair(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+ return super().forward(x)
+class MaxPool3d(nn.MaxPool3d):
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
+ _triple(self.padding),
+ _triple(self.stride),
+ _triple(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+ return super().forward(x)
+class Linear(torch.nn.Linear):
+ def forward(self, x):
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
+ out_shape = [x.shape[0], self.out_features]
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+ return super().forward(x)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/builder.py b/ControlNet/annotator/uniformer/mmcv/cnn/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/builder.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..runner import Sequential
+from ..utils import Registry, build_from_cfg
+def build_model_from_cfg(cfg, registry, default_args=None):
+ """Build a PyTorch model from config dict(s). Different from
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
+ Args:
+ cfg (dict, list[dict]): The config of modules, is is either a config
+ dict or a list of config dicts. If cfg is a list, a
+ the built modules will be wrapped with ``nn.Sequential``.
+ registry (:obj:`Registry`): A registry the module belongs to.
+ default_args (dict, optional): Default arguments to build the module.
+ Defaults to None.
+ Returns:
+ nn.Module: A built nn module.
+ """
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+MODELS = Registry('model', build_func=build_model_from_cfg)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/resnet.py b/ControlNet/annotator/uniformer/mmcv/cnn/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb3ac057ee2d52c46fc94685b5d4e698aad8d5f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/resnet.py
@@ -0,0 +1,316 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from .utils import constant_init, kaiming_init
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+class BasicBlock(nn.Module):
+ expansion = 1
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ super(BasicBlock, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ assert not with_cp
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+class Bottleneck(nn.Module):
+ expansion = 4
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ """Bottleneck block.
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ if style == 'pytorch':
+ conv1_stride = 1
+ conv2_stride = stride
+ else:
+ conv1_stride = stride
+ conv2_stride = 1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+ def forward(self, x):
+ def _inner_forward(x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ return out
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ out = self.relu(out)
+ return out
+def make_res_layer(block,
+ inplanes,
+ planes,
+ blocks,
+ stride=1,
+ dilation=1,
+ style='pytorch',
+ with_cp=False):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ dilation,
+ downsample,
+ style=style,
+ with_cp=with_cp))
+ inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
+ return nn.Sequential(*layers)
+class ResNet(nn.Module):
+ """ResNet backbone.
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+ def __init__(self,
+ depth,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ with_cp=False):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ assert num_stages >= 1 and num_stages <= 4
+ block, stage_blocks = self.arch_settings[depth]
+ stage_blocks = stage_blocks[:num_stages]
+ assert len(strides) == len(dilations) == num_stages
+ assert max(out_indices) < num_stages
+ self.out_indices = out_indices
+ self.style = style
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.with_cp = with_cp
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.res_layers = []
+ for i, num_blocks in enumerate(stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = 64 * 2**i
+ res_layer = make_res_layer(
+ block,
+ self.inplanes,
+ planes,
+ num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ with_cp=with_cp)
+ self.inplanes = planes * block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+ self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+ def train(self, mode=True):
+ super(ResNet, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ if mode and self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for param in self.bn1.parameters():
+ param.requires_grad = False
+ self.bn1.eval()
+ self.bn1.weight.requires_grad = False
+ self.bn1.bias.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ mod = getattr(self, f'layer{i}')
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/utils/__init__.py b/ControlNet/annotator/uniformer/mmcv/cnn/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/utils/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .flops_counter import get_model_complexity_info
+from .fuse_conv_bn import fuse_conv_bn
+from .sync_bn import revert_sync_batchnorm
+from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
+ KaimingInit, NormalInit, PretrainedInit,
+ TruncNormalInit, UniformInit, XavierInit,
+ bias_init_with_prob, caffe2_xavier_init,
+ constant_init, initialize, kaiming_init, normal_init,
+ trunc_normal_init, uniform_init, xavier_init)
+__all__ = [
+ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'revert_sync_batchnorm'
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/utils/flops_counter.py b/ControlNet/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10af5feca7f4b8c0ba359b7b1c826f754e048be
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
@@ -0,0 +1,599 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+# MIT License
+# Copyright (c) 2018 Vladislav Sovrasov
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+import sys
+from functools import partial
+import numpy as np
+import torch
+import torch.nn as nn
+import annotator.uniformer.mmcv as mmcv
+def get_model_complexity_info(model,
+ input_shape,
+ print_per_layer_stat=True,
+ as_strings=True,
+ input_constructor=None,
+ flush=False,
+ ost=sys.stdout):
+ """Get complexity information of a model.
+ This method can calculate FLOPs and parameter counts of a model with
+ corresponding input shape. It can also print complexity information for
+ each layer in a model.
+ Supported layers are listed as below:
+ - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
+ - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
+ ``nn.ReLU6``.
+ - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
+ ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
+ ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
+ ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
+ ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
+ - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
+ ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
+ ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
+ - Linear: ``nn.Linear``.
+ - Deconvolution: ``nn.ConvTranspose2d``.
+ - Upsample: ``nn.Upsample``.
+ Args:
+ model (nn.Module): The model for complexity calculation.
+ input_shape (tuple): Input shape used for calculation.
+ print_per_layer_stat (bool): Whether to print complexity information
+ for each layer in a model. Default: True.
+ as_strings (bool): Output FLOPs and params counts in a string form.
+ Default: True.
+ input_constructor (None | callable): If specified, it takes a callable
+ method that generates input. otherwise, it will generate a random
+ tensor with input shape to calculate FLOPs. Default: None.
+ flush (bool): same as that in :func:`print`. Default: False.
+ ost (stream): same as ``file`` param in :func:`print`.
+ Default: sys.stdout.
+ Returns:
+ tuple[float | str]: If ``as_strings`` is set to True, it will return
+ FLOPs and parameter counts in a string format. otherwise, it will
+ return those in a float number format.
+ """
+ assert type(input_shape) is tuple
+ assert len(input_shape) >= 1
+ assert isinstance(model, nn.Module)
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval()
+ flops_model.start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_shape)
+ _ = flops_model(**input)
+ else:
+ try:
+ batch = torch.ones(()).new_empty(
+ (1, *input_shape),
+ dtype=next(flops_model.parameters()).dtype,
+ device=next(flops_model.parameters()).device)
+ except StopIteration:
+ # Avoid StopIteration for models which have no parameters,
+ # like `nn.Relu()`, `nn.AvgPool2d`, etc.
+ batch = torch.ones(()).new_empty((1, *input_shape))
+ _ = flops_model(batch)
+ flops_count, params_count = flops_model.compute_average_flops_cost()
+ if print_per_layer_stat:
+ print_model_with_flops(
+ flops_model, flops_count, params_count, ost=ost, flush=flush)
+ flops_model.stop_flops_count()
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+ return flops_count, params_count
+def flops_to_string(flops, units='GFLOPs', precision=2):
+ """Convert FLOPs number into a string.
+ Note that Here we take a multiply-add counts as one FLOP.
+ Args:
+ flops (float): FLOPs number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
+ 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
+ choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 2.
+ Returns:
+ str: The converted FLOPs number with units.
+ Examples:
+ >>> flops_to_string(1e9)
+ '1.0 GFLOPs'
+ >>> flops_to_string(2e5, 'MFLOPs')
+ '0.2 MFLOPs'
+ >>> flops_to_string(3e-9, None)
+ '3e-09 FLOPs'
+ """
+ if units is None:
+ if flops // 10**9 > 0:
+ return str(round(flops / 10.**9, precision)) + ' GFLOPs'
+ elif flops // 10**6 > 0:
+ return str(round(flops / 10.**6, precision)) + ' MFLOPs'
+ elif flops // 10**3 > 0:
+ return str(round(flops / 10.**3, precision)) + ' KFLOPs'
+ else:
+ return str(flops) + ' FLOPs'
+ else:
+ if units == 'GFLOPs':
+ return str(round(flops / 10.**9, precision)) + ' ' + units
+ elif units == 'MFLOPs':
+ return str(round(flops / 10.**6, precision)) + ' ' + units
+ elif units == 'KFLOPs':
+ return str(round(flops / 10.**3, precision)) + ' ' + units
+ else:
+ return str(flops) + ' FLOPs'
+def params_to_string(num_params, units=None, precision=2):
+ """Convert parameter number into a string.
+ Args:
+ num_params (float): Parameter number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'M',
+ 'K' and ''. If set to None, it will automatically choose the most
+ suitable unit for Parameter number. Default: None.
+ precision (int): Digit number after the decimal point. Default: 2.
+ Returns:
+ str: The converted parameter number with units.
+ Examples:
+ >>> params_to_string(1e9)
+ '1000.0 M'
+ >>> params_to_string(2e5)
+ '200.0 k'
+ >>> params_to_string(3e-9)
+ '3e-09'
+ """
+ if units is None:
+ if num_params // 10**6 > 0:
+ return str(round(num_params / 10**6, precision)) + ' M'
+ elif num_params // 10**3:
+ return str(round(num_params / 10**3, precision)) + ' k'
+ else:
+ return str(num_params)
+ else:
+ if units == 'M':
+ return str(round(num_params / 10.**6, precision)) + ' ' + units
+ elif units == 'K':
+ return str(round(num_params / 10.**3, precision)) + ' ' + units
+ else:
+ return str(num_params)
+def print_model_with_flops(model,
+ total_flops,
+ total_params,
+ units='GFLOPs',
+ precision=3,
+ ost=sys.stdout,
+ flush=False):
+ """Print a model with FLOPs for each layer.
+ Args:
+ model (nn.Module): The model to be printed.
+ total_flops (float): Total FLOPs of the model.
+ total_params (float): Total parameter counts of the model.
+ units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 3.
+ ost (stream): same as `file` param in :func:`print`.
+ Default: sys.stdout.
+ flush (bool): same as that in :func:`print`. Default: False.
+ Example:
+ >>> class ExampleModel(nn.Module):
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.conv1 = nn.Conv2d(3, 8, 3)
+ >>> self.conv2 = nn.Conv2d(8, 256, 3)
+ >>> self.conv3 = nn.Conv2d(256, 8, 3)
+ >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ >>> self.flatten = nn.Flatten()
+ >>> self.fc = nn.Linear(8, 1)
+ >>> def forward(self, x):
+ >>> x = self.conv1(x)
+ >>> x = self.conv2(x)
+ >>> x = self.conv3(x)
+ >>> x = self.avg_pool(x)
+ >>> x = self.flatten(x)
+ >>> x = self.fc(x)
+ >>> return x
+ >>> model = ExampleModel()
+ >>> x = (3, 16, 16)
+ to print the complexity information state for each layer, you can use
+ >>> get_model_complexity_info(model, x)
+ or directly use
+ >>> print_model_with_flops(model, 4579784.0, 37361)
+ ExampleModel(
+ 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
+ (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
+ (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
+ (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
+ (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
+ (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
+ (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
+ )
+ """
+ def accumulate_params(self):
+ if is_supported_instance(self):
+ return self.__params__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_params()
+ return sum
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+ def flops_repr(self):
+ accumulated_num_params = self.accumulate_params()
+ accumulated_flops_cost = self.accumulate_flops()
+ return ', '.join([
+ params_to_string(
+ accumulated_num_params, units='M', precision=precision),
+ '{:.3%} Params'.format(accumulated_num_params / total_params),
+ flops_to_string(
+ accumulated_flops_cost, units=units, precision=precision),
+ '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
+ self.original_extra_repr()
+ ])
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ m.accumulate_params = accumulate_params.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+ def del_extra_repr(m):
+ if hasattr(m, 'original_extra_repr'):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, 'accumulate_flops'):
+ del m.accumulate_flops
+ model.apply(add_extra_repr)
+ print(model, file=ost, flush=flush)
+ model.apply(del_extra_repr)
+def get_model_parameters_number(model):
+ """Calculate parameter number of a model.
+ Args:
+ model (nn.module): The model for parameter number calculation.
+ Returns:
+ float: Parameter number of the model.
+ """
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num_params
+def add_flops_counting_methods(net_main_module):
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ net_main_module.start_flops_count = start_flops_count.__get__(
+ net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__(
+ net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__(
+ net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
+ net_main_module)
+ net_main_module.reset_flops_count()
+ return net_main_module
+def compute_average_flops_cost(self):
+ """Compute average FLOPs cost.
+ A method to compute average FLOPs cost, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+ Returns:
+ float: Current mean flops consumption per image.
+ """
+ batches_count = self.__batch_counter__
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+ params_sum = get_model_parameters_number(self)
+ return flops_sum / batches_count, params_sum
+def start_flops_count(self):
+ """Activate the computation of mean flops consumption per image.
+ A method to activate the computation of mean flops consumption per image.
+ which will be available after ``add_flops_counting_methods()`` is called on
+ a desired net object. It should be called before running the network.
+ """
+ add_batch_counter_hook_function(self)
+ def add_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ return
+ else:
+ handle = module.register_forward_hook(
+ get_modules_mapping()[type(module)])
+ module.__flops_handle__ = handle
+ self.apply(partial(add_flops_counter_hook_function))
+def stop_flops_count(self):
+ """Stop computing the mean flops consumption per image.
+ A method to stop computing the mean flops consumption per image, which will
+ be available after ``add_flops_counting_methods()`` is called on a desired
+ net object. It can be called to pause the computation whenever.
+ """
+ remove_batch_counter_hook_function(self)
+ self.apply(remove_flops_counter_hook_function)
+def reset_flops_count(self):
+ """Reset statistics computed so far.
+ A method to Reset computed statistics, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+ """
+ add_batch_counter_variables_or_reset(self)
+ self.apply(add_flops_counter_variable_or_reset)
+# ---- Internal functions
+def empty_flops_counter_hook(module, input, output):
+ module.__flops__ += 0
+def upsample_flops_counter_hook(module, input, output):
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+def relu_flops_counter_hook(module, input, output):
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+def linear_flops_counter_hook(module, input, output):
+ input = input[0]
+ output_last_dim = output.shape[
+ -1] # pytorch checks dimensions, so here we don't care much
+ module.__flops__ += int(np.prod(input.shape) * output_last_dim)
+def pool_flops_counter_hook(module, input, output):
+ input = input[0]
+ module.__flops__ += int(np.prod(input.shape))
+def norm_flops_counter_hook(module, input, output):
+ input = input[0]
+ batch_flops = np.prod(input.shape)
+ if (getattr(module, 'affine', False)
+ or getattr(module, 'elementwise_affine', False)):
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+def deconv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+ batch_size = input.shape[0]
+ input_height, input_width = input.shape[2:]
+ kernel_height, kernel_width = conv_module.kernel_size
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = (
+ kernel_height * kernel_width * in_channels * filters_per_channel)
+ active_elements_count = batch_size * input_height * input_width
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ output_height, output_width = output.shape[2:]
+ bias_flops = out_channels * batch_size * output_height * output_height
+ overall_flops = overall_conv_flops + bias_flops
+ conv_module.__flops__ += int(overall_flops)
+def conv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+ batch_size = input.shape[0]
+ output_dims = list(output.shape[2:])
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = int(
+ np.prod(kernel_dims)) * in_channels * filters_per_channel
+ active_elements_count = batch_size * int(np.prod(output_dims))
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ bias_flops = out_channels * active_elements_count
+ overall_flops = overall_conv_flops + bias_flops
+ conv_module.__flops__ += int(overall_flops)
+def batch_counter_hook(module, input, output):
+ batch_size = 1
+ if len(input) > 0:
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+ batch_size = len(input)
+ else:
+ pass
+ print('Warning! No positional inputs found for a module, '
+ 'assuming batch size is 1.')
+ module.__batch_counter__ += batch_size
+def add_batch_counter_variables_or_reset(module):
+ module.__batch_counter__ = 0
+def add_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ return
+ handle = module.register_forward_hook(batch_counter_hook)
+ module.__batch_counter_handle__ = handle
+def remove_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ module.__batch_counter_handle__.remove()
+ del module.__batch_counter_handle__
+def add_flops_counter_variable_or_reset(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+ print('Warning: variables __flops__ or __params__ are already '
+ 'defined for the module' + type(module).__name__ +
+ ' ptflops can affect your code!')
+ module.__flops__ = 0
+ module.__params__ = get_model_parameters_number(module)
+def is_supported_instance(module):
+ if type(module) in get_modules_mapping():
+ return True
+ return False
+def remove_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+def get_modules_mapping():
+ return {
+ # convolutions
+ nn.Conv1d: conv_flops_counter_hook,
+ nn.Conv2d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
+ nn.Conv3d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
+ # activations
+ nn.ReLU: relu_flops_counter_hook,
+ nn.PReLU: relu_flops_counter_hook,
+ nn.ELU: relu_flops_counter_hook,
+ nn.LeakyReLU: relu_flops_counter_hook,
+ nn.ReLU6: relu_flops_counter_hook,
+ # poolings
+ nn.MaxPool1d: pool_flops_counter_hook,
+ nn.AvgPool1d: pool_flops_counter_hook,
+ nn.AvgPool2d: pool_flops_counter_hook,
+ nn.MaxPool2d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
+ nn.MaxPool3d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
+ nn.AvgPool3d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+ # normalizations
+ nn.BatchNorm1d: norm_flops_counter_hook,
+ nn.BatchNorm2d: norm_flops_counter_hook,
+ nn.BatchNorm3d: norm_flops_counter_hook,
+ nn.GroupNorm: norm_flops_counter_hook,
+ nn.InstanceNorm1d: norm_flops_counter_hook,
+ nn.InstanceNorm2d: norm_flops_counter_hook,
+ nn.InstanceNorm3d: norm_flops_counter_hook,
+ nn.LayerNorm: norm_flops_counter_hook,
+ # FC
+ nn.Linear: linear_flops_counter_hook,
+ mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
+ # Upscale
+ nn.Upsample: upsample_flops_counter_hook,
+ # Deconvolution
+ nn.ConvTranspose2d: deconv_flops_counter_hook,
+ mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
+ }
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py b/ControlNet/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7076f80bf37f7931185bf0293ffcc1ce19c8ef
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+def _fuse_conv_bn(conv, bn):
+ """Fuse conv and bn into one module.
+ Args:
+ conv (nn.Module): Conv to be fused.
+ bn (nn.Module): BN to be fused.
+ Returns:
+ nn.Module: Fused module.
+ """
+ conv_w = conv.weight
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+ bn.running_mean)
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+ conv.weight = nn.Parameter(conv_w *
+ factor.reshape([conv.out_channels, 1, 1, 1]))
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+ return conv
+def fuse_conv_bn(module):
+ """Recursively fuse conv and bn in a module.
+ During inference, the functionary of batch norm layers is turned off
+ but only the mean and var alone channels are used, which exposes the
+ chance to fuse it with the preceding conv layers to save computations and
+ simplify network structures.
+ Args:
+ module (nn.Module): Module to be fused.
+ Returns:
+ nn.Module: Fused module.
+ """
+ last_conv = None
+ last_conv_name = None
+ for name, child in module.named_children():
+ if isinstance(child,
+ (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+ if last_conv is None: # only fuse BN that is after Conv
+ continue
+ fused_conv = _fuse_conv_bn(last_conv, child)
+ module._modules[last_conv_name] = fused_conv
+ # To reduce changes, set BN as Identity instead of deleting it.
+ module._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ fuse_conv_bn(child)
+ return module
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/utils/sync_bn.py b/ControlNet/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78f39181d75bb85c53e8c7c8eaf45690e9f0bee
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
@@ -0,0 +1,59 @@
+import torch
+import annotator.uniformer.mmcv as mmcv
+class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
+ """A general BatchNorm layer without input dimension check.
+ Reproduced from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+ is `_check_input_dim` that is designed for tensor sanity checks.
+ The check has been bypassed in this class for the convenience of converting
+ SyncBatchNorm.
+ """
+ def _check_input_dim(self, input):
+ return
+def revert_sync_batchnorm(module):
+ """Helper function to convert all `SyncBatchNorm` (SyncBN) and
+ `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
+ `BatchNormXd` layers.
+ Adapted from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ Args:
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
+ Returns:
+ module_output: The converted module with `BatchNormXd` layers.
+ """
+ module_output = module
+ module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
+ if hasattr(mmcv, 'ops'):
+ module_checklist.append(mmcv.ops.SyncBatchNorm)
+ if isinstance(module, tuple(module_checklist)):
+ module_output = _BatchNormXd(module.num_features, module.eps,
+ module.momentum, module.affine,
+ module.track_running_stats)
+ if module.affine:
+ # no_grad() may not be needed here but
+ # just to be consistent with `convert_sync_batchnorm()`
+ with torch.no_grad():
+ module_output.weight = module.weight
+ module_output.bias = module.bias
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ module_output.training = module.training
+ # qconfig exists in quantized models
+ if hasattr(module, 'qconfig'):
+ module_output.qconfig = module.qconfig
+ for name, child in module.named_children():
+ module_output.add_module(name, revert_sync_batchnorm(child))
+ del module
+ return module_output
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/utils/weight_init.py b/ControlNet/annotator/uniformer/mmcv/cnn/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..287a1d0bffe26e023029d48634d9b761deda7ba4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/utils/weight_init.py
@@ -0,0 +1,684 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg, get_logger, print_log
+INITIALIZERS = Registry('initializer')
+def update_init_info(module, init_info):
+ """Update the `_params_init_info` in the module if the value of parameters
+ are changed.
+ Args:
+ module (obj:`nn.Module`): The module of PyTorch with a user-defined
+ attribute `_params_init_info` which records the initialization
+ information.
+ init_info (str): The string that describes the initialization.
+ """
+ assert hasattr(
+ module,
+ '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+ for name, param in module.named_parameters():
+ assert param in module._params_init_info, (
+ f'Find a new :obj:`Parameter` '
+ f'named `{name}` during executing the '
+ f'`init_weights` of '
+ f'`{module.__class__.__name__}`. '
+ f'Please do not add or '
+ f'replace parameters during executing '
+ f'the `init_weights`. ')
+ # The parameter has been changed during executing the
+ # `init_weights` of module
+ mean_value = param.data.mean()
+ if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+ module._params_init_info[param]['init_info'] = init_info
+ module._params_init_info[param]['tmp_mean_value'] = mean_value
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+def normal_init(module, mean=0, std=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+def uniform_init(module, a=0, b=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+def kaiming_init(module,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+def caffe2_xavier_init(module, bias=0):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ kaiming_init(
+ module,
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=bias,
+ distribution='uniform')
+def bias_init_with_prob(prior_prob):
+ """initialize conv/fc bias value according to a given probability value."""
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+ return bias_init
+def _get_bases_name(m):
+ return [b.__name__ for b in m.__class__.__bases__]
+class BaseInit(object):
+ def __init__(self, *, bias=0, bias_prob=None, layer=None):
+ self.wholemodule = False
+ if not isinstance(bias, (int, float)):
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
+ if bias_prob is not None:
+ if not isinstance(bias_prob, float):
+ raise TypeError(f'bias_prob type must be float, \
+ but got {type(bias_prob)}')
+ if layer is not None:
+ if not isinstance(layer, (str, list)):
+ raise TypeError(f'layer must be a str or a list of str, \
+ but got a {type(layer)}')
+ else:
+ layer = []
+ if bias_prob is not None:
+ self.bias = bias_init_with_prob(bias_prob)
+ else:
+ self.bias = bias
+ self.layer = [layer] if isinstance(layer, str) else layer
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}, bias={self.bias}'
+ return info
+class ConstantInit(BaseInit):
+ """Initialize module parameters with constant values.
+ Args:
+ val (int | float): the value to fill the weights in the module with
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+ def __init__(self, val, **kwargs):
+ super().__init__(**kwargs)
+ self.val = val
+ def __call__(self, module):
+ def init(m):
+ if self.wholemodule:
+ constant_init(m, self.val, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ constant_init(m, self.val, self.bias)
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+ return info
+class XavierInit(BaseInit):
+ r"""Initialize module parameters with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks - Glorot, X. & Bengio, Y. (2010).
+ `_
+ Args:
+ gain (int | float): an optional scaling factor. Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'``
+ or ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+ def __init__(self, gain=1, distribution='normal', **kwargs):
+ super().__init__(**kwargs)
+ self.gain = gain
+ self.distribution = distribution
+ def __call__(self, module):
+ def init(m):
+ if self.wholemodule:
+ xavier_init(m, self.gain, self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ xavier_init(m, self.gain, self.bias, self.distribution)
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+ f'distribution={self.distribution}, bias={self.bias}'
+ return info
+class NormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+ Args:
+ mean (int | float):the mean of the normal distribution. Defaults to 0.
+ std (int | float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+ def __init__(self, mean=0, std=1, **kwargs):
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ def __call__(self, module):
+ def init(m):
+ if self.wholemodule:
+ normal_init(m, self.mean, self.std, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ normal_init(m, self.mean, self.std, self.bias)
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: mean={self.mean},' \
+ f' std={self.std}, bias={self.bias}'
+ return info
+class TruncNormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+ outside :math:`[a, b]`.
+ Args:
+ mean (float): the mean of the normal distribution. Defaults to 0.
+ std (float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ a (float): The minimum cutoff value.
+ b ( float): The maximum cutoff value.
+ bias (float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+ def __init__(self,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ self.a = a
+ self.b = b
+ def __call__(self, module: nn.Module) -> None:
+ def init(m):
+ if self.wholemodule:
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+ f' mean={self.mean}, std={self.std}, bias={self.bias}'
+ return info
+class UniformInit(BaseInit):
+ r"""Initialize module parameters with values drawn from the uniform
+ distribution :math:`\mathcal{U}(a, b)`.
+ Args:
+ a (int | float): the lower bound of the uniform distribution.
+ Defaults to 0.
+ b (int | float): the upper bound of the uniform distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+ def __init__(self, a=0, b=1, **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.b = b
+ def __call__(self, module):
+ def init(m):
+ if self.wholemodule:
+ uniform_init(m, self.a, self.b, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ uniform_init(m, self.a, self.b, self.bias)
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a},' \
+ f' b={self.b}, bias={self.bias}'
+ return info
+class KaimingInit(BaseInit):
+ r"""Initialize module parameters with the values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification - He, K. et al. (2015).
+ `_
+ Args:
+ a (int | float): the negative slope of the rectifier used after this
+ layer (only used with ``'leaky_relu'``). Defaults to 0.
+ mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
+ ``'fan_in'`` preserves the magnitude of the variance of the weights
+ in the forward pass. Choosing ``'fan_out'`` preserves the
+ magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+ nonlinearity (str): the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+ Defaults to 'relu'.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'`` or
+ ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+ def __init__(self,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ distribution='normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.mode = mode
+ self.nonlinearity = nonlinearity
+ self.distribution = distribution
+ def __call__(self, module):
+ def init(m):
+ if self.wholemodule:
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+ f'nonlinearity={self.nonlinearity}, ' \
+ f'distribution ={self.distribution}, bias={self.bias}'
+ return info
+class Caffe2XavierInit(KaimingInit):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ def __init__(self, **kwargs):
+ super().__init__(
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ distribution='uniform',
+ **kwargs)
+ def __call__(self, module):
+ super().__call__(module)
+class PretrainedInit(object):
+ """Initialize module by loading a pretrained model.
+ Args:
+ checkpoint (str): the checkpoint file of the pretrained model should
+ be load.
+ prefix (str, optional): the prefix of a sub-module in the pretrained
+ model. it is for loading a part of the pretrained model to
+ initialize. For example, if we would like to only load the
+ backbone of a detector model, we can set ``prefix='backbone.'``.
+ Defaults to None.
+ map_location (str): map tensors into proper locations.
+ """
+ def __init__(self, checkpoint, prefix=None, map_location=None):
+ self.checkpoint = checkpoint
+ self.prefix = prefix
+ self.map_location = map_location
+ def __call__(self, module):
+ from annotator.uniformer.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict)
+ logger = get_logger('mmcv')
+ if self.prefix is None:
+ print_log(f'load model from: {self.checkpoint}', logger=logger)
+ load_checkpoint(
+ module,
+ self.checkpoint,
+ map_location=self.map_location,
+ strict=False,
+ logger=logger)
+ else:
+ print_log(
+ f'load {self.prefix} in model from: {self.checkpoint}',
+ logger=logger)
+ state_dict = _load_checkpoint_with_prefix(
+ self.prefix, self.checkpoint, map_location=self.map_location)
+ load_state_dict(module, state_dict, strict=False, logger=logger)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+ return info
+def _initialize(module, cfg, wholemodule=False):
+ func = build_from_cfg(cfg, INITIALIZERS)
+ # wholemodule flag is for override mode, there is no layer key in override
+ # and initializer will give init values for the whole module with the name
+ # in override.
+ func.wholemodule = wholemodule
+ func(module)
+def _initialize_override(module, override, cfg):
+ if not isinstance(override, (dict, list)):
+ raise TypeError(f'override must be a dict or a list of dict, \
+ but got {type(override)}')
+ override = [override] if isinstance(override, dict) else override
+ for override_ in override:
+ cp_override = copy.deepcopy(override_)
+ name = cp_override.pop('name', None)
+ if name is None:
+ raise ValueError('`override` must contain the key "name",'
+ f'but got {cp_override}')
+ # if override only has name key, it means use args in init_cfg
+ if not cp_override:
+ cp_override.update(cfg)
+ # if override has name key and other args except type key, it will
+ # raise error
+ elif 'type' not in cp_override.keys():
+ raise ValueError(
+ f'`override` need "type" key, but got {cp_override}')
+ if hasattr(module, name):
+ _initialize(getattr(module, name), cp_override, wholemodule=True)
+ else:
+ raise RuntimeError(f'module did not have attribute {name}, '
+ f'but init_cfg is {cp_override}.')
+def initialize(module, init_cfg):
+ """Initialize a module.
+ Args:
+ module (``torch.nn.Module``): the module will be initialized.
+ init_cfg (dict | list[dict]): initialization configuration dict to
+ define initializer. OpenMMLab has implemented 6 initializers
+ including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+ ``Kaiming``, and ``Pretrained``.
+ Example:
+ >>> module = nn.Linear(2, 3, bias=True)
+ >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+ >>> initialize(module, init_cfg)
+ >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+ >>> # define key ``'layer'`` for initializing layer with different
+ >>> # configuration
+ >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+ dict(type='Constant', layer='Linear', val=2)]
+ >>> initialize(module, init_cfg)
+ >>> # define key``'override'`` to initialize some specific part in
+ >>> # module
+ >>> class FooNet(nn.Module):
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.feat = nn.Conv2d(3, 16, 3)
+ >>> self.reg = nn.Conv2d(16, 10, 3)
+ >>> self.cls = nn.Conv2d(16, 5, 3)
+ >>> model = FooNet()
+ >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+ >>> override=dict(type='Constant', name='reg', val=3, bias=4))
+ >>> initialize(model, init_cfg)
+ >>> model = ResNet(depth=50)
+ >>> # Initialize weights with the pretrained model.
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint='torchvision://resnet50')
+ >>> initialize(model, init_cfg)
+ >>> # Initialize weights of a sub-module with the specific part of
+ >>> # a pretrained model by using "prefix".
+ >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
+ >>> 'retinanet_r50_fpn_1x_coco/'\
+ >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint=url, prefix='backbone.')
+ """
+ if not isinstance(init_cfg, (dict, list)):
+ raise TypeError(f'init_cfg must be a dict or a list of dict, \
+ but got {type(init_cfg)}')
+ if isinstance(init_cfg, dict):
+ init_cfg = [init_cfg]
+ for cfg in init_cfg:
+ # should deeply copy the original config because cfg may be used by
+ # other modules, e.g., one init_cfg shared by multiple bottleneck
+ # blocks, the expected cfg will be changed after pop and will change
+ # the initialization behavior of other modules
+ cp_cfg = copy.deepcopy(cfg)
+ override = cp_cfg.pop('override', None)
+ _initialize(module, cp_cfg)
+ if override is not None:
+ cp_cfg.pop('layer', None)
+ _initialize_override(module, override, cp_cfg)
+ else:
+ # All attributes in module have same initialization.
+ pass
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+ b: float) -> Tensor:
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ # Modified from
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+ # Uniformly fill tensor with values from [lower, upper], then translate
+ # to [2lower-1, 2upper-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+def trunc_normal_(tensor: Tensor,
+ mean: float = 0.,
+ std: float = 1.,
+ a: float = -2.,
+ b: float = 2.) -> Tensor:
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Modified from
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+ mean (float): the mean of the normal distribution.
+ std (float): the standard deviation of the normal distribution.
+ a (float): the minimum cutoff value.
+ b (float): the maximum cutoff value.
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/ControlNet/annotator/uniformer/mmcv/cnn/vgg.py b/ControlNet/annotator/uniformer/mmcv/cnn/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8778b649561a45a9652b1a15a26c2d171e58f3e1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/cnn/vgg.py
@@ -0,0 +1,175 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import torch.nn as nn
+from .utils import constant_init, kaiming_init, normal_init
+def conv3x3(in_planes, out_planes, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation)
+def make_vgg_layer(inplanes,
+ planes,
+ num_blocks,
+ dilation=1,
+ with_bn=False,
+ ceil_mode=False):
+ layers = []
+ for _ in range(num_blocks):
+ layers.append(conv3x3(inplanes, planes, dilation))
+ if with_bn:
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ inplanes = planes
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+ return layers
+class VGG(nn.Module):
+ """VGG backbone.
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_bn (bool): Use BatchNorm or not.
+ num_classes (int): number of classes for classification.
+ num_stages (int): VGG stages, normally 5.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ """
+ arch_settings = {
+ 11: (1, 1, 2, 2, 2),
+ 13: (2, 2, 2, 2, 2),
+ 16: (2, 2, 3, 3, 3),
+ 19: (2, 2, 4, 4, 4)
+ }
+ def __init__(self,
+ depth,
+ with_bn=False,
+ num_classes=-1,
+ num_stages=5,
+ dilations=(1, 1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3, 4),
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ ceil_mode=False,
+ with_last_pool=True):
+ super(VGG, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for vgg')
+ assert num_stages >= 1 and num_stages <= 5
+ stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ assert len(dilations) == num_stages
+ assert max(out_indices) <= num_stages
+ self.num_classes = num_classes
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.inplanes = 3
+ start_idx = 0
+ vgg_layers = []
+ self.range_sub_modules = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ num_modules = num_blocks * (2 + with_bn) + 1
+ end_idx = start_idx + num_modules
+ dilation = dilations[i]
+ planes = 64 * 2**i if i < 4 else 512
+ vgg_layer = make_vgg_layer(
+ self.inplanes,
+ planes,
+ num_blocks,
+ dilation=dilation,
+ with_bn=with_bn,
+ ceil_mode=ceil_mode)
+ vgg_layers.extend(vgg_layer)
+ self.inplanes = planes
+ self.range_sub_modules.append([start_idx, end_idx])
+ start_idx = end_idx
+ if not with_last_pool:
+ vgg_layers.pop(-1)
+ self.range_sub_modules[-1][1] -= 1
+ self.module_name = 'features'
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ elif isinstance(m, nn.Linear):
+ normal_init(m, std=0.01)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def forward(self, x):
+ outs = []
+ vgg_layers = getattr(self, self.module_name)
+ for i in range(len(self.stage_blocks)):
+ for j in range(*self.range_sub_modules[i]):
+ vgg_layer = vgg_layers[j]
+ x = vgg_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+ def train(self, mode=True):
+ super(VGG, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ vgg_layers = getattr(self, self.module_name)
+ if mode and self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ for j in range(*self.range_sub_modules[i]):
+ mod = vgg_layers[j]
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/ControlNet/annotator/uniformer/mmcv/engine/__init__.py b/ControlNet/annotator/uniformer/mmcv/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/engine/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
+ single_gpu_test)
+__all__ = [
+ 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
+ 'single_gpu_test'
diff --git a/ControlNet/annotator/uniformer/mmcv/engine/test.py b/ControlNet/annotator/uniformer/mmcv/engine/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbeef271db634ec2dadfda3bc0b5ef9c7a677ff
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/engine/test.py
@@ -0,0 +1,202 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+import torch
+import torch.distributed as dist
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.runner import get_dist_info
+def single_gpu_test(model, data_loader):
+ """Test model with a single gpu.
+ This method tests model with a single gpu and displays test progress bar.
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for data in data_loader:
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+ # Assume result has the same length of batch_size
+ # refer to https://github.com/open-mmlab/mmcv/issues/985
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting
+ ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
+ communication for results collection. On cpu mode it saves the results on
+ different gpus to ``tmpdir`` and collects them by the rank 0 worker.
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+ if rank == 0:
+ batch_size = len(result)
+ batch_size_all = batch_size * world_size
+ if batch_size_all + prog_bar.completed > len(dataset):
+ batch_size_all = len(dataset) - prog_bar.completed
+ for _ in range(batch_size_all):
+ prog_bar.update()
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results under cpu mode.
+ On cpu mode, this function will save the results on different gpus to
+ ``tmpdir`` and collect them by the rank 0 worker.
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+ tmpdir (str | None): temporal directory for collected results to
+ store. If set to None, it will create a random temporal directory
+ for it.
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
+ part_result = mmcv.load(part_file)
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+def collect_results_gpu(result_part, size):
+ """Collect results under gpu mode.
+ On gpu mode, this function will encode results to gpu tensors and use gpu
+ communication for results collection.
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/__init__.py b/ControlNet/annotator/uniformer/mmcv/fileio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .file_client import BaseStorageBackend, FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .io import dump, load, register_handler
+from .parse import dict_from_file, list_from_file
+__all__ = [
+ 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
+ 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
+ 'list_from_file', 'dict_from_file'
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/file_client.py b/ControlNet/annotator/uniformer/mmcv/fileio/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..950f0c1aeab14b8e308a7455ccd64a95b5d98add
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/file_client.py
@@ -0,0 +1,1148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import os
+import os.path as osp
+import re
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Iterable, Iterator, Optional, Tuple, Union
+from urllib.request import urlopen
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.utils.misc import has_method
+from annotator.uniformer.mmcv.utils.path import is_filepath
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+ # a flag to indicate whether the backend can create a symlink for a file
+ _allow_symlink = False
+ @property
+ def name(self):
+ return self.__class__.__name__
+ @property
+ def allow_symlink(self):
+ return self._allow_symlink
+ @abstractmethod
+ def get(self, filepath):
+ pass
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+class CephBackend(BaseStorageBackend):
+ """Ceph storage backend (for internal use).
+ Args:
+ path_mapping (dict|None): path mapping dict from local path to Petrel
+ path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
+ will be replaced by ``dst``. Default: None.
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+ """
+ def __init__(self, path_mapping=None):
+ try:
+ import ceph
+ except ImportError:
+ raise ImportError('Please install ceph to enable CephBackend.')
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+ self._client = ceph.S3Client()
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+ def get(self, filepath):
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+class PetrelBackend(BaseStorageBackend):
+ """Petrel storage backend (for internal use).
+ PetrelBackend supports reading and writing data to multiple clusters.
+ If the file path contains the cluster name, PetrelBackend will read data
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
+ access the default cluster.
+ Args:
+ path_mapping (dict, optional): Path mapping dict from local path to
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
+ ``filepath`` will be replaced by ``dst``. Default: None.
+ enable_mc (bool, optional): Whether to enable memcached support.
+ Default: True.
+ Examples:
+ >>> filepath1 = 's3://path/of/file'
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
+ >>> client = PetrelBackend()
+ >>> client.get(filepath1) # get data from default cluster
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
+ """
+ def __init__(self,
+ path_mapping: Optional[dict] = None,
+ enable_mc: bool = True):
+ try:
+ from petrel_client import client
+ except ImportError:
+ raise ImportError('Please install petrel_client to enable '
+ 'PetrelBackend.')
+ self._client = client.Client(enable_mc=enable_mc)
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+ def _map_path(self, filepath: Union[str, Path]) -> str:
+ """Map ``filepath`` to a string path whose prefix will be replaced by
+ :attr:`self.path_mapping`.
+ Args:
+ filepath (str): Path to be mapped.
+ """
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ return filepath
+ def _format_path(self, filepath: str) -> str:
+ """Convert a ``filepath`` to standard format of petrel oss.
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
+ environment, the ``filepath`` will be the format of
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
+ Args:
+ filepath (str): Path to be formatted.
+ """
+ return re.sub(r'\\+', '/', filepath)
+ def get(self, filepath: Union[str, Path]) -> memoryview:
+ """Read data from a given ``filepath`` with 'rb' mode.
+ Args:
+ filepath (str or Path): Path to read data.
+ Returns:
+ memoryview: A memory view of expected bytes object to avoid
+ copying. The memoryview object can be converted to bytes by
+ ``value_buf.tobytes()``.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return str(self.get(filepath), encoding=encoding)
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Save data to a given ``filepath``.
+ Args:
+ obj (bytes): Data to be saved.
+ filepath (str or Path): Path to write data.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.put(filepath, obj)
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Save data to a given ``filepath``.
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to encode the ``obj``.
+ Default: 'utf-8'.
+ """
+ self.put(bytes(obj, encoding=encoding), filepath)
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ if not has_method(self._client, 'delete'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `delete` method, please use a higher version or dev'
+ ' branch instead.'))
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.delete(filepath)
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ if not (has_method(self._client, 'contains')
+ and has_method(self._client, 'isdir')):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` and `isdir` methods, please use a higher'
+ 'version or dev branch instead.'))
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath) or self._client.isdir(filepath)
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ if not has_method(self._client, 'isdir'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `isdir` method, please use a higher version or dev'
+ ' branch instead.'))
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.isdir(filepath)
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ if not has_method(self._client, 'contains'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` method, please use a higher version or '
+ 'dev branch instead.'))
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath)
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+ Args:
+ filepath (str or Path): Path to be concatenated.
+ Returns:
+ str: The result after concatenation.
+ """
+ filepath = self._format_path(self._map_path(filepath))
+ if filepath.endswith('/'):
+ filepath = filepath[:-1]
+ formatted_paths = [filepath]
+ for path in filepaths:
+ formatted_paths.append(self._format_path(self._map_path(path)))
+ return '/'.join(formatted_paths)
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download a file from ``filepath`` and return a temporary path.
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+ Args:
+ filepath (str | Path): Download a file from ``filepath``.
+ Examples:
+ >>> client = PetrelBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('s3://path/of/your/file') as path:
+ ... # do something here
+ Yields:
+ Iterable[str]: Only yield one temporary path.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ assert self.isfile(filepath)
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+ Note:
+ Petrel has no concept of directories but it simulates the directory
+ hierarchy in the filesystem through public prefixes. In addition,
+ if the returned path ends with '/', it means the path is a public
+ prefix which is a logical directory.
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ In addition, the returned path of directory will not contains the
+ suffix '/' which is consistent with other backends.
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if not has_method(self._client, 'list'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `list` method, please use a higher version or dev'
+ ' branch instead.'))
+ dir_path = self._map_path(dir_path)
+ dir_path = self._format_path(dir_path)
+ if list_dir and suffix is not None:
+ raise TypeError(
+ '`list_dir` should be False when `suffix` is not None')
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+ # Petrel's simulated directory hierarchy assumes that directory paths
+ # should end with `/`
+ if not dir_path.endswith('/'):
+ dir_path += '/'
+ root = dir_path
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for path in self._client.list(dir_path):
+ # the `self.isdir` is not used here to determine whether path
+ # is a directory, because `self.isdir` relies on
+ # `self._client.list`
+ if path.endswith('/'): # a directory path
+ next_dir_path = self.join_path(dir_path, path)
+ if list_dir:
+ # get the relative path and exclude the last
+ # character '/'
+ rel_dir = next_dir_path[len(root):-1]
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(next_dir_path, list_dir,
+ list_file, suffix,
+ recursive)
+ else: # a file path
+ absolute_path = self.join_path(dir_path, path)
+ rel_path = absolute_path[len(root):]
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError(
+ 'Please install memcached to enable MemcachedBackend.')
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+ self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+ Args:
+ db_path (str): Lmdb database path.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+ Attributes:
+ db_path (str): Lmdb database path.
+ """
+ def __init__(self,
+ db_path,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+ self.db_path = str(db_path)
+ self._client = lmdb.open(
+ self.db_path,
+ readonly=readonly,
+ lock=lock,
+ readahead=readahead,
+ **kwargs)
+ def get(self, filepath):
+ """Get values according to the filepath.
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ """
+ filepath = str(filepath)
+ with self._client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+ _allow_symlink = True
+ def get(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+ Args:
+ filepath (str or Path): Path to read data.
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, 'r', encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+ Note:
+ ``put`` will create a directory if the directory of ``filepath``
+ does not exist.
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'wb') as f:
+ f.write(obj)
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+ Note:
+ ``put_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'w', encoding=encoding) as f:
+ f.write(obj)
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ os.remove(filepath)
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return osp.exists(filepath)
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return osp.isdir(filepath)
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return osp.isfile(filepath)
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+ Args:
+ filepath (str or Path): Path to be concatenated.
+ Returns:
+ str: The result of concatenation.
+ """
+ return osp.join(filepath, *filepaths)
+ @contextmanager
+ def get_local_path(
+ self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
+ """Only for unified API and do nothing."""
+ yield filepath
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if list_dir and suffix is not None:
+ raise TypeError('`suffix` should be None when `list_dir` is True')
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+ root = dir_path
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+ elif osp.isdir(entry.path):
+ if list_dir:
+ rel_dir = osp.relpath(entry.path, root)
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(entry.path, list_dir,
+ list_file, suffix,
+ recursive)
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+class HTTPBackend(BaseStorageBackend):
+ """HTTP and HTTPS storage bachend."""
+ def get(self, filepath):
+ value_buf = urlopen(filepath).read()
+ return value_buf
+ def get_text(self, filepath, encoding='utf-8'):
+ value_buf = urlopen(filepath).read()
+ return value_buf.decode(encoding)
+ @contextmanager
+ def get_local_path(self, filepath: str) -> Iterable[str]:
+ """Download a file from ``filepath``.
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+ Args:
+ filepath (str): Download a file from ``filepath``.
+ Examples:
+ >>> client = HTTPBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('http://path/of/your/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+class FileClient:
+ """A general file client to access files in different backends.
+ The client loads a file or text in a specified backend from its path
+ and returns it as a binary or text file. There are two ways to choose a
+ backend, the name of backend and the prefix of path. Although both of them
+ can be used to choose a storage backend, ``backend`` has a higher priority
+ that is if they are all set, the storage backend will be chosen by the
+ backend argument. If they are all `None`, the disk backend will be chosen.
+ Note that It can also register other backend accessor with a given name,
+ prefixes, and backend class. In addition, We use the singleton pattern to
+ avoid repeated object creation. If the arguments are the same, the same
+ object will be returned.
+ Args:
+ backend (str, optional): The storage backend type. Options are "disk",
+ "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
+ prefix (str, optional): The prefix of the registered storage backend.
+ Options are "s3", "http", "https". Default: None.
+ Examples:
+ >>> # only set backend
+ >>> file_client = FileClient(backend='petrel')
+ >>> # only set prefix
+ >>> file_client = FileClient(prefix='s3')
+ >>> # set both backend and prefix but use backend to choose client
+ >>> file_client = FileClient(backend='petrel', prefix='s3')
+ >>> # if the arguments are the same, the same object is returned
+ >>> file_client1 = FileClient(backend='petrel')
+ >>> file_client1 is file_client
+ True
+ Attributes:
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'ceph': CephBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ 'petrel': PetrelBackend,
+ 'http': HTTPBackend,
+ }
+ # This collection is used to record the overridden backends, and when a
+ # backend appears in the collection, the singleton pattern is disabled for
+ # that backend, because if the singleton pattern is used, then the object
+ # returned will be the backend before overwriting
+ _overridden_backends = set()
+ _prefix_to_backends = {
+ 's3': PetrelBackend,
+ 'http': HTTPBackend,
+ 'https': HTTPBackend,
+ }
+ _overridden_prefixes = set()
+ _instances = {}
+ def __new__(cls, backend=None, prefix=None, **kwargs):
+ if backend is None and prefix is None:
+ backend = 'disk'
+ if backend is not None and backend not in cls._backends:
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(cls._backends.keys())}')
+ if prefix is not None and prefix not in cls._prefix_to_backends:
+ raise ValueError(
+ f'prefix {prefix} is not supported. Currently supported ones '
+ f'are {list(cls._prefix_to_backends.keys())}')
+ # concatenate the arguments to a unique key for determining whether
+ # objects with the same arguments were created
+ arg_key = f'{backend}:{prefix}'
+ for key, value in kwargs.items():
+ arg_key += f':{key}:{value}'
+ # if a backend was overridden, it will create a new object
+ if (arg_key in cls._instances
+ and backend not in cls._overridden_backends
+ and prefix not in cls._overridden_prefixes):
+ _instance = cls._instances[arg_key]
+ else:
+ # create a new object and put it to _instance
+ _instance = super().__new__(cls)
+ if backend is not None:
+ _instance.client = cls._backends[backend](**kwargs)
+ else:
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
+ cls._instances[arg_key] = _instance
+ return _instance
+ @property
+ def name(self):
+ return self.client.name
+ @property
+ def allow_symlink(self):
+ return self.client.allow_symlink
+ @staticmethod
+ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
+ """Parse the prefix of a uri.
+ Args:
+ uri (str | Path): Uri to be parsed that contains the file prefix.
+ Examples:
+ >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
+ 's3'
+ Returns:
+ str | None: Return the prefix of uri if the uri contains '://'
+ else ``None``.
+ """
+ assert is_filepath(uri)
+ uri = str(uri)
+ if '://' not in uri:
+ return None
+ else:
+ prefix, _ = uri.split('://')
+ # In the case of PetrelBackend, the prefix may contains the cluster
+ # name like clusterName:s3
+ if ':' in prefix:
+ _, prefix = prefix.split(':')
+ return prefix
+ @classmethod
+ def infer_client(cls,
+ file_client_args: Optional[dict] = None,
+ uri: Optional[Union[str, Path]] = None) -> 'FileClient':
+ """Infer a suitable file client based on the URI and arguments.
+ Args:
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. Default: None.
+ uri (str | Path, optional): Uri to be parsed that contains the file
+ prefix. Default: None.
+ Examples:
+ >>> uri = 's3://path/of/your/file'
+ >>> file_client = FileClient.infer_client(uri=uri)
+ >>> file_client_args = {'backend': 'petrel'}
+ >>> file_client = FileClient.infer_client(file_client_args)
+ Returns:
+ FileClient: Instantiated FileClient object.
+ """
+ assert file_client_args is not None or uri is not None
+ if file_client_args is None:
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
+ return cls(prefix=file_prefix)
+ else:
+ return cls(**file_client_args)
+ @classmethod
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
+ if not isinstance(name, str):
+ raise TypeError('the backend name should be a string, '
+ f'but got {type(name)}')
+ if not inspect.isclass(backend):
+ raise TypeError(
+ f'backend should be a class but got {type(backend)}')
+ if not issubclass(backend, BaseStorageBackend):
+ raise TypeError(
+ f'backend {backend} is not a subclass of BaseStorageBackend')
+ if not force and name in cls._backends:
+ raise KeyError(
+ f'{name} is already registered as a storage backend, '
+ 'add "force=True" if you want to override it')
+ if name in cls._backends and force:
+ cls._overridden_backends.add(name)
+ cls._backends[name] = backend
+ if prefixes is not None:
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if prefix not in cls._prefix_to_backends:
+ cls._prefix_to_backends[prefix] = backend
+ elif (prefix in cls._prefix_to_backends) and force:
+ cls._overridden_prefixes.add(prefix)
+ cls._prefix_to_backends[prefix] = backend
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a storage backend,'
+ ' add "force=True" if you want to override it')
+ @classmethod
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
+ """Register a backend to FileClient.
+ This method can be used as a normal class method or a decorator.
+ .. code-block:: python
+ class NewBackend(BaseStorageBackend):
+ def get(self, filepath):
+ return filepath
+ def get_text(self, filepath):
+ return filepath
+ FileClient.register_backend('new', NewBackend)
+ or
+ .. code-block:: python
+ @FileClient.register_backend('new')
+ class NewBackend(BaseStorageBackend):
+ def get(self, filepath):
+ return filepath
+ def get_text(self, filepath):
+ return filepath
+ Args:
+ name (str): The name of the registered backend.
+ backend (class, optional): The backend class to be registered,
+ which must be a subclass of :class:`BaseStorageBackend`.
+ When this method is used as a decorator, backend is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the backend if the name
+ has already been registered. Defaults to False.
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
+ of the registered storage backend. Default: None.
+ `New in version 1.3.15.`
+ """
+ if backend is not None:
+ cls._register_backend(
+ name, backend, force=force, prefixes=prefixes)
+ return
+ def _register(backend_cls):
+ cls._register_backend(
+ name, backend_cls, force=force, prefixes=prefixes)
+ return backend_cls
+ return _register
+ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
+ """Read data from a given ``filepath`` with 'rb' mode.
+ Note:
+ There are two types of return values for ``get``, one is ``bytes``
+ and the other is ``memoryview``. The advantage of using memoryview
+ is that you can avoid copying, and if you want to convert it to
+ ``bytes``, you can use ``.tobytes()``.
+ Args:
+ filepath (str or Path): Path to read data.
+ Returns:
+ bytes | memoryview: Expected bytes object or a memory view of the
+ bytes object.
+ """
+ return self.client.get(filepath)
+ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return self.client.get_text(filepath, encoding)
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+ Note:
+ ``put`` should create a directory if the directory of ``filepath``
+ does not exist.
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ self.client.put(obj, filepath)
+ def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+ Note:
+ ``put_text`` should create a directory if the directory of
+ ``filepath`` does not exist.
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str, optional): The encoding format used to open the
+ `filepath`. Default: 'utf-8'.
+ """
+ self.client.put_text(obj, filepath)
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+ Args:
+ filepath (str, Path): Path to be removed.
+ """
+ self.client.remove(filepath)
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return self.client.exists(filepath)
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return self.client.isdir(filepath)
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return self.client.isfile(filepath)
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+ Args:
+ filepath (str or Path): Path to be concatenated.
+ Returns:
+ str: The result of concatenation.
+ """
+ return self.client.join_path(filepath, *filepaths)
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download data from ``filepath`` and write the data to local path.
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+ Note:
+ If the ``filepath`` is a local path, just return itself.
+ .. warning::
+ ``get_local_path`` is an experimental interface that may change in
+ the future.
+ Args:
+ filepath (str or Path): Path to be read data.
+ Examples:
+ >>> file_client = FileClient(prefix='s3')
+ >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
+ ... # do something here
+ Yields:
+ Iterable[str]: Only yield one path.
+ """
+ with self.client.get_local_path(str(filepath)) as local_path:
+ yield local_path
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
+ suffix, recursive)
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/handlers/__init__.py b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/handlers/base.py b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..288878bc57282fbb2f12b32290152ca8e9d3cab0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+class BaseFileHandler(metaclass=ABCMeta):
+ # `str_like` is a flag to indicate whether the type of file object is
+ # str-like object or bytes-like object. Pickle only processes bytes-like
+ # objects but json only processes str-like object. If it is str-like
+ # object, `StringIO` will be used to process the buffer.
+ str_like = True
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+ def load_from_path(self, filepath, mode='r', **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+ def dump_to_path(self, obj, filepath, mode='w', **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/handlers/json_handler.py b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import numpy as np
+from .base import BaseFileHandler
+def set_default(obj):
+ """Set default json values for non-serializable values.
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
+ etc.) into plain numbers of plain python built-in types.
+ """
+ if isinstance(obj, (set, range)):
+ return list(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ raise TypeError(f'{type(obj)} is unsupported for json dump')
+class JsonHandler(BaseFileHandler):
+ def load_from_fileobj(self, file):
+ return json.load(file)
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('default', set_default)
+ json.dump(obj, file, **kwargs)
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('default', set_default)
+ return json.dumps(obj, **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37c79bed4ef9fd8913715e62dbe3fc5cafdc3aa
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+from .base import BaseFileHandler
+class PickleHandler(BaseFileHandler):
+ str_like = False
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(
+ filepath, mode='rb', **kwargs)
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(
+ obj, filepath, mode='wb', **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5aa2eea1e8c76f8baf753d1c8c959dee665e543
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import yaml
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+from .base import BaseFileHandler # isort:skip
+class YamlHandler(BaseFileHandler):
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault('Loader', Loader)
+ return yaml.load(file, **kwargs)
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ yaml.dump(obj, file, **kwargs)
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ return yaml.dump(obj, **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/io.py b/ControlNet/annotator/uniformer/mmcv/fileio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaefde58aa3ea5b58f86249ce7e1c40c186eb8dd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/io.py
@@ -0,0 +1,151 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import BytesIO, StringIO
+from pathlib import Path
+from ..utils import is_list_of, is_str
+from .file_client import FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+file_handlers = {
+ 'json': JsonHandler(),
+ 'yaml': YamlHandler(),
+ 'yml': YamlHandler(),
+ 'pickle': PickleHandler(),
+ 'pkl': PickleHandler()
+def load(file, file_format=None, file_client_args=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+ This method provides a unified api for loading data from serialized files.
+ Note:
+ In v1.3.16 and later, ``load`` supports loading data from serialized
+ files those can be storaged in different backends.
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ Examples:
+ >>> load('/path/of/your/file') # file is storaged in disk
+ >>> load('https://path/of/your/file') # file is storaged in Internet
+ >>> load('s3://path/of/your/file') # file is storaged in petrel
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and is_str(file):
+ file_format = file.split('.')[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+ handler = file_handlers[file_format]
+ if is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO(file_client.get_text(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ else:
+ with BytesIO(file_client.get(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ elif hasattr(file, 'read'):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+ Note:
+ In v1.3.16 and later, ``dump`` supports dumping data as strings or to
+ files which is saved to different backends.
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dumped to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ Examples:
+ >>> dump('hello world', '/path/of/your/file') # disk
+ >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if is_str(file):
+ file_format = file.split('.')[-1]
+ elif file is None:
+ raise ValueError(
+ 'file_format must be specified since file is None')
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put_text(f.getvalue(), file)
+ else:
+ with BytesIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put(f.getvalue(), file)
+ elif hasattr(file, 'write'):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
+def _register_handler(handler, file_formats):
+ """Register a handler for some file extensions.
+ Args:
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
+ file_formats (str or list[str]): File formats to be handled by this
+ handler.
+ """
+ if not isinstance(handler, BaseFileHandler):
+ raise TypeError(
+ f'handler must be a child of BaseFileHandler, not {type(handler)}')
+ if isinstance(file_formats, str):
+ file_formats = [file_formats]
+ if not is_list_of(file_formats, str):
+ raise TypeError('file_formats must be a str or a list of str')
+ for ext in file_formats:
+ file_handlers[ext] = handler
+def register_handler(file_formats, **kwargs):
+ def wrap(cls):
+ _register_handler(cls(**kwargs), file_formats)
+ return cls
+ return wrap
diff --git a/ControlNet/annotator/uniformer/mmcv/fileio/parse.py b/ControlNet/annotator/uniformer/mmcv/fileio/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60f0d611b8d75692221d0edd7dc993b0a6445c9
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/fileio/parse.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import StringIO
+from .file_client import FileClient
+def list_from_file(filename,
+ prefix='',
+ offset=0,
+ max_num=0,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a list of strings.
+ Note:
+ In v1.3.16 and later, ``list_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a list for strings.
+ Args:
+ filename (str): Filename.
+ prefix (str): The prefix to be inserted to the beginning of each item.
+ offset (int): The offset of lines.
+ max_num (int): The maximum number of lines to be read,
+ zeros and negatives mean no limitation.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ Examples:
+ >>> list_from_file('/path/of/your/file') # disk
+ ['hello', 'world']
+ >>> list_from_file('s3://path/of/your/file') # ceph or petrel
+ ['hello', 'world']
+ Returns:
+ list[str]: A list of strings.
+ """
+ cnt = 0
+ item_list = []
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for _ in range(offset):
+ f.readline()
+ for line in f:
+ if 0 < max_num <= cnt:
+ break
+ item_list.append(prefix + line.rstrip('\n\r'))
+ cnt += 1
+ return item_list
+def dict_from_file(filename,
+ key_type=str,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a dict.
+ Each line of the text file will be two or more columns split by
+ whitespaces or tabs. The first column will be parsed as dict keys, and
+ the following columns will be parsed as dict values.
+ Note:
+ In v1.3.16 and later, ``dict_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a dict.
+ Args:
+ filename(str): Filename.
+ key_type(type): Type of the dict keys. str is user by default and
+ type conversion will be performed if specified.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ Examples:
+ >>> dict_from_file('/path/of/your/file') # disk
+ {'key1': 'value1', 'key2': 'value2'}
+ >>> dict_from_file('s3://path/of/your/file') # ceph or petrel
+ {'key1': 'value1', 'key2': 'value2'}
+ Returns:
+ dict: The parsed contents.
+ """
+ mapping = {}
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for line in f:
+ items = line.rstrip('\n').split()
+ assert len(items) >= 2
+ key = key_type(items[0])
+ val = items[1:] if len(items) > 2 else items[1]
+ mapping[key] = val
+ return mapping
diff --git a/ControlNet/annotator/uniformer/mmcv/image/__init__.py b/ControlNet/annotator/uniformer/mmcv/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0051d609d3de4e7562e3fe638335c66617c4d91
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/image/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
+ gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
+ rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
+from .geometric import (cutout, imcrop, imflip, imflip_, impad,
+ impad_to_multiple, imrescale, imresize, imresize_like,
+ imresize_to_multiple, imrotate, imshear, imtranslate,
+ rescale_size)
+from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
+from .misc import tensor2imgs
+from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+ adjust_lighting, adjust_sharpness, auto_contrast,
+ clahe, imdenormalize, imequalize, iminvert,
+ imnormalize, imnormalize_, lut_transform, posterize,
+ solarize)
+__all__ = [
+ 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
+ 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
+ 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+ 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+ 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+ 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+ 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+ 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+ 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
+ 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
diff --git a/ControlNet/annotator/uniformer/mmcv/image/colorspace.py b/ControlNet/annotator/uniformer/mmcv/image/colorspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..814533952fdfda23d67cb6a3073692d8c1156add
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/image/colorspace.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+def imconvert(img, src, dst):
+ """Convert an image from the src colorspace to dst colorspace.
+ Args:
+ img (ndarray): The input image.
+ src (str): The source colorspace, e.g., 'rgb', 'hsv'.
+ dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
+ Returns:
+ ndarray: The converted image.
+ """
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+def bgr2gray(img, keepdim=False):
+ """Convert a BGR image to grayscale image.
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+def rgb2gray(img, keepdim=False):
+ """Convert a RGB image to grayscale image.
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+def gray2bgr(img):
+ """Convert a grayscale image to BGR image.
+ Args:
+ img (ndarray): The input image.
+ Returns:
+ ndarray: The converted BGR image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ return out_img
+def gray2rgb(img):
+ """Convert a grayscale image to RGB image.
+ Args:
+ img (ndarray): The input image.
+ Returns:
+ ndarray: The converted RGB image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ return out_img
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, '
+ f'but got {img_type}')
+ return img
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
+ f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
+ -276.836, 135.576, -222.921
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+def convert_color_factory(src, dst):
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+ def convert_color(img):
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+ convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
+ image.
+ Args:
+ img (ndarray or str): The input image.
+ Returns:
+ ndarray: The converted {dst.upper()} image.
+ """
+ return convert_color
+bgr2rgb = convert_color_factory('bgr', 'rgb')
+rgb2bgr = convert_color_factory('rgb', 'bgr')
+bgr2hsv = convert_color_factory('bgr', 'hsv')
+hsv2bgr = convert_color_factory('hsv', 'bgr')
+bgr2hls = convert_color_factory('bgr', 'hls')
+hls2bgr = convert_color_factory('hls', 'bgr')
diff --git a/ControlNet/annotator/uniformer/mmcv/image/geometric.py b/ControlNet/annotator/uniformer/mmcv/image/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf97c201cb4e43796c911919d03fb26a07ed817d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/image/geometric.py
@@ -0,0 +1,728 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+import cv2
+import numpy as np
+from ..utils import to_2tuple
+from .io import imread_backend
+ from PIL import Image
+except ImportError:
+ Image = None
+def _scale_size(size, scale):
+ """Rescale a size by a ratio.
+ Args:
+ size (tuple[int]): (w, h).
+ scale (float | tuple(float)): Scaling factor.
+ Returns:
+ tuple[int]: scaled size.
+ """
+ if isinstance(scale, (float, int)):
+ scale = (scale, scale)
+ w, h = size
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+cv2_interp_codes = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'area': cv2.INTER_AREA,
+ 'lanczos': cv2.INTER_LANCZOS4
+if Image is not None:
+ pillow_interp_codes = {
+ 'nearest': Image.NEAREST,
+ 'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'box': Image.BOX,
+ 'lanczos': Image.LANCZOS,
+ 'hamming': Image.HAMMING
+ }
+def imresize(img,
+ size,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image to a given size.
+ Args:
+ img (ndarray): The input image.
+ size (tuple[int]): Target size (w, h).
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if backend is None:
+ backend = imread_backend
+ if backend not in ['cv2', 'pillow']:
+ raise ValueError(f'backend: {backend} is not supported for resize.'
+ f"Supported backends are 'cv2', 'pillow'")
+ if backend == 'pillow':
+ assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+ resized_img = np.array(pil_image)
+ else:
+ resized_img = cv2.resize(
+ img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+def imresize_to_multiple(img,
+ divisor,
+ size=None,
+ scale_factor=None,
+ keep_ratio=False,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image according to a given size or scale factor and then rounds
+ up the the resized or rescaled image size to the nearest value that can be
+ divided by the divisor.
+ Args:
+ img (ndarray): The input image.
+ divisor (int | tuple): Resized image size will be a multiple of
+ divisor. If divisor is a tuple, divisor should be
+ (w_divisor, h_divisor).
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
+ size. Should match input size if it is a tuple and the 2D style is
+ (w_scale_factor, h_scale_factor). Default: None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Default: False.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ elif size is None and scale_factor is None:
+ raise ValueError('one of size or scale_factor should be defined')
+ elif size is not None:
+ size = to_2tuple(size)
+ if keep_ratio:
+ size = rescale_size((w, h), size, return_scale=False)
+ else:
+ size = _scale_size((w, h), scale_factor)
+ divisor = to_2tuple(divisor)
+ size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
+ resized_img, w_scale, h_scale = imresize(
+ img,
+ size,
+ return_scale=True,
+ interpolation=interpolation,
+ out=out,
+ backend=backend)
+ if return_scale:
+ return resized_img, w_scale, h_scale
+ else:
+ return resized_img
+def imresize_like(img,
+ dst_img,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image to the same size of a given image.
+ Args:
+ img (ndarray): The input image.
+ dst_img (ndarray): The target image.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+ Returns:
+ tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = dst_img.shape[:2]
+ return imresize(img, (w, h), return_scale, interpolation, backend=backend)
+def rescale_size(old_size, scale, return_scale=False):
+ """Calculate the new size to be rescaled to.
+ Args:
+ old_size (tuple[int]): The old size (w, h) of image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image size.
+ Returns:
+ tuple[int]: The new rescaled image size.
+ """
+ w, h = old_size
+ if isinstance(scale, (float, int)):
+ if scale <= 0:
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
+ scale_factor = scale
+ elif isinstance(scale, tuple):
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ else:
+ raise TypeError(
+ f'Scale must be a number or tuple of int, but got {type(scale)}')
+ new_size = _scale_size((w, h), scale_factor)
+ if return_scale:
+ return new_size, scale_factor
+ else:
+ return new_size
+def imrescale(img,
+ scale,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image while keeping the aspect ratio.
+ Args:
+ img (ndarray): The input image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+ Returns:
+ ndarray: The rescaled image.
+ """
+ h, w = img.shape[:2]
+ new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+ rescaled_img = imresize(
+ img, new_size, interpolation=interpolation, backend=backend)
+ if return_scale:
+ return rescaled_img, scale_factor
+ else:
+ return rescaled_img
+def imflip(img, direction='horizontal'):
+ """Flip an image horizontally or vertically.
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+ Returns:
+ ndarray: The flipped image.
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return np.flip(img, axis=1)
+ elif direction == 'vertical':
+ return np.flip(img, axis=0)
+ else:
+ return np.flip(img, axis=(0, 1))
+def imflip_(img, direction='horizontal'):
+ """Inplace flip an image horizontally or vertically.
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+ Returns:
+ ndarray: The flipped image (inplace).
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return cv2.flip(img, 1, img)
+ elif direction == 'vertical':
+ return cv2.flip(img, 0, img)
+ else:
+ return cv2.flip(img, -1, img)
+def imrotate(img,
+ angle,
+ center=None,
+ scale=1.0,
+ border_value=0,
+ interpolation='bilinear',
+ auto_bound=False):
+ """Rotate an image.
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees, positive values mean
+ clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used.
+ scale (float): Isotropic scale factor.
+ border_value (int): Border value.
+ interpolation (str): Same as :func:`resize`.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image.
+ Returns:
+ ndarray: The rotated image.
+ """
+ if center is not None and auto_bound:
+ raise ValueError('`auto_bound` conflicts with `center`')
+ h, w = img.shape[:2]
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ assert isinstance(center, tuple)
+ matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ if auto_bound:
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
+ new_w = h * sin + w * cos
+ new_h = h * cos + w * sin
+ matrix[0, 2] += (new_w - w) * 0.5
+ matrix[1, 2] += (new_h - h) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(
+ img,
+ matrix, (w, h),
+ flags=cv2_interp_codes[interpolation],
+ borderValue=border_value)
+ return rotated
+def bbox_clip(bboxes, img_shape):
+ """Clip bboxes to fit the image shape.
+ Args:
+ bboxes (ndarray): Shape (..., 4*k)
+ img_shape (tuple[int]): (height, width) of the image.
+ Returns:
+ ndarray: Clipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
+ cmin[0::2] = img_shape[1] - 1
+ cmin[1::2] = img_shape[0] - 1
+ clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
+ return clipped_bboxes
+def bbox_scaling(bboxes, scale, clip_shape=None):
+ """Scaling bboxes w.r.t the box center.
+ Args:
+ bboxes (ndarray): Shape(..., 4).
+ scale (float): Scaling factor.
+ clip_shape (tuple[int], optional): If specified, bboxes that exceed the
+ boundary will be clipped according to the given shape (h, w).
+ Returns:
+ ndarray: Scaled bboxes.
+ """
+ if float(scale) == 1.0:
+ scaled_bboxes = bboxes.copy()
+ else:
+ w = bboxes[..., 2] - bboxes[..., 0] + 1
+ h = bboxes[..., 3] - bboxes[..., 1] + 1
+ dw = (w * (scale - 1)) * 0.5
+ dh = (h * (scale - 1)) * 0.5
+ scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
+ if clip_shape is not None:
+ return bbox_clip(scaled_bboxes, clip_shape)
+ else:
+ return scaled_bboxes
+def imcrop(img, bboxes, scale=1.0, pad_fill=None):
+ """Crop image patches.
+ 3 steps: scale the bboxes -> clip bboxes -> crop and pad.
+ Args:
+ img (ndarray): Image to be cropped.
+ bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
+ scale (float, optional): Scale ratio of bboxes, the default value
+ 1.0 means no padding.
+ pad_fill (Number | list[Number]): Value to be filled for padding.
+ Default: None, which means no padding.
+ Returns:
+ list[ndarray] | ndarray: The cropped image patches.
+ """
+ chn = 1 if img.ndim == 2 else img.shape[2]
+ if pad_fill is not None:
+ if isinstance(pad_fill, (int, float)):
+ pad_fill = [pad_fill for _ in range(chn)]
+ assert len(pad_fill) == chn
+ _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
+ scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
+ clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
+ patches = []
+ for i in range(clipped_bbox.shape[0]):
+ x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
+ if pad_fill is None:
+ patch = img[y1:y2 + 1, x1:x2 + 1, ...]
+ else:
+ _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
+ if chn == 1:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
+ else:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
+ patch = np.array(
+ pad_fill, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ x_start = 0 if _x1 >= 0 else -_x1
+ y_start = 0 if _y1 >= 0 else -_y1
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ patch[y_start:y_start + h, x_start:x_start + w,
+ ...] = img[y1:y1 + h, x1:x1 + w, ...]
+ patches.append(patch)
+ if bboxes.ndim == 1:
+ return patches[0]
+ else:
+ return patches
+def impad(img,
+ *,
+ shape=None,
+ padding=None,
+ pad_val=0,
+ padding_mode='constant'):
+ """Pad the given image to a certain shape or pad on all sides with
+ specified padding mode and padding value.
+ Args:
+ img (ndarray): Image to be padded.
+ shape (tuple[int]): Expected padding shape (h, w). Default: None.
+ padding (int or tuple[int]): Padding on each border. If a single int is
+ provided this is used to pad all borders. If tuple of length 2 is
+ provided this is the padding on left/right and top/bottom
+ respectively. If a tuple of length 4 is provided this is the
+ padding for the left, top, right and bottom borders respectively.
+ Default: None. Note that `shape` and `padding` can not be both
+ set.
+ pad_val (Number | Sequence[Number]): Values to be filled in padding
+ areas when padding_mode is 'constant'. Default: 0.
+ padding_mode (str): Type of padding. Should be: constant, edge,
+ reflect or symmetric. Default: constant.
+ - constant: pads with a constant value, this value is specified
+ with pad_val.
+ - edge: pads with the last value at the edge of the image.
+ - reflect: pads with reflection of image without repeating the
+ last value on the edge. For example, padding [1, 2, 3, 4]
+ with 2 elements on both sides in reflect mode will result
+ in [3, 2, 1, 2, 3, 4, 3, 2].
+ - symmetric: pads with reflection of image repeating the last
+ value on the edge. For example, padding [1, 2, 3, 4] with
+ 2 elements on both sides in symmetric mode will result in
+ [2, 1, 1, 2, 3, 4, 4, 3]
+ Returns:
+ ndarray: The padded image.
+ """
+ assert (shape is not None) ^ (padding is not None)
+ if shape is not None:
+ padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
+ # check pad_val
+ if isinstance(pad_val, tuple):
+ assert len(pad_val) == img.shape[-1]
+ elif not isinstance(pad_val, numbers.Number):
+ raise TypeError('pad_val must be a int or a tuple. '
+ f'But received {type(pad_val)}')
+ # check padding
+ if isinstance(padding, tuple) and len(padding) in [2, 4]:
+ if len(padding) == 2:
+ padding = (padding[0], padding[1], padding[0], padding[1])
+ elif isinstance(padding, numbers.Number):
+ padding = (padding, padding, padding, padding)
+ else:
+ raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+ f'But received {padding}')
+ # check padding mode
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+ border_type = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'edge': cv2.BORDER_REPLICATE,
+ 'reflect': cv2.BORDER_REFLECT_101,
+ 'symmetric': cv2.BORDER_REFLECT
+ }
+ img = cv2.copyMakeBorder(
+ img,
+ padding[1],
+ padding[3],
+ padding[0],
+ padding[2],
+ border_type[padding_mode],
+ value=pad_val)
+ return img
+def impad_to_multiple(img, divisor, pad_val=0):
+ """Pad an image to ensure each edge to be multiple to some number.
+ Args:
+ img (ndarray): Image to be padded.
+ divisor (int): Padded image edges will be multiple to divisor.
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+ Returns:
+ ndarray: The padded image.
+ """
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+ return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
+def cutout(img, shape, pad_val=0):
+ """Randomly cut out a rectangle from the original img.
+ Args:
+ img (ndarray): Image to be cutout.
+ shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
+ int, the value will be used for both h and w.
+ pad_val (int | float | tuple[int | float]): Values to be filled in the
+ cut area. Defaults to 0.
+ Returns:
+ ndarray: The cutout image.
+ """
+ channels = 1 if img.ndim == 2 else img.shape[2]
+ if isinstance(shape, int):
+ cut_h, cut_w = shape, shape
+ else:
+ assert isinstance(shape, tuple) and len(shape) == 2, \
+ f'shape must be a int or a tuple with length 2, but got type ' \
+ f'{type(shape)} instead.'
+ cut_h, cut_w = shape
+ if isinstance(pad_val, (int, float)):
+ pad_val = tuple([pad_val] * channels)
+ elif isinstance(pad_val, tuple):
+ assert len(pad_val) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(pad_val), channels)
+ else:
+ raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
+ img_h, img_w = img.shape[:2]
+ y0 = np.random.uniform(img_h)
+ x0 = np.random.uniform(img_w)
+ y1 = int(max(0, y0 - cut_h / 2.))
+ x1 = int(max(0, x0 - cut_w / 2.))
+ y2 = min(img_h, y1 + cut_h)
+ x2 = min(img_w, x1 + cut_w)
+ if img.ndim == 2:
+ patch_shape = (y2 - y1, x2 - x1)
+ else:
+ patch_shape = (y2 - y1, x2 - x1, channels)
+ img_cutout = img.copy()
+ patch = np.array(
+ pad_val, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ img_cutout[y1:y2, x1:x2, ...] = patch
+ return img_cutout
+def _get_shear_matrix(magnitude, direction='horizontal'):
+ """Generate the shear matrix for transformation.
+ Args:
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+ Returns:
+ ndarray: The shear matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
+ elif direction == 'vertical':
+ shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
+ return shear_matrix
+def imshear(img,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear an image.
+ Args:
+ img (ndarray): Image to be sheared with format (h, w)
+ or (h, w, c).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+ Returns:
+ ndarray: The sheared image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`')
+ shear_matrix = _get_shear_matrix(magnitude, direction)
+ sheared = cv2.warpAffine(
+ img,
+ shear_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. shearing masks whose channels large
+ # than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return sheared
+def _get_translate_matrix(offset, direction='horizontal'):
+ """Generate the translate matrix.
+ Args:
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either
+ "horizontal" or "vertical".
+ Returns:
+ ndarray: The translate matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
+ elif direction == 'vertical':
+ translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
+ return translate_matrix
+def imtranslate(img,
+ offset,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Translate an image.
+ Args:
+ img (ndarray): Image to be translated with format
+ (h, w) or (h, w, c).
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+ Returns:
+ ndarray: The translated image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`.')
+ translate_matrix = _get_translate_matrix(offset, direction)
+ translated = cv2.warpAffine(
+ img,
+ translate_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. translating masks whose channels
+ # large than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return translated
diff --git a/ControlNet/annotator/uniformer/mmcv/image/io.py b/ControlNet/annotator/uniformer/mmcv/image/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3fa2e8cc06b1a7b0b69de6406980b15d61a1e5d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/image/io.py
@@ -0,0 +1,258 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os.path as osp
+from pathlib import Path
+import cv2
+import numpy as np
+from annotator.uniformer.mmcv.utils import check_file_exist, is_str, mkdir_or_exist
+ from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
+except ImportError:
+ from PIL import Image, ImageOps
+except ImportError:
+ Image = None
+ import tifffile
+except ImportError:
+ tifffile = None
+jpeg = None
+supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
+imread_flags = {
+ 'color': IMREAD_COLOR,
+ 'grayscale': IMREAD_GRAYSCALE,
+ 'unchanged': IMREAD_UNCHANGED,
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+ 'grayscale_ignore_orientation':
+imread_backend = 'cv2'
+def use_backend(backend):
+ """Select a backend for image decoding.
+ Args:
+ backend (str): The image decoding backend type. Options are `cv2`,
+ `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
+ and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
+ file format.
+ """
+ assert backend in supported_backends
+ global imread_backend
+ imread_backend = backend
+ if imread_backend == 'turbojpeg':
+ if TurboJPEG is None:
+ raise ImportError('`PyTurboJPEG` is not installed')
+ global jpeg
+ if jpeg is None:
+ jpeg = TurboJPEG()
+ elif imread_backend == 'pillow':
+ if Image is None:
+ raise ImportError('`Pillow` is not installed')
+ elif imread_backend == 'tifffile':
+ if tifffile is None:
+ raise ImportError('`tifffile` is not installed')
+def _jpegflag(flag='color', channel_order='bgr'):
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+ if flag == 'color':
+ if channel_order == 'bgr':
+ return TJPF_BGR
+ elif channel_order == 'rgb':
+ return TJCS_RGB
+ elif flag == 'grayscale':
+ return TJPF_GRAY
+ else:
+ raise ValueError('flag must be "color" or "grayscale"')
+def _pillow2array(img, flag='color', channel_order='bgr'):
+ """Convert a pillow image to numpy array.
+ Args:
+ img (:obj:`PIL.Image.Image`): The image loaded using PIL
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are 'color', 'grayscale' and 'unchanged'.
+ Default to 'color'.
+ channel_order (str): The channel order of the output image array,
+ candidates are 'bgr' and 'rgb'. Default to 'bgr'.
+ Returns:
+ np.ndarray: The converted numpy array
+ """
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+ if flag == 'unchanged':
+ array = np.array(img)
+ if array.ndim >= 3 and array.shape[2] >= 3: # color image
+ array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
+ else:
+ # Handle exif orientation tag
+ if flag in ['color', 'grayscale']:
+ img = ImageOps.exif_transpose(img)
+ # If the image mode is not 'RGB', convert it to 'RGB' first.
+ if img.mode != 'RGB':
+ if img.mode != 'LA':
+ # Most formats except 'LA' can be directly converted to RGB
+ img = img.convert('RGB')
+ else:
+ # When the mode is 'LA', the default conversion will fill in
+ # the canvas with black, which sometimes shadows black objects
+ # in the foreground.
+ #
+ # Therefore, a random color (124, 117, 104) is used for canvas
+ img_rgba = img.convert('RGBA')
+ img = Image.new('RGB', img_rgba.size, (124, 117, 104))
+ img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
+ if flag in ['color', 'color_ignore_orientation']:
+ array = np.array(img)
+ if channel_order != 'rgb':
+ array = array[:, :, ::-1] # RGB to BGR
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
+ img = img.convert('L')
+ array = np.array(img)
+ else:
+ raise ValueError(
+ 'flag must be "color", "grayscale", "unchanged", '
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+ f' but got {flag}')
+ return array
+def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
+ """Read an image.
+ Args:
+ img_or_path (ndarray or str or Path): Either a numpy array or str or
+ pathlib.Path. If it is a numpy array (loaded image), then
+ it will be returned as is.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale`, `unchanged`,
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
+ By default, `cv2` and `pillow` backend would rotate the image
+ according to its EXIF info unless called with `unchanged` or
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+ always ignore image's EXIF info regardless of the flag.
+ The `turbojpeg` backend only supports `color` and `grayscale`.
+ channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
+ If backend is None, the global imread_backend specified by
+ ``mmcv.use_backend()`` will be used. Default: None.
+ Returns:
+ ndarray: Loaded image array.
+ """
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if isinstance(img_or_path, Path):
+ img_or_path = str(img_or_path)
+ if isinstance(img_or_path, np.ndarray):
+ return img_or_path
+ elif is_str(img_or_path):
+ check_file_exist(img_or_path,
+ f'img file does not exist: {img_or_path}')
+ if backend == 'turbojpeg':
+ with open(img_or_path, 'rb') as in_file:
+ img = jpeg.decode(in_file.read(),
+ _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ img = Image.open(img_or_path)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ elif backend == 'tifffile':
+ img = tifffile.imread(img_or_path)
+ return img
+ else:
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imread(img_or_path, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+ else:
+ raise TypeError('"img" must be a numpy array or a str or '
+ 'a pathlib.Path object')
+def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
+ """Read an image from bytes.
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Same as :func:`imread`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
+ global imread_backend specified by ``mmcv.use_backend()`` will be
+ used. Default: None.
+ Returns:
+ ndarray: Loaded image array.
+ """
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if backend == 'turbojpeg':
+ img = jpeg.decode(content, _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ buff = io.BytesIO(content)
+ img = Image.open(buff)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ else:
+ img_np = np.frombuffer(content, np.uint8)
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imdecode(img_np, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = osp.abspath(osp.dirname(file_path))
+ mkdir_or_exist(dir_name)
+ return cv2.imwrite(file_path, img, params)
diff --git a/ControlNet/annotator/uniformer/mmcv/image/misc.py b/ControlNet/annotator/uniformer/mmcv/image/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e61f05e3b05e4c7b40de4eb6c8eb100e6da41d0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/image/misc.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import annotator.uniformer.mmcv as mmcv
+ import torch
+except ImportError:
+ torch = None
+def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
+ """Convert tensor to 3-channel images.
+ Args:
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
+ N, C, H, W).
+ mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
+ std (tuple[float], optional): Standard deviation of images.
+ Defaults to (1, 1, 1).
+ to_rgb (bool, optional): Whether the tensor was converted to RGB
+ format in the first place. If so, convert it back to BGR.
+ Defaults to True.
+ Returns:
+ list[np.ndarray]: A list that contains multiple images.
+ """
+ if torch is None:
+ raise RuntimeError('pytorch is not installed')
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
+ assert len(mean) == 3
+ assert len(std) == 3
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = mmcv.imdenormalize(
+ img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
diff --git a/ControlNet/annotator/uniformer/mmcv/image/photometric.py b/ControlNet/annotator/uniformer/mmcv/image/photometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..5085d012019c0cbf56f66f421a378278c1a058ae
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/image/photometric.py
@@ -0,0 +1,428 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+from ..utils import is_tuple_of
+from .colorspace import bgr2gray, gray2bgr
+def imnormalize(img, mean, std, to_rgb=True):
+ """Normalize an image with mean and std.
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+ Returns:
+ ndarray: The normalized image.
+ """
+ img = img.copy().astype(np.float32)
+ return imnormalize_(img, mean, std, to_rgb)
+def imnormalize_(img, mean, std, to_rgb=True):
+ """Inplace normalize an image with mean and std.
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+ Returns:
+ ndarray: The normalized image.
+ """
+ # cv2 inplace normalization does not accept uint8
+ assert img.dtype != np.uint8
+ mean = np.float64(mean.reshape(1, -1))
+ stdinv = 1 / np.float64(std.reshape(1, -1))
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ cv2.subtract(img, mean, img) # inplace
+ cv2.multiply(img, stdinv, img) # inplace
+ return img
+def imdenormalize(img, mean, std, to_bgr=True):
+ assert img.dtype != np.uint8
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = cv2.multiply(img, std) # make a copy
+ cv2.add(img, mean, img) # inplace
+ if to_bgr:
+ cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
+ return img
+def iminvert(img):
+ """Invert (negate) an image.
+ Args:
+ img (ndarray): Image to be inverted.
+ Returns:
+ ndarray: The inverted image.
+ """
+ return np.full_like(img, 255) - img
+def solarize(img, thr=128):
+ """Solarize an image (invert all pixel values above a threshold)
+ Args:
+ img (ndarray): Image to be solarized.
+ thr (int): Threshold for solarizing (0 - 255).
+ Returns:
+ ndarray: The solarized image.
+ """
+ img = np.where(img < thr, img, 255 - img)
+ return img
+def posterize(img, bits):
+ """Posterize an image (reduce the number of bits for each color channel)
+ Args:
+ img (ndarray): Image to be posterized.
+ bits (int): Number of bits (1 to 8) to use for posterizing.
+ Returns:
+ ndarray: The posterized image.
+ """
+ shift = 8 - bits
+ img = np.left_shift(np.right_shift(img, shift), shift)
+ return img
+def adjust_color(img, alpha=1, beta=None, gamma=0):
+ r"""It blends the source image and its gray image:
+ .. math::
+ output = img * alpha + gray\_img * beta + gamma
+ Args:
+ img (ndarray): The input source image.
+ alpha (int | float): Weight for the source image. Default 1.
+ beta (int | float): Weight for the converted gray image.
+ If None, it's assigned the value (1 - `alpha`).
+ gamma (int | float): Scalar added to each sum.
+ Same as :func:`cv2.addWeighted`. Default 0.
+ Returns:
+ ndarray: Colored image which has the same size and dtype as input.
+ """
+ gray_img = bgr2gray(img)
+ gray_img = np.tile(gray_img[..., None], [1, 1, 3])
+ if beta is None:
+ beta = 1 - alpha
+ colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
+ if not colored_img.dtype == np.uint8:
+ # Note when the dtype of `img` is not the default `np.uint8`
+ # (e.g. np.float32), the value in `colored_img` got from cv2
+ # is not guaranteed to be in range [0, 255], so here clip
+ # is needed.
+ colored_img = np.clip(colored_img, 0, 255)
+ return colored_img
+def imequalize(img):
+ """Equalize the image histogram.
+ This function applies a non-linear mapping to the input image,
+ in order to create a uniform distribution of grayscale values
+ in the output image.
+ Args:
+ img (ndarray): Image to be equalized.
+ Returns:
+ ndarray: The equalized image.
+ """
+ def _scale_channel(im, c):
+ """Scale the data in the corresponding channel."""
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # For computing the step, filter out the nonzeros.
+ nonzero_histo = histo[histo > 0]
+ step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
+ if not step:
+ lut = np.array(range(256))
+ else:
+ # Compute the cumulative sum, shifted by step // 2
+ # and then normalized by step.
+ lut = (np.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = np.concatenate([[0], lut[:-1]], 0)
+ # handle potential integer overflow
+ lut[lut > 255] = 255
+ # If step is zero, return the original image.
+ # Otherwise, index from lut.
+ return np.where(np.equal(step, 0), im, lut[im])
+ # Scales each channel independently and then stacks
+ # the result.
+ s1 = _scale_channel(img, 0)
+ s2 = _scale_channel(img, 1)
+ s3 = _scale_channel(img, 2)
+ equalized_img = np.stack([s1, s2, s3], axis=-1)
+ return equalized_img.astype(img.dtype)
+def adjust_brightness(img, factor=1.):
+ """Adjust image brightness.
+ This function controls the brightness of an image. An
+ enhancement factor of 0.0 gives a black image.
+ A factor of 1.0 gives the original image. This function
+ blends the source image and the degenerated black image:
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+ Args:
+ img (ndarray): Image to be brightened.
+ factor (float): A value controls the enhancement.
+ Factor 1.0 returns the original image, lower
+ factors mean less color (brightness, contrast,
+ etc), and higher values more. Default 1.
+ Returns:
+ ndarray: The brightened image.
+ """
+ degenerated = np.zeros_like(img)
+ # Note manually convert the dtype to np.float32, to
+ # achieve as close results as PIL.ImageEnhance.Brightness.
+ # Set beta=1-factor, and gamma=0
+ brightened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ brightened_img = np.clip(brightened_img, 0, 255)
+ return brightened_img.astype(img.dtype)
+def adjust_contrast(img, factor=1.):
+ """Adjust image contrast.
+ This function controls the contrast of an image. An
+ enhancement factor of 0.0 gives a solid grey
+ image. A factor of 1.0 gives the original image. It
+ blends the source image and the degenerated mean image:
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+ Returns:
+ ndarray: The contrasted image.
+ """
+ gray_img = bgr2gray(img)
+ hist = np.histogram(gray_img, 256, (0, 255))[0]
+ mean = round(np.sum(gray_img) / np.sum(hist))
+ degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
+ degenerated = gray2bgr(degenerated)
+ contrasted_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ contrasted_img = np.clip(contrasted_img, 0, 255)
+ return contrasted_img.astype(img.dtype)
+def auto_contrast(img, cutoff=0):
+ """Auto adjust image contrast.
+ This function maximize (normalize) image contrast by first removing cutoff
+ percent of the lightest and darkest pixels from the histogram and remapping
+ the image so that the darkest pixel becomes black (0), and the lightest
+ becomes white (255).
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ cutoff (int | float | tuple): The cutoff percent of the lightest and
+ darkest pixels to be removed. If given as tuple, it shall be
+ (low, high). Otherwise, the single value will be used for both.
+ Defaults to 0.
+ Returns:
+ ndarray: The contrasted image.
+ """
+ def _auto_contrast_channel(im, c, cutoff):
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # Remove cut-off percent pixels from histo
+ histo_sum = np.cumsum(histo)
+ cut_low = histo_sum[-1] * cutoff[0] // 100
+ cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
+ histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
+ histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
+ # Compute mapping
+ low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
+ # If all the values have been cut off, return the origin img
+ if low >= high:
+ return im
+ scale = 255.0 / (high - low)
+ offset = -low * scale
+ lut = np.array(range(256))
+ lut = lut * scale + offset
+ lut = np.clip(lut, 0, 255)
+ return lut[im]
+ if isinstance(cutoff, (int, float)):
+ cutoff = (cutoff, cutoff)
+ else:
+ assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
+ f'float or tuple, but got {type(cutoff)} instead.'
+ # Auto adjusts contrast for each channel independently and then stacks
+ # the result.
+ s1 = _auto_contrast_channel(img, 0, cutoff)
+ s2 = _auto_contrast_channel(img, 1, cutoff)
+ s3 = _auto_contrast_channel(img, 2, cutoff)
+ contrasted_img = np.stack([s1, s2, s3], axis=-1)
+ return contrasted_img.astype(img.dtype)
+def adjust_sharpness(img, factor=1., kernel=None):
+ """Adjust image sharpness.
+ This function controls the sharpness of an image. An
+ enhancement factor of 0.0 gives a blurred image. A
+ factor of 1.0 gives the original image. And a factor
+ of 2.0 gives a sharpened image. It blends the source
+ image and the degenerated mean image:
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+ Args:
+ img (ndarray): Image to be sharpened. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+ kernel (np.ndarray, optional): Filter kernel to be applied on the img
+ to obtain the degenerated img. Defaults to None.
+ Note:
+ No value sanity check is enforced on the kernel set by users. So with
+ an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
+ the function its name indicates but end up performing whatever
+ transform determined by the kernel.
+ Returns:
+ ndarray: The sharpened image.
+ """
+ if kernel is None:
+ # adopted from PIL.ImageFilter.SMOOTH
+ kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
+ assert isinstance(kernel, np.ndarray), \
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
+ assert kernel.ndim == 2, \
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
+ degenerated = cv2.filter2D(img, -1, kernel)
+ sharpened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ sharpened_img = np.clip(sharpened_img, 0, 255)
+ return sharpened_img.astype(img.dtype)
+def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
+ """AlexNet-style PCA jitter.
+ This data augmentation is proposed in `ImageNet Classification with Deep
+ Convolutional Neural Networks
+ `_.
+ Args:
+ img (ndarray): Image to be adjusted lighting. BGR order.
+ eigval (ndarray): the eigenvalue of the convariance matrix of pixel
+ values, respectively.
+ eigvec (ndarray): the eigenvector of the convariance matrix of pixel
+ values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1
+ to_rgb (bool): Whether to convert img to rgb.
+ Returns:
+ ndarray: The adjusted image.
+ """
+ assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
+ f'eigval and eigvec should both be of type np.ndarray, got ' \
+ f'{type(eigval)} and {type(eigvec)} instead.'
+ assert eigval.ndim == 1 and eigvec.ndim == 2
+ assert eigvec.shape == (3, eigval.shape[0])
+ n_eigval = eigval.shape[0]
+ assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
+ f'got {type(alphastd)} instead.'
+ img = img.copy().astype(np.float32)
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ alpha = np.random.normal(0, alphastd, n_eigval)
+ alter = eigvec \
+ * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
+ * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
+ alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
+ img_adjusted = img + alter
+ return img_adjusted
+def lut_transform(img, lut_table):
+ """Transform array by look-up table.
+ The function lut_transform fills the output array with values from the
+ look-up table. Indices of the entries are taken from the input array.
+ Args:
+ img (ndarray): Image to be transformed.
+ lut_table (ndarray): look-up table of 256 elements; in case of
+ multi-channel input array, the table should either have a single
+ channel (in this case the same table is used for all channels) or
+ the same number of channels as in the input array.
+ Returns:
+ ndarray: The transformed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert 0 <= np.min(img) and np.max(img) <= 255
+ assert isinstance(lut_table, np.ndarray)
+ assert lut_table.shape == (256, )
+ return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
+def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
+ """Use CLAHE method to process the image.
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+ Args:
+ img (ndarray): Image to be processed.
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+ Returns:
+ ndarray: The processed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert img.ndim == 2
+ assert isinstance(clip_limit, (float, int))
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+ clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
+ return clahe.apply(np.array(img, dtype=np.uint8))
diff --git a/ControlNet/annotator/uniformer/mmcv/model_zoo/deprecated.json b/ControlNet/annotator/uniformer/mmcv/model_zoo/deprecated.json
new file mode 100644
index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/model_zoo/deprecated.json
@@ -0,0 +1,6 @@
+ "resnet50_caffe": "detectron/resnet50_caffe",
+ "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
+ "resnet101_caffe": "detectron/resnet101_caffe",
+ "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
diff --git a/ControlNet/annotator/uniformer/mmcv/model_zoo/mmcls.json b/ControlNet/annotator/uniformer/mmcv/model_zoo/mmcls.json
new file mode 100644
index 0000000000000000000000000000000000000000..bdb311d9fe6d9f317290feedc9e37236c6cf6e8f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/model_zoo/mmcls.json
@@ -0,0 +1,31 @@
+ "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+ "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+ "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+ "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
+ "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+ "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+ "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+ "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+ "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth",
+ "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth",
+ "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth",
+ "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth",
+ "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth",
+ "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth",
+ "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth",
+ "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+ "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+ "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
+ "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
+ "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
+ "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
+ "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
+ "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
+ "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
+ "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
+ "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
+ "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth"
diff --git a/ControlNet/annotator/uniformer/mmcv/model_zoo/open_mmlab.json b/ControlNet/annotator/uniformer/mmcv/model_zoo/open_mmlab.json
new file mode 100644
index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/model_zoo/open_mmlab.json
@@ -0,0 +1,50 @@
+ "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
+ "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
+ "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
+ "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
+ "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
+ "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
+ "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
+ "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
+ "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
+ "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
+ "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
+ "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
+ "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
+ "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
+ "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
+ "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
+ "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
+ "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
+ "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
+ "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
+ "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
+ "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
+ "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
+ "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
+ "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
+ "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
+ "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
+ "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
+ "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
+ "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
+ "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
+ "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
+ "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
+ "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
+ "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
+ "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
+ "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
+ "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
+ "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
+ "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
+ "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
+ "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
+ "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
+ "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
+ "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+ "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/__init__.py b/ControlNet/annotator/uniformer/mmcv/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..999e090a458ee148ceca0649f1e3806a40e909bd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/__init__.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .assign_score_withk import assign_score_withk
+from .ball_query import ball_query
+from .bbox import bbox_overlaps
+from .border_align import BorderAlign, border_align
+from .box_iou_rotated import box_iou_rotated
+from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+from .cc_attention import CrissCrossAttention
+from .contour_expand import contour_expand
+from .corner_pool import CornerPool
+from .correlation import Correlation
+from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
+from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
+ ModulatedDeformRoIPoolPack, deform_roi_pool)
+from .deprecated_wrappers import Conv2d_deprecated as Conv2d
+from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
+from .deprecated_wrappers import Linear_deprecated as Linear
+from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
+from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
+ sigmoid_focal_loss, softmax_focal_loss)
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
+from .gather_points import gather_points
+from .group_points import GroupAll, QueryAndGroup, grouping_operation
+from .info import (get_compiler_version, get_compiling_cuda_version,
+ get_onnxruntime_op_path)
+from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
+from .knn import knn
+from .masked_conv import MaskedConv2d, masked_conv2d
+from .modulated_deform_conv import (ModulatedDeformConv2d,
+ ModulatedDeformConv2dPack,
+ modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
+from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
+from .pixel_group import pixel_group
+from .point_sample import (SimpleRoIAlign, point_sample,
+ rel_roi_point_to_rel_img_point)
+from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
+ points_in_boxes_part)
+from .points_sampler import PointsSampler
+from .psa_mask import PSAMask
+from .roi_align import RoIAlign, roi_align
+from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
+from .roi_pool import RoIPool, roi_pool
+from .roiaware_pool3d import RoIAwarePool3d
+from .roipoint_pool3d import RoIPointPool3d
+from .saconv import SAConv2d
+from .scatter_points import DynamicScatter, dynamic_scatter
+from .sync_bn import SyncBatchNorm
+from .three_interpolate import three_interpolate
+from .three_nn import three_nn
+from .tin_shift import TINShift, tin_shift
+from .upfirdn2d import upfirdn2d
+from .voxelize import Voxelization, voxelization
+__all__ = [
+ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
+ 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
+ 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
+ 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
+ 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
+ 'get_compiler_version', 'get_compiling_cuda_version',
+ 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
+ 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
+ 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
+ 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
+ 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
+ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
+ 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
+ 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
+ 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
+ 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
+ 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
+ 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
+ 'border_align', 'gather_points', 'furthest_point_sample',
+ 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
+ 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
+ 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
+ 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/assign_score_withk.py b/ControlNet/annotator/uniformer/mmcv/ops/assign_score_withk.py
new file mode 100644
index 0000000000000000000000000000000000000000..4906adaa2cffd1b46912fbe7d4f87ef2f9fa0012
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/assign_score_withk.py
@@ -0,0 +1,123 @@
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
+class AssignScoreWithK(Function):
+ r"""Perform weighted sum to generate output features according to scores.
+ Modified from `PAConv `_.
+ This is a memory-efficient CUDA implementation of assign_scores operation,
+ which first transform all point features with weight bank, then assemble
+ neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
+ See the `paper `_ appendix Sec. D for
+ more detailed descriptions.
+ Note:
+ This implementation assumes using ``neighbor`` kernel input, which is
+ (point_features - center_features, point_features).
+ See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
+ pointnet2/paconv.py#L128 for more details.
+ """
+ @staticmethod
+ def forward(ctx,
+ scores,
+ point_features,
+ center_features,
+ knn_idx,
+ aggregate='sum'):
+ """
+ Args:
+ scores (torch.Tensor): (B, npoint, K, M), predicted scores to
+ aggregate weight matrices in the weight bank.
+ ``npoint`` is the number of sampled centers.
+ ``K`` is the number of queried neighbors.
+ ``M`` is the number of weight matrices in the weight bank.
+ point_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed point features to be aggregated.
+ center_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed center features to be aggregated.
+ knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
+ We assume the first idx in each row is the idx of the center.
+ aggregate (str, optional): Aggregation method.
+ Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
+ Returns:
+ torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
+ """
+ agg = {'sum': 0, 'avg': 1, 'max': 2}
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+ output = point_features.new_zeros((B, out_dim, npoint, K))
+ ext_module.assign_score_withk_forward(
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ output,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg[aggregate])
+ ctx.save_for_backward(output, point_features, center_features, scores,
+ knn_idx)
+ ctx.agg = agg[aggregate]
+ return output
+ @staticmethod
+ def backward(ctx, grad_out):
+ """
+ Args:
+ grad_out (torch.Tensor): (B, out_dim, npoint, K)
+ Returns:
+ grad_scores (torch.Tensor): (B, npoint, K, M)
+ grad_point_features (torch.Tensor): (B, N, M, out_dim)
+ grad_center_features (torch.Tensor): (B, N, M, out_dim)
+ """
+ _, point_features, center_features, scores, knn_idx = ctx.saved_tensors
+ agg = ctx.agg
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+ grad_point_features = point_features.new_zeros(point_features.shape)
+ grad_center_features = center_features.new_zeros(center_features.shape)
+ grad_scores = scores.new_zeros(scores.shape)
+ ext_module.assign_score_withk_backward(
+ grad_out.contiguous(),
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ grad_point_features,
+ grad_center_features,
+ grad_scores,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg)
+ return grad_scores, grad_point_features, \
+ grad_center_features, None, None
+assign_score_withk = AssignScoreWithK.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/ball_query.py b/ControlNet/annotator/uniformer/mmcv/ops/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0466847c6e5c1239e359a0397568413ebc1504a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/ball_query.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
+class BallQuery(Function):
+ """Find nearby points in spherical space."""
+ @staticmethod
+ def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
+ xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ min_radius (float): minimum radius of the balls.
+ max_radius (float): maximum radius of the balls.
+ sample_num (int): maximum number of features in the balls.
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
+ Returns:
+ Tensor: (B, npoint, nsample) tensor with the indices of
+ the features that form the query balls.
+ """
+ assert center_xyz.is_contiguous()
+ assert xyz.is_contiguous()
+ assert min_radius < max_radius
+ B, N, _ = xyz.size()
+ npoint = center_xyz.size(1)
+ idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
+ ext_module.ball_query_forward(
+ center_xyz,
+ xyz,
+ idx,
+ b=B,
+ n=N,
+ m=npoint,
+ min_radius=min_radius,
+ max_radius=max_radius,
+ nsample=sample_num)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None, None
+ball_query = BallQuery.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/bbox.py b/ControlNet/annotator/uniformer/mmcv/ops/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4d58b6c91f652933974f519acd3403a833e906
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/bbox.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
+ """Calculate overlap between two set of bboxes.
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+ Args:
+ bboxes1 (Tensor): shape (m, 4) in format or empty.
+ bboxes2 (Tensor): shape (n, 4) in format or empty.
+ If aligned is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+ Returns:
+ ious(Tensor): shape (m, n) if aligned == False else shape (m, 1)
+ Example:
+ >>> bboxes1 = torch.FloatTensor([
+ >>> [0, 0, 10, 10],
+ >>> [10, 10, 20, 20],
+ >>> [32, 32, 38, 42],
+ >>> ])
+ >>> bboxes2 = torch.FloatTensor([
+ >>> [0, 0, 10, 20],
+ >>> [0, 10, 10, 19],
+ >>> [10, 10, 20, 20],
+ >>> ])
+ >>> bbox_overlaps(bboxes1, bboxes2)
+ tensor([[0.5000, 0.0000, 0.0000],
+ [0.0000, 0.0000, 1.0000],
+ [0.0000, 0.0000, 0.0000]])
+ Example:
+ >>> empty = torch.FloatTensor([])
+ >>> nonempty = torch.FloatTensor([
+ >>> [0, 0, 10, 9],
+ >>> ])
+ >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+ >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+ >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+ """
+ mode_dict = {'iou': 0, 'iof': 1}
+ assert mode in mode_dict.keys()
+ mode_flag = mode_dict[mode]
+ # Either the boxes are empty or the length of boxes' last dimension is 4
+ assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+ assert offset == 1 or offset == 0
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ assert rows == cols
+ if rows * cols == 0:
+ return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows, cols))
+ ext_module.bbox_overlaps(
+ bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
+ return ious
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/border_align.py b/ControlNet/annotator/uniformer/mmcv/ops/border_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff305be328e9b0a15e1bbb5e6b41beb940f55c81
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/border_align.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modified from
+# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['border_align_forward', 'border_align_backward'])
+class BorderAlignFunction(Function):
+ @staticmethod
+ def symbolic(g, input, boxes, pool_size):
+ return g.op(
+ 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
+ @staticmethod
+ def forward(ctx, input, boxes, pool_size):
+ ctx.pool_size = pool_size
+ ctx.input_shape = input.size()
+ assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
+ assert boxes.size(2) == 4, \
+ 'the last dimension of boxes must be (x1, y1, x2, y2)'
+ assert input.size(1) % 4 == 0, \
+ 'the channel for input feature must be divisible by factor 4'
+ # [B, C//4, H*W, 4]
+ output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
+ output = input.new_zeros(output_shape)
+ # `argmax_idx` only used for backward
+ argmax_idx = input.new_zeros(output_shape).to(torch.int)
+ ext_module.border_align_forward(
+ input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
+ ctx.save_for_backward(boxes, argmax_idx)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ boxes, argmax_idx = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous
+ grad_output = grad_output.contiguous()
+ ext_module.border_align_backward(
+ grad_output,
+ boxes,
+ argmax_idx,
+ grad_input,
+ pool_size=ctx.pool_size)
+ return grad_input, None, None
+border_align = BorderAlignFunction.apply
+class BorderAlign(nn.Module):
+ r"""Border align pooling layer.
+ Applies border_align over the input feature based on predicted bboxes.
+ The details were described in the paper
+ `BorderDet: Border Feature for Dense Object Detection
+ `_.
+ For each border line (e.g. top, left, bottom or right) of each box,
+ border_align does the following:
+ 1. uniformly samples `pool_size`+1 positions on this line, involving \
+ the start and end points.
+ 2. the corresponding features on these points are computed by \
+ bilinear interpolation.
+ 3. max pooling over all the `pool_size`+1 positions are used for \
+ computing pooled feature.
+ Args:
+ pool_size (int): number of positions sampled over the boxes' borders
+ (e.g. top, bottom, left, right).
+ """
+ def __init__(self, pool_size):
+ super(BorderAlign, self).__init__()
+ self.pool_size = pool_size
+ def forward(self, input, boxes):
+ """
+ Args:
+ input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
+ [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
+ right features respectively.
+ boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
+ Returns:
+ Tensor: Pooled features with shape [N,C,H*W,4]. The order is
+ (top,left,bottom,right) for the last dimension.
+ """
+ return border_align(input, boxes, self.pool_size)
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(pool_size={self.pool_size})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/box_iou_rotated.py b/ControlNet/annotator/uniformer/mmcv/ops/box_iou_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d78015e9c2a9e7a52859b4e18f84a9aa63481a0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/box_iou_rotated.py
@@ -0,0 +1,45 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
+def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
+ """Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in
+ (x_center, y_center, width, height, angle) format.
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+ Arguments:
+ boxes1 (Tensor): rotated bboxes 1. \
+ It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ boxes2 (Tensor): rotated bboxes 2. \
+ It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+ Returns:
+ ious(Tensor): shape (N, M) if aligned == False else shape (N,)
+ """
+ assert mode in ['iou', 'iof']
+ mode_dict = {'iou': 0, 'iof': 1}
+ mode_flag = mode_dict[mode]
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows * cols))
+ bboxes1 = bboxes1.contiguous()
+ bboxes2 = bboxes2.contiguous()
+ ext_module.box_iou_rotated(
+ bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
+ if not aligned:
+ ious = ious.view(rows, cols)
+ return ious
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/carafe.py b/ControlNet/annotator/uniformer/mmcv/ops/carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5154cb3abfccfbbe0a1b2daa67018dbf80aaf6d2
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/carafe.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.modules.module import Module
+from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', [
+ 'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward',
+ 'carafe_backward'
+class CARAFENaiveFunction(Function):
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFENaive',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ ext_module.carafe_naive_forward(
+ features,
+ masks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+ features, masks = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+ grad_input = torch.zeros_like(features)
+ grad_masks = torch.zeros_like(masks)
+ ext_module.carafe_naive_backward(
+ grad_output.contiguous(),
+ features,
+ masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ return grad_input, grad_masks, None, None, None
+carafe_naive = CARAFENaiveFunction.apply
+class CARAFENaive(Module):
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFENaive, self).__init__()
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+ def forward(self, features, masks):
+ return carafe_naive(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+class CARAFEFunction(Function):
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFE',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ routput = features.new_zeros(output.size(), requires_grad=False)
+ rfeatures = features.new_zeros(features.size(), requires_grad=False)
+ rmasks = masks.new_zeros(masks.size(), requires_grad=False)
+ ext_module.carafe_forward(
+ features,
+ masks,
+ rfeatures,
+ routput,
+ rmasks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks, rfeatures)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+ features, masks, rfeatures = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+ rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input = torch.zeros_like(features, requires_grad=False)
+ rgrad_masks = torch.zeros_like(masks, requires_grad=False)
+ grad_input = torch.zeros_like(features, requires_grad=False)
+ grad_masks = torch.zeros_like(masks, requires_grad=False)
+ ext_module.carafe_backward(
+ grad_output.contiguous(),
+ rfeatures,
+ masks,
+ rgrad_output,
+ rgrad_input_hs,
+ rgrad_input,
+ rgrad_masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ return grad_input, grad_masks, None, None, None
+carafe = CARAFEFunction.apply
+class CARAFE(Module):
+ """ CARAFE: Content-Aware ReAssembly of FEatures
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+ Args:
+ kernel_size (int): reassemble kernel size
+ group_size (int): reassemble group size
+ scale_factor (int): upsample ratio
+ Returns:
+ upsampled feature map
+ """
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFE, self).__init__()
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+ def forward(self, features, masks):
+ return carafe(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+class CARAFEPack(nn.Module):
+ """A unified package of CARAFE upsampler that contains: 1) channel
+ compressor 2) content encoder 3) CARAFE op.
+ Official implementation of ICCV 2019 paper
+ CARAFE: Content-Aware ReAssembly of FEatures
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+ Args:
+ channels (int): input feature channels
+ scale_factor (int): upsample ratio
+ up_kernel (int): kernel size of CARAFE op
+ up_group (int): group size of CARAFE op
+ encoder_kernel (int): kernel size of content encoder
+ encoder_dilation (int): dilation of content encoder
+ compressed_channels (int): output channels of channels compressor
+ Returns:
+ upsampled feature map
+ """
+ def __init__(self,
+ channels,
+ scale_factor,
+ up_kernel=5,
+ up_group=1,
+ encoder_kernel=3,
+ encoder_dilation=1,
+ compressed_channels=64):
+ super(CARAFEPack, self).__init__()
+ self.channels = channels
+ self.scale_factor = scale_factor
+ self.up_kernel = up_kernel
+ self.up_group = up_group
+ self.encoder_kernel = encoder_kernel
+ self.encoder_dilation = encoder_dilation
+ self.compressed_channels = compressed_channels
+ self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
+ 1)
+ self.content_encoder = nn.Conv2d(
+ self.compressed_channels,
+ self.up_kernel * self.up_kernel * self.up_group *
+ self.scale_factor * self.scale_factor,
+ self.encoder_kernel,
+ padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
+ dilation=self.encoder_dilation,
+ groups=1)
+ self.init_weights()
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ normal_init(self.content_encoder, std=0.001)
+ def kernel_normalizer(self, mask):
+ mask = F.pixel_shuffle(mask, self.scale_factor)
+ n, mask_c, h, w = mask.size()
+ # use float division explicitly,
+ # to void inconsistency while exporting to onnx
+ mask_channel = int(mask_c / float(self.up_kernel**2))
+ mask = mask.view(n, mask_channel, -1, h, w)
+ mask = F.softmax(mask, dim=2, dtype=mask.dtype)
+ mask = mask.view(n, mask_c, h, w).contiguous()
+ return mask
+ def feature_reassemble(self, x, mask):
+ x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
+ return x
+ def forward(self, x):
+ compressed_x = self.channel_compressor(x)
+ mask = self.content_encoder(compressed_x)
+ mask = self.kernel_normalizer(mask)
+ x = self.feature_reassemble(x, mask)
+ return x
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/cc_attention.py b/ControlNet/annotator/uniformer/mmcv/ops/cc_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9207aa95e6730bd9b3362dee612059a5f0ce1c5e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/cc_attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import PLUGIN_LAYERS, Scale
+def NEG_INF_DIAG(n, device):
+ """Returns a diagonal matrix of size [n, n].
+ The diagonal are all "-inf". This is for avoiding calculating the
+ overlapped element in the Criss-Cross twice.
+ """
+ return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
+class CrissCrossAttention(nn.Module):
+ """Criss-Cross Attention Module.
+ .. note::
+ Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
+ to a pure PyTorch and equivalent implementation. For more
+ details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
+ Speed comparison for one forward pass
+ - Input size: [2,512,97,97]
+ - Device: 1 NVIDIA GeForce RTX 2080 Ti
+ +-----------------------+---------------+------------+---------------+
+ | |PyTorch version|CUDA version|Relative speed |
+ +=======================+===============+============+===============+
+ |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+ |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ """
+ def __init__(self, in_channels):
+ super().__init__()
+ self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
+ self.gamma = Scale(0.)
+ self.in_channels = in_channels
+ def forward(self, x):
+ """forward function of Criss-Cross Attention.
+ Args:
+ x (Tensor): Input feature. \
+ shape (batch_size, in_channels, height, width)
+ Returns:
+ Tensor: Output of the layer, with shape of \
+ (batch_size, in_channels, height, width)
+ """
+ B, C, H, W = x.size()
+ query = self.query_conv(x)
+ key = self.key_conv(x)
+ value = self.value_conv(x)
+ energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
+ H, query.device)
+ energy_H = energy_H.transpose(1, 2)
+ energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
+ attn = F.softmax(
+ torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
+ out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
+ out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])
+ out = self.gamma(out) + x
+ out = out.contiguous()
+ return out
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/contour_expand.py b/ControlNet/annotator/uniformer/mmcv/ops/contour_expand.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1111e1768b5f27e118bf7dbc0d9c70a7afd6d7
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/contour_expand.py
@@ -0,0 +1,49 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
+def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
+ kernel_num):
+ """Expand kernel contours so that foreground pixels are assigned into
+ instances.
+ Arguments:
+ kernel_mask (np.array or Tensor): The instance kernel mask with
+ size hxw.
+ internal_kernel_label (np.array or Tensor): The instance internal
+ kernel label with size hxw.
+ min_kernel_area (int): The minimum kernel area.
+ kernel_num (int): The instance kernel number.
+ Returns:
+ label (list): The instance index map with size hxw.
+ """
+ assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
+ assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(min_kernel_area, int)
+ assert isinstance(kernel_num, int)
+ if isinstance(kernel_mask, np.ndarray):
+ kernel_mask = torch.from_numpy(kernel_mask)
+ if isinstance(internal_kernel_label, np.ndarray):
+ internal_kernel_label = torch.from_numpy(internal_kernel_label)
+ if torch.__version__ == 'parrots':
+ if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0:
+ label = []
+ else:
+ label = ext_module.contour_expand(
+ kernel_mask,
+ internal_kernel_label,
+ min_kernel_area=min_kernel_area,
+ kernel_num=kernel_num)
+ label = label.tolist()
+ else:
+ label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
+ min_kernel_area, kernel_num)
+ return label
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/corner_pool.py b/ControlNet/annotator/uniformer/mmcv/ops/corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33d798b43d405e4c86bee4cd6389be21ca9c637
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/corner_pool.py
@@ -0,0 +1,161 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', [
+ 'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward',
+ 'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward',
+ 'right_pool_forward', 'right_pool_backward'
+_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
+class TopPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
+ return output
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.top_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.top_pool_backward(input, grad_output)
+ return output
+class BottomPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
+ return output
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.bottom_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.bottom_pool_backward(input, grad_output)
+ return output
+class LeftPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
+ return output
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.left_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.left_pool_backward(input, grad_output)
+ return output
+class RightPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
+ return output
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.right_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.right_pool_backward(input, grad_output)
+ return output
+class CornerPool(nn.Module):
+ """Corner Pooling.
+ Corner Pooling is a new type of pooling layer that helps a
+ convolutional network better localize corners of bounding boxes.
+ Please refer to https://arxiv.org/abs/1808.01244 for more details.
+ Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
+ Args:
+ mode(str): Pooling orientation for the pooling layer
+ - 'bottom': Bottom Pooling
+ - 'left': Left Pooling
+ - 'right': Right Pooling
+ - 'top': Top Pooling
+ Returns:
+ Feature map after pooling.
+ """
+ pool_functions = {
+ 'bottom': BottomPoolFunction,
+ 'left': LeftPoolFunction,
+ 'right': RightPoolFunction,
+ 'top': TopPoolFunction,
+ }
+ cummax_dim_flip = {
+ 'bottom': (2, False),
+ 'left': (3, True),
+ 'right': (3, False),
+ 'top': (2, True),
+ }
+ def __init__(self, mode):
+ super(CornerPool, self).__init__()
+ assert mode in self.pool_functions
+ self.mode = mode
+ self.corner_pool = self.pool_functions[mode]
+ def forward(self, x):
+ if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
+ if torch.onnx.is_in_onnx_export():
+ assert torch.__version__ >= '1.7.0', \
+ 'When `cummax` serves as an intermediate component whose '\
+ 'outputs is used as inputs for another modules, it\'s '\
+ 'expected that pytorch version must be >= 1.7.0, '\
+ 'otherwise Error appears like: `RuntimeError: tuple '\
+ 'appears in op that does not forward tuples, unsupported '\
+ 'kind: prim::PythonOp`.'
+ dim, flip = self.cummax_dim_flip[self.mode]
+ if flip:
+ x = x.flip(dim)
+ pool_tensor, _ = torch.cummax(x, dim=dim)
+ if flip:
+ pool_tensor = pool_tensor.flip(dim)
+ return pool_tensor
+ else:
+ return self.corner_pool.apply(x)
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/correlation.py b/ControlNet/annotator/uniformer/mmcv/ops/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0b79c301b29915dfaf4d2b1846c59be73127d3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/correlation.py
@@ -0,0 +1,196 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import Tensor, nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['correlation_forward', 'correlation_backward'])
+class CorrelationFunction(Function):
+ @staticmethod
+ def forward(ctx,
+ input1,
+ input2,
+ kernel_size=1,
+ max_displacement=1,
+ stride=1,
+ padding=1,
+ dilation=1,
+ dilation_patch=1):
+ ctx.save_for_backward(input1, input2)
+ kH, kW = ctx.kernel_size = _pair(kernel_size)
+ patch_size = max_displacement * 2 + 1
+ ctx.patch_size = patch_size
+ dH, dW = ctx.stride = _pair(stride)
+ padH, padW = ctx.padding = _pair(padding)
+ dilationH, dilationW = ctx.dilation = _pair(dilation)
+ dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(
+ dilation_patch)
+ output_size = CorrelationFunction._output_size(ctx, input1)
+ output = input1.new_zeros(output_size)
+ ext_module.correlation_forward(
+ input1,
+ input2,
+ output,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input1, input2 = ctx.saved_tensors
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilation_patchH, dilation_patchW = ctx.dilation_patch
+ dH, dW = ctx.stride
+ grad_input1 = torch.zeros_like(input1)
+ grad_input2 = torch.zeros_like(input2)
+ ext_module.correlation_backward(
+ grad_output,
+ input1,
+ input2,
+ grad_input1,
+ grad_input2,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+ return grad_input1, grad_input2, None, None, None, None, None, None
+ @staticmethod
+ def _output_size(ctx, input1):
+ iH, iW = input1.size(2), input1.size(3)
+ batch_size = input1.size(0)
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ dH, dW = ctx.stride
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilatedKH = (kH - 1) * dilationH + 1
+ dilatedKW = (kW - 1) * dilationW + 1
+ oH = int((iH + 2 * padH - dilatedKH) / dH + 1)
+ oW = int((iW + 2 * padW - dilatedKW) / dW + 1)
+ output_size = (batch_size, patch_size, patch_size, oH, oW)
+ return output_size
+class Correlation(nn.Module):
+ r"""Correlation operator
+ This correlation operator works for optical flow correlation computation.
+ There are two batched tensors with shape :math:`(N, C, H, W)`,
+ and the correlation output's shape is :math:`(N, max\_displacement \times
+ 2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})`
+ where
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding -
+ dilation \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation
+ \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+ the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
+ window convolution between input1 and shifted input2,
+ .. math::
+ Corr(N_i, dx, dy) =
+ \sum_{c=0}^{C-1}
+ input1(N_i, c) \star
+ \mathcal{S}(input2(N_i, c), dy, dx)
+ where :math:`\star` is the valid 2d sliding window convolution operator,
+ and :math:`\mathcal{S}` means shifting the input features (auto-complete
+ zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
+ [-max\_displacement \times dilation\_patch, max\_displacement \times
+ dilation\_patch]`.
+ Args:
+ kernel_size (int): The size of sliding window i.e. local neighborhood
+ representing the center points and involved in correlation
+ computation. Defaults to 1.
+ max_displacement (int): The radius for computing correlation volume,
+ but the actual working space can be dilated by dilation_patch.
+ Defaults to 1.
+ stride (int): The stride of the sliding blocks in the input spatial
+ dimensions. Defaults to 1.
+ padding (int): Zero padding added to all four sides of the input1.
+ Defaults to 0.
+ dilation (int): The spacing of local neighborhood that will involved
+ in correlation. Defaults to 1.
+ dilation_patch (int): The spacing between position need to compute
+ correlation. Defaults to 1.
+ """
+ def __init__(self,
+ kernel_size: int = 1,
+ max_displacement: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ dilation_patch: int = 1) -> None:
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.max_displacement = max_displacement
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.dilation_patch = dilation_patch
+ def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
+ return CorrelationFunction.apply(input1, input2, self.kernel_size,
+ self.max_displacement, self.stride,
+ self.padding, self.dilation,
+ self.dilation_patch)
+ def __repr__(self) -> str:
+ s = self.__class__.__name__
+ s += f'(kernel_size={self.kernel_size}, '
+ s += f'max_displacement={self.max_displacement}, '
+ s += f'stride={self.stride}, '
+ s += f'padding={self.padding}, '
+ s += f'dilation={self.dilation}, '
+ s += f'dilation_patch={self.dilation_patch})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/deform_conv.py b/ControlNet/annotator/uniformer/mmcv/ops/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f8c75ee774823eea334e3b3732af6a18f55038
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/deform_conv.py
@@ -0,0 +1,405 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+ext_module = ext_loader.load_ext('_ext', [
+ 'deform_conv_forward', 'deform_conv_backward_input',
+ 'deform_conv_backward_parameters'
+class DeformConv2dFunction(Function):
+ @staticmethod
+ def symbolic(g,
+ input,
+ offset,
+ weight,
+ stride,
+ padding,
+ dilation,
+ groups,
+ deform_groups,
+ bias=False,
+ im2col_step=32):
+ return g.op(
+ 'mmcv::MMCVDeformConv2d',
+ input,
+ offset,
+ weight,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups,
+ bias_i=bias,
+ im2col_step_i=im2col_step)
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=False,
+ im2col_step=32):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ assert bias is False, 'Only support bias is False.'
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.im2col_step = im2col_step
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, weight)
+ output = input.new_empty(
+ DeformConv2dFunction._output_size(ctx, input, weight))
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ ext_module.deform_conv_forward(
+ input,
+ weight,
+ offset,
+ output,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+ grad_input = grad_offset = grad_weight = None
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) % cur_im2col_step
+ ) == 0, 'batch size must be divisible by im2col_step'
+ grad_output = grad_output.contiguous()
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ ext_module.deform_conv_backward_input(
+ input,
+ offset,
+ grad_output,
+ grad_input,
+ grad_offset,
+ weight,
+ ctx.bufs_[0],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ ext_module.deform_conv_backward_parameters(
+ input,
+ offset,
+ grad_output,
+ grad_weight,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ scale=1,
+ im2col_step=cur_im2col_step)
+ return grad_input, grad_offset, grad_weight, \
+ None, None, None, None, None, None, None
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+deform_conv2d = DeformConv2dFunction.apply
+class DeformConv2d(nn.Module):
+ r"""Deformable 2D convolution.
+ Applies a deformable 2D convolution over an input signal composed of
+ several input planes. DeformConv2d was described in the paper
+ `Deformable Convolutional Networks
+ `_
+ Note:
+ The argument ``im2col_step`` was added in version 1.3.17, which means
+ number of samples processed by the ``im2col_cuda_kernel`` per call.
+ It enables users to define ``batch_size`` and ``im2col_step`` more
+ flexibly and solved `issue mmcv#1440
+ `_.
+ Args:
+ in_channels (int): Number of channels in the input image.
+ out_channels (int): Number of channels produced by the convolution.
+ kernel_size(int, tuple): Size of the convolving kernel.
+ stride(int, tuple): Stride of the convolution. Default: 1.
+ padding (int or tuple): Zero-padding added to both sides of the input.
+ Default: 0.
+ dilation (int or tuple): Spacing between kernel elements. Default: 1.
+ groups (int): Number of blocked connections from input.
+ channels to output channels. Default: 1.
+ deform_groups (int): Number of deformable group partitions.
+ bias (bool): If True, adds a learnable bias to the output.
+ Default: False.
+ im2col_step (int): Number of samples processed by im2col_cuda_kernel
+ per call. It will work when ``batch_size`` > ``im2col_step``, but
+ ``batch_size`` must be divisible by ``im2col_step``. Default: 32.
+ `New in version 1.3.17.`
+ """
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='DeformConv2d')
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, ...]],
+ stride: Union[int, Tuple[int, ...]] = 1,
+ padding: Union[int, Tuple[int, ...]] = 0,
+ dilation: Union[int, Tuple[int, ...]] = 1,
+ groups: int = 1,
+ deform_groups: int = 1,
+ bias: bool = False,
+ im2col_step: int = 32) -> None:
+ super(DeformConv2d, self).__init__()
+ assert not bias, \
+ f'bias={bias} is not supported in DeformConv2d.'
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} cannot be divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} cannot be divisible by groups \
+ {groups}'
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ self.im2col_step = im2col_step
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+ # only weight, no bias
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // self.groups,
+ *self.kernel_size))
+ self.reset_parameters()
+ def reset_parameters(self):
+ # switch the initialization of `self.weight` to the standard kaiming
+ # method described in `Delving deep into rectifiers: Surpassing
+ # human-level performance on ImageNet classification` - He, K. et al.
+ # (2015), using a uniform distribution
+ nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
+ def forward(self, x: Tensor, offset: Tensor) -> Tensor:
+ """Deformable Convolutional forward function.
+ Args:
+ x (Tensor): Input feature, shape (B, C_in, H_in, W_in)
+ offset (Tensor): Offset for deformable convolution, shape
+ (B, deform_groups*kernel_size[0]*kernel_size[1]*2,
+ H_out, W_out), H_out, W_out are equal to the output's.
+ An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+ .. code:: text
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+ Returns:
+ Tensor: Output of the layer.
+ """
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) <
+ self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
+ offset = offset.contiguous()
+ out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
+ pad_w].contiguous()
+ return out
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels},\n'
+ s += f'out_channels={self.out_channels},\n'
+ s += f'kernel_size={self.kernel_size},\n'
+ s += f'stride={self.stride},\n'
+ s += f'padding={self.padding},\n'
+ s += f'dilation={self.dilation},\n'
+ s += f'groups={self.groups},\n'
+ s += f'deform_groups={self.deform_groups},\n'
+ # bias is not supported in DeformConv2d.
+ s += 'bias=False)'
+ return s
+class DeformConv2dPack(DeformConv2d):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+ The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+ .. code:: text
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+ _version = 2
+ def __init__(self, *args, **kwargs):
+ super(DeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, DeformConvPack loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+ if version is not None and version > 1:
+ print_log(
+ f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/deform_roi_pool.py b/ControlNet/annotator/uniformer/mmcv/ops/deform_roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc245ba91fee252226ba22e76bb94a35db9a629b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/deform_roi_pool.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward'])
+class DeformRoIPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input, rois, offset, output_size, spatial_scale,
+ sampling_ratio, gamma):
+ return g.op(
+ 'mmcv::MMCVDeformRoIPool',
+ input,
+ rois,
+ offset,
+ pooled_height_i=output_size[0],
+ pooled_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_f=sampling_ratio,
+ gamma_f=gamma)
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ offset,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ if offset is None:
+ offset = input.new_zeros(0)
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = float(spatial_scale)
+ ctx.sampling_ratio = int(sampling_ratio)
+ ctx.gamma = float(gamma)
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ ext_module.deform_roi_pool_forward(
+ input,
+ rois,
+ offset,
+ output,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+ ctx.save_for_backward(input, rois, offset)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, rois, offset = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(input.shape)
+ grad_offset = grad_output.new_zeros(offset.shape)
+ ext_module.deform_roi_pool_backward(
+ grad_output,
+ input,
+ rois,
+ offset,
+ grad_input,
+ grad_offset,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+ if grad_offset.numel() == 0:
+ grad_offset = None
+ return grad_input, None, grad_offset, None, None, None, None
+deform_roi_pool = DeformRoIPoolFunction.apply
+class DeformRoIPool(nn.Module):
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPool, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.gamma = float(gamma)
+ def forward(self, input, rois, offset=None):
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+class DeformRoIPoolPack(DeformRoIPool):
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
+ sampling_ratio, gamma)
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+class ModulatedDeformRoIPoolPack(DeformRoIPool):
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(ModulatedDeformRoIPoolPack,
+ self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+ self.mask_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 1),
+ nn.Sigmoid())
+ self.mask_fc[2].weight.data.zero_()
+ self.mask_fc[2].bias.data.zero_()
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ mask = self.mask_fc(x.view(rois_num, -1))
+ mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1])
+ d = deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ return d * mask
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/deprecated_wrappers.py b/ControlNet/annotator/uniformer/mmcv/ops/deprecated_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e593df9ee57637038683d7a1efaa347b2b69e7
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/deprecated_wrappers.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This file is for backward compatibility.
+# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks.
+import warnings
+from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
+class Conv2d_deprecated(Conv2d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+class ConvTranspose2d_deprecated(ConvTranspose2d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
+ 'deprecated in the future. Please import them from "mmcv.cnn" '
+ 'instead')
+class MaxPool2d_deprecated(MaxPool2d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+class Linear_deprecated(Linear):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/focal_loss.py b/ControlNet/annotator/uniformer/mmcv/ops/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..763bc93bd2575c49ca8ccf20996bbd92d1e0d1a4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/focal_loss.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', [
+ 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
+ 'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
+class SigmoidFocalLossFunction(Function):
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSigmoidFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+ output = input.new_zeros(input.size())
+ ext_module.sigmoid_focal_loss_forward(
+ input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input, target, weight)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, target, weight = ctx.saved_tensors
+ grad_input = input.new_zeros(input.size())
+ ext_module.sigmoid_focal_loss_backward(
+ input,
+ target,
+ weight,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input.size(0)
+ return grad_input, None, None, None, None, None
+sigmoid_focal_loss = SigmoidFocalLossFunction.apply
+class SigmoidFocalLoss(nn.Module):
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SigmoidFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+ def forward(self, input, target):
+ return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
+class SoftmaxFocalLossFunction(Function):
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSoftmaxFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+ channel_stats, _ = torch.max(input, dim=1)
+ input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
+ input_softmax.exp_()
+ channel_stats = input_softmax.sum(dim=1)
+ input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
+ output = input.new_zeros(input.size(0))
+ ext_module.softmax_focal_loss_forward(
+ input_softmax,
+ target,
+ weight,
+ output,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input_softmax, target, weight)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_softmax, target, weight = ctx.saved_tensors
+ buff = input_softmax.new_zeros(input_softmax.size(0))
+ grad_input = input_softmax.new_zeros(input_softmax.size())
+ ext_module.softmax_focal_loss_backward(
+ input_softmax,
+ target,
+ weight,
+ buff,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input_softmax.size(0)
+ return grad_input, None, None, None, None, None
+softmax_focal_loss = SoftmaxFocalLossFunction.apply
+class SoftmaxFocalLoss(nn.Module):
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SoftmaxFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+ def forward(self, input, target):
+ return softmax_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/furthest_point_sample.py b/ControlNet/annotator/uniformer/mmcv/ops/furthest_point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..374b7a878f1972c183941af28ba1df216ac1a60f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/furthest_point_sample.py
@@ -0,0 +1,83 @@
+import torch
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', [
+ 'furthest_point_sampling_forward',
+ 'furthest_point_sampling_with_dist_forward'
+class FurthestPointSampling(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+ @staticmethod
+ def forward(ctx, points_xyz: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) where N > num_points.
+ num_points (int): Number of points in the sampled set.
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_xyz.is_contiguous()
+ B, N = points_xyz.size()[:2]
+ output = torch.cuda.IntTensor(B, num_points)
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+ ext_module.furthest_point_sampling_forward(
+ points_xyz,
+ temp,
+ output,
+ b=B,
+ n=N,
+ m=num_points,
+ )
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+class FurthestPointSamplingWithDist(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+ @staticmethod
+ def forward(ctx, points_dist: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_dist (Tensor): (B, N, N) Distance between each point pair.
+ num_points (int): Number of points in the sampled set.
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_dist.is_contiguous()
+ B, N, _ = points_dist.size()
+ output = points_dist.new_zeros([B, num_points], dtype=torch.int32)
+ temp = points_dist.new_zeros([B, N]).fill_(1e10)
+ ext_module.furthest_point_sampling_with_dist_forward(
+ points_dist, temp, output, b=B, n=N, m=num_points)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+furthest_point_sample = FurthestPointSampling.apply
+furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py b/ControlNet/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d12508469c6c8fa1884debece44c58d158cb6fa
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py
@@ -0,0 +1,268 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+# 1. Definitions
+# "Licensor" means any person or entity that distributes its Work.
+# "Software" means the original work of authorship made available under
+# this License.
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+# 2. License Grants
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+# 3. Limitations
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+# 4. Disclaimer of Warranty.
+# 5. Limitation of Liability.
+# =======================================================================
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu'])
+class FusedBiasLeakyReLUFunctionBackward(Function):
+ """Calculate second order deviation.
+ This function is to compute the second order deviation for the fused leaky
+ relu operation.
+ """
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+ empty = grad_output.new_empty(0)
+ grad_input = ext_module.fused_bias_leakyrelu(
+ grad_output,
+ empty,
+ out,
+ act=3,
+ grad=1,
+ alpha=negative_slope,
+ scale=scale)
+ dim = [0]
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+ grad_bias = grad_input.sum(dim).detach()
+ return grad_input, grad_bias
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ # The second order deviation, in fact, contains two parts, while the
+ # the first part is zero. Thus, we direct consider the second part
+ # which is similar with the first order deviation in implementation.
+ gradgrad_out = ext_module.fused_bias_leakyrelu(
+ gradgrad_input,
+ gradgrad_bias.to(out.dtype),
+ out,
+ act=3,
+ grad=1,
+ alpha=ctx.negative_slope,
+ scale=ctx.scale)
+ return gradgrad_out, None, None, None
+class FusedBiasLeakyReLUFunction(Function):
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = ext_module.fused_bias_leakyrelu(
+ input,
+ bias,
+ empty,
+ act=3,
+ grad=0,
+ alpha=negative_slope,
+ scale=scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+ return out
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+ grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale)
+ return grad_input, grad_bias, None, None
+class FusedBiasLeakyReLU(nn.Module):
+ """Fused bias leaky ReLU.
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+ TODO: Implement the CPU version.
+ Args:
+ channel (int): The channel number of the feature map.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+ """
+ def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
+ super(FusedBiasLeakyReLU, self).__init__()
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.negative_slope = negative_slope
+ self.scale = scale
+ def forward(self, input):
+ return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
+ self.scale)
+def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
+ """Fused bias leaky ReLU function.
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+ Args:
+ input (torch.Tensor): Input feature map.
+ bias (nn.Parameter): The bias from convolution operation.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+ Returns:
+ torch.Tensor: Feature map after non-linear activation.
+ """
+ if not input.is_cuda:
+ return bias_leakyrelu_ref(input, bias, negative_slope, scale)
+ return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
+ negative_slope, scale)
+def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
+ if bias is not None:
+ assert bias.ndim == 1
+ assert bias.shape[0] == x.shape[1]
+ x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)])
+ x = F.leaky_relu(x, negative_slope)
+ if scale != 1:
+ x = x * scale
+ return x
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/gather_points.py b/ControlNet/annotator/uniformer/mmcv/ops/gather_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52f1677d8ea0facafc56a3672d37adb44677ff3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/gather_points.py
@@ -0,0 +1,57 @@
+import torch
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['gather_points_forward', 'gather_points_backward'])
+class GatherPoints(Function):
+ """Gather points with given index."""
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) features to gather.
+ indices (Tensor): (B, M) where M is the number of points.
+ Returns:
+ Tensor: (B, C, M) where M is the number of points.
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+ B, npoint = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, npoint)
+ ext_module.gather_points_forward(
+ features, indices, output, b=B, c=C, n=N, npoints=npoint)
+ ctx.for_backwards = (indices, C, N)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(indices)
+ return output
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, C, N = ctx.for_backwards
+ B, npoint = idx.size()
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.gather_points_backward(
+ grad_out_data,
+ idx,
+ grad_features.data,
+ b=B,
+ c=C,
+ n=N,
+ npoints=npoint)
+ return grad_features, None
+gather_points = GatherPoints.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/group_points.py b/ControlNet/annotator/uniformer/mmcv/ops/group_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3ec9d758ebe4e1c2205882af4be154008253a5
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/group_points.py
@@ -0,0 +1,224 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from ..utils import ext_loader
+from .ball_query import ball_query
+from .knn import knn
+ext_module = ext_loader.load_ext(
+ '_ext', ['group_points_forward', 'group_points_backward'])
+class QueryAndGroup(nn.Module):
+ """Groups points with a ball query of radius.
+ Args:
+ max_radius (float): The maximum radius of the balls.
+ If None is given, we will use kNN sampling instead of ball query.
+ sample_num (int): Maximum number of features to gather in the ball.
+ min_radius (float, optional): The minimum radius of the balls.
+ Default: 0.
+ use_xyz (bool, optional): Whether to use xyz.
+ Default: True.
+ return_grouped_xyz (bool, optional): Whether to return grouped xyz.
+ Default: False.
+ normalize_xyz (bool, optional): Whether to normalize xyz.
+ Default: False.
+ uniform_sample (bool, optional): Whether to sample uniformly.
+ Default: False
+ return_unique_cnt (bool, optional): Whether to return the count of
+ unique samples. Default: False.
+ return_grouped_idx (bool, optional): Whether to return grouped idx.
+ Default: False.
+ """
+ def __init__(self,
+ max_radius,
+ sample_num,
+ min_radius=0,
+ use_xyz=True,
+ return_grouped_xyz=False,
+ normalize_xyz=False,
+ uniform_sample=False,
+ return_unique_cnt=False,
+ return_grouped_idx=False):
+ super().__init__()
+ self.max_radius = max_radius
+ self.min_radius = min_radius
+ self.sample_num = sample_num
+ self.use_xyz = use_xyz
+ self.return_grouped_xyz = return_grouped_xyz
+ self.normalize_xyz = normalize_xyz
+ self.uniform_sample = uniform_sample
+ self.return_unique_cnt = return_unique_cnt
+ self.return_grouped_idx = return_grouped_idx
+ if self.return_unique_cnt:
+ assert self.uniform_sample, \
+ 'uniform_sample should be True when ' \
+ 'returning the count of unique samples'
+ if self.max_radius is None:
+ assert not self.normalize_xyz, \
+ 'can not normalize grouped xyz when max_radius is None'
+ def forward(self, points_xyz, center_xyz, features=None):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods.
+ features (Tensor): (B, C, N) Descriptors of the features.
+ Returns:
+ Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
+ """
+ # if self.max_radius is None, we will perform kNN instead of ball query
+ # idx is of shape [B, npoint, sample_num]
+ if self.max_radius is None:
+ idx = knn(self.sample_num, points_xyz, center_xyz, False)
+ idx = idx.transpose(1, 2).contiguous()
+ else:
+ idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
+ points_xyz, center_xyz)
+ if self.uniform_sample:
+ unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
+ for i_batch in range(idx.shape[0]):
+ for i_region in range(idx.shape[1]):
+ unique_ind = torch.unique(idx[i_batch, i_region, :])
+ num_unique = unique_ind.shape[0]
+ unique_cnt[i_batch, i_region] = num_unique
+ sample_ind = torch.randint(
+ 0,
+ num_unique, (self.sample_num - num_unique, ),
+ dtype=torch.long)
+ all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
+ idx[i_batch, i_region, :] = all_ind
+ xyz_trans = points_xyz.transpose(1, 2).contiguous()
+ # (B, 3, npoint, sample_num)
+ grouped_xyz = grouping_operation(xyz_trans, idx)
+ grouped_xyz_diff = grouped_xyz - \
+ center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets
+ if self.normalize_xyz:
+ grouped_xyz_diff /= self.max_radius
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ # (B, C + 3, npoint, sample_num)
+ new_features = torch.cat([grouped_xyz_diff, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ assert (self.use_xyz
+ ), 'Cannot have not features and not use xyz as a feature!'
+ new_features = grouped_xyz_diff
+ ret = [new_features]
+ if self.return_grouped_xyz:
+ ret.append(grouped_xyz)
+ if self.return_unique_cnt:
+ ret.append(unique_cnt)
+ if self.return_grouped_idx:
+ ret.append(idx)
+ if len(ret) == 1:
+ return ret[0]
+ else:
+ return tuple(ret)
+class GroupAll(nn.Module):
+ """Group xyz with feature.
+ Args:
+ use_xyz (bool): Whether to use xyz.
+ """
+ def __init__(self, use_xyz: bool = True):
+ super().__init__()
+ self.use_xyz = use_xyz
+ def forward(self,
+ xyz: torch.Tensor,
+ new_xyz: torch.Tensor,
+ features: torch.Tensor = None):
+ """
+ Args:
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ new_xyz (Tensor): new xyz coordinates of the features.
+ features (Tensor): (B, C, N) features to group.
+ Returns:
+ Tensor: (B, C + 3, 1, N) Grouped feature.
+ """
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ # (B, 3 + C, 1, N)
+ new_features = torch.cat([grouped_xyz, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+ return new_features
+class GroupingOperation(Function):
+ """Group feature with given index."""
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) tensor of features to group.
+ indices (Tensor): (B, npoint, nsample) the indices of
+ features to group with.
+ Returns:
+ Tensor: (B, C, npoint, nsample) Grouped features.
+ """
+ features = features.contiguous()
+ indices = indices.contiguous()
+ B, nfeatures, nsample = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+ ext_module.group_points_forward(B, C, N, nfeatures, nsample, features,
+ indices, output)
+ ctx.for_backwards = (indices, N)
+ return output
+ @staticmethod
+ def backward(ctx,
+ grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
+ of the output from forward.
+ Returns:
+ Tensor: (B, C, N) gradient of the features.
+ """
+ idx, N = ctx.for_backwards
+ B, C, npoint, nsample = grad_out.size()
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.group_points_backward(B, C, N, npoint, nsample,
+ grad_out_data, idx,
+ grad_features.data)
+ return grad_features, None
+grouping_operation = GroupingOperation.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/info.py b/ControlNet/annotator/uniformer/mmcv/ops/info.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f2e5598ae2bb5866ccd15a7d3b4de33c0cd14d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/info.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import glob
+import os
+import torch
+if torch.__version__ == 'parrots':
+ import parrots
+ def get_compiler_version():
+ return 'GCC ' + parrots.version.compiler
+ def get_compiling_cuda_version():
+ return parrots.version.cuda
+ from ..utils import ext_loader
+ ext_module = ext_loader.load_ext(
+ '_ext', ['get_compiler_version', 'get_compiling_cuda_version'])
+ def get_compiler_version():
+ return ext_module.get_compiler_version()
+ def get_compiling_cuda_version():
+ return ext_module.get_compiling_cuda_version()
+def get_onnxruntime_op_path():
+ wildcard = os.path.join(
+ os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
+ '_ext_ort.*.so')
+ paths = glob.glob(wildcard)
+ if len(paths) > 0:
+ return paths[0]
+ else:
+ return ''
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/iou3d.py b/ControlNet/annotator/uniformer/mmcv/ops/iou3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc71979190323f44c09f8b7e1761cf49cd2d76b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/iou3d.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', [
+ 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward',
+ 'iou3d_nms_normal_forward'
+def boxes_iou_bev(boxes_a, boxes_b):
+ """Calculate boxes IoU in the Bird's Eye View.
+ Args:
+ boxes_a (torch.Tensor): Input boxes a with shape (M, 5).
+ boxes_b (torch.Tensor): Input boxes b with shape (N, 5).
+ Returns:
+ ans_iou (torch.Tensor): IoU result with shape (M, N).
+ """
+ ans_iou = boxes_a.new_zeros(
+ torch.Size((boxes_a.shape[0], boxes_b.shape[0])))
+ ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(),
+ boxes_b.contiguous(), ans_iou)
+ return ans_iou
+def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
+ """NMS function GPU implementation (for BEV boxes). The overlap of two
+ boxes for IoU calculation is defined as the exact overlapping area of the
+ two boxes. In this function, one can also set ``pre_max_size`` and
+ ``post_max_size``.
+ Args:
+ boxes (torch.Tensor): Input boxes with the shape of [N, 5]
+ ([x1, y1, x2, y2, ry]).
+ scores (torch.Tensor): Scores of boxes with the shape of [N].
+ thresh (float): Overlap threshold of NMS.
+ pre_max_size (int, optional): Max size of boxes before NMS.
+ Default: None.
+ post_max_size (int, optional): Max size of boxes after NMS.
+ Default: None.
+ Returns:
+ torch.Tensor: Indexes after NMS.
+ """
+ assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+ if pre_max_size is not None:
+ order = order[:pre_max_size]
+ boxes = boxes[order].contiguous()
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh)
+ keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
+ if post_max_size is not None:
+ keep = keep[:post_max_size]
+ return keep
+def nms_normal_bev(boxes, scores, thresh):
+ """Normal NMS function GPU implementation (for BEV boxes). The overlap of
+ two boxes for IoU calculation is defined as the exact overlapping area of
+ the two boxes WITH their yaw angle set to 0.
+ Args:
+ boxes (torch.Tensor): Input boxes with shape (N, 5).
+ scores (torch.Tensor): Scores of predicted boxes with shape (N).
+ thresh (float): Overlap threshold of NMS.
+ Returns:
+ torch.Tensor: Remaining indices with scores in descending order.
+ """
+ assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+ boxes = boxes[order].contiguous()
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh)
+ return order[keep[:num_out].cuda(boxes.device)].contiguous()
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/knn.py b/ControlNet/annotator/uniformer/mmcv/ops/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f335785036669fc19239825b0aae6dde3f73bf92
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/knn.py
@@ -0,0 +1,77 @@
+import torch
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['knn_forward'])
+class KNN(Function):
+ r"""KNN (CUDA) based on heap data structure.
+ Modified from `PAConv `_.
+ Find k-nearest points.
+ """
+ @staticmethod
+ def forward(ctx,
+ k: int,
+ xyz: torch.Tensor,
+ center_xyz: torch.Tensor = None,
+ transposed: bool = False) -> torch.Tensor:
+ """
+ Args:
+ k (int): number of nearest neighbors.
+ xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
+ xyz coordinates of the features.
+ center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
+ False, else (B, 3, npoint). centers of the knn query.
+ Default: None.
+ transposed (bool, optional): whether the input tensors are
+ transposed. Should not explicitly use this keyword when
+ calling knn (=KNN.apply), just add the fourth param.
+ Default: False.
+ Returns:
+ Tensor: (B, k, npoint) tensor with the indices of
+ the features that form k-nearest neighbours.
+ """
+ assert (k > 0) & (k < 100), 'k should be in range(0, 100)'
+ if center_xyz is None:
+ center_xyz = xyz
+ if transposed:
+ xyz = xyz.transpose(2, 1).contiguous()
+ center_xyz = center_xyz.transpose(2, 1).contiguous()
+ assert xyz.is_contiguous() # [B, N, 3]
+ assert center_xyz.is_contiguous() # [B, npoint, 3]
+ center_xyz_device = center_xyz.get_device()
+ assert center_xyz_device == xyz.get_device(), \
+ 'center_xyz and xyz should be put on the same device'
+ if torch.cuda.current_device() != center_xyz_device:
+ torch.cuda.set_device(center_xyz_device)
+ B, npoint, _ = center_xyz.shape
+ N = xyz.shape[1]
+ idx = center_xyz.new_zeros((B, npoint, k)).int()
+ dist2 = center_xyz.new_zeros((B, npoint, k)).float()
+ ext_module.knn_forward(
+ xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
+ # idx shape to [B, k, npoint]
+ idx = idx.transpose(2, 1).contiguous()
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None
+knn = KNN.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/masked_conv.py b/ControlNet/annotator/uniformer/mmcv/ops/masked_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd514cc204c1d571ea5dc7e74b038c0f477a008b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/masked_conv.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['masked_im2col_forward', 'masked_col2im_forward'])
+class MaskedConv2dFunction(Function):
+ @staticmethod
+ def symbolic(g, features, mask, weight, bias, padding, stride):
+ return g.op(
+ 'mmcv::MMCVMaskedConv2d',
+ features,
+ mask,
+ weight,
+ bias,
+ padding_i=padding,
+ stride_i=stride)
+ @staticmethod
+ def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
+ assert mask.dim() == 3 and mask.size(0) == 1
+ assert features.dim() == 4 and features.size(0) == 1
+ assert features.size()[2:] == mask.size()[1:]
+ pad_h, pad_w = _pair(padding)
+ stride_h, stride_w = _pair(stride)
+ if stride_h != 1 or stride_w != 1:
+ raise ValueError(
+ 'Stride could not only be 1 in masked_conv2d currently.')
+ out_channel, in_channel, kernel_h, kernel_w = weight.size()
+ batch_size = features.size(0)
+ out_h = int(
+ math.floor((features.size(2) + 2 * pad_h -
+ (kernel_h - 1) - 1) / stride_h + 1))
+ out_w = int(
+ math.floor((features.size(3) + 2 * pad_w -
+ (kernel_h - 1) - 1) / stride_w + 1))
+ mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
+ output = features.new_zeros(batch_size, out_channel, out_h, out_w)
+ if mask_inds.numel() > 0:
+ mask_h_idx = mask_inds[:, 0].contiguous()
+ mask_w_idx = mask_inds[:, 1].contiguous()
+ data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
+ mask_inds.size(0))
+ ext_module.masked_im2col_forward(
+ features,
+ mask_h_idx,
+ mask_w_idx,
+ data_col,
+ kernel_h=kernel_h,
+ kernel_w=kernel_w,
+ pad_h=pad_h,
+ pad_w=pad_w)
+ masked_output = torch.addmm(1, bias[:, None], 1,
+ weight.view(out_channel, -1), data_col)
+ ext_module.masked_col2im_forward(
+ masked_output,
+ mask_h_idx,
+ mask_w_idx,
+ output,
+ height=out_h,
+ width=out_w,
+ channels=out_channel)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ return (None, ) * 5
+masked_conv2d = MaskedConv2dFunction.apply
+class MaskedConv2d(nn.Conv2d):
+ """A MaskedConv2d which inherits the official Conv2d.
+ The masked forward doesn't implement the backward function and only
+ supports the stride parameter to be 1 currently.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super(MaskedConv2d,
+ self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+ def forward(self, input, mask=None):
+ if mask is None: # fallback to the normal Conv2d
+ return super(MaskedConv2d, self).forward(input)
+ else:
+ return masked_conv2d(input, mask, self.weight, self.bias,
+ self.padding)
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/merge_cells.py b/ControlNet/annotator/uniformer/mmcv/ops/merge_cells.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ca8cc0a8aca8432835bd760c0403a3c35b34cf
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/merge_cells.py
@@ -0,0 +1,149 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..cnn import ConvModule
+class BaseMergeCell(nn.Module):
+ """The basic class for cells used in NAS-FPN and NAS-FCOS.
+ BaseMergeCell takes 2 inputs. After applying convolution
+ on them, they are resized to the target size. Then,
+ they go through binary_op, which depends on the type of cell.
+ If with_out_conv is True, the result of output will go through
+ another convolution layer.
+ Args:
+ in_channels (int): number of input channels in out_conv layer.
+ out_channels (int): number of output channels in out_conv layer.
+ with_out_conv (bool): Whether to use out_conv layer
+ out_conv_cfg (dict): Config dict for convolution layer, which should
+ contain "groups", "kernel_size", "padding", "bias" to build
+ out_conv layer.
+ out_norm_cfg (dict): Config dict for normalization layer in out_conv.
+ out_conv_order (tuple): The order of conv/norm/activation layers in
+ out_conv.
+ with_input1_conv (bool): Whether to use convolution on input1.
+ with_input2_conv (bool): Whether to use convolution on input2.
+ input_conv_cfg (dict): Config dict for building input1_conv layer and
+ input2_conv layer, which is expected to contain the type of
+ convolution.
+ Default: None, which means using conv2d.
+ input_norm_cfg (dict): Config dict for normalization layer in
+ input1_conv and input2_conv layer. Default: None.
+ upsample_mode (str): Interpolation method used to resize the output
+ of input1_conv and input2_conv to target size. Currently, we
+ support ['nearest', 'bilinear']. Default: 'nearest'.
+ """
+ def __init__(self,
+ fused_channels=256,
+ out_channels=256,
+ with_out_conv=True,
+ out_conv_cfg=dict(
+ groups=1, kernel_size=3, padding=1, bias=True),
+ out_norm_cfg=None,
+ out_conv_order=('act', 'conv', 'norm'),
+ with_input1_conv=False,
+ with_input2_conv=False,
+ input_conv_cfg=None,
+ input_norm_cfg=None,
+ upsample_mode='nearest'):
+ super(BaseMergeCell, self).__init__()
+ assert upsample_mode in ['nearest', 'bilinear']
+ self.with_out_conv = with_out_conv
+ self.with_input1_conv = with_input1_conv
+ self.with_input2_conv = with_input2_conv
+ self.upsample_mode = upsample_mode
+ if self.with_out_conv:
+ self.out_conv = ConvModule(
+ fused_channels,
+ out_channels,
+ **out_conv_cfg,
+ norm_cfg=out_norm_cfg,
+ order=out_conv_order)
+ self.input1_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input1_conv else nn.Sequential()
+ self.input2_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input2_conv else nn.Sequential()
+ def _build_input_conv(self, channel, conv_cfg, norm_cfg):
+ return ConvModule(
+ channel,
+ channel,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ bias=True)
+ @abstractmethod
+ def _binary_op(self, x1, x2):
+ pass
+ def _resize(self, x, size):
+ if x.shape[-2:] == size:
+ return x
+ elif x.shape[-2:] < size:
+ return F.interpolate(x, size=size, mode=self.upsample_mode)
+ else:
+ assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
+ kernel_size = x.shape[-1] // size[-1]
+ x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
+ return x
+ def forward(self, x1, x2, out_size=None):
+ assert x1.shape[:2] == x2.shape[:2]
+ assert out_size is None or len(out_size) == 2
+ if out_size is None: # resize to larger one
+ out_size = max(x1.size()[2:], x2.size()[2:])
+ x1 = self.input1_conv(x1)
+ x2 = self.input2_conv(x2)
+ x1 = self._resize(x1, out_size)
+ x2 = self._resize(x2, out_size)
+ x = self._binary_op(x1, x2)
+ if self.with_out_conv:
+ x = self.out_conv(x)
+ return x
+class SumCell(BaseMergeCell):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
+ def _binary_op(self, x1, x2):
+ return x1 + x2
+class ConcatCell(BaseMergeCell):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(ConcatCell, self).__init__(in_channels * 2, out_channels,
+ **kwargs)
+ def _binary_op(self, x1, x2):
+ ret = torch.cat([x1, x2], dim=1)
+ return ret
+class GlobalPoolingCell(BaseMergeCell):
+ def __init__(self, in_channels=None, out_channels=None, **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
+ def _binary_op(self, x1, x2):
+ x2_att = self.global_pool(x2).sigmoid()
+ return x2 + x2_att * x1
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/modulated_deform_conv.py b/ControlNet/annotator/uniformer/mmcv/ops/modulated_deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..75559579cf053abcc99538606cbb88c723faf783
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/modulated_deform_conv.py
@@ -0,0 +1,282 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
+class ModulatedDeformConv2dFunction(Function):
+ @staticmethod
+ def symbolic(g, input, offset, mask, weight, bias, stride, padding,
+ dilation, groups, deform_groups):
+ input_tensors = [input, offset, mask, weight]
+ if bias is not None:
+ input_tensors.append(bias)
+ return g.op(
+ 'mmcv::MMCVModulatedDeformConv2d',
+ *input_tensors,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups)
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(0) # fake tensor
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(
+ ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ ext_module.modulated_deform_conv_forward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ output,
+ ctx._bufs[1],
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ grad_output = grad_output.contiguous()
+ ext_module.modulated_deform_conv_backward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ ctx._bufs[1],
+ grad_input,
+ grad_weight,
+ grad_bias,
+ grad_offset,
+ grad_mask,
+ grad_output,
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+ None, None, None, None, None)
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
+class ModulatedDeformConv2d(nn.Module):
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='ModulatedDeformConv2d')
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=True):
+ super(ModulatedDeformConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups,
+ *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
+ layers.
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int): Same as nn.Conv2d, while tuple is not supported.
+ padding (int): Same as nn.Conv2d, while tuple is not supported.
+ dilation (int): Same as nn.Conv2d, while tuple is not supported.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+ _version = 2
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ bias=True)
+ self.init_weights()
+ def init_weights(self):
+ super(ModulatedDeformConv2dPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, ModulatedDeformConvPack
+ # loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+ if version is not None and version > 1:
+ print_log(
+ f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py b/ControlNet/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52dda18b41705705b47dd0e995b124048c16fba
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py
@@ -0,0 +1,358 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function, once_differentiable
+from annotator.uniformer.mmcv import deprecated_api_warning
+from annotator.uniformer.mmcv.cnn import constant_init, xavier_init
+from annotator.uniformer.mmcv.cnn.bricks.registry import ATTENTION
+from annotator.uniformer.mmcv.runner import BaseModule
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
+class MultiScaleDeformableAttnFunction(Function):
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index,
+ sampling_locations, attention_weights, im2col_step):
+ """GPU version of multi-scale deformable attention.
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+ im2col_step (Tensor): The step used in image to column.
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+ ctx.im2col_step = im2col_step
+ output = ext_module.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step=ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes,
+ value_level_start_index, sampling_locations,
+ attention_weights)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ """GPU version of backward function.
+ Args:
+ grad_output (Tensor): Gradient
+ of output tensor of forward.
+ Returns:
+ Tuple[Tensor]: Gradient
+ of input tensors in forward.
+ """
+ value, value_spatial_shapes, value_level_start_index,\
+ sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value = torch.zeros_like(value)
+ grad_sampling_loc = torch.zeros_like(sampling_locations)
+ grad_attn_weight = torch.zeros_like(attention_weights)
+ ext_module.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output.contiguous(),
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight,
+ im2col_step=ctx.im2col_step)
+ return grad_value, None, None, \
+ grad_sampling_loc, grad_attn_weight, None
+def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
+ sampling_locations, attention_weights):
+ """CPU version of multi-scale deformable attention.
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+ bs, _, num_heads, embed_dims = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ =\
+ sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
+ dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
+ # bs, H_*W_, num_heads, embed_dims ->
+ # bs, H_*W_, num_heads*embed_dims ->
+ # bs, num_heads*embed_dims, H_*W_ ->
+ # bs*num_heads, embed_dims, H_, W_
+ value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
+ bs * num_heads, embed_dims, H_, W_)
+ # bs, num_queries, num_heads, num_points, 2 ->
+ # bs, num_heads, num_queries, num_points, 2 ->
+ # bs*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :,
+ level].transpose(1, 2).flatten(0, 1)
+ # bs*num_heads, embed_dims, num_queries, num_points
+ sampling_value_l_ = F.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ bs * num_heads, 1, num_queries, num_levels * num_points)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
+ attention_weights).sum(-1).view(bs, num_heads * embed_dims,
+ num_queries)
+ return output.transpose(1, 2).contiguous()
+class MultiScaleDeformableAttention(BaseModule):
+ """An attention module used in Deformable-Detr.
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
+ `_.
+ Args:
+ embed_dims (int): The embedding dimension of Attention.
+ Default: 256.
+ num_heads (int): Parallel attention heads. Default: 64.
+ num_levels (int): The number of feature map used in
+ Attention. Default: 4.
+ num_points (int): The number of sampling points for
+ each query in each head. Default: 4.
+ im2col_step (int): The step used in image_to_column.
+ Default: 64.
+ dropout (float): A Dropout layer on `inp_identity`.
+ Default: 0.1.
+ batch_first (bool): Key, Query and Value are shape of
+ (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+ def __init__(self,
+ embed_dims=256,
+ num_heads=8,
+ num_levels=4,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.1,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None):
+ super().__init__(init_cfg)
+ if embed_dims % num_heads != 0:
+ raise ValueError(f'embed_dims must be divisible by num_heads, '
+ f'but got {embed_dims} and {num_heads}')
+ dim_per_head = embed_dims // num_heads
+ self.norm_cfg = norm_cfg
+ self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+ # you'd better set dim_per_head to a power of 2
+ # which is more efficient in the CUDA implementation
+ def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError(
+ 'invalid input for _is_power_of_2: {} (type: {})'.format(
+ n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+ if not _is_power_of_2(dim_per_head):
+ warnings.warn(
+ "You'd better set embed_dims in "
+ 'MultiScaleDeformAttention to make '
+ 'the dimension of each attention head a power of 2 '
+ 'which is more efficient in our CUDA implementation.')
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2)
+ self.attention_weights = nn.Linear(embed_dims,
+ num_heads * num_levels * num_points)
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weights()
+ def init_weights(self):
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.)
+ thetas = torch.arange(
+ self.num_heads,
+ dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init /
+ grid_init.abs().max(-1, keepdim=True)[0]).view(
+ self.num_heads, 1, 1,
+ 2).repeat(1, self.num_levels, self.num_points, 1)
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0., bias=0.)
+ xavier_init(self.value_proj, distribution='uniform', bias=0.)
+ xavier_init(self.output_proj, distribution='uniform', bias=0.)
+ self._is_init = True
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiScaleDeformableAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_padding_mask=None,
+ reference_points=None,
+ spatial_shapes=None,
+ level_start_index=None,
+ **kwargs):
+ """Forward Function of MultiScaleDeformAttention.
+ Args:
+ query (Tensor): Query of Transformer with shape
+ (num_query, bs, embed_dims).
+ key (Tensor): The key tensor with shape
+ `(num_key, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_key, bs, embed_dims)`.
+ identity (Tensor): The tensor used for addition, with the
+ same shape as `query`. Default None. If None,
+ `query` will be used.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`. Default
+ None.
+ reference_points (Tensor): The normalized reference
+ points with shape (bs, num_query, num_levels, 2),
+ all elements is range in [0, 1], top-left (0,0),
+ bottom-right (1, 1), including padding area.
+ or (N, Length_{query}, num_levels, 4), add
+ additional two dimensions is (w, h) to
+ form reference boxes.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_key].
+ spatial_shapes (Tensor): Spatial shape of features in
+ different levels. With shape (num_levels, 2),
+ last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ if value is None:
+ value = query
+ if identity is None:
+ identity = query
+ if query_pos is not None:
+ query = query + query_pos
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points)
+ attention_weights = attention_weights.softmax(-1)
+ attention_weights = attention_weights.view(bs, num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points)
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets \
+ / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.num_points \
+ * reference_points[:, :, None, :, None, 2:] \
+ * 0.5
+ else:
+ raise ValueError(
+ f'Last dim of reference_points must be'
+ f' 2 or 4, but get {reference_points.shape[-1]} instead.')
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MultiScaleDeformableAttnFunction.apply(
+ value, spatial_shapes, level_start_index, sampling_locations,
+ attention_weights, self.im2col_step)
+ else:
+ output = multi_scale_deformable_attn_pytorch(
+ value, spatial_shapes, sampling_locations, attention_weights)
+ output = self.output_proj(output)
+ if not self.batch_first:
+ # (num_query, bs ,embed_dims)
+ output = output.permute(1, 0, 2)
+ return self.dropout(output) + identity
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/nms.py b/ControlNet/annotator/uniformer/mmcv/ops/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9634281f486ab284091786886854c451368052
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/nms.py
@@ -0,0 +1,417 @@
+import os
+import numpy as np
+import torch
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated'])
+# This function is modified from: https://github.com/pytorch/vision/
+class NMSop(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ is_filtering_by_score = score_threshold > 0
+ if is_filtering_by_score:
+ valid_mask = scores > score_threshold
+ bboxes, scores = bboxes[valid_mask], scores[valid_mask]
+ valid_inds = torch.nonzero(
+ valid_mask, as_tuple=False).squeeze(dim=1)
+ inds = ext_module.nms(
+ bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
+ if max_num > 0:
+ inds = inds[:max_num]
+ if is_filtering_by_score:
+ inds = valid_inds[inds]
+ return inds
+ @staticmethod
+ def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ # TensorRT nms plugin is aligned with original nms in ONNXRuntime
+ is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+ if has_custom_op and (not is_trt_backend):
+ return g.op(
+ 'mmcv::NonMaxSuppression',
+ bboxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ offset_i=int(offset))
+ else:
+ from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
+ from ..onnx.onnx_utils.symbolic_helper import _size_helper
+ boxes = unsqueeze(g, bboxes, 0)
+ scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
+ if max_num > 0:
+ max_num = g.op(
+ 'Constant',
+ value_t=torch.tensor(max_num, dtype=torch.long))
+ else:
+ dim = g.op('Constant', value_t=torch.tensor(0))
+ max_num = _size_helper(g, bboxes, dim)
+ max_output_per_class = max_num
+ iou_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([iou_threshold], dtype=torch.float))
+ score_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([score_threshold], dtype=torch.float))
+ nms_out = g.op('NonMaxSuppression', boxes, scores,
+ max_output_per_class, iou_threshold,
+ score_threshold)
+ return squeeze(
+ g,
+ select(
+ g, nms_out, 1,
+ g.op(
+ 'Constant',
+ value_t=torch.tensor([2], dtype=torch.long))), 1)
+class SoftNMSop(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ inds = ext_module.softnms(
+ boxes.cpu(),
+ scores.cpu(),
+ dets.cpu(),
+ iou_threshold=float(iou_threshold),
+ sigma=float(sigma),
+ min_score=float(min_score),
+ method=int(method),
+ offset=int(offset))
+ return dets, inds
+ @staticmethod
+ def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ from packaging import version
+ assert version.parse(torch.__version__) >= version.parse('1.7.0')
+ nms_out = g.op(
+ 'mmcv::SoftNonMaxSuppression',
+ boxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ sigma_f=float(sigma),
+ min_score_f=float(min_score),
+ method_i=int(method),
+ offset_i=int(offset),
+ outputs=2)
+ return nms_out
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
+ """Dispatch to either CPU or GPU NMS implementations.
+ The input can be either torch tensor or numpy array. GPU NMS will be used
+ if the input is gpu tensor, otherwise CPU NMS
+ will be used. The returned type will always be the same as inputs.
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+ score_threshold (float): score threshold for NMS.
+ max_num (int): maximum number of boxes after NMS.
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+ Example:
+ >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9],
+ >>> [49.3, 32.9, 51.0, 35.3],
+ >>> [49.2, 31.8, 51.0, 35.4],
+ >>> [35.1, 11.5, 39.1, 15.7],
+ >>> [35.6, 11.8, 39.3, 14.2],
+ >>> [35.3, 11.5, 39.9, 14.5],
+ >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\
+ dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = nms(boxes, scores, iou_threshold)
+ >>> assert len(inds) == len(dets) == 3
+ """
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+ if torch.__version__ == 'parrots':
+ indata_list = [boxes, scores]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'offset': int(offset)
+ }
+ inds = ext_module.nms(*indata_list, **indata_dict)
+ else:
+ inds = NMSop.apply(boxes, scores, iou_threshold, offset,
+ score_threshold, max_num)
+ dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def soft_nms(boxes,
+ scores,
+ iou_threshold=0.3,
+ sigma=0.5,
+ min_score=1e-3,
+ method='linear',
+ offset=0):
+ """Dispatch to only CPU Soft NMS implementations.
+ The input can be either a torch tensor or numpy array.
+ The returned type will always be the same as inputs.
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ sigma (float): hyperparameter for gaussian method
+ min_score (float): score filter threshold
+ method (str): either 'linear' or 'gaussian'
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+ Example:
+ >>> boxes = np.array([[4., 3., 5., 3.],
+ >>> [4., 3., 5., 4.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.4, 0.0], dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = soft_nms(boxes, scores, iou_threshold, sigma=0.5)
+ >>> assert len(inds) == len(dets) == 5
+ """
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+ method_dict = {'naive': 0, 'linear': 1, 'gaussian': 2}
+ assert method in method_dict.keys()
+ if torch.__version__ == 'parrots':
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ indata_list = [boxes.cpu(), scores.cpu(), dets.cpu()]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'sigma': float(sigma),
+ 'min_score': min_score,
+ 'method': method_dict[method],
+ 'offset': int(offset)
+ }
+ inds = ext_module.softnms(*indata_list, **indata_dict)
+ else:
+ dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(),
+ float(iou_threshold), float(sigma),
+ float(min_score), method_dict[method],
+ int(offset))
+ dets = dets[:inds.size(0)]
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+ else:
+ return dets.to(device=boxes.device), inds.to(device=boxes.device)
+def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
+ """Performs non-maximum suppression in a batched fashion.
+ Modified from https://github.com/pytorch/vision/blob
+ /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
+ In order to perform NMS independently per class, we add an offset to all
+ the boxes. The offset is dependent only on the class idx, and is large
+ enough so that boxes from different classes do not overlap.
+ Arguments:
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ idxs (torch.Tensor): each index value correspond to a bbox cluster,
+ and NMS will not be applied between elements of different idxs,
+ shape (N, ).
+ nms_cfg (dict): specify nms type and other parameters like iou_thr.
+ Possible keys includes the following.
+ - iou_thr (float): IoU threshold used for NMS.
+ - split_thr (float): threshold number of boxes. In some cases the
+ number of boxes is large (e.g., 200k). To avoid OOM during
+ training, the users could set `split_thr` to a small value.
+ If the number of boxes is greater than the threshold, it will
+ perform NMS on each group of boxes separately and sequentially.
+ Defaults to 10000.
+ class_agnostic (bool): if true, nms is class agnostic,
+ i.e. IoU thresholding happens over all boxes,
+ regardless of the predicted class.
+ Returns:
+ tuple: kept dets and indice.
+ """
+ nms_cfg_ = nms_cfg.copy()
+ class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
+ if class_agnostic:
+ boxes_for_nms = boxes
+ else:
+ max_coordinate = boxes.max()
+ offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
+ boxes_for_nms = boxes + offsets[:, None]
+ nms_type = nms_cfg_.pop('type', 'nms')
+ nms_op = eval(nms_type)
+ split_thr = nms_cfg_.pop('split_thr', 10000)
+ # Won't split to multiple nms nodes when exporting to onnx
+ if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export():
+ dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
+ boxes = boxes[keep]
+ # -1 indexing works abnormal in TensorRT
+ # This assumes `dets` has 5 dimensions where
+ # the last dimension is score.
+ # TODO: more elegant way to handle the dimension issue.
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores = dets[:, 4]
+ else:
+ max_num = nms_cfg_.pop('max_num', -1)
+ total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores_after_nms = scores.new_zeros(scores.size())
+ for id in torch.unique(idxs):
+ mask = (idxs == id).nonzero(as_tuple=False).view(-1)
+ dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_)
+ total_mask[mask[keep]] = True
+ scores_after_nms[mask[keep]] = dets[:, -1]
+ keep = total_mask.nonzero(as_tuple=False).view(-1)
+ scores, inds = scores_after_nms[keep].sort(descending=True)
+ keep = keep[inds]
+ boxes = boxes[keep]
+ if max_num > 0:
+ keep = keep[:max_num]
+ boxes = boxes[:max_num]
+ scores = scores[:max_num]
+ return torch.cat([boxes, scores[:, None]], -1), keep
+def nms_match(dets, iou_threshold):
+ """Matched dets into different groups by NMS.
+ NMS match is Similar to NMS but when a bbox is suppressed, nms match will
+ record the indice of suppressed bbox and form a group with the indice of
+ kept bbox. In each group, indice is sorted as score order.
+ Arguments:
+ dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5).
+ iou_thr (float): IoU thresh for NMS.
+ Returns:
+ List[torch.Tensor | np.ndarray]: The outer list corresponds different
+ matched group, the inner Tensor corresponds the indices for a group
+ in score order.
+ """
+ if dets.shape[0] == 0:
+ matched = []
+ else:
+ assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \
+ f'but get {dets.shape}'
+ if isinstance(dets, torch.Tensor):
+ dets_t = dets.detach().cpu()
+ else:
+ dets_t = torch.from_numpy(dets)
+ indata_list = [dets_t]
+ indata_dict = {'iou_threshold': float(iou_threshold)}
+ matched = ext_module.nms_match(*indata_list, **indata_dict)
+ if torch.__version__ == 'parrots':
+ matched = matched.tolist()
+ if isinstance(dets, torch.Tensor):
+ return [dets.new_tensor(m, dtype=torch.long) for m in matched]
+ else:
+ return [np.array(m, dtype=np.int) for m in matched]
+def nms_rotated(dets, scores, iou_threshold, labels=None):
+ """Performs non-maximum suppression (NMS) on the rotated boxes according to
+ their intersection-over-union (IoU).
+ Rotated NMS iteratively removes lower scoring rotated boxes which have an
+ IoU greater than iou_threshold with another (higher scoring) rotated box.
+ Args:
+ boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to \
+ be in (x_ctr, y_ctr, width, height, angle_radian) format.
+ scores (Tensor): scores in shape (N, ).
+ iou_threshold (float): IoU thresh for NMS.
+ labels (Tensor): boxes' label in shape (N,).
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+ """
+ if dets.shape[0] == 0:
+ return dets, None
+ multi_label = labels is not None
+ if multi_label:
+ dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1)
+ else:
+ dets_wl = dets
+ _, order = scores.sort(0, descending=True)
+ dets_sorted = dets_wl.index_select(0, order)
+ if torch.__version__ == 'parrots':
+ keep_inds = ext_module.nms_rotated(
+ dets_wl,
+ scores,
+ order,
+ dets_sorted,
+ iou_threshold=iou_threshold,
+ multi_label=multi_label)
+ else:
+ keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
+ iou_threshold, multi_label)
+ dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
+ dim=1)
+ return dets, keep_inds
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/pixel_group.py b/ControlNet/annotator/uniformer/mmcv/ops/pixel_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..2143c75f835a467c802fc3c37ecd3ac0f85bcda4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/pixel_group.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['pixel_group'])
+def pixel_group(score, mask, embedding, kernel_label, kernel_contour,
+ kernel_region_num, distance_threshold):
+ """Group pixels into text instances, which is widely used text detection
+ methods.
+ Arguments:
+ score (np.array or Tensor): The foreground score with size hxw.
+ mask (np.array or Tensor): The foreground mask with size hxw.
+ embedding (np.array or Tensor): The embedding with size hxwxc to
+ distinguish instances.
+ kernel_label (np.array or Tensor): The instance kernel index with
+ size hxw.
+ kernel_contour (np.array or Tensor): The kernel contour with size hxw.
+ kernel_region_num (int): The instance kernel region number.
+ distance_threshold (float): The embedding distance threshold between
+ kernel and pixel in one instance.
+ Returns:
+ pixel_assignment (List[List[float]]): The instance coordinate list.
+ Each element consists of averaged confidence, pixel number, and
+ coordinates (x_i, y_i for all pixels) in order.
+ """
+ assert isinstance(score, (torch.Tensor, np.ndarray))
+ assert isinstance(mask, (torch.Tensor, np.ndarray))
+ assert isinstance(embedding, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_contour, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_region_num, int)
+ assert isinstance(distance_threshold, float)
+ if isinstance(score, np.ndarray):
+ score = torch.from_numpy(score)
+ if isinstance(mask, np.ndarray):
+ mask = torch.from_numpy(mask)
+ if isinstance(embedding, np.ndarray):
+ embedding = torch.from_numpy(embedding)
+ if isinstance(kernel_label, np.ndarray):
+ kernel_label = torch.from_numpy(kernel_label)
+ if isinstance(kernel_contour, np.ndarray):
+ kernel_contour = torch.from_numpy(kernel_contour)
+ if torch.__version__ == 'parrots':
+ label = ext_module.pixel_group(
+ score,
+ mask,
+ embedding,
+ kernel_label,
+ kernel_contour,
+ kernel_region_num=kernel_region_num,
+ distance_threshold=distance_threshold)
+ label = label.tolist()
+ label = label[0]
+ list_index = kernel_region_num
+ pixel_assignment = []
+ for x in range(kernel_region_num):
+ pixel_assignment.append(
+ np.array(
+ label[list_index:list_index + int(label[x])],
+ dtype=np.float))
+ list_index = list_index + int(label[x])
+ else:
+ pixel_assignment = ext_module.pixel_group(score, mask, embedding,
+ kernel_label, kernel_contour,
+ kernel_region_num,
+ distance_threshold)
+ return pixel_assignment
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/point_sample.py b/ControlNet/annotator/uniformer/mmcv/ops/point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..267f4b3c56630acd85f9bdc630b7be09abab0aba
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/point_sample.py
@@ -0,0 +1,336 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
+from os import path as osp
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair
+from torch.onnx.operators import shape_as_tensor
+def bilinear_grid_sample(im, grid, align_corners=False):
+ """Given an input and a flow-field grid, computes the output using input
+ values and pixel locations from grid. Supported only bilinear interpolation
+ method to sample the input pixels.
+ Args:
+ im (torch.Tensor): Input feature map, shape (N, C, H, W)
+ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
+ align_corners {bool}: If set to True, the extrema (-1 and 1) are
+ considered as referring to the center points of the input’s
+ corner pixels. If set to False, they are instead considered as
+ referring to the corner points of the input’s corner pixels,
+ making the sampling more resolution agnostic.
+ Returns:
+ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
+ """
+ n, c, h, w = im.shape
+ gn, gh, gw, _ = grid.shape
+ assert n == gn
+ x = grid[:, :, :, 0]
+ y = grid[:, :, :, 1]
+ if align_corners:
+ x = ((x + 1) / 2) * (w - 1)
+ y = ((y + 1) / 2) * (h - 1)
+ else:
+ x = ((x + 1) * w - 1) / 2
+ y = ((y + 1) * h - 1) / 2
+ x = x.view(n, -1)
+ y = y.view(n, -1)
+ x0 = torch.floor(x).long()
+ y0 = torch.floor(y).long()
+ x1 = x0 + 1
+ y1 = y0 + 1
+ wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
+ wb = ((x1 - x) * (y - y0)).unsqueeze(1)
+ wc = ((x - x0) * (y1 - y)).unsqueeze(1)
+ wd = ((x - x0) * (y - y0)).unsqueeze(1)
+ # Apply default for grid_sample function zero padding
+ im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
+ padded_h = h + 2
+ padded_w = w + 2
+ # save points positions after padding
+ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
+ # Clip coordinates to padded image size
+ x0 = torch.where(x0 < 0, torch.tensor(0), x0)
+ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
+ x1 = torch.where(x1 < 0, torch.tensor(0), x1)
+ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
+ y0 = torch.where(y0 < 0, torch.tensor(0), y0)
+ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
+ y1 = torch.where(y1 < 0, torch.tensor(0), y1)
+ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
+ im_padded = im_padded.view(n, c, -1)
+ x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ Ia = torch.gather(im_padded, 2, x0_y0)
+ Ib = torch.gather(im_padded, 2, x0_y1)
+ Ic = torch.gather(im_padded, 2, x1_y0)
+ Id = torch.gather(im_padded, 2, x1_y1)
+ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
+def is_in_onnx_export_without_custom_ops():
+ from annotator.uniformer.mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ return torch.onnx.is_in_onnx_export(
+ ) and not osp.exists(ort_custom_op_path)
+def normalize(grid):
+ """Normalize input grid from [-1, 1] to [0, 1]
+ Args:
+ grid (Tensor): The grid to be normalize, range [-1, 1].
+ Returns:
+ Tensor: Normalized grid, range [0, 1].
+ """
+ return (grid + 1.0) / 2.0
+def denormalize(grid):
+ """Denormalize input grid from range [0, 1] to [-1, 1]
+ Args:
+ grid (Tensor): The grid to be denormalize, range [0, 1].
+ Returns:
+ Tensor: Denormalized grid, range [-1, 1].
+ """
+ return grid * 2.0 - 1.0
+def generate_grid(num_grid, size, device):
+ """Generate regular square grid of points in [0, 1] x [0, 1] coordinate
+ space.
+ Args:
+ num_grid (int): The number of grids to sample, one for each region.
+ size (tuple(int, int)): The side size of the regular grid.
+ device (torch.device): Desired device of returned tensor.
+ Returns:
+ (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that
+ contains coordinates for the regular grids.
+ """
+ affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device)
+ grid = F.affine_grid(
+ affine_trans, torch.Size((1, 1, *size)), align_corners=False)
+ grid = normalize(grid)
+ return grid.view(1, -1, 2).expand(num_grid, -1, -1)
+def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ Returns:
+ Tensor: Image based absolute point coordinates, shape (N, P, 2)
+ """
+ with torch.no_grad():
+ assert rel_roi_points.size(0) == rois.size(0)
+ assert rois.dim() == 2
+ assert rel_roi_points.dim() == 3
+ assert rel_roi_points.size(2) == 2
+ # remove batch idx
+ if rois.size(1) == 5:
+ rois = rois[:, 1:]
+ abs_img_points = rel_roi_points.clone()
+ # To avoid an error during exporting to onnx use independent
+ # variables instead inplace computation
+ xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
+ ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
+ xs += rois[:, None, 0]
+ ys += rois[:, None, 1]
+ abs_img_points = torch.stack([xs, ys], dim=2)
+ return abs_img_points
+def get_shape_from_feature_map(x):
+ """Get spatial resolution of input feature map considering exporting to
+ onnx mode.
+ Args:
+ x (torch.Tensor): Input tensor, shape (N, C, H, W)
+ Returns:
+ torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
+ """
+ if torch.onnx.is_in_onnx_export():
+ img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
+ x.device).float()
+ else:
+ img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
+ x.device).float()
+ return img_shape
+def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
+ """Convert image based absolute point coordinates to image based relative
+ coordinates for sampling.
+ Args:
+ abs_img_points (Tensor): Image based absolute point coordinates,
+ shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+ assert (isinstance(img, tuple) and len(img) == 2) or \
+ (isinstance(img, torch.Tensor) and len(img.shape) == 4)
+ if isinstance(img, tuple):
+ h, w = img
+ scale = torch.tensor([w, h],
+ dtype=torch.float,
+ device=abs_img_points.device)
+ scale = scale.view(1, 1, 2)
+ else:
+ scale = get_shape_from_feature_map(img)
+ return abs_img_points / scale * spatial_scale
+def rel_roi_point_to_rel_img_point(rois,
+ rel_roi_points,
+ img,
+ spatial_scale=1.):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+ abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
+ rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
+ spatial_scale)
+ return rel_img_point
+def point_sample(input, points, align_corners=False, **kwargs):
+ """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
+ Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
+ lie inside ``[0, 1] x [0, 1]`` square.
+ Args:
+ input (Tensor): Feature map, shape (N, C, H, W).
+ points (Tensor): Image based absolute point coordinates (normalized),
+ range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
+ align_corners (bool): Whether align_corners. Default: False
+ Returns:
+ Tensor: Features of `point` on `input`, shape (N, C, P) or
+ (N, C, Hgrid, Wgrid).
+ """
+ add_dim = False
+ if points.dim() == 3:
+ add_dim = True
+ points = points.unsqueeze(2)
+ if is_in_onnx_export_without_custom_ops():
+ # If custom ops for onnx runtime not compiled use python
+ # implementation of grid_sample function to make onnx graph
+ # with supported nodes
+ output = bilinear_grid_sample(
+ input, denormalize(points), align_corners=align_corners)
+ else:
+ output = F.grid_sample(
+ input, denormalize(points), align_corners=align_corners, **kwargs)
+ if add_dim:
+ output = output.squeeze(3)
+ return output
+class SimpleRoIAlign(nn.Module):
+ def __init__(self, output_size, spatial_scale, aligned=True):
+ """Simple RoI align in PointRend, faster than standard RoIAlign.
+ Args:
+ output_size (tuple[int]): h, w
+ spatial_scale (float): scale the input boxes by this number
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection, align_corners=True will be used in F.grid_sample.
+ If True, align the results more perfectly.
+ """
+ super(SimpleRoIAlign, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ # to be consistent with other RoI ops
+ self.use_torchvision = False
+ self.aligned = aligned
+ def forward(self, features, rois):
+ num_imgs = features.size(0)
+ num_rois = rois.size(0)
+ rel_roi_points = generate_grid(
+ num_rois, self.output_size, device=rois.device)
+ if torch.onnx.is_in_onnx_export():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, features, self.spatial_scale)
+ rel_img_points = rel_img_points.reshape(num_imgs, -1,
+ *rel_img_points.shape[1:])
+ point_feats = point_sample(
+ features, rel_img_points, align_corners=not self.aligned)
+ point_feats = point_feats.transpose(1, 2)
+ else:
+ point_feats = []
+ for batch_ind in range(num_imgs):
+ # unravel batch dim
+ feat = features[batch_ind].unsqueeze(0)
+ inds = (rois[:, 0].long() == batch_ind)
+ if inds.any():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois[inds], rel_roi_points[inds], feat,
+ self.spatial_scale).unsqueeze(0)
+ point_feat = point_sample(
+ feat, rel_img_points, align_corners=not self.aligned)
+ point_feat = point_feat.squeeze(0).transpose(0, 1)
+ point_feats.append(point_feat)
+ point_feats = torch.cat(point_feats, dim=0)
+ channels = features.size(1)
+ roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
+ return roi_feats
+ def __repr__(self):
+ format_str = self.__class__.__name__
+ format_str += '(output_size={}, spatial_scale={}'.format(
+ self.output_size, self.spatial_scale)
+ return format_str
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/points_in_boxes.py b/ControlNet/annotator/uniformer/mmcv/ops/points_in_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4003173a53052161dbcd687a2fa1d755642fdab8
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/points_in_boxes.py
@@ -0,0 +1,133 @@
+import torch
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', [
+ 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward',
+ 'points_in_boxes_all_forward'
+def points_in_boxes_part(points, boxes):
+ """Find the box in which each point is (CUDA).
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
+ LiDAR/DEPTH coordinate, (x, y, z) is the bottom center
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points),
+ dtype=torch.int).fill_(-1)
+ # If manually put the tensor 'points' or 'boxes' on a device
+ # which is not the current device, some temporary variables
+ # will be created on the current device in the cuda op,
+ # and the output will be incorrect.
+ # Therefore, we force the current device to be the same
+ # as the device of the tensors if it was not.
+ # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
+ # for the incorrect output before the fix.
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+ ext_module.points_in_boxes_part_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+ return box_idxs_of_pts
+def points_in_boxes_cpu(points, boxes):
+ """Find all boxes in which each point is (CPU). The CPU version of
+ :meth:`points_in_boxes_all`.
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in
+ LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+ point_indices = points.new_zeros((batch_size, num_boxes, num_points),
+ dtype=torch.int)
+ for b in range(batch_size):
+ ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(),
+ points[b].float().contiguous(),
+ point_indices[b])
+ point_indices = point_indices.transpose(1, 2)
+ return point_indices
+def points_in_boxes_all(points, boxes):
+ """Find all boxes in which each point is (CUDA).
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert boxes.shape[0] == points.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {boxes.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
+ dtype=torch.int).fill_(0)
+ # Same reason as line 25-32
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+ ext_module.points_in_boxes_all_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+ return box_idxs_of_pts
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/points_sampler.py b/ControlNet/annotator/uniformer/mmcv/ops/points_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a802a74fd6c3610d9ae178e6201f47423eca7ad1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/points_sampler.py
@@ -0,0 +1,177 @@
+from typing import List
+import torch
+from torch import nn as nn
+from annotator.uniformer.mmcv.runner import force_fp32
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+def calc_square_dist(point_feat_a, point_feat_b, norm=True):
+ """Calculating square distance between a and b.
+ Args:
+ point_feat_a (Tensor): (B, N, C) Feature vector of each point.
+ point_feat_b (Tensor): (B, M, C) Feature vector of each point.
+ norm (Bool, optional): Whether to normalize the distance.
+ Default: True.
+ Returns:
+ Tensor: (B, N, M) Distance between each pair points.
+ """
+ num_channel = point_feat_a.shape[-1]
+ # [bs, n, 1]
+ a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
+ # [bs, 1, m]
+ b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
+ corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
+ dist = a_square + b_square - 2 * corr_matrix
+ if norm:
+ dist = torch.sqrt(dist) / num_channel
+ return dist
+def get_sampler_cls(sampler_type):
+ """Get the type and mode of points sampler.
+ Args:
+ sampler_type (str): The type of points sampler.
+ The valid value are "D-FPS", "F-FPS", or "FS".
+ Returns:
+ class: Points sampler type.
+ """
+ sampler_mappings = {
+ 'D-FPS': DFPSSampler,
+ 'F-FPS': FFPSSampler,
+ 'FS': FSSampler,
+ }
+ try:
+ return sampler_mappings[sampler_type]
+ except KeyError:
+ raise KeyError(
+ f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
+ {sampler_type}')
+class PointsSampler(nn.Module):
+ """Points sampling.
+ Args:
+ num_point (list[int]): Number of sample points.
+ fps_mod_list (list[str], optional): Type of FPS method, valid mod
+ ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
+ F-FPS: using feature distances for FPS.
+ D-FPS: using Euclidean distances of points for FPS.
+ FS: using F-FPS and D-FPS simultaneously.
+ fps_sample_range_list (list[int], optional):
+ Range of points to apply FPS. Default: [-1].
+ """
+ def __init__(self,
+ num_point: List[int],
+ fps_mod_list: List[str] = ['D-FPS'],
+ fps_sample_range_list: List[int] = [-1]):
+ super().__init__()
+ # FPS would be applied to different fps_mod in the list,
+ # so the length of the num_point should be equal to
+ # fps_mod_list and fps_sample_range_list.
+ assert len(num_point) == len(fps_mod_list) == len(
+ fps_sample_range_list)
+ self.num_point = num_point
+ self.fps_sample_range_list = fps_sample_range_list
+ self.samplers = nn.ModuleList()
+ for fps_mod in fps_mod_list:
+ self.samplers.append(get_sampler_cls(fps_mod)())
+ self.fp16_enabled = False
+ @force_fp32()
+ def forward(self, points_xyz, features):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ features (Tensor): (B, C, N) Descriptors of the features.
+ Returns:
+ Tensor: (B, npoint, sample_num) Indices of sampled points.
+ """
+ indices = []
+ last_fps_end_index = 0
+ for fps_sample_range, sampler, npoint in zip(
+ self.fps_sample_range_list, self.samplers, self.num_point):
+ assert fps_sample_range < points_xyz.shape[1]
+ if fps_sample_range == -1:
+ sample_points_xyz = points_xyz[:, last_fps_end_index:]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:]
+ else:
+ sample_features = None
+ else:
+ sample_points_xyz = \
+ points_xyz[:, last_fps_end_index:fps_sample_range]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:
+ fps_sample_range]
+ else:
+ sample_features = None
+ fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
+ npoint)
+ indices.append(fps_idx + last_fps_end_index)
+ last_fps_end_index += fps_sample_range
+ indices = torch.cat(indices, dim=1)
+ return indices
+class DFPSSampler(nn.Module):
+ """Using Euclidean distances of points for FPS."""
+ def __init__(self):
+ super().__init__()
+ def forward(self, points, features, npoint):
+ """Sampling points with D-FPS."""
+ fps_idx = furthest_point_sample(points.contiguous(), npoint)
+ return fps_idx
+class FFPSSampler(nn.Module):
+ """Using feature distances for FPS."""
+ def __init__(self):
+ super().__init__()
+ def forward(self, points, features, npoint):
+ """Sampling points with F-FPS."""
+ assert features is not None, \
+ 'feature input to FFPS_Sampler should not be None'
+ features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
+ features_dist = calc_square_dist(
+ features_for_fps, features_for_fps, norm=False)
+ fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
+ return fps_idx
+class FSSampler(nn.Module):
+ """Using F-FPS and D-FPS simultaneously."""
+ def __init__(self):
+ super().__init__()
+ def forward(self, points, features, npoint):
+ """Sampling points with FS_Sampling."""
+ assert features is not None, \
+ 'feature input to FS_Sampler should not be None'
+ ffps_sampler = FFPSSampler()
+ dfps_sampler = DFPSSampler()
+ fps_idx_ffps = ffps_sampler(points, features, npoint)
+ fps_idx_dfps = dfps_sampler(points, features, npoint)
+ fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
+ return fps_idx
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/psa_mask.py b/ControlNet/annotator/uniformer/mmcv/ops/psa_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf14e62b50e8d4dd6856c94333c703bcc4c9ab6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/psa_mask.py
@@ -0,0 +1,92 @@
+# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext',
+ ['psamask_forward', 'psamask_backward'])
+class PSAMaskFunction(Function):
+ @staticmethod
+ def symbolic(g, input, psa_type, mask_size):
+ return g.op(
+ 'mmcv::MMCVPSAMask',
+ input,
+ psa_type_i=psa_type,
+ mask_size_i=mask_size)
+ @staticmethod
+ def forward(ctx, input, psa_type, mask_size):
+ ctx.psa_type = psa_type
+ ctx.mask_size = _pair(mask_size)
+ ctx.save_for_backward(input)
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ assert channels == h_mask * w_mask
+ output = input.new_zeros(
+ (batch_size, h_feature * w_feature, h_feature, w_feature))
+ ext_module.psamask_forward(
+ input,
+ output,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ input = ctx.saved_tensors[0]
+ psa_type = ctx.psa_type
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ grad_input = grad_output.new_zeros(
+ (batch_size, channels, h_feature, w_feature))
+ ext_module.psamask_backward(
+ grad_output,
+ grad_input,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return grad_input, None, None, None
+psa_mask = PSAMaskFunction.apply
+class PSAMask(nn.Module):
+ def __init__(self, psa_type, mask_size=None):
+ super(PSAMask, self).__init__()
+ assert psa_type in ['collect', 'distribute']
+ if psa_type == 'collect':
+ psa_type_enum = 0
+ else:
+ psa_type_enum = 1
+ self.psa_type_enum = psa_type_enum
+ self.mask_size = mask_size
+ self.psa_type = psa_type
+ def forward(self, input):
+ return psa_mask(input, self.psa_type_enum, self.mask_size)
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(psa_type={self.psa_type}, '
+ s += f'mask_size={self.mask_size})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/roi_align.py b/ControlNet/annotator/uniformer/mmcv/ops/roi_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..0755aefc66e67233ceae0f4b77948301c443e9fb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/roi_align.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+from ..utils import deprecated_api_warning, ext_loader
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_align_forward', 'roi_align_backward'])
+class RoIAlignFunction(Function):
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
+ pool_mode, aligned):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ if has_custom_op:
+ return g.op(
+ 'mmcv::MMCVRoiAlign',
+ input,
+ rois,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sampling_ratio,
+ mode_s=pool_mode,
+ aligned_i=aligned)
+ else:
+ from torch.onnx.symbolic_opset9 import sub, squeeze
+ from torch.onnx.symbolic_helper import _slice_helper
+ from torch.onnx import TensorProtoDataType
+ # batch_indices = rois[:, 0].long()
+ batch_indices = _slice_helper(
+ g, rois, axes=[1], starts=[0], ends=[1])
+ batch_indices = squeeze(g, batch_indices, 1)
+ batch_indices = g.op(
+ 'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
+ # rois = rois[:, 1:]
+ rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
+ if aligned:
+ # rois -= 0.5/spatial_scale
+ aligned_offset = g.op(
+ 'Constant',
+ value_t=torch.tensor([0.5 / spatial_scale],
+ dtype=torch.float32))
+ rois = sub(g, rois, aligned_offset)
+ # roi align
+ return g.op(
+ 'RoiAlign',
+ input,
+ rois,
+ batch_indices,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=max(0, sampling_ratio),
+ mode_s=pool_mode)
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.sampling_ratio = sampling_ratio
+ assert pool_mode in ('max', 'avg')
+ ctx.pool_mode = 0 if pool_mode == 'max' else 1
+ ctx.aligned = aligned
+ ctx.input_shape = input.size()
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ if ctx.pool_mode == 0:
+ argmax_y = input.new_zeros(output_shape)
+ argmax_x = input.new_zeros(output_shape)
+ else:
+ argmax_y = input.new_zeros(0)
+ argmax_x = input.new_zeros(0)
+ ext_module.roi_align_forward(
+ input,
+ rois,
+ output,
+ argmax_y,
+ argmax_x,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+ ctx.save_for_backward(rois, argmax_y, argmax_x)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax_y, argmax_x = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous.
+ grad_output = grad_output.contiguous()
+ ext_module.roi_align_backward(
+ grad_output,
+ rois,
+ argmax_y,
+ argmax_x,
+ grad_input,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+ return grad_input, None, None, None, None, None, None
+roi_align = RoIAlignFunction.apply
+class RoIAlign(nn.Module):
+ """RoI align pooling layer.
+ Args:
+ output_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sampling_ratio (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ use_torchvision (bool): whether to use roi_align from torchvision.
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+ The meaning of aligned=True:
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+ @deprecated_api_warning(
+ {
+ 'out_size': 'output_size',
+ 'sample_num': 'sampling_ratio'
+ },
+ cls_name='RoIAlign')
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True,
+ use_torchvision=False):
+ super(RoIAlign, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.pool_mode = pool_mode
+ self.aligned = aligned
+ self.use_torchvision = use_torchvision
+ def forward(self, input, rois):
+ """
+ Args:
+ input: NCHW images
+ rois: Bx5 boxes. First column is the index into N.\
+ The other 4 columns are xyxy.
+ """
+ if self.use_torchvision:
+ from torchvision.ops import roi_align as tv_roi_align
+ if 'aligned' in tv_roi_align.__code__.co_varnames:
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.aligned)
+ else:
+ if self.aligned:
+ rois -= rois.new_tensor([0.] +
+ [0.5 / self.spatial_scale] * 4)
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio)
+ else:
+ return roi_align(input, rois, self.output_size, self.spatial_scale,
+ self.sampling_ratio, self.pool_mode, self.aligned)
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale}, '
+ s += f'sampling_ratio={self.sampling_ratio}, '
+ s += f'pool_mode={self.pool_mode}, '
+ s += f'aligned={self.aligned}, '
+ s += f'use_torchvision={self.use_torchvision})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/roi_align_rotated.py b/ControlNet/annotator/uniformer/mmcv/ops/roi_align_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce4961a3555d4da8bc3e32f1f7d5ad50036587d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/roi_align_rotated.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
+class RoIAlignRotatedFunction(Function):
+ @staticmethod
+ def symbolic(g, features, rois, out_size, spatial_scale, sample_num,
+ aligned, clockwise):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ return g.op(
+ 'mmcv::MMCVRoIAlignRotated',
+ features,
+ rois,
+ output_height_i=out_h,
+ output_width_i=out_h,
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sample_num,
+ aligned_i=aligned,
+ clockwise_i=clockwise)
+ @staticmethod
+ def forward(ctx,
+ features,
+ rois,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ ctx.spatial_scale = spatial_scale
+ ctx.sample_num = sample_num
+ ctx.aligned = aligned
+ ctx.clockwise = clockwise
+ ctx.save_for_backward(rois)
+ ctx.feature_size = features.size()
+ batch_size, num_channels, data_height, data_width = features.size()
+ num_rois = rois.size(0)
+ output = features.new_zeros(num_rois, num_channels, out_h, out_w)
+ ext_module.roi_align_rotated_forward(
+ features,
+ rois,
+ output,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return output
+ @staticmethod
+ def backward(ctx, grad_output):
+ feature_size = ctx.feature_size
+ spatial_scale = ctx.spatial_scale
+ aligned = ctx.aligned
+ clockwise = ctx.clockwise
+ sample_num = ctx.sample_num
+ rois = ctx.saved_tensors[0]
+ assert feature_size is not None
+ batch_size, num_channels, data_height, data_width = feature_size
+ out_w = grad_output.size(3)
+ out_h = grad_output.size(2)
+ grad_input = grad_rois = None
+ if ctx.needs_input_grad[0]:
+ grad_input = rois.new_zeros(batch_size, num_channels, data_height,
+ data_width)
+ ext_module.roi_align_rotated_backward(
+ grad_output.contiguous(),
+ rois,
+ grad_input,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return grad_input, grad_rois, None, None, None, None, None
+roi_align_rotated = RoIAlignRotatedFunction.apply
+class RoIAlignRotated(nn.Module):
+ """RoI align pooling layer for rotated proposals.
+ It accepts a feature map of shape (N, C, H, W) and rois with shape
+ (n, 6) with each roi decoded as (batch_index, center_x, center_y,
+ w, h, angle). The angle is in radian.
+ Args:
+ out_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sample_num (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ Default: True.
+ clockwise (bool): If True, the angle in each proposal follows a
+ clockwise fashion in image space, otherwise, the angle is
+ counterclockwise. Default: False.
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+ The meaning of aligned=True:
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+ def __init__(self,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ super(RoIAlignRotated, self).__init__()
+ self.out_size = out_size
+ self.spatial_scale = float(spatial_scale)
+ self.sample_num = int(sample_num)
+ self.aligned = aligned
+ self.clockwise = clockwise
+ def forward(self, features, rois):
+ return RoIAlignRotatedFunction.apply(features, rois, self.out_size,
+ self.spatial_scale,
+ self.sample_num, self.aligned,
+ self.clockwise)
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/roi_pool.py b/ControlNet/annotator/uniformer/mmcv/ops/roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..d339d8f2941eabc1cbe181a9c6c5ab5ff4ff4e5f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/roi_pool.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_pool_forward', 'roi_pool_backward'])
+class RoIPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale):
+ return g.op(
+ 'MaxRoiPool',
+ input,
+ rois,
+ pooled_shape_i=output_size,
+ spatial_scale_f=spatial_scale)
+ @staticmethod
+ def forward(ctx, input, rois, output_size, spatial_scale=1.0):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.input_shape = input.size()
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ argmax = input.new_zeros(output_shape, dtype=torch.int)
+ ext_module.roi_pool_forward(
+ input,
+ rois,
+ output,
+ argmax,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+ ctx.save_for_backward(rois, argmax)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ ext_module.roi_pool_backward(
+ grad_output,
+ rois,
+ argmax,
+ grad_input,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+ return grad_input, None, None, None
+roi_pool = RoIPoolFunction.apply
+class RoIPool(nn.Module):
+ def __init__(self, output_size, spatial_scale=1.0):
+ super(RoIPool, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ def forward(self, input, rois):
+ return roi_pool(input, rois, self.output_size, self.spatial_scale)
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/roiaware_pool3d.py b/ControlNet/annotator/uniformer/mmcv/ops/roiaware_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..291b0e5a9b692492c7d7e495ea639c46042e2f18
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/roiaware_pool3d.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+import annotator.uniformer.mmcv as mmcv
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])
+class RoIAwarePool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+ Please refer to `PartA2 `_ for more
+ details.
+ Args:
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int, optional): The maximum number of points per
+ voxel. Default: 128.
+ mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'.
+ Default: 'max'.
+ """
+ def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
+ super().__init__()
+ self.out_size = out_size
+ self.max_pts_per_voxel = max_pts_per_voxel
+ assert mode in ['max', 'avg']
+ pool_mapping = {'max': 0, 'avg': 1}
+ self.mode = pool_mapping[mode]
+ def forward(self, rois, pts, pts_feature):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
+ """
+ return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
+ self.out_size,
+ self.max_pts_per_voxel, self.mode)
+class RoIAwarePool3dFunction(Function):
+ @staticmethod
+ def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
+ mode):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int): The maximum number of points per voxel.
+ Default: 128.
+ mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average
+ pool).
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output
+ pooled features.
+ """
+ if isinstance(out_size, int):
+ out_x = out_y = out_z = out_size
+ else:
+ assert len(out_size) == 3
+ assert mmcv.is_tuple_of(out_size, int)
+ out_x, out_y, out_z = out_size
+ num_rois = rois.shape[0]
+ num_channels = pts_feature.shape[-1]
+ num_pts = pts.shape[0]
+ pooled_features = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels))
+ argmax = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
+ pts_idx_of_voxels = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
+ dtype=torch.int)
+ ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax,
+ pts_idx_of_voxels, pooled_features,
+ mode)
+ ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
+ num_pts, num_channels)
+ return pooled_features
+ @staticmethod
+ def backward(ctx, grad_out):
+ ret = ctx.roiaware_pool3d_for_backward
+ pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
+ grad_in = grad_out.new_zeros((num_pts, num_channels))
+ ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax,
+ grad_out.contiguous(), grad_in,
+ mode)
+ return None, None, grad_in, None, None, None
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/roipoint_pool3d.py b/ControlNet/annotator/uniformer/mmcv/ops/roipoint_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a21412c0728431c04b84245bc2e3109eea9aefc
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/roipoint_pool3d.py
@@ -0,0 +1,77 @@
+from torch import nn as nn
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])
+class RoIPointPool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+ Please refer to `Paper of PartA2 `_
+ for more details.
+ Args:
+ num_sampled_points (int, optional): Number of samples in each roi.
+ Default: 512.
+ """
+ def __init__(self, num_sampled_points=512):
+ super().__init__()
+ self.num_sampled_points = num_sampled_points
+ def forward(self, points, point_features, boxes3d):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
+ self.num_sampled_points)
+class RoIPointPool3dFunction(Function):
+ @staticmethod
+ def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+ num_sampled_points (int, optional): The num of sampled points.
+ Default: 512.
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ assert len(points.shape) == 3 and points.shape[2] == 3
+ batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
+ 1], point_features.shape[2]
+ pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
+ pooled_features = point_features.new_zeros(
+ (batch_size, boxes_num, num_sampled_points, 3 + feature_len))
+ pooled_empty_flag = point_features.new_zeros(
+ (batch_size, boxes_num)).int()
+ ext_module.roipoint_pool3d_forward(points.contiguous(),
+ pooled_boxes3d.contiguous(),
+ point_features.contiguous(),
+ pooled_features, pooled_empty_flag)
+ return pooled_features, pooled_empty_flag
+ @staticmethod
+ def backward(ctx, grad_out):
+ raise NotImplementedError
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/saconv.py b/ControlNet/annotator/uniformer/mmcv/ops/saconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ee3978e097fca422805db4e31ae481006d7971
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/saconv.py
@@ -0,0 +1,145 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
+from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+class SAConv2d(ConvAWS2d):
+ """SAC (Switchable Atrous Convolution)
+ This is an implementation of SAC in DetectoRS
+ (https://arxiv.org/pdf/2006.02334.pdf).
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ use_deform: If ``True``, replace convolution with deformable
+ convolution. Default: ``False``.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ use_deform=False):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.use_deform = use_deform
+ self.switch = nn.Conv2d(
+ self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
+ self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
+ self.pre_context = nn.Conv2d(
+ self.in_channels, self.in_channels, kernel_size=1, bias=True)
+ self.post_context = nn.Conv2d(
+ self.out_channels, self.out_channels, kernel_size=1, bias=True)
+ if self.use_deform:
+ self.offset_s = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.offset_l = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.init_weights()
+ def init_weights(self):
+ constant_init(self.switch, 0, bias=1)
+ self.weight_diff.data.zero_()
+ constant_init(self.pre_context, 0)
+ constant_init(self.post_context, 0)
+ if self.use_deform:
+ constant_init(self.offset_s, 0)
+ constant_init(self.offset_l, 0)
+ def forward(self, x):
+ # pre-context
+ avg_x = F.adaptive_avg_pool2d(x, output_size=1)
+ avg_x = self.pre_context(avg_x)
+ avg_x = avg_x.expand_as(x)
+ x = x + avg_x
+ # switch
+ avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
+ avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
+ switch = self.switch(avg_x)
+ # sac
+ weight = self._get_weight(self.weight)
+ zero_bias = torch.zeros(
+ self.out_channels, device=weight.device, dtype=weight.dtype)
+ if self.use_deform:
+ offset = self.offset_s(avg_x)
+ out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_s = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_s = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_s = super()._conv_forward(x, weight)
+ ori_p = self.padding
+ ori_d = self.dilation
+ self.padding = tuple(3 * p for p in self.padding)
+ self.dilation = tuple(3 * d for d in self.dilation)
+ weight = weight + self.weight_diff
+ if self.use_deform:
+ offset = self.offset_l(avg_x)
+ out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_l = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_l = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_l = super()._conv_forward(x, weight)
+ out = switch * out_s + (1 - switch) * out_l
+ self.padding = ori_p
+ self.dilation = ori_d
+ # post-context
+ avg_x = F.adaptive_avg_pool2d(out, output_size=1)
+ avg_x = self.post_context(avg_x)
+ avg_x = avg_x.expand_as(out)
+ out = out + avg_x
+ return out
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/scatter_points.py b/ControlNet/annotator/uniformer/mmcv/ops/scatter_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8aa4169e9f6ca4a6f845ce17d6d1e4db416bb8
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/scatter_points.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward'])
+class _DynamicScatter(Function):
+ @staticmethod
+ def forward(ctx, feats, coors, reduce_type='max'):
+ """convert kitti points(N, >=3) to voxels.
+ Args:
+ feats (torch.Tensor): [N, C]. Points features to be reduced
+ into voxels.
+ coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates
+ (specifically multi-dim voxel index) of each points.
+ reduce_type (str, optional): Reduce op. support 'max', 'sum' and
+ 'mean'. Default: 'max'.
+ Returns:
+ voxel_feats (torch.Tensor): [M, C]. Reduced features, input
+ features that shares the same voxel coordinates are reduced to
+ one row.
+ voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates.
+ """
+ results = ext_module.dynamic_point_to_voxel_forward(
+ feats, coors, reduce_type)
+ (voxel_feats, voxel_coors, point2voxel_map,
+ voxel_points_count) = results
+ ctx.reduce_type = reduce_type
+ ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
+ voxel_points_count)
+ ctx.mark_non_differentiable(voxel_coors)
+ return voxel_feats, voxel_coors
+ @staticmethod
+ def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
+ (feats, voxel_feats, point2voxel_map,
+ voxel_points_count) = ctx.saved_tensors
+ grad_feats = torch.zeros_like(feats)
+ # TODO: whether to use index put or use cuda_backward
+ # To use index put, need point to voxel index
+ ext_module.dynamic_point_to_voxel_backward(
+ grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats,
+ point2voxel_map, voxel_points_count, ctx.reduce_type)
+ return grad_feats, None, None
+dynamic_scatter = _DynamicScatter.apply
+class DynamicScatter(nn.Module):
+ """Scatters points into voxels, used in the voxel encoder with dynamic
+ voxelization.
+ Note:
+ The CPU and GPU implementation get the same output, but have numerical
+ difference after summation and division (e.g., 5e-7).
+ Args:
+ voxel_size (list): list [x, y, z] size of three dimension.
+ point_cloud_range (list): The coordinate range of points, [x_min,
+ y_min, z_min, x_max, y_max, z_max].
+ average_points (bool): whether to use avg pooling to scatter points
+ into voxel.
+ """
+ def __init__(self, voxel_size, point_cloud_range, average_points: bool):
+ super().__init__()
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.average_points = average_points
+ def forward_single(self, points, coors):
+ """Scatters points into voxels.
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ reduce = 'mean' if self.average_points else 'max'
+ return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
+ def forward(self, points, coors):
+ """Scatters points/features into voxels.
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ if coors.size(-1) == 3:
+ return self.forward_single(points, coors)
+ else:
+ batch_size = coors[-1, 0] + 1
+ voxels, voxel_coors = [], []
+ for i in range(batch_size):
+ inds = torch.where(coors[:, 0] == i)
+ voxel, voxel_coor = self.forward_single(
+ points[inds], coors[inds][:, 1:])
+ coor_pad = nn.functional.pad(
+ voxel_coor, (1, 0), mode='constant', value=i)
+ voxel_coors.append(coor_pad)
+ voxels.append(voxel)
+ features = torch.cat(voxels, dim=0)
+ feature_coors = torch.cat(voxel_coors, dim=0)
+ return features, feature_coors
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', average_points=' + str(self.average_points)
+ s += ')'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/sync_bn.py b/ControlNet/annotator/uniformer/mmcv/ops/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b016fcbe860989c56cd1040034bcfa60e146d2
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/sync_bn.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.module import Module
+from torch.nn.parameter import Parameter
+from annotator.uniformer.mmcv.cnn import NORM_LAYERS
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', [
+ 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output',
+ 'sync_bn_backward_param', 'sync_bn_backward_data'
+class SyncBatchNormFunction(Function):
+ @staticmethod
+ def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ return g.op(
+ 'mmcv::MMCVSyncBatchNorm',
+ input,
+ running_mean,
+ running_var,
+ weight,
+ bias,
+ momentum_f=momentum,
+ eps_f=eps,
+ group_i=group,
+ group_size_i=group_size,
+ stats_mode=stats_mode)
+ @staticmethod
+ def forward(self, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ self.momentum = momentum
+ self.eps = eps
+ self.group = group
+ self.group_size = group_size
+ self.stats_mode = stats_mode
+ assert isinstance(
+ input, (torch.HalfTensor, torch.FloatTensor,
+ torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
+ f'only support Half or Float Tensor, but {input.type()}'
+ output = torch.zeros_like(input)
+ input3d = input.flatten(start_dim=2)
+ output3d = output.view_as(input3d)
+ num_channels = input3d.size(1)
+ # ensure mean/var/norm/std are initialized as zeros
+ # ``torch.empty()`` does not guarantee that
+ mean = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ var = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ norm = torch.zeros_like(
+ input3d, dtype=torch.float, device=input3d.device)
+ std = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ batch_size = input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_forward_mean(input3d, mean)
+ batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
+ else:
+ # skip updating mean and leave it as zeros when the input is empty
+ batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)
+ # synchronize mean and the batch flag
+ vec = torch.cat([mean, batch_flag])
+ if self.stats_mode == 'N':
+ vec *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(vec, group=self.group)
+ total_batch = vec[-1].detach()
+ mean = vec[:num_channels]
+ if self.stats_mode == 'default':
+ mean = mean / self.group_size
+ elif self.stats_mode == 'N':
+ mean = mean / total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+ # leave var as zeros when the input is empty
+ if batch_size > 0:
+ ext_module.sync_bn_forward_var(input3d, mean, var)
+ if self.stats_mode == 'N':
+ var *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(var, group=self.group)
+ if self.stats_mode == 'default':
+ var /= self.group_size
+ elif self.stats_mode == 'N':
+ var /= total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+ # if the total batch size over all the ranks is zero,
+ # we should not update the statistics in the current batch
+ update_flag = total_batch.clamp(max=1)
+ momentum = update_flag * self.momentum
+ ext_module.sync_bn_forward_output(
+ input3d,
+ mean,
+ var,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ norm,
+ std,
+ output3d,
+ eps=self.eps,
+ momentum=momentum,
+ group_size=self.group_size)
+ self.save_for_backward(norm, std, weight)
+ return output
+ @staticmethod
+ @once_differentiable
+ def backward(self, grad_output):
+ norm, std, weight = self.saved_tensors
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(weight)
+ grad_input = torch.zeros_like(grad_output)
+ grad_output3d = grad_output.flatten(start_dim=2)
+ grad_input3d = grad_input.view_as(grad_output3d)
+ batch_size = grad_input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
+ grad_bias)
+ # all reduce
+ if self.group_size > 1:
+ dist.all_reduce(grad_weight, group=self.group)
+ dist.all_reduce(grad_bias, group=self.group)
+ grad_weight /= self.group_size
+ grad_bias /= self.group_size
+ if batch_size > 0:
+ ext_module.sync_bn_backward_data(grad_output3d, weight,
+ grad_weight, grad_bias, norm, std,
+ grad_input3d)
+ return grad_input, None, None, grad_weight, grad_bias, \
+ None, None, None, None, None
+class SyncBatchNorm(Module):
+ """Synchronized Batch Normalization.
+ Args:
+ num_features (int): number of features/chennels in input tensor
+ eps (float, optional): a value added to the denominator for numerical
+ stability. Defaults to 1e-5.
+ momentum (float, optional): the value used for the running_mean and
+ running_var computation. Defaults to 0.1.
+ affine (bool, optional): whether to use learnable affine parameters.
+ Defaults to True.
+ track_running_stats (bool, optional): whether to track the running
+ mean and variance during training. When set to False, this
+ module does not track such statistics, and initializes statistics
+ buffers ``running_mean`` and ``running_var`` as ``None``. When
+ these buffers are ``None``, this module always uses batch
+ statistics in both training and eval modes. Defaults to True.
+ group (int, optional): synchronization of stats happen within
+ each process group individually. By default it is synchronization
+ across the whole world. Defaults to None.
+ stats_mode (str, optional): The statistical mode. Available options
+ includes ``'default'`` and ``'N'``. Defaults to 'default'.
+ When ``stats_mode=='default'``, it computes the overall statistics
+ using those from each worker with equal weight, i.e., the
+ statistics are synchronized and simply divied by ``group``. This
+ mode will produce inaccurate statistics when empty tensors occur.
+ When ``stats_mode=='N'``, it compute the overall statistics using
+ the total number of batches in each worker ignoring the number of
+ group, i.e., the statistics are synchronized and then divied by
+ the total batch ``N``. This mode is beneficial when empty tensors
+ occur during training, as it average the total mean by the real
+ number of batch.
+ """
+ def __init__(self,
+ num_features,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ group=None,
+ stats_mode='default'):
+ super(SyncBatchNorm, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.track_running_stats = track_running_stats
+ group = dist.group.WORLD if group is None else group
+ self.group = group
+ self.group_size = dist.get_world_size(group)
+ assert stats_mode in ['default', 'N'], \
+ f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
+ self.stats_mode = stats_mode
+ if self.affine:
+ self.weight = Parameter(torch.Tensor(num_features))
+ self.bias = Parameter(torch.Tensor(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ if self.track_running_stats:
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.register_buffer('num_batches_tracked',
+ torch.tensor(0, dtype=torch.long))
+ else:
+ self.register_buffer('running_mean', None)
+ self.register_buffer('running_var', None)
+ self.register_buffer('num_batches_tracked', None)
+ self.reset_parameters()
+ def reset_running_stats(self):
+ if self.track_running_stats:
+ self.running_mean.zero_()
+ self.running_var.fill_(1)
+ self.num_batches_tracked.zero_()
+ def reset_parameters(self):
+ self.reset_running_stats()
+ if self.affine:
+ self.weight.data.uniform_() # pytorch use ones_()
+ self.bias.data.zero_()
+ def forward(self, input):
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input, got {input.dim()}D input')
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
+ if self.training and self.track_running_stats:
+ if self.num_batches_tracked is not None:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(
+ self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+ if self.training or not self.track_running_stats:
+ return SyncBatchNormFunction.apply(
+ input, self.running_mean, self.running_var, self.weight,
+ self.bias, exponential_average_factor, self.eps, self.group,
+ self.group_size, self.stats_mode)
+ else:
+ return F.batch_norm(input, self.running_mean, self.running_var,
+ self.weight, self.bias, False,
+ exponential_average_factor, self.eps)
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'({self.num_features}, '
+ s += f'eps={self.eps}, '
+ s += f'momentum={self.momentum}, '
+ s += f'affine={self.affine}, '
+ s += f'track_running_stats={self.track_running_stats}, '
+ s += f'group_size={self.group_size},'
+ s += f'stats_mode={self.stats_mode})'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/three_interpolate.py b/ControlNet/annotator/uniformer/mmcv/ops/three_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..203f47f05d58087e034fb3cd8cd6a09233947b4a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/three_interpolate.py
@@ -0,0 +1,68 @@
+from typing import Tuple
+import torch
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['three_interpolate_forward', 'three_interpolate_backward'])
+class ThreeInterpolate(Function):
+ """Performs weighted linear interpolation on 3 features.
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
+ weight: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, M) Features descriptors to be
+ interpolated
+ indices (Tensor): (B, n, 3) index three nearest neighbors
+ of the target features in features
+ weight (Tensor): (B, n, 3) weights of interpolation
+ Returns:
+ Tensor: (B, C, N) tensor of the interpolated features
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+ assert weight.is_contiguous()
+ B, c, m = features.size()
+ n = indices.size(1)
+ ctx.three_interpolate_for_backward = (indices, weight, m)
+ output = torch.cuda.FloatTensor(B, c, n)
+ ext_module.three_interpolate_forward(
+ features, indices, weight, output, b=B, c=c, m=m, n=n)
+ return output
+ @staticmethod
+ def backward(
+ ctx, grad_out: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, N) tensor with gradients of outputs
+ Returns:
+ Tensor: (B, C, M) tensor with gradients of features
+ """
+ idx, weight, m = ctx.three_interpolate_for_backward
+ B, c, n = grad_out.size()
+ grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.three_interpolate_backward(
+ grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
+ return grad_features, None, None
+three_interpolate = ThreeInterpolate.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/three_nn.py b/ControlNet/annotator/uniformer/mmcv/ops/three_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b01047a129989cd5545a0a86f23a487f4a13ce1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/three_nn.py
@@ -0,0 +1,51 @@
+from typing import Tuple
+import torch
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])
+class ThreeNN(Function):
+ """Find the top-3 nearest neighbors of the target set from the source set.
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+ @staticmethod
+ def forward(ctx, target: torch.Tensor,
+ source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ target (Tensor): shape (B, N, 3), points set that needs to
+ find the nearest neighbors.
+ source (Tensor): shape (B, M, 3), points set that is used
+ to find the nearest neighbors of points in target set.
+ Returns:
+ Tensor: shape (B, N, 3), L2 distance of each point in target
+ set to their corresponding nearest neighbors.
+ """
+ target = target.contiguous()
+ source = source.contiguous()
+ B, N, _ = target.size()
+ m = source.size(1)
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
+ idx = torch.cuda.IntTensor(B, N, 3)
+ ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return torch.sqrt(dist2), idx
+ @staticmethod
+ def backward(ctx, a=None, b=None):
+ return None, None
+three_nn = ThreeNN.apply
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/tin_shift.py b/ControlNet/annotator/uniformer/mmcv/ops/tin_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..472c9fcfe45a124e819b7ed5653e585f94a8811e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/tin_shift.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code reference from "Temporal Interlacing Network"
+# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
+# Hao Shao, Shengju Qian, Yu Liu
+# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext('_ext',
+ ['tin_shift_forward', 'tin_shift_backward'])
+class TINShiftFunction(Function):
+ @staticmethod
+ def forward(ctx, input, shift):
+ C = input.size(2)
+ num_segments = shift.size(1)
+ if C // num_segments <= 0 or C % num_segments != 0:
+ raise ValueError('C should be a multiple of num_segments, '
+ f'but got C={C} and num_segments={num_segments}.')
+ ctx.save_for_backward(shift)
+ out = torch.zeros_like(input)
+ ext_module.tin_shift_forward(input, shift, out)
+ return out
+ @staticmethod
+ def backward(ctx, grad_output):
+ shift = ctx.saved_tensors[0]
+ data_grad_input = grad_output.new(*grad_output.size()).zero_()
+ shift_grad_input = shift.new(*shift.size()).zero_()
+ ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
+ return data_grad_input, shift_grad_input
+tin_shift = TINShiftFunction.apply
+class TINShift(nn.Module):
+ """Temporal Interlace Shift.
+ Temporal Interlace shift is a differentiable temporal-wise frame shifting
+ which is proposed in "Temporal Interlacing Network"
+ Please refer to https://arxiv.org/abs/2001.06499 for more details.
+ Code is modified from https://github.com/mit-han-lab/temporal-shift-module
+ """
+ def forward(self, input, shift):
+ """Perform temporal interlace shift.
+ Args:
+ input (Tensor): Feature map with shape [N, num_segments, C, H * W].
+ shift (Tensor): Shift tensor with shape [N, num_segments].
+ Returns:
+ Feature map after temporal interlace shift.
+ """
+ return tin_shift(input, shift)
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/upfirdn2d.py b/ControlNet/annotator/uniformer/mmcv/ops/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8bb2c3c949eed38a6465ed369fa881538dca010
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/upfirdn2d.py
@@ -0,0 +1,330 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+# 1. Definitions
+# "Licensor" means any person or entity that distributes its Work.
+# "Software" means the original work of authorship made available under
+# this License.
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+# 2. License Grants
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+# 3. Limitations
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+# 4. Disclaimer of Warranty.
+# 5. Limitation of Liability.
+# =======================================================================
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+from annotator.uniformer.mmcv.utils import to_2tuple
+from ..utils import ext_loader
+upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
+class UpFirDn2dBackward(Function):
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
+ in_size, out_size):
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ up_x=down_x,
+ up_y=down_y,
+ down_x=up_x,
+ down_y=up_y,
+ pad_x0=g_pad_x0,
+ pad_x1=g_pad_x1,
+ pad_y0=g_pad_y0,
+ pad_y1=g_pad_y1)
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
+ in_size[3])
+ ctx.save_for_backward(kernel)
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+ return grad_input
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
+ ctx.in_size[3], 1)
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ up_x=ctx.up_x,
+ up_y=ctx.up_y,
+ down_x=ctx.down_x,
+ down_y=ctx.down_y,
+ pad_x0=ctx.pad_x0,
+ pad_x1=ctx.pad_x1,
+ pad_y0=ctx.pad_y0,
+ pad_y1=ctx.pad_y1)
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
+ ctx.out_size[0], ctx.out_size[1])
+ return gradgrad_out, None, None, None, None, None, None, None, None
+class UpFirDn2d(Function):
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+ out = upfirdn2d_ext.upfirdn2d(
+ input,
+ kernel,
+ up_x=up_x,
+ up_y=up_y,
+ down_x=down_x,
+ down_y=down_y,
+ pad_x0=pad_x0,
+ pad_x1=pad_x1,
+ pad_y0=pad_y0,
+ pad_y1=pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+ return out
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+ return grad_input, None, None, None, None
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ """UpFRIDn for 2d features.
+ UpFIRDn is short for upsample, apply FIR filter and downsample. More
+ details can be found in:
+ https://www.mathworks.com/help/signal/ref/upfirdn.html
+ Args:
+ input (Tensor): Tensor with shape of (n, c, h, w).
+ kernel (Tensor): Filter kernel.
+ up (int | tuple[int], optional): Upsampling factor. If given a number,
+ we will use this factor for the both height and width side.
+ Defaults to 1.
+ down (int | tuple[int], optional): Downsampling factor. If given a
+ number, we will use this factor for the both height and width side.
+ Defaults to 1.
+ pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
+ (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
+ Returns:
+ Tensor: Tensor after UpFIRDn.
+ """
+ if input.device.type == 'cpu':
+ if len(pad) == 2:
+ pad = (pad[0], pad[1], pad[0], pad[1])
+ up = to_2tuple(up)
+ down = to_2tuple(down)
+ out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
+ pad[0], pad[1], pad[2], pad[3])
+ else:
+ _up = to_2tuple(up)
+ _down = to_2tuple(down)
+ if len(pad) == 4:
+ _pad = pad
+ elif len(pad) == 2:
+ _pad = (pad[0], pad[1], pad[0], pad[1])
+ out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
+ return out
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
+ pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+ out = F.pad(
+ out,
+ [0, 0,
+ max(pad_x0, 0),
+ max(pad_x1, 0),
+ max(pad_y0, 0),
+ max(pad_y1, 0)])
+ out = out[:,
+ max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ return out.view(-1, channel, out_h, out_w)
diff --git a/ControlNet/annotator/uniformer/mmcv/ops/voxelize.py b/ControlNet/annotator/uniformer/mmcv/ops/voxelize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3226a4fbcbfe58490fa2ea8e1c16b531214121
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/ops/voxelize.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+from ..utils import ext_loader
+ext_module = ext_loader.load_ext(
+ '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward'])
+class _Voxelization(Function):
+ @staticmethod
+ def forward(ctx,
+ points,
+ voxel_size,
+ coors_range,
+ max_points=35,
+ max_voxels=20000):
+ """Convert kitti points(N, >=3) to voxels.
+ Args:
+ points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points
+ and points[:, 3:] contain other information like reflectivity.
+ voxel_size (tuple or float): The size of voxel with the shape of
+ [3].
+ coors_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_points (int, optional): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize. Default: 35.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+ Returns:
+ voxels_out (torch.Tensor): Output voxels with the shape of [M,
+ max_points, ndim]. Only contain points and returned when
+ max_points != -1.
+ coors_out (torch.Tensor): Output coordinates with the shape of
+ [M, 3].
+ num_points_per_voxel_out (torch.Tensor): Num points per voxel with
+ the shape of [M]. Only returned when max_points != -1.
+ """
+ if max_points == -1 or max_voxels == -1:
+ coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
+ ext_module.dynamic_voxelize_forward(points, coors, voxel_size,
+ coors_range, 3)
+ return coors
+ else:
+ voxels = points.new_zeros(
+ size=(max_voxels, max_points, points.size(1)))
+ coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
+ num_points_per_voxel = points.new_zeros(
+ size=(max_voxels, ), dtype=torch.int)
+ voxel_num = ext_module.hard_voxelize_forward(
+ points, voxels, coors, num_points_per_voxel, voxel_size,
+ coors_range, max_points, max_voxels, 3)
+ # select the valid voxels
+ voxels_out = voxels[:voxel_num]
+ coors_out = coors[:voxel_num]
+ num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
+ return voxels_out, coors_out, num_points_per_voxel_out
+voxelization = _Voxelization.apply
+class Voxelization(nn.Module):
+ """Convert kitti points(N, >=3) to voxels.
+ Please refer to `PVCNN `_ for more
+ details.
+ Args:
+ voxel_size (tuple or float): The size of voxel with the shape of [3].
+ point_cloud_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_num_points (int): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+ """
+ def __init__(self,
+ voxel_size,
+ point_cloud_range,
+ max_num_points,
+ max_voxels=20000):
+ super().__init__()
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.max_num_points = max_num_points
+ if isinstance(max_voxels, tuple):
+ self.max_voxels = max_voxels
+ else:
+ self.max_voxels = _pair(max_voxels)
+ point_cloud_range = torch.tensor(
+ point_cloud_range, dtype=torch.float32)
+ voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
+ grid_size = (point_cloud_range[3:] -
+ point_cloud_range[:3]) / voxel_size
+ grid_size = torch.round(grid_size).long()
+ input_feat_shape = grid_size[:2]
+ self.grid_size = grid_size
+ # the origin shape is as [x-len, y-len, z-len]
+ # [w, h, d] -> [d, h, w]
+ self.pcd_shape = [*input_feat_shape, 1][::-1]
+ def forward(self, input):
+ if self.training:
+ max_voxels = self.max_voxels[0]
+ else:
+ max_voxels = self.max_voxels[1]
+ return voxelization(input, self.voxel_size, self.point_cloud_range,
+ self.max_num_points, max_voxels)
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', max_num_points=' + str(self.max_num_points)
+ s += ', max_voxels=' + str(self.max_voxels)
+ s += ')'
+ return s
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/__init__.py b/ControlNet/annotator/uniformer/mmcv/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed2c17ad357742e423beeaf4d35db03fe9af469
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .collate import collate
+from .data_container import DataContainer
+from .data_parallel import MMDataParallel
+from .distributed import MMDistributedDataParallel
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter, scatter_kwargs
+from .utils import is_module_wrapper
+__all__ = [
+ 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
+ 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/_functions.py b/ControlNet/annotator/uniformer/mmcv/parallel/_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5a8a44483ab991411d07122b22a1d027e4be8e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/_functions.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import _get_stream
+def scatter(input, devices, streams=None):
+ """Scatters tensor across multiple GPUs."""
+ if streams is None:
+ streams = [None] * len(devices)
+ if isinstance(input, list):
+ chunk_size = (len(input) - 1) // len(devices) + 1
+ outputs = [
+ scatter(input[i], [devices[i // chunk_size]],
+ [streams[i // chunk_size]]) for i in range(len(input))
+ ]
+ return outputs
+ elif isinstance(input, torch.Tensor):
+ output = input.contiguous()
+ # TODO: copy to a pinned buffer first (if copying from CPU)
+ stream = streams[0] if output.numel() > 0 else None
+ if devices != [-1]:
+ with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
+ output = output.cuda(devices[0], non_blocking=True)
+ else:
+ # unsqueeze the first dimension thus the tensor's shape is the
+ # same as those scattered with GPU.
+ output = output.unsqueeze(0)
+ return output
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+def synchronize_stream(output, devices, streams):
+ if isinstance(output, list):
+ chunk_size = len(output) // len(devices)
+ for i in range(len(devices)):
+ for j in range(chunk_size):
+ synchronize_stream(output[i * chunk_size + j], [devices[i]],
+ [streams[i]])
+ elif isinstance(output, torch.Tensor):
+ if output.numel() != 0:
+ with torch.cuda.device(devices[0]):
+ main_stream = torch.cuda.current_stream()
+ main_stream.wait_stream(streams[0])
+ output.record_stream(main_stream)
+ else:
+ raise Exception(f'Unknown type {type(output)}.')
+def get_input_device(input):
+ if isinstance(input, list):
+ for item in input:
+ input_device = get_input_device(item)
+ if input_device != -1:
+ return input_device
+ return -1
+ elif isinstance(input, torch.Tensor):
+ return input.get_device() if input.is_cuda else -1
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+class Scatter:
+ @staticmethod
+ def forward(target_gpus, input):
+ input_device = get_input_device(input)
+ streams = None
+ if input_device == -1 and target_gpus != [-1]:
+ # Perform CPU to GPU copies in a background stream
+ streams = [_get_stream(device) for device in target_gpus]
+ outputs = scatter(input, target_gpus, streams)
+ # Synchronize with the copy stream
+ if streams is not None:
+ synchronize_stream(outputs, target_gpus, streams)
+ return tuple(outputs)
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/collate.py b/ControlNet/annotator/uniformer/mmcv/parallel/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad749197df21b0d74297548be5f66a696adebf7f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/collate.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Mapping, Sequence
+import torch
+import torch.nn.functional as F
+from torch.utils.data.dataloader import default_collate
+from .data_container import DataContainer
+def collate(batch, samples_per_gpu=1):
+ """Puts each data field into a tensor/DataContainer with outer dimension
+ batch size.
+ Extend default_collate to add support for
+ :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
+ 1. cpu_only = True, e.g., meta data
+ 2. cpu_only = False, stack = True, e.g., images tensors
+ 3. cpu_only = False, stack = False, e.g., gt bboxes
+ """
+ if not isinstance(batch, Sequence):
+ raise TypeError(f'{batch.dtype} is not supported.')
+ if isinstance(batch[0], DataContainer):
+ stacked = []
+ if batch[0].cpu_only:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(
+ stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
+ elif batch[0].stack:
+ for i in range(0, len(batch), samples_per_gpu):
+ assert isinstance(batch[i].data, torch.Tensor)
+ if batch[i].pad_dims is not None:
+ ndim = batch[i].dim()
+ assert ndim > batch[i].pad_dims
+ max_shape = [0 for _ in range(batch[i].pad_dims)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = batch[i].size(-dim)
+ for sample in batch[i:i + samples_per_gpu]:
+ for dim in range(0, ndim - batch[i].pad_dims):
+ assert batch[i].size(dim) == sample.size(dim)
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = max(max_shape[dim - 1],
+ sample.size(-dim))
+ padded_samples = []
+ for sample in batch[i:i + samples_per_gpu]:
+ pad = [0 for _ in range(batch[i].pad_dims * 2)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ pad[2 * dim -
+ 1] = max_shape[dim - 1] - sample.size(-dim)
+ padded_samples.append(
+ F.pad(
+ sample.data, pad, value=sample.padding_value))
+ stacked.append(default_collate(padded_samples))
+ elif batch[i].pad_dims is None:
+ stacked.append(
+ default_collate([
+ sample.data
+ for sample in batch[i:i + samples_per_gpu]
+ ]))
+ else:
+ raise ValueError(
+ 'pad_dims should be either None or integers (1-3)')
+ else:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
+ elif isinstance(batch[0], Sequence):
+ transposed = zip(*batch)
+ return [collate(samples, samples_per_gpu) for samples in transposed]
+ elif isinstance(batch[0], Mapping):
+ return {
+ key: collate([d[key] for d in batch], samples_per_gpu)
+ for key in batch[0]
+ }
+ else:
+ return default_collate(batch)
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/data_container.py b/ControlNet/annotator/uniformer/mmcv/parallel/data_container.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedb0d32a51a1f575a622b38de2cee3ab4757821
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/data_container.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import torch
+def assert_tensor_type(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if not isinstance(args[0].data, torch.Tensor):
+ raise AttributeError(
+ f'{args[0].__class__.__name__} has no attribute '
+ f'{func.__name__} for type {args[0].datatype}')
+ return func(*args, **kwargs)
+ return wrapper
+class DataContainer:
+ """A container for any type of objects.
+ Typically tensors will be stacked in the collate function and sliced along
+ some dimension in the scatter function. This behavior has some limitations.
+ 1. All tensors have to be the same size.
+ 2. Types are limited (numpy array or Tensor).
+ We design `DataContainer` and `MMDataParallel` to overcome these
+ limitations. The behavior can be either of the following.
+ - copy to GPU, pad all tensors to the same size and stack them
+ - copy to GPU without stacking
+ - leave the objects as is and pass it to the model
+ - pad_dims specifies the number of last few dimensions to do padding
+ """
+ def __init__(self,
+ data,
+ stack=False,
+ padding_value=0,
+ cpu_only=False,
+ pad_dims=2):
+ self._data = data
+ self._cpu_only = cpu_only
+ self._stack = stack
+ self._padding_value = padding_value
+ assert pad_dims in [None, 1, 2, 3]
+ self._pad_dims = pad_dims
+ def __repr__(self):
+ return f'{self.__class__.__name__}({repr(self.data)})'
+ def __len__(self):
+ return len(self._data)
+ @property
+ def data(self):
+ return self._data
+ @property
+ def datatype(self):
+ if isinstance(self.data, torch.Tensor):
+ return self.data.type()
+ else:
+ return type(self.data)
+ @property
+ def cpu_only(self):
+ return self._cpu_only
+ @property
+ def stack(self):
+ return self._stack
+ @property
+ def padding_value(self):
+ return self._padding_value
+ @property
+ def pad_dims(self):
+ return self._pad_dims
+ @assert_tensor_type
+ def size(self, *args, **kwargs):
+ return self.data.size(*args, **kwargs)
+ @assert_tensor_type
+ def dim(self):
+ return self.data.dim()
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/data_parallel.py b/ControlNet/annotator/uniformer/mmcv/parallel/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b5f69b654cf647dc7ae9174223781ab5c607d2
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/data_parallel.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import chain
+from torch.nn.parallel import DataParallel
+from .scatter_gather import scatter_kwargs
+class MMDataParallel(DataParallel):
+ """The DataParallel module that supports DataContainer.
+ MMDataParallel has two main differences with PyTorch DataParallel:
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data during both GPU and CPU inference.
+ - It implement two more APIs ``train_step()`` and ``val_step()``.
+ Args:
+ module (:class:`nn.Module`): Module to be encapsulated.
+ device_ids (list[int]): Device IDS of modules to be scattered to.
+ Defaults to None when GPU is not available.
+ output_device (str | int): Device ID for output. Defaults to None.
+ dim (int): Dimension used to scatter the data. Defaults to 0.
+ """
+ def __init__(self, *args, dim=0, **kwargs):
+ super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
+ self.dim = dim
+ def forward(self, *inputs, **kwargs):
+ """Override the original forward function.
+ The main difference lies in the CPU inference where the data in
+ :class:`DataContainers` will still be gathered.
+ """
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module(*inputs[0], **kwargs[0])
+ else:
+ return super().forward(*inputs, **kwargs)
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+ def train_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.train_step(*inputs[0], **kwargs[0])
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ 'instead.')
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.train_step(*inputs[0], **kwargs[0])
+ def val_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.val_step(*inputs[0], **kwargs[0])
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ ' instead.')
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.val_step(*inputs[0], **kwargs[0])
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/distributed.py b/ControlNet/annotator/uniformer/mmcv/parallel/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e4c27903db58a54d37ea1ed9ec0104098b486f2
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/distributed.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel.distributed import (DistributedDataParallel,
+ _find_tensors)
+from annotator.uniformer.mmcv import print_log
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .scatter_gather import scatter_kwargs
+class MMDistributedDataParallel(DistributedDataParallel):
+ """The DDP module that supports DataContainer.
+ MMDDP has two main differences with PyTorch DDP:
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data.
+ - It implement two APIs ``train_step()`` and ``val_step()``.
+ """
+ def to_kwargs(self, inputs, kwargs, device_id):
+ # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
+ # to move all tensors to device_id
+ return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+ def train_step(self, *inputs, **kwargs):
+ """train_step() API for module wrapped by DistributedDataParallel.
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.train_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.train_step(*inputs, **kwargs)
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
+ def val_step(self, *inputs, **kwargs):
+ """val_step() API for module wrapped by DistributedDataParallel.
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.val_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.val_step(*inputs, **kwargs)
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/distributed_deprecated.py b/ControlNet/annotator/uniformer/mmcv/parallel/distributed_deprecated.py
new file mode 100644
index 0000000000000000000000000000000000000000..676937a2085d4da20fa87923041a200fca6214eb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/distributed_deprecated.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter_kwargs
+class MMDistributedDataParallel(nn.Module):
+ def __init__(self,
+ module,
+ dim=0,
+ broadcast_buffers=True,
+ bucket_cap_mb=25):
+ super(MMDistributedDataParallel, self).__init__()
+ self.module = module
+ self.dim = dim
+ self.broadcast_buffers = broadcast_buffers
+ self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
+ self._sync_params()
+ def _dist_broadcast_coalesced(self, tensors, buffer_size):
+ for tensors in _take_tensors(tensors, buffer_size):
+ flat_tensors = _flatten_dense_tensors(tensors)
+ dist.broadcast(flat_tensors, 0)
+ for tensor, synced in zip(
+ tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
+ tensor.copy_(synced)
+ def _sync_params(self):
+ module_states = list(self.module.state_dict().values())
+ if len(module_states) > 0:
+ self._dist_broadcast_coalesced(module_states,
+ self.broadcast_bucket_size)
+ if self.broadcast_buffers:
+ if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) < digit_version('1.0')):
+ buffers = [b.data for b in self.module._all_buffers()]
+ else:
+ buffers = [b.data for b in self.module.buffers()]
+ if len(buffers) > 0:
+ self._dist_broadcast_coalesced(buffers,
+ self.broadcast_bucket_size)
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+ def forward(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ return self.module(*inputs[0], **kwargs[0])
+ def train_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ return output
+ def val_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ return output
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/registry.py b/ControlNet/annotator/uniformer/mmcv/parallel/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..a204a07fba10e614223f090d1a57cf9c4d74d4a1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/registry.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+from annotator.uniformer.mmcv.utils import Registry
+MODULE_WRAPPERS = Registry('module wrapper')
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/scatter_gather.py b/ControlNet/annotator/uniformer/mmcv/parallel/scatter_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..900ff88566f8f14830590459dc4fd16d4b382e47
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/scatter_gather.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import Scatter as OrigScatter
+from ._functions import Scatter
+from .data_container import DataContainer
+def scatter(inputs, target_gpus, dim=0):
+ """Scatter inputs to target gpus.
+ The only difference from original :func:`scatter` is to add support for
+ :type:`~mmcv.parallel.DataContainer`.
+ """
+ def scatter_map(obj):
+ if isinstance(obj, torch.Tensor):
+ if target_gpus != [-1]:
+ return OrigScatter.apply(target_gpus, None, dim, obj)
+ else:
+ # for CPU inference we use self-implemented scatter
+ return Scatter.forward(target_gpus, obj)
+ if isinstance(obj, DataContainer):
+ if obj.cpu_only:
+ return obj.data
+ else:
+ return Scatter.forward(target_gpus, obj.data)
+ if isinstance(obj, tuple) and len(obj) > 0:
+ return list(zip(*map(scatter_map, obj)))
+ if isinstance(obj, list) and len(obj) > 0:
+ out = list(map(list, zip(*map(scatter_map, obj))))
+ return out
+ if isinstance(obj, dict) and len(obj) > 0:
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+ return out
+ return [obj for targets in target_gpus]
+ # After scatter_map is called, a scatter_map cell will exist. This cell
+ # has a reference to the actual function scatter_map, which has references
+ # to a closure that has a reference to the scatter_map cell (because the
+ # fn is recursive). To avoid this reference cycle, we set the function to
+ # None, clearing the cell
+ try:
+ return scatter_map(inputs)
+ finally:
+ scatter_map = None
+def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
+ """Scatter with support for kwargs dictionary."""
+ inputs = scatter(inputs, target_gpus, dim) if inputs else []
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
+ if len(inputs) < len(kwargs):
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+ elif len(kwargs) < len(inputs):
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+ inputs = tuple(inputs)
+ kwargs = tuple(kwargs)
+ return inputs, kwargs
diff --git a/ControlNet/annotator/uniformer/mmcv/parallel/utils.py b/ControlNet/annotator/uniformer/mmcv/parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5712cb42c38a2e8563bf563efb6681383cab9b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/parallel/utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .registry import MODULE_WRAPPERS
+def is_module_wrapper(module):
+ """Check if a module is a module wrapper.
+ The following 3 modules in MMCV (and their subclasses) are regarded as
+ module wrappers: DataParallel, DistributedDataParallel,
+ MMDistributedDataParallel (the deprecated version). You may add you own
+ module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
+ Args:
+ module (nn.Module): The module to be checked.
+ Returns:
+ bool: True if the input module is a module wrapper.
+ """
+ module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
+ return isinstance(module, module_wrappers)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/__init__.py b/ControlNet/annotator/uniformer/mmcv/runner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e4b48d383a84a055dcd7f6236f6e8e58eab924
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/__init__.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_module import BaseModule, ModuleList, Sequential
+from .base_runner import BaseRunner
+from .builder import RUNNERS, build_runner
+from .checkpoint import (CheckpointLoader, _load_checkpoint,
+ _load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict, save_checkpoint, weights_to_cpu)
+from .default_constructor import DefaultRunnerConstructor
+from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
+ init_dist, master_only)
+from .epoch_based_runner import EpochBasedRunner, Runner
+from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
+from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
+ DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
+ Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, Hook, IterTimerHook,
+ LoggerHook, LrUpdaterHook, MlflowLoggerHook,
+ NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
+ SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
+ WandbLoggerHook)
+from .iter_based_runner import IterBasedRunner, IterLoader
+from .log_buffer import LogBuffer
+from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
+ DefaultOptimizerConstructor, build_optimizer,
+ build_optimizer_constructor)
+from .priority import Priority, get_priority
+from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
+__all__ = [
+ 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
+ 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
+ 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
+ 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
+ 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
+ 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
+ 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
+ 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
+ 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
+ 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
+ '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
+ 'ModuleList', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/base_module.py b/ControlNet/annotator/uniformer/mmcv/runner/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..617fad9bb89f10a9a0911d962dfb3bc8f3a3628c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/base_module.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+from abc import ABCMeta
+from collections import defaultdict
+from logging import FileHandler
+import torch.nn as nn
+from annotator.uniformer.mmcv.runner.dist_utils import master_only
+from annotator.uniformer.mmcv.utils.logging import get_logger, logger_initialized, print_log
+class BaseModule(nn.Module, metaclass=ABCMeta):
+ """Base module for all modules in openmmlab.
+ ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
+ functionality of parameter initialization. Compared with
+ ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
+ - ``init_cfg``: the config to control the initialization.
+ - ``init_weights``: The function of parameter
+ initialization and recording initialization
+ information.
+ - ``_params_init_info``: Used to track the parameter
+ initialization information. This attribute only
+ exists during executing the ``init_weights``.
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+ def __init__(self, init_cfg=None):
+ """Initialize BaseModule, inherited from `torch.nn.Module`"""
+ # NOTE init_cfg can be defined in different levels, but init_cfg
+ # in low levels has a higher priority.
+ super(BaseModule, self).__init__()
+ # define default value of init_cfg instead of hard code
+ # in init_weights() function
+ self._is_init = False
+ self.init_cfg = copy.deepcopy(init_cfg)
+ # Backward compatibility in derived classes
+ # if pretrained is not None:
+ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
+ # key, please consider using init_cfg')
+ # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ @property
+ def is_init(self):
+ return self._is_init
+ def init_weights(self):
+ """Initialize the weights."""
+ is_top_level_module = False
+ # check if it is top-level module
+ if not hasattr(self, '_params_init_info'):
+ # The `_params_init_info` is used to record the initialization
+ # information of the parameters
+ # the key should be the obj:`nn.Parameter` of model and the value
+ # should be a dict containing
+ # - init_info (str): The string that describes the initialization.
+ # - tmp_mean_value (FloatTensor): The mean of the parameter,
+ # which indicates whether the parameter has been modified.
+ # this attribute would be deleted after all parameters
+ # is initialized.
+ self._params_init_info = defaultdict(dict)
+ is_top_level_module = True
+ # Initialize the `_params_init_info`,
+ # When detecting the `tmp_mean_value` of
+ # the corresponding parameter is changed, update related
+ # initialization information
+ for name, param in self.named_parameters():
+ self._params_init_info[param][
+ 'init_info'] = f'The value is the same before and ' \
+ f'after calling `init_weights` ' \
+ f'of {self.__class__.__name__} '
+ self._params_init_info[param][
+ 'tmp_mean_value'] = param.data.mean()
+ # pass `params_init_info` to all submodules
+ # All submodules share the same `params_init_info`,
+ # so it will be updated when parameters are
+ # modified at any level of the model.
+ for sub_module in self.modules():
+ sub_module._params_init_info = self._params_init_info
+ # Get the initialized logger, if not exist,
+ # create a logger named `mmcv`
+ logger_names = list(logger_initialized.keys())
+ logger_name = logger_names[0] if logger_names else 'mmcv'
+ from ..cnn import initialize
+ from ..cnn.utils.weight_init import update_init_info
+ module_name = self.__class__.__name__
+ if not self._is_init:
+ if self.init_cfg:
+ print_log(
+ f'initialize {module_name} with init_cfg {self.init_cfg}',
+ logger=logger_name)
+ initialize(self, self.init_cfg)
+ if isinstance(self.init_cfg, dict):
+ # prevent the parameters of
+ # the pre-trained model
+ # from being overwritten by
+ # the `init_weights`
+ if self.init_cfg['type'] == 'Pretrained':
+ return
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights()
+ # users may overload the `init_weights`
+ update_init_info(
+ m,
+ init_info=f'Initialized by '
+ f'user-defined `init_weights`'
+ f' in {m.__class__.__name__} ')
+ self._is_init = True
+ else:
+ warnings.warn(f'init_weights of {self.__class__.__name__} has '
+ f'been called more than once.')
+ if is_top_level_module:
+ self._dump_init_info(logger_name)
+ for sub_module in self.modules():
+ del sub_module._params_init_info
+ @master_only
+ def _dump_init_info(self, logger_name):
+ """Dump the initialization information to a file named
+ `initialization.log.json` in workdir.
+ Args:
+ logger_name (str): The name of logger.
+ """
+ logger = get_logger(logger_name)
+ with_file_handler = False
+ # dump the information to the logger file if there is a `FileHandler`
+ for handler in logger.handlers:
+ if isinstance(handler, FileHandler):
+ handler.stream.write(
+ 'Name of parameter - Initialization information\n')
+ for name, param in self.named_parameters():
+ handler.stream.write(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n")
+ handler.stream.flush()
+ with_file_handler = True
+ if not with_file_handler:
+ for name, param in self.named_parameters():
+ print_log(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n ",
+ logger=logger_name)
+ def __repr__(self):
+ s = super().__repr__()
+ if self.init_cfg:
+ s += f'\ninit_cfg={self.init_cfg}'
+ return s
+class Sequential(BaseModule, nn.Sequential):
+ """Sequential module in openmmlab.
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+ def __init__(self, *args, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.Sequential.__init__(self, *args)
+class ModuleList(BaseModule, nn.ModuleList):
+ """ModuleList in openmmlab.
+ Args:
+ modules (iterable, optional): an iterable of modules to add.
+ init_cfg (dict, optional): Initialization config dict.
+ """
+ def __init__(self, modules=None, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.ModuleList.__init__(self, modules)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/base_runner.py b/ControlNet/annotator/uniformer/mmcv/runner/base_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..4928db0a73b56fe0218a4bf66ec4ffa082d31ccc
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/base_runner.py
@@ -0,0 +1,542 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import logging
+import os.path as osp
+import warnings
+from abc import ABCMeta, abstractmethod
+import torch
+from torch.optim import Optimizer
+import annotator.uniformer.mmcv as mmcv
+from ..parallel import is_module_wrapper
+from .checkpoint import load_checkpoint
+from .dist_utils import get_dist_info
+from .hooks import HOOKS, Hook
+from .log_buffer import LogBuffer
+from .priority import Priority, get_priority
+from .utils import get_time_str
+class BaseRunner(metaclass=ABCMeta):
+ """The base class of Runner, a training helper for PyTorch.
+ All subclasses should implement the following APIs:
+ - ``run()``
+ - ``train()``
+ - ``val()``
+ - ``save_checkpoint()``
+ Args:
+ model (:obj:`torch.nn.Module`): The model to be run.
+ batch_processor (callable): A callable method that process a data
+ batch. The interface of this method should be
+ `batch_processor(model, data, train_mode) -> dict`
+ optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
+ optimizer (in most cases) or a dict of optimizers (in models that
+ requires more than one optimizer, e.g., GAN).
+ work_dir (str, optional): The working directory to save checkpoints
+ and logs. Defaults to None.
+ logger (:obj:`logging.Logger`): Logger used during training.
+ Defaults to None. (The default value is just for backward
+ compatibility)
+ meta (dict | None): A dict records some import information such as
+ environment info and seed, which will be logged in logger hook.
+ Defaults to None.
+ max_epochs (int, optional): Total training epochs.
+ max_iters (int, optional): Total training iterations.
+ """
+ def __init__(self,
+ model,
+ batch_processor=None,
+ optimizer=None,
+ work_dir=None,
+ logger=None,
+ meta=None,
+ max_iters=None,
+ max_epochs=None):
+ if batch_processor is not None:
+ if not callable(batch_processor):
+ raise TypeError('batch_processor must be callable, '
+ f'but got {type(batch_processor)}')
+ warnings.warn('batch_processor is deprecated, please implement '
+ 'train_step() and val_step() in the model instead.')
+ # raise an error is `batch_processor` is not None and
+ # `model.train_step()` exists.
+ if is_module_wrapper(model):
+ _model = model.module
+ else:
+ _model = model
+ if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
+ raise RuntimeError(
+ 'batch_processor and model.train_step()/model.val_step() '
+ 'cannot be both available.')
+ else:
+ assert hasattr(model, 'train_step')
+ # check the type of `optimizer`
+ if isinstance(optimizer, dict):
+ for name, optim in optimizer.items():
+ if not isinstance(optim, Optimizer):
+ raise TypeError(
+ f'optimizer must be a dict of torch.optim.Optimizers, '
+ f'but optimizer["{name}"] is a {type(optim)}')
+ elif not isinstance(optimizer, Optimizer) and optimizer is not None:
+ raise TypeError(
+ f'optimizer must be a torch.optim.Optimizer object '
+ f'or dict or None, but got {type(optimizer)}')
+ # check the type of `logger`
+ if not isinstance(logger, logging.Logger):
+ raise TypeError(f'logger must be a logging.Logger object, '
+ f'but got {type(logger)}')
+ # check the type of `meta`
+ if meta is not None and not isinstance(meta, dict):
+ raise TypeError(
+ f'meta must be a dict or None, but got {type(meta)}')
+ self.model = model
+ self.batch_processor = batch_processor
+ self.optimizer = optimizer
+ self.logger = logger
+ self.meta = meta
+ # create work_dir
+ if mmcv.is_str(work_dir):
+ self.work_dir = osp.abspath(work_dir)
+ mmcv.mkdir_or_exist(self.work_dir)
+ elif work_dir is None:
+ self.work_dir = None
+ else:
+ raise TypeError('"work_dir" must be a str or None')
+ # get model name from the model class
+ if hasattr(self.model, 'module'):
+ self._model_name = self.model.module.__class__.__name__
+ else:
+ self._model_name = self.model.__class__.__name__
+ self._rank, self._world_size = get_dist_info()
+ self.timestamp = get_time_str()
+ self.mode = None
+ self._hooks = []
+ self._epoch = 0
+ self._iter = 0
+ self._inner_iter = 0
+ if max_epochs is not None and max_iters is not None:
+ raise ValueError(
+ 'Only one of `max_epochs` or `max_iters` can be set.')
+ self._max_epochs = max_epochs
+ self._max_iters = max_iters
+ # TODO: Redesign LogBuffer, it is not flexible and elegant enough
+ self.log_buffer = LogBuffer()
+ @property
+ def model_name(self):
+ """str: Name of the model, usually the module class name."""
+ return self._model_name
+ @property
+ def rank(self):
+ """int: Rank of current process. (distributed training)"""
+ return self._rank
+ @property
+ def world_size(self):
+ """int: Number of processes participating in the job.
+ (distributed training)"""
+ return self._world_size
+ @property
+ def hooks(self):
+ """list[:obj:`Hook`]: A list of registered hooks."""
+ return self._hooks
+ @property
+ def epoch(self):
+ """int: Current epoch."""
+ return self._epoch
+ @property
+ def iter(self):
+ """int: Current iteration."""
+ return self._iter
+ @property
+ def inner_iter(self):
+ """int: Iteration in an epoch."""
+ return self._inner_iter
+ @property
+ def max_epochs(self):
+ """int: Maximum training epochs."""
+ return self._max_epochs
+ @property
+ def max_iters(self):
+ """int: Maximum training iterations."""
+ return self._max_iters
+ @abstractmethod
+ def train(self):
+ pass
+ @abstractmethod
+ def val(self):
+ pass
+ @abstractmethod
+ def run(self, data_loaders, workflow, **kwargs):
+ pass
+ @abstractmethod
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl,
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ pass
+ def current_lr(self):
+ """Get current learning rates.
+ Returns:
+ list[float] | dict[str, list[float]]: Current learning rates of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+ if isinstance(self.optimizer, torch.optim.Optimizer):
+ lr = [group['lr'] for group in self.optimizer.param_groups]
+ elif isinstance(self.optimizer, dict):
+ lr = dict()
+ for name, optim in self.optimizer.items():
+ lr[name] = [group['lr'] for group in optim.param_groups]
+ else:
+ raise RuntimeError(
+ 'lr is not applicable because optimizer does not exist.')
+ return lr
+ def current_momentum(self):
+ """Get current momentums.
+ Returns:
+ list[float] | dict[str, list[float]]: Current momentums of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+ def _get_momentum(optimizer):
+ momentums = []
+ for group in optimizer.param_groups:
+ if 'momentum' in group.keys():
+ momentums.append(group['momentum'])
+ elif 'betas' in group.keys():
+ momentums.append(group['betas'][0])
+ else:
+ momentums.append(0)
+ return momentums
+ if self.optimizer is None:
+ raise RuntimeError(
+ 'momentum is not applicable because optimizer does not exist.')
+ elif isinstance(self.optimizer, torch.optim.Optimizer):
+ momentums = _get_momentum(self.optimizer)
+ elif isinstance(self.optimizer, dict):
+ momentums = dict()
+ for name, optim in self.optimizer.items():
+ momentums[name] = _get_momentum(optim)
+ return momentums
+ def register_hook(self, hook, priority='NORMAL'):
+ """Register a hook into the hook list.
+ The hook will be inserted into a priority queue, with the specified
+ priority (See :class:`Priority` for details of priorities).
+ For hooks with the same priority, they will be triggered in the same
+ order as they are registered.
+ Args:
+ hook (:obj:`Hook`): The hook to be registered.
+ priority (int or str or :obj:`Priority`): Hook priority.
+ Lower value means higher priority.
+ """
+ assert isinstance(hook, Hook)
+ if hasattr(hook, 'priority'):
+ raise ValueError('"priority" is a reserved attribute for hooks')
+ priority = get_priority(priority)
+ hook.priority = priority
+ # insert the hook to a sorted list
+ inserted = False
+ for i in range(len(self._hooks) - 1, -1, -1):
+ if priority >= self._hooks[i].priority:
+ self._hooks.insert(i + 1, hook)
+ inserted = True
+ break
+ if not inserted:
+ self._hooks.insert(0, hook)
+ def register_hook_from_cfg(self, hook_cfg):
+ """Register a hook from its cfg.
+ Args:
+ hook_cfg (dict): Hook config. It should have at least keys 'type'
+ and 'priority' indicating its type and priority.
+ Notes:
+ The specific hook class to register should not use 'type' and
+ 'priority' arguments during initialization.
+ """
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop('priority', 'NORMAL')
+ hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
+ self.register_hook(hook, priority=priority)
+ def call_hook(self, fn_name):
+ """Call all hooks.
+ Args:
+ fn_name (str): The function name in each hook to be called, such as
+ "before_train_epoch".
+ """
+ for hook in self._hooks:
+ getattr(hook, fn_name)(self)
+ def get_hook_info(self):
+ # Get hooks info in each stage
+ stage_hook_map = {stage: [] for stage in Hook.stages}
+ for hook in self.hooks:
+ try:
+ priority = Priority(hook.priority).name
+ except ValueError:
+ priority = hook.priority
+ classname = hook.__class__.__name__
+ hook_info = f'({priority:<12}) {classname:<35}'
+ for trigger_stage in hook.get_triggered_stages():
+ stage_hook_map[trigger_stage].append(hook_info)
+ stage_hook_infos = []
+ for stage in Hook.stages:
+ hook_infos = stage_hook_map[stage]
+ if len(hook_infos) > 0:
+ info = f'{stage}:\n'
+ info += '\n'.join(hook_infos)
+ info += '\n -------------------- '
+ stage_hook_infos.append(info)
+ return '\n'.join(stage_hook_infos)
+ def load_checkpoint(self,
+ filename,
+ map_location='cpu',
+ strict=False,
+ revise_keys=[(r'^module.', '')]):
+ return load_checkpoint(
+ self.model,
+ filename,
+ map_location,
+ strict,
+ self.logger,
+ revise_keys=revise_keys)
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ if map_location == 'default':
+ if torch.cuda.is_available():
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(checkpoint)
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ if self.meta is None:
+ self.meta = {}
+ self.meta.setdefault('hook_msgs', {})
+ # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
+ self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
+ # Re-calculate the number of iterations when resuming
+ # models with different number of GPUs
+ if 'config' in checkpoint['meta']:
+ config = mmcv.Config.fromstring(
+ checkpoint['meta']['config'], file_format='.py')
+ previous_gpu_ids = config.get('gpu_ids', None)
+ if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
+ previous_gpu_ids) != self.world_size:
+ self._iter = int(self._iter * len(previous_gpu_ids) /
+ self.world_size)
+ self.logger.info('the iteration number is changed due to '
+ 'change of GPU number')
+ # resume meta information meta
+ self.meta = checkpoint['meta']
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
+ def register_lr_hook(self, lr_config):
+ if lr_config is None:
+ return
+ elif isinstance(lr_config, dict):
+ assert 'policy' in lr_config
+ policy_type = lr_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of Lr updater.
+ # Since this is not applicable for `
+ # CosineAnnealingLrUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'LrUpdaterHook'
+ lr_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(lr_config, HOOKS)
+ else:
+ hook = lr_config
+ self.register_hook(hook, priority='VERY_HIGH')
+ def register_momentum_hook(self, momentum_config):
+ if momentum_config is None:
+ return
+ if isinstance(momentum_config, dict):
+ assert 'policy' in momentum_config
+ policy_type = momentum_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of momentum updater.
+ # Since this is not applicable for
+ # `CosineAnnealingMomentumUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'MomentumUpdaterHook'
+ momentum_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(momentum_config, HOOKS)
+ else:
+ hook = momentum_config
+ self.register_hook(hook, priority='HIGH')
+ def register_optimizer_hook(self, optimizer_config):
+ if optimizer_config is None:
+ return
+ if isinstance(optimizer_config, dict):
+ optimizer_config.setdefault('type', 'OptimizerHook')
+ hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
+ else:
+ hook = optimizer_config
+ self.register_hook(hook, priority='ABOVE_NORMAL')
+ def register_checkpoint_hook(self, checkpoint_config):
+ if checkpoint_config is None:
+ return
+ if isinstance(checkpoint_config, dict):
+ checkpoint_config.setdefault('type', 'CheckpointHook')
+ hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
+ else:
+ hook = checkpoint_config
+ self.register_hook(hook, priority='NORMAL')
+ def register_logger_hooks(self, log_config):
+ if log_config is None:
+ return
+ log_interval = log_config['interval']
+ for info in log_config['hooks']:
+ logger_hook = mmcv.build_from_cfg(
+ info, HOOKS, default_args=dict(interval=log_interval))
+ self.register_hook(logger_hook, priority='VERY_LOW')
+ def register_timer_hook(self, timer_config):
+ if timer_config is None:
+ return
+ if isinstance(timer_config, dict):
+ timer_config_ = copy.deepcopy(timer_config)
+ hook = mmcv.build_from_cfg(timer_config_, HOOKS)
+ else:
+ hook = timer_config
+ self.register_hook(hook, priority='LOW')
+ def register_custom_hooks(self, custom_config):
+ if custom_config is None:
+ return
+ if not isinstance(custom_config, list):
+ custom_config = [custom_config]
+ for item in custom_config:
+ if isinstance(item, dict):
+ self.register_hook_from_cfg(item)
+ else:
+ self.register_hook(item, priority='NORMAL')
+ def register_profiler_hook(self, profiler_config):
+ if profiler_config is None:
+ return
+ if isinstance(profiler_config, dict):
+ profiler_config.setdefault('type', 'ProfilerHook')
+ hook = mmcv.build_from_cfg(profiler_config, HOOKS)
+ else:
+ hook = profiler_config
+ self.register_hook(hook)
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ timer_config=dict(type='IterTimerHook'),
+ custom_hooks_config=None):
+ """Register default and custom hooks for training.
+ Default and custom hooks include:
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ self.register_lr_hook(lr_config)
+ self.register_momentum_hook(momentum_config)
+ self.register_optimizer_hook(optimizer_config)
+ self.register_checkpoint_hook(checkpoint_config)
+ self.register_timer_hook(timer_config)
+ self.register_logger_hooks(log_config)
+ self.register_custom_hooks(custom_hooks_config)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/builder.py b/ControlNet/annotator/uniformer/mmcv/runner/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c96ba0b2f30ead9da23f293c5dc84dd3e4a74f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/builder.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from ..utils import Registry
+RUNNERS = Registry('runner')
+RUNNER_BUILDERS = Registry('runner builder')
+def build_runner_constructor(cfg):
+ return RUNNER_BUILDERS.build(cfg)
+def build_runner(cfg, default_args=None):
+ runner_cfg = copy.deepcopy(cfg)
+ constructor_type = runner_cfg.pop('constructor',
+ 'DefaultRunnerConstructor')
+ runner_constructor = build_runner_constructor(
+ dict(
+ type=constructor_type,
+ runner_cfg=runner_cfg,
+ default_args=default_args))
+ runner = runner_constructor()
+ return runner
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/checkpoint.py b/ControlNet/annotator/uniformer/mmcv/runner/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b29ca320679164432f446adad893e33fb2b4b29e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/checkpoint.py
@@ -0,0 +1,707 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import re
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+import annotator.uniformer.mmcv as mmcv
+from ..fileio import FileClient
+from ..fileio import load as load_file
+from ..parallel import is_module_wrapper
+from ..utils import mkdir_or_exist
+from .dist_utils import get_dist_info
+DEFAULT_CACHE_DIR = '~/.cache'
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+ load(module)
+ load = None # break load->load reference cycle
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+ return default_urls
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+ return mmcls_urls
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+ return deprecate_urls
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+ return new_checkpoint
+class CheckpointLoader:
+ """A general checkpoint loader to manage all schemes."""
+ _schemes = {}
+ @classmethod
+ def _register_scheme(cls, prefixes, loader, force=False):
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if (prefix not in cls._schemes) or force:
+ cls._schemes[prefix] = loader
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a loader backend, '
+ 'add "force=True" if you want to override it')
+ # sort, longer prefixes take priority
+ cls._schemes = OrderedDict(
+ sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
+ @classmethod
+ def register_scheme(cls, prefixes, loader=None, force=False):
+ """Register a loader to CheckpointLoader.
+ This method can be used as a normal class method or a decorator.
+ Args:
+ prefixes (str or list[str] or tuple[str]):
+ The prefix of the registered loader.
+ loader (function, optional): The loader function to be registered.
+ When this method is used as a decorator, loader is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the loader
+ if the prefix has already been registered. Defaults to False.
+ """
+ if loader is not None:
+ cls._register_scheme(prefixes, loader, force=force)
+ return
+ def _register(loader_cls):
+ cls._register_scheme(prefixes, loader_cls, force=force)
+ return loader_cls
+ return _register
+ @classmethod
+ def _get_checkpoint_loader(cls, path):
+ """Finds a loader that supports the given path. Falls back to the local
+ loader if no other loader is found.
+ Args:
+ path (str): checkpoint path
+ Returns:
+ loader (function): checkpoint loader
+ """
+ for p in cls._schemes:
+ if path.startswith(p):
+ return cls._schemes[p]
+ @classmethod
+ def load_checkpoint(cls, filename, map_location=None, logger=None):
+ """load checkpoint through URL scheme path.
+ Args:
+ filename (str): checkpoint file name with given prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+ logger (:mod:`logging.Logger`, optional): The logger for message.
+ Default: None
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint_loader = cls._get_checkpoint_loader(filename)
+ class_name = checkpoint_loader.__name__
+ mmcv.print_log(
+ f'load checkpoint from {class_name[10:]} path: {filename}', logger)
+ return checkpoint_loader(filename, map_location)
+def load_from_local(filename, map_location):
+ """load checkpoint by local file path.
+ Args:
+ filename (str): local checkpoint file path
+ map_location (str, optional): Same as :func:`torch.load`.
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
+def load_from_http(filename, map_location=None, model_dir=None):
+ """load checkpoint through HTTP or HTTPS scheme path. In distributed
+ setting, this function only download checkpoint at local rank 0.
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ model_dir (string, optional): directory in which to save the object,
+ Default: None
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ return checkpoint
+def load_from_pavi(filename, map_location=None):
+ """load checkpoint through the file path prefixed with pavi. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+ Args:
+ filename (str): checkpoint file path with pavi prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ assert filename.startswith('pavi://'), \
+ f'Expected filename startswith `pavi://`, but get {filename}'
+ model_path = filename[7:]
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ return checkpoint
+def load_from_ceph(filename, map_location=None, backend='petrel'):
+ """load checkpoint through the file path prefixed with s3. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+ Args:
+ filename (str): checkpoint file path with s3 prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ backend (str, optional): The storage backend type. Options are 'ceph',
+ 'petrel'. Default: 'petrel'.
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ allowed_backends = ['ceph', 'petrel']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+ if backend == 'ceph':
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+ # CephClient and PetrelBackend have the same prefix 's3://' and the latter
+ # will be chosen as default. If PetrelBackend can not be instantiated
+ # successfully, the CephClient will be chosen.
+ try:
+ file_client = FileClient(backend=backend)
+ except ImportError:
+ allowed_backends.remove(backend)
+ file_client = FileClient(backend=allowed_backends[0])
+ with io.BytesIO(file_client.get(filename)) as buffer:
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
+def load_from_torchvision(filename, map_location=None):
+ """load checkpoint through the file path prefixed with modelzoo or
+ torchvision.
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ model_urls = get_torchvision_models()
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_name = filename[11:]
+ else:
+ model_name = filename[14:]
+ return load_from_http(model_urls[model_name], map_location=map_location)
+@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
+def load_from_openmmlab(filename, map_location=None):
+ """load checkpoint through the file path prefixed with open-mmlab or
+ openmmlab.
+ Args:
+ filename (str): checkpoint file path with open-mmlab or
+ openmmlab prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ model_urls = get_external_models()
+ prefix_str = 'open-mmlab://'
+ if filename.startswith(prefix_str):
+ model_name = filename[13:]
+ else:
+ model_name = filename[12:]
+ prefix_str = 'openmmlab://'
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
+ f'of {prefix_str}{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_from_http(model_url, map_location=map_location)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+def load_from_mmcls(filename, map_location=None):
+ """load checkpoint through the file path prefixed with mmcls.
+ Args:
+ filename (str): checkpoint file path with mmcls prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_from_http(
+ model_urls[model_name], map_location=map_location)
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ return checkpoint
+def _load_checkpoint(filename, map_location=None, logger=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None.
+ logger (:mod:`logging.Logger`, optional): The logger for error message.
+ Default: None
+ Returns:
+ dict or OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ return CheckpointLoader.load_checkpoint(filename, map_location, logger)
+def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
+ """Load partial pretrained model with specific prefix.
+ Args:
+ prefix (str): The prefix of sub-module.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location=map_location)
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if not prefix.endswith('.'):
+ prefix += '.'
+ prefix_len = len(prefix)
+ state_dict = {
+ k[prefix_len:]: v
+ for k, v in state_dict.items() if k.startswith(prefix)
+ }
+ assert state_dict, f'{prefix} is not in the pretrained model'
+ return state_dict
+def load_checkpoint(model,
+ filename,
+ map_location=None,
+ strict=False,
+ logger=None,
+ revise_keys=[(r'^module\.', '')]):
+ """Load checkpoint from a file or URI.
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+ revise_keys (list): A list of customized keywords to modify the
+ state_dict in checkpoint. Each item is a (pattern, replacement)
+ pair of the regular expression operations. Default: strip
+ the prefix 'module.' by [(r'^module\\.', '')].
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location, logger)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ # strip prefix of state_dict
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
+ for p, r in revise_keys:
+ state_dict = OrderedDict(
+ {re.sub(p, r, k): v
+ for k, v in state_dict.items()})
+ # Keep metadata in state_dict
+ state_dict._metadata = metadata
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ # Keep metadata in state_dict
+ state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
+ return state_dict_cpu
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+def save_checkpoint(model,
+ filename,
+ optimizer=None,
+ meta=None,
+ file_client_args=None):
+ """Save checkpoint to file.
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+ if is_module_wrapper(model):
+ model = model.module
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+ if filename.startswith('pavi://'):
+ if file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" if filename starts with'
+ f'"pavi://", but got {file_client_args}')
+ try:
+ from pavi import modelcloud
+ from pavi import exception
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except exception.NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with io.BytesIO() as f:
+ torch.save(checkpoint, f)
+ file_client.put(f.getvalue(), filename)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/default_constructor.py b/ControlNet/annotator/uniformer/mmcv/runner/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f1f5b44168768dfda3947393a63a6cf9cf50b41
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/default_constructor.py
@@ -0,0 +1,44 @@
+from .builder import RUNNER_BUILDERS, RUNNERS
+class DefaultRunnerConstructor:
+ """Default constructor for runners.
+ Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
+ For example, We can inject some new properties and functions for `Runner`.
+ Example:
+ >>> from annotator.uniformer.mmcv.runner import RUNNER_BUILDERS, build_runner
+ >>> # Define a new RunnerReconstructor
+ >>> @RUNNER_BUILDERS.register_module()
+ >>> class MyRunnerConstructor:
+ ... def __init__(self, runner_cfg, default_args=None):
+ ... if not isinstance(runner_cfg, dict):
+ ... raise TypeError('runner_cfg should be a dict',
+ ... f'but got {type(runner_cfg)}')
+ ... self.runner_cfg = runner_cfg
+ ... self.default_args = default_args
+ ...
+ ... def __call__(self):
+ ... runner = RUNNERS.build(self.runner_cfg,
+ ... default_args=self.default_args)
+ ... # Add new properties for existing runner
+ ... runner.my_name = 'my_runner'
+ ... runner.my_function = lambda self: print(self.my_name)
+ ... ...
+ >>> # build your runner
+ >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
+ ... constructor='MyRunnerConstructor')
+ >>> runner = build_runner(runner_cfg)
+ """
+ def __init__(self, runner_cfg, default_args=None):
+ if not isinstance(runner_cfg, dict):
+ raise TypeError('runner_cfg should be a dict',
+ f'but got {type(runner_cfg)}')
+ self.runner_cfg = runner_cfg
+ self.default_args = default_args
+ def __call__(self):
+ return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/dist_utils.py b/ControlNet/annotator/uniformer/mmcv/runner/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a1ef3fda5ceeb31bf15a73779da1b1903ab0fe
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/dist_utils.py
@@ -0,0 +1,164 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import os
+import subprocess
+from collections import OrderedDict
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+def _init_dist_pytorch(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+def _init_dist_mpi(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ # use MASTER_ADDR in the environment variable if it already exists
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+def master_only(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+ return wrapper
+def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce parameters.
+ Args:
+ params (list[torch.Parameters]): List of parameters or buffers of a
+ model.
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ params = [param.data for param in params]
+ if coalesce:
+ _allreduce_coalesced(params, world_size, bucket_size_mb)
+ else:
+ for tensor in params:
+ dist.all_reduce(tensor.div_(world_size))
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce gradients.
+ Args:
+ params (list[torch.Parameters]): List of parameters of a model
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ grads = [
+ param.grad.data for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+ tensor.copy_(synced)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/epoch_based_runner.py b/ControlNet/annotator/uniformer/mmcv/runner/epoch_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..766a9ce6afdf09cd11b1b15005f5132583011348
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/epoch_based_runner.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+import torch
+import annotator.uniformer.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .utils import get_host_info
+class EpochBasedRunner(BaseRunner):
+ """Epoch-based Runner.
+ This runner train models epoch by epoch.
+ """
+ def run_iter(self, data_batch, train_mode, **kwargs):
+ if self.batch_processor is not None:
+ outputs = self.batch_processor(
+ self.model, data_batch, train_mode=train_mode, **kwargs)
+ elif train_mode:
+ outputs = self.model.train_step(data_batch, self.optimizer,
+ **kwargs)
+ else:
+ outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('"batch_processor()" or "model.train_step()"'
+ 'and "model.val_step()" must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._max_iters = self._max_epochs * len(self.data_loader)
+ self.call_hook('before_train_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_train_iter')
+ self.run_iter(data_batch, train_mode=True, **kwargs)
+ self.call_hook('after_train_iter')
+ self._iter += 1
+ self.call_hook('after_train_epoch')
+ self._epoch += 1
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ self.call_hook('before_val_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_val_iter')
+ self.run_iter(data_batch, train_mode=False)
+ self.call_hook('after_val_iter')
+ self.call_hook('after_val_epoch')
+ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
+ """Start running.
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, epochs) to specify the
+ running order and epochs. E.g, [('train', 2), ('val', 1)] means
+ running 2 epochs for training and 1 epoch for validation,
+ iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_epochs is not None:
+ warnings.warn(
+ 'setting max_epochs in run is deprecated, '
+ 'please set max_epochs in runner_config', DeprecationWarning)
+ self._max_epochs = max_epochs
+ assert self._max_epochs is not None, (
+ 'max_epochs must be specified during instantiation')
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if mode == 'train':
+ self._max_iters = self._max_epochs * len(data_loaders[i])
+ break
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d epochs', workflow,
+ self._max_epochs)
+ self.call_hook('before_run')
+ while self.epoch < self._max_epochs:
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if isinstance(mode, str): # self.train()
+ if not hasattr(self, mode):
+ raise ValueError(
+ f'runner has no method named "{mode}" to run an '
+ 'epoch')
+ epoch_runner = getattr(self, mode)
+ else:
+ raise TypeError(
+ 'mode in workflow must be a str, but got {}'.format(
+ type(mode)))
+ for _ in range(epochs):
+ if mode == 'train' and self.epoch >= self._max_epochs:
+ break
+ epoch_runner(data_loaders[i], **kwargs)
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_run')
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='epoch_{}.pth',
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ """Save the checkpoint.
+ Args:
+ out_dir (str): The directory that checkpoints are saved.
+ filename_tmpl (str, optional): The checkpoint filename template,
+ which contains a placeholder for the epoch number.
+ Defaults to 'epoch_{}.pth'.
+ save_optimizer (bool, optional): Whether to save the optimizer to
+ the checkpoint. Defaults to True.
+ meta (dict, optional): The meta information to be saved in the
+ checkpoint. Defaults to None.
+ create_symlink (bool, optional): Whether to create a symlink
+ "latest.pth" to point to the latest checkpoint.
+ Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+ filename = filename_tmpl.format(self.epoch + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+class Runner(EpochBasedRunner):
+ """Deprecated name of EpochBasedRunner."""
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Runner was deprecated, please use EpochBasedRunner instead')
+ super().__init__(*args, **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/fp16_utils.py b/ControlNet/annotator/uniformer/mmcv/runner/fp16_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1981011d6859192e3e663e29d13500d56ba47f6c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/fp16_utils.py
@@ -0,0 +1,410 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import warnings
+from collections import abc
+from inspect import getfullargspec
+import numpy as np
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .dist_utils import allreduce_grads as _allreduce_grads
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
+ # manually, so the behavior may not be consistent with real amp.
+ from torch.cuda.amp import autocast
+except ImportError:
+ pass
+def cast_tensor_type(inputs, src_type, dst_type):
+ """Recursively convert Tensor in inputs from src_type to dst_type.
+ Args:
+ inputs: Inputs that to be casted.
+ src_type (torch.dtype): Source type..
+ dst_type (torch.dtype): Destination type.
+ Returns:
+ The same type with inputs, but all contained Tensors have been cast.
+ """
+ if isinstance(inputs, nn.Module):
+ return inputs
+ elif isinstance(inputs, torch.Tensor):
+ return inputs.to(dst_type)
+ elif isinstance(inputs, str):
+ return inputs
+ elif isinstance(inputs, np.ndarray):
+ return inputs
+ elif isinstance(inputs, abc.Mapping):
+ return type(inputs)({
+ k: cast_tensor_type(v, src_type, dst_type)
+ for k, v in inputs.items()
+ })
+ elif isinstance(inputs, abc.Iterable):
+ return type(inputs)(
+ cast_tensor_type(item, src_type, dst_type) for item in inputs)
+ else:
+ return inputs
+def auto_fp16(apply_to=None, out_fp32=False):
+ """Decorator to enable fp16 training automatically.
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If inputs arguments are fp32 tensors, they will
+ be converted to fp16 automatically. Arguments other than fp32 tensors are
+ ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp32 (bool): Whether to convert the output back to fp32.
+ Example:
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp16
+ >>> @auto_fp16()
+ >>> def forward(self, x, y):
+ >>> pass
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp16
+ >>> @auto_fp16(apply_to=('pred', ))
+ >>> def do_something(self, pred, others):
+ >>> pass
+ """
+ def auto_fp16_wrapper(old_func):
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@auto_fp16 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ # NOTE: default args are not taken into consideration
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.float, torch.half))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = {}
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.float, torch.half)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=True):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp32:
+ output = cast_tensor_type(output, torch.half, torch.float)
+ return output
+ return new_func
+ return auto_fp16_wrapper
+def force_fp32(apply_to=None, out_fp16=False):
+ """Decorator to convert input arguments to fp32 in force.
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If there are some inputs that must be processed
+ in fp32 mode, then this decorator can handle it. If inputs arguments are
+ fp16 tensors, they will be converted to fp32 automatically. Arguments other
+ than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
+ torch.cuda.amp is used as the backend, otherwise, original mmcv
+ implementation will be adopted.
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp16 (bool): Whether to convert the output back to fp16.
+ Example:
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp32
+ >>> @force_fp32()
+ >>> def loss(self, x, y):
+ >>> pass
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp32
+ >>> @force_fp32(apply_to=('pred', ))
+ >>> def post_process(self, pred, others):
+ >>> pass
+ """
+ def force_fp32_wrapper(old_func):
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@force_fp32 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.half, torch.float))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = dict()
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.half, torch.float)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=False):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp16:
+ output = cast_tensor_type(output, torch.float, torch.half)
+ return output
+ return new_func
+ return force_fp32_wrapper
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ warnings.warning(
+ '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
+ 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
+ _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
+def wrap_fp16_model(model):
+ """Wrap the FP32 model to FP16.
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+ For PyTorch >= 1.6, this function will
+ 1. Set fp16 flag inside the model to True.
+ Otherwise:
+ 1. Convert FP32 model to FP16.
+ 2. Remain some necessary layers to be FP32, e.g., normalization layers.
+ 3. Set `fp16_enabled` flag inside the model to True.
+ Args:
+ model (nn.Module): Model in FP32.
+ """
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
+ # convert model to fp16
+ model.half()
+ # patch the normalization layers to make it work in fp32 mode
+ patch_norm_fp32(model)
+ # set `fp16_enabled` flag
+ for m in model.modules():
+ if hasattr(m, 'fp16_enabled'):
+ m.fp16_enabled = True
+def patch_norm_fp32(module):
+ """Recursively convert normalization layers from FP16 to FP32.
+ Args:
+ module (nn.Module): The modules to be converted in FP16.
+ Returns:
+ nn.Module: The converted module, the normalization layers have been
+ converted to FP32.
+ """
+ if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
+ module.float()
+ if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
+ module.forward = patch_forward_method(module.forward, torch.half,
+ torch.float)
+ for child in module.children():
+ patch_norm_fp32(child)
+ return module
+def patch_forward_method(func, src_type, dst_type, convert_output=True):
+ """Patch the forward method of a module.
+ Args:
+ func (callable): The original forward method.
+ src_type (torch.dtype): Type of input arguments to be converted from.
+ dst_type (torch.dtype): Type of input arguments to be converted to.
+ convert_output (bool): Whether to convert the output back to src_type.
+ Returns:
+ callable: The patched forward method.
+ """
+ def new_forward(*args, **kwargs):
+ output = func(*cast_tensor_type(args, src_type, dst_type),
+ **cast_tensor_type(kwargs, src_type, dst_type))
+ if convert_output:
+ output = cast_tensor_type(output, dst_type, src_type)
+ return output
+ return new_forward
+class LossScaler:
+ """Class that manages loss scaling in mixed precision training which
+ supports both dynamic or static mode.
+ The implementation refers to
+ https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
+ Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
+ It's important to understand how :class:`LossScaler` operates.
+ Loss scaling is designed to combat the problem of underflowing
+ gradients encountered at long times when training fp16 networks.
+ Dynamic loss scaling begins by attempting a very high loss
+ scale. Ironically, this may result in OVERflowing gradients.
+ If overflowing gradients are encountered, :class:`FP16_Optimizer` then
+ skips the update step for this particular iteration/minibatch,
+ and :class:`LossScaler` adjusts the loss scale to a lower value.
+ If a certain number of iterations occur without overflowing gradients
+ detected,:class:`LossScaler` increases the loss scale once more.
+ In this way :class:`LossScaler` attempts to "ride the edge" of always
+ using the highest loss scale possible without incurring overflow.
+ Args:
+ init_scale (float): Initial loss scale value, default: 2**32.
+ scale_factor (float): Factor used when adjusting the loss scale.
+ Default: 2.
+ mode (str): Loss scaling mode. 'dynamic' or 'static'
+ scale_window (int): Number of consecutive iterations without an
+ overflow to wait before increasing the loss scale. Default: 1000.
+ """
+ def __init__(self,
+ init_scale=2**32,
+ mode='dynamic',
+ scale_factor=2.,
+ scale_window=1000):
+ self.cur_scale = init_scale
+ self.cur_iter = 0
+ assert mode in ('dynamic',
+ 'static'), 'mode can only be dynamic or static'
+ self.mode = mode
+ self.last_overflow_iter = -1
+ self.scale_factor = scale_factor
+ self.scale_window = scale_window
+ def has_overflow(self, params):
+ """Check if params contain overflow."""
+ if self.mode != 'dynamic':
+ return False
+ for p in params:
+ if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
+ return True
+ return False
+ def _has_inf_or_nan(x):
+ """Check if params contain NaN."""
+ try:
+ cpu_sum = float(x.float().sum())
+ except RuntimeError as instance:
+ if 'value cannot be converted' not in instance.args[0]:
+ raise
+ return True
+ else:
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') \
+ or cpu_sum != cpu_sum:
+ return True
+ return False
+ def update_scale(self, overflow):
+ """update the current loss scale value when overflow happens."""
+ if self.mode != 'dynamic':
+ return
+ if overflow:
+ self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
+ self.last_overflow_iter = self.cur_iter
+ else:
+ if (self.cur_iter - self.last_overflow_iter) % \
+ self.scale_window == 0:
+ self.cur_scale *= self.scale_factor
+ self.cur_iter += 1
+ def state_dict(self):
+ """Returns the state of the scaler as a :class:`dict`."""
+ return dict(
+ cur_scale=self.cur_scale,
+ cur_iter=self.cur_iter,
+ mode=self.mode,
+ last_overflow_iter=self.last_overflow_iter,
+ scale_factor=self.scale_factor,
+ scale_window=self.scale_window)
+ def load_state_dict(self, state_dict):
+ """Loads the loss_scaler state dict.
+ Args:
+ state_dict (dict): scaler state.
+ """
+ self.cur_scale = state_dict['cur_scale']
+ self.cur_iter = state_dict['cur_iter']
+ self.mode = state_dict['mode']
+ self.last_overflow_iter = state_dict['last_overflow_iter']
+ self.scale_factor = state_dict['scale_factor']
+ self.scale_window = state_dict['scale_window']
+ @property
+ def loss_scale(self):
+ return self.cur_scale
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/__init__.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..915af28cefab14a14c1188ed861161080fd138a3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .checkpoint import CheckpointHook
+from .closure import ClosureHook
+from .ema import EMAHook
+from .evaluation import DistEvalHook, EvalHook
+from .hook import HOOKS, Hook
+from .iter_timer import IterTimerHook
+from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
+ NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
+ TextLoggerHook, WandbLoggerHook)
+from .lr_updater import LrUpdaterHook
+from .memory import EmptyCacheHook
+from .momentum_updater import MomentumUpdaterHook
+from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, OptimizerHook)
+from .profiler import ProfilerHook
+from .sampler_seed import DistSamplerSeedHook
+from .sync_buffer import SyncBuffersHook
+__all__ = [
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
+ 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
+ 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
+ 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook'
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/checkpoint.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af3fae43ac4b35532641a81eb13557edfc7dfba
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/checkpoint.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from annotator.uniformer.mmcv.fileio import FileClient
+from ..dist_utils import allreduce_params, master_only
+from .hook import HOOKS, Hook
+class CheckpointHook(Hook):
+ """Save checkpoints periodically.
+ Args:
+ interval (int): The saving period. If ``by_epoch=True``, interval
+ indicates epochs, otherwise it indicates iterations.
+ Default: -1, which means "never".
+ by_epoch (bool): Saving checkpoints by epoch or by iteration.
+ Default: True.
+ save_optimizer (bool): Whether to save optimizer state_dict in the
+ checkpoint. It is usually used for resuming experiments.
+ Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, ``runner.work_dir`` will be used by default. If
+ specified, the ``out_dir`` will be the concatenation of ``out_dir``
+ and the last level directory of ``runner.work_dir``.
+ `Changed in version 1.3.16.`
+ max_keep_ckpts (int, optional): The maximum checkpoints to keep.
+ In some cases we want only the latest few checkpoints and would
+ like to delete old ones to save the disk space.
+ Default: -1, which means unlimited.
+ save_last (bool, optional): Whether to force the last checkpoint to be
+ saved regardless of interval. Default: True.
+ sync_buffer (bool, optional): Whether to synchronize buffers in
+ different gpus. Default: False.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ .. warning::
+ Before v1.3.16, the ``out_dir`` argument indicates the path where the
+ checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
+ root directory and the final path to save checkpoint is the
+ concatenation of ``out_dir`` and the last level directory of
+ ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
+ and the value of ``runner.work_dir`` is "/path/of/B", then the final
+ path will be "/path/of/A/B".
+ """
+ def __init__(self,
+ interval=-1,
+ by_epoch=True,
+ save_optimizer=True,
+ out_dir=None,
+ max_keep_ckpts=-1,
+ save_last=True,
+ sync_buffer=False,
+ file_client_args=None,
+ **kwargs):
+ self.interval = interval
+ self.by_epoch = by_epoch
+ self.save_optimizer = save_optimizer
+ self.out_dir = out_dir
+ self.max_keep_ckpts = max_keep_ckpts
+ self.save_last = save_last
+ self.args = kwargs
+ self.sync_buffer = sync_buffer
+ self.file_client_args = file_client_args
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
+ f'{self.file_client.name}.'))
+ # disable the create_symlink option because some file backends do not
+ # allow to create a symlink
+ if 'create_symlink' in self.args:
+ if self.args[
+ 'create_symlink'] and not self.file_client.allow_symlink:
+ self.args['create_symlink'] = False
+ warnings.warn(
+ ('create_symlink is set as True by the user but is changed'
+ 'to be False because creating symbolic link is not '
+ f'allowed in {self.file_client.name}'))
+ else:
+ self.args['create_symlink'] = self.file_client.allow_symlink
+ def after_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` epochs
+ # 2. reach the last epoch of training
+ if self.every_n_epochs(
+ runner, self.interval) or (self.save_last
+ and self.is_last_epoch(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.epoch + 1} epochs')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
+ @master_only
+ def _save_checkpoint(self, runner):
+ """Save the current checkpoint and delete unwanted checkpoint."""
+ runner.save_checkpoint(
+ self.out_dir, save_optimizer=self.save_optimizer, **self.args)
+ if runner.meta is not None:
+ if self.by_epoch:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
+ else:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
+ runner.meta.setdefault('hook_msgs', dict())
+ runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
+ self.out_dir, cur_ckpt_filename)
+ # remove other checkpoints
+ if self.max_keep_ckpts > 0:
+ if self.by_epoch:
+ name = 'epoch_{}.pth'
+ current_ckpt = runner.epoch + 1
+ else:
+ name = 'iter_{}.pth'
+ current_ckpt = runner.iter + 1
+ redundant_ckpts = range(
+ current_ckpt - self.max_keep_ckpts * self.interval, 0,
+ -self.interval)
+ filename_tmpl = self.args.get('filename_tmpl', name)
+ for _step in redundant_ckpts:
+ ckpt_path = self.file_client.join_path(
+ self.out_dir, filename_tmpl.format(_step))
+ if self.file_client.isfile(ckpt_path):
+ self.file_client.remove(ckpt_path)
+ else:
+ break
+ def after_train_iter(self, runner):
+ if self.by_epoch:
+ return
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` iterations
+ # 2. reach the last iteration of training
+ if self.every_n_iters(
+ runner, self.interval) or (self.save_last
+ and self.is_last_iter(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.iter + 1} iterations')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/closure.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/closure.py
new file mode 100644
index 0000000000000000000000000000000000000000..b955f81f425be4ac3e6bb3f4aac653887989e872
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/closure.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+class ClosureHook(Hook):
+ def __init__(self, fn_name, fn):
+ assert hasattr(self, fn_name)
+ assert callable(fn)
+ setattr(self, fn_name, fn)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/ema.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c7e68088f019802a59e7ae41cc1fe0c7f28f96
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/ema.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...parallel import is_module_wrapper
+from ..hooks.hook import HOOKS, Hook
+class EMAHook(Hook):
+ r"""Exponential Moving Average Hook.
+ Use Exponential Moving Average on all parameters of model in training
+ process. All parameters have a ema backup, which update by the formula
+ as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
+ .. math::
+ \text{Xema\_{t+1}} = (1 - \text{momentum}) \times
+ \text{Xema\_{t}} + \text{momentum} \times X_t
+ Args:
+ momentum (float): The momentum used for updating ema parameter.
+ Defaults to 0.0002.
+ interval (int): Update ema parameter every interval iteration.
+ Defaults to 1.
+ warm_up (int): During first warm_up steps, we may use smaller momentum
+ to update ema parameters more slowly. Defaults to 100.
+ resume_from (str): The checkpoint path. Defaults to None.
+ """
+ def __init__(self,
+ momentum=0.0002,
+ interval=1,
+ warm_up=100,
+ resume_from=None):
+ assert isinstance(interval, int) and interval > 0
+ self.warm_up = warm_up
+ self.interval = interval
+ assert momentum > 0 and momentum < 1
+ self.momentum = momentum**interval
+ self.checkpoint = resume_from
+ def before_run(self, runner):
+ """To resume model with it's ema parameters more friendly.
+ Register ema parameter as ``named_buffer`` to model
+ """
+ model = runner.model
+ if is_module_wrapper(model):
+ model = model.module
+ self.param_ema_buffer = {}
+ self.model_parameters = dict(model.named_parameters(recurse=True))
+ for name, value in self.model_parameters.items():
+ # "." is not allowed in module's buffer name
+ buffer_name = f"ema_{name.replace('.', '_')}"
+ self.param_ema_buffer[name] = buffer_name
+ model.register_buffer(buffer_name, value.data.clone())
+ self.model_buffers = dict(model.named_buffers(recurse=True))
+ if self.checkpoint is not None:
+ runner.resume(self.checkpoint)
+ def after_train_iter(self, runner):
+ """Update ema parameter every self.interval iterations."""
+ curr_step = runner.iter
+ # We warm up the momentum considering the instability at beginning
+ momentum = min(self.momentum,
+ (1 + curr_step) / (self.warm_up + curr_step))
+ if curr_step % self.interval != 0:
+ return
+ for name, parameter in self.model_parameters.items():
+ buffer_name = self.param_ema_buffer[name]
+ buffer_parameter = self.model_buffers[buffer_name]
+ buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
+ def after_train_epoch(self, runner):
+ """We load parameter values from ema backup to model before the
+ EvalHook."""
+ self._swap_ema_parameters()
+ def before_train_epoch(self, runner):
+ """We recover model's parameter from ema backup after last epoch's
+ EvalHook."""
+ self._swap_ema_parameters()
+ def _swap_ema_parameters(self):
+ """Swap the parameter of model with parameter in ema_buffer."""
+ for name, value in self.model_parameters.items():
+ temp = value.data.clone()
+ ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
+ value.data.copy_(ema_buffer.data)
+ ema_buffer.data.copy_(temp)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/evaluation.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d00999ce5665c53bded8de9e084943eee2d230d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/evaluation.py
@@ -0,0 +1,509 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from math import inf
+import torch.distributed as dist
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.data import DataLoader
+from annotator.uniformer.mmcv.fileio import FileClient
+from annotator.uniformer.mmcv.utils import is_seq_of
+from .hook import Hook
+from .logger import LoggerHook
+class EvalHook(Hook):
+ """Non-Distributed evaluation hook.
+ This hook will regularly perform evaluation in a given interval when
+ performing in non-distributed environment.
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader, and return the test results. If ``None``, the default
+ test function ``mmcv.engine.single_gpu_test`` will be used.
+ (default: ``None``)
+ greater_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'greater' comparison rule. If ``None``,
+ _default_greater_keys will be used. (default: ``None``)
+ less_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'less' comparison rule. If ``None``, _default_less_keys
+ will be used. (default: ``None``)
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ `New in version 1.3.16.`
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ `New in version 1.3.16.`
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+ Notes:
+ If new arguments are added for EvalHook, tools/test.py,
+ tools/eval_metric.py may be affected.
+ """
+ # Since the key for determine greater or less is related to the downstream
+ # tasks, downstream repos may need to overwrite the following inner
+ # variable accordingly.
+ rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
+ init_value_map = {'greater': -inf, 'less': inf}
+ _default_greater_keys = [
+ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
+ 'mAcc', 'aAcc'
+ ]
+ _default_less_keys = ['loss']
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+ if not isinstance(dataloader, DataLoader):
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
+ f'but got {type(dataloader)}')
+ if interval <= 0:
+ raise ValueError(f'interval must be a positive number, '
+ f'but got {interval}')
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
+ if start is not None and start < 0:
+ raise ValueError(f'The evaluation start epoch {start} is smaller '
+ f'than 0')
+ self.dataloader = dataloader
+ self.interval = interval
+ self.start = start
+ self.by_epoch = by_epoch
+ assert isinstance(save_best, str) or save_best is None, \
+ '""save_best"" should be a str or None ' \
+ f'rather than {type(save_best)}'
+ self.save_best = save_best
+ self.eval_kwargs = eval_kwargs
+ self.initial_flag = True
+ if test_fn is None:
+ from annotator.uniformer.mmcv.engine import single_gpu_test
+ self.test_fn = single_gpu_test
+ else:
+ self.test_fn = test_fn
+ if greater_keys is None:
+ self.greater_keys = self._default_greater_keys
+ else:
+ if not isinstance(greater_keys, (list, tuple)):
+ greater_keys = (greater_keys, )
+ assert is_seq_of(greater_keys, str)
+ self.greater_keys = greater_keys
+ if less_keys is None:
+ self.less_keys = self._default_less_keys
+ else:
+ if not isinstance(less_keys, (list, tuple)):
+ less_keys = (less_keys, )
+ assert is_seq_of(less_keys, str)
+ self.less_keys = less_keys
+ if self.save_best is not None:
+ self.best_ckpt_path = None
+ self._init_rule(rule, self.save_best)
+ self.out_dir = out_dir
+ self.file_client_args = file_client_args
+ def _init_rule(self, rule, key_indicator):
+ """Initialize rule, key_indicator, comparison_func, and best score.
+ Here is the rule to determine which rule is used for key indicator
+ when the rule is not specific (note that the key indicator matching
+ is case-insensitive):
+ 1. If the key indicator is in ``self.greater_keys``, the rule will be
+ specified as 'greater'.
+ 2. Or if the key indicator is in ``self.less_keys``, the rule will be
+ specified as 'less'.
+ 3. Or if the key indicator is equal to the substring in any one item
+ in ``self.greater_keys``, the rule will be specified as 'greater'.
+ 4. Or if the key indicator is equal to the substring in any one item
+ in ``self.less_keys``, the rule will be specified as 'less'.
+ Args:
+ rule (str | None): Comparison rule for best score.
+ key_indicator (str | None): Key indicator to determine the
+ comparison rule.
+ """
+ if rule not in self.rule_map and rule is not None:
+ raise KeyError(f'rule must be greater, less or None, '
+ f'but got {rule}.')
+ if rule is None:
+ if key_indicator != 'auto':
+ # `_lc` here means we use the lower case of keys for
+ # case-insensitive matching
+ key_indicator_lc = key_indicator.lower()
+ greater_keys = [key.lower() for key in self.greater_keys]
+ less_keys = [key.lower() for key in self.less_keys]
+ if key_indicator_lc in greater_keys:
+ rule = 'greater'
+ elif key_indicator_lc in less_keys:
+ rule = 'less'
+ elif any(key in key_indicator_lc for key in greater_keys):
+ rule = 'greater'
+ elif any(key in key_indicator_lc for key in less_keys):
+ rule = 'less'
+ else:
+ raise ValueError(f'Cannot infer the rule for key '
+ f'{key_indicator}, thus a specific rule '
+ f'must be specified.')
+ self.rule = rule
+ self.key_indicator = key_indicator
+ if self.rule is not None:
+ self.compare_func = self.rule_map[self.rule]
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'The best checkpoint will be saved to {self.out_dir} by '
+ f'{self.file_client.name}'))
+ if self.save_best is not None:
+ if runner.meta is None:
+ warnings.warn('runner.meta is None. Creating an empty one.')
+ runner.meta = dict()
+ runner.meta.setdefault('hook_msgs', dict())
+ self.best_ckpt_path = runner.meta['hook_msgs'].get(
+ 'best_ckpt', None)
+ def before_train_iter(self, runner):
+ """Evaluate the model only at the start of training by iteration."""
+ if self.by_epoch or not self.initial_flag:
+ return
+ if self.start is not None and runner.iter >= self.start:
+ self.after_train_iter(runner)
+ self.initial_flag = False
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training by epoch."""
+ if not (self.by_epoch and self.initial_flag):
+ return
+ if self.start is not None and runner.epoch >= self.start:
+ self.after_train_epoch(runner)
+ self.initial_flag = False
+ def after_train_iter(self, runner):
+ """Called after every training iter to evaluate the results."""
+ if not self.by_epoch and self._should_evaluate(runner):
+ # Because the priority of EvalHook is higher than LoggerHook, the
+ # training log and the evaluating log are mixed. Therefore,
+ # we need to dump the training log and clear it before evaluating
+ # log is generated. In addition, this problem will only appear in
+ # `IterBasedRunner` whose `self.by_epoch` is False, because
+ # `EpochBasedRunner` whose `self.by_epoch` is True calls
+ # `_do_evaluate` in `after_train_epoch` stage, and at this stage
+ # the training log has been printed, so it will not cause any
+ # problem. more details at
+ # https://github.com/open-mmlab/mmsegmentation/issues/694
+ for hook in runner._hooks:
+ if isinstance(hook, LoggerHook):
+ hook.after_train_iter(runner)
+ runner.log_buffer.clear()
+ self._do_evaluate(runner)
+ def after_train_epoch(self, runner):
+ """Called after every training epoch to evaluate the results."""
+ if self.by_epoch and self._should_evaluate(runner):
+ self._do_evaluate(runner)
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ results = self.test_fn(runner.model, self.dataloader)
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to save
+ # the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
+ def _should_evaluate(self, runner):
+ """Judge whether to perform evaluation.
+ Here is the rule to judge whether to perform evaluation:
+ 1. It will not perform evaluation during the epoch/iteration interval,
+ which is determined by ``self.interval``.
+ 2. It will not perform evaluation if the start time is larger than
+ current time.
+ 3. It will not perform evaluation when current time is larger than
+ the start time but during epoch/iteration interval.
+ Returns:
+ bool: The flag indicating whether to perform evaluation.
+ """
+ if self.by_epoch:
+ current = runner.epoch
+ check_time = self.every_n_epochs
+ else:
+ current = runner.iter
+ check_time = self.every_n_iters
+ if self.start is None:
+ if not check_time(runner, self.interval):
+ # No evaluation during the interval.
+ return False
+ elif (current + 1) < self.start:
+ # No evaluation if start is larger than the current time.
+ return False
+ else:
+ # Evaluation only at epochs/iters 3, 5, 7...
+ # if start==3 and interval==2
+ if (current + 1 - self.start) % self.interval:
+ return False
+ return True
+ def _save_ckpt(self, runner, key_score):
+ """Save the best checkpoint.
+ It will compare the score according to the compare function, write
+ related information (best score, best checkpoint path) and save the
+ best checkpoint into ``work_dir``.
+ """
+ if self.by_epoch:
+ current = f'epoch_{runner.epoch + 1}'
+ cur_type, cur_time = 'epoch', runner.epoch + 1
+ else:
+ current = f'iter_{runner.iter + 1}'
+ cur_type, cur_time = 'iter', runner.iter + 1
+ best_score = runner.meta['hook_msgs'].get(
+ 'best_score', self.init_value_map[self.rule])
+ if self.compare_func(key_score, best_score):
+ best_score = key_score
+ runner.meta['hook_msgs']['best_score'] = best_score
+ if self.best_ckpt_path and self.file_client.isfile(
+ self.best_ckpt_path):
+ self.file_client.remove(self.best_ckpt_path)
+ runner.logger.info(
+ (f'The previous best checkpoint {self.best_ckpt_path} was '
+ 'removed'))
+ best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
+ self.best_ckpt_path = self.file_client.join_path(
+ self.out_dir, best_ckpt_name)
+ runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
+ runner.save_checkpoint(
+ self.out_dir, best_ckpt_name, create_symlink=False)
+ runner.logger.info(
+ f'Now best checkpoint is saved as {best_ckpt_name}.')
+ runner.logger.info(
+ f'Best {self.key_indicator} is {best_score:0.4f} '
+ f'at {cur_time} {cur_type}.')
+ def evaluate(self, runner, results):
+ """Evaluate the results.
+ Args:
+ runner (:obj:`mmcv.Runner`): The underlined training runner.
+ results (list): Output results.
+ """
+ eval_res = self.dataloader.dataset.evaluate(
+ results, logger=runner.logger, **self.eval_kwargs)
+ for name, val in eval_res.items():
+ runner.log_buffer.output[name] = val
+ runner.log_buffer.ready = True
+ if self.save_best is not None:
+ # If the performance of model is pool, the `eval_res` may be an
+ # empty dict and it will raise exception when `self.save_best` is
+ # not None. More details at
+ # https://github.com/open-mmlab/mmdetection/issues/6265.
+ if not eval_res:
+ warnings.warn(
+ 'Since `eval_res` is an empty dict, the behavior to save '
+ 'the best checkpoint will be skipped in this evaluation.')
+ return None
+ if self.key_indicator == 'auto':
+ # infer from eval_results
+ self._init_rule(self.rule, list(eval_res.keys())[0])
+ return eval_res[self.key_indicator]
+ return None
+class DistEvalHook(EvalHook):
+ """Distributed evaluation hook.
+ This hook will regularly perform evaluation in a given interval when
+ performing in distributed environment.
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader in a multi-gpu manner, and return the test results. If
+ ``None``, the default test function ``mmcv.engine.multi_gpu_test``
+ will be used. (default: ``None``)
+ tmpdir (str | None): Temporary directory to save the results of all
+ processes. Default: None.
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
+ Default: False.
+ broadcast_bn_buffer (bool): Whether to broadcast the
+ buffer(running_mean and running_var) of rank 0 to other rank
+ before evaluation. Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+ """
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ broadcast_bn_buffer=True,
+ tmpdir=None,
+ gpu_collect=False,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+ if test_fn is None:
+ from annotator.uniformer.mmcv.engine import multi_gpu_test
+ test_fn = multi_gpu_test
+ super().__init__(
+ dataloader,
+ start=start,
+ interval=interval,
+ by_epoch=by_epoch,
+ save_best=save_best,
+ rule=rule,
+ test_fn=test_fn,
+ greater_keys=greater_keys,
+ less_keys=less_keys,
+ out_dir=out_dir,
+ file_client_args=file_client_args,
+ **eval_kwargs)
+ self.broadcast_bn_buffer = broadcast_bn_buffer
+ self.tmpdir = tmpdir
+ self.gpu_collect = gpu_collect
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ # Synchronization of BatchNorm's buffer (running_mean
+ # and running_var) is not supported in the DDP of pytorch,
+ # which may cause the inconsistent performance of models in
+ # different ranks, so we broadcast BatchNorm's buffers
+ # of rank 0 to other ranks to avoid this.
+ if self.broadcast_bn_buffer:
+ model = runner.model
+ for name, module in model.named_modules():
+ if isinstance(module,
+ _BatchNorm) and module.track_running_stats:
+ dist.broadcast(module.running_var, 0)
+ dist.broadcast(module.running_mean, 0)
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
+ results = self.test_fn(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to
+ # save the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/hook.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8855c107727ecf85b917c890fc8b7f6359238a4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/hook.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.uniformer.mmcv.utils import Registry, is_method_overridden
+HOOKS = Registry('hook')
+class Hook:
+ stages = ('before_run', 'before_train_epoch', 'before_train_iter',
+ 'after_train_iter', 'after_train_epoch', 'before_val_epoch',
+ 'before_val_iter', 'after_val_iter', 'after_val_epoch',
+ 'after_run')
+ def before_run(self, runner):
+ pass
+ def after_run(self, runner):
+ pass
+ def before_epoch(self, runner):
+ pass
+ def after_epoch(self, runner):
+ pass
+ def before_iter(self, runner):
+ pass
+ def after_iter(self, runner):
+ pass
+ def before_train_epoch(self, runner):
+ self.before_epoch(runner)
+ def before_val_epoch(self, runner):
+ self.before_epoch(runner)
+ def after_train_epoch(self, runner):
+ self.after_epoch(runner)
+ def after_val_epoch(self, runner):
+ self.after_epoch(runner)
+ def before_train_iter(self, runner):
+ self.before_iter(runner)
+ def before_val_iter(self, runner):
+ self.before_iter(runner)
+ def after_train_iter(self, runner):
+ self.after_iter(runner)
+ def after_val_iter(self, runner):
+ self.after_iter(runner)
+ def every_n_epochs(self, runner, n):
+ return (runner.epoch + 1) % n == 0 if n > 0 else False
+ def every_n_inner_iters(self, runner, n):
+ return (runner.inner_iter + 1) % n == 0 if n > 0 else False
+ def every_n_iters(self, runner, n):
+ return (runner.iter + 1) % n == 0 if n > 0 else False
+ def end_of_epoch(self, runner):
+ return runner.inner_iter + 1 == len(runner.data_loader)
+ def is_last_epoch(self, runner):
+ return runner.epoch + 1 == runner._max_epochs
+ def is_last_iter(self, runner):
+ return runner.iter + 1 == runner._max_iters
+ def get_triggered_stages(self):
+ trigger_stages = set()
+ for stage in Hook.stages:
+ if is_method_overridden(stage, Hook, self):
+ trigger_stages.add(stage)
+ # some methods will be triggered in multi stages
+ # use this dict to map method to stages.
+ method_stages_map = {
+ 'before_epoch': ['before_train_epoch', 'before_val_epoch'],
+ 'after_epoch': ['after_train_epoch', 'after_val_epoch'],
+ 'before_iter': ['before_train_iter', 'before_val_iter'],
+ 'after_iter': ['after_train_iter', 'after_val_iter'],
+ }
+ for method, map_stages in method_stages_map.items():
+ if is_method_overridden(method, Hook, self):
+ trigger_stages.update(map_stages)
+ return [stage for stage in Hook.stages if stage in trigger_stages]
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/iter_timer.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/iter_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd5002fe85ffc6992155ac01003878064a1d9be
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/iter_timer.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from .hook import HOOKS, Hook
+class IterTimerHook(Hook):
+ def before_epoch(self, runner):
+ self.t = time.time()
+ def before_iter(self, runner):
+ runner.log_buffer.update({'data_time': time.time() - self.t})
+ def after_iter(self, runner):
+ runner.log_buffer.update({'time': time.time() - self.t})
+ self.t = time.time()
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b6b345640a895368ac8a647afef6f24333d90e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import LoggerHook
+from .dvclive import DvcliveLoggerHook
+from .mlflow import MlflowLoggerHook
+from .neptune import NeptuneLoggerHook
+from .pavi import PaviLoggerHook
+from .tensorboard import TensorboardLoggerHook
+from .text import TextLoggerHook
+from .wandb import WandbLoggerHook
+__all__ = [
+ 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
+ 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
+ 'NeptuneLoggerHook', 'DvcliveLoggerHook'
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/base.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f845256729458ced821762a1b8ef881e17ff9955
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/base.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from abc import ABCMeta, abstractmethod
+import numpy as np
+import torch
+from ..hook import Hook
+class LoggerHook(Hook):
+ """Base class for logger hooks.
+ Args:
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ """
+ __metaclass__ = ABCMeta
+ def __init__(self,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ self.interval = interval
+ self.ignore_last = ignore_last
+ self.reset_flag = reset_flag
+ self.by_epoch = by_epoch
+ @abstractmethod
+ def log(self, runner):
+ pass
+ @staticmethod
+ def is_scalar(val, include_np=True, include_torch=True):
+ """Tell the input variable is a scalar or not.
+ Args:
+ val: Input variable.
+ include_np (bool): Whether include 0-d np.ndarray as a scalar.
+ include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
+ Returns:
+ bool: True or False.
+ """
+ if isinstance(val, numbers.Number):
+ return True
+ elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
+ return True
+ elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
+ return True
+ else:
+ return False
+ def get_mode(self, runner):
+ if runner.mode == 'train':
+ if 'time' in runner.log_buffer.output:
+ mode = 'train'
+ else:
+ mode = 'val'
+ elif runner.mode == 'val':
+ mode = 'val'
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return mode
+ def get_epoch(self, runner):
+ if runner.mode == 'train':
+ epoch = runner.epoch + 1
+ elif runner.mode == 'val':
+ # normal val mode
+ # runner.epoch += 1 has been done before val workflow
+ epoch = runner.epoch
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return epoch
+ def get_iter(self, runner, inner_iter=False):
+ """Get the current training iteration step."""
+ if self.by_epoch and inner_iter:
+ current_iter = runner.inner_iter + 1
+ else:
+ current_iter = runner.iter + 1
+ return current_iter
+ def get_lr_tags(self, runner):
+ tags = {}
+ lrs = runner.current_lr()
+ if isinstance(lrs, dict):
+ for name, value in lrs.items():
+ tags[f'learning_rate/{name}'] = value[0]
+ else:
+ tags['learning_rate'] = lrs[0]
+ return tags
+ def get_momentum_tags(self, runner):
+ tags = {}
+ momentums = runner.current_momentum()
+ if isinstance(momentums, dict):
+ for name, value in momentums.items():
+ tags[f'momentum/{name}'] = value[0]
+ else:
+ tags['momentum'] = momentums[0]
+ return tags
+ def get_loggable_tags(self,
+ runner,
+ allow_scalar=True,
+ allow_text=False,
+ add_mode=True,
+ tags_to_skip=('time', 'data_time')):
+ tags = {}
+ for var, val in runner.log_buffer.output.items():
+ if var in tags_to_skip:
+ continue
+ if self.is_scalar(val) and not allow_scalar:
+ continue
+ if isinstance(val, str) and not allow_text:
+ continue
+ if add_mode:
+ var = f'{self.get_mode(runner)}/{var}'
+ tags[var] = val
+ tags.update(self.get_lr_tags(runner))
+ tags.update(self.get_momentum_tags(runner))
+ return tags
+ def before_run(self, runner):
+ for hook in runner.hooks[::-1]:
+ if isinstance(hook, LoggerHook):
+ hook.reset_flag = True
+ break
+ def before_epoch(self, runner):
+ runner.log_buffer.clear() # clear logs of last epoch
+ def after_train_iter(self, runner):
+ if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif not self.by_epoch and self.every_n_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif self.end_of_epoch(runner) and not self.ignore_last:
+ # not precise but more stable
+ runner.log_buffer.average(self.interval)
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+ def after_train_epoch(self, runner):
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+ def after_val_epoch(self, runner):
+ runner.log_buffer.average()
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py
new file mode 100644
index 0000000000000000000000000000000000000000..687cdc58c0336c92b1e4f9a410ba67ebaab2bc7a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+class DvcliveLoggerHook(LoggerHook):
+ """Class to log metrics with dvclive.
+ It requires `dvclive`_ to be installed.
+ Args:
+ path (str): Directory where dvclive will write TSV log files.
+ interval (int): Logging interval (every k iterations).
+ Default 10.
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ Default: True.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ Default: True.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ Default: True.
+ .. _dvclive:
+ https://dvc.org/doc/dvclive
+ """
+ def __init__(self,
+ path,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ by_epoch=True):
+ super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.path = path
+ self.import_dvclive()
+ def import_dvclive(self):
+ try:
+ import dvclive
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install dvclive" to install dvclive')
+ self.dvclive = dvclive
+ @master_only
+ def before_run(self, runner):
+ self.dvclive.init(self.path)
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for k, v in tags.items():
+ self.dvclive.log(k, v, step=self.get_iter(runner))
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a72592be47b534ce22573775fd5a7e8e86d72d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+class MlflowLoggerHook(LoggerHook):
+ def __init__(self,
+ exp_name=None,
+ tags=None,
+ log_model=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ """Class to log metrics and (optionally) a trained model to MLflow.
+ It requires `MLflow`_ to be installed.
+ Args:
+ exp_name (str, optional): Name of the experiment to be used.
+ Default None.
+ If not None, set the active experiment.
+ If experiment does not exist, an experiment with provided name
+ will be created.
+ tags (dict of str: str, optional): Tags for the current run.
+ Default None.
+ If not None, set tags for the current run.
+ log_model (bool, optional): Whether to log an MLflow artifact.
+ Default True.
+ If True, log runner.model as an MLflow artifact
+ for the current run.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ .. _MLflow:
+ https://www.mlflow.org/docs/latest/index.html
+ """
+ super(MlflowLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_mlflow()
+ self.exp_name = exp_name
+ self.tags = tags
+ self.log_model = log_model
+ def import_mlflow(self):
+ try:
+ import mlflow
+ import mlflow.pytorch as mlflow_pytorch
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install mlflow" to install mlflow')
+ self.mlflow = mlflow
+ self.mlflow_pytorch = mlflow_pytorch
+ @master_only
+ def before_run(self, runner):
+ super(MlflowLoggerHook, self).before_run(runner)
+ if self.exp_name is not None:
+ self.mlflow.set_experiment(self.exp_name)
+ if self.tags is not None:
+ self.mlflow.set_tags(self.tags)
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ self.mlflow.log_metrics(tags, step=self.get_iter(runner))
+ @master_only
+ def after_run(self, runner):
+ if self.log_model:
+ self.mlflow_pytorch.log_model(runner.model, 'models')
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a38772b0c93a8608f32c6357b8616e77c139dc9
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+class NeptuneLoggerHook(LoggerHook):
+ """Class to log metrics to NeptuneAI.
+ It requires `neptune-client` to be installed.
+ Args:
+ init_kwargs (dict): a dict contains the initialization keys as below:
+ - project (str): Name of a project in a form of
+ namespace/project_name. If None, the value of
+ NEPTUNE_PROJECT environment variable will be taken.
+ - api_token (str): User’s API token.
+ If None, the value of NEPTUNE_API_TOKEN environment
+ variable will be taken. Note: It is strongly recommended
+ to use NEPTUNE_API_TOKEN environment variable rather than
+ placing your API token in plain text in your source code.
+ - name (str, optional, default is 'Untitled'): Editable name of
+ the run. Name is displayed in the run's Details and in
+ Runs table as a column.
+ Check https://docs.neptune.ai/api-reference/neptune#init for
+ more init arguments.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ .. _NeptuneAI:
+ https://docs.neptune.ai/you-should-know/logging-metadata
+ """
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ with_step=True,
+ by_epoch=True):
+ super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_neptune()
+ self.init_kwargs = init_kwargs
+ self.with_step = with_step
+ def import_neptune(self):
+ try:
+ import neptune.new as neptune
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install neptune-client" to install neptune')
+ self.neptune = neptune
+ self.run = None
+ @master_only
+ def before_run(self, runner):
+ if self.init_kwargs:
+ self.run = self.neptune.init(**self.init_kwargs)
+ else:
+ self.run = self.neptune.init()
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for tag_name, tag_value in tags.items():
+ if self.with_step:
+ self.run[tag_name].log(
+ tag_value, step=self.get_iter(runner))
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.run[tag_name].log(tags)
+ @master_only
+ def after_run(self, runner):
+ self.run.stop()
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dcf146d8163aff1363e9764999b0a74d674a595
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+import torch
+import yaml
+import annotator.uniformer.mmcv as mmcv
+from ....parallel.utils import is_module_wrapper
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+class PaviLoggerHook(LoggerHook):
+ def __init__(self,
+ init_kwargs=None,
+ add_graph=False,
+ add_last_ckpt=False,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True,
+ img_key='img_info'):
+ super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.init_kwargs = init_kwargs
+ self.add_graph = add_graph
+ self.add_last_ckpt = add_last_ckpt
+ self.img_key = img_key
+ @master_only
+ def before_run(self, runner):
+ super(PaviLoggerHook, self).before_run(runner)
+ try:
+ from pavi import SummaryWriter
+ except ImportError:
+ raise ImportError('Please run "pip install pavi" to install pavi.')
+ self.run_name = runner.work_dir.split('/')[-1]
+ if not self.init_kwargs:
+ self.init_kwargs = dict()
+ self.init_kwargs['name'] = self.run_name
+ self.init_kwargs['model'] = runner._model_name
+ if runner.meta is not None:
+ if 'config_dict' in runner.meta:
+ config_dict = runner.meta['config_dict']
+ assert isinstance(
+ config_dict,
+ dict), ('meta["config_dict"] has to be of a dict, '
+ f'but got {type(config_dict)}')
+ elif 'config_file' in runner.meta:
+ config_file = runner.meta['config_file']
+ config_dict = dict(mmcv.Config.fromfile(config_file))
+ else:
+ config_dict = None
+ if config_dict is not None:
+ # 'max_.*iter' is parsed in pavi sdk as the maximum iterations
+ # to properly set up the progress bar.
+ config_dict = config_dict.copy()
+ config_dict.setdefault('max_iter', runner.max_iters)
+ # non-serializable values are first converted in
+ # mmcv.dump to json
+ config_dict = json.loads(
+ mmcv.dump(config_dict, file_format='json'))
+ session_text = yaml.dump(config_dict)
+ self.init_kwargs['session_text'] = session_text
+ self.writer = SummaryWriter(**self.init_kwargs)
+ def get_step(self, runner):
+ """Get the total training step/epoch."""
+ if self.get_mode(runner) == 'val' and self.by_epoch:
+ return self.get_epoch(runner)
+ else:
+ return self.get_iter(runner)
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, add_mode=False)
+ if tags:
+ self.writer.add_scalars(
+ self.get_mode(runner), tags, self.get_step(runner))
+ @master_only
+ def after_run(self, runner):
+ if self.add_last_ckpt:
+ ckpt_path = osp.join(runner.work_dir, 'latest.pth')
+ if osp.islink(ckpt_path):
+ ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
+ if osp.isfile(ckpt_path):
+ # runner.epoch += 1 has been done before `after_run`.
+ iteration = runner.epoch if self.by_epoch else runner.iter
+ return self.writer.add_snapshot_file(
+ tag=self.run_name,
+ snapshot_file_path=ckpt_path,
+ iteration=iteration)
+ # flush the buffer and send a task ending signal to Pavi
+ self.writer.close()
+ @master_only
+ def before_epoch(self, runner):
+ if runner.epoch == 0 and self.add_graph:
+ if is_module_wrapper(runner.model):
+ _model = runner.model.module
+ else:
+ _model = runner.model
+ device = next(_model.parameters()).device
+ data = next(iter(runner.data_loader))
+ image = data[self.img_key][0:1].to(device)
+ with torch.no_grad():
+ self.writer.add_graph(_model, image)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dd5011dc08def6c09eef86d3ce5b124c9fc5372
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+class TensorboardLoggerHook(LoggerHook):
+ def __init__(self,
+ log_dir=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.log_dir = log_dir
+ @master_only
+ def before_run(self, runner):
+ super(TensorboardLoggerHook, self).before_run(runner)
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.1')):
+ try:
+ from tensorboardX import SummaryWriter
+ except ImportError:
+ raise ImportError('Please install tensorboardX to use '
+ 'TensorboardLoggerHook.')
+ else:
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install future tensorboard" to install '
+ 'the dependencies to use torch.utils.tensorboard '
+ '(applicable to PyTorch 1.1 or higher)')
+ if self.log_dir is None:
+ self.log_dir = osp.join(runner.work_dir, 'tf_logs')
+ self.writer = SummaryWriter(self.log_dir)
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, allow_text=True)
+ for tag, val in tags.items():
+ if isinstance(val, str):
+ self.writer.add_text(tag, val, self.get_iter(runner))
+ else:
+ self.writer.add_scalar(tag, val, self.get_iter(runner))
+ @master_only
+ def after_run(self, runner):
+ self.writer.close()
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/text.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b1a3eca9595a130121526f8b4c29915387ab35
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/text.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os
+import os.path as osp
+from collections import OrderedDict
+import torch
+import torch.distributed as dist
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.fileio.file_client import FileClient
+from annotator.uniformer.mmcv.utils import is_tuple_of, scandir
+from ..hook import HOOKS
+from .base import LoggerHook
+class TextLoggerHook(LoggerHook):
+ """Logger hook in text.
+ In this logger hook, the information will be printed on terminal and
+ saved in json file.
+ Args:
+ by_epoch (bool, optional): Whether EpochBasedRunner is used.
+ Default: True.
+ interval (int, optional): Logging interval (every k iterations).
+ Default: 10.
+ ignore_last (bool, optional): Ignore the log of last iterations in each
+ epoch if less than :attr:`interval`. Default: True.
+ reset_flag (bool, optional): Whether to clear the output buffer after
+ logging. Default: False.
+ interval_exp_name (int, optional): Logging interval for experiment
+ name. This feature is to help users conveniently get the experiment
+ information from screen or log file. Default: 1000.
+ out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
+ If ``out_dir`` is specified, logs will be copied to a new directory
+ which is the concatenation of ``out_dir`` and the last level
+ directory of ``runner.work_dir``. Default: None.
+ `New in version 1.3.16.`
+ out_suffix (str or tuple[str], optional): Those filenames ending with
+ ``out_suffix`` will be copied to ``out_dir``.
+ Default: ('.log.json', '.log', '.py').
+ `New in version 1.3.16.`
+ keep_local (bool, optional): Whether to keep local log when
+ :attr:`out_dir` is specified. If False, the local log will be
+ removed. Default: True.
+ `New in version 1.3.16.`
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+ def __init__(self,
+ by_epoch=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ interval_exp_name=1000,
+ out_dir=None,
+ out_suffix=('.log.json', '.log', '.py'),
+ keep_local=True,
+ file_client_args=None):
+ super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.by_epoch = by_epoch
+ self.time_sec_tot = 0
+ self.interval_exp_name = interval_exp_name
+ if out_dir is None and file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" when `out_dir` is not'
+ 'specified.')
+ self.out_dir = out_dir
+ if not (out_dir is None or isinstance(out_dir, str)
+ or is_tuple_of(out_dir, str)):
+ raise TypeError('out_dir should be "None" or string or tuple of '
+ 'string, but got {out_dir}')
+ self.out_suffix = out_suffix
+ self.keep_local = keep_local
+ self.file_client_args = file_client_args
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(file_client_args,
+ self.out_dir)
+ def before_run(self, runner):
+ super(TextLoggerHook, self).before_run(runner)
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+ # The final `self.out_dir` is the concatenation of `self.out_dir`
+ # and the last level directory of `runner.work_dir`
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'Text logs will be saved to {self.out_dir} by '
+ f'{self.file_client.name} after the training process.'))
+ self.start_iter = runner.iter
+ self.json_log_path = osp.join(runner.work_dir,
+ f'{runner.timestamp}.log.json')
+ if runner.meta is not None:
+ self._dump_log(runner.meta, runner)
+ def _get_max_memory(self, runner):
+ device = getattr(runner.model, 'output_device', None)
+ mem = torch.cuda.max_memory_allocated(device=device)
+ mem_mb = torch.tensor([mem / (1024 * 1024)],
+ dtype=torch.int,
+ device=device)
+ if runner.world_size > 1:
+ dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
+ return mem_mb.item()
+ def _log_info(self, log_dict, runner):
+ # print exp name for users to distinguish experiments
+ # at every ``interval_exp_name`` iterations and the end of each epoch
+ if runner.meta is not None and 'exp_name' in runner.meta:
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
+ self.by_epoch and self.end_of_epoch(runner)):
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
+ runner.logger.info(exp_info)
+ if log_dict['mode'] == 'train':
+ if isinstance(log_dict['lr'], dict):
+ lr_str = []
+ for k, val in log_dict['lr'].items():
+ lr_str.append(f'lr_{k}: {val:.3e}')
+ lr_str = ' '.join(lr_str)
+ else:
+ lr_str = f'lr: {log_dict["lr"]:.3e}'
+ # by epoch: Epoch [4][100/1000]
+ # by iter: Iter [100/100000]
+ if self.by_epoch:
+ log_str = f'Epoch [{log_dict["epoch"]}]' \
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
+ else:
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
+ log_str += f'{lr_str}, '
+ if 'time' in log_dict.keys():
+ self.time_sec_tot += (log_dict['time'] * self.interval)
+ time_sec_avg = self.time_sec_tot / (
+ runner.iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ log_str += f'eta: {eta_str}, '
+ log_str += f'time: {log_dict["time"]:.3f}, ' \
+ f'data_time: {log_dict["data_time"]:.3f}, '
+ # statistic memory
+ if torch.cuda.is_available():
+ log_str += f'memory: {log_dict["memory"]}, '
+ else:
+ # val/test time
+ # here 1000 is the length of the val dataloader
+ # by epoch: Epoch[val] [4][1000]
+ # by iter: Iter[val] [1000]
+ if self.by_epoch:
+ log_str = f'Epoch({log_dict["mode"]}) ' \
+ f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
+ else:
+ log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
+ log_items = []
+ for name, val in log_dict.items():
+ # TODO: resolve this hack
+ # these items have been in log_str
+ if name in [
+ 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
+ 'memory', 'epoch'
+ ]:
+ continue
+ if isinstance(val, float):
+ val = f'{val:.4f}'
+ log_items.append(f'{name}: {val}')
+ log_str += ', '.join(log_items)
+ runner.logger.info(log_str)
+ def _dump_log(self, log_dict, runner):
+ # dump log in json format
+ json_log = OrderedDict()
+ for k, v in log_dict.items():
+ json_log[k] = self._round_float(v)
+ # only append log at last line
+ if runner.rank == 0:
+ with open(self.json_log_path, 'a+') as f:
+ mmcv.dump(json_log, f, file_format='json')
+ f.write('\n')
+ def _round_float(self, items):
+ if isinstance(items, list):
+ return [self._round_float(item) for item in items]
+ elif isinstance(items, float):
+ return round(items, 5)
+ else:
+ return items
+ def log(self, runner):
+ if 'eval_iter_num' in runner.log_buffer.output:
+ # this doesn't modify runner.iter and is regardless of by_epoch
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
+ else:
+ cur_iter = self.get_iter(runner, inner_iter=True)
+ log_dict = OrderedDict(
+ mode=self.get_mode(runner),
+ epoch=self.get_epoch(runner),
+ iter=cur_iter)
+ # only record lr of the first param group
+ cur_lr = runner.current_lr()
+ if isinstance(cur_lr, list):
+ log_dict['lr'] = cur_lr[0]
+ else:
+ assert isinstance(cur_lr, dict)
+ log_dict['lr'] = {}
+ for k, lr_ in cur_lr.items():
+ assert isinstance(lr_, list)
+ log_dict['lr'].update({k: lr_[0]})
+ if 'time' in runner.log_buffer.output:
+ # statistic memory
+ if torch.cuda.is_available():
+ log_dict['memory'] = self._get_max_memory(runner)
+ log_dict = dict(log_dict, **runner.log_buffer.output)
+ self._log_info(log_dict, runner)
+ self._dump_log(log_dict, runner)
+ return log_dict
+ def after_run(self, runner):
+ # copy or upload logs to self.out_dir
+ if self.out_dir is not None:
+ for filename in scandir(runner.work_dir, self.out_suffix, True):
+ local_filepath = osp.join(runner.work_dir, filename)
+ out_filepath = self.file_client.join_path(
+ self.out_dir, filename)
+ with open(local_filepath, 'r') as f:
+ self.file_client.put_text(f.read(), out_filepath)
+ runner.logger.info(
+ (f'The file {local_filepath} has been uploaded to '
+ f'{out_filepath}.'))
+ if not self.keep_local:
+ os.remove(local_filepath)
+ runner.logger.info(
+ (f'{local_filepath} was removed due to the '
+ '`self.keep_local=False`'))
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f6808462eb79ab2b04806a5d9f0d3dd079b5ea9
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+class WandbLoggerHook(LoggerHook):
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ commit=True,
+ by_epoch=True,
+ with_step=True):
+ super(WandbLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_wandb()
+ self.init_kwargs = init_kwargs
+ self.commit = commit
+ self.with_step = with_step
+ def import_wandb(self):
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install wandb" to install wandb')
+ self.wandb = wandb
+ @master_only
+ def before_run(self, runner):
+ super(WandbLoggerHook, self).before_run(runner)
+ if self.wandb is None:
+ self.import_wandb()
+ if self.init_kwargs:
+ self.wandb.init(**self.init_kwargs)
+ else:
+ self.wandb.init()
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ if self.with_step:
+ self.wandb.log(
+ tags, step=self.get_iter(runner), commit=self.commit)
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.wandb.log(tags, commit=self.commit)
+ @master_only
+ def after_run(self, runner):
+ self.wandb.join()
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/lr_updater.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/lr_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..6365908ddf6070086de2ffc0afada46ed2f32256
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/lr_updater.py
@@ -0,0 +1,670 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from math import cos, pi
+import annotator.uniformer.mmcv as mmcv
+from .hook import HOOKS, Hook
+class LrUpdaterHook(Hook):
+ """LR Scheduler in MMCV.
+ Args:
+ by_epoch (bool): LR changes epoch by epoch
+ warmup (string): Type of warmup used. It can be None(use no warmup),
+ 'constant', 'linear' or 'exp'
+ warmup_iters (int): The number of iterations or epochs that warmup
+ lasts
+ warmup_ratio (float): LR used at the beginning of warmup equals to
+ warmup_ratio * initial_lr
+ warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters
+ means the number of epochs that warmup lasts, otherwise means the
+ number of iteration that warmup lasts
+ """
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.1,
+ warmup_by_epoch=False):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_ratio" must be in range (0,1]'
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+ self.warmup_by_epoch = warmup_by_epoch
+ if self.warmup_by_epoch:
+ self.warmup_epochs = self.warmup_iters
+ self.warmup_iters = None
+ else:
+ self.warmup_epochs = None
+ self.base_lr = [] # initial lr for all param groups
+ self.regular_lr = [] # expected lr if no warming up is performed
+ def _set_lr(self, runner, lr_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, lr in zip(optim.param_groups, lr_groups[k]):
+ param_group['lr'] = lr
+ else:
+ for param_group, lr in zip(runner.optimizer.param_groups,
+ lr_groups):
+ param_group['lr'] = lr
+ def get_lr(self, runner, base_lr):
+ raise NotImplementedError
+ def get_regular_lr(self, runner):
+ if isinstance(runner.optimizer, dict):
+ lr_groups = {}
+ for k in runner.optimizer.keys():
+ _lr_group = [
+ self.get_lr(runner, _base_lr)
+ for _base_lr in self.base_lr[k]
+ ]
+ lr_groups.update({k: _lr_group})
+ return lr_groups
+ else:
+ return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
+ def get_warmup_lr(self, cur_iters):
+ def _get_warmup_lr(cur_iters, regular_lr):
+ if self.warmup == 'constant':
+ warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_lr = [_lr * k for _lr in regular_lr]
+ return warmup_lr
+ if isinstance(self.regular_lr, dict):
+ lr_groups = {}
+ for key, regular_lr in self.regular_lr.items():
+ lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
+ return lr_groups
+ else:
+ return _get_warmup_lr(cur_iters, self.regular_lr)
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ _base_lr = [
+ group['initial_lr'] for group in optim.param_groups
+ ]
+ self.base_lr.update({k: _base_lr})
+ else:
+ for group in runner.optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ self.base_lr = [
+ group['initial_lr'] for group in runner.optimizer.param_groups
+ ]
+ def before_train_epoch(self, runner):
+ if self.warmup_iters is None:
+ epoch_len = len(runner.data_loader)
+ self.warmup_iters = self.warmup_epochs * epoch_len
+ if not self.by_epoch:
+ return
+ self.regular_lr = self.get_regular_lr(runner)
+ self._set_lr(runner, self.regular_lr)
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_lr = self.get_regular_lr(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+class FixedLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, **kwargs):
+ super(FixedLrUpdaterHook, self).__init__(**kwargs)
+ def get_lr(self, runner, base_lr):
+ return base_lr
+class StepLrUpdaterHook(LrUpdaterHook):
+ """Step LR scheduler with min_lr clipping.
+ Args:
+ step (int | list[int]): Step to decay the LR. If an int value is given,
+ regard it as the decay interval. If a list is given, decay LR at
+ these steps.
+ gamma (float, optional): Decay LR ratio. Default: 0.1.
+ min_lr (float, optional): Minimum LR value to keep. If LR after decay
+ is lower than `min_lr`, it will be clipped to this value. If None
+ is given, we don't perform lr clipping. Default: None.
+ """
+ def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_lr = min_lr
+ super(StepLrUpdaterHook, self).__init__(**kwargs)
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+ lr = base_lr * (self.gamma**exp)
+ if self.min_lr is not None:
+ # clip to a minimum value
+ lr = max(lr, self.min_lr)
+ return lr
+class ExpLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, gamma, **kwargs):
+ self.gamma = gamma
+ super(ExpLrUpdaterHook, self).__init__(**kwargs)
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * self.gamma**progress
+class PolyLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, power=1., min_lr=0., **kwargs):
+ self.power = power
+ self.min_lr = min_lr
+ super(PolyLrUpdaterHook, self).__init__(**kwargs)
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ coeff = (1 - progress / max_progress)**self.power
+ return (base_lr - self.min_lr) * coeff + self.min_lr
+class InvLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, gamma, power=1., **kwargs):
+ self.gamma = gamma
+ self.power = power
+ super(InvLrUpdaterHook, self).__init__(**kwargs)
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * (1 + self.gamma * progress)**(-self.power)
+class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
+ """Flat + Cosine lr schedule.
+ Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
+ Args:
+ start_percent (float): When to start annealing the learning rate
+ after the percentage of the total training steps.
+ The value should be in range [0, 1).
+ Default: 0.75
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+ def __init__(self,
+ start_percent=0.75,
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ if start_percent < 0 or start_percent > 1 or not isinstance(
+ start_percent, float):
+ raise ValueError(
+ 'expected float between 0 and 1 start_percent, but '
+ f'got {start_percent}')
+ self.start_percent = start_percent
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ start = round(runner.max_epochs * self.start_percent)
+ progress = runner.epoch - start
+ max_progress = runner.max_epochs - start
+ else:
+ start = round(runner.max_iters * self.start_percent)
+ progress = runner.iter - start
+ max_progress = runner.max_iters - start
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+ if progress < 0:
+ return base_lr
+ else:
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+class CosineRestartLrUpdaterHook(LrUpdaterHook):
+ """Cosine annealing with restarts learning rate scheme.
+ Args:
+ periods (list[int]): Periods for each cosine anneling cycle.
+ restart_weights (list[float], optional): Restart weights at each
+ restart iteration. Default: [1].
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+ def __init__(self,
+ periods,
+ restart_weights=[1],
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.periods = periods
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ self.restart_weights = restart_weights
+ assert (len(self.periods) == len(self.restart_weights)
+ ), 'periods and restart_weights should have the same length.'
+ super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
+ self.cumulative_periods = [
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+ ]
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ else:
+ progress = runner.iter
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+ idx = get_position_from_periods(progress, self.cumulative_periods)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
+ current_periods = self.periods[idx]
+ alpha = min((progress - nearest_restart) / current_periods, 1)
+ return annealing_cos(base_lr, target_lr, alpha, current_weight)
+def get_position_from_periods(iteration, cumulative_periods):
+ """Get the position from a period list.
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_periods = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 3.
+ Args:
+ iteration (int): Current iteration.
+ cumulative_periods (list[int]): Cumulative period list.
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_periods):
+ if iteration < period:
+ return i
+ raise ValueError(f'Current iteration {iteration} exceeds '
+ f'cumulative_periods {cumulative_periods}')
+class CyclicLrUpdaterHook(LrUpdaterHook):
+ """Cyclic LR Scheduler.
+ Implement the cyclical learning rate policy (CLR) described in
+ https://arxiv.org/pdf/1506.01186.pdf
+ Different from the original paper, we use cosine annealing rather than
+ triangular policy inside a cycle. This improves the performance in the
+ 3D detection area.
+ Args:
+ by_epoch (bool): Whether to update LR by epoch.
+ target_ratio (tuple[float]): Relative ratio of the highest LR and the
+ lowest LR to the initial LR.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of LR in
+ the total cycle.
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing. Default: 'cos'.
+ """
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(10, 1e-4),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ anneal_strategy='cos',
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.lr_phases = [] # init lr_phases
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
+ def before_run(self, runner):
+ super(CyclicLrUpdaterHook, self).before_run(runner)
+ # initiate lr_phases
+ # total lr_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.lr_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.lr_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.lr_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return self.anneal_func(base_lr * start_ratio,
+ base_lr * end_ratio,
+ progress / (end_iter - start_iter))
+class OneCycleLrUpdaterHook(LrUpdaterHook):
+ """One Cycle LR Scheduler.
+ The 1cycle learning rate policy changes the learning rate after every
+ batch. The one cycle learning rate policy is described in
+ https://arxiv.org/pdf/1708.07120.pdf
+ Args:
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group.
+ total_steps (int, optional): The total number of steps in the cycle.
+ Note that if a value is not provided here, it will be the max_iter
+ of runner. Default: None.
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ div_factor (float): Determines the initial learning rate via
+ initial_lr = max_lr/div_factor
+ Default: 25
+ final_div_factor (float): Determines the minimum learning rate via
+ min_lr = initial_lr/final_div_factor
+ Default: 1e4
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+ def __init__(self,
+ max_lr,
+ total_steps=None,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ div_factor=25,
+ final_div_factor=1e4,
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch = False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(max_lr, (numbers.Number, list, dict)):
+ raise ValueError('the type of max_lr must be the one of list or '
+ f'dict, but got {type(max_lr)}')
+ self._max_lr = max_lr
+ if total_steps is not None:
+ if not isinstance(total_steps, int):
+ raise ValueError('the type of total_steps must be int, but'
+ f'got {type(total_steps)}')
+ self.total_steps = total_steps
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.div_factor = div_factor
+ self.final_div_factor = final_div_factor
+ self.three_phase = three_phase
+ self.lr_phases = [] # init lr_phases
+ super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
+ def before_run(self, runner):
+ if hasattr(self, 'total_steps'):
+ total_steps = self.total_steps
+ else:
+ total_steps = runner.max_iters
+ if total_steps < runner.max_iters:
+ raise ValueError(
+ 'The total steps must be greater than or equal to max '
+ f'iterations {runner.max_iters} of runner, but total steps '
+ f'is {total_steps}.')
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ _max_lr = format_param(k, optim, self._max_lr)
+ self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(optim.param_groups, self.base_lr[k]):
+ group.setdefault('initial_lr', lr)
+ else:
+ k = type(runner.optimizer).__name__
+ _max_lr = format_param(k, runner.optimizer, self._max_lr)
+ self.base_lr = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
+ group.setdefault('initial_lr', lr)
+ if self.three_phase:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append([
+ float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
+ ])
+ self.lr_phases.append(
+ [total_steps - 1, 1, 1 / self.final_div_factor])
+ else:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append(
+ [total_steps - 1, self.div_factor, 1 / self.final_div_factor])
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
+ if curr_iter <= end_iter:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
+ pct)
+ break
+ start_iter = end_iter
+ return lr
+def annealing_cos(start, end, factor, weight=1):
+ """Calculate annealing cos learning rate.
+ Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
+ percentage goes from 0.0 to 1.0.
+ Args:
+ start (float): The starting learning rate of the cosine annealing.
+ end (float): The ending learing rate of the cosine annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ weight (float, optional): The combination factor of `start` and `end`
+ when calculating the actual starting learning rate. Default to 1.
+ """
+ cos_out = cos(pi * factor) + 1
+ return end + 0.5 * weight * (start - end) * cos_out
+def annealing_linear(start, end, factor):
+ """Calculate annealing linear learning rate.
+ Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
+ Args:
+ start (float): The starting learning rate of the linear annealing.
+ end (float): The ending learing rate of the linear annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ """
+ return start + (end - start) * factor
+def format_param(name, optim, param):
+ if isinstance(param, numbers.Number):
+ return [param] * len(optim.param_groups)
+ elif isinstance(param, (list, tuple)): # multi param groups
+ if len(param) != len(optim.param_groups):
+ raise ValueError(f'expected {len(optim.param_groups)} '
+ f'values for {name}, got {len(param)}')
+ return param
+ else: # multi optimizers
+ if name not in param:
+ raise KeyError(f'{name} is not found in {param.keys()}')
+ return param[name]
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/memory.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..70cf9a838fb314e3bd3c07aadbc00921a81e83ed
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/memory.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from .hook import HOOKS, Hook
+class EmptyCacheHook(Hook):
+ def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
+ self._before_epoch = before_epoch
+ self._after_epoch = after_epoch
+ self._after_iter = after_iter
+ def after_iter(self, runner):
+ if self._after_iter:
+ torch.cuda.empty_cache()
+ def before_epoch(self, runner):
+ if self._before_epoch:
+ torch.cuda.empty_cache()
+ def after_epoch(self, runner):
+ if self._after_epoch:
+ torch.cuda.empty_cache()
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..60437756ceedf06055ec349df69a25465738d3f0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py
@@ -0,0 +1,493 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import annotator.uniformer.mmcv as mmcv
+from .hook import HOOKS, Hook
+from .lr_updater import annealing_cos, annealing_linear, format_param
+class MomentumUpdaterHook(Hook):
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.9):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_momentum" must be in range (0,1]'
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+ self.base_momentum = [] # initial momentum for all param groups
+ self.regular_momentum = [
+ ] # expected momentum if no warming up is performed
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ def get_momentum(self, runner, base_momentum):
+ raise NotImplementedError
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k in runner.optimizer.keys():
+ _momentum_group = [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum[k]
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ return [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum
+ ]
+ def get_warmup_momentum(self, cur_iters):
+ def _get_warmup_momentum(cur_iters, regular_momentum):
+ if self.warmup == 'constant':
+ warmup_momentum = [
+ _momentum / self.warmup_ratio
+ for _momentum in self.regular_momentum
+ ]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_momentum = [
+ _momentum / (1 - k) for _momentum in self.regular_mom
+ ]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_momentum = [
+ _momentum / k for _momentum in self.regular_mom
+ ]
+ return warmup_momentum
+ if isinstance(self.regular_momentum, dict):
+ momentum_groups = {}
+ for key, regular_momentum in self.regular_momentum.items():
+ momentum_groups[key] = _get_warmup_momentum(
+ cur_iters, regular_momentum)
+ return momentum_groups
+ else:
+ return _get_warmup_momentum(cur_iters, self.regular_momentum)
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint,
+ # if 'initial_momentum' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_momentum = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ _base_momentum = [
+ group['initial_momentum'] for group in optim.param_groups
+ ]
+ self.base_momentum.update({k: _base_momentum})
+ else:
+ for group in runner.optimizer.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ self.base_momentum = [
+ group['initial_momentum']
+ for group in runner.optimizer.param_groups
+ ]
+ def before_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+ self.regular_mom = self.get_regular_momentum(runner)
+ self._set_momentum(runner, self.regular_mom)
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_mom = self.get_regular_momentum(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+class StepMomentumUpdaterHook(MomentumUpdaterHook):
+ """Step momentum scheduler with min value clipping.
+ Args:
+ step (int | list[int]): Step to decay the momentum. If an int value is
+ given, regard it as the decay interval. If a list is given, decay
+ momentum at these steps.
+ gamma (float, optional): Decay momentum ratio. Default: 0.5.
+ min_momentum (float, optional): Minimum momentum value to keep. If
+ momentum after decay is lower than this value, it will be clipped
+ accordingly. If None is given, we don't perform lr clipping.
+ Default: None.
+ """
+ def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_momentum = min_momentum
+ super(StepMomentumUpdaterHook, self).__init__(**kwargs)
+ def get_momentum(self, runner, base_momentum):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+ momentum = base_momentum * (self.gamma**exp)
+ if self.min_momentum is not None:
+ # clip to a minimum value
+ momentum = max(momentum, self.min_momentum)
+ return momentum
+class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
+ def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
+ assert (min_momentum is None) ^ (min_momentum_ratio is None)
+ self.min_momentum = min_momentum
+ self.min_momentum_ratio = min_momentum_ratio
+ super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
+ def get_momentum(self, runner, base_momentum):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ if self.min_momentum_ratio is not None:
+ target_momentum = base_momentum * self.min_momentum_ratio
+ else:
+ target_momentum = self.min_momentum
+ return annealing_cos(base_momentum, target_momentum,
+ progress / max_progress)
+class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
+ """Cyclic momentum Scheduler.
+ Implement the cyclical momentum scheduler policy described in
+ https://arxiv.org/pdf/1708.07120.pdf
+ This momentum scheduler usually used together with the CyclicLRUpdater
+ to improve the performance in the 3D detection area.
+ Attributes:
+ target_ratio (tuple[float]): Relative ratio of the lowest momentum and
+ the highest momentum to the initial momentum.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of momentum
+ in the total cycle.
+ by_epoch (bool): Whether to update momentum by epoch.
+ """
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(0.85 / 0.95, 1),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.momentum_phases = [] # init momentum_phases
+ # currently only support by_epoch=False
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
+ def before_run(self, runner):
+ super(CyclicMomentumUpdaterHook, self).before_run(runner)
+ # initiate momentum_phases
+ # total momentum_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.momentum_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.momentum_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+ def get_momentum(self, runner, base_momentum):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.momentum_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return annealing_cos(base_momentum * start_ratio,
+ base_momentum * end_ratio,
+ progress / (end_iter - start_iter))
+class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
+ """OneCycle momentum Scheduler.
+ This momentum scheduler usually used together with the OneCycleLrUpdater
+ to improve the performance.
+ Args:
+ base_momentum (float or list): Lower momentum boundaries in the cycle
+ for each parameter group. Note that momentum is cycled inversely
+ to learning rate; at the peak of a cycle, momentum is
+ 'base_momentum' and learning rate is 'max_lr'.
+ Default: 0.85
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ Note that momentum is cycled inversely
+ to learning rate; at the start of a cycle, momentum is
+ 'max_momentum' and learning rate is 'base_lr'
+ Default: 0.95
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+ def __init__(self,
+ base_momentum=0.85,
+ max_momentum=0.95,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch=False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(base_momentum, (float, list, dict)):
+ raise ValueError('base_momentum must be the type among of float,'
+ 'list or dict.')
+ self._base_momentum = base_momentum
+ if not isinstance(max_momentum, (float, list, dict)):
+ raise ValueError('max_momentum must be the type among of float,'
+ 'list or dict.')
+ self._max_momentum = max_momentum
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('Expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must by one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.three_phase = three_phase
+ self.momentum_phases = [] # init momentum_phases
+ super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
+ def before_run(self, runner):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(
+ optim.param_groups, _base_momentum, _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+ else:
+ optim = runner.optimizer
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ k = type(optim).__name__
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(optim.param_groups,
+ _base_momentum,
+ _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+ if self.three_phase:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter':
+ float(2 * self.pct_start * runner.max_iters) - 2,
+ 'start_momentum':
+ 'base_momentum',
+ 'end_momentum':
+ 'max_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'max_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+ else:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'base_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ def get_momentum(self, runner, param_group):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, phase in enumerate(self.momentum_phases):
+ end_iter = phase['end_iter']
+ if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ momentum = self.anneal_func(
+ param_group[phase['start_momentum']],
+ param_group[phase['end_momentum']], pct)
+ break
+ start_iter = end_iter
+ return momentum
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k, optim in runner.optimizer.items():
+ _momentum_group = [
+ self.get_momentum(runner, param_group)
+ for param_group in optim.param_groups
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ momentum_groups = []
+ for param_group in runner.optimizer.param_groups:
+ momentum_groups.append(self.get_momentum(runner, param_group))
+ return momentum_groups
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/optimizer.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ef3e9ff8f9c6926e32bdf027612267b64ed80df
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/optimizer.py
@@ -0,0 +1,508 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from collections import defaultdict
+from itertools import chain
+from torch.nn.utils import clip_grad
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
+from ..dist_utils import allreduce_grads
+from ..fp16_utils import LossScaler, wrap_fp16_model
+from .hook import HOOKS, Hook
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ from torch.cuda.amp import GradScaler
+except ImportError:
+ pass
+class OptimizerHook(Hook):
+ def __init__(self, grad_clip=None):
+ self.grad_clip = grad_clip
+ def clip_grads(self, params):
+ params = list(
+ filter(lambda p: p.requires_grad and p.grad is not None, params))
+ if len(params) > 0:
+ return clip_grad.clip_grad_norm_(params, **self.grad_clip)
+ def after_train_iter(self, runner):
+ runner.optimizer.zero_grad()
+ runner.outputs['loss'].backward()
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+class GradientCumulativeOptimizerHook(OptimizerHook):
+ """Optimizer Hook implements multi-iters gradient cumulating.
+ Args:
+ cumulative_iters (int, optional): Num of gradient cumulative iters.
+ The optimizer will step every `cumulative_iters` iters.
+ Defaults to 1.
+ Examples:
+ >>> # Use cumulative_iters to simulate a large batch size
+ >>> # It is helpful when the hardware cannot handle a large batch size.
+ >>> loader = DataLoader(data, batch_size=64)
+ >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
+ >>> # almost equals to
+ >>> loader = DataLoader(data, batch_size=256)
+ >>> optim_hook = OptimizerHook()
+ """
+ def __init__(self, cumulative_iters=1, **kwargs):
+ super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
+ assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
+ f'cumulative_iters only accepts positive int, but got ' \
+ f'{type(cumulative_iters)} instead.'
+ self.cumulative_iters = cumulative_iters
+ self.divisible_iters = 0
+ self.remainder_iters = 0
+ self.initialized = False
+ def has_batch_norm(self, module):
+ if isinstance(module, _BatchNorm):
+ return True
+ for m in module.children():
+ if self.has_batch_norm(m):
+ return True
+ return False
+ def _init(self, runner):
+ if runner.iter % self.cumulative_iters != 0:
+ runner.logger.warning(
+ 'Resume iter number is not divisible by cumulative_iters in '
+ 'GradientCumulativeOptimizerHook, which means the gradient of '
+ 'some iters is lost and the result may be influenced slightly.'
+ )
+ if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
+ runner.logger.warning(
+ 'GradientCumulativeOptimizerHook may slightly decrease '
+ 'performance if the model has BatchNorm layers.')
+ residual_iters = runner.max_iters - runner.iter
+ self.divisible_iters = (
+ residual_iters // self.cumulative_iters * self.cumulative_iters)
+ self.remainder_iters = residual_iters - self.divisible_iters
+ self.initialized = True
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+ loss.backward()
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+ runner.optimizer.zero_grad()
+if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (using PyTorch's implementation).
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of GradScalar.
+ Defaults to 512. For Pytorch >= 1.6, mmcv uses official
+ implementation of GradScaler. If you use a dict version of
+ loss_scale to create GradScaler, please refer to:
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
+ for the parameters.
+ Examples:
+ >>> loss_scale = dict(
+ ... init_scale=65536.0,
+ ... growth_factor=2.0,
+ ... backoff_factor=0.5,
+ ... growth_interval=2000
+ ... )
+ >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
+ """
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ self._scale_update_param = None
+ if loss_scale == 'dynamic':
+ self.loss_scaler = GradScaler()
+ elif isinstance(loss_scale, float):
+ self._scale_update_param = loss_scale
+ self.loss_scaler = GradScaler(init_scale=loss_scale)
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = GradScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training."""
+ # wrap model mode to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer to
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients.
+ 3. Unscale the optimizer’s gradient tensors.
+ 4. Call optimizer.step() and update scale factor.
+ 5. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+ self.loss_scaler.scale(runner.outputs['loss']).backward()
+ self.loss_scaler.unscale_(runner.optimizer)
+ # grad clip
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using PyTorch's implementation) implements
+ multi-iters gradient cumulating.
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+ """
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+ self.loss_scaler.scale(loss).backward()
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ self.loss_scaler.unscale_(runner.optimizer)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (mmcv's implementation).
+ The steps of fp16 optimizer is as follows.
+ 1. Scale the loss value.
+ 2. BP in the fp16 model.
+ 2. Copy gradients from fp16 model to fp32 weights.
+ 3. Update fp32 weights.
+ 4. Copy updated parameters from fp32 weights to fp16 model.
+ Refer to https://arxiv.org/abs/1710.03740 for more details.
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of LossScaler.
+ Defaults to 512.
+ """
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ if loss_scale == 'dynamic':
+ self.loss_scaler = LossScaler(mode='dynamic')
+ elif isinstance(loss_scale, float):
+ self.loss_scaler = LossScaler(
+ init_scale=loss_scale, mode='static')
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = LossScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training.
+ 1. Make a master copy of fp32 weights for optimization.
+ 2. Convert the main model from fp32 to fp16.
+ """
+ # keep a copy of fp32 weights
+ old_groups = runner.optimizer.param_groups
+ runner.optimizer.param_groups = copy.deepcopy(
+ runner.optimizer.param_groups)
+ state = defaultdict(dict)
+ p_map = {
+ old_p: p
+ for old_p, p in zip(
+ chain(*(g['params'] for g in old_groups)),
+ chain(*(g['params']
+ for g in runner.optimizer.param_groups)))
+ }
+ for k, v in runner.optimizer.state.items():
+ state[p_map[k]] = v
+ runner.optimizer.state = state
+ # convert model to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer `loss_scalar.py`
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients (fp16).
+ 3. Copy gradients from the model to the fp32 weight copy.
+ 4. Scale the gradients back and update the fp32 weight copy.
+ 5. Copy back the params from fp32 weight copy to the fp16 model.
+ 6. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+ # scale the loss value
+ scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ self.loss_scaler.update_scale(has_overflow)
+ if has_overflow:
+ runner.logger.warning('Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using mmcv implementation) implements multi-
+ iters gradient cumulating."""
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+ # scale the loss value
+ scaled_loss = loss * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ else:
+ runner.logger.warning(
+ 'Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+ self.loss_scaler.update_scale(has_overflow)
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/profiler.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70236997eec59c2209ef351ae38863b4112d0ec
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/profiler.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Callable, List, Optional, Union
+import torch
+from ..dist_utils import master_only
+from .hook import HOOKS, Hook
+class ProfilerHook(Hook):
+ """Profiler to analyze performance during training.
+ PyTorch Profiler is a tool that allows the collection of the performance
+ metrics during the training. More details on Profiler can be found at
+ https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
+ Args:
+ by_epoch (bool): Profile performance by epoch or by iteration.
+ Default: True.
+ profile_iters (int): Number of iterations for profiling.
+ If ``by_epoch=True``, profile_iters indicates that they are the
+ first profile_iters epochs at the beginning of the
+ training, otherwise it indicates the first profile_iters
+ iterations. Default: 1.
+ activities (list[str]): List of activity groups (CPU, CUDA) to use in
+ profiling. Default: ['cpu', 'cuda'].
+ schedule (dict, optional): Config of generating the callable schedule.
+ if schedule is None, profiler will not add step markers into the
+ trace and table view. Default: None.
+ on_trace_ready (callable, dict): Either a handler or a dict of generate
+ handler. Default: None.
+ record_shapes (bool): Save information about operator's input shapes.
+ Default: False.
+ profile_memory (bool): Track tensor memory allocation/deallocation.
+ Default: False.
+ with_stack (bool): Record source information (file and line number)
+ for the ops. Default: False.
+ with_flops (bool): Use formula to estimate the FLOPS of specific
+ operators (matrix multiplication and 2D convolution).
+ Default: False.
+ json_trace_path (str, optional): Exports the collected trace in Chrome
+ JSON format. Default: None.
+ Example:
+ >>> runner = ... # instantiate a Runner
+ >>> # tensorboard trace
+ >>> trace_config = dict(type='tb_trace', dir_name='work_dir')
+ >>> profiler_config = dict(on_trace_ready=trace_config)
+ >>> runner.register_profiler_hook(profiler_config)
+ >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
+ """
+ def __init__(self,
+ by_epoch: bool = True,
+ profile_iters: int = 1,
+ activities: List[str] = ['cpu', 'cuda'],
+ schedule: Optional[dict] = None,
+ on_trace_ready: Optional[Union[Callable, dict]] = None,
+ record_shapes: bool = False,
+ profile_memory: bool = False,
+ with_stack: bool = False,
+ with_flops: bool = False,
+ json_trace_path: Optional[str] = None) -> None:
+ try:
+ from torch import profiler # torch version >= 1.8.1
+ except ImportError:
+ raise ImportError('profiler is the new feature of torch1.8.1, '
+ f'but your version is {torch.__version__}')
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
+ self.by_epoch = by_epoch
+ if profile_iters < 1:
+ raise ValueError('profile_iters should be greater than 0, but got '
+ f'{profile_iters}')
+ self.profile_iters = profile_iters
+ if not isinstance(activities, list):
+ raise ValueError(
+ f'activities should be list, but got {type(activities)}')
+ self.activities = []
+ for activity in activities:
+ activity = activity.lower()
+ if activity == 'cpu':
+ self.activities.append(profiler.ProfilerActivity.CPU)
+ elif activity == 'cuda':
+ self.activities.append(profiler.ProfilerActivity.CUDA)
+ else:
+ raise ValueError(
+ f'activity should be "cpu" or "cuda", but got {activity}')
+ if schedule is not None:
+ self.schedule = profiler.schedule(**schedule)
+ else:
+ self.schedule = None
+ self.on_trace_ready = on_trace_ready
+ self.record_shapes = record_shapes
+ self.profile_memory = profile_memory
+ self.with_stack = with_stack
+ self.with_flops = with_flops
+ self.json_trace_path = json_trace_path
+ @master_only
+ def before_run(self, runner):
+ if self.by_epoch and runner.max_epochs < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_epochs}')
+ if not self.by_epoch and runner.max_iters < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_iters}')
+ if callable(self.on_trace_ready): # handler
+ _on_trace_ready = self.on_trace_ready
+ elif isinstance(self.on_trace_ready, dict): # config of handler
+ trace_cfg = self.on_trace_ready.copy()
+ trace_type = trace_cfg.pop('type') # log_trace handler
+ if trace_type == 'log_trace':
+ def _log_handler(prof):
+ print(prof.key_averages().table(**trace_cfg))
+ _on_trace_ready = _log_handler
+ elif trace_type == 'tb_trace': # tensorboard_trace handler
+ try:
+ import torch_tb_profiler # noqa: F401
+ except ImportError:
+ raise ImportError('please run "pip install '
+ 'torch-tb-profiler" to install '
+ 'torch_tb_profiler')
+ _on_trace_ready = torch.profiler.tensorboard_trace_handler(
+ **trace_cfg)
+ else:
+ raise ValueError('trace_type should be "log_trace" or '
+ f'"tb_trace", but got {trace_type}')
+ elif self.on_trace_ready is None:
+ _on_trace_ready = None # type: ignore
+ else:
+ raise ValueError('on_trace_ready should be handler, dict or None, '
+ f'but got {type(self.on_trace_ready)}')
+ if runner.max_epochs > 1:
+ warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
+ 'instead of 1 epoch. Since profiler will slow down '
+ 'the training, it is recommended to train 1 epoch '
+ 'with ProfilerHook and adjust your setting according'
+ ' to the profiler summary. During normal training '
+ '(epoch > 1), you may disable the ProfilerHook.')
+ self.profiler = torch.profiler.profile(
+ activities=self.activities,
+ schedule=self.schedule,
+ on_trace_ready=_on_trace_ready,
+ record_shapes=self.record_shapes,
+ profile_memory=self.profile_memory,
+ with_stack=self.with_stack,
+ with_flops=self.with_flops)
+ self.profiler.__enter__()
+ runner.logger.info('profiler is profiling...')
+ @master_only
+ def after_train_epoch(self, runner):
+ if self.by_epoch and runner.epoch == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
+ @master_only
+ def after_train_iter(self, runner):
+ self.profiler.step()
+ if not self.by_epoch and runner.iter == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0dc6bdd8df5775857028aaed5444c0f59caf80
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+class DistSamplerSeedHook(Hook):
+ """Data-loading sampler for distributed training.
+ When distributed training, it is only useful in conjunction with
+ :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
+ purpose with :obj:`IterLoader`.
+ """
+ def before_epoch(self, runner):
+ if hasattr(runner.data_loader.sampler, 'set_epoch'):
+ # in case the data loader uses `SequentialSampler` in Pytorch
+ runner.data_loader.sampler.set_epoch(runner.epoch)
+ elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
+ # batch sampler in pytorch warps the sampler as its attributes.
+ runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py b/ControlNet/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6376b7ff894280cb2782243b25e8973650591577
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..dist_utils import allreduce_params
+from .hook import HOOKS, Hook
+class SyncBuffersHook(Hook):
+ """Synchronize model buffers such as running_mean and running_var in BN at
+ the end of each epoch.
+ Args:
+ distributed (bool): Whether distributed training is used. It is
+ effective only for distributed training. Defaults to True.
+ """
+ def __init__(self, distributed=True):
+ self.distributed = distributed
+ def after_epoch(self, runner):
+ """All-reduce model buffers at the end of each epoch."""
+ if self.distributed:
+ allreduce_params(runner.model.buffers())
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/iter_based_runner.py b/ControlNet/annotator/uniformer/mmcv/runner/iter_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df4de8c0285669dec9b014dfd1f3dd1600f0831
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/iter_based_runner.py
@@ -0,0 +1,273 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+import torch
+from torch.optim import Optimizer
+import annotator.uniformer.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .hooks import IterTimerHook
+from .utils import get_host_info
+class IterLoader:
+ def __init__(self, dataloader):
+ self._dataloader = dataloader
+ self.iter_loader = iter(self._dataloader)
+ self._epoch = 0
+ @property
+ def epoch(self):
+ return self._epoch
+ def __next__(self):
+ try:
+ data = next(self.iter_loader)
+ except StopIteration:
+ self._epoch += 1
+ if hasattr(self._dataloader.sampler, 'set_epoch'):
+ self._dataloader.sampler.set_epoch(self._epoch)
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.iter_loader = iter(self._dataloader)
+ data = next(self.iter_loader)
+ return data
+ def __len__(self):
+ return len(self._dataloader)
+class IterBasedRunner(BaseRunner):
+ """Iteration-based Runner.
+ This runner train models iteration by iteration.
+ """
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._epoch = data_loader.epoch
+ data_batch = next(data_loader)
+ self.call_hook('before_train_iter')
+ outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.train_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_train_iter')
+ self._inner_iter += 1
+ self._iter += 1
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ data_batch = next(data_loader)
+ self.call_hook('before_val_iter')
+ outputs = self.model.val_step(data_batch, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.val_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_val_iter')
+ self._inner_iter += 1
+ def run(self, data_loaders, workflow, max_iters=None, **kwargs):
+ """Start running.
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, iters) to specify the
+ running order and iterations. E.g, [('train', 10000),
+ ('val', 1000)] means running 10000 iterations for training and
+ 1000 iterations for validation, iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_iters is not None:
+ warnings.warn(
+ 'setting max_iters in run is deprecated, '
+ 'please set max_iters in runner_config', DeprecationWarning)
+ self._max_iters = max_iters
+ assert self._max_iters is not None, (
+ 'max_iters must be specified during instantiation')
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d iters', workflow,
+ self._max_iters)
+ self.call_hook('before_run')
+ iter_loaders = [IterLoader(x) for x in data_loaders]
+ self.call_hook('before_epoch')
+ while self.iter < self._max_iters:
+ for i, flow in enumerate(workflow):
+ self._inner_iter = 0
+ mode, iters = flow
+ if not isinstance(mode, str) or not hasattr(self, mode):
+ raise ValueError(
+ 'runner has no method named "{}" to run a workflow'.
+ format(mode))
+ iter_runner = getattr(self, mode)
+ for _ in range(iters):
+ if mode == 'train' and self.iter >= self._max_iters:
+ break
+ iter_runner(iter_loaders[i], **kwargs)
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_epoch')
+ self.call_hook('after_run')
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ """Resume model from checkpoint.
+ Args:
+ checkpoint (str): Checkpoint to resume from.
+ resume_optimizer (bool, optional): Whether resume the optimizer(s)
+ if the checkpoint file includes optimizer(s). Default to True.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default to 'default'.
+ """
+ if map_location == 'default':
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ self._inner_iter = checkpoint['meta']['iter']
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+ self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='iter_{}.pth',
+ meta=None,
+ save_optimizer=True,
+ create_symlink=True):
+ """Save checkpoint to file.
+ Args:
+ out_dir (str): Directory to save checkpoint files.
+ filename_tmpl (str, optional): Checkpoint file template.
+ Defaults to 'iter_{}.pth'.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ Defaults to None.
+ save_optimizer (bool, optional): Whether save optimizer.
+ Defaults to True.
+ create_symlink (bool, optional): Whether create symlink to the
+ latest checkpoint file. Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+ filename = filename_tmpl.format(self.iter + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ custom_hooks_config=None):
+ """Register default hooks for iter-based training.
+ Checkpoint hook, optimizer stepper hook and logger hooks will be set to
+ `by_epoch=False` by default.
+ Default hooks include:
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ if checkpoint_config is not None:
+ checkpoint_config.setdefault('by_epoch', False)
+ if lr_config is not None:
+ lr_config.setdefault('by_epoch', False)
+ if log_config is not None:
+ for info in log_config['hooks']:
+ info.setdefault('by_epoch', False)
+ super(IterBasedRunner, self).register_training_hooks(
+ lr_config=lr_config,
+ momentum_config=momentum_config,
+ optimizer_config=optimizer_config,
+ checkpoint_config=checkpoint_config,
+ log_config=log_config,
+ timer_config=IterTimerHook(),
+ custom_hooks_config=custom_hooks_config)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/log_buffer.py b/ControlNet/annotator/uniformer/mmcv/runner/log_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d949e2941c5400088c7cd8a1dc893d8b233ae785
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/log_buffer.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+import numpy as np
+class LogBuffer:
+ def __init__(self):
+ self.val_history = OrderedDict()
+ self.n_history = OrderedDict()
+ self.output = OrderedDict()
+ self.ready = False
+ def clear(self):
+ self.val_history.clear()
+ self.n_history.clear()
+ self.clear_output()
+ def clear_output(self):
+ self.output.clear()
+ self.ready = False
+ def update(self, vars, count=1):
+ assert isinstance(vars, dict)
+ for key, var in vars.items():
+ if key not in self.val_history:
+ self.val_history[key] = []
+ self.n_history[key] = []
+ self.val_history[key].append(var)
+ self.n_history[key].append(count)
+ def average(self, n=0):
+ """Average latest n values or all values."""
+ assert n >= 0
+ for key in self.val_history:
+ values = np.array(self.val_history[key][-n:])
+ nums = np.array(self.n_history[key][-n:])
+ avg = np.sum(values * nums) / np.sum(nums)
+ self.output[key] = avg
+ self.ready = True
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/optimizer/__init__.py b/ControlNet/annotator/uniformer/mmcv/runner/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c34d0470992cbc374f29681fdd00dc0e57968d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/optimizer/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
+ build_optimizer_constructor)
+from .default_constructor import DefaultOptimizerConstructor
+__all__ = [
+ 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
+ 'build_optimizer', 'build_optimizer_constructor'
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/optimizer/builder.py b/ControlNet/annotator/uniformer/mmcv/runner/optimizer/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9234eed8f1f186d9d8dfda34562157ee39bdb3a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/optimizer/builder.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import inspect
+import torch
+from ...utils import Registry, build_from_cfg
+OPTIMIZERS = Registry('optimizer')
+OPTIMIZER_BUILDERS = Registry('optimizer builder')
+def register_torch_optimizers():
+ torch_optimizers = []
+ for module_name in dir(torch.optim):
+ if module_name.startswith('__'):
+ continue
+ _optim = getattr(torch.optim, module_name)
+ if inspect.isclass(_optim) and issubclass(_optim,
+ torch.optim.Optimizer):
+ OPTIMIZERS.register_module()(_optim)
+ torch_optimizers.append(module_name)
+ return torch_optimizers
+TORCH_OPTIMIZERS = register_torch_optimizers()
+def build_optimizer_constructor(cfg):
+ return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
+def build_optimizer(model, cfg):
+ optimizer_cfg = copy.deepcopy(cfg)
+ constructor_type = optimizer_cfg.pop('constructor',
+ 'DefaultOptimizerConstructor')
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
+ optim_constructor = build_optimizer_constructor(
+ dict(
+ type=constructor_type,
+ optimizer_cfg=optimizer_cfg,
+ paramwise_cfg=paramwise_cfg))
+ optimizer = optim_constructor(model)
+ return optimizer
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py b/ControlNet/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c0da3503b75441738efe38d70352b55a210a34a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py
@@ -0,0 +1,249 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+import torch
+from torch.nn import GroupNorm, LayerNorm
+from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
+from annotator.uniformer.mmcv.utils.ext_loader import check_ops_exist
+class DefaultOptimizerConstructor:
+ """Default constructor for optimizers.
+ By default each parameter share the same optimizer settings, and we
+ provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
+ It is a dict and may contain the following fields:
+ - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
+ one of the keys in ``custom_keys`` is a substring of the name of one
+ parameter, then the setting of the parameter will be specified by
+ ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
+ be ignored. It should be noted that the aforementioned ``key`` is the
+ longest key that is a substring of the name of the parameter. If there
+ are multiple matched keys with the same length, then the key with lower
+ alphabet order will be chosen.
+ ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
+ and ``decay_mult``. See Example 2 below.
+ - ``bias_lr_mult`` (float): It will be multiplied to the learning
+ rate for all bias parameters (except for those in normalization
+ layers and offset layers of DCN).
+ - ``bias_decay_mult`` (float): It will be multiplied to the weight
+ decay for all bias parameters (except for those in
+ normalization layers, depthwise conv layers, offset layers of DCN).
+ - ``norm_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of normalization
+ layers.
+ - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of depthwise conv
+ layers.
+ - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
+ rate for parameters of offset layer in the deformable convs
+ of a model.
+ - ``bypass_duplicate`` (bool): If true, the duplicate parameters
+ would not be added into optimizer. Default: False.
+ Note:
+ 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ override the effect of ``bias_lr_mult`` in the bias of offset
+ layer. So be careful when using both ``bias_lr_mult`` and
+ ``dcn_offset_lr_mult``. If you wish to apply both of them to the
+ offset layer in deformable convs, set ``dcn_offset_lr_mult``
+ to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
+ 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ apply it to all the DCN layers in the model. So be careful when
+ the model contains multiple DCN layers in places other than
+ backbone.
+ Args:
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
+ optimizer_cfg (dict): The config dict of the optimizer.
+ Positional fields are
+ - `type`: class name of the optimizer.
+ Optional fields are
+ - any arguments of the corresponding optimizer type, e.g.,
+ lr, weight_decay, momentum, etc.
+ paramwise_cfg (dict, optional): Parameter-wise options.
+ Example 1:
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
+ >>> weight_decay=0.0001)
+ >>> paramwise_cfg = dict(norm_decay_mult=0.)
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+ Example 2:
+ >>> # assume model have attribute model.backbone and model.cls_head
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
+ >>> paramwise_cfg = dict(custom_keys={
+ '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+ >>> # Then the `lr` and `weight_decay` for model.backbone is
+ >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
+ >>> # model.cls_head is (0.01, 0.95).
+ """
+ def __init__(self, optimizer_cfg, paramwise_cfg=None):
+ if not isinstance(optimizer_cfg, dict):
+ raise TypeError('optimizer_cfg should be a dict',
+ f'but got {type(optimizer_cfg)}')
+ self.optimizer_cfg = optimizer_cfg
+ self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
+ self.base_lr = optimizer_cfg.get('lr', None)
+ self.base_wd = optimizer_cfg.get('weight_decay', None)
+ self._validate_cfg()
+ def _validate_cfg(self):
+ if not isinstance(self.paramwise_cfg, dict):
+ raise TypeError('paramwise_cfg should be None or a dict, '
+ f'but got {type(self.paramwise_cfg)}')
+ if 'custom_keys' in self.paramwise_cfg:
+ if not isinstance(self.paramwise_cfg['custom_keys'], dict):
+ raise TypeError(
+ 'If specified, custom_keys must be a dict, '
+ f'but got {type(self.paramwise_cfg["custom_keys"])}')
+ if self.base_wd is None:
+ for key in self.paramwise_cfg['custom_keys']:
+ if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
+ raise ValueError('base_wd should not be None')
+ # get base lr and weight decay
+ # weight_decay must be explicitly specified if mult is specified
+ if ('bias_decay_mult' in self.paramwise_cfg
+ or 'norm_decay_mult' in self.paramwise_cfg
+ or 'dwconv_decay_mult' in self.paramwise_cfg):
+ if self.base_wd is None:
+ raise ValueError('base_wd should not be None')
+ def _is_in(self, param_group, param_group_list):
+ assert is_list_of(param_group_list, dict)
+ param = set(param_group['params'])
+ param_set = set()
+ for group in param_group_list:
+ param_set.update(set(group['params']))
+ return not param.isdisjoint(param_set)
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
+ """Add all parameters of module to the params list.
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ prefix (str): The prefix of the module
+ is_dcn_module (int|float|None): If the current module is a
+ submodule of DCN, `is_dcn_module` will be passed to
+ control conv_offset layer's learning rate. Defaults to None.
+ """
+ # get param-wise options
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
+ # first sort with alphabet order and then sort with reversed len of str
+ sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
+ bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
+ dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
+ bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
+ dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
+ # special rules for norm layers and depth-wise conv layers
+ is_norm = isinstance(module,
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
+ is_dwconv = (
+ isinstance(module, torch.nn.Conv2d)
+ and module.in_channels == module.groups)
+ for name, param in module.named_parameters(recurse=False):
+ param_group = {'params': [param]}
+ if not param.requires_grad:
+ params.append(param_group)
+ continue
+ if bypass_duplicate and self._is_in(param_group, params):
+ warnings.warn(f'{prefix} is duplicate. It is skipped since '
+ f'bypass_duplicate={bypass_duplicate}')
+ continue
+ # if the parameter match one of the custom keys, ignore other rules
+ is_custom = False
+ for key in sorted_keys:
+ if key in f'{prefix}.{name}':
+ is_custom = True
+ lr_mult = custom_keys[key].get('lr_mult', 1.)
+ param_group['lr'] = self.base_lr * lr_mult
+ if self.base_wd is not None:
+ decay_mult = custom_keys[key].get('decay_mult', 1.)
+ param_group['weight_decay'] = self.base_wd * decay_mult
+ break
+ if not is_custom:
+ # bias_lr_mult affects all bias parameters
+ # except for norm.bias dcn.conv_offset.bias
+ if name == 'bias' and not (is_norm or is_dcn_module):
+ param_group['lr'] = self.base_lr * bias_lr_mult
+ if (prefix.find('conv_offset') != -1 and is_dcn_module
+ and isinstance(module, torch.nn.Conv2d)):
+ # deal with both dcn_offset's bias & weight
+ param_group['lr'] = self.base_lr * dcn_offset_lr_mult
+ # apply weight decay policies
+ if self.base_wd is not None:
+ # norm decay
+ if is_norm:
+ param_group[
+ 'weight_decay'] = self.base_wd * norm_decay_mult
+ # depth-wise conv
+ elif is_dwconv:
+ param_group[
+ 'weight_decay'] = self.base_wd * dwconv_decay_mult
+ # bias lr and decay
+ elif name == 'bias' and not is_dcn_module:
+ # TODO: current bias_decay_mult will have affect on DCN
+ param_group[
+ 'weight_decay'] = self.base_wd * bias_decay_mult
+ params.append(param_group)
+ if check_ops_exist():
+ from annotator.uniformer.mmcv.ops import DeformConv2d, ModulatedDeformConv2d
+ is_dcn_module = isinstance(module,
+ (DeformConv2d, ModulatedDeformConv2d))
+ else:
+ is_dcn_module = False
+ for child_name, child_mod in module.named_children():
+ child_prefix = f'{prefix}.{child_name}' if prefix else child_name
+ self.add_params(
+ params,
+ child_mod,
+ prefix=child_prefix,
+ is_dcn_module=is_dcn_module)
+ def __call__(self, model):
+ if hasattr(model, 'module'):
+ model = model.module
+ optimizer_cfg = self.optimizer_cfg.copy()
+ # if no paramwise option is specified, just use the global setting
+ if not self.paramwise_cfg:
+ optimizer_cfg['params'] = model.parameters()
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
+ # set param-wise lr and weight decay recursively
+ params = []
+ self.add_params(params, model)
+ optimizer_cfg['params'] = params
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/priority.py b/ControlNet/annotator/uniformer/mmcv/runner/priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..64cc4e3a05f8d5b89ab6eb32461e6e80f1d62e67
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/priority.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+class Priority(Enum):
+ """Hook priority levels.
+ +--------------+------------+
+ | Level | Value |
+ +==============+============+
+ | HIGHEST | 0 |
+ +--------------+------------+
+ | VERY_HIGH | 10 |
+ +--------------+------------+
+ | HIGH | 30 |
+ +--------------+------------+
+ | ABOVE_NORMAL | 40 |
+ +--------------+------------+
+ | NORMAL | 50 |
+ +--------------+------------+
+ | BELOW_NORMAL | 60 |
+ +--------------+------------+
+ | LOW | 70 |
+ +--------------+------------+
+ | VERY_LOW | 90 |
+ +--------------+------------+
+ | LOWEST | 100 |
+ +--------------+------------+
+ """
+ VERY_HIGH = 10
+ HIGH = 30
+ NORMAL = 50
+ LOW = 70
+ VERY_LOW = 90
+ LOWEST = 100
+def get_priority(priority):
+ """Get priority value.
+ Args:
+ priority (int or str or :obj:`Priority`): Priority.
+ Returns:
+ int: The priority value.
+ """
+ if isinstance(priority, int):
+ if priority < 0 or priority > 100:
+ raise ValueError('priority must be between 0 and 100')
+ return priority
+ elif isinstance(priority, Priority):
+ return priority.value
+ elif isinstance(priority, str):
+ return Priority[priority.upper()].value
+ else:
+ raise TypeError('priority must be an integer or Priority enum value')
diff --git a/ControlNet/annotator/uniformer/mmcv/runner/utils.py b/ControlNet/annotator/uniformer/mmcv/runner/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5befb8e56ece50b5fecfd007b26f8a29124c0bd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/runner/utils.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import random
+import sys
+import time
+import warnings
+from getpass import getuser
+from socket import gethostname
+import numpy as np
+import torch
+import annotator.uniformer.mmcv as mmcv
+def get_host_info():
+ """Get hostname and username.
+ Return empty string if exception raised, e.g. ``getpass.getuser()`` will
+ lead to error in docker container
+ """
+ host = ''
+ try:
+ host = f'{getuser()}@{gethostname()}'
+ except Exception as e:
+ warnings.warn(f'Host or user not found: {str(e)}')
+ finally:
+ return host
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+def obj_from_dict(info, parent=None, default_args=None):
+ """Initialize an object from dict.
+ The dict must contain the key "type", which indicates the object type, it
+ can be either a string or type, such as "list" or ``list``. Remaining
+ fields are treated as the arguments for constructing the object.
+ Args:
+ info (dict): Object types and arguments.
+ parent (:class:`module`): Module which may containing expected object
+ classes.
+ default_args (dict, optional): Default arguments for initializing the
+ object.
+ Returns:
+ any type: Object built from the dict.
+ """
+ assert isinstance(info, dict) and 'type' in info
+ assert isinstance(default_args, dict) or default_args is None
+ args = info.copy()
+ obj_type = args.pop('type')
+ if mmcv.is_str(obj_type):
+ if parent is not None:
+ obj_type = getattr(parent, obj_type)
+ else:
+ obj_type = sys.modules[obj_type]
+ elif not isinstance(obj_type, type):
+ raise TypeError('type must be a str or valid type, but '
+ f'got {type(obj_type)}')
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+ return obj_type(**args)
+def set_random_seed(seed, deterministic=False, use_rank_shift=False):
+ """Set random seed.
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ rank_shift (bool): Whether to add rank number to the random seed to
+ have different random seed in different threads. Default: False.
+ """
+ if use_rank_shift:
+ rank, _ = mmcv.runner.get_dist_info()
+ seed += rank
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/__init__.py b/ControlNet/annotator/uniformer/mmcv/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..378a0068432a371af364de9d73785901c0f83383
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/__init__.py
@@ -0,0 +1,69 @@
+# flake8: noqa
+# Copyright (c) OpenMMLab. All rights reserved.
+from .config import Config, ConfigDict, DictAction
+from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
+ has_method, import_modules_from_strings, is_list_of,
+ is_method_overridden, is_seq_of, is_str, is_tuple_of,
+ iter_cast, list_cast, requires_executable, requires_package,
+ slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
+ to_ntuple, tuple_cast)
+from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
+ scandir, symlink)
+from .progressbar import (ProgressBar, track_iter_progress,
+ track_parallel_progress, track_progress)
+from .testing import (assert_attrs_equal, assert_dict_contains_subset,
+ assert_dict_has_keys, assert_is_norm_layer,
+ assert_keys_equal, assert_params_all_zeros,
+ check_python_script)
+from .timer import Timer, TimerError, check_time
+from .version_utils import digit_version, get_git_hash
+ import torch
+except ImportError:
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
+ 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
+ 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
+ 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
+ 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
+ 'track_progress', 'track_iter_progress', 'track_parallel_progress',
+ 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
+ 'digit_version', 'get_git_hash', 'import_modules_from_strings',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
+ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
+ 'is_method_overridden', 'has_method'
+ ]
+ from .env import collect_env
+ from .logging import get_logger, print_log
+ from .parrots_jit import jit, skip_no_elena
+ from .parrots_wrapper import (
+ TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
+ PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
+ _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
+ _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
+ from .registry import Registry, build_from_cfg
+ from .trace import is_jit_tracing
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
+ 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
+ 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
+ 'check_prerequisites', 'requires_package', 'requires_executable',
+ 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
+ 'symlink', 'scandir', 'ProgressBar', 'track_progress',
+ 'track_iter_progress', 'track_parallel_progress', 'Registry',
+ 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
+ '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
+ '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
+ 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
+ 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
+ 'deprecated_api_warning', 'digit_version', 'get_git_hash',
+ 'import_modules_from_strings', 'jit', 'skip_no_elena',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
+ 'assert_params_all_zeros', 'check_python_script',
+ 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
+ '_get_cuda_home', 'has_method'
+ ]
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/config.py b/ControlNet/annotator/uniformer/mmcv/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..17149353aefac6d737c67bb2f35a3a6cd2147b0a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/config.py
@@ -0,0 +1,688 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import ast
+import copy
+import os
+import os.path as osp
+import platform
+import shutil
+import sys
+import tempfile
+import uuid
+import warnings
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+from .misc import import_modules_from_strings
+from .path import check_file_exist
+if platform.system() == 'Windows':
+ import regex as re
+ import re
+BASE_KEY = '_base_'
+DELETE_KEY = '_delete_'
+DEPRECATION_KEY = '_deprecation_'
+RESERVED_KEYS = ['filename', 'text', 'pretty_text']
+class ConfigDict(Dict):
+ def __missing__(self, name):
+ raise KeyError(name)
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
+ f"attribute '{name}'")
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+def add_args(parser, cfg, prefix=''):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument('--' + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument('--' + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument('--' + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument('--' + prefix + k, action='store_true')
+ elif isinstance(v, dict):
+ add_args(parser, v, prefix + k + '.')
+ elif isinstance(v, abc.Iterable):
+ parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
+ else:
+ print(f'cannot parse key {prefix + k} of type {type(v)}')
+ return parser
+class Config:
+ """A facility for config and config files.
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError as e:
+ raise SyntaxError('There are syntax errors in config '
+ f'file {filename}: {e}')
+ @staticmethod
+ def _substitute_predefined_vars(filename, temp_config_name):
+ file_dirname = osp.dirname(filename)
+ file_basename = osp.basename(filename)
+ file_basename_no_extension = osp.splitext(file_basename)[0]
+ file_extname = osp.splitext(filename)[1]
+ support_templates = dict(
+ fileDirname=file_dirname,
+ fileBasename=file_basename,
+ fileBasenameNoExtension=file_basename_no_extension,
+ fileExtname=file_extname)
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ for key, value in support_templates.items():
+ regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
+ value = value.replace('\\', '/')
+ config_file = re.sub(regexp, value, config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+ @staticmethod
+ def _pre_substitute_base_vars(filename, temp_config_name):
+ """Substitute base variable placehoders to string, so that parsing
+ would work."""
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ base_var_dict = {}
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
+ base_vars = set(re.findall(regexp, config_file))
+ for base_var in base_vars:
+ randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
+ base_var_dict[randstr] = base_var
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+ return base_var_dict
+ @staticmethod
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+ """Substitute variable strings to their actual values."""
+ cfg = copy.deepcopy(cfg)
+ if isinstance(cfg, dict):
+ for k, v in cfg.items():
+ if isinstance(v, str) and v in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[v].split('.'):
+ new_v = new_v[new_k]
+ cfg[k] = new_v
+ elif isinstance(v, (list, tuple, dict)):
+ cfg[k] = Config._substitute_base_vars(
+ v, base_var_dict, base_cfg)
+ elif isinstance(cfg, tuple):
+ cfg = tuple(
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg)
+ elif isinstance(cfg, list):
+ cfg = [
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg
+ ]
+ elif isinstance(cfg, str) and cfg in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[cfg].split('.'):
+ new_v = new_v[new_k]
+ cfg = new_v
+ return cfg
+ @staticmethod
+ def _file2dict(filename, use_predefined_variables=True):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ fileExtname = osp.splitext(filename)[1]
+ if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix=fileExtname)
+ if platform.system() == 'Windows':
+ temp_config_file.close()
+ temp_config_name = osp.basename(temp_config_file.name)
+ # Substitute predefined variables
+ if use_predefined_variables:
+ Config._substitute_predefined_vars(filename,
+ temp_config_file.name)
+ else:
+ shutil.copyfile(filename, temp_config_file.name)
+ # Substitute base variables from placeholders to strings
+ base_var_dict = Config._pre_substitute_base_vars(
+ temp_config_file.name, temp_config_file.name)
+ if filename.endswith('.py'):
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ Config._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith('__')
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ elif filename.endswith(('.yml', '.yaml', '.json')):
+ import annotator.uniformer.mmcv as mmcv
+ cfg_dict = mmcv.load(temp_config_file.name)
+ # close temp file
+ temp_config_file.close()
+ # check deprecation information
+ if DEPRECATION_KEY in cfg_dict:
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
+ warning_msg = f'The config file {filename} will be deprecated ' \
+ 'in the future.'
+ if 'expected' in deprecation_info:
+ warning_msg += f' Please use {deprecation_info["expected"]} ' \
+ 'instead.'
+ if 'reference' in deprecation_info:
+ warning_msg += ' More information can be found at ' \
+ f'{deprecation_info["reference"]}'
+ warnings.warn(warning_msg)
+ cfg_text = filename + '\n'
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ cfg_text += f.read()
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = base_filename if isinstance(
+ base_filename, list) else [base_filename]
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
+ if len(duplicate_keys) > 0:
+ raise KeyError('Duplicate key is not allowed among bases. '
+ f'Duplicate keys: {duplicate_keys}')
+ base_cfg_dict.update(c)
+ # Substitute base variables from strings to their actual values
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
+ base_cfg_dict)
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = '\n'.join(cfg_text_list)
+ return cfg_dict, cfg_text
+ @staticmethod
+ def _merge_a_into_b(a, b, allow_list_keys=False):
+ """merge dict ``a`` into dict ``b`` (non-inplace).
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
+ in-place modifications.
+ Args:
+ a (dict): The source dict to be merged into ``b``.
+ b (dict): The origin dict to be fetch keys from ``a``.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in source ``a`` and will replace the element of the
+ corresponding index in b if b is a list. Default: False.
+ Returns:
+ dict: The modified dict of ``b`` using ``a``.
+ Examples:
+ # Normally merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+ # Delete b first and merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+ # b is a list
+ >>> Config._merge_a_into_b(
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
+ [{'a': 2}, {'b': 2}]
+ """
+ b = b.copy()
+ for k, v in a.items():
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
+ k = int(k)
+ if len(b) <= k:
+ raise KeyError(f'Index {k} exceeds the length of list {b}')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ elif isinstance(v,
+ dict) and k in b and not v.pop(DELETE_KEY, False):
+ allowed_types = (dict, list) if allow_list_keys else dict
+ if not isinstance(b[k], allowed_types):
+ raise TypeError(
+ f'{k}={v} in child config cannot inherit from base '
+ f'because {k} is a dict in the child config but is of '
+ f'type {type(b[k])} in base config. You may set '
+ f'`{DELETE_KEY}=True` to ignore the base config')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ else:
+ b[k] = v
+ return b
+ @staticmethod
+ def fromfile(filename,
+ use_predefined_variables=True,
+ import_custom_modules=True):
+ cfg_dict, cfg_text = Config._file2dict(filename,
+ use_predefined_variables)
+ if import_custom_modules and cfg_dict.get('custom_imports', None):
+ import_modules_from_strings(**cfg_dict['custom_imports'])
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+ @staticmethod
+ def fromstring(cfg_str, file_format):
+ """Generate config from config str.
+ Args:
+ cfg_str (str): Config str.
+ file_format (str): Config file format corresponding to the
+ config str. Only py/yml/yaml/json type are supported now!
+ Returns:
+ obj:`Config`: Config obj.
+ """
+ if file_format not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+ if file_format != '.py' and 'dict(' in cfg_str:
+ # check if users specify a wrong suffix for python
+ warnings.warn(
+ 'Please check "file_format", the file format may be .py')
+ with tempfile.NamedTemporaryFile(
+ 'w', encoding='utf-8', suffix=file_format,
+ delete=False) as temp_file:
+ temp_file.write(cfg_str)
+ # on windows, previous implementation cause error
+ # see PR 1077 for details
+ cfg = Config.fromfile(temp_file.name)
+ os.remove(temp_file.name)
+ return cfg
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)"""
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument('config', help='config file path')
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument('config', help='config file path')
+ add_args(parser, cfg)
+ return parser, cfg
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError('cfg_dict must be a dict, but '
+ f'got {type(cfg_dict)}')
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f'{key} is reserved for config file')
+ super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
+ super(Config, self).__setattr__('_filename', filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, 'r') as f:
+ text = f.read()
+ else:
+ text = ''
+ super(Config, self).__setattr__('_text', text)
+ @property
+ def filename(self):
+ return self._filename
+ @property
+ def text(self):
+ return self._text
+ @property
+ def pretty_text(self):
+ indent = 4
+ def _indent(s_, num_spaces):
+ s = s_.split('\n')
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent)
+ return attr_str
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = '[\n'
+ v_str += '\n'.join(
+ f'dict({_indent(_format_dict(v_), indent)}),'
+ for v_ in v).rstrip(',')
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent) + ']'
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= \
+ (not str(key_name).isidentifier())
+ return contain_invalid_identifier
+ def _format_dict(input_dict, outest_level=False):
+ r = ''
+ s = []
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += '{'
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = '' if outest_level or is_last else ','
+ if isinstance(v, dict):
+ v_str = '\n' + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: dict({v_str}'
+ else:
+ attr_str = f'{str(k)}=dict({v_str}'
+ attr_str = _indent(attr_str, indent) + ')' + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+ s.append(attr_str)
+ r += '\n'.join(s)
+ if use_mapping:
+ r += '}'
+ return r
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style='pep8',
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True)
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+ return text
+ def __repr__(self):
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
+ def __len__(self):
+ return len(self._cfg_dict)
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+ def __iter__(self):
+ return iter(self._cfg_dict)
+ def __getstate__(self):
+ return (self._cfg_dict, self._filename, self._text)
+ def __setstate__(self, state):
+ _cfg_dict, _filename, _text = state
+ super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
+ super(Config, self).__setattr__('_filename', _filename)
+ super(Config, self).__setattr__('_text', _text)
+ def dump(self, file=None):
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
+ if self.filename.endswith('.py'):
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, 'w', encoding='utf-8') as f:
+ f.write(self.pretty_text)
+ else:
+ import annotator.uniformer.mmcv as mmcv
+ if file is None:
+ file_format = self.filename.split('.')[-1]
+ return mmcv.dump(cfg_dict, file_format=file_format)
+ else:
+ mmcv.dump(cfg_dict, file)
+ def merge_from_dict(self, options, allow_list_keys=True):
+ """Merge list into cfg_dict.
+ Merge the dict parsed by MultipleKVAction into this cfg.
+ Examples:
+ >>> options = {'model.backbone.depth': 50,
+ ... 'model.backbone.with_cp':True}
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
+ # Merge list element
+ >>> cfg = Config(dict(pipeline=[
+ ... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
+ >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
+ >>> cfg.merge_from_dict(options, allow_list_keys=True)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(pipeline=[
+ ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
+ Args:
+ options (dict): dict of configs to merge from.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in ``options`` and will replace the element of the
+ corresponding index in the config if the config is a list.
+ Default: True.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split('.')
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ super(Config, self).__setattr__(
+ '_cfg_dict',
+ Config._merge_a_into_b(
+ option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ['true', 'false']:
+ return True if val.lower() == 'true' else False
+ return val
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+ All elements inside '()' or '[]' are treated as iterable values.
+ Args:
+ val (str): Value string.
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b'], 'c']
+ """
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count('(') == string.count(')')) and (
+ string.count('[') == string.count(']')), \
+ f'Imbalanced brackets exist in {string}'
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if ((char == ',') and (pre.count('(') == pre.count(')'))
+ and (pre.count('[') == pre.count(']'))):
+ end = idx
+ break
+ return end
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip('\'\"').replace(' ', '')
+ is_tuple = False
+ if val.startswith('(') and val.endswith(')'):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith('[') and val.endswith(']'):
+ val = val[1:-1]
+ elif ',' not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1:]
+ if is_tuple:
+ values = tuple(values)
+ return values
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split('=', maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/env.py b/ControlNet/annotator/uniformer/mmcv/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f0d92529e193e6d8339419bcd9bed7901a7769
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/env.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This file holding some environment constant for sharing by other files."""
+import os.path as osp
+import subprocess
+import sys
+from collections import defaultdict
+import cv2
+import torch
+import annotator.uniformer.mmcv as mmcv
+from .parrots_wrapper import get_build_config
+def collect_env():
+ """Collect the information of the running environments.
+ Returns:
+ dict: The environment information. The following fields are contained.
+ - sys.platform: The variable of ``sys.platform``.
+ - Python: Python version.
+ - CUDA available: Bool, indicating if CUDA is available.
+ - GPU devices: Device type of each GPU.
+ - CUDA_HOME (optional): The env var ``CUDA_HOME``.
+ - NVCC (optional): NVCC version.
+ - GCC: GCC version, "n/a" if GCC is not installed.
+ - PyTorch: PyTorch version.
+ - PyTorch compiling details: The output of \
+ ``torch.__config__.show()``.
+ - TorchVision (optional): TorchVision version.
+ - OpenCV: OpenCV version.
+ - MMCV: MMCV version.
+ - MMCV Compiler: The GCC version for compiling MMCV ops.
+ - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
+ """
+ env_info = {}
+ env_info['sys.platform'] = sys.platform
+ env_info['Python'] = sys.version.replace('\n', '')
+ cuda_available = torch.cuda.is_available()
+ env_info['CUDA available'] = cuda_available
+ if cuda_available:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, device_ids in devices.items():
+ env_info['GPU ' + ','.join(device_ids)] = name
+ from annotator.uniformer.mmcv.utils.parrots_wrapper import _get_cuda_home
+ CUDA_HOME = _get_cuda_home()
+ env_info['CUDA_HOME'] = CUDA_HOME
+ if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
+ try:
+ nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
+ nvcc = subprocess.check_output(
+ f'"{nvcc}" -V | tail -n1', shell=True)
+ nvcc = nvcc.decode('utf-8').strip()
+ except subprocess.SubprocessError:
+ nvcc = 'Not Available'
+ env_info['NVCC'] = nvcc
+ try:
+ gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
+ gcc = gcc.decode('utf-8').strip()
+ env_info['GCC'] = gcc
+ except subprocess.CalledProcessError: # gcc is unavailable
+ env_info['GCC'] = 'n/a'
+ env_info['PyTorch'] = torch.__version__
+ env_info['PyTorch compiling details'] = get_build_config()
+ try:
+ import torchvision
+ env_info['TorchVision'] = torchvision.__version__
+ except ModuleNotFoundError:
+ pass
+ env_info['OpenCV'] = cv2.__version__
+ env_info['MMCV'] = mmcv.__version__
+ try:
+ from annotator.uniformer.mmcv.ops import get_compiler_version, get_compiling_cuda_version
+ except ModuleNotFoundError:
+ env_info['MMCV Compiler'] = 'n/a'
+ env_info['MMCV CUDA Compiler'] = 'n/a'
+ else:
+ env_info['MMCV Compiler'] = get_compiler_version()
+ env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version()
+ return env_info
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/ext_loader.py b/ControlNet/annotator/uniformer/mmcv/utils/ext_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..08132d2c1b9a1c28880e4bab4d4fa1ba39d9d083
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/ext_loader.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os
+import pkgutil
+import warnings
+from collections import namedtuple
+import torch
+if torch.__version__ != 'parrots':
+ def load_ext(name, funcs):
+ ext = importlib.import_module('mmcv.' + name)
+ for fun in funcs:
+ assert hasattr(ext, fun), f'{fun} miss in module {name}'
+ return ext
+ from parrots import extension
+ from parrots.base import ParrotsException
+ has_return_value_ops = [
+ 'nms',
+ 'softnms',
+ 'nms_match',
+ 'nms_rotated',
+ 'top_pool_forward',
+ 'top_pool_backward',
+ 'bottom_pool_forward',
+ 'bottom_pool_backward',
+ 'left_pool_forward',
+ 'left_pool_backward',
+ 'right_pool_forward',
+ 'right_pool_backward',
+ 'fused_bias_leakyrelu',
+ 'upfirdn2d',
+ 'ms_deform_attn_forward',
+ 'pixel_group',
+ 'contour_expand',
+ ]
+ def get_fake_func(name, e):
+ def fake_func(*args, **kwargs):
+ warnings.warn(f'{name} is not supported in parrots now')
+ raise e
+ return fake_func
+ def load_ext(name, funcs):
+ ExtModule = namedtuple('ExtModule', funcs)
+ ext_list = []
+ lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+ for fun in funcs:
+ try:
+ ext_fun = extension.load(fun, name, lib_dir=lib_root)
+ except ParrotsException as e:
+ if 'No element registered' not in e.message:
+ warnings.warn(e.message)
+ ext_fun = get_fake_func(fun, e)
+ ext_list.append(ext_fun)
+ else:
+ if fun in has_return_value_ops:
+ ext_list.append(ext_fun.op)
+ else:
+ ext_list.append(ext_fun.op_)
+ return ExtModule(*ext_list)
+def check_ops_exist():
+ ext_loader = pkgutil.find_loader('mmcv._ext')
+ return ext_loader is not None
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/logging.py b/ControlNet/annotator/uniformer/mmcv/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa0e04bb9b3ab2a4bfbc4def50404ccbac2c6e6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/logging.py
@@ -0,0 +1,110 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import torch.distributed as dist
+logger_initialized = {}
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
+ """Initialize and get a logger by name.
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'w'.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+ # handle duplicate logs to the console
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
+ # to the root logger. As logger.propagate is True by default, this root
+ # level handler causes logging messages from rank>0 processes to
+ # unexpectedly show up on the console, creating much unwanted clutter.
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
+ # at the ERROR level.
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+ logger_initialized[name] = True
+ return logger
+def print_log(msg, logger=None, level=logging.INFO):
+ """Print a log message.
+ Args:
+ msg (str): The message to be logged.
+ logger (logging.Logger | str | None): The logger to be used.
+ Some special loggers are:
+ - "silent": no message will be printed.
+ - other str: the logger obtained with `get_root_logger(logger)`.
+ - None: The `print()` method will be used to print log messages.
+ level (int): Logging level. Only available when `logger` is a Logger
+ object or "root".
+ """
+ if logger is None:
+ print(msg)
+ elif isinstance(logger, logging.Logger):
+ logger.log(level, msg)
+ elif logger == 'silent':
+ pass
+ elif isinstance(logger, str):
+ _logger = get_logger(logger)
+ _logger.log(level, msg)
+ else:
+ raise TypeError(
+ 'logger should be either a logging.Logger object, str, '
+ f'"silent" or None, but got {type(logger)}')
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/misc.py b/ControlNet/annotator/uniformer/mmcv/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c58d0d7fee9fe3d4519270ad8c1e998d0d8a18c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/misc.py
@@ -0,0 +1,377 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import collections.abc
+import functools
+import itertools
+import subprocess
+import warnings
+from collections import abc
+from importlib import import_module
+from inspect import getfullargspec
+from itertools import repeat
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+def is_str(x):
+ """Whether the input is an string instance.
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+def import_modules_from_strings(imports, allow_failed_imports=False):
+ """Import modules from the given list of strings.
+ Args:
+ imports (list | str | None): The given module names to be imported.
+ allow_failed_imports (bool): If True, the failed imports will return
+ None. Otherwise, an ImportError is raise. Default: False.
+ Returns:
+ list[module] | module | None: The imported modules.
+ Examples:
+ >>> osp, sys = import_modules_from_strings(
+ ... ['os.path', 'sys'])
+ >>> import os.path as osp_
+ >>> import sys as sys_
+ >>> assert osp == osp_
+ >>> assert sys == sys_
+ """
+ if not imports:
+ return
+ single_import = False
+ if isinstance(imports, str):
+ single_import = True
+ imports = [imports]
+ if not isinstance(imports, list):
+ raise TypeError(
+ f'custom_imports must be a list but got type {type(imports)}')
+ imported = []
+ for imp in imports:
+ if not isinstance(imp, str):
+ raise TypeError(
+ f'{imp} is of type {type(imp)} and cannot be imported.')
+ try:
+ imported_tmp = import_module(imp)
+ except ImportError:
+ if allow_failed_imports:
+ warnings.warn(f'{imp} failed to import and is ignored.',
+ UserWarning)
+ imported_tmp = None
+ else:
+ raise ImportError
+ imported.append(imported_tmp)
+ if single_import:
+ imported = imported[0]
+ return imported
+def iter_cast(inputs, dst_type, return_type=None):
+ """Cast elements of an iterable object into some type.
+ Args:
+ inputs (Iterable): The input object.
+ dst_type (type): Destination type.
+ return_type (type, optional): If specified, the output object will be
+ converted to this type, otherwise an iterator.
+ Returns:
+ iterator or specified type: The converted object.
+ """
+ if not isinstance(inputs, abc.Iterable):
+ raise TypeError('inputs must be an iterable object')
+ if not isinstance(dst_type, type):
+ raise TypeError('"dst_type" must be a valid type')
+ out_iterable = map(dst_type, inputs)
+ if return_type is None:
+ return out_iterable
+ else:
+ return return_type(out_iterable)
+def list_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a list of some type.
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=list)
+def tuple_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a tuple of some type.
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=tuple)
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+def is_list_of(seq, expected_type):
+ """Check whether it is a list of some type.
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=list)
+def is_tuple_of(seq, expected_type):
+ """Check whether it is a tuple of some type.
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=tuple)
+def slice_list(in_list, lens):
+ """Slice a list into several sub lists by a list of given length.
+ Args:
+ in_list (list): The list to be sliced.
+ lens(int or list): The expected length of each out list.
+ Returns:
+ list: A list of sliced list.
+ """
+ if isinstance(lens, int):
+ assert len(in_list) % lens == 0
+ lens = [lens] * int(len(in_list) / lens)
+ if not isinstance(lens, list):
+ raise TypeError('"indices" must be an integer or a list of integers')
+ elif sum(lens) != len(in_list):
+ raise ValueError('sum of lens and list length does not '
+ f'match: {sum(lens)} != {len(in_list)}')
+ out_list = []
+ idx = 0
+ for i in range(len(lens)):
+ out_list.append(in_list[idx:idx + lens[i]])
+ idx += lens[i]
+ return out_list
+def concat_list(in_list):
+ """Concatenate a list of list into a single list.
+ Args:
+ in_list (list): The list of list to be merged.
+ Returns:
+ list: The concatenated flat list.
+ """
+ return list(itertools.chain(*in_list))
+def check_prerequisites(
+ prerequisites,
+ checker,
+ msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
+ 'found, please install them first.'): # yapf: disable
+ """A decorator factory to check if prerequisites are satisfied.
+ Args:
+ prerequisites (str of list[str]): Prerequisites to be checked.
+ checker (callable): The checker method that returns True if a
+ prerequisite is meet, False otherwise.
+ msg_tmpl (str): The message template with two variables.
+ Returns:
+ decorator: A specific decorator.
+ """
+ def wrap(func):
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ requirements = [prerequisites] if isinstance(
+ prerequisites, str) else prerequisites
+ missing = []
+ for item in requirements:
+ if not checker(item):
+ missing.append(item)
+ if missing:
+ print(msg_tmpl.format(', '.join(missing), func.__name__))
+ raise RuntimeError('Prerequisites not meet.')
+ else:
+ return func(*args, **kwargs)
+ return wrapped_func
+ return wrap
+def _check_py_package(package):
+ try:
+ import_module(package)
+ except ImportError:
+ return False
+ else:
+ return True
+def _check_executable(cmd):
+ if subprocess.call(f'which {cmd}', shell=True) != 0:
+ return False
+ else:
+ return True
+def requires_package(prerequisites):
+ """A decorator to check if some python packages are installed.
+ Example:
+ >>> @requires_package('numpy')
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ array([0.])
+ >>> @requires_package(['numpy', 'non_package'])
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ ImportError
+ """
+ return check_prerequisites(prerequisites, checker=_check_py_package)
+def requires_executable(prerequisites):
+ """A decorator to check if some executable files are installed.
+ Example:
+ >>> @requires_executable('ffmpeg')
+ >>> func(arg1, args):
+ >>> print(1)
+ 1
+ """
+ return check_prerequisites(prerequisites, checker=_check_executable)
+def deprecated_api_warning(name_dict, cls_name=None):
+ """A decorator to check if some arguments are deprecate and try to replace
+ deprecate src_arg_name to dst_arg_name.
+ Args:
+ name_dict(dict):
+ key (str): Deprecate argument names.
+ val (str): Expected argument names.
+ Returns:
+ func: New function.
+ """
+ def api_warning_wrapper(old_func):
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get name of the function
+ func_name = old_func.__name__
+ if cls_name is not None:
+ func_name = f'{cls_name}.{func_name}'
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in arg_names:
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
+ if kwargs:
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in kwargs:
+ assert dst_arg_name not in kwargs, (
+ f'The expected behavior is to replace '
+ f'the deprecated key `{src_arg_name}` to '
+ f'new key `{dst_arg_name}`, but got them '
+ f'in the arguments at the same time, which '
+ f'is confusing. `{src_arg_name} will be '
+ f'deprecated in the future, please '
+ f'use `{dst_arg_name}` instead.')
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
+ # apply converted arguments to the decorated method
+ output = old_func(*args, **kwargs)
+ return output
+ return new_func
+ return api_warning_wrapper
+def is_method_overridden(method, base_class, derived_class):
+ """Check if a method of base class is overridden in derived class.
+ Args:
+ method (str): the method name to check.
+ base_class (type): the class of the base class.
+ derived_class (type | Any): the class or instance of the derived class.
+ """
+ assert isinstance(base_class, type), \
+ "base_class doesn't accept instance, Please pass class instead."
+ if not isinstance(derived_class, type):
+ derived_class = derived_class.__class__
+ base_method = getattr(base_class, method)
+ derived_method = getattr(derived_class, method)
+ return derived_method != base_method
+def has_method(obj: object, method: str) -> bool:
+ """Check whether the object has a method.
+ Args:
+ method (str): The method name to check.
+ obj (object): The object to check.
+ Returns:
+ bool: True if the object has the method else False.
+ """
+ return hasattr(obj, method) and callable(getattr(obj, method))
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/parrots_jit.py b/ControlNet/annotator/uniformer/mmcv/utils/parrots_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..61873f6dbb9b10ed972c90aa8faa321e3cb3249e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/parrots_jit.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from .parrots_wrapper import TORCH_VERSION
+parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
+if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
+ from parrots.jit import pat as jit
+ def jit(func=None,
+ check_input=None,
+ full_shape=True,
+ derivate=False,
+ coderize=False,
+ optimize=False):
+ def wrapper(func):
+ def wrapper_inner(*args, **kargs):
+ return func(*args, **kargs)
+ return wrapper_inner
+ if func is None:
+ return wrapper
+ else:
+ return func
+if TORCH_VERSION == 'parrots':
+ from parrots.utils.tester import skip_no_elena
+ def skip_no_elena(func):
+ def wrapper(*args, **kargs):
+ return func(*args, **kargs)
+ return wrapper
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/parrots_wrapper.py b/ControlNet/annotator/uniformer/mmcv/utils/parrots_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c97640d4b9ed088ca82cfe03e6efebfcfa9dbf
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/parrots_wrapper.py
@@ -0,0 +1,107 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+import torch
+TORCH_VERSION = torch.__version__
+def is_rocm_pytorch() -> bool:
+ is_rocm = False
+ if TORCH_VERSION != 'parrots':
+ try:
+ from torch.utils.cpp_extension import ROCM_HOME
+ is_rocm = True if ((torch.version.hip is not None) and
+ (ROCM_HOME is not None)) else False
+ except ImportError:
+ pass
+ return is_rocm
+def _get_cuda_home():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import CUDA_HOME
+ else:
+ if is_rocm_pytorch():
+ from torch.utils.cpp_extension import ROCM_HOME
+ else:
+ from torch.utils.cpp_extension import CUDA_HOME
+ return CUDA_HOME
+def get_build_config():
+ if TORCH_VERSION == 'parrots':
+ from parrots.config import get_build_info
+ return get_build_info()
+ else:
+ return torch.__config__.show()
+def _get_conv():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ else:
+ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ return _ConvNd, _ConvTransposeMixin
+def _get_dataloader():
+ if TORCH_VERSION == 'parrots':
+ from torch.utils.data import DataLoader, PoolDataLoader
+ else:
+ from torch.utils.data import DataLoader
+ PoolDataLoader = DataLoader
+ return DataLoader, PoolDataLoader
+def _get_extension():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import BuildExtension, Extension
+ CppExtension = partial(Extension, cuda=False)
+ CUDAExtension = partial(Extension, cuda=True)
+ else:
+ from torch.utils.cpp_extension import (BuildExtension, CppExtension,
+ CUDAExtension)
+ return BuildExtension, CppExtension, CUDAExtension
+def _get_pool():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ else:
+ from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
+def _get_norm():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
+ else:
+ from torch.nn.modules.instancenorm import _InstanceNorm
+ from torch.nn.modules.batchnorm import _BatchNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm
+ return _BatchNorm, _InstanceNorm, SyncBatchNorm_
+_ConvNd, _ConvTransposeMixin = _get_conv()
+DataLoader, PoolDataLoader = _get_dataloader()
+BuildExtension, CppExtension, CUDAExtension = _get_extension()
+_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
+_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
+class SyncBatchNorm(SyncBatchNorm_):
+ def _check_input_dim(self, input):
+ if TORCH_VERSION == 'parrots':
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input (got {input.dim()}D input)')
+ else:
+ super()._check_input_dim(input)
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/path.py b/ControlNet/annotator/uniformer/mmcv/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dab4b3041413b1432b0f434b8b14783097d33c6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/path.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from pathlib import Path
+from .misc import is_str
+def is_filepath(x):
+ return is_str(x) or isinstance(x, Path)
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+ raise ValueError('`filepath` should be a string or a Path')
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == '':
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+ """Scan a directory to find the interested files.
+ Args:
+ dir_path (str | obj:`Path`): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ case_sensitive (bool, optional) : If set to False, ignore the case of
+ suffix. Default: True.
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+ if isinstance(dir_path, (str, Path)):
+ dir_path = str(dir_path)
+ else:
+ raise TypeError('"dir_path" must be a string or Path object')
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+ if suffix is not None and not case_sensitive:
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
+ item.lower() for item in suffix)
+ root = dir_path
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
+ if suffix is None or _rel_path.endswith(suffix):
+ yield rel_path
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(entry.path, suffix, recursive,
+ case_sensitive)
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
+def find_vcs_root(path, markers=('.git', )):
+ """Finds the root directory (including itself) of specified markers.
+ Args:
+ path (str): Path of directory or file.
+ markers (list[str], optional): List of file or directory names.
+ Returns:
+ The directory contained one of the markers or None if not found.
+ """
+ if osp.isfile(path):
+ path = osp.dirname(path)
+ prev, cur = None, osp.abspath(osp.expanduser(path))
+ while cur != prev:
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
+ return cur
+ prev, cur = cur, osp.split(cur)[0]
+ return None
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/progressbar.py b/ControlNet/annotator/uniformer/mmcv/utils/progressbar.py
new file mode 100644
index 0000000000000000000000000000000000000000..0062f670dd94fa9da559ab26ef85517dcf5211c7
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/progressbar.py
@@ -0,0 +1,208 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+from collections.abc import Iterable
+from multiprocessing import Pool
+from shutil import get_terminal_size
+from .timer import Timer
+class ProgressBar:
+ """A progress bar which can print the progress."""
+ def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
+ self.task_num = task_num
+ self.bar_width = bar_width
+ self.completed = 0
+ self.file = file
+ if start:
+ self.start()
+ @property
+ def terminal_width(self):
+ width, _ = get_terminal_size()
+ return width
+ def start(self):
+ if self.task_num > 0:
+ self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
+ 'elapsed: 0s, ETA:')
+ else:
+ self.file.write('completed: 0, elapsed: 0s')
+ self.file.flush()
+ self.timer = Timer()
+ def update(self, num_tasks=1):
+ assert num_tasks > 0
+ self.completed += num_tasks
+ elapsed = self.timer.since_start()
+ if elapsed > 0:
+ fps = self.completed / elapsed
+ else:
+ fps = float('inf')
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
+ f'ETA: {eta:5}s'
+ bar_width = min(self.bar_width,
+ int(self.terminal_width - len(msg)) + 2,
+ int(self.terminal_width * 0.6))
+ bar_width = max(2, bar_width)
+ mark_width = int(bar_width * percentage)
+ bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
+ self.file.write(msg.format(bar_chars))
+ else:
+ self.file.write(
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
+ f' {fps:.1f} tasks/s')
+ self.file.flush()
+def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
+ """Track the progress of tasks execution with a progress bar.
+ Tasks are done with a simple for-loop.
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ results = []
+ for task in tasks:
+ results.append(func(task, **kwargs))
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ return results
+def init_pool(process_num, initializer=None, initargs=None):
+ if initializer is None:
+ return Pool(process_num)
+ elif initargs is None:
+ return Pool(process_num, initializer)
+ else:
+ if not isinstance(initargs, tuple):
+ raise TypeError('"initargs" must be a tuple')
+ return Pool(process_num, initializer, initargs)
+def track_parallel_progress(func,
+ tasks,
+ nproc,
+ initializer=None,
+ initargs=None,
+ bar_width=50,
+ chunksize=1,
+ skip_first=False,
+ keep_order=True,
+ file=sys.stdout):
+ """Track the progress of parallel task execution with a progress bar.
+ The built-in :mod:`multiprocessing` module is used for process pools and
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ nproc (int): Process (worker) number.
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
+ for details.
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
+ details.
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
+ bar_width (int): Width of progress bar.
+ skip_first (bool): Whether to skip the first sample for each worker
+ when estimating fps, since the initialization step may takes
+ longer.
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
+ :func:`Pool.imap_unordered` is used.
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ pool = init_pool(nproc, initializer, initargs)
+ start = not skip_first
+ task_num -= nproc * chunksize * int(skip_first)
+ prog_bar = ProgressBar(task_num, bar_width, start, file=file)
+ results = []
+ if keep_order:
+ gen = pool.imap(func, tasks, chunksize)
+ else:
+ gen = pool.imap_unordered(func, tasks, chunksize)
+ for result in gen:
+ results.append(result)
+ if skip_first:
+ if len(results) < nproc * chunksize:
+ continue
+ elif len(results) == nproc * chunksize:
+ prog_bar.start()
+ continue
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ pool.close()
+ pool.join()
+ return results
+def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
+ """Track the progress of tasks iteration or enumeration with a progress
+ bar.
+ Tasks are yielded with a simple for-loop.
+ Args:
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+ Yields:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ for task in tasks:
+ yield task
+ prog_bar.update()
+ prog_bar.file.write('\n')
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/registry.py b/ControlNet/annotator/uniformer/mmcv/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9df39bc9f3d8d568361e7250ab35468f2b74e0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/registry.py
@@ -0,0 +1,315 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import warnings
+from functools import partial
+from .misc import is_seq_of
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from config dict.
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ if default_args is None or 'type' not in default_args:
+ raise KeyError(
+ '`cfg` or `default_args` must contain the key "type", '
+ f'but got {cfg}\n{default_args}')
+ if not isinstance(registry, Registry):
+ raise TypeError('registry must be an mmcv.Registry object, '
+ f'but got {type(registry)}')
+ if not (isinstance(default_args, dict) or default_args is None):
+ raise TypeError('default_args must be a dict or None, '
+ f'but got {type(default_args)}')
+ args = cfg.copy()
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+ obj_type = args.pop('type')
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(
+ f'{obj_type} is not in the {registry.name} registry')
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+ try:
+ return obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f'{obj_cls.__name__}: {e}')
+class Registry:
+ """A registry to map strings to classes.
+ Registered object could be built from registry.
+ Example:
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = MODELS.build(dict(type='ResNet'))
+ Please refer to
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
+ advanced usage.
+ Args:
+ name (str): Registry name.
+ build_func(func, optional): Build function to construct instance from
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
+ ``build_func`` is specified. If ``parent`` is specified and
+ ``build_func`` is not given, ``build_func`` will be inherited
+ from ``parent``. Default: None.
+ parent (Registry, optional): Parent registry. The class registered in
+ children registry could be built from parent. Default: None.
+ scope (str, optional): The scope of registry. It is the key to search
+ for children registry. If not specified, scope will be the name of
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
+ Default: None.
+ """
+ def __init__(self, name, build_func=None, parent=None, scope=None):
+ self._name = name
+ self._module_dict = dict()
+ self._children = dict()
+ self._scope = self.infer_scope() if scope is None else scope
+ # self.build_func will be set with the following priority:
+ # 1. build_func
+ # 2. parent.build_func
+ # 3. build_from_cfg
+ if build_func is None:
+ if parent is not None:
+ self.build_func = parent.build_func
+ else:
+ self.build_func = build_from_cfg
+ else:
+ self.build_func = build_func
+ if parent is not None:
+ assert isinstance(parent, Registry)
+ parent._add_children(self)
+ self.parent = parent
+ else:
+ self.parent = None
+ def __len__(self):
+ return len(self._module_dict)
+ def __contains__(self, key):
+ return self.get(key) is not None
+ def __repr__(self):
+ format_str = self.__class__.__name__ + \
+ f'(name={self._name}, ' \
+ f'items={self._module_dict})'
+ return format_str
+ @staticmethod
+ def infer_scope():
+ """Infer the scope of registry.
+ The name of the package where registry is defined will be returned.
+ Example:
+ # in mmdet/models/backbone/resnet.py
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ The scope of ``ResNet`` will be ``mmdet``.
+ Returns:
+ scope (str): The inferred scope name.
+ """
+ # inspect.stack() trace where this function is called, the index-2
+ # indicates the frame where `infer_scope()` is called
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
+ split_filename = filename.split('.')
+ return split_filename[0]
+ @staticmethod
+ def split_scope_key(key):
+ """Split scope and key.
+ The first scope will be split from key.
+ Examples:
+ >>> Registry.split_scope_key('mmdet.ResNet')
+ 'mmdet', 'ResNet'
+ >>> Registry.split_scope_key('ResNet')
+ None, 'ResNet'
+ Return:
+ scope (str, None): The first scope.
+ key (str): The remaining key.
+ """
+ split_index = key.find('.')
+ if split_index != -1:
+ return key[:split_index], key[split_index + 1:]
+ else:
+ return None, key
+ @property
+ def name(self):
+ return self._name
+ @property
+ def scope(self):
+ return self._scope
+ @property
+ def module_dict(self):
+ return self._module_dict
+ @property
+ def children(self):
+ return self._children
+ def get(self, key):
+ """Get the registry record.
+ Args:
+ key (str): The class name in string format.
+ Returns:
+ class: The corresponding class.
+ """
+ scope, real_key = self.split_scope_key(key)
+ if scope is None or scope == self._scope:
+ # get from self
+ if real_key in self._module_dict:
+ return self._module_dict[real_key]
+ else:
+ # get from self._children
+ if scope in self._children:
+ return self._children[scope].get(real_key)
+ else:
+ # goto root
+ parent = self.parent
+ while parent.parent is not None:
+ parent = parent.parent
+ return parent.get(key)
+ def build(self, *args, **kwargs):
+ return self.build_func(*args, **kwargs, registry=self)
+ def _add_children(self, registry):
+ """Add children for a registry.
+ The ``registry`` will be added as children based on its scope.
+ The parent registry could build objects from children registry.
+ Example:
+ >>> models = Registry('models')
+ >>> mmdet_models = Registry('models', parent=models)
+ >>> @mmdet_models.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
+ """
+ assert isinstance(registry, Registry)
+ assert registry.scope is not None
+ assert registry.scope not in self.children, \
+ f'scope {registry.scope} exists in {self.name} registry'
+ self.children[registry.scope] = registry
+ def _register_module(self, module_class, module_name=None, force=False):
+ if not inspect.isclass(module_class):
+ raise TypeError('module must be a class, '
+ f'but got {type(module_class)}')
+ if module_name is None:
+ module_name = module_class.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in self._module_dict:
+ raise KeyError(f'{name} is already registered '
+ f'in {self.name}')
+ self._module_dict[name] = module_class
+ def deprecated_register_module(self, cls=None, force=False):
+ warnings.warn(
+ 'The old API of register_module(module, force=False) '
+ 'is deprecated and will be removed, please use the new API '
+ 'register_module(name=None, force=False, module=None) instead.')
+ if cls is None:
+ return partial(self.deprecated_register_module, force=force)
+ self._register_module(cls, force=force)
+ return cls
+ def register_module(self, name=None, force=False, module=None):
+ """Register a module.
+ A record will be added to `self._module_dict`, whose key is the class
+ name or the specified name, and value is the class itself.
+ It can be used as a decorator or a normal function.
+ Example:
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module(name='mnet')
+ >>> class MobileNet:
+ >>> pass
+ >>> backbones = Registry('backbone')
+ >>> class ResNet:
+ >>> pass
+ >>> backbones.register_module(ResNet)
+ Args:
+ name (str | None): The module name to be registered. If not
+ specified, the class name will be used.
+ force (bool, optional): Whether to override an existing class with
+ the same name. Default: False.
+ module (type): Module class to be registered.
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
+ # NOTE: This is a walkaround to be compatible with the old api,
+ # while it may introduce unexpected bugs.
+ if isinstance(name, type):
+ return self.deprecated_register_module(name, force=force)
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
+ raise TypeError(
+ 'name must be either of None, an instance of str or a sequence'
+ f' of str, but got {type(name)}')
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ self._register_module(
+ module_class=module, module_name=name, force=force)
+ return module
+ # use it as a decorator: @x.register_module()
+ def _register(cls):
+ self._register_module(
+ module_class=cls, module_name=name, force=force)
+ return cls
+ return _register
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/testing.py b/ControlNet/annotator/uniformer/mmcv/utils/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27f936da8ec14bac18562ede0a79d476d82f797
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/testing.py
@@ -0,0 +1,140 @@
+# Copyright (c) Open-MMLab.
+import sys
+from collections.abc import Iterable
+from runpy import run_path
+from shlex import split
+from typing import Any, Dict, List
+from unittest.mock import patch
+def check_python_script(cmd):
+ """Run the python cmd script with `__main__`. The difference between
+ `os.system` is that, this function exectues code in the current process, so
+ that it can be tracked by coverage tools. Currently it supports two forms:
+ - ./tests/data/scripts/hello.py zz
+ - python tests/data/scripts/hello.py zz
+ """
+ args = split(cmd)
+ if args[0] == 'python':
+ args = args[1:]
+ with patch.object(sys, 'argv', args):
+ run_path(args[0], run_name='__main__')
+def _any(judge_result):
+ """Since built-in ``any`` works only when the element of iterable is not
+ iterable, implement the function."""
+ if not isinstance(judge_result, Iterable):
+ return judge_result
+ try:
+ for element in judge_result:
+ if _any(element):
+ return True
+ except TypeError:
+ # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
+ if judge_result:
+ return True
+ return False
+def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
+ expected_subset: Dict[Any, Any]) -> bool:
+ """Check if the dict_obj contains the expected_subset.
+ Args:
+ dict_obj (Dict[Any, Any]): Dict object to be checked.
+ expected_subset (Dict[Any, Any]): Subset expected to be contained in
+ dict_obj.
+ Returns:
+ bool: Whether the dict_obj contains the expected_subset.
+ """
+ for key, value in expected_subset.items():
+ if key not in dict_obj.keys() or _any(dict_obj[key] != value):
+ return False
+ return True
+def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
+ """Check if attribute of class object is correct.
+ Args:
+ obj (object): Class object to be checked.
+ expected_attrs (Dict[str, Any]): Dict of the expected attrs.
+ Returns:
+ bool: Whether the attribute of class object is correct.
+ """
+ for attr, value in expected_attrs.items():
+ if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
+ return False
+ return True
+def assert_dict_has_keys(obj: Dict[str, Any],
+ expected_keys: List[str]) -> bool:
+ """Check if the obj has all the expected_keys.
+ Args:
+ obj (Dict[str, Any]): Object to be checked.
+ expected_keys (List[str]): Keys expected to contained in the keys of
+ the obj.
+ Returns:
+ bool: Whether the obj has the expected keys.
+ """
+ return set(expected_keys).issubset(set(obj.keys()))
+def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
+ """Check if target_keys is equal to result_keys.
+ Args:
+ result_keys (List[str]): Result keys to be checked.
+ target_keys (List[str]): Target keys to be checked.
+ Returns:
+ bool: Whether target_keys is equal to result_keys.
+ """
+ return set(result_keys) == set(target_keys)
+def assert_is_norm_layer(module) -> bool:
+ """Check if the module is a norm layer.
+ Args:
+ module (nn.Module): The module to be checked.
+ Returns:
+ bool: Whether the module is a norm layer.
+ """
+ from .parrots_wrapper import _BatchNorm, _InstanceNorm
+ from torch.nn import GroupNorm, LayerNorm
+ norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
+ return isinstance(module, norm_layer_candidates)
+def assert_params_all_zeros(module) -> bool:
+ """Check if the parameters of the module is all zeros.
+ Args:
+ module (nn.Module): The module to be checked.
+ Returns:
+ bool: Whether the parameters of the module is all zeros.
+ """
+ weight_data = module.weight.data
+ is_weight_zero = weight_data.allclose(
+ weight_data.new_zeros(weight_data.size()))
+ if hasattr(module, 'bias') and module.bias is not None:
+ bias_data = module.bias.data
+ is_bias_zero = bias_data.allclose(
+ bias_data.new_zeros(bias_data.size()))
+ else:
+ is_bias_zero = True
+ return is_weight_zero and is_bias_zero
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/timer.py b/ControlNet/annotator/uniformer/mmcv/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3db7d497d8b374e18b5297e0a1d6eb186fd8cba
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/timer.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from time import time
+class TimerError(Exception):
+ def __init__(self, message):
+ self.message = message
+ super(TimerError, self).__init__(message)
+class Timer:
+ """A flexible Timer class.
+ :Example:
+ >>> import time
+ >>> import annotator.uniformer.mmcv as mmcv
+ >>> with mmcv.Timer():
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ 1.000
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ it takes 1.0 seconds
+ >>> timer = mmcv.Timer()
+ >>> time.sleep(0.5)
+ >>> print(timer.since_start())
+ 0.500
+ >>> time.sleep(0.5)
+ >>> print(timer.since_last_check())
+ 0.500
+ >>> print(timer.since_start())
+ 1.000
+ """
+ def __init__(self, start=True, print_tmpl=None):
+ self._is_running = False
+ self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
+ if start:
+ self.start()
+ @property
+ def is_running(self):
+ """bool: indicate whether the timer is running"""
+ return self._is_running
+ def __enter__(self):
+ self.start()
+ return self
+ def __exit__(self, type, value, traceback):
+ print(self.print_tmpl.format(self.since_last_check()))
+ self._is_running = False
+ def start(self):
+ """Start the timer."""
+ if not self._is_running:
+ self._t_start = time()
+ self._is_running = True
+ self._t_last = time()
+ def since_start(self):
+ """Total time since the timer is started.
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ self._t_last = time()
+ return self._t_last - self._t_start
+ def since_last_check(self):
+ """Time since the last checking.
+ Either :func:`since_start` or :func:`since_last_check` is a checking
+ operation.
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ dur = time() - self._t_last
+ self._t_last = time()
+ return dur
+_g_timers = {} # global timers
+def check_time(timer_id):
+ """Add check points in a single line.
+ This method is suitable for running a task on a list of items. A timer will
+ be registered when the method is called for the first time.
+ :Example:
+ >>> import time
+ >>> import annotator.uniformer.mmcv as mmcv
+ >>> for i in range(1, 6):
+ >>> # simulate a code block
+ >>> time.sleep(i)
+ >>> mmcv.check_time('task1')
+ 2.000
+ 3.000
+ 4.000
+ 5.000
+ Args:
+ timer_id (str): Timer identifier.
+ """
+ if timer_id not in _g_timers:
+ _g_timers[timer_id] = Timer()
+ return 0
+ else:
+ return _g_timers[timer_id].since_last_check()
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/trace.py b/ControlNet/annotator/uniformer/mmcv/utils/trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca99dc3eda05ef980d9a4249b50deca8273b6cc
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/trace.py
@@ -0,0 +1,23 @@
+import warnings
+import torch
+from annotator.uniformer.mmcv.utils import digit_version
+def is_jit_tracing() -> bool:
+ if (torch.__version__ != 'parrots'
+ and digit_version(torch.__version__) >= digit_version('1.6.0')):
+ on_trace = torch.jit.is_tracing()
+ # In PyTorch 1.6, torch.jit.is_tracing has a bug.
+ # Refers to https://github.com/pytorch/pytorch/issues/42448
+ if isinstance(on_trace, bool):
+ return on_trace
+ else:
+ return torch._C._is_tracing()
+ else:
+ warnings.warn(
+ 'torch.jit.is_tracing is only supported after v1.6.0. '
+ 'Therefore is_tracing returns False automatically. Please '
+ 'set on_trace manually if you are using trace.', UserWarning)
+ return False
diff --git a/ControlNet/annotator/uniformer/mmcv/utils/version_utils.py b/ControlNet/annotator/uniformer/mmcv/utils/version_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/utils/version_utils.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import subprocess
+import warnings
+from packaging.version import parse
+def digit_version(version_str: str, length: int = 4):
+ """Convert a version string into a tuple of integers.
+ This method is usually used for comparing two versions. For pre-release
+ versions: alpha < beta < rc.
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+ Returns:
+ tuple[int]: The version info in digits (integers).
+ """
+ assert 'parrots' not in version_str
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
+ val = -4
+ # version.pre can be None
+ if version.pre:
+ if version.pre[0] not in mapping:
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
+ 'version checking may go wrong')
+ else:
+ val = mapping[version.pre[0]]
+ release.extend([val, version.pre[-1]])
+ else:
+ release.extend([val, 0])
+ elif version.is_postrelease:
+ release.extend([1, version.post])
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+def get_git_hash(fallback='unknown', digits=None):
+ """Get the git hash of the current repo.
+ Args:
+ fallback (str, optional): The fallback string when git hash is
+ unavailable. Defaults to 'unknown'.
+ digits (int, optional): kept digits of the hash. Defaults to None,
+ meaning all digits are kept.
+ Returns:
+ str: Git commit hash.
+ """
+ if digits is not None and not isinstance(digits, int):
+ raise TypeError('digits must be None or an integer')
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ if digits is not None:
+ sha = sha[:digits]
+ except OSError:
+ sha = fallback
+ return sha
diff --git a/ControlNet/annotator/uniformer/mmcv/version.py b/ControlNet/annotator/uniformer/mmcv/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cce4e50bd692d4002e3cac3c545a3fb2efe95d0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/version.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+__version__ = '1.3.17'
+def parse_version_info(version_str: str, length: int = 4) -> tuple:
+ """Parse a version string into a tuple.
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+ Returns:
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
+ (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
+ (2, 0, 0, 0, 'rc', 1) (when length is set to 4).
+ """
+ from packaging.version import parse
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ release.extend(list(version.pre))
+ elif version.is_postrelease:
+ release.extend(list(version.post))
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+version_info = tuple(int(x) for x in __version__.split('.')[:3])
+__all__ = ['__version__', 'version_info', 'parse_version_info']
diff --git a/ControlNet/annotator/uniformer/mmcv/video/__init__.py b/ControlNet/annotator/uniformer/mmcv/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73199b01dec52820dc6ca0139903536344d5a1eb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/video/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .io import Cache, VideoReader, frames2video
+from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
+ flowwrite, quantize_flow, sparse_flow_from_bytes)
+from .processing import concat_video, convert_video, cut_video, resize_video
+__all__ = [
+ 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
+ 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
+ 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
diff --git a/ControlNet/annotator/uniformer/mmcv/video/io.py b/ControlNet/annotator/uniformer/mmcv/video/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..9879154227f640c262853b92c219461c6f67ee8e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/video/io.py
@@ -0,0 +1,318 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from collections import OrderedDict
+import cv2
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
+from annotator.uniformer.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
+ track_progress)
+class Cache:
+ def __init__(self, capacity):
+ self._cache = OrderedDict()
+ self._capacity = int(capacity)
+ if capacity <= 0:
+ raise ValueError('capacity must be a positive integer')
+ @property
+ def capacity(self):
+ return self._capacity
+ @property
+ def size(self):
+ return len(self._cache)
+ def put(self, key, val):
+ if key in self._cache:
+ return
+ if len(self._cache) >= self.capacity:
+ self._cache.popitem(last=False)
+ self._cache[key] = val
+ def get(self, key, default=None):
+ val = self._cache[key] if key in self._cache else default
+ return val
+class VideoReader:
+ """Video class with similar usage to a list object.
+ This video warpper class provides convenient apis to access frames.
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
+ certain frame may be inaccurate. It is fixed in this class by checking
+ the position after jumping each time.
+ Cache is used when decoding videos. So if the same frame is visited for
+ the second time, there is no need to decode again if it is stored in the
+ cache.
+ :Example:
+ >>> import annotator.uniformer.mmcv as mmcv
+ >>> v = mmcv.VideoReader('sample.mp4')
+ >>> len(v) # get the total frame number with `len()`
+ 120
+ >>> for img in v: # v is iterable
+ >>> mmcv.imshow(img)
+ >>> v[5] # get the 6th frame
+ """
+ def __init__(self, filename, cache_capacity=10):
+ # Check whether the video path is a url
+ if not filename.startswith(('https://', 'http://')):
+ check_file_exist(filename, 'Video file not found: ' + filename)
+ self._vcap = cv2.VideoCapture(filename)
+ assert cache_capacity > 0
+ self._cache = Cache(cache_capacity)
+ self._position = 0
+ # get basic info
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
+ self._fps = self._vcap.get(CAP_PROP_FPS)
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
+ @property
+ def vcap(self):
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
+ return self._vcap
+ @property
+ def opened(self):
+ """bool: Indicate whether the video is opened."""
+ return self._vcap.isOpened()
+ @property
+ def width(self):
+ """int: Width of video frames."""
+ return self._width
+ @property
+ def height(self):
+ """int: Height of video frames."""
+ return self._height
+ @property
+ def resolution(self):
+ """tuple: Video resolution (width, height)."""
+ return (self._width, self._height)
+ @property
+ def fps(self):
+ """float: FPS of the video."""
+ return self._fps
+ @property
+ def frame_cnt(self):
+ """int: Total frames of the video."""
+ return self._frame_cnt
+ @property
+ def fourcc(self):
+ """str: "Four character code" of the video."""
+ return self._fourcc
+ @property
+ def position(self):
+ """int: Current cursor position, indicating frame decoded."""
+ return self._position
+ def _get_real_position(self):
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
+ def _set_real_position(self, frame_id):
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
+ pos = self._get_real_position()
+ for _ in range(frame_id - pos):
+ self._vcap.read()
+ self._position = frame_id
+ def read(self):
+ """Read the next frame.
+ If the next frame have been decoded before and in the cache, then
+ return it directly, otherwise decode, cache and return it.
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ # pos = self._position
+ if self._cache:
+ img = self._cache.get(self._position)
+ if img is not None:
+ ret = True
+ else:
+ if self._position != self._get_real_position():
+ self._set_real_position(self._position)
+ ret, img = self._vcap.read()
+ if ret:
+ self._cache.put(self._position, img)
+ else:
+ ret, img = self._vcap.read()
+ if ret:
+ self._position += 1
+ return img
+ def get_frame(self, frame_id):
+ """Get frame by index.
+ Args:
+ frame_id (int): Index of the expected frame, 0-based.
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ if frame_id < 0 or frame_id >= self._frame_cnt:
+ raise IndexError(
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
+ if frame_id == self._position:
+ return self.read()
+ if self._cache:
+ img = self._cache.get(frame_id)
+ if img is not None:
+ self._position = frame_id + 1
+ return img
+ self._set_real_position(frame_id)
+ ret, img = self._vcap.read()
+ if ret:
+ if self._cache:
+ self._cache.put(self._position, img)
+ self._position += 1
+ return img
+ def current_frame(self):
+ """Get the current frame (frame that is just visited).
+ Returns:
+ ndarray or None: If the video is fresh, return None, otherwise
+ return the frame.
+ """
+ if self._position == 0:
+ return None
+ return self._cache.get(self._position - 1)
+ def cvt2frames(self,
+ frame_dir,
+ file_start=0,
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ max_num=0,
+ show_progress=True):
+ """Convert a video to frame images.
+ Args:
+ frame_dir (str): Output directory to store all the frame images.
+ file_start (int): Filenames will start from the specified number.
+ filename_tmpl (str): Filename template with the index as the
+ placeholder.
+ start (int): The starting frame index.
+ max_num (int): Maximum number of frames to be written.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ mkdir_or_exist(frame_dir)
+ if max_num == 0:
+ task_num = self.frame_cnt - start
+ else:
+ task_num = min(self.frame_cnt - start, max_num)
+ if task_num <= 0:
+ raise ValueError('start must be less than total frame number')
+ if start > 0:
+ self._set_real_position(start)
+ def write_frame(file_idx):
+ img = self.read()
+ if img is None:
+ return
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ cv2.imwrite(filename, img)
+ if show_progress:
+ track_progress(write_frame, range(file_start,
+ file_start + task_num))
+ else:
+ for i in range(task_num):
+ write_frame(file_start + i)
+ def __len__(self):
+ return self.frame_cnt
+ def __getitem__(self, index):
+ if isinstance(index, slice):
+ return [
+ self.get_frame(i)
+ for i in range(*index.indices(self.frame_cnt))
+ ]
+ # support negative indexing
+ if index < 0:
+ index += self.frame_cnt
+ if index < 0:
+ raise IndexError('index out of range')
+ return self.get_frame(index)
+ def __iter__(self):
+ self._set_real_position(0)
+ return self
+ def __next__(self):
+ img = self.read()
+ if img is not None:
+ return img
+ else:
+ raise StopIteration
+ next = __next__
+ def __enter__(self):
+ return self
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._vcap.release()
+def frames2video(frame_dir,
+ video_file,
+ fps=30,
+ fourcc='XVID',
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ end=0,
+ show_progress=True):
+ """Read the frame images from a directory and join them as a video.
+ Args:
+ frame_dir (str): The directory containing video frames.
+ video_file (str): Output filename.
+ fps (float): FPS of the output video.
+ fourcc (str): Fourcc of the output video, this should be compatible
+ with the output file type.
+ filename_tmpl (str): Filename template with the index as the variable.
+ start (int): Starting frame index.
+ end (int): Ending frame index.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ if end == 0:
+ ext = filename_tmpl.split('.')[-1]
+ end = len([name for name in scandir(frame_dir, ext)])
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
+ img = cv2.imread(first_file)
+ height, width = img.shape[:2]
+ resolution = (width, height)
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
+ resolution)
+ def write_frame(file_idx):
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ img = cv2.imread(filename)
+ vwriter.write(img)
+ if show_progress:
+ track_progress(write_frame, range(start, end))
+ else:
+ for i in range(start, end):
+ write_frame(i)
+ vwriter.release()
diff --git a/ControlNet/annotator/uniformer/mmcv/video/optflow.py b/ControlNet/annotator/uniformer/mmcv/video/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..84160f8d6ef9fceb5a2f89e7481593109fc1905d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/video/optflow.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+import cv2
+import numpy as np
+from annotator.uniformer.mmcv.arraymisc import dequantize, quantize
+from annotator.uniformer.mmcv.image import imread, imwrite
+from annotator.uniformer.mmcv.utils import is_str
+def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+ Args:
+ flow_or_path (ndarray or str): A flow map or filepath.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if isinstance(flow_or_path, np.ndarray):
+ if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
+ raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
+ return flow_or_path
+ elif not is_str(flow_or_path):
+ raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
+ f'not {type(flow_or_path)}')
+ if not quantize:
+ with open(flow_or_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_or_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_or_path}, '
+ 'header does not contain PIEH')
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+ else:
+ assert concat_axis in [0, 1]
+ cat_flow = imread(flow_or_path, flag='unchanged')
+ if cat_flow.ndim != 2:
+ raise IOError(
+ f'{flow_or_path} is not a valid quantized flow file, '
+ f'its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+ return flow.astype(np.float32)
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ imwrite(dxdy, filename)
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
+ ]
+ return tuple(flow_comps)
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
+ """Use flow to warp img.
+ Args:
+ img (ndarray, float or uint8): Image to be warped.
+ flow (ndarray, float): Optical Flow.
+ filling_value (int): The missing pixels will be set with filling_value.
+ interpolate_mode (str): bilinear -> Bilinear Interpolation;
+ nearest -> Nearest Neighbor.
+ Returns:
+ ndarray: Warped image with the same shape of img
+ """
+ warnings.warn('This function is just for prototyping and cannot '
+ 'guarantee the computational efficiency.')
+ assert flow.ndim == 3, 'Flow must be in 3D arrays.'
+ height = flow.shape[0]
+ width = flow.shape[1]
+ channels = img.shape[2]
+ output = np.ones(
+ (height, width, channels), dtype=img.dtype) * filling_value
+ grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
+ dx = grid[:, :, 0] + flow[:, :, 1]
+ dy = grid[:, :, 1] + flow[:, :, 0]
+ sx = np.floor(dx).astype(int)
+ sy = np.floor(dy).astype(int)
+ valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
+ if interpolate_mode == 'nearest':
+ output[valid, :] = img[dx[valid].round().astype(int),
+ dy[valid].round().astype(int), :]
+ elif interpolate_mode == 'bilinear':
+ # dirty walkround for integer positions
+ eps_ = 1e-6
+ dx, dy = dx + eps_, dy + eps_
+ left_top_ = img[np.floor(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ left_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ right_top_ = img[np.floor(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ right_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
+ else:
+ raise NotImplementedError(
+ 'We only support interpolation modes of nearest and bilinear, '
+ f'but got {interpolate_mode}.')
+ return output.astype(img.dtype)
+def flow_from_bytes(content):
+ """Read dense optical flow from bytes.
+ .. note::
+ This load optical flow function works for FlyingChairs, FlyingThings3D,
+ Sintel, FlyingChairsOcc datasets, but cannot load the data from
+ ChairsSDHom.
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+ Returns:
+ ndarray: Loaded optical flow with the shape (H, W, 2).
+ """
+ # header in first 4 bytes
+ header = content[:4]
+ if header.decode('utf-8') != 'PIEH':
+ raise Exception('Flow file header does not contain PIEH')
+ # width in second 4 bytes
+ width = np.frombuffer(content[4:], np.int32, 1).squeeze()
+ # height in third 4 bytes
+ height = np.frombuffer(content[8:], np.int32, 1).squeeze()
+ # after first 12 bytes, all bytes are flow
+ flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
+ (height, width, 2))
+ return flow
+def sparse_flow_from_bytes(content):
+ """Read the optical flow in KITTI datasets from bytes.
+ This function is modified from RAFT load the `KITTI datasets
+ `_.
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+ Returns:
+ Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
+ and flow valid mask with the shape (H, W).
+ """ # nopa
+ content = np.frombuffer(content, np.uint8)
+ flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+ flow = flow[:, :, ::-1].astype(np.float32)
+ # flow shape (H, W, 2) valid shape (H, W)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2**15) / 64.0
+ return flow, valid
diff --git a/ControlNet/annotator/uniformer/mmcv/video/processing.py b/ControlNet/annotator/uniformer/mmcv/video/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d90b96e0823d5f116755e7f498d25d17017224a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/video/processing.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import subprocess
+import tempfile
+from annotator.uniformer.mmcv.utils import requires_executable
+def convert_video(in_file,
+ out_file,
+ print_cmd=False,
+ pre_options='',
+ **kwargs):
+ """Convert a video with ffmpeg.
+ This provides a general api to ffmpeg, the executed command is::
+ `ffmpeg -y -i `
+ Options(kwargs) are mapped to ffmpeg commands with the following rules:
+ - key=val: "-key val"
+ - key=True: "-key"
+ - key=False: ""
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ pre_options (str): Options appears before "-i ".
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = []
+ for k, v in kwargs.items():
+ if isinstance(v, bool):
+ if v:
+ options.append(f'-{k}')
+ elif k == 'log_level':
+ assert v in [
+ 'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
+ 'verbose', 'debug', 'trace'
+ ]
+ options.append(f'-loglevel {v}')
+ else:
+ options.append(f'-{k} {v}')
+ cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
+ f'{out_file}'
+ if print_cmd:
+ print(cmd)
+ subprocess.call(cmd, shell=True)
+def resize_video(in_file,
+ out_file,
+ size=None,
+ ratio=None,
+ keep_ar=False,
+ log_level='info',
+ print_cmd=False):
+ """Resize a video.
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
+ ratio (tuple or float): Expected resize ratio, (2, 0.5) means
+ (w*2, h*0.5).
+ keep_ar (bool): Whether to keep original aspect ratio.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ if size is None and ratio is None:
+ raise ValueError('expected size or ratio must be specified')
+ if size is not None and ratio is not None:
+ raise ValueError('size and ratio cannot be specified at the same time')
+ options = {'log_level': log_level}
+ if size:
+ if not keep_ar:
+ options['vf'] = f'scale={size[0]}:{size[1]}'
+ else:
+ options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
+ 'force_original_aspect_ratio=decrease'
+ else:
+ if not isinstance(ratio, tuple):
+ ratio = (ratio, ratio)
+ options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
+ convert_video(in_file, out_file, print_cmd, **options)
+def cut_video(in_file,
+ out_file,
+ start=None,
+ end=None,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Cut a clip from a video.
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ start (None or float): Start time (in seconds).
+ end (None or float): End time (in seconds).
+ vcodec (None or str): Output video codec, None for unchanged.
+ acodec (None or str): Output audio codec, None for unchanged.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ if start:
+ options['ss'] = start
+ else:
+ start = 0
+ if end:
+ options['t'] = end - start
+ convert_video(in_file, out_file, print_cmd, **options)
+def concat_video(video_list,
+ out_file,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Concatenate multiple videos into a single one.
+ Args:
+ video_list (list): A list of video filenames
+ out_file (str): Output video filename
+ vcodec (None or str): Output video codec, None for unchanged
+ acodec (None or str): Output audio codec, None for unchanged
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
+ with open(tmp_filename, 'w') as f:
+ for filename in video_list:
+ f.write(f'file {osp.abspath(filename)}\n')
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ convert_video(
+ tmp_filename,
+ out_file,
+ print_cmd,
+ pre_options='-f concat -safe 0',
+ **options)
+ os.close(tmp_filehandler)
+ os.remove(tmp_filename)
diff --git a/ControlNet/annotator/uniformer/mmcv/visualization/__init__.py b/ControlNet/annotator/uniformer/mmcv/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..835df136bdcf69348281d22914d41aa84cdf92b1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/visualization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .color import Color, color_val
+from .image import imshow, imshow_bboxes, imshow_det_bboxes
+from .optflow import flow2rgb, flowshow, make_color_wheel
+__all__ = [
+ 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
+ 'flowshow', 'flow2rgb', 'make_color_wheel'
diff --git a/ControlNet/annotator/uniformer/mmcv/visualization/color.py b/ControlNet/annotator/uniformer/mmcv/visualization/color.py
new file mode 100644
index 0000000000000000000000000000000000000000..9041e0e6b7581c3356795d6a3c5e84667c88f025
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/visualization/color.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+import numpy as np
+from annotator.uniformer.mmcv.utils import is_str
+class Color(Enum):
+ """An enum that defines common colors.
+ Contains red, green, blue, cyan, yellow, magenta, white and black.
+ """
+ red = (0, 0, 255)
+ green = (0, 255, 0)
+ blue = (255, 0, 0)
+ cyan = (255, 255, 0)
+ yellow = (0, 255, 255)
+ magenta = (255, 0, 255)
+ white = (255, 255, 255)
+ black = (0, 0, 0)
+def color_val(color):
+ """Convert various input to color tuples.
+ Args:
+ color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
+ Returns:
+ tuple[int]: A tuple of 3 integers indicating BGR channels.
+ """
+ if is_str(color):
+ return Color[color].value
+ elif isinstance(color, Color):
+ return color.value
+ elif isinstance(color, tuple):
+ assert len(color) == 3
+ for channel in color:
+ assert 0 <= channel <= 255
+ return color
+ elif isinstance(color, int):
+ assert 0 <= color <= 255
+ return color, color, color
+ elif isinstance(color, np.ndarray):
+ assert color.ndim == 1 and color.size == 3
+ assert np.all((color >= 0) & (color <= 255))
+ color = color.astype(np.uint8)
+ return tuple(color)
+ else:
+ raise TypeError(f'Invalid type for color: {type(color)}')
diff --git a/ControlNet/annotator/uniformer/mmcv/visualization/image.py b/ControlNet/annotator/uniformer/mmcv/visualization/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..61a56c75b67f593c298408462c63c0468be8e276
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/visualization/image.py
@@ -0,0 +1,152 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+from annotator.uniformer.mmcv.image import imread, imwrite
+from .color import color_val
+def imshow(img, win_name='', wait_time=0):
+ """Show an image.
+ Args:
+ img (str or ndarray): The image to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ cv2.imshow(win_name, imread(img))
+ if wait_time == 0: # prevent from hanging if windows was closed
+ while True:
+ ret = cv2.waitKey(1)
+ closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
+ # if user closed window or if some key pressed
+ if closed or ret != -1:
+ break
+ else:
+ ret = cv2.waitKey(wait_time)
+def imshow_bboxes(img,
+ bboxes,
+ colors='green',
+ top_k=-1,
+ thickness=1,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes on an image.
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (list or ndarray): A list of ndarray of shape (k, 4).
+ colors (list[str or tuple or Color]): A list of colors.
+ top_k (int): Plot the first k bboxes only if set positive.
+ thickness (int): Thickness of lines.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str, optional): The filename to write the image.
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+ if isinstance(bboxes, np.ndarray):
+ bboxes = [bboxes]
+ if not isinstance(colors, list):
+ colors = [colors for _ in range(len(bboxes))]
+ colors = [color_val(c) for c in colors]
+ assert len(bboxes) == len(colors)
+ for i, _bboxes in enumerate(bboxes):
+ _bboxes = _bboxes.astype(np.int32)
+ if top_k <= 0:
+ _top_k = _bboxes.shape[0]
+ else:
+ _top_k = min(top_k, _bboxes.shape[0])
+ for j in range(_top_k):
+ left_top = (_bboxes[j, 0], _bboxes[j, 1])
+ right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
+ cv2.rectangle(
+ img, left_top, right_bottom, colors[i], thickness=thickness)
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
+def imshow_det_bboxes(img,
+ bboxes,
+ labels,
+ class_names=None,
+ score_thr=0,
+ bbox_color='green',
+ text_color='green',
+ thickness=1,
+ font_scale=0.5,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes and class labels (with scores) on an image.
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5).
+ labels (ndarray): Labels of bboxes.
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown.
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename to write the image.
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ assert bboxes.ndim == 2
+ assert labels.ndim == 1
+ assert bboxes.shape[0] == labels.shape[0]
+ assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+ if score_thr > 0:
+ assert bboxes.shape[1] == 5
+ scores = bboxes[:, -1]
+ inds = scores > score_thr
+ bboxes = bboxes[inds, :]
+ labels = labels[inds]
+ bbox_color = color_val(bbox_color)
+ text_color = color_val(text_color)
+ for bbox, label in zip(bboxes, labels):
+ bbox_int = bbox.astype(np.int32)
+ left_top = (bbox_int[0], bbox_int[1])
+ right_bottom = (bbox_int[2], bbox_int[3])
+ cv2.rectangle(
+ img, left_top, right_bottom, bbox_color, thickness=thickness)
+ label_text = class_names[
+ label] if class_names is not None else f'cls {label}'
+ if len(bbox) > 4:
+ label_text += f'|{bbox[-1]:.02f}'
+ cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
diff --git a/ControlNet/annotator/uniformer/mmcv/visualization/optflow.py b/ControlNet/annotator/uniformer/mmcv/visualization/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3870c700f7c946177ee5d536ce3f6c814a77ce7
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv/visualization/optflow.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+import numpy as np
+from annotator.uniformer.mmcv.image import rgb2bgr
+from annotator.uniformer.mmcv.video import flowread
+from .image import imshow
+def flowshow(flow, win_name='', wait_time=0):
+ """Show optical flow.
+ Args:
+ flow (ndarray or str): The optical flow to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ flow = flowread(flow)
+ flow_img = flow2rgb(flow)
+ imshow(rgb2bgr(flow_img), win_name, wait_time)
+def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
+ """Convert flow map to RGB image.
+ Args:
+ flow (ndarray): Array of optical flow.
+ color_wheel (ndarray or None): Color wheel used to map flow field to
+ RGB colorspace. Default color wheel will be used if not specified.
+ unknown_thr (str): Values above this threshold will be marked as
+ unknown and thus ignored.
+ Returns:
+ ndarray: RGB image that can be visualized.
+ """
+ assert flow.ndim == 3 and flow.shape[-1] == 2
+ if color_wheel is None:
+ color_wheel = make_color_wheel()
+ assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
+ num_bins = color_wheel.shape[0]
+ dx = flow[:, :, 0].copy()
+ dy = flow[:, :, 1].copy()
+ ignore_inds = (
+ np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
+ (np.abs(dy) > unknown_thr))
+ dx[ignore_inds] = 0
+ dy[ignore_inds] = 0
+ rad = np.sqrt(dx**2 + dy**2)
+ if np.any(rad > np.finfo(float).eps):
+ max_rad = np.max(rad)
+ dx /= max_rad
+ dy /= max_rad
+ rad = np.sqrt(dx**2 + dy**2)
+ angle = np.arctan2(-dy, -dx) / np.pi
+ bin_real = (angle + 1) / 2 * (num_bins - 1)
+ bin_left = np.floor(bin_real).astype(int)
+ bin_right = (bin_left + 1) % num_bins
+ w = (bin_real - bin_left.astype(np.float32))[..., None]
+ flow_img = (1 -
+ w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
+ small_ind = rad <= 1
+ flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
+ flow_img[np.logical_not(small_ind)] *= 0.75
+ flow_img[ignore_inds, :] = 0
+ return flow_img
+def make_color_wheel(bins=None):
+ """Build a color wheel.
+ Args:
+ bins(list or tuple, optional): Specify the number of bins for each
+ color range, corresponding to six ranges: red -> yellow,
+ yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
+ magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
+ (see Middlebury).
+ Returns:
+ ndarray: Color wheel of shape (total_bins, 3).
+ """
+ if bins is None:
+ bins = [15, 6, 4, 11, 13, 6]
+ assert len(bins) == 6
+ RY, YG, GC, CB, BM, MR = tuple(bins)
+ ry = [1, np.arange(RY) / RY, 0]
+ yg = [1 - np.arange(YG) / YG, 1, 0]
+ gc = [0, 1, np.arange(GC) / GC]
+ cb = [0, 1 - np.arange(CB) / CB, 1]
+ bm = [np.arange(BM) / BM, 0, 1]
+ mr = [1, 0, 1 - np.arange(MR) / MR]
+ num_bins = RY + YG + GC + CB + BM + MR
+ color_wheel = np.zeros((3, num_bins), dtype=np.float32)
+ col = 0
+ for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
+ for j in range(3):
+ color_wheel[j, col:col + bins[i]] = color[j]
+ col += bins[i]
+ return color_wheel.T
diff --git a/ControlNet/annotator/uniformer/mmcv_custom/__init__.py b/ControlNet/annotator/uniformer/mmcv_custom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b958738b9fd93bfcec239c550df1d9a44b8c536
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv_custom/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding: utf-8 -*-
+from .checkpoint import load_checkpoint
+__all__ = ['load_checkpoint']
\ No newline at end of file
diff --git a/ControlNet/annotator/uniformer/mmcv_custom/checkpoint.py b/ControlNet/annotator/uniformer/mmcv_custom/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..19b87fef0a52d31babcdb3edb8f3089b6420173f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmcv_custom/checkpoint.py
@@ -0,0 +1,500 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+from torch.nn import functional as F
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.fileio import FileClient
+from annotator.uniformer.mmcv.fileio import load as load_file
+from annotator.uniformer.mmcv.parallel import is_module_wrapper
+from annotator.uniformer.mmcv.utils import mkdir_or_exist
+from annotator.uniformer.mmcv.runner import get_dist_info
+DEFAULT_CACHE_DIR = '~/.cache'
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+ load(module)
+ load = None # break load->load reference cycle
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+def load_url_dist(url, model_dir=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ return checkpoint
+def load_pavimodel_dist(model_path, map_location=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(
+ downloaded_file, map_location=map_location)
+ return checkpoint
+def load_fileclient_dist(filename, backend, map_location):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ allowed_backends = ['ceph']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+ if rank == 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+ return default_urls
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+ return mmcls_urls
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+ return deprecate_urls
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+ return new_checkpoint
+def _load_checkpoint(filename, map_location=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+ Returns:
+ dict | OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_urls = get_torchvision_models()
+ model_name = filename[11:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('torchvision://'):
+ model_urls = get_torchvision_models()
+ model_name = filename[14:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('open-mmlab://'):
+ model_urls = get_external_models()
+ model_name = filename[13:]
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
+ f'of open-mmlab://{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(model_url)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ elif filename.startswith('mmcls://'):
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ elif filename.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(filename)
+ elif filename.startswith('pavi://'):
+ model_path = filename[7:]
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+ elif filename.startswith('s3://'):
+ checkpoint = load_fileclient_dist(
+ filename, backend='ceph', map_location=map_location)
+ else:
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+def load_checkpoint(model,
+ filename,
+ map_location='cpu',
+ strict=False,
+ logger=None):
+ """Load checkpoint from a file or URI.
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith('module.'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+ # for MoBY, load model of online branch
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
+ # reshape absolute position embedding
+ if state_dict.get('absolute_pos_embed') is not None:
+ absolute_pos_embed = state_dict['absolute_pos_embed']
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = model.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H*W:
+ logger.warning("Error in loading absolute_pos_embed, pass")
+ else:
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ table_current = model.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f"Error in loading {table_key}, pass")
+ else:
+ if L1 != L2:
+ S1 = int(L1 ** 0.5)
+ S2 = int(L2 ** 0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
+ size=(S2, S2), mode='bicubic')
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ return state_dict_cpu
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+ """Save checkpoint to file.
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+ if is_module_wrapper(model):
+ model = model.module
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+ if filename.startswith('pavi://'):
+ try:
+ from pavi import modelcloud
+ from pavi.exception import NodeNotFoundError
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ mmcv.mkdir_or_exist(osp.dirname(filename))
+ # immediately flush buffer
+ with open(filename, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
\ No newline at end of file
diff --git a/ControlNet/annotator/uniformer/mmseg/apis/__init__.py b/ControlNet/annotator/uniformer/mmseg/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/apis/__init__.py
@@ -0,0 +1,9 @@
+from .inference import inference_segmentor, init_segmentor, show_result_pyplot
+from .test import multi_gpu_test, single_gpu_test
+from .train import get_root_logger, set_random_seed, train_segmentor
+__all__ = [
+ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
+ 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
+ 'show_result_pyplot'
diff --git a/ControlNet/annotator/uniformer/mmseg/apis/inference.py b/ControlNet/annotator/uniformer/mmseg/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..90bc1c0c68525734bd6793f07c15fe97d3c8342c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/apis/inference.py
@@ -0,0 +1,136 @@
+import matplotlib.pyplot as plt
+import annotator.uniformer.mmcv as mmcv
+import torch
+from annotator.uniformer.mmcv.parallel import collate, scatter
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmseg.datasets.pipelines import Compose
+from annotator.uniformer.mmseg.models import build_segmentor
+def init_segmentor(config, checkpoint=None, device='cuda:0'):
+ """Initialize a segmentor from config file.
+ Args:
+ config (str or :obj:`mmcv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
+ Use 'cpu' for loading model on CPU.
+ Returns:
+ nn.Module: The constructed segmentor.
+ """
+ if isinstance(config, str):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ 'but got {}'.format(type(config)))
+ config.model.pretrained = None
+ config.model.train_cfg = None
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ model.PALETTE = checkpoint['meta']['PALETTE']
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+class LoadImage:
+ """A simple pipeline to load image."""
+ def __call__(self, results):
+ """Call function to load images into results.
+ Args:
+ results (dict): A result dict contains the file name
+ of the image to be read.
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+ if isinstance(results['img'], str):
+ results['filename'] = results['img']
+ results['ori_filename'] = results['img']
+ else:
+ results['filename'] = None
+ results['ori_filename'] = None
+ img = mmcv.imread(results['img'])
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ return results
+def inference_segmentor(model, img):
+ """Inference image(s) with the segmentor.
+ Args:
+ model (nn.Module): The loaded segmentor.
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+ images.
+ Returns:
+ (list[Tensor]): The segmentation result.
+ """
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+ # build the data pipeline
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
+ test_pipeline = Compose(test_pipeline)
+ # prepare data
+ data = dict(img=img)
+ data = test_pipeline(data)
+ data = collate([data], samples_per_gpu=1)
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
+ # forward the model
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ return result
+def show_result_pyplot(model,
+ img,
+ result,
+ palette=None,
+ fig_size=(15, 10),
+ opacity=0.5,
+ title='',
+ block=True):
+ """Visualize the segmentation results on the image.
+ Args:
+ model (nn.Module): The loaded segmentor.
+ img (str or np.ndarray): Image filename or loaded image.
+ result (list): The segmentation result.
+ palette (list[list[int]]] | None): The palette of segmentation
+ map. If None is given, random palette will be generated.
+ Default: None
+ fig_size (tuple): Figure size of the pyplot figure.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ title (str): The title of pyplot figure.
+ Default is ''.
+ block (bool): Whether to block the pyplot figure.
+ Default is True.
+ """
+ if hasattr(model, 'module'):
+ model = model.module
+ img = model.show_result(
+ img, result, palette=palette, show=False, opacity=opacity)
+ # plt.figure(figsize=fig_size)
+ # plt.imshow(mmcv.bgr2rgb(img))
+ # plt.title(title)
+ # plt.tight_layout()
+ # plt.show(block=block)
+ return mmcv.bgr2rgb(img)
diff --git a/ControlNet/annotator/uniformer/mmseg/apis/test.py b/ControlNet/annotator/uniformer/mmseg/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e574eb7da04f09a59cf99ff953c36468ae87a326
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/apis/test.py
@@ -0,0 +1,238 @@
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from annotator.uniformer.mmcv.image import tensor2imgs
+from annotator.uniformer.mmcv.runner import get_dist_info
+def np2tmp(array, temp_file_name=None):
+ """Save ndarray to local numpy file.
+ Args:
+ array (ndarray): Ndarray to save.
+ temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
+ function will generate a file name with tempfile.NamedTemporaryFile
+ to save ndarray. Default: None.
+ Returns:
+ str: The numpy file name.
+ """
+ if temp_file_name is None:
+ temp_file_name = tempfile.NamedTemporaryFile(
+ suffix='.npy', delete=False).name
+ np.save(temp_file_name, array)
+ return temp_file_name
+def single_gpu_test(model,
+ data_loader,
+ show=False,
+ out_dir=None,
+ efficient_test=False,
+ opacity=0.5):
+ """Test with single GPU.
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ show (bool): Whether show results during inference. Default: False.
+ out_dir (str, optional): If specified, the results will be dumped into
+ the directory to save output results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ if show or out_dir:
+ img_tensor = data['img'][0]
+ img_metas = data['img_metas'][0].data[0]
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+ assert len(imgs) == len(img_metas)
+ for img, img_meta in zip(imgs, img_metas):
+ h, w, _ = img_meta['img_shape']
+ img_show = img[:h, :w, :]
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+ if out_dir:
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
+ else:
+ out_file = None
+ model.module.show_result(
+ img_show,
+ result,
+ palette=dataset.PALETTE,
+ show=show,
+ out_file=out_file,
+ opacity=opacity)
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+def multi_gpu_test(model,
+ data_loader,
+ tmpdir=None,
+ gpu_collect=False,
+ efficient_test=False):
+ """Test model with multiple gpus.
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+ if rank == 0:
+ batch_size = data['img'][0].size(0)
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results with CPU."""
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ tmpdir = tempfile.mkdtemp()
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
+ part_list.append(mmcv.load(part_file))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+def collect_results_gpu(result_part, size):
+ """Collect results with GPU."""
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_list.append(
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/ControlNet/annotator/uniformer/mmseg/apis/train.py b/ControlNet/annotator/uniformer/mmseg/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f319a919ff023931a6a663e668f27dd1a07a2e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/apis/train.py
@@ -0,0 +1,116 @@
+import random
+import warnings
+import numpy as np
+import torch
+from annotator.uniformer.mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from annotator.uniformer.mmcv.runner import build_optimizer, build_runner
+from annotator.uniformer.mmseg.core import DistEvalHook, EvalHook
+from annotator.uniformer.mmseg.datasets import build_dataloader, build_dataset
+from annotator.uniformer.mmseg.utils import get_root_logger
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+def train_segmentor(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ """Launch segmentor training."""
+ logger = get_root_logger(cfg.log_level)
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed,
+ drop_last=True) for ds in dataset
+ ]
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+ # build runner
+ optimizer = build_optimizer(model, cfg.optimizer)
+ if cfg.get('runner') is None:
+ cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ batch_processor=None,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
+ cfg.checkpoint_config, cfg.log_config,
+ cfg.get('momentum_config', None))
+ # an ugly walkaround to make the .log and .log.json filenames the same
+ runner.timestamp = timestamp
+ # register eval hooks
+ if validate:
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=1,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/ControlNet/annotator/uniformer/mmseg/core/__init__.py b/ControlNet/annotator/uniformer/mmseg/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/__init__.py
@@ -0,0 +1,3 @@
+from .evaluation import * # noqa: F401, F403
+from .seg import * # noqa: F401, F403
+from .utils import * # noqa: F401, F403
diff --git a/ControlNet/annotator/uniformer/mmseg/core/evaluation/__init__.py b/ControlNet/annotator/uniformer/mmseg/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7cc4b23413a0639e9de00eeb0bf600632d2c6cd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/evaluation/__init__.py
@@ -0,0 +1,8 @@
+from .class_names import get_classes, get_palette
+from .eval_hooks import DistEvalHook, EvalHook
+from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
+__all__ = [
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
+ 'eval_metrics', 'get_classes', 'get_palette'
diff --git a/ControlNet/annotator/uniformer/mmseg/core/evaluation/class_names.py b/ControlNet/annotator/uniformer/mmseg/core/evaluation/class_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffae816cf980ce4b03e491cc0c4298cb823797e6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/evaluation/class_names.py
@@ -0,0 +1,152 @@
+import annotator.uniformer.mmcv as mmcv
+def cityscapes_classes():
+ """Cityscapes class names for external use."""
+ return [
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle'
+ ]
+def ade_classes():
+ """ADE20K class names for external use."""
+ return [
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag'
+ ]
+def voc_classes():
+ """Pascal VOC class names for external use."""
+ return [
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
+ 'tvmonitor'
+ ]
+def cityscapes_palette():
+ """Cityscapes palette for external use."""
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
+ [0, 0, 230], [119, 11, 32]]
+def ade_palette():
+ """ADE20K palette for external use."""
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+def voc_palette():
+ """Pascal VOC palette for external use."""
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+dataset_aliases = {
+ 'cityscapes': ['cityscapes'],
+ 'ade': ['ade', 'ade20k'],
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
+def get_classes(dataset):
+ """Get class names of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_classes()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
+def get_palette(dataset):
+ """Get class palette (RGB) of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_palette()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
diff --git a/ControlNet/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py b/ControlNet/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc100c8f96e817a6ed2666f7c9f762af2463b48
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
@@ -0,0 +1,109 @@
+import os.path as osp
+from annotator.uniformer.mmcv.runner import DistEvalHook as _DistEvalHook
+from annotator.uniformer.mmcv.runner import EvalHook as _EvalHook
+class EvalHook(_EvalHook):
+ """Single GPU EvalHook, with efficient test support.
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+ Override default ``single_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(
+ runner.model,
+ self.dataloader,
+ show=False,
+ efficient_test=self.efficient_test)
+ self.evaluate(runner, results)
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+ Override default ``single_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
+ self.evaluate(runner, results)
+class DistEvalHook(_DistEvalHook):
+ """Distributed EvalHook, with efficient test support.
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+ Override default ``multi_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect,
+ efficient_test=self.efficient_test)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+ Override default ``multi_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
diff --git a/ControlNet/annotator/uniformer/mmseg/core/evaluation/metrics.py b/ControlNet/annotator/uniformer/mmseg/core/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c7dd47cadd53cf1caaa194e28a343f2aacc599
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/evaluation/metrics.py
@@ -0,0 +1,326 @@
+from collections import OrderedDict
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+def f_score(precision, recall, beta=1):
+ """calcuate the f-score value.
+ Args:
+ precision (float | torch.Tensor): The precision value.
+ recall (float | torch.Tensor): The recall value.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+ Returns:
+ [torch.tensor]: The f-score value.
+ """
+ score = (1 + beta**2) * (precision * recall) / (
+ (beta**2 * precision) + recall)
+ return score
+def intersect_and_union(pred_label,
+ label,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate intersection and Union.
+ Args:
+ pred_label (ndarray | str): Prediction segmentation map
+ or predict result filename.
+ label (ndarray | str): Ground truth segmentation map
+ or label filename.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. The parameter will
+ work only when label is str. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. The parameter will
+ work only when label is str. Default: False.
+ Returns:
+ torch.Tensor: The intersection of prediction and ground truth
+ histogram on all classes.
+ torch.Tensor: The union of prediction and ground truth histogram on
+ all classes.
+ torch.Tensor: The prediction histogram on all classes.
+ torch.Tensor: The ground truth histogram on all classes.
+ """
+ if isinstance(pred_label, str):
+ pred_label = torch.from_numpy(np.load(pred_label))
+ else:
+ pred_label = torch.from_numpy((pred_label))
+ if isinstance(label, str):
+ label = torch.from_numpy(
+ mmcv.imread(label, flag='unchanged', backend='pillow'))
+ else:
+ label = torch.from_numpy(label)
+ if label_map is not None:
+ for old_id, new_id in label_map.items():
+ label[label == old_id] = new_id
+ if reduce_zero_label:
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+ mask = (label != ignore_index)
+ pred_label = pred_label[mask]
+ label = label[mask]
+ intersect = pred_label[pred_label == label]
+ area_intersect = torch.histc(
+ intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_pred_label = torch.histc(
+ pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_label = torch.histc(
+ label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_union = area_pred_label + area_label - area_intersect
+ return area_intersect, area_union, area_pred_label, area_label
+def total_intersect_and_union(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Total Intersection and Union.
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+ Returns:
+ ndarray: The intersection of prediction and ground truth histogram
+ on all classes.
+ ndarray: The union of prediction and ground truth histogram on all
+ classes.
+ ndarray: The prediction histogram on all classes.
+ ndarray: The ground truth histogram on all classes.
+ """
+ num_imgs = len(results)
+ assert len(gt_seg_maps) == num_imgs
+ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ for i in range(num_imgs):
+ area_intersect, area_union, area_pred_label, area_label = \
+ intersect_and_union(
+ results[i], gt_seg_maps[i], num_classes, ignore_index,
+ label_map, reduce_zero_label)
+ total_area_intersect += area_intersect
+ total_area_union += area_union
+ total_area_pred_label += area_pred_label
+ total_area_label += area_label
+ return total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label
+def mean_iou(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Intersection and Union (mIoU)
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+ Returns:
+ dict[str, float | ndarray]:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category IoU, shape (num_classes, ).
+ """
+ iou_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return iou_result
+def mean_dice(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Dice (mDice)
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category dice, shape (num_classes, ).
+ """
+ dice_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mDice'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return dice_result
+def mean_fscore(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate Mean Intersection and Union (mIoU)
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category recall, shape (num_classes, ).
+ ndarray: Per category precision, shape (num_classes, ).
+ ndarray: Per category f-score, shape (num_classes, ).
+ """
+ fscore_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mFscore'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label,
+ beta=beta)
+ return fscore_result
+def eval_metrics(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate evaluation metrics
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+ Returns:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
+ """
+ if isinstance(metrics, str):
+ metrics = [metrics]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metrics).issubset(set(allowed_metrics)):
+ raise KeyError('metrics {} is not supported'.format(metrics))
+ total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label = total_intersect_and_union(
+ results, gt_seg_maps, num_classes, ignore_index, label_map,
+ reduce_zero_label)
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
+ ret_metrics = OrderedDict({'aAcc': all_acc})
+ for metric in metrics:
+ if metric == 'mIoU':
+ iou = total_area_intersect / total_area_union
+ acc = total_area_intersect / total_area_label
+ ret_metrics['IoU'] = iou
+ ret_metrics['Acc'] = acc
+ elif metric == 'mDice':
+ dice = 2 * total_area_intersect / (
+ total_area_pred_label + total_area_label)
+ acc = total_area_intersect / total_area_label
+ ret_metrics['Dice'] = dice
+ ret_metrics['Acc'] = acc
+ elif metric == 'mFscore':
+ precision = total_area_intersect / total_area_pred_label
+ recall = total_area_intersect / total_area_label
+ f_value = torch.tensor(
+ [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
+ ret_metrics['Fscore'] = f_value
+ ret_metrics['Precision'] = precision
+ ret_metrics['Recall'] = recall
+ ret_metrics = {
+ metric: value.numpy()
+ for metric, value in ret_metrics.items()
+ }
+ if nan_to_num is not None:
+ ret_metrics = OrderedDict({
+ metric: np.nan_to_num(metric_value, nan=nan_to_num)
+ for metric, metric_value in ret_metrics.items()
+ })
+ return ret_metrics
diff --git a/ControlNet/annotator/uniformer/mmseg/core/seg/__init__.py b/ControlNet/annotator/uniformer/mmseg/core/seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/seg/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_pixel_sampler
+from .sampler import BasePixelSampler, OHEMPixelSampler
+__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
diff --git a/ControlNet/annotator/uniformer/mmseg/core/seg/builder.py b/ControlNet/annotator/uniformer/mmseg/core/seg/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..db61f03d4abb2072f2532ce4429c0842495e015b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/seg/builder.py
@@ -0,0 +1,8 @@
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg
+PIXEL_SAMPLERS = Registry('pixel sampler')
+def build_pixel_sampler(cfg, **default_args):
+ """Build pixel sampler for segmentation map."""
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
diff --git a/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/__init__.py b/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
@@ -0,0 +1,4 @@
+from .base_pixel_sampler import BasePixelSampler
+from .ohem_pixel_sampler import OHEMPixelSampler
+__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
diff --git a/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py b/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b75b1566c9f18169cee51d4b55d75e0357b69c57
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
@@ -0,0 +1,12 @@
+from abc import ABCMeta, abstractmethod
+class BasePixelSampler(metaclass=ABCMeta):
+ """Base class of pixel sampler."""
+ def __init__(self, **kwargs):
+ pass
+ @abstractmethod
+ def sample(self, seg_logit, seg_label):
+ """Placeholder for sample function."""
diff --git a/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+from ..builder import PIXEL_SAMPLERS
+from .base_pixel_sampler import BasePixelSampler
+class OHEMPixelSampler(BasePixelSampler):
+ """Online Hard Example Mining Sampler for segmentation.
+ Args:
+ context (nn.Module): The context of sampler, subclass of
+ :obj:`BaseDecodeHead`.
+ thresh (float, optional): The threshold for hard example selection.
+ Below which, are prediction with low confidence. If not
+ specified, the hard examples will be pixels of top ``min_kept``
+ loss. Default: None.
+ min_kept (int, optional): The minimum number of predictions to keep.
+ Default: 100000.
+ """
+ def __init__(self, context, thresh=None, min_kept=100000):
+ super(OHEMPixelSampler, self).__init__()
+ self.context = context
+ assert min_kept > 1
+ self.thresh = thresh
+ self.min_kept = min_kept
+ def sample(self, seg_logit, seg_label):
+ """Sample pixels that have high loss or with low prediction confidence.
+ Args:
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
+ Returns:
+ torch.Tensor: segmentation weight, shape (N, H, W)
+ """
+ with torch.no_grad():
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
+ assert seg_label.shape[1] == 1
+ seg_label = seg_label.squeeze(1).long()
+ batch_kept = self.min_kept * seg_label.size(0)
+ valid_mask = seg_label != self.context.ignore_index
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
+ valid_seg_weight = seg_weight[valid_mask]
+ if self.thresh is not None:
+ seg_prob = F.softmax(seg_logit, dim=1)
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
+ if sort_prob.numel() > 0:
+ min_threshold = sort_prob[min(batch_kept,
+ sort_prob.numel() - 1)]
+ else:
+ min_threshold = 0.0
+ threshold = max(min_threshold, self.thresh)
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
+ else:
+ losses = self.context.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=None,
+ ignore_index=self.context.ignore_index,
+ reduction_override='none')
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
+ _, sort_indices = losses[valid_mask].sort(descending=True)
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
+ seg_weight[valid_mask] = valid_seg_weight
+ return seg_weight
diff --git a/ControlNet/annotator/uniformer/mmseg/core/utils/__init__.py b/ControlNet/annotator/uniformer/mmseg/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/utils/__init__.py
@@ -0,0 +1,3 @@
+from .misc import add_prefix
+__all__ = ['add_prefix']
diff --git a/ControlNet/annotator/uniformer/mmseg/core/utils/misc.py b/ControlNet/annotator/uniformer/mmseg/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/core/utils/misc.py
@@ -0,0 +1,17 @@
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+ Returns:
+ dict: The dict with keys updated with ``prefix``.
+ """
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f'{prefix}.{name}'] = value
+ return outputs
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/__init__.py b/ControlNet/annotator/uniformer/mmseg/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebeaef4a28ef655e43578552a8aef6b77f13a636
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/__init__.py
@@ -0,0 +1,19 @@
+from .ade import ADE20KDataset
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .chase_db1 import ChaseDB1Dataset
+from .cityscapes import CityscapesDataset
+from .custom import CustomDataset
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+from .drive import DRIVEDataset
+from .hrf import HRFDataset
+from .pascal_context import PascalContextDataset, PascalContextDataset59
+from .stare import STAREDataset
+from .voc import PascalVOCDataset
+__all__ = [
+ 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+ 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
+ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
+ 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
+ 'STAREDataset'
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/ade.py b/ControlNet/annotator/uniformer/mmseg/datasets/ade.py
new file mode 100644
index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/ade.py
@@ -0,0 +1,84 @@
+from .builder import DATASETS
+from .custom import CustomDataset
+class ADE20KDataset(CustomDataset):
+ """ADE20K dataset.
+ In segmentation map annotation for ADE20K, 0 stands for background, which
+ is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
+ The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag')
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+ def __init__(self, **kwargs):
+ super(ADE20KDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ reduce_zero_label=True,
+ **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/builder.py b/ControlNet/annotator/uniformer/mmseg/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0798b14cd8b39fc58d8f2a4930f1e079b5bf8b55
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/builder.py
@@ -0,0 +1,169 @@
+import copy
+import platform
+import random
+from functools import partial
+import numpy as np
+from annotator.uniformer.mmcv.parallel import collate
+from annotator.uniformer.mmcv.runner import get_dist_info
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg
+from annotator.uniformer.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
+from torch.utils.data import DistributedSampler
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ hard_limit = rlimit[1]
+ soft_limit = min(4096, hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+def _concat_dataset(cfg, default_args=None):
+ """Build :obj:`ConcatDataset by."""
+ from .dataset_wrappers import ConcatDataset
+ img_dir = cfg['img_dir']
+ ann_dir = cfg.get('ann_dir', None)
+ split = cfg.get('split', None)
+ num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
+ if ann_dir is not None:
+ num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
+ else:
+ num_ann_dir = 0
+ if split is not None:
+ num_split = len(split) if isinstance(split, (list, tuple)) else 1
+ else:
+ num_split = 0
+ if num_img_dir > 1:
+ assert num_img_dir == num_ann_dir or num_ann_dir == 0
+ assert num_img_dir == num_split or num_split == 0
+ else:
+ assert num_split == num_ann_dir or num_ann_dir <= 1
+ num_dset = max(num_split, num_img_dir)
+ datasets = []
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ if isinstance(img_dir, (list, tuple)):
+ data_cfg['img_dir'] = img_dir[i]
+ if isinstance(ann_dir, (list, tuple)):
+ data_cfg['ann_dir'] = ann_dir[i]
+ if isinstance(split, (list, tuple)):
+ data_cfg['split'] = split[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+ return ConcatDataset(datasets)
+def build_dataset(cfg, default_args=None):
+ """Build datasets."""
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
+ cfg.get('split', None), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+ return dataset
+def build_dataloader(dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ drop_last=False,
+ pin_memory=True,
+ dataloader_type='PoolDataLoader',
+ **kwargs):
+ """Build PyTorch DataLoader.
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ seed (int | None): Seed to be used. Default: None.
+ drop_last (bool): Whether to drop the last incomplete batch in epoch.
+ Default: False
+ pin_memory (bool): Whether to use pin_memory in DataLoader.
+ Default: True
+ dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
+ kwargs: any keyword argument to be used to initialize DataLoader
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ if dist:
+ sampler = DistributedSampler(
+ dataset, world_size, rank, shuffle=shuffle)
+ shuffle = False
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ sampler = None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+ assert dataloader_type in (
+ 'DataLoader',
+ 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
+ if dataloader_type == 'PoolDataLoader':
+ dataloader = PoolDataLoader
+ elif dataloader_type == 'DataLoader':
+ dataloader = DataLoader
+ data_loader = dataloader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=pin_memory,
+ shuffle=shuffle,
+ worker_init_fn=init_fn,
+ drop_last=drop_last,
+ **kwargs)
+ return data_loader
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ """Worker init func for dataloader.
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
+ Args:
+ worker_id (int): Worker id.
+ num_workers (int): Number of workers.
+ rank (int): The rank of current process.
+ seed (int): The random seed to use.
+ """
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/chase_db1.py b/ControlNet/annotator/uniformer/mmseg/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/chase_db1.py
@@ -0,0 +1,27 @@
+import os.path as osp
+from .builder import DATASETS
+from .custom import CustomDataset
+class ChaseDB1Dataset(CustomDataset):
+ """Chase_db1 dataset.
+ In segmentation map annotation for Chase_db1, 0 stands for background,
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
+ The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_1stHO.png'.
+ """
+ CLASSES = ('background', 'vessel')
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+ def __init__(self, **kwargs):
+ super(ChaseDB1Dataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_1stHO.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/cityscapes.py b/ControlNet/annotator/uniformer/mmseg/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e47a914a1aa2e5458e18669d65ffb742f46fc6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/cityscapes.py
@@ -0,0 +1,217 @@
+import os.path as osp
+import tempfile
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import print_log
+from PIL import Image
+from .builder import DATASETS
+from .custom import CustomDataset
+class CityscapesDataset(CustomDataset):
+ """Cityscapes dataset.
+ The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
+ fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
+ """
+ CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle')
+ PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
+ [0, 80, 100], [0, 0, 230], [119, 11, 32]]
+ def __init__(self, **kwargs):
+ super(CityscapesDataset, self).__init__(
+ img_suffix='_leftImg8bit.png',
+ seg_map_suffix='_gtFine_labelTrainIds.png',
+ **kwargs)
+ @staticmethod
+ def _convert_to_label_id(result):
+ """Convert trainId to id for cityscapes."""
+ if isinstance(result, str):
+ result = np.load(result)
+ import cityscapesscripts.helpers.labels as CSLabels
+ result_copy = result.copy()
+ for trainId, label in CSLabels.trainId2label.items():
+ result_copy[result == trainId] = label.id
+ return result_copy
+ def results2img(self, results, imgfile_prefix, to_label_id):
+ """Write the segmentation results to images.
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ imgfile_prefix (str): The filename prefix of the png files.
+ If the prefix is "somepath/xxx",
+ the png files will be named "somepath/xxx.png".
+ to_label_id (bool): whether convert output to label_id for
+ submission
+ Returns:
+ list[str: str]: result txt files which contains corresponding
+ semantic segmentation images.
+ """
+ mmcv.mkdir_or_exist(imgfile_prefix)
+ result_files = []
+ prog_bar = mmcv.ProgressBar(len(self))
+ for idx in range(len(self)):
+ result = results[idx]
+ if to_label_id:
+ result = self._convert_to_label_id(result)
+ filename = self.img_infos[idx]['filename']
+ basename = osp.splitext(osp.basename(filename))[0]
+ png_filename = osp.join(imgfile_prefix, f'{basename}.png')
+ output = Image.fromarray(result.astype(np.uint8)).convert('P')
+ import cityscapesscripts.helpers.labels as CSLabels
+ palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
+ for label_id, label in CSLabels.id2label.items():
+ palette[label_id] = label.color
+ output.putpalette(palette)
+ output.save(png_filename)
+ result_files.append(png_filename)
+ prog_bar.update()
+ return result_files
+ def format_results(self, results, imgfile_prefix=None, to_label_id=True):
+ """Format the results into dir (standard format for Cityscapes
+ evaluation).
+ Args:
+ results (list): Testing results of the dataset.
+ imgfile_prefix (str | None): The prefix of images files. It
+ includes the file path and the prefix of filename, e.g.,
+ "a/b/prefix". If not specified, a temp file will be created.
+ Default: None.
+ to_label_id (bool): whether convert output to label_id for
+ submission. Default: False
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a list containing
+ the image paths, tmp_dir is the temporal directory created
+ for saving json/png files when img_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: '
+ f'{len(results)} != {len(self)}')
+ if imgfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ imgfile_prefix = tmp_dir.name
+ else:
+ tmp_dir = None
+ result_files = self.results2img(results, imgfile_prefix, to_label_id)
+ return result_files, tmp_dir
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ imgfile_prefix=None,
+ efficient_test=False):
+ """Evaluation in Cityscapes/default protocol.
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file,
+ for cityscapes evaluation only. It includes the file path and
+ the prefix of filename, e.g., "a/b/prefix".
+ If results are evaluated with cityscapes protocol, it would be
+ the prefix of output png files. The output files would be
+ png images under folder "a/b/prefix/xxx.png", where "xxx" is
+ the image name of cityscapes. If not specified, a temp file
+ will be created for evaluation.
+ Default: None.
+ Returns:
+ dict[str, float]: Cityscapes/default metrics.
+ """
+ eval_results = dict()
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
+ if 'cityscapes' in metrics:
+ eval_results.update(
+ self._evaluate_cityscapes(results, logger, imgfile_prefix))
+ metrics.remove('cityscapes')
+ if len(metrics) > 0:
+ eval_results.update(
+ super(CityscapesDataset,
+ self).evaluate(results, metrics, logger, efficient_test))
+ return eval_results
+ def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
+ """Evaluation in Cityscapes protocol.
+ Args:
+ results (list): Testing results of the dataset.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file
+ Returns:
+ dict[str: float]: Cityscapes evaluation results.
+ """
+ try:
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
+ except ImportError:
+ raise ImportError('Please run "pip install cityscapesscripts" to '
+ 'install cityscapesscripts first.')
+ msg = 'Evaluating in Cityscapes style'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+ result_files, tmp_dir = self.format_results(results, imgfile_prefix)
+ if tmp_dir is None:
+ result_dir = imgfile_prefix
+ else:
+ result_dir = tmp_dir.name
+ eval_results = dict()
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
+ CSEval.args.evalInstLevelScore = True
+ CSEval.args.predictionPath = osp.abspath(result_dir)
+ CSEval.args.evalPixelAccuracy = True
+ CSEval.args.JSONOutput = False
+ seg_map_list = []
+ pred_list = []
+ # when evaluating with official cityscapesscripts,
+ # **_gtFine_labelIds.png is used
+ for seg_map in mmcv.scandir(
+ self.ann_dir, 'gtFine_labelIds.png', recursive=True):
+ seg_map_list.append(osp.join(self.ann_dir, seg_map))
+ pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
+ eval_results.update(
+ CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/custom.py b/ControlNet/annotator/uniformer/mmseg/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8eb2a709cc7a3a68fc6a1e3a1ad98faef4c5b7b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/custom.py
@@ -0,0 +1,400 @@
+import os
+import os.path as osp
+from collections import OrderedDict
+from functools import reduce
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import print_log
+from prettytable import PrettyTable
+from torch.utils.data import Dataset
+from annotator.uniformer.mmseg.core import eval_metrics
+from annotator.uniformer.mmseg.utils import get_root_logger
+from .builder import DATASETS
+from .pipelines import Compose
+class CustomDataset(Dataset):
+ """Custom dataset for semantic segmentation. An example of file structure
+ is as followed.
+ .. code-block:: none
+ ├── data
+ │ ├── my_dataset
+ │ │ ├── img_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{img_suffix}
+ │ │ │ │ ├── yyy{img_suffix}
+ │ │ │ │ ├── zzz{img_suffix}
+ │ │ │ ├── val
+ │ │ ├── ann_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{seg_map_suffix}
+ │ │ │ │ ├── yyy{seg_map_suffix}
+ │ │ │ │ ├── zzz{seg_map_suffix}
+ │ │ │ ├── val
+ The img/gt_semantic_seg pair of CustomDataset should be of the same
+ except suffix. A valid img/gt_semantic_seg filename pair should be like
+ ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
+ in the suffix). If split is given, then ``xxx`` is specified in txt file.
+ Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
+ Please refer to ``docs/tutorials/new_dataset.md`` for more details.
+ Args:
+ pipeline (list[dict]): Processing pipeline
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images. Default: '.jpg'
+ ann_dir (str, optional): Path to annotation directory. Default: None
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ split (str, optional): Split txt file. If split is specified, only
+ file with suffix in the splits will be loaded. Otherwise, all
+ images in img_dir/ann_dir will be loaded. Default: None
+ data_root (str, optional): Data root for img_dir/ann_dir. Default:
+ None.
+ test_mode (bool): If test_mode=True, gt wouldn't be loaded.
+ ignore_index (int): The label index to be ignored. Default: 255
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default: False
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Default: None.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, and
+ self.PALETTE is None, random palette will be generated.
+ Default: None
+ """
+ CLASSES = None
+ PALETTE = None
+ def __init__(self,
+ pipeline,
+ img_dir,
+ img_suffix='.jpg',
+ ann_dir=None,
+ seg_map_suffix='.png',
+ split=None,
+ data_root=None,
+ test_mode=False,
+ ignore_index=255,
+ reduce_zero_label=False,
+ classes=None,
+ palette=None):
+ self.pipeline = Compose(pipeline)
+ self.img_dir = img_dir
+ self.img_suffix = img_suffix
+ self.ann_dir = ann_dir
+ self.seg_map_suffix = seg_map_suffix
+ self.split = split
+ self.data_root = data_root
+ self.test_mode = test_mode
+ self.ignore_index = ignore_index
+ self.reduce_zero_label = reduce_zero_label
+ self.label_map = None
+ self.CLASSES, self.PALETTE = self.get_classes_and_palette(
+ classes, palette)
+ # join paths if data_root is specified
+ if self.data_root is not None:
+ if not osp.isabs(self.img_dir):
+ self.img_dir = osp.join(self.data_root, self.img_dir)
+ if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
+ self.ann_dir = osp.join(self.data_root, self.ann_dir)
+ if not (self.split is None or osp.isabs(self.split)):
+ self.split = osp.join(self.data_root, self.split)
+ # load annotations
+ self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
+ self.ann_dir,
+ self.seg_map_suffix, self.split)
+ def __len__(self):
+ """Total number of samples of data."""
+ return len(self.img_infos)
+ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
+ split):
+ """Load annotation from directory.
+ Args:
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images.
+ ann_dir (str|None): Path to annotation directory.
+ seg_map_suffix (str|None): Suffix of segmentation maps.
+ split (str|None): Split txt file. If split is specified, only file
+ with suffix in the splits will be loaded. Otherwise, all images
+ in img_dir/ann_dir will be loaded. Default: None
+ Returns:
+ list[dict]: All image info of dataset.
+ """
+ img_infos = []
+ if split is not None:
+ with open(split) as f:
+ for line in f:
+ img_name = line.strip()
+ img_info = dict(filename=img_name + img_suffix)
+ if ann_dir is not None:
+ seg_map = img_name + seg_map_suffix
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+ else:
+ for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
+ img_info = dict(filename=img)
+ if ann_dir is not None:
+ seg_map = img.replace(img_suffix, seg_map_suffix)
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+ print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
+ return img_infos
+ def get_ann_info(self, idx):
+ """Get annotation by index.
+ Args:
+ idx (int): Index of data.
+ Returns:
+ dict: Annotation info of specified index.
+ """
+ return self.img_infos[idx]['ann']
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['seg_fields'] = []
+ results['img_prefix'] = self.img_dir
+ results['seg_prefix'] = self.ann_dir
+ if self.custom_classes:
+ results['label_map'] = self.label_map
+ def __getitem__(self, idx):
+ """Get training/test data after pipeline.
+ Args:
+ idx (int): Index of data.
+ Returns:
+ dict: Training/test data (with annotation if `test_mode` is set
+ False).
+ """
+ if self.test_mode:
+ return self.prepare_test_img(idx)
+ else:
+ return self.prepare_train_img(idx)
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline.
+ Args:
+ idx (int): Index of data.
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+ img_info = self.img_infos[idx]
+ ann_info = self.get_ann_info(idx)
+ results = dict(img_info=img_info, ann_info=ann_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline.
+ Args:
+ idx (int): Index of data.
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by
+ pipeline.
+ """
+ img_info = self.img_infos[idx]
+ results = dict(img_info=img_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+ def format_results(self, results, **kwargs):
+ """Place holder to format result to dataset specific output."""
+ def get_gt_seg_maps(self, efficient_test=False):
+ """Get ground truth segmentation maps for evaluation."""
+ gt_seg_maps = []
+ for img_info in self.img_infos:
+ seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
+ if efficient_test:
+ gt_seg_map = seg_map
+ else:
+ gt_seg_map = mmcv.imread(
+ seg_map, flag='unchanged', backend='pillow')
+ gt_seg_maps.append(gt_seg_map)
+ return gt_seg_maps
+ def get_classes_and_palette(self, classes=None, palette=None):
+ """Get class names of current dataset.
+ Args:
+ classes (Sequence[str] | str | None): If classes is None, use
+ default CLASSES defined by builtin dataset. If classes is a
+ string, take it as a file name. The file contains the name of
+ classes where each line contains one class name. If classes is
+ a tuple or list, override the CLASSES defined by the dataset.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, random
+ palette will be generated. Default: None
+ """
+ if classes is None:
+ self.custom_classes = False
+ return self.CLASSES, self.PALETTE
+ self.custom_classes = True
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmcv.list_from_file(classes)
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ else:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+ if self.CLASSES:
+ if not set(classes).issubset(self.CLASSES):
+ raise ValueError('classes is not a subset of CLASSES.')
+ # dictionary, its keys are the old label ids and its values
+ # are the new label ids.
+ # used for changing pixel labels in load_annotations.
+ self.label_map = {}
+ for i, c in enumerate(self.CLASSES):
+ if c not in class_names:
+ self.label_map[i] = -1
+ else:
+ self.label_map[i] = classes.index(c)
+ palette = self.get_palette_for_custom_classes(class_names, palette)
+ return class_names, palette
+ def get_palette_for_custom_classes(self, class_names, palette=None):
+ if self.label_map is not None:
+ # return subset of palette
+ palette = []
+ for old_id, new_id in sorted(
+ self.label_map.items(), key=lambda x: x[1]):
+ if new_id != -1:
+ palette.append(self.PALETTE[old_id])
+ palette = type(self.PALETTE)(palette)
+ elif palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(0, 255, size=(len(class_names), 3))
+ else:
+ palette = self.PALETTE
+ return palette
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ efficient_test=False,
+ **kwargs):
+ """Evaluate the dataset.
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. 'mIoU',
+ 'mDice' and 'mFscore' are supported.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ Returns:
+ dict[str, float]: Default metrics.
+ """
+ if isinstance(metric, str):
+ metric = [metric]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metric).issubset(set(allowed_metrics)):
+ raise KeyError('metric {} is not supported'.format(metric))
+ eval_results = {}
+ gt_seg_maps = self.get_gt_seg_maps(efficient_test)
+ if self.CLASSES is None:
+ num_classes = len(
+ reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
+ else:
+ num_classes = len(self.CLASSES)
+ ret_metrics = eval_metrics(
+ results,
+ gt_seg_maps,
+ num_classes,
+ self.ignore_index,
+ metric,
+ label_map=self.label_map,
+ reduce_zero_label=self.reduce_zero_label)
+ if self.CLASSES is None:
+ class_names = tuple(range(num_classes))
+ else:
+ class_names = self.CLASSES
+ # summary table
+ ret_metrics_summary = OrderedDict({
+ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+ # each class table
+ ret_metrics.pop('aAcc', None)
+ ret_metrics_class = OrderedDict({
+ ret_metric: np.round(ret_metric_value * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+ ret_metrics_class.update({'Class': class_names})
+ ret_metrics_class.move_to_end('Class', last=False)
+ # for logger
+ class_table_data = PrettyTable()
+ for key, val in ret_metrics_class.items():
+ class_table_data.add_column(key, val)
+ summary_table_data = PrettyTable()
+ for key, val in ret_metrics_summary.items():
+ if key == 'aAcc':
+ summary_table_data.add_column(key, [val])
+ else:
+ summary_table_data.add_column('m' + key, [val])
+ print_log('per class results:', logger)
+ print_log('\n' + class_table_data.get_string(), logger=logger)
+ print_log('Summary:', logger)
+ print_log('\n' + summary_table_data.get_string(), logger=logger)
+ # each metric dict
+ for key, value in ret_metrics_summary.items():
+ if key == 'aAcc':
+ eval_results[key] = value / 100.0
+ else:
+ eval_results['m' + key] = value / 100.0
+ ret_metrics_class.pop('Class', None)
+ for key, value in ret_metrics_class.items():
+ eval_results.update({
+ key + '.' + str(name): value[idx] / 100.0
+ for idx, name in enumerate(class_names)
+ })
+ if mmcv.is_list_of(results, str):
+ for file_name in results:
+ os.remove(file_name)
+ return eval_results
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/dataset_wrappers.py b/ControlNet/annotator/uniformer/mmseg/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/dataset_wrappers.py
@@ -0,0 +1,50 @@
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+from .builder import DATASETS
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ """
+ def __init__(self, datasets):
+ super(ConcatDataset, self).__init__(datasets)
+ self.CLASSES = datasets[0].CLASSES
+ self.PALETTE = datasets[0].PALETTE
+class RepeatDataset(object):
+ """A wrapper of repeated dataset.
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+ def __init__(self, dataset, times):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+ self.PALETTE = dataset.PALETTE
+ self._ori_len = len(self.dataset)
+ def __getitem__(self, idx):
+ """Get item from original dataset."""
+ return self.dataset[idx % self._ori_len]
+ def __len__(self):
+ """The length is multiplied by ``times``"""
+ return self.times * self._ori_len
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/drive.py b/ControlNet/annotator/uniformer/mmseg/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/drive.py
@@ -0,0 +1,27 @@
+import os.path as osp
+from .builder import DATASETS
+from .custom import CustomDataset
+class DRIVEDataset(CustomDataset):
+ """DRIVE dataset.
+ In segmentation map annotation for DRIVE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_manual1.png'.
+ """
+ CLASSES = ('background', 'vessel')
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+ def __init__(self, **kwargs):
+ super(DRIVEDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_manual1.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/hrf.py b/ControlNet/annotator/uniformer/mmseg/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/hrf.py
@@ -0,0 +1,27 @@
+import os.path as osp
+from .builder import DATASETS
+from .custom import CustomDataset
+class HRFDataset(CustomDataset):
+ """HRF dataset.
+ In segmentation map annotation for HRF, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+ CLASSES = ('background', 'vessel')
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+ def __init__(self, **kwargs):
+ super(HRFDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/pascal_context.py b/ControlNet/annotator/uniformer/mmseg/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..541a63c66a13fb16fd52921e755715ad8d078fdd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/pascal_context.py
@@ -0,0 +1,103 @@
+import os.path as osp
+from .builder import DATASETS
+from .custom import CustomDataset
+class PascalContextDataset(CustomDataset):
+ """PascalContext dataset.
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+ CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
+ 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
+ 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
+ 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
+ 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
+ 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
+ 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
+ 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
+ 'window', 'wood')
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
+class PascalContextDataset59(CustomDataset):
+ """PascalContext dataset.
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+ CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
+ 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
+ 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
+ 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
+ 'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
+ 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
+ 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
+ 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
+ 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
+ PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset59, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=True,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/__init__.py b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9046b07bb4ddea7a707a392b42e72db7c9df67
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/__init__.py
@@ -0,0 +1,16 @@
+from .compose import Compose
+from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
+ Transpose, to_tensor)
+from .loading import LoadAnnotations, LoadImageFromFile
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
+ PhotoMetricDistortion, RandomCrop, RandomFlip,
+ RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
+__all__ = [
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+ 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
+ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
+ 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
+ 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/compose.py b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbfcbb925c6d4ebf849328b9f94ef6fc24359bf5
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/compose.py
@@ -0,0 +1,51 @@
+import collections
+from annotator.uniformer.mmcv.utils import build_from_cfg
+from ..builder import PIPELINES
+class Compose(object):
+ """Compose multiple transforms sequentially.
+ Args:
+ transforms (Sequence[dict | callable]): Sequence of transform object or
+ config dict to be composed.
+ """
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict')
+ def __call__(self, data):
+ """Call function to apply transforms sequentially.
+ Args:
+ data (dict): A result dict contains the data to transform.
+ Returns:
+ dict: Transformed data.
+ """
+ for t in self.transforms:
+ data = t(data)
+ if data is None:
+ return None
+ return data
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += f' {t}'
+ format_string += '\n)'
+ return format_string
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/formating.py b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..97db85f4f9db39fb86ba77ead7d1a8407d810adb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/formating.py
@@ -0,0 +1,288 @@
+from collections.abc import Sequence
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+from annotator.uniformer.mmcv.parallel import DataContainer as DC
+from ..builder import PIPELINES
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+class ToTensor(object):
+ """Convert some results to :obj:`torch.Tensor` by given keys.
+ Args:
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
+ """
+ def __init__(self, keys):
+ self.keys = keys
+ def __call__(self, results):
+ """Call function to convert data in results to :obj:`torch.Tensor`.
+ Args:
+ results (dict): Result dict contains the data to convert.
+ Returns:
+ dict: The result dict contains the data converted
+ to :obj:`torch.Tensor`.
+ """
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+class ImageToTensor(object):
+ """Convert image to :obj:`torch.Tensor` by given keys.
+ The dimension order of input image is (H, W, C). The pipeline will convert
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+ (1, H, W).
+ Args:
+ keys (Sequence[str]): Key of images to be converted to Tensor.
+ """
+ def __init__(self, keys):
+ self.keys = keys
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+ Args:
+ results (dict): Result dict contains the image data to convert.
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img.transpose(2, 0, 1))
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+class Transpose(object):
+ """Transpose some results by given keys.
+ Args:
+ keys (Sequence[str]): Keys of results to be transposed.
+ order (Sequence[int]): Order of transpose.
+ """
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+ Args:
+ results (dict): Result dict contains the image data to convert.
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+class ToDataContainer(object):
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True),
+ dict(key='gt_semantic_seg'))``.
+ """
+ def __init__(self,
+ fields=(dict(key='img',
+ stack=True), dict(key='gt_semantic_seg'))):
+ self.fields = fields
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+ Args:
+ results (dict): Result dict contains the data to convert.
+ Returns:
+ dict: The result dict contains the data converted to
+ :obj:`mmcv.DataContainer`.
+ """
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+class DefaultFormatBundle(object):
+ """Default formatting bundle.
+ It simplifies the pipeline of formatting common fields, including "img"
+ and "gt_semantic_seg". These fields are formatted as follows.
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
+ (3)to DataContainer (stack=True)
+ """
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+ Args:
+ results (dict): Result dict contains the data to convert.
+ Returns:
+ dict: The result dict contains the data that is formatted with
+ default bundle.
+ """
+ if 'img' in results:
+ img = results['img']
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ results['img'] = DC(to_tensor(img), stack=True)
+ if 'gt_semantic_seg' in results:
+ # convert to long
+ results['gt_semantic_seg'] = DC(
+ to_tensor(results['gt_semantic_seg'][None,
+ ...].astype(np.int64)),
+ stack=True)
+ return results
+ def __repr__(self):
+ return self.__class__.__name__
+class Collect(object):
+ """Collect data from the loader relevant to the specific task.
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img", "gt_semantic_seg".
+ The "img_meta" item is always populated. The contents of the "img_meta"
+ dictionary depends on "meta_keys". By default this includes:
+ - "img_shape": shape of the image input to the network as a tuple
+ (h, w, c). Note that images may be zero padded on the bottom/right
+ if the batch tensor is larger than this shape.
+ - "scale_factor": a float indicating the preprocessing scale
+ - "flip": a boolean indicating if image flip transform was used
+ - "filename": path to the image file
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
+ - "pad_shape": image shape after padding
+ - "img_norm_cfg": a dict of normalization information:
+ - mean - per channel mean subtraction
+ - std - per channel std divisor
+ - to_rgb - bool indicating if bgr was converted to rgb
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'img_norm_cfg')``
+ """
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+ def __call__(self, results):
+ """Call function to collect keys in results. The keys in ``meta_keys``
+ will be converted to :obj:mmcv.DataContainer.
+ Args:
+ results (dict): Result dict contains the data to collect.
+ Returns:
+ dict: The result dict contains the following keys
+ - keys in``self.keys``
+ - ``img_metas``
+ """
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/loading.py b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3692ae91f19b9c7ccf6023168788ff42c9e93e3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/loading.py
@@ -0,0 +1,153 @@
+import os.path as osp
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from ..builder import PIPELINES
+class LoadImageFromFile(object):
+ """Load an image from file.
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'cv2'
+ """
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='cv2'):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+ def __call__(self, results):
+ """Call functions to load image and get image meta information.
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+ if results.get('img_prefix') is not None:
+ filename = osp.join(results['img_prefix'],
+ results['img_info']['filename'])
+ else:
+ filename = results['img_info']['filename']
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes, flag=self.color_type, backend=self.imdecode_backend)
+ if self.to_float32:
+ img = img.astype(np.float32)
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ # Set initial values for default meta_keys
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False)
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(to_float32={self.to_float32},'
+ repr_str += f"color_type='{self.color_type}',"
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
+class LoadAnnotations(object):
+ """Load annotations for semantic segmentation.
+ Args:
+ reduce_zero_label (bool): Whether reduce all label value by 1.
+ Usually used for datasets where 0 is background label.
+ Default: False.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'pillow'
+ """
+ def __init__(self,
+ reduce_zero_label=False,
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='pillow'):
+ self.reduce_zero_label = reduce_zero_label
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+ def __call__(self, results):
+ """Call function to load multiple types annotations.
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+ Returns:
+ dict: The dict contains loaded semantic segmentation annotations.
+ """
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+ if results.get('seg_prefix', None) is not None:
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ else:
+ filename = results['ann_info']['seg_map']
+ img_bytes = self.file_client.get(filename)
+ gt_semantic_seg = mmcv.imfrombytes(
+ img_bytes, flag='unchanged',
+ backend=self.imdecode_backend).squeeze().astype(np.uint8)
+ # modify if custom classes
+ if results.get('label_map', None) is not None:
+ for old_id, new_id in results['label_map'].items():
+ gt_semantic_seg[gt_semantic_seg == old_id] = new_id
+ # reduce zero_label
+ if self.reduce_zero_label:
+ # avoid using underflow conversion
+ gt_semantic_seg[gt_semantic_seg == 0] = 255
+ gt_semantic_seg = gt_semantic_seg - 1
+ gt_semantic_seg[gt_semantic_seg == 254] = 255
+ results['gt_semantic_seg'] = gt_semantic_seg
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1611a04d9d927223c9afbe5bf68af04d62937a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,133 @@
+import warnings
+import annotator.uniformer.mmcv as mmcv
+from ..builder import PIPELINES
+from .compose import Compose
+class MultiScaleFlipAug(object):
+ """Test-time augmentation with multiple scales and flipping.
+ An example configuration is as followed:
+ .. code-block::
+ img_scale=(2048, 1024),
+ img_ratios=[0.5, 1.0],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ]
+ After MultiScaleFLipAug with above configuration, the results are wrapped
+ into lists of the same length as followed:
+ .. code-block::
+ dict(
+ img=[...],
+ img_shape=[...],
+ scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
+ flip=[False, True, False, True]
+ ...
+ )
+ Args:
+ transforms (list[dict]): Transforms to apply in each augmentation.
+ img_scale (None | tuple | list[tuple]): Images scales for resizing.
+ img_ratios (float | list[float]): Image ratios for resizing
+ flip (bool): Whether apply flip augmentation. Default: False.
+ flip_direction (str | list[str]): Flip augmentation directions,
+ options are "horizontal" and "vertical". If flip_direction is list,
+ multiple flip augmentations will be applied.
+ It has no effect when flip == False. Default: "horizontal".
+ """
+ def __init__(self,
+ transforms,
+ img_scale,
+ img_ratios=None,
+ flip=False,
+ flip_direction='horizontal'):
+ self.transforms = Compose(transforms)
+ if img_ratios is not None:
+ img_ratios = img_ratios if isinstance(img_ratios,
+ list) else [img_ratios]
+ assert mmcv.is_list_of(img_ratios, float)
+ if img_scale is None:
+ # mode 1: given img_scale=None and a range of image ratio
+ self.img_scale = None
+ assert mmcv.is_list_of(img_ratios, float)
+ elif isinstance(img_scale, tuple) and mmcv.is_list_of(
+ img_ratios, float):
+ assert len(img_scale) == 2
+ # mode 2: given a scale and a range of image ratio
+ self.img_scale = [(int(img_scale[0] * ratio),
+ int(img_scale[1] * ratio))
+ for ratio in img_ratios]
+ else:
+ # mode 3: given multiple scales
+ self.img_scale = img_scale if isinstance(img_scale,
+ list) else [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
+ self.flip = flip
+ self.img_ratios = img_ratios
+ self.flip_direction = flip_direction if isinstance(
+ flip_direction, list) else [flip_direction]
+ assert mmcv.is_list_of(self.flip_direction, str)
+ if not self.flip and self.flip_direction != ['horizontal']:
+ warnings.warn(
+ 'flip_direction has no effect when flip is set to False')
+ if (self.flip
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
+ warnings.warn(
+ 'flip has no effect when RandomFlip is not in transforms')
+ def __call__(self, results):
+ """Call function to apply test time augment transforms on results.
+ Args:
+ results (dict): Result dict contains the data to transform.
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+ aug_data = []
+ if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
+ h, w = results['img'].shape[:2]
+ img_scale = [(int(w * ratio), int(h * ratio))
+ for ratio in self.img_ratios]
+ else:
+ img_scale = self.img_scale
+ flip_aug = [False, True] if self.flip else [False]
+ for scale in img_scale:
+ for flip in flip_aug:
+ for direction in self.flip_direction:
+ _results = results.copy()
+ _results['scale'] = scale
+ _results['flip'] = flip
+ _results['flip_direction'] = direction
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
+ repr_str += f'flip_direction={self.flip_direction}'
+ return repr_str
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/transforms.py b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..94e869b252ef6d8b43604add2bbc02f034614bfb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/pipelines/transforms.py
@@ -0,0 +1,889 @@
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import deprecated_api_warning, is_tuple_of
+from numpy import random
+from ..builder import PIPELINES
+class Resize(object):
+ """Resize images & seg.
+ This transform resizes the input image to some scale. If the input dict
+ contains the key "scale", then the scale in the input dict is used,
+ otherwise the specified scale in the init method is used.
+ ``img_scale`` can be None, a tuple (single-scale) or a list of tuple
+ (multi-scale). There are 4 multiscale modes:
+ - ``ratio_range is not None``:
+ 1. When img_scale is None, img_scale is the shape of image in results
+ (img_scale = results['img'].shape[:2]) and the image is resized based
+ on the original size. (mode 1)
+ 2. When img_scale is a tuple (single-scale), randomly sample a ratio from
+ the ratio range and multiply it with the image scale. (mode 2)
+ - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
+ scale from the a range. (mode 3)
+ - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
+ scale from multiple scales. (mode 4)
+ Args:
+ img_scale (tuple or list[tuple]): Images scales for resizing.
+ multiscale_mode (str): Either "range" or "value".
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image.
+ """
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=True):
+ if img_scale is None:
+ self.img_scale = None
+ else:
+ if isinstance(img_scale, list):
+ self.img_scale = img_scale
+ else:
+ self.img_scale = [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple)
+ if ratio_range is not None:
+ # mode 1: given img_scale=None and a range of image ratio
+ # mode 2: given a scale and a range of image ratio
+ assert self.img_scale is None or len(self.img_scale) == 1
+ else:
+ # mode 3 and 4: given multiple scales or a range of scales
+ assert multiscale_mode in ['value', 'range']
+ self.multiscale_mode = multiscale_mode
+ self.ratio_range = ratio_range
+ self.keep_ratio = keep_ratio
+ @staticmethod
+ def random_select(img_scales):
+ """Randomly select an img_scale from given candidates.
+ Args:
+ img_scales (list[tuple]): Images scales for selection.
+ Returns:
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
+ where ``img_scale`` is the selected image scale and
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+ assert mmcv.is_list_of(img_scales, tuple)
+ scale_idx = np.random.randint(len(img_scales))
+ img_scale = img_scales[scale_idx]
+ return img_scale, scale_idx
+ @staticmethod
+ def random_sample(img_scales):
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+ Args:
+ img_scales (list[tuple]): Images scale range for sampling.
+ There must be two tuples in img_scales, which specify the lower
+ and upper bound of image scales.
+ Returns:
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where
+ ``img_scale`` is sampled scale and None is just a placeholder
+ to be consistent with :func:`random_select`.
+ """
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+ img_scale_long = [max(s) for s in img_scales]
+ img_scale_short = [min(s) for s in img_scales]
+ long_edge = np.random.randint(
+ min(img_scale_long),
+ max(img_scale_long) + 1)
+ short_edge = np.random.randint(
+ min(img_scale_short),
+ max(img_scale_short) + 1)
+ img_scale = (long_edge, short_edge)
+ return img_scale, None
+ @staticmethod
+ def random_sample_ratio(img_scale, ratio_range):
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
+ A ratio will be randomly sampled from the range specified by
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+ generate sampled scale.
+ Args:
+ img_scale (tuple): Images scale base to multiply with ratio.
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
+ the ``img_scale``.
+ Returns:
+ (tuple, None): Returns a tuple ``(scale, None)``, where
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and
+ None is just a placeholder to be consistent with
+ :func:`random_select`.
+ """
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
+ min_ratio, max_ratio = ratio_range
+ assert min_ratio <= max_ratio
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+ return scale, None
+ def _random_scale(self, results):
+ """Randomly sample an img_scale according to ``ratio_range`` and
+ ``multiscale_mode``.
+ If ``ratio_range`` is specified, a ratio will be sampled and be
+ multiplied with ``img_scale``.
+ If multiple scales are specified by ``img_scale``, a scale will be
+ sampled according to ``multiscale_mode``.
+ Otherwise, single scale will be used.
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+ Returns:
+ dict: Two new keys 'scale` and 'scale_idx` are added into
+ ``results``, which would be used by subsequent pipelines.
+ """
+ if self.ratio_range is not None:
+ if self.img_scale is None:
+ h, w = results['img'].shape[:2]
+ scale, scale_idx = self.random_sample_ratio((w, h),
+ self.ratio_range)
+ else:
+ scale, scale_idx = self.random_sample_ratio(
+ self.img_scale[0], self.ratio_range)
+ elif len(self.img_scale) == 1:
+ scale, scale_idx = self.img_scale[0], 0
+ elif self.multiscale_mode == 'range':
+ scale, scale_idx = self.random_sample(self.img_scale)
+ elif self.multiscale_mode == 'value':
+ scale, scale_idx = self.random_select(self.img_scale)
+ else:
+ raise NotImplementedError
+ results['scale'] = scale
+ results['scale_idx'] = scale_idx
+ def _resize_img(self, results):
+ """Resize images with ``results['scale']``."""
+ if self.keep_ratio:
+ img, scale_factor = mmcv.imrescale(
+ results['img'], results['scale'], return_scale=True)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results['img'].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ results['img'], results['scale'], return_scale=True)
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+ dtype=np.float32)
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['pad_shape'] = img.shape # in case that there is no padding
+ results['scale_factor'] = scale_factor
+ results['keep_ratio'] = self.keep_ratio
+ def _resize_seg(self, results):
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for key in results.get('seg_fields', []):
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[key], results['scale'], interpolation='nearest')
+ else:
+ gt_seg = mmcv.imresize(
+ results[key], results['scale'], interpolation='nearest')
+ results[key] = gt_seg
+ def __call__(self, results):
+ """Call function to resize images, bounding boxes, masks, semantic
+ segmentation map.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
+ 'keep_ratio' keys are added into result dict.
+ """
+ if 'scale' not in results:
+ self._random_scale(results)
+ self._resize_img(results)
+ self._resize_seg(results)
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(img_scale={self.img_scale}, '
+ f'multiscale_mode={self.multiscale_mode}, '
+ f'ratio_range={self.ratio_range}, '
+ f'keep_ratio={self.keep_ratio})')
+ return repr_str
+class RandomFlip(object):
+ """Flip the image & seg.
+ If the input dict contains the key "flip", then the flag will be used,
+ otherwise it will be randomly decided by a ratio specified in the init
+ method.
+ Args:
+ prob (float, optional): The flipping probability. Default: None.
+ direction(str, optional): The flipping direction. Options are
+ 'horizontal' and 'vertical'. Default: 'horizontal'.
+ """
+ @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
+ def __init__(self, prob=None, direction='horizontal'):
+ self.prob = prob
+ self.direction = direction
+ if prob is not None:
+ assert prob >= 0 and prob <= 1
+ assert direction in ['horizontal', 'vertical']
+ def __call__(self, results):
+ """Call function to flip bounding boxes, masks, semantic segmentation
+ maps.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Flipped results, 'flip', 'flip_direction' keys are added into
+ result dict.
+ """
+ if 'flip' not in results:
+ flip = True if np.random.rand() < self.prob else False
+ results['flip'] = flip
+ if 'flip_direction' not in results:
+ results['flip_direction'] = self.direction
+ if results['flip']:
+ # flip image
+ results['img'] = mmcv.imflip(
+ results['img'], direction=results['flip_direction'])
+ # flip segs
+ for key in results.get('seg_fields', []):
+ # use copy() to make numpy stride positive
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction']).copy()
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + f'(prob={self.prob})'
+class Pad(object):
+ """Pad the image & mask.
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_val (float, optional): Padding value. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ """
+ def __init__(self,
+ size=None,
+ size_divisor=None,
+ pad_val=0,
+ seg_pad_val=255):
+ self.size = size
+ self.size_divisor = size_divisor
+ self.pad_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ # only one of size and size_divisor should be valid
+ assert size is not None or size_divisor is not None
+ assert size is None or size_divisor is None
+ def _pad_img(self, results):
+ """Pad images according to ``self.size``."""
+ if self.size is not None:
+ padded_img = mmcv.impad(
+ results['img'], shape=self.size, pad_val=self.pad_val)
+ elif self.size_divisor is not None:
+ padded_img = mmcv.impad_to_multiple(
+ results['img'], self.size_divisor, pad_val=self.pad_val)
+ results['img'] = padded_img
+ results['pad_shape'] = padded_img.shape
+ results['pad_fixed_size'] = self.size
+ results['pad_size_divisor'] = self.size_divisor
+ def _pad_seg(self, results):
+ """Pad masks according to ``results['pad_shape']``."""
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.impad(
+ results[key],
+ shape=results['pad_shape'][:2],
+ pad_val=self.seg_pad_val)
+ def __call__(self, results):
+ """Call function to pad images, masks, semantic segmentation maps.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Updated result dict.
+ """
+ self._pad_img(results)
+ self._pad_seg(results)
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
+ f'pad_val={self.pad_val})'
+ return repr_str
+class Normalize(object):
+ """Normalize the image.
+ Added key is "img_norm_cfg".
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+ def __call__(self, results):
+ """Call function to normalize images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+ results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
+ f'{self.to_rgb})'
+ return repr_str
+class Rerange(object):
+ """Rerange the image pixel value.
+ Args:
+ min_value (float or int): Minimum value of the reranged image.
+ Default: 0.
+ max_value (float or int): Maximum value of the reranged image.
+ Default: 255.
+ """
+ def __init__(self, min_value=0, max_value=255):
+ assert isinstance(min_value, float) or isinstance(min_value, int)
+ assert isinstance(max_value, float) or isinstance(max_value, int)
+ assert min_value < max_value
+ self.min_value = min_value
+ self.max_value = max_value
+ def __call__(self, results):
+ """Call function to rerange images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Reranged results.
+ """
+ img = results['img']
+ img_min_value = np.min(img)
+ img_max_value = np.max(img)
+ assert img_min_value < img_max_value
+ # rerange to [0, 1]
+ img = (img - img_min_value) / (img_max_value - img_min_value)
+ # rerange to [min_value, max_value]
+ img = img * (self.max_value - self.min_value) + self.min_value
+ results['img'] = img
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
+ return repr_str
+class CLAHE(object):
+ """Use CLAHE method to process the image.
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+ Args:
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+ """
+ def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
+ assert isinstance(clip_limit, (float, int))
+ self.clip_limit = clip_limit
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+ self.tile_grid_size = tile_grid_size
+ def __call__(self, results):
+ """Call function to Use CLAHE method process images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Processed results.
+ """
+ for i in range(results['img'].shape[2]):
+ results['img'][:, :, i] = mmcv.clahe(
+ np.array(results['img'][:, :, i], dtype=np.uint8),
+ self.clip_limit, self.tile_grid_size)
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(clip_limit={self.clip_limit}, '\
+ f'tile_grid_size={self.tile_grid_size})'
+ return repr_str
+class RandomCrop(object):
+ """Random crop the image & seg.
+ Args:
+ crop_size (tuple): Expected size after cropping, (h, w).
+ cat_max_ratio (float): The maximum ratio that single category could
+ occupy.
+ """
+ def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ self.crop_size = crop_size
+ self.cat_max_ratio = cat_max_ratio
+ self.ignore_index = ignore_index
+ def get_crop_bbox(self, img):
+ """Randomly get a crop bounding box."""
+ margin_h = max(img.shape[0] - self.crop_size[0], 0)
+ margin_w = max(img.shape[1] - self.crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
+ return crop_y1, crop_y2, crop_x1, crop_x2
+ def crop(self, img, crop_bbox):
+ """Crop from ``img``"""
+ crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ return img
+ def __call__(self, results):
+ """Call function to randomly crop images, semantic segmentation maps.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+ img = results['img']
+ crop_bbox = self.get_crop_bbox(img)
+ if self.cat_max_ratio < 1.:
+ # Repeat 10 times
+ for _ in range(10):
+ seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
+ labels, cnt = np.unique(seg_temp, return_counts=True)
+ cnt = cnt[labels != self.ignore_index]
+ if len(cnt) > 1 and np.max(cnt) / np.sum(
+ cnt) < self.cat_max_ratio:
+ break
+ crop_bbox = self.get_crop_bbox(img)
+ # crop the image
+ img = self.crop(img, crop_bbox)
+ img_shape = img.shape
+ results['img'] = img
+ results['img_shape'] = img_shape
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ results[key] = self.crop(results[key], crop_bbox)
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + f'(crop_size={self.crop_size})'
+class RandomRotate(object):
+ """Rotate the image & seg.
+ Args:
+ prob (float): The rotation probability.
+ degree (float, tuple[float]): Range of degrees to select from. If
+ degree is a number instead of tuple like (min, max),
+ the range of degree will be (``-degree``, ``+degree``)
+ pad_val (float, optional): Padding value of image. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used. Default: None.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image. Default: False
+ """
+ def __init__(self,
+ prob,
+ degree,
+ pad_val=0,
+ seg_pad_val=255,
+ center=None,
+ auto_bound=False):
+ self.prob = prob
+ assert prob >= 0 and prob <= 1
+ if isinstance(degree, (float, int)):
+ assert degree > 0, f'degree {degree} should be positive'
+ self.degree = (-degree, degree)
+ else:
+ self.degree = degree
+ assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
+ f'tuple of (min, max)'
+ self.pal_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ self.center = center
+ self.auto_bound = auto_bound
+ def __call__(self, results):
+ """Call function to rotate image, semantic segmentation maps.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Rotated results.
+ """
+ rotate = True if np.random.rand() < self.prob else False
+ degree = np.random.uniform(min(*self.degree), max(*self.degree))
+ if rotate:
+ # rotate image
+ results['img'] = mmcv.imrotate(
+ results['img'],
+ angle=degree,
+ border_value=self.pal_val,
+ center=self.center,
+ auto_bound=self.auto_bound)
+ # rotate segs
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.imrotate(
+ results[key],
+ angle=degree,
+ border_value=self.seg_pad_val,
+ center=self.center,
+ auto_bound=self.auto_bound,
+ interpolation='nearest')
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob}, ' \
+ f'degree={self.degree}, ' \
+ f'pad_val={self.pal_val}, ' \
+ f'seg_pad_val={self.seg_pad_val}, ' \
+ f'center={self.center}, ' \
+ f'auto_bound={self.auto_bound})'
+ return repr_str
+class RGB2Gray(object):
+ """Convert RGB image to grayscale image.
+ This transform calculate the weighted mean of input image channels with
+ ``weights`` and then expand the channels to ``out_channels``. When
+ ``out_channels`` is None, the number of output channels is the same as
+ input channels.
+ Args:
+ out_channels (int): Expected number of output channels after
+ transforming. Default: None.
+ weights (tuple[float]): The weights to calculate the weighted mean.
+ Default: (0.299, 0.587, 0.114).
+ """
+ def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
+ assert out_channels is None or out_channels > 0
+ self.out_channels = out_channels
+ assert isinstance(weights, tuple)
+ for item in weights:
+ assert isinstance(item, (float, int))
+ self.weights = weights
+ def __call__(self, results):
+ """Call function to convert RGB image to grayscale image.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Result dict with grayscale image.
+ """
+ img = results['img']
+ assert len(img.shape) == 3
+ assert img.shape[2] == len(self.weights)
+ weights = np.array(self.weights).reshape((1, 1, -1))
+ img = (img * weights).sum(2, keepdims=True)
+ if self.out_channels is None:
+ img = img.repeat(weights.shape[2], axis=2)
+ else:
+ img = img.repeat(self.out_channels, axis=2)
+ results['img'] = img
+ results['img_shape'] = img.shape
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(out_channels={self.out_channels}, ' \
+ f'weights={self.weights})'
+ return repr_str
+class AdjustGamma(object):
+ """Using gamma correction to process the image.
+ Args:
+ gamma (float or int): Gamma value used in gamma correction.
+ Default: 1.0.
+ """
+ def __init__(self, gamma=1.0):
+ assert isinstance(gamma, float) or isinstance(gamma, int)
+ assert gamma > 0
+ self.gamma = gamma
+ inv_gamma = 1.0 / gamma
+ self.table = np.array([(i / 255.0)**inv_gamma * 255
+ for i in np.arange(256)]).astype('uint8')
+ def __call__(self, results):
+ """Call function to process the image with gamma correction.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Processed results.
+ """
+ results['img'] = mmcv.lut_transform(
+ np.array(results['img'], dtype=np.uint8), self.table)
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + f'(gamma={self.gamma})'
+class SegRescale(object):
+ """Rescale semantic segmentation maps.
+ Args:
+ scale_factor (float): The scale factor of the final output.
+ """
+ def __init__(self, scale_factor=1):
+ self.scale_factor = scale_factor
+ def __call__(self, results):
+ """Call function to scale the semantic segmentation map.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Result dict with semantic segmentation map scaled.
+ """
+ for key in results.get('seg_fields', []):
+ if self.scale_factor != 1:
+ results[key] = mmcv.imrescale(
+ results[key], self.scale_factor, interpolation='nearest')
+ return results
+ def __repr__(self):
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+class PhotoMetricDistortion(object):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+ def convert(self, img, alpha=1, beta=0):
+ """Multiple with alpha and add beat with clip."""
+ img = img.astype(np.float32) * alpha + beta
+ img = np.clip(img, 0, 255)
+ return img.astype(np.uint8)
+ def brightness(self, img):
+ """Brightness distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ beta=random.uniform(-self.brightness_delta,
+ self.brightness_delta))
+ return img
+ def contrast(self, img):
+ """Contrast distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ alpha=random.uniform(self.contrast_lower, self.contrast_upper))
+ return img
+ def saturation(self, img):
+ """Saturation distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :, 1] = self.convert(
+ img[:, :, 1],
+ alpha=random.uniform(self.saturation_lower,
+ self.saturation_upper))
+ img = mmcv.hsv2bgr(img)
+ return img
+ def hue(self, img):
+ """Hue distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :,
+ 0] = (img[:, :, 0].astype(int) +
+ random.randint(-self.hue_delta, self.hue_delta)) % 180
+ img = mmcv.hsv2bgr(img)
+ return img
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Result dict with images distorted.
+ """
+ img = results['img']
+ # random brightness
+ img = self.brightness(img)
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ img = self.contrast(img)
+ # random saturation
+ img = self.saturation(img)
+ # random hue
+ img = self.hue(img)
+ # random contrast
+ if mode == 0:
+ img = self.contrast(img)
+ results['img'] = img
+ return results
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(brightness_delta={self.brightness_delta}, '
+ f'contrast_range=({self.contrast_lower}, '
+ f'{self.contrast_upper}), '
+ f'saturation_range=({self.saturation_lower}, '
+ f'{self.saturation_upper}), '
+ f'hue_delta={self.hue_delta})')
+ return repr_str
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/stare.py b/ControlNet/annotator/uniformer/mmseg/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/stare.py
@@ -0,0 +1,27 @@
+import os.path as osp
+from .builder import DATASETS
+from .custom import CustomDataset
+class STAREDataset(CustomDataset):
+ """STARE dataset.
+ In segmentation map annotation for STARE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.ah.png'.
+ """
+ CLASSES = ('background', 'vessel')
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+ def __init__(self, **kwargs):
+ super(STAREDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.ah.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/ControlNet/annotator/uniformer/mmseg/datasets/voc.py b/ControlNet/annotator/uniformer/mmseg/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/datasets/voc.py
@@ -0,0 +1,29 @@
+import os.path as osp
+from .builder import DATASETS
+from .custom import CustomDataset
+class PascalVOCDataset(CustomDataset):
+ """Pascal VOC dataset.
+ Args:
+ split (str): Split txt file for Pascal VOC.
+ """
+ CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
+ 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
+ 'train', 'tvmonitor')
+ PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+ def __init__(self, split, **kwargs):
+ super(PascalVOCDataset, self).__init__(
+ img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/ControlNet/annotator/uniformer/mmseg/models/__init__.py b/ControlNet/annotator/uniformer/mmseg/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf93f8bec9cf0cef0a3bd76ca3ca92eb188f535
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/__init__.py
@@ -0,0 +1,12 @@
+from .backbones import * # noqa: F401,F403
+from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
+ build_head, build_loss, build_segmentor)
+from .decode_heads import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .segmentors import * # noqa: F401,F403
+__all__ = [
+ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
+ 'build_head', 'build_loss', 'build_segmentor'
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/__init__.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8339983905fb5d20bae42ba6f76fea75d278b1aa
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/__init__.py
@@ -0,0 +1,17 @@
+from .cgnet import CGNet
+# from .fast_scnn import FastSCNN
+from .hrnet import HRNet
+from .mobilenet_v2 import MobileNetV2
+from .mobilenet_v3 import MobileNetV3
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1c, ResNetV1d
+from .resnext import ResNeXt
+from .unet import UNet
+from .vit import VisionTransformer
+from .uniformer import UniFormer
+__all__ = [
+ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet',
+ 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
+ 'VisionTransformer', 'UniFormer'
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/cgnet.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8bca442c8f18179f217e40c298fb5ef39df77c4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/cgnet.py
@@ -0,0 +1,367 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
+ constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+class GlobalContextExtractor(nn.Module):
+ """Global Context Extractor for CGNet.
+ This class is employed to refine the joint feature of both local feature
+ and surrounding context.
+ Args:
+ channel (int): Number of input feature channels.
+ reduction (int): Reductions for global context extractor. Default: 16.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+ def __init__(self, channel, reduction=16, with_cp=False):
+ super(GlobalContextExtractor, self).__init__()
+ self.channel = channel
+ self.reduction = reduction
+ assert reduction >= 1 and channel >= reduction
+ self.with_cp = with_cp
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel), nn.Sigmoid())
+ def forward(self, x):
+ def _inner_forward(x):
+ num_batch, num_channel = x.size()[:2]
+ y = self.avg_pool(x).view(num_batch, num_channel)
+ y = self.fc(y).view(num_batch, num_channel, 1, 1)
+ return x * y
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ return out
+class ContextGuidedBlock(nn.Module):
+ """Context Guided Block for CGNet.
+ This class consists of four components: local feature extractor,
+ surrounding feature extractor, joint feature extractor and global
+ context extractor.
+ Args:
+ in_channels (int): Number of input feature channels.
+ out_channels (int): Number of output feature channels.
+ dilation (int): Dilation rate for surrounding context extractor.
+ Default: 2.
+ reduction (int): Reduction for global context extractor. Default: 16.
+ skip_connect (bool): Add input to output or not. Default: True.
+ downsample (bool): Downsample the input to 1/2 or not. Default: False.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ dilation=2,
+ reduction=16,
+ skip_connect=True,
+ downsample=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ with_cp=False):
+ super(ContextGuidedBlock, self).__init__()
+ self.with_cp = with_cp
+ self.downsample = downsample
+ channels = out_channels if downsample else out_channels // 2
+ if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
+ act_cfg['num_parameters'] = channels
+ kernel_size = 3 if downsample else 1
+ stride = 2 if downsample else 1
+ padding = (kernel_size - 1) // 2
+ self.conv1x1 = ConvModule(
+ in_channels,
+ channels,
+ kernel_size,
+ stride,
+ padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.f_loc = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=1,
+ groups=channels,
+ bias=False)
+ self.f_sur = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=dilation,
+ groups=channels,
+ dilation=dilation,
+ bias=False)
+ self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
+ self.activate = nn.PReLU(2 * channels)
+ if downsample:
+ self.bottleneck = build_conv_layer(
+ conv_cfg,
+ 2 * channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+ self.skip_connect = skip_connect and not downsample
+ self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
+ def forward(self, x):
+ def _inner_forward(x):
+ out = self.conv1x1(x)
+ loc = self.f_loc(out)
+ sur = self.f_sur(out)
+ joi_feat = torch.cat([loc, sur], 1) # the joint feature
+ joi_feat = self.bn(joi_feat)
+ joi_feat = self.activate(joi_feat)
+ if self.downsample:
+ joi_feat = self.bottleneck(joi_feat) # channel = out_channels
+ # f_glo is employed to refine the joint feature
+ out = self.f_glo(joi_feat)
+ if self.skip_connect:
+ return x + out
+ else:
+ return out
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ return out
+class InputInjection(nn.Module):
+ """Downsampling module for CGNet."""
+ def __init__(self, num_downsampling):
+ super(InputInjection, self).__init__()
+ self.pool = nn.ModuleList()
+ for i in range(num_downsampling):
+ self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
+ def forward(self, x):
+ for pool in self.pool:
+ x = pool(x)
+ return x
+class CGNet(nn.Module):
+ """CGNet backbone.
+ A Light-weight Context Guided Network for Semantic Segmentation
+ arXiv: https://arxiv.org/abs/1811.08201
+ Args:
+ in_channels (int): Number of input image channels. Normally 3.
+ num_channels (tuple[int]): Numbers of feature channels at each stages.
+ Default: (32, 64, 128).
+ num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
+ Default: (3, 21).
+ dilations (tuple[int]): Dilation rate for surrounding context
+ extractors at stage 1 and stage 2. Default: (2, 4).
+ reductions (tuple[int]): Reductions for global context extractors at
+ stage 1 and stage 2. Default: (8, 16).
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+ def __init__(self,
+ in_channels=3,
+ num_channels=(32, 64, 128),
+ num_blocks=(3, 21),
+ dilations=(2, 4),
+ reductions=(8, 16),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ norm_eval=False,
+ with_cp=False):
+ super(CGNet, self).__init__()
+ self.in_channels = in_channels
+ self.num_channels = num_channels
+ assert isinstance(self.num_channels, tuple) and len(
+ self.num_channels) == 3
+ self.num_blocks = num_blocks
+ assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
+ self.dilations = dilations
+ assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
+ self.reductions = reductions
+ assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
+ self.act_cfg['num_parameters'] = num_channels[0]
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ cur_channels = in_channels
+ self.stem = nn.ModuleList()
+ for i in range(3):
+ self.stem.append(
+ ConvModule(
+ cur_channels,
+ num_channels[0],
+ 3,
+ 2 if i == 0 else 1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ cur_channels = num_channels[0]
+ self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
+ self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
+ cur_channels += in_channels
+ self.norm_prelu_0 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+ # stage 1
+ self.level1 = nn.ModuleList()
+ for i in range(num_blocks[0]):
+ self.level1.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[1],
+ num_channels[1],
+ dilations[0],
+ reductions[0],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+ cur_channels = 2 * num_channels[1] + in_channels
+ self.norm_prelu_1 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+ # stage 2
+ self.level2 = nn.ModuleList()
+ for i in range(num_blocks[1]):
+ self.level2.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[2],
+ num_channels[2],
+ dilations[1],
+ reductions[1],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+ cur_channels = 2 * num_channels[2]
+ self.norm_prelu_2 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+ def forward(self, x):
+ output = []
+ # stage 0
+ inp_2x = self.inject_2x(x)
+ inp_4x = self.inject_4x(x)
+ for layer in self.stem:
+ x = layer(x)
+ x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
+ output.append(x)
+ # stage 1
+ for i, layer in enumerate(self.level1):
+ x = layer(x)
+ if i == 0:
+ down1 = x
+ x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
+ output.append(x)
+ # stage 2
+ for i, layer in enumerate(self.level2):
+ x = layer(x)
+ if i == 0:
+ down2 = x
+ x = self.norm_prelu_2(torch.cat([down2, x], 1))
+ output.append(x)
+ return output
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ elif isinstance(m, nn.PReLU):
+ constant_init(m, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(CGNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/fast_scnn.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c2350177cbc2066f45add568d30eb6041f74f3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/fast_scnn.py
@@ -0,0 +1,375 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
+ kaiming_init)
+from torch.nn.modules.batchnorm import _BatchNorm
+from annotator.uniformer.mmseg.models.decode_heads.psp_head import PPM
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import BACKBONES
+from ..utils.inverted_residual import InvertedResidual
+class LearningToDownsample(nn.Module):
+ """Learning to downsample module.
+ Args:
+ in_channels (int): Number of input channels.
+ dw_channels (tuple[int]): Number of output channels of the first and
+ the second depthwise conv (dwconv) layers.
+ out_channels (int): Number of output channels of the whole
+ 'learning to downsample' module.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ """
+ def __init__(self,
+ in_channels,
+ dw_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU')):
+ super(LearningToDownsample, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ dw_channels1 = dw_channels[0]
+ dw_channels2 = dw_channels[1]
+ self.conv = ConvModule(
+ in_channels,
+ dw_channels1,
+ 3,
+ stride=2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.dsconv1 = DepthwiseSeparableConvModule(
+ dw_channels1,
+ dw_channels2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+ self.dsconv2 = DepthwiseSeparableConvModule(
+ dw_channels2,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.dsconv1(x)
+ x = self.dsconv2(x)
+ return x
+class GlobalFeatureExtractor(nn.Module):
+ """Global feature extractor module.
+ Args:
+ in_channels (int): Number of input channels of the GFE module.
+ Default: 64
+ block_channels (tuple[int]): Tuple of ints. Each int specifies the
+ number of output channels of each Inverted Residual module.
+ Default: (64, 96, 128)
+ out_channels(int): Number of output channels of the GFE module.
+ Default: 128
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ Default: 6
+ num_blocks (tuple[int]): Tuple of ints. Each int specifies the
+ number of times each Inverted Residual module is repeated.
+ The repeated Inverted Residual modules are called a 'group'.
+ Default: (3, 3, 3)
+ strides (tuple[int]): Tuple of ints. Each int specifies
+ the downsampling factor of each 'group'.
+ Default: (2, 2, 1)
+ pool_scales (tuple[int]): Tuple of ints. Each int specifies
+ the parameter required in 'global average pooling' within PPM.
+ Default: (1, 2, 3, 6)
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+ def __init__(self,
+ in_channels=64,
+ block_channels=(64, 96, 128),
+ out_channels=128,
+ expand_ratio=6,
+ num_blocks=(3, 3, 3),
+ strides=(2, 2, 1),
+ pool_scales=(1, 2, 3, 6),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(GlobalFeatureExtractor, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ assert len(block_channels) == len(num_blocks) == 3
+ self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
+ num_blocks[0], strides[0],
+ expand_ratio)
+ self.bottleneck2 = self._make_layer(block_channels[0],
+ block_channels[1], num_blocks[1],
+ strides[1], expand_ratio)
+ self.bottleneck3 = self._make_layer(block_channels[1],
+ block_channels[2], num_blocks[2],
+ strides[2], expand_ratio)
+ self.ppm = PPM(
+ pool_scales,
+ block_channels[2],
+ block_channels[2] // 4,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=align_corners)
+ self.out = ConvModule(
+ block_channels[2] * 2,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def _make_layer(self,
+ in_channels,
+ out_channels,
+ blocks,
+ stride=1,
+ expand_ratio=6):
+ layers = [
+ InvertedResidual(
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ norm_cfg=self.norm_cfg)
+ ]
+ for i in range(1, blocks):
+ layers.append(
+ InvertedResidual(
+ out_channels,
+ out_channels,
+ 1,
+ expand_ratio,
+ norm_cfg=self.norm_cfg))
+ return nn.Sequential(*layers)
+ def forward(self, x):
+ x = self.bottleneck1(x)
+ x = self.bottleneck2(x)
+ x = self.bottleneck3(x)
+ x = torch.cat([x, *self.ppm(x)], dim=1)
+ x = self.out(x)
+ return x
+class FeatureFusionModule(nn.Module):
+ """Feature fusion module.
+ Args:
+ higher_in_channels (int): Number of input channels of the
+ higher-resolution branch.
+ lower_in_channels (int): Number of input channels of the
+ lower-resolution branch.
+ out_channels (int): Number of output channels.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+ def __init__(self,
+ higher_in_channels,
+ lower_in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(FeatureFusionModule, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.dwconv = ConvModule(
+ lower_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.conv_lower_res = ConvModule(
+ out_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.conv_higher_res = ConvModule(
+ higher_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.relu = nn.ReLU(True)
+ def forward(self, higher_res_feature, lower_res_feature):
+ lower_res_feature = resize(
+ lower_res_feature,
+ size=higher_res_feature.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ lower_res_feature = self.dwconv(lower_res_feature)
+ lower_res_feature = self.conv_lower_res(lower_res_feature)
+ higher_res_feature = self.conv_higher_res(higher_res_feature)
+ out = higher_res_feature + lower_res_feature
+ return self.relu(out)
+class FastSCNN(nn.Module):
+ """Fast-SCNN Backbone.
+ Args:
+ in_channels (int): Number of input image channels. Default: 3.
+ downsample_dw_channels (tuple[int]): Number of output channels after
+ the first conv layer & the second conv layer in
+ Learning-To-Downsample (LTD) module.
+ Default: (32, 48).
+ global_in_channels (int): Number of input channels of
+ Global Feature Extractor(GFE).
+ Equal to number of output channels of LTD.
+ Default: 64.
+ global_block_channels (tuple[int]): Tuple of integers that describe
+ the output channels for each of the MobileNet-v2 bottleneck
+ residual blocks in GFE.
+ Default: (64, 96, 128).
+ global_block_strides (tuple[int]): Tuple of integers
+ that describe the strides (downsampling factors) for each of the
+ MobileNet-v2 bottleneck residual blocks in GFE.
+ Default: (2, 2, 1).
+ global_out_channels (int): Number of output channels of GFE.
+ Default: 128.
+ higher_in_channels (int): Number of input channels of the higher
+ resolution branch in FFM.
+ Equal to global_in_channels.
+ Default: 64.
+ lower_in_channels (int): Number of input channels of the lower
+ resolution branch in FFM.
+ Equal to global_out_channels.
+ Default: 128.
+ fusion_out_channels (int): Number of output channels of FFM.
+ Default: 128.
+ out_indices (tuple): Tuple of indices of list
+ [higher_res_features, lower_res_features, fusion_output].
+ Often set to (0,1,2) to enable aux. heads.
+ Default: (0, 1, 2).
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+ def __init__(self,
+ in_channels=3,
+ downsample_dw_channels=(32, 48),
+ global_in_channels=64,
+ global_block_channels=(64, 96, 128),
+ global_block_strides=(2, 2, 1),
+ global_out_channels=128,
+ higher_in_channels=64,
+ lower_in_channels=128,
+ fusion_out_channels=128,
+ out_indices=(0, 1, 2),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(FastSCNN, self).__init__()
+ if global_in_channels != higher_in_channels:
+ raise AssertionError('Global Input Channels must be the same \
+ with Higher Input Channels!')
+ elif global_out_channels != lower_in_channels:
+ raise AssertionError('Global Output Channels must be the same \
+ with Lower Input Channels!')
+ self.in_channels = in_channels
+ self.downsample_dw_channels1 = downsample_dw_channels[0]
+ self.downsample_dw_channels2 = downsample_dw_channels[1]
+ self.global_in_channels = global_in_channels
+ self.global_block_channels = global_block_channels
+ self.global_block_strides = global_block_strides
+ self.global_out_channels = global_out_channels
+ self.higher_in_channels = higher_in_channels
+ self.lower_in_channels = lower_in_channels
+ self.fusion_out_channels = fusion_out_channels
+ self.out_indices = out_indices
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.learning_to_downsample = LearningToDownsample(
+ in_channels,
+ downsample_dw_channels,
+ global_in_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.global_feature_extractor = GlobalFeatureExtractor(
+ global_in_channels,
+ global_block_channels,
+ global_out_channels,
+ strides=self.global_block_strides,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.feature_fusion = FeatureFusionModule(
+ higher_in_channels,
+ lower_in_channels,
+ fusion_out_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ def init_weights(self, pretrained=None):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ def forward(self, x):
+ higher_res_features = self.learning_to_downsample(x)
+ lower_res_features = self.global_feature_extractor(higher_res_features)
+ fusion_output = self.feature_fusion(higher_res_features,
+ lower_res_features)
+ outs = [higher_res_features, lower_res_features, fusion_output]
+ outs = [outs[i] for i in self.out_indices]
+ return tuple(outs)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/hrnet.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..331ebf3ccb8597b3f507670753789073fc3c946d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/hrnet.py
@@ -0,0 +1,555 @@
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+ kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+from annotator.uniformer.mmseg.ops import Upsample, resize
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import BasicBlock, Bottleneck
+class HRModule(nn.Module):
+ """High-Resolution Module for HRNet.
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(HRModule, self).__init__()
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ """Check branches configuration."""
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
+ f'{len(num_blocks)})'
+ raise ValueError(error_msg)
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
+ f'{len(num_channels)})'
+ raise ValueError(error_msg)
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
+ f'{len(in_channels)})'
+ raise ValueError(error_msg)
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ """Build one branch."""
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+ block.expansion)[1])
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ return nn.Sequential(*layers)
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ """Build multiple branch."""
+ branches = []
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+ return nn.ModuleList(branches)
+ def _make_fuse_layers(self):
+ """Build fuse layer."""
+ if self.num_branches == 1:
+ return None
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ # we set align_corners=False for HRNet
+ Upsample(
+ scale_factor=2**(j - i),
+ mode='bilinear',
+ align_corners=False)))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+ return nn.ModuleList(fuse_layers)
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ elif j > i:
+ y = y + resize(
+ self.fuse_layers[i][j](x[j]),
+ size=x[i].shape[2:],
+ mode='bilinear',
+ align_corners=False)
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+class HRNet(nn.Module):
+ """HRNet backbone.
+ High-Resolution Representations for Labeling Pixels and Regions
+ arXiv: https://arxiv.org/abs/1904.04514
+ Args:
+ extra (dict): detailed configuration for each stage of HRNet.
+ in_channels (int): Number of input image channels. Normally 3.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ Example:
+ >>> from annotator.uniformer.mmseg.models import HRNet
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=False):
+ super(HRNet, self).__init__()
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * block.expansion
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels)
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ """Make transition layer."""
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+ return nn.ModuleList(transition_layers)
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ """Make each layer."""
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ return nn.Sequential(*layers)
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ """Make each stage."""
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+ hr_modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ return nn.Sequential(*hr_modules), in_channels
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+ return y_list
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(HRNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6b3791692a0d1b5da3601875711710b7bd01ba
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py
@@ -0,0 +1,180 @@
+import logging
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+from ..builder import BACKBONES
+from ..utils import InvertedResidual, make_divisible
+class MobileNetV2(nn.Module):
+ """MobileNetV2 backbone.
+ Args:
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ strides (Sequence[int], optional): Strides of the first block of each
+ layer. If not specified, default config in ``arch_setting`` will
+ be used.
+ dilations (Sequence[int]): Dilation of each layer.
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: (7, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+ # Parameters to build layers. 3 parameters are needed to construct a
+ # layer, from left to right: expand_ratio, channel, num_blocks.
+ arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
+ [6, 96, 3], [6, 160, 3], [6, 320, 1]]
+ def __init__(self,
+ widen_factor=1.,
+ strides=(1, 2, 2, 2, 1, 2, 1),
+ dilations=(1, 1, 1, 1, 1, 1, 1),
+ out_indices=(1, 2, 4, 6),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV2, self).__init__()
+ self.widen_factor = widen_factor
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == len(self.arch_settings)
+ self.out_indices = out_indices
+ for index in out_indices:
+ if index not in range(0, 7):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 8). But received {index}')
+ if frozen_stages not in range(-1, 7):
+ raise ValueError('frozen_stages must be in range(-1, 7). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.in_channels = make_divisible(32 * widen_factor, 8)
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.layers = []
+ for i, layer_cfg in enumerate(self.arch_settings):
+ expand_ratio, channel, num_blocks = layer_cfg
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ out_channels = make_divisible(channel * widen_factor, 8)
+ inverted_res_layer = self.make_layer(
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ expand_ratio=expand_ratio)
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, inverted_res_layer)
+ self.layers.append(layer_name)
+ def make_layer(self, out_channels, num_blocks, stride, dilation,
+ expand_ratio):
+ """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+ Args:
+ out_channels (int): out_channels of block.
+ num_blocks (int): Number of blocks.
+ stride (int): Stride of the first block.
+ dilation (int): Dilation of the first block.
+ expand_ratio (int): Expand the number of channels of the
+ hidden layer in InvertedResidual by this ratio.
+ """
+ layers = []
+ for i in range(num_blocks):
+ layers.append(
+ InvertedResidual(
+ self.in_channels,
+ out_channels,
+ stride if i == 0 else 1,
+ expand_ratio=expand_ratio,
+ dilation=dilation if i == 0 else 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+ return nn.Sequential(*layers)
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def forward(self, x):
+ x = self.conv1(x)
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+ def train(self, mode=True):
+ super(MobileNetV2, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..16817400b4102899794fe64c9644713a4e54e2f9
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py
@@ -0,0 +1,255 @@
+import logging
+import annotator.uniformer.mmcv as mmcv
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.uniformer.mmcv.cnn.bricks import Conv2dAdaptivePadding
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+from ..builder import BACKBONES
+from ..utils import InvertedResidualV3 as InvertedResidual
+class MobileNetV3(nn.Module):
+ """MobileNetV3 backbone.
+ This backbone is the improved implementation of `Searching for MobileNetV3
+ `_.
+ Args:
+ arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
+ Default: 'small'.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ out_indices (tuple[int]): Output from which layer.
+ Default: (0, 1, 12).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
+ some memory while slowing down the training speed.
+ Default: False.
+ """
+ # Parameters to build each block:
+ # [kernel size, mid channels, out channels, with_se, act type, stride]
+ arch_settings = {
+ 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
+ [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
+ [3, 88, 24, False, 'ReLU', 1],
+ [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
+ [5, 144, 48, True, 'HSwish', 1],
+ [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
+ [5, 576, 96, True, 'HSwish', 1],
+ [5, 576, 96, True, 'HSwish', 1]],
+ 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
+ [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
+ [3, 72, 24, False, 'ReLU', 1],
+ [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
+ [5, 120, 40, True, 'ReLU', 1],
+ [5, 120, 40, True, 'ReLU', 1],
+ [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
+ [3, 200, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
+ [3, 672, 112, True, 'HSwish', 1],
+ [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
+ [5, 960, 160, True, 'HSwish', 1],
+ [5, 960, 160, True, 'HSwish', 1]]
+ } # yapf: disable
+ def __init__(self,
+ arch='small',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ out_indices=(0, 1, 12),
+ frozen_stages=-1,
+ reduction_factor=1,
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV3, self).__init__()
+ assert arch in self.arch_settings
+ assert isinstance(reduction_factor, int) and reduction_factor > 0
+ assert mmcv.is_tuple_of(out_indices, int)
+ for index in out_indices:
+ if index not in range(0, len(self.arch_settings[arch]) + 2):
+ raise ValueError(
+ 'the item in out_indices must in '
+ f'range(0, {len(self.arch_settings[arch])+2}). '
+ f'But received {index}')
+ if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{len(self.arch_settings[arch])+2}). '
+ f'But received {frozen_stages}')
+ self.arch = arch
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.reduction_factor = reduction_factor
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.layers = self._make_layer()
+ def _make_layer(self):
+ layers = []
+ # build the first layer (layer0)
+ in_channels = 16
+ layer = ConvModule(
+ in_channels=3,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=dict(type='Conv2dAdaptivePadding'),
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ self.add_module('layer0', layer)
+ layers.append('layer0')
+ layer_setting = self.arch_settings[self.arch]
+ for i, params in enumerate(layer_setting):
+ (kernel_size, mid_channels, out_channels, with_se, act,
+ stride) = params
+ if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
+ i >= 8:
+ mid_channels = mid_channels // self.reduction_factor
+ out_channels = out_channels // self.reduction_factor
+ if with_se:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=4,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0)))
+ else:
+ se_cfg = None
+ layer = InvertedResidual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ with_expand_conv=(in_channels != mid_channels),
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type=act),
+ with_cp=self.with_cp)
+ in_channels = out_channels
+ layer_name = 'layer{}'.format(i + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+ # build the last layer
+ # block5 layer12 os=32 for small model
+ # block6 layer16 os=32 for large model
+ layer = ConvModule(
+ in_channels=in_channels,
+ out_channels=576 if self.arch == 'small' else 960,
+ kernel_size=1,
+ stride=1,
+ dilation=4,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ layer_name = 'layer{}'.format(len(layer_setting) + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+ # next, convert backbone MobileNetV3 to a semantic segmentation version
+ if self.arch == 'small':
+ self.layer4.depthwise_conv.conv.stride = (1, 1)
+ self.layer9.depthwise_conv.conv.stride = (1, 1)
+ for i in range(4, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+ if i < 9:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+ else:
+ self.layer7.depthwise_conv.conv.stride = (1, 1)
+ self.layer13.depthwise_conv.conv.stride = (1, 1)
+ for i in range(7, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+ if i < 13:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+ return layers
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+ def train(self, mode=True):
+ super(MobileNetV3, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/resnest.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45a837f395230029e9d4194ff9f7f2f8f7067b0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/resnest.py
@@ -0,0 +1,314 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNetV1d
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+class SplitAttentionConv2d(nn.Module):
+ """Split-Attention Conv2d in ResNeSt.
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int | tuple[int]): Same as nn.Conv2d.
+ stride (int | tuple[int]): Same as nn.Conv2d.
+ padding (int | tuple[int]): Same as nn.Conv2d.
+ dilation (int | tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels. Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ dcn (dict): Config dict for DCN. Default: None.
+ """
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None):
+ super(SplitAttentionConv2d, self).__init__()
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.with_dcn = dcn is not None
+ self.dcn = dcn
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_dcn and not fallback_on_stride:
+ assert conv_cfg is None, 'conv_cfg must be None for DCN'
+ conv_cfg = dcn
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+ @property
+ def norm0(self):
+ """nn.Module: the normalization layer named "norm0" """
+ return getattr(self, self.norm0_name)
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+ batch, rchannel = x.shape[:2]
+ batch = x.size(0)
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+ Args:
+ inplane (int): Input planes of this block.
+ planes (int): Middle planes of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Key word arguments for base class.
+ """
+ expansion = 4
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ """Bottleneck block for ResNeSt."""
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.with_modulated_dcn = False
+ self.conv2 = SplitAttentionConv2d(
+ width,
+ width,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=self.dcn)
+ delattr(self, self.norm2_name)
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+ def forward(self, x):
+ def _inner_forward(x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+ out = self.conv2(out)
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+ out = self.conv3(out)
+ out = self.norm3(out)
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ out = self.relu(out)
+ return out
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+ Args:
+ groups (int): Number of groups of Bottleneck. Default: 1
+ base_width (int): Base width of Bottleneck. Default: 4
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Keyword arguments for ResNet.
+ """
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3))
+ }
+ def __init__(self,
+ groups=1,
+ base_width=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super(ResNeSt, self).__init__(**kwargs)
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/resnet.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e52bf048d28ecb069db4728e5f05ad85ac53198
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/resnet.py
@@ -0,0 +1,688 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
+ constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import ResLayer
+class BasicBlock(nn.Module):
+ """Basic block for ResNet."""
+ expansion = 1
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(BasicBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=False)
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+ def forward(self, x):
+ """Forward function."""
+ def _inner_forward(x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ out = self.relu(out)
+ return out
+class Bottleneck(nn.Module):
+ """Bottleneck block for ResNet.
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ expansion = 4
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+ assert all(p['position'] in allowed_position for p in plugins)
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv1'
+ ]
+ self.after_conv2_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv2'
+ ]
+ self.after_conv3_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv3'
+ ]
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, planes * self.expansion, postfix=3)
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ planes * self.expansion, self.after_conv3_plugins)
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(
+ plugin,
+ in_channels=in_channels,
+ postfix=plugin.pop('postfix', ''))
+ assert not hasattr(self, name), f'duplicate plugin {name}'
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+ def forward_plugin(self, x, plugin_names):
+ """Forward function for plugins."""
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(x)
+ return out
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+ def forward(self, x):
+ """Forward function."""
+ def _inner_forward(x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+ out = self.conv3(out)
+ out = self.norm3(out)
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ out = self.relu(out)
+ return out
+class ResNet(nn.Module):
+ """ResNet backbone.
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default" 3.
+ stem_channels (int): Number of stem channels. Default: 64.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+ - cfg (dict, required): Cfg dict to build plugin.
+ - position (str, required): Position inside block to insert plugin,
+ options: 'after_conv1', 'after_conv2', 'after_conv3'.
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'
+ multi_grid (Sequence[int]|None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ Example:
+ >>> from annotator.uniformer.mmseg.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=64,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ multi_grid=None,
+ contract_dilation=False,
+ with_cp=False,
+ zero_init_residual=True):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.multi_grid = multi_grid
+ self.contract_dilation = contract_dilation
+ self.zero_init_residual = zero_init_residual
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+ self._make_stem_layer(in_channels, stem_channels)
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ # multi grid is applied to last layer only
+ stage_multi_grid = multi_grid if i == len(
+ self.stage_blocks) - 1 else None
+ planes = base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ multi_grid=stage_multi_grid,
+ contract_dilation=contract_dilation)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i+1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+ self._freeze_stages()
+ self.feat_dim = self.block.expansion * base_channels * 2**(
+ len(self.stage_blocks) - 1)
+ def make_stage_plugins(self, plugins, stage_idx):
+ """make plugins for ResNet 'stage_idx'th stage .
+ Currently we support to insert 'context_block',
+ 'empirical_attention_block', 'nonlocal_block' into the backbone like
+ ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+ An example of plugins format could be :
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+ Suppose 'stage_idx=0', the structure of blocks in the stage would be:
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+ If stages is missing, the plugin would be applied to all stages.
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop('stages', None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+ return stage_plugins
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+ def _make_stem_layer(self, in_channels, stem_channels):
+ """Make stem layer for ResNet."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ def _freeze_stages(self):
+ """Freeze stages param and norm stats."""
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and hasattr(
+ m, 'conv2_offset'):
+ constant_init(m.conv2_offset, 0)
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+class ResNetV1c(ResNet):
+ """ResNetV1c variant described in [1]_.
+ Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
+ in the input stem with three 3x3 convs.
+ References:
+ .. [1] https://arxiv.org/pdf/1812.01187.pdf
+ """
+ def __init__(self, **kwargs):
+ super(ResNetV1c, self).__init__(
+ deep_stem=True, avg_down=False, **kwargs)
+class ResNetV1d(ResNet):
+ """ResNetV1d variant described in [1]_.
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(
+ deep_stem=True, avg_down=True, **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/resnext.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..962249ad6fd9b50960ad6426f7ce3cac6ed8c5bc
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/resnext.py
@@ -0,0 +1,145 @@
+import math
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeXt.
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Normally 3.
+ num_stages (int): Resnet stages, normally 4.
+ groups (int): Group of resnext.
+ base_width (int): Base width of resnext.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ Example:
+ >>> from annotator.uniformer.mmseg.models import ResNeXt
+ >>> import torch
+ >>> self = ResNeXt(depth=50)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(ResNeXt, self).__init__(**kwargs)
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``"""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/unet.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..82caa16a94c195c192a2a920fb7bc7e60f0f3ce3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/unet.py
@@ -0,0 +1,429 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
+ build_norm_layer, constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import UpConvBlock
+class BasicConvBlock(nn.Module):
+ """Basic convolutional block for UNet.
+ This module consists of several plain convolutional layers.
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers. Default: 2.
+ stride (int): Whether use stride convolution to downsample
+ the input feature map. If stride=2, it only uses stride convolution
+ in the first convolutional layer to downsample the input feature
+ map. Options are 1 or 2. Default: 1.
+ dilation (int): Whether use dilated convolution to expand the
+ receptive field. Set dilation rate of each convolutional layer and
+ the dilation rate of the first convolutional layer is always 1.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ dcn=None,
+ plugins=None):
+ super(BasicConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ self.with_cp = with_cp
+ convs = []
+ for i in range(num_convs):
+ convs.append(
+ ConvModule(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ dilation=1 if i == 0 else dilation,
+ padding=1 if i == 0 else dilation,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.convs = nn.Sequential(*convs)
+ def forward(self, x):
+ """Forward function."""
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.convs, x)
+ else:
+ out = self.convs(x)
+ return out
+class DeconvModule(nn.Module):
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
+ This module uses deconvolution to upsample feature map in the decoder
+ of UNet.
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ kernel_size=4,
+ scale_factor=2):
+ super(DeconvModule, self).__init__()
+ assert (kernel_size - scale_factor >= 0) and\
+ (kernel_size - scale_factor) % 2 == 0,\
+ f'kernel_size should be greater than or equal to scale_factor '\
+ f'and (kernel_size - scale_factor) should be even numbers, '\
+ f'while the kernel size is {kernel_size} and scale_factor is '\
+ f'{scale_factor}.'
+ stride = scale_factor
+ padding = (kernel_size - scale_factor) // 2
+ self.with_cp = with_cp
+ deconv = nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ norm_name, norm = build_norm_layer(norm_cfg, out_channels)
+ activate = build_activation_layer(act_cfg)
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+ def forward(self, x):
+ """Forward function."""
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.deconv_upsamping, x)
+ else:
+ out = self.deconv_upsamping(x)
+ return out
+class InterpConv(nn.Module):
+ """Interpolation upsample module in decoder for UNet.
+ This module uses interpolation to upsample feature map in the decoder
+ of UNet. It consists of one interpolation upsample layer and one
+ convolutional layer. It can be one interpolation upsample layer followed
+ by one convolutional layer (conv_first=False) or one convolutional layer
+ followed by one interpolation upsample layer (conv_first=True).
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ conv_first (bool): Whether convolutional layer or interpolation
+ upsample layer first. Default: False. It means interpolation
+ upsample layer followed by one convolutional layer.
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+ stride (int): Stride of the convolutional layer. Default: 1.
+ padding (int): Padding of the convolutional layer. Default: 1.
+ upsample_cfg (dict): Interpolation config of the upsample layer.
+ Default: dict(
+ scale_factor=2, mode='bilinear', align_corners=False).
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ conv_cfg=None,
+ conv_first=False,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ upsample_cfg=dict(
+ scale_factor=2, mode='bilinear', align_corners=False)):
+ super(InterpConv, self).__init__()
+ self.with_cp = with_cp
+ conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ upsample = nn.Upsample(**upsample_cfg)
+ if conv_first:
+ self.interp_upsample = nn.Sequential(conv, upsample)
+ else:
+ self.interp_upsample = nn.Sequential(upsample, conv)
+ def forward(self, x):
+ """Forward function."""
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.interp_upsample, x)
+ else:
+ out = self.interp_upsample(x)
+ return out
+class UNet(nn.Module):
+ """UNet backbone.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondence encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondence encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ Notice:
+ The input image size should be divisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_divisible.
+ """
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None):
+ super(UNet, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, \
+ 'The length of strides should be equal to num_stages, '\
+ f'while the strides is {strides}, the length of '\
+ f'strides is {len(strides)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_num_convs) == num_stages, \
+ 'The length of enc_num_convs should be equal to num_stages, '\
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_num_convs) == (num_stages-1), \
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(downsamples) == (num_stages-1), \
+ 'The length of downsamples should be equal to (num_stages-1), '\
+ f'while the downsamples is {downsamples}, the length of '\
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_dilations) == num_stages, \
+ 'The length of enc_dilations should be equal to num_stages, '\
+ f'while the enc_dilations is {enc_dilations}, the length of '\
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_dilations) == (num_stages-1), \
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
+ f'while the dec_dilations is {dec_dilations}, the length of '\
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+ self.base_channels = base_channels
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+ def forward(self, x):
+ self._check_input_divisible(x)
+ enc_outs = []
+ for enc in self.encoder:
+ x = enc(x)
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+ return dec_outs
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(UNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+ def _check_input_divisible(self, x):
+ h, w = x.shape[-2:]
+ whole_downsample_rate = 1
+ for i in range(1, self.num_stages):
+ if self.strides[i] == 2 or self.downsamples[i - 1]:
+ whole_downsample_rate *= 2
+ assert (h % whole_downsample_rate == 0) \
+ and (w % whole_downsample_rate == 0),\
+ f'The input image size {(h, w)} should be divisible by the whole '\
+ f'downsample rate {whole_downsample_rate}, when num_stages is '\
+ f'{self.num_stages}, strides is {self.strides}, and downsamples '\
+ f'is {self.downsamples}.'
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/uniformer.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4bb88e4c928540cca9ab609988b916520f5b7a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/uniformer.py
@@ -0,0 +1,422 @@
+# --------------------------------------------------------
+# UniFormer
+# Copyright (c) 2022 SenseTime X-Lab
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Kunchang Li
+# --------------------------------------------------------
+from collections import OrderedDict
+import math
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from annotator.uniformer.mmcv_custom import load_checkpoint
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+class CMlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+class CBlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = nn.BatchNorm2d(dim)
+ self.conv1 = nn.Conv2d(dim, dim, 1)
+ self.conv2 = nn.Conv2d(dim, dim, 1)
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = nn.BatchNorm2d(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+class SABlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ B, N, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.transpose(1, 2).reshape(B, N, H, W)
+ return x
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+class SABlock_Windows(nn.Module):
+ def __init__(self, dim, num_heads, window_size=14, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.window_size=window_size
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x.permute(0, 2, 3, 1)
+ B, H, W, C = x.shape
+ shortcut = x
+ x = self.norm1(x)
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+ x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+ attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+ # reverse cyclic shift
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.permute(0, 3, 1, 2).reshape(B, C, H, W)
+ return x
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.norm = nn.LayerNorm(embed_dim)
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ def forward(self, x):
+ B, _, H, W = x.shape
+ x = self.proj(x)
+ B, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ return x
+class UniFormer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, layers=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=80, embed_dim=[64, 128, 320, 512],
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ pretrained_path=None, use_checkpoint=False, checkpoint_num=[0, 0, 0, 0],
+ windows=False, hybrid=False, window_size=14):
+ """
+ Args:
+ layer (list): number of block in each layer
+ img_size (int, tuple): input image size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ head_dim (int): dimension of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer (nn.Module): normalization layer
+ pretrained_path (str): path of pretrained model
+ use_checkpoint (bool): whether use checkpoint
+ checkpoint_num (list): index for using checkpoint in every stage
+ windows (bool): whether use window MHRA
+ hybrid (bool): whether use hybrid MHRA
+ window_size (int): size of window (>14)
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.checkpoint_num = checkpoint_num
+ self.windows = windows
+ print(f'Use Checkpoint: {self.use_checkpoint}')
+ print(f'Checkpoint Number: {self.checkpoint_num}')
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
+ self.patch_embed2 = PatchEmbed(
+ img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
+ self.patch_embed3 = PatchEmbed(
+ img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
+ self.patch_embed4 = PatchEmbed(
+ img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] # stochastic depth decay rule
+ num_heads = [dim // head_dim for dim in embed_dim]
+ self.blocks1 = nn.ModuleList([
+ CBlock(
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+ for i in range(layers[0])])
+ self.norm1=norm_layer(embed_dim[0])
+ self.blocks2 = nn.ModuleList([
+ CBlock(
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]], norm_layer=norm_layer)
+ for i in range(layers[1])])
+ self.norm2 = norm_layer(embed_dim[1])
+ if self.windows:
+ print('Use local window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ SABlock_Windows(
+ dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+ for i in range(layers[2])])
+ elif hybrid:
+ print('Use hybrid window for blocks in stage3')
+ block3 = []
+ for i in range(layers[2]):
+ if (i + 1) % 4 == 0:
+ block3.append(SABlock(
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+ else:
+ block3.append(SABlock_Windows(
+ dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+ self.blocks3 = nn.ModuleList(block3)
+ else:
+ print('Use global window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ SABlock(
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+ for i in range(layers[2])])
+ self.norm3 = norm_layer(embed_dim[2])
+ self.blocks4 = nn.ModuleList([
+ SABlock(
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]+layers[2]], norm_layer=norm_layer)
+ for i in range(layers[3])])
+ self.norm4 = norm_layer(embed_dim[3])
+ # Representation layer
+ if representation_size:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+ self.apply(self._init_weights)
+ self.init_weights(pretrained=pretrained_path)
+ def init_weights(self, pretrained):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
+ print(f'Load pretrained model from {pretrained}')
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+ def get_classifier(self):
+ return self.head
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ def forward_features(self, x):
+ out = []
+ x = self.patch_embed1(x)
+ x = self.pos_drop(x)
+ for i, blk in enumerate(self.blocks1):
+ if self.use_checkpoint and i < self.checkpoint_num[0]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm1(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed2(x)
+ for i, blk in enumerate(self.blocks2):
+ if self.use_checkpoint and i < self.checkpoint_num[1]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm2(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed3(x)
+ for i, blk in enumerate(self.blocks3):
+ if self.use_checkpoint and i < self.checkpoint_num[2]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm3(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed4(x)
+ for i, blk in enumerate(self.blocks4):
+ if self.use_checkpoint and i < self.checkpoint_num[3]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm4(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ return tuple(out)
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
diff --git a/ControlNet/annotator/uniformer/mmseg/models/backbones/vit.py b/ControlNet/annotator/uniformer/mmseg/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..59e4479650690e08cbc4cab9427aefda47c2116d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/backbones/vit.py
@@ -0,0 +1,459 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
+ constant_init, kaiming_init, normal_init)
+from annotator.uniformer.mmcv.runner import _load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import DropPath, trunc_normal_
+class Mlp(nn.Module):
+ """MLP layer for Encoder block.
+ Args:
+ in_features(int): Input dimension for the first fully
+ connected layer.
+ hidden_features(int): Output dimension for the first fully
+ connected layer.
+ out_features(int): Output dementsion for the second fully
+ connected layer.
+ act_cfg(dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ drop(float): Drop rate for the dropout layer. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=dict(type='GELU'),
+ drop=0.):
+ super(Mlp, self).__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = Linear(in_features, hidden_features)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+class Attention(nn.Module):
+ """Attention layer for Encoder block.
+ Args:
+ dim (int): Dimension for the input vector.
+ num_heads (int): Number of parallel attention heads.
+ qkv_bias (bool): Enable bias for qkv if True. Default: False.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for output weights. Default: 0.
+ """
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ def forward(self, x):
+ b, n, c = x.shape
+ qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
+ c // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+class Block(nn.Module):
+ """Implements encoder block with residual connection.
+ Args:
+ dim (int): The feature dimension.
+ num_heads (int): Number of parallel attention heads.
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float): Drop rate for mlp output weights. Default: 0.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for attn layer output weights.
+ Default: 0.
+ drop_path (float): Drop rate for paths of model.
+ Default: 0.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', requires_grad=True).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ proj_drop=0.,
+ drop_path=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN', eps=1e-6),
+ with_cp=False):
+ super(Block, self).__init__()
+ self.with_cp = with_cp
+ _, self.norm1 = build_norm_layer(norm_cfg, dim)
+ self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
+ proj_drop)
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ _, self.norm2 = build_norm_layer(norm_cfg, dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+ def forward(self, x):
+ def _inner_forward(x):
+ out = x + self.drop_path(self.attn(self.norm1(x)))
+ out = out + self.drop_path(self.mlp(self.norm2(out)))
+ return out
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ return out
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding.
+ Args:
+ img_size (int | tuple): Input image size.
+ default: 224.
+ patch_size (int): Width and height for a patch.
+ default: 16.
+ in_channels (int): Input channels for images. Default: 3.
+ embed_dim (int): The embedding dimension. Default: 768.
+ """
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768):
+ super(PatchEmbed, self).__init__()
+ if isinstance(img_size, int):
+ self.img_size = (img_size, img_size)
+ elif isinstance(img_size, tuple):
+ self.img_size = img_size
+ else:
+ raise TypeError('img_size must be type of int or tuple')
+ h, w = self.img_size
+ self.patch_size = (patch_size, patch_size)
+ self.num_patches = (h // patch_size) * (w // patch_size)
+ self.proj = Conv2d(
+ in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+ def forward(self, x):
+ return self.proj(x).flatten(2).transpose(1, 2)
+class VisionTransformer(nn.Module):
+ """Vision transformer backbone.
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
+ Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
+ Args:
+ img_size (tuple): input image size. Default: (224, 224).
+ patch_size (int, tuple): patch size. Default: 16.
+ in_channels (int): number of input channels. Default: 3.
+ embed_dim (int): embedding dimension. Default: 768.
+ depth (int): depth of transformer. Default: 12.
+ num_heads (int): number of attention heads. Default: 12.
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ out_indices (list | tuple | int): Output from which stages.
+ Default: -1.
+ qkv_bias (bool): enable bias for qkv if True. Default: True.
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): dropout rate. Default: 0.
+ attn_drop_rate (float): attention dropout rate. Default: 0.
+ drop_path_rate (float): Rate of DropPath. Default: 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', eps=1e-6, requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Default: False.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Default: bicubic.
+ with_cls_token (bool): If concatenating class token into image tokens
+ as transformer input. Default: True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ """
+ def __init__(self,
+ img_size=(224, 224),
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=11,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
+ act_cfg=dict(type='GELU'),
+ norm_eval=False,
+ final_norm=False,
+ with_cls_token=True,
+ interpolate_mode='bicubic',
+ with_cp=False):
+ super(VisionTransformer, self).__init__()
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.features = self.embed_dim = embed_dim
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim)
+ self.with_cls_token = with_cls_token
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ if isinstance(out_indices, int):
+ self.out_indices = [out_indices]
+ elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
+ self.out_indices = out_indices
+ else:
+ raise TypeError('out_indices must be type of int, list or tuple')
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=dpr[i],
+ attn_drop=attn_drop_rate,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp) for i in range(depth)
+ ])
+ self.interpolate_mode = interpolate_mode
+ self.final_norm = final_norm
+ if final_norm:
+ _, self.norm = build_norm_layer(norm_cfg, embed_dim)
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ checkpoint = _load_checkpoint(pretrained, logger=logger)
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if 'pos_embed' in state_dict.keys():
+ if self.pos_embed.shape != state_dict['pos_embed'].shape:
+ logger.info(msg=f'Resize the pos_embed shape from \
+{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
+ h, w = self.img_size
+ pos_size = int(
+ math.sqrt(state_dict['pos_embed'].shape[1] - 1))
+ state_dict['pos_embed'] = self.resize_pos_embed(
+ state_dict['pos_embed'], (h, w), (pos_size, pos_size),
+ self.patch_size, self.interpolate_mode)
+ self.load_state_dict(state_dict, False)
+ elif pretrained is None:
+ # We only implement the 'jax_impl' initialization implemented at
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ for n, m in self.named_modules():
+ if isinstance(m, Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ if 'mlp' in n:
+ normal_init(m.bias, std=1e-6)
+ else:
+ constant_init(m.bias, 0)
+ elif isinstance(m, Conv2d):
+ kaiming_init(m.weight, mode='fan_in')
+ if m.bias is not None:
+ constant_init(m.bias, 0)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ constant_init(m.bias, 0)
+ constant_init(m.weight, 1.0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+ def _pos_embeding(self, img, patched_img, pos_embed):
+ """Positiong embeding method.
+ Resize the pos_embed, if the input image size doesn't match
+ the training size.
+ Args:
+ img (torch.Tensor): The inference image tensor, the shape
+ must be [B, C, H, W].
+ patched_img (torch.Tensor): The patched image, it should be
+ shape of [B, L1, C].
+ pos_embed (torch.Tensor): The pos_embed weighs, it should be
+ shape of [B, L2, c].
+ Return:
+ torch.Tensor: The pos encoded image feature.
+ """
+ assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
+ 'the shapes of patched_img and pos_embed must be [B, L, C]'
+ x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
+ if x_len != pos_len:
+ if pos_len == (self.img_size[0] // self.patch_size) * (
+ self.img_size[1] // self.patch_size) + 1:
+ pos_h = self.img_size[0] // self.patch_size
+ pos_w = self.img_size[1] // self.patch_size
+ else:
+ raise ValueError(
+ 'Unexpected shape of pos_embed, got {}.'.format(
+ pos_embed.shape))
+ pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
+ (pos_h, pos_w), self.patch_size,
+ self.interpolate_mode)
+ return self.pos_drop(patched_img + pos_embed)
+ @staticmethod
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
+ """Resize pos_embed weights.
+ Resize pos_embed using bicubic interpolate method.
+ Args:
+ pos_embed (torch.Tensor): pos_embed weights.
+ input_shpae (tuple): Tuple for (input_h, intput_w).
+ pos_shape (tuple): Tuple for (pos_h, pos_w).
+ patch_size (int): Patch size.
+ Return:
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
+ """
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
+ input_h, input_w = input_shpae
+ pos_h, pos_w = pos_shape
+ cls_token_weight = pos_embed[:, 0]
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
+ pos_embed_weight = pos_embed_weight.reshape(
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
+ pos_embed_weight = F.interpolate(
+ pos_embed_weight,
+ size=[input_h // patch_size, input_w // patch_size],
+ align_corners=False,
+ mode=mode)
+ cls_token_weight = cls_token_weight.unsqueeze(1)
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
+ return pos_embed
+ def forward(self, inputs):
+ B = inputs.shape[0]
+ x = self.patch_embed(inputs)
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = self._pos_embeding(inputs, x, self.pos_embed)
+ if not self.with_cls_token:
+ # Remove class token for transformer input
+ x = x[:, 1:]
+ outs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i == len(self.blocks) - 1:
+ if self.final_norm:
+ x = self.norm(x)
+ if i in self.out_indices:
+ if self.with_cls_token:
+ # Remove class token and reshape token for decoder head
+ out = x[:, 1:]
+ else:
+ out = x
+ B, _, C = out.shape
+ out = out.reshape(B, inputs.shape[2] // self.patch_size,
+ inputs.shape[3] // self.patch_size,
+ C).permute(0, 3, 1, 2)
+ outs.append(out)
+ return tuple(outs)
+ def train(self, mode=True):
+ super(VisionTransformer, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.LayerNorm):
+ m.eval()
diff --git a/ControlNet/annotator/uniformer/mmseg/models/builder.py b/ControlNet/annotator/uniformer/mmseg/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5b971252bfc971c3ffbaa27746d69b1d3ea9fd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/builder.py
@@ -0,0 +1,46 @@
+import warnings
+from annotator.uniformer.mmcv.cnn import MODELS as MMCV_MODELS
+from annotator.uniformer.mmcv.utils import Registry
+MODELS = Registry('models', parent=MMCV_MODELS)
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+def build_segmentor(cfg, train_cfg=None, test_cfg=None):
+ """Build segmentor."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn(
+ 'train_cfg and test_cfg is deprecated, '
+ 'please specify them in model', UserWarning)
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
+ 'train_cfg specified in both outer field and model field '
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
+ 'test_cfg specified in both outer field and model field '
+ return SEGMENTORS.build(
+ cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/__init__.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac66d3cfe0ea04af45c0f3594bf135841c3812e3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/__init__.py
@@ -0,0 +1,28 @@
+from .ann_head import ANNHead
+from .apc_head import APCHead
+from .aspp_head import ASPPHead
+from .cc_head import CCHead
+from .da_head import DAHead
+from .dm_head import DMHead
+from .dnl_head import DNLHead
+from .ema_head import EMAHead
+from .enc_head import EncHead
+from .fcn_head import FCNHead
+from .fpn_head import FPNHead
+from .gc_head import GCHead
+from .lraspp_head import LRASPPHead
+from .nl_head import NLHead
+from .ocr_head import OCRHead
+# from .point_head import PointHead
+from .psa_head import PSAHead
+from .psp_head import PSPHead
+from .sep_aspp_head import DepthwiseSeparableASPPHead
+from .sep_fcn_head import DepthwiseSeparableFCNHead
+from .uper_head import UPerHead
+__all__ = [
+ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
+ 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
+ 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
+ 'APCHead', 'DMHead', 'LRASPPHead'
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ann_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ann_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..30aaacc2cafc568d3de71d1477b4de0dc0fea9d3
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ann_head.py
@@ -0,0 +1,245 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+class PPMConcat(nn.ModuleList):
+ """Pyramid Pooling Module that only concat the features of each layer.
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ """
+ def __init__(self, pool_scales=(1, 3, 6, 8)):
+ super(PPMConcat, self).__init__(
+ [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
+ def forward(self, feats):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(feats)
+ ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
+ concat_outs = torch.cat(ppm_outs, dim=2)
+ return concat_outs
+class SelfAttentionBlock(_SelfAttentionBlock):
+ """Make a ANN used SelfAttentionBlock.
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_scale (int): The scale of query feature map.
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, share_key_query, query_scale, key_pool_scales,
+ conv_cfg, norm_cfg, act_cfg):
+ key_psp = PPMConcat(key_pool_scales)
+ if query_scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=query_scale)
+ else:
+ query_downsample = None
+ super(SelfAttentionBlock, self).__init__(
+ key_in_channels=low_in_channels,
+ query_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=share_key_query,
+ query_downsample=query_downsample,
+ key_downsample=key_psp,
+ key_query_num_convs=1,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+class AFNB(nn.Module):
+ """Asymmetric Fusion Non-local Block(AFNB)
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ and query projection.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, query_scales, key_pool_scales, conv_cfg,
+ norm_cfg, act_cfg):
+ super(AFNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=False,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ out_channels + high_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ def forward(self, low_feats, high_feats):
+ """Forward function."""
+ priors = [stage(high_feats, low_feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, high_feats], 1))
+ return output
+class APNB(nn.Module):
+ """Asymmetric Pyramid Non-local Block (APNB)
+ Args:
+ in_channels (int): Input channels of key/query feature,
+ which is the key feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+ def __init__(self, in_channels, channels, out_channels, query_scales,
+ key_pool_scales, conv_cfg, norm_cfg, act_cfg):
+ super(APNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=in_channels,
+ high_in_channels=in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=True,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ 2 * in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ def forward(self, feats):
+ """Forward function."""
+ priors = [stage(feats, feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, feats], 1))
+ return output
+class ANNHead(BaseDecodeHead):
+ """Asymmetric Non-local Neural Networks for Semantic Segmentation.
+ This head is the implementation of `ANNNet
+ `_.
+ Args:
+ project_channels (int): Projection channels for Nonlocal.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): The pooling scales of key feature map.
+ Default: (1, 3, 6, 8).
+ """
+ def __init__(self,
+ project_channels,
+ query_scales=(1, ),
+ key_pool_scales=(1, 3, 6, 8),
+ **kwargs):
+ super(ANNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(self.in_channels) == 2
+ low_in_channels, high_in_channels = self.in_channels
+ self.project_channels = project_channels
+ self.fusion = AFNB(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ out_channels=high_in_channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ high_in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.context = APNB(
+ in_channels=self.channels,
+ out_channels=self.channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ low_feats, high_feats = self._transform_inputs(inputs)
+ output = self.fusion(low_feats, high_feats)
+ output = self.dropout(output)
+ output = self.bottleneck(output)
+ output = self.context(output)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/apc_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/apc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7038bdbe0edf2a1f184b6899486d2d190dda076
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/apc_head.py
@@ -0,0 +1,158 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+class ACM(nn.Module):
+ """Adaptive Context Module used in APCNet.
+ Args:
+ pool_scale (int): Pooling scale used in Adaptive Context
+ Module to extract region features.
+ fusion (bool): Add one conv to fuse residual feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(ACM, self).__init__()
+ self.pool_scale = pool_scale
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.pooled_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.global_info = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
+ self.residual_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, x):
+ """Forward function."""
+ pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
+ # [batch_size, channels, h, w]
+ x = self.input_redu_conv(x)
+ # [batch_size, channels, pool_scale, pool_scale]
+ pooled_x = self.pooled_redu_conv(pooled_x)
+ batch_size = x.size(0)
+ # [batch_size, pool_scale * pool_scale, channels]
+ pooled_x = pooled_x.view(batch_size, self.channels,
+ -1).permute(0, 2, 1).contiguous()
+ # [batch_size, h * w, pool_scale * pool_scale]
+ affinity_matrix = self.gla(x + resize(
+ self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
+ ).permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.pool_scale**2)
+ affinity_matrix = F.sigmoid(affinity_matrix)
+ # [batch_size, h * w, channels]
+ z_out = torch.matmul(affinity_matrix, pooled_x)
+ # [batch_size, channels, h * w]
+ z_out = z_out.permute(0, 2, 1).contiguous()
+ # [batch_size, channels, h, w]
+ z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
+ z_out = self.residual_conv(z_out)
+ z_out = F.relu(z_out + x)
+ if self.fusion:
+ z_out = self.fusion_conv(z_out)
+ return z_out
+class APCHead(BaseDecodeHead):
+ """Adaptive Pyramid Context Network for Semantic Segmentation.
+ This head is the implementation of
+ `APCNet `_.
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Adaptive Context
+ Module. Default: (1, 2, 3, 6).
+ fusion (bool): Add one conv to fuse residual feature.
+ """
+ def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
+ super(APCHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.fusion = fusion
+ acm_modules = []
+ for pool_scale in self.pool_scales:
+ acm_modules.append(
+ ACM(pool_scale,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.acm_modules = nn.ModuleList(acm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ acm_outs = [x]
+ for acm_module in self.acm_modules:
+ acm_outs.append(acm_module(x))
+ acm_outs = torch.cat(acm_outs, dim=1)
+ output = self.bottleneck(acm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa914b5bb25124d1ff199553d96713d6a80484c0
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+class ASPPModule(nn.ModuleList):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module.
+ Args:
+ dilations (tuple[int]): Dilation rate of each layer.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg):
+ super(ASPPModule, self).__init__()
+ self.dilations = dilations
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for dilation in dilations:
+ self.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1 if dilation == 1 else 3,
+ dilation=dilation,
+ padding=0 if dilation == 1 else dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ def forward(self, x):
+ """Forward function."""
+ aspp_outs = []
+ for aspp_module in self:
+ aspp_outs.append(aspp_module(x))
+ return aspp_outs
+class ASPPHead(BaseDecodeHead):
+ """Rethinking Atrous Convolution for Semantic Image Segmentation.
+ This head is the implementation of `DeepLabV3
+ `_.
+ Args:
+ dilations (tuple[int]): Dilation rates for ASPP module.
+ Default: (1, 6, 12, 18).
+ """
+ def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
+ super(ASPPHead, self).__init__(**kwargs)
+ assert isinstance(dilations, (list, tuple))
+ self.dilations = dilations
+ self.image_pool = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.aspp_modules = ASPPModule(
+ dilations,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ (len(dilations) + 1) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02122ca0e68743b1bf7a893afae96042f23838c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py
@@ -0,0 +1,57 @@
+from abc import ABCMeta, abstractmethod
+from .decode_head import BaseDecodeHead
+class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
+ """Base class for cascade decode head used in
+ :class:`CascadeEncoderDecoder."""
+ def __init__(self, *args, **kwargs):
+ super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
+ @abstractmethod
+ def forward(self, inputs, prev_output):
+ """Placeholder of forward function."""
+ pass
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs, prev_output)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs, prev_output)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/cc_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/cc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9abb4e747f92657f4220b29788539340986c00
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/cc_head.py
@@ -0,0 +1,42 @@
+import torch
+from ..builder import HEADS
+from .fcn_head import FCNHead
+ from annotator.uniformer.mmcv.ops import CrissCrossAttention
+except ModuleNotFoundError:
+ CrissCrossAttention = None
+class CCHead(FCNHead):
+ """CCNet: Criss-Cross Attention for Semantic Segmentation.
+ This head is the implementation of `CCNet
+ `_.
+ Args:
+ recurrence (int): Number of recurrence of Criss Cross Attention
+ module. Default: 2.
+ """
+ def __init__(self, recurrence=2, **kwargs):
+ if CrissCrossAttention is None:
+ raise RuntimeError('Please install mmcv-full for '
+ 'CrissCrossAttention ops')
+ super(CCHead, self).__init__(num_convs=2, **kwargs)
+ self.recurrence = recurrence
+ self.cca = CrissCrossAttention(self.channels)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ for _ in range(self.recurrence):
+ output = self.cca(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/da_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/da_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd49fcfdc7c0a70f9485cc71843dcf3e0cb1774
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/da_head.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, Scale
+from torch import nn
+from annotator.uniformer.mmseg.core import add_prefix
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+class PAM(_SelfAttentionBlock):
+ """Position Attention Module (PAM)
+ Args:
+ in_channels (int): Input channels of key/query feature.
+ channels (int): Output channels of key/query transform.
+ """
+ def __init__(self, in_channels, channels):
+ super(PAM, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=None,
+ key_downsample=None,
+ key_query_num_convs=1,
+ key_query_norm=False,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=False,
+ with_out=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None)
+ self.gamma = Scale(0)
+ def forward(self, x):
+ """Forward function."""
+ out = super(PAM, self).forward(x, x)
+ out = self.gamma(out) + x
+ return out
+class CAM(nn.Module):
+ """Channel Attention Module (CAM)"""
+ def __init__(self):
+ super(CAM, self).__init__()
+ self.gamma = Scale(0)
+ def forward(self, x):
+ """Forward function."""
+ batch_size, channels, height, width = x.size()
+ proj_query = x.view(batch_size, channels, -1)
+ proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
+ energy = torch.bmm(proj_query, proj_key)
+ energy_new = torch.max(
+ energy, -1, keepdim=True)[0].expand_as(energy) - energy
+ attention = F.softmax(energy_new, dim=-1)
+ proj_value = x.view(batch_size, channels, -1)
+ out = torch.bmm(attention, proj_value)
+ out = out.view(batch_size, channels, height, width)
+ out = self.gamma(out) + x
+ return out
+class DAHead(BaseDecodeHead):
+ """Dual Attention Network for Scene Segmentation.
+ This head is the implementation of `DANet
+ `_.
+ Args:
+ pam_channels (int): The channels of Position Attention Module(PAM).
+ """
+ def __init__(self, pam_channels, **kwargs):
+ super(DAHead, self).__init__(**kwargs)
+ self.pam_channels = pam_channels
+ self.pam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam = PAM(self.channels, pam_channels)
+ self.pam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+ self.cam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam = CAM()
+ self.cam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+ def pam_cls_seg(self, feat):
+ """PAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.pam_conv_seg(feat)
+ return output
+ def cam_cls_seg(self, feat):
+ """CAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.cam_conv_seg(feat)
+ return output
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ pam_feat = self.pam_in_conv(x)
+ pam_feat = self.pam(pam_feat)
+ pam_feat = self.pam_out_conv(pam_feat)
+ pam_out = self.pam_cls_seg(pam_feat)
+ cam_feat = self.cam_in_conv(x)
+ cam_feat = self.cam(cam_feat)
+ cam_feat = self.cam_out_conv(cam_feat)
+ cam_out = self.cam_cls_seg(cam_feat)
+ feat_sum = pam_feat + cam_feat
+ pam_cam_out = self.cls_seg(feat_sum)
+ return pam_cam_out, pam_out, cam_out
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, only ``pam_cam`` is used."""
+ return self.forward(inputs)[0]
+ def losses(self, seg_logit, seg_label):
+ """Compute ``pam_cam``, ``pam``, ``cam`` loss."""
+ pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
+ loss = dict()
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
+ 'pam_cam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
+ return loss
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/decode_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a661b8f6fec5d4c031d3d85e80777ee63951a6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/decode_head.py
@@ -0,0 +1,234 @@
+from abc import ABCMeta, abstractmethod
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import normal_init
+from annotator.uniformer.mmcv.runner import auto_fp16, force_fp32
+from annotator.uniformer.mmseg.core import build_pixel_sampler
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import build_loss
+from ..losses import accuracy
+class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead.
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ ignore_index=255,
+ sampler=None,
+ align_corners=False):
+ super(BaseDecodeHead, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.loss_decode = build_loss(loss_decode)
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+ if sampler is not None:
+ self.sampler = build_pixel_sampler(sampler, context=self)
+ else:
+ self.sampler = None
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ Tensor: The transformed inputs
+ """
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+ return inputs
+ @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs)
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+ @force_fp32(apply_to=('seg_logit', ))
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ seg_logit = resize(
+ input=seg_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ if self.sampler is not None:
+ seg_weight = self.sampler.sample(seg_logit, seg_label)
+ else:
+ seg_weight = None
+ seg_label = seg_label.squeeze(1)
+ loss['loss_seg'] = self.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=seg_weight,
+ ignore_index=self.ignore_index)
+ loss['acc_seg'] = accuracy(seg_logit, seg_label)
+ return loss
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/dm_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/dm_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c963923126b53ce22f60813540a35badf24b3d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/dm_head.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+class DCM(nn.Module):
+ """Dynamic Convolutional Module used in DMNet.
+ Args:
+ filter_size (int): The filter size of generated convolution kernel
+ used in Dynamic Convolutional Module.
+ fusion (bool): Add one conv to fuse DCM output feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(DCM, self).__init__()
+ self.filter_size = filter_size
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
+ 0)
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.norm_cfg is not None:
+ self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
+ else:
+ self.norm = None
+ self.activate = build_activation_layer(self.act_cfg)
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, x):
+ """Forward function."""
+ generated_filter = self.filter_gen_conv(
+ F.adaptive_avg_pool2d(x, self.filter_size))
+ x = self.input_redu_conv(x)
+ b, c, h, w = x.shape
+ # [1, b * c, h, w], c = self.channels
+ x = x.view(1, b * c, h, w)
+ # [b * c, 1, filter_size, filter_size]
+ generated_filter = generated_filter.view(b * c, 1, self.filter_size,
+ self.filter_size)
+ pad = (self.filter_size - 1) // 2
+ if (self.filter_size - 1) % 2 == 0:
+ p2d = (pad, pad, pad, pad)
+ else:
+ p2d = (pad + 1, pad, pad + 1, pad)
+ x = F.pad(input=x, pad=p2d, mode='constant', value=0)
+ # [1, b * c, h, w]
+ output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
+ # [b, c, h, w]
+ output = output.view(b, c, h, w)
+ if self.norm is not None:
+ output = self.norm(output)
+ output = self.activate(output)
+ if self.fusion:
+ output = self.fusion_conv(output)
+ return output
+class DMHead(BaseDecodeHead):
+ """Dynamic Multi-scale Filters for Semantic Segmentation.
+ This head is the implementation of
+ `DMNet `_.
+ Args:
+ filter_sizes (tuple[int]): The size of generated convolutional filters
+ used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
+ fusion (bool): Add one conv to fuse DCM output feature.
+ """
+ def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
+ super(DMHead, self).__init__(**kwargs)
+ assert isinstance(filter_sizes, (list, tuple))
+ self.filter_sizes = filter_sizes
+ self.fusion = fusion
+ dcm_modules = []
+ for filter_size in self.filter_sizes:
+ dcm_modules.append(
+ DCM(filter_size,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.dcm_modules = nn.ModuleList(dcm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(filter_sizes) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ dcm_outs = [x]
+ for dcm_module in self.dcm_modules:
+ dcm_outs.append(dcm_module(x))
+ dcm_outs = torch.cat(dcm_outs, dim=1)
+ output = self.bottleneck(dcm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..333280c5947066fd3c7ebcfe302a0e7ad65480d5
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py
@@ -0,0 +1,131 @@
+import torch
+from annotator.uniformer.mmcv.cnn import NonLocal2d
+from torch import nn
+from ..builder import HEADS
+from .fcn_head import FCNHead
+class DisentangledNonLocal2d(NonLocal2d):
+ """Disentangled Non-Local Blocks.
+ Args:
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+ def __init__(self, *arg, temperature, **kwargs):
+ super().__init__(*arg, **kwargs)
+ self.temperature = temperature
+ self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
+ def embedded_gaussian(self, theta_x, phi_x):
+ """Embedded gaussian with temperature."""
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight /= self.temperature
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+ def forward(self, x):
+ # x: [N, C, H, W]
+ n = x.size(0)
+ # g_x: [N, HxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+ # theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+ # subtract mean
+ theta_x -= theta_x.mean(dim=-2, keepdim=True)
+ phi_x -= phi_x.mean(dim=-1, keepdim=True)
+ pairwise_func = getattr(self, self.mode)
+ # pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+ # y: [N, HxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # y: [N, C, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+ # unary_mask: [N, 1, HxW]
+ unary_mask = self.conv_mask(x)
+ unary_mask = unary_mask.view(n, 1, -1)
+ unary_mask = unary_mask.softmax(dim=-1)
+ # unary_x: [N, 1, C]
+ unary_x = torch.matmul(unary_mask, g_x)
+ # unary_x: [N, C, 1, 1]
+ unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
+ n, self.inter_channels, 1, 1)
+ output = x + self.conv_out(y + unary_x)
+ return output
+class DNLHead(FCNHead):
+ """Disentangled Non-Local Neural Networks.
+ This head is the implementation of `DNLNet
+ `_.
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: False.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ temperature=0.05,
+ **kwargs):
+ super(DNLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.temperature = temperature
+ self.dnl_block = DisentangledNonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode,
+ temperature=self.temperature)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.dnl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ema_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ema_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..12267cb40569d2b5a4a2955a6dc2671377ff5e0a
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ema_head.py
@@ -0,0 +1,168 @@
+import math
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+def reduce_mean(tensor):
+ """Reduce mean when distributed training."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
+class EMAModule(nn.Module):
+ """Expectation Maximization Attention Module used in EMANet.
+ Args:
+ channels (int): Channels of the whole module.
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ """
+ def __init__(self, channels, num_bases, num_stages, momentum):
+ super(EMAModule, self).__init__()
+ assert num_stages >= 1, 'num_stages must be at least 1!'
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.momentum = momentum
+ bases = torch.zeros(1, channels, self.num_bases)
+ bases.normal_(0, math.sqrt(2. / self.num_bases))
+ # [1, channels, num_bases]
+ bases = F.normalize(bases, dim=1, p=2)
+ self.register_buffer('bases', bases)
+ def forward(self, feats):
+ """Forward function."""
+ batch_size, channels, height, width = feats.size()
+ # [batch_size, channels, height*width]
+ feats = feats.view(batch_size, channels, height * width)
+ # [batch_size, channels, num_bases]
+ bases = self.bases.repeat(batch_size, 1, 1)
+ with torch.no_grad():
+ for i in range(self.num_stages):
+ # [batch_size, height*width, num_bases]
+ attention = torch.einsum('bcn,bck->bnk', feats, bases)
+ attention = F.softmax(attention, dim=2)
+ # l1 norm
+ attention_normed = F.normalize(attention, dim=1, p=1)
+ # [batch_size, channels, num_bases]
+ bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+ feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
+ feats_recon = feats_recon.view(batch_size, channels, height, width)
+ if self.training:
+ bases = bases.mean(dim=0, keepdim=True)
+ bases = reduce_mean(bases)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+ self.bases = (1 -
+ self.momentum) * self.bases + self.momentum * bases
+ return feats_recon
+class EMAHead(BaseDecodeHead):
+ """Expectation Maximization Attention Networks for Semantic Segmentation.
+ This head is the implementation of `EMANet
+ `_.
+ Args:
+ ema_channels (int): EMA module channels
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer. Default: True
+ momentum (float): Momentum to update the base. Default: 0.1.
+ """
+ def __init__(self,
+ ema_channels,
+ num_bases,
+ num_stages,
+ concat_input=True,
+ momentum=0.1,
+ **kwargs):
+ super(EMAHead, self).__init__(**kwargs)
+ self.ema_channels = ema_channels
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.concat_input = concat_input
+ self.momentum = momentum
+ self.ema_module = EMAModule(self.ema_channels, self.num_bases,
+ self.num_stages, self.momentum)
+ self.ema_in_conv = ConvModule(
+ self.in_channels,
+ self.ema_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # project (0, inf) -> (-inf, inf)
+ self.ema_mid_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=None,
+ act_cfg=None)
+ for param in self.ema_mid_conv.parameters():
+ param.requires_grad = False
+ self.ema_out_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.bottleneck = ConvModule(
+ self.ema_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.ema_in_conv(x)
+ identity = feats
+ feats = self.ema_mid_conv(feats)
+ recon = self.ema_module(feats)
+ recon = F.relu(recon, inplace=True)
+ recon = self.ema_out_conv(recon)
+ output = F.relu(identity + recon, inplace=True)
+ output = self.bottleneck(output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/enc_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/enc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..da57af617e05d41761628fd2d6d232655b32d905
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/enc_head.py
@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, build_norm_layer
+from annotator.uniformer.mmseg.ops import Encoding, resize
+from ..builder import HEADS, build_loss
+from .decode_head import BaseDecodeHead
+class EncModule(nn.Module):
+ """Encoding Module used in EncNet.
+ Args:
+ in_channels (int): Input channels.
+ num_codes (int): Number of code words.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
+ super(EncModule, self).__init__()
+ self.encoding_project = ConvModule(
+ in_channels,
+ in_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ # TODO: resolve this hack
+ # change to 1d
+ if norm_cfg is not None:
+ encoding_norm_cfg = norm_cfg.copy()
+ if encoding_norm_cfg['type'] in ['BN', 'IN']:
+ encoding_norm_cfg['type'] += '1d'
+ else:
+ encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
+ '2d', '1d')
+ else:
+ # fallback to BN1d
+ encoding_norm_cfg = dict(type='BN1d')
+ self.encoding = nn.Sequential(
+ Encoding(channels=in_channels, num_codes=num_codes),
+ build_norm_layer(encoding_norm_cfg, num_codes)[1],
+ nn.ReLU(inplace=True))
+ self.fc = nn.Sequential(
+ nn.Linear(in_channels, in_channels), nn.Sigmoid())
+ def forward(self, x):
+ """Forward function."""
+ encoding_projection = self.encoding_project(x)
+ encoding_feat = self.encoding(encoding_projection).mean(dim=1)
+ batch_size, channels, _, _ = x.size()
+ gamma = self.fc(encoding_feat)
+ y = gamma.view(batch_size, channels, 1, 1)
+ output = F.relu_(x + x * y)
+ return encoding_feat, output
+class EncHead(BaseDecodeHead):
+ """Context Encoding for Semantic Segmentation.
+ This head is the implementation of `EncNet
+ `_.
+ Args:
+ num_codes (int): Number of code words. Default: 32.
+ use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
+ regularize the training. Default: True.
+ add_lateral (bool): Whether use lateral connection to fuse features.
+ Default: False.
+ loss_se_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
+ """
+ def __init__(self,
+ num_codes=32,
+ use_se_loss=True,
+ add_lateral=False,
+ loss_se_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=0.2),
+ **kwargs):
+ super(EncHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ self.use_se_loss = use_se_loss
+ self.add_lateral = add_lateral
+ self.num_codes = num_codes
+ self.bottleneck = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if add_lateral:
+ self.lateral_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the last one
+ self.lateral_convs.append(
+ ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.fusion = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.enc_module = EncModule(
+ self.channels,
+ num_codes=num_codes,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.use_se_loss:
+ self.loss_se_decode = build_loss(loss_se_decode)
+ self.se_layer = nn.Linear(self.channels, self.num_classes)
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+ feat = self.bottleneck(inputs[-1])
+ if self.add_lateral:
+ laterals = [
+ resize(
+ lateral_conv(inputs[i]),
+ size=feat.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ feat = self.fusion(torch.cat([feat, *laterals], 1))
+ encode_feat, output = self.enc_module(feat)
+ output = self.cls_seg(output)
+ if self.use_se_loss:
+ se_output = self.se_layer(encode_feat)
+ return output, se_output
+ else:
+ return output
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, ignore se_loss."""
+ if self.use_se_loss:
+ return self.forward(inputs)[0]
+ else:
+ return self.forward(inputs)
+ @staticmethod
+ def _convert_to_onehot_labels(seg_label, num_classes):
+ """Convert segmentation label to onehot.
+ Args:
+ seg_label (Tensor): Segmentation label of shape (N, H, W).
+ num_classes (int): Number of classes.
+ Returns:
+ Tensor: Onehot labels of shape (N, num_classes).
+ """
+ batch_size = seg_label.size(0)
+ onehot_labels = seg_label.new_zeros((batch_size, num_classes))
+ for i in range(batch_size):
+ hist = seg_label[i].float().histc(
+ bins=num_classes, min=0, max=num_classes - 1)
+ onehot_labels[i] = hist > 0
+ return onehot_labels
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation and semantic encoding loss."""
+ seg_logit, se_seg_logit = seg_logit
+ loss = dict()
+ loss.update(super(EncHead, self).losses(seg_logit, seg_label))
+ se_loss = self.loss_se_decode(
+ se_seg_logit,
+ self._convert_to_onehot_labels(seg_label, self.num_classes))
+ loss['loss_se'] = se_loss
+ return loss
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..edb32c283fa4baada6b4a0bf3f7540c3580c3468
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+class FCNHead(BaseDecodeHead):
+ """Fully Convolution Networks for Semantic Segmentation.
+ This head is implemented of `FCNNet `_.
+ Args:
+ num_convs (int): Number of convs in the head. Default: 2.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer.
+ dilation (int): The dilation rate for convs in the head. Default: 1.
+ """
+ def __init__(self,
+ num_convs=2,
+ kernel_size=3,
+ concat_input=True,
+ dilation=1,
+ **kwargs):
+ assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
+ self.num_convs = num_convs
+ self.concat_input = concat_input
+ self.kernel_size = kernel_size
+ super(FCNHead, self).__init__(**kwargs)
+ if num_convs == 0:
+ assert self.in_channels == self.channels
+ conv_padding = (kernel_size // 2) * dilation
+ convs = []
+ convs.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ for i in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if num_convs == 0:
+ self.convs = nn.Identity()
+ else:
+ self.convs = nn.Sequential(*convs)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs(x)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1241c55b0813d1ecdddf1e66e7c5031fbf78ed50
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py
@@ -0,0 +1,68 @@
+import numpy as np
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+class FPNHead(BaseDecodeHead):
+ """Panoptic Feature Pyramid Networks.
+ This head is the implementation of `Semantic FPN
+ `_.
+ Args:
+ feature_strides (tuple[int]): The strides for input feature maps.
+ stack_lateral. All strides suppose to be power of 2. The first
+ one is of largest resolution.
+ """
+ def __init__(self, feature_strides, **kwargs):
+ super(FPNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(feature_strides) == len(self.in_channels)
+ assert min(feature_strides) == feature_strides[0]
+ self.feature_strides = feature_strides
+ self.scale_heads = nn.ModuleList()
+ for i in range(len(feature_strides)):
+ head_length = max(
+ 1,
+ int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
+ scale_head = []
+ for k in range(head_length):
+ scale_head.append(
+ ConvModule(
+ self.in_channels[i] if k == 0 else self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if feature_strides[i] != feature_strides[0]:
+ scale_head.append(
+ nn.Upsample(
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=self.align_corners))
+ self.scale_heads.append(nn.Sequential(*scale_head))
+ def forward(self, inputs):
+ x = self._transform_inputs(inputs)
+ output = self.scale_heads[0](x[0])
+ for i in range(1, len(self.feature_strides)):
+ # non inplace
+ output = output + resize(
+ self.scale_heads[i](x[i]),
+ size=output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/gc_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/gc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..70741245af975800840709911bd18d72247e3e04
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/gc_head.py
@@ -0,0 +1,47 @@
+import torch
+from annotator.uniformer.mmcv.cnn import ContextBlock
+from ..builder import HEADS
+from .fcn_head import FCNHead
+class GCHead(FCNHead):
+ """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
+ This head is the implementation of `GCNet
+ `_.
+ Args:
+ ratio (float): Multiplier of channels ratio. Default: 1/4.
+ pooling_type (str): The pooling type of context aggregation.
+ Options are 'att', 'avg'. Default: 'avg'.
+ fusion_types (tuple[str]): The fusion type for feature fusion.
+ Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
+ """
+ def __init__(self,
+ ratio=1 / 4.,
+ pooling_type='att',
+ fusion_types=('channel_add', ),
+ **kwargs):
+ super(GCHead, self).__init__(num_convs=2, **kwargs)
+ self.ratio = ratio
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ self.gc_block = ContextBlock(
+ in_channels=self.channels,
+ ratio=self.ratio,
+ pooling_type=self.pooling_type,
+ fusion_types=self.fusion_types)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.gc_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bf320934d787aaa11984a0c4effe9ad8015b22
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py
@@ -0,0 +1,90 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv import is_tuple_of
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+class LRASPPHead(BaseDecodeHead):
+ """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
+ This head is the improved implementation of `Searching for MobileNetV3
+ `_.
+ Args:
+ branch_channels (tuple[int]): The number of output channels in every
+ each branch. Default: (32, 64).
+ """
+ def __init__(self, branch_channels=(32, 64), **kwargs):
+ super(LRASPPHead, self).__init__(**kwargs)
+ if self.input_transform != 'multiple_select':
+ raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
+ f'must be \'multiple_select\'. But received '
+ f'\'{self.input_transform}\'')
+ assert is_tuple_of(branch_channels, int)
+ assert len(branch_channels) == len(self.in_channels) - 1
+ self.branch_channels = branch_channels
+ self.convs = nn.Sequential()
+ self.conv_ups = nn.Sequential()
+ for i in range(len(branch_channels)):
+ self.convs.add_module(
+ f'conv{i}',
+ nn.Conv2d(
+ self.in_channels[i], branch_channels[i], 1, bias=False))
+ self.conv_ups.add_module(
+ f'conv_up{i}',
+ ConvModule(
+ self.channels + branch_channels[i],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False))
+ self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
+ self.aspp_conv = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False)
+ self.image_pool = nn.Sequential(
+ nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
+ ConvModule(
+ self.in_channels[2],
+ self.channels,
+ 1,
+ act_cfg=dict(type='Sigmoid'),
+ bias=False))
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+ x = inputs[-1]
+ x = self.aspp_conv(x) * resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = self.conv_up_input(x)
+ for i in range(len(self.branch_channels) - 1, -1, -1):
+ x = resize(
+ x,
+ size=inputs[i].size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = torch.cat([x, self.convs[i](inputs[i])], 1)
+ x = self.conv_ups[i](x)
+ return self.cls_seg(x)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/nl_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/nl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eee424199e6aa363b564e2a3340a070db04db86
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/nl_head.py
@@ -0,0 +1,49 @@
+import torch
+from annotator.uniformer.mmcv.cnn import NonLocal2d
+from ..builder import HEADS
+from .fcn_head import FCNHead
+class NLHead(FCNHead):
+ """Non-local Neural Networks.
+ This head is the implementation of `NLNet
+ `_.
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: True.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ """
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(NLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.nl_block = NonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.nl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..715852e94e81dc46623972748285d2d19237a341
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .cascade_decode_head import BaseCascadeDecodeHead
+class SpatialGatherModule(nn.Module):
+ """Aggregate the context features according to the initial predicted
+ probability distribution.
+ Employ the soft-weighted method to aggregate the context.
+ """
+ def __init__(self, scale):
+ super(SpatialGatherModule, self).__init__()
+ self.scale = scale
+ def forward(self, feats, probs):
+ """Forward function."""
+ batch_size, num_classes, height, width = probs.size()
+ channels = feats.size(1)
+ probs = probs.view(batch_size, num_classes, -1)
+ feats = feats.view(batch_size, channels, -1)
+ # [batch_size, height*width, num_classes]
+ feats = feats.permute(0, 2, 1)
+ # [batch_size, channels, height*width]
+ probs = F.softmax(self.scale * probs, dim=2)
+ # [batch_size, channels, num_classes]
+ ocr_context = torch.matmul(probs, feats)
+ ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
+ return ocr_context
+class ObjectAttentionBlock(_SelfAttentionBlock):
+ """Make a OCR used SelfAttentionBlock."""
+ def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
+ act_cfg):
+ if scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=scale)
+ else:
+ query_downsample = None
+ super(ObjectAttentionBlock, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=query_downsample,
+ key_downsample=None,
+ key_query_num_convs=2,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=True,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.bottleneck = ConvModule(
+ in_channels * 2,
+ in_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ context = super(ObjectAttentionBlock,
+ self).forward(query_feats, key_feats)
+ output = self.bottleneck(torch.cat([context, query_feats], dim=1))
+ if self.query_downsample is not None:
+ output = resize(query_feats)
+ return output
+class OCRHead(BaseCascadeDecodeHead):
+ """Object-Contextual Representations for Semantic Segmentation.
+ This head is the implementation of `OCRNet
+ `_.
+ Args:
+ ocr_channels (int): The intermediate channels of OCR block.
+ scale (int): The scale of probability map in SpatialGatherModule in
+ Default: 1.
+ """
+ def __init__(self, ocr_channels, scale=1, **kwargs):
+ super(OCRHead, self).__init__(**kwargs)
+ self.ocr_channels = ocr_channels
+ self.scale = scale
+ self.object_context_block = ObjectAttentionBlock(
+ self.channels,
+ self.ocr_channels,
+ self.scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.spatial_gather_module = SpatialGatherModule(self.scale)
+ self.bottleneck = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs, prev_output):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.bottleneck(x)
+ context = self.spatial_gather_module(feats, prev_output)
+ object_context = self.object_context_block(feats, context)
+ output = self.cls_seg(object_context)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/point_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/point_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3342aa28bb8d264b2c3d01cbf5098d145943c193
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/point_head.py
@@ -0,0 +1,349 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, normal_init
+from annotator.uniformer.mmcv.ops import point_sample
+from annotator.uniformer.mmseg.models.builder import HEADS
+from annotator.uniformer.mmseg.ops import resize
+from ..losses import accuracy
+from .cascade_decode_head import BaseCascadeDecodeHead
+def calculate_uncertainty(seg_logits):
+ """Estimate uncertainty based on seg logits.
+ For each location of the prediction ``seg_logits`` we estimate
+ uncertainty as the difference between top first and top second
+ predicted logits.
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits,
+ shape (batch_size, num_classes, height, width).
+ Returns:
+ scores (Tensor): T uncertainty scores with the most uncertain
+ locations having the highest uncertainty score, shape (
+ batch_size, 1, height, width)
+ """
+ top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
+ return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
+class PointHead(BaseCascadeDecodeHead):
+ """A mask point head use in PointRend.
+ ``PointHead`` use shared multi-layer perceptron (equivalent to
+ nn.Conv1d) to predict the logit of input points. The fine-grained feature
+ and coarse feature will be concatenate together for predication.
+ Args:
+ num_fcs (int): Number of fc layers in the head. Default: 3.
+ in_channels (int): Number of input channels. Default: 256.
+ fc_channels (int): Number of fc channels. Default: 256.
+ num_classes (int): Number of classes for logits. Default: 80.
+ class_agnostic (bool): Whether use class agnostic classification.
+ If so, the output channels of logits will be 1. Default: False.
+ coarse_pred_each_layer (bool): Whether concatenate coarse feature with
+ the output of each fc layer. Default: True.
+ conv_cfg (dict|None): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d'))
+ norm_cfg (dict|None): Dictionary to construct and config norm layer.
+ Default: None.
+ loss_point (dict): Dictionary to construct and config loss layer of
+ point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
+ loss_weight=1.0).
+ """
+ def __init__(self,
+ num_fcs=3,
+ coarse_pred_each_layer=True,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU', inplace=False),
+ **kwargs):
+ super(PointHead, self).__init__(
+ input_transform='multiple_select',
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ **kwargs)
+ self.num_fcs = num_fcs
+ self.coarse_pred_each_layer = coarse_pred_each_layer
+ fc_in_channels = sum(self.in_channels) + self.num_classes
+ fc_channels = self.channels
+ self.fcs = nn.ModuleList()
+ for k in range(num_fcs):
+ fc = ConvModule(
+ fc_in_channels,
+ fc_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.fcs.append(fc)
+ fc_in_channels = fc_channels
+ fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
+ else 0
+ self.fc_seg = nn.Conv1d(
+ fc_in_channels,
+ self.num_classes,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ if self.dropout_ratio > 0:
+ self.dropout = nn.Dropout(self.dropout_ratio)
+ delattr(self, 'conv_seg')
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.fc_seg, std=0.001)
+ def cls_seg(self, feat):
+ """Classify each pixel with fc."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.fc_seg(feat)
+ return output
+ def forward(self, fine_grained_point_feats, coarse_point_feats):
+ x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
+ for fc in self.fcs:
+ x = fc(x)
+ if self.coarse_pred_each_layer:
+ x = torch.cat((x, coarse_point_feats), dim=1)
+ return self.cls_seg(x)
+ def _get_fine_grained_point_feats(self, x, points):
+ """Sample from fine grained features.
+ Args:
+ x (list[Tensor]): Feature pyramid from by neck or backbone.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+ Returns:
+ fine_grained_feats (Tensor): Sampled fine grained feature,
+ shape (batch_size, sum(channels of x), num_points).
+ """
+ fine_grained_feats_list = [
+ point_sample(_, points, align_corners=self.align_corners)
+ for _ in x
+ ]
+ if len(fine_grained_feats_list) > 1:
+ fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
+ else:
+ fine_grained_feats = fine_grained_feats_list[0]
+ return fine_grained_feats
+ def _get_coarse_point_feats(self, prev_output, points):
+ """Sample from fine grained features.
+ Args:
+ prev_output (list[Tensor]): Prediction of previous decode head.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+ Returns:
+ coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
+ num_classes, num_points).
+ """
+ coarse_feats = point_sample(
+ prev_output, points, align_corners=self.align_corners)
+ return coarse_feats
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ x = self._transform_inputs(inputs)
+ with torch.no_grad():
+ points = self.get_points_train(
+ prev_output, calculate_uncertainty, cfg=train_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+ point_label = point_sample(
+ gt_semantic_seg.float(),
+ points,
+ mode='nearest',
+ align_corners=self.align_corners)
+ point_label = point_label.squeeze(1).long()
+ losses = self.losses(point_logits, point_label)
+ return losses
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ x = self._transform_inputs(inputs)
+ refined_seg_logits = prev_output.clone()
+ for _ in range(test_cfg.subdivision_steps):
+ refined_seg_logits = resize(
+ refined_seg_logits,
+ scale_factor=test_cfg.scale_factor,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ batch_size, channels, height, width = refined_seg_logits.shape
+ point_indices, points = self.get_points_test(
+ refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(
+ prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+ point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+ refined_seg_logits = refined_seg_logits.reshape(
+ batch_size, channels, height * width)
+ refined_seg_logits = refined_seg_logits.scatter_(
+ 2, point_indices, point_logits)
+ refined_seg_logits = refined_seg_logits.view(
+ batch_size, channels, height, width)
+ return refined_seg_logits
+ def losses(self, point_logits, point_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ loss['loss_point'] = self.loss_decode(
+ point_logits, point_label, ignore_index=self.ignore_index)
+ loss['acc_point'] = accuracy(point_logits, point_label)
+ return loss
+ def get_points_train(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for training.
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ 'uncertainty_func' function that takes point's logit prediction as
+ input.
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits, shape (
+ batch_size, num_classes, height, width).
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Training config of point head.
+ Returns:
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains the coordinates of ``num_points`` sampled
+ points.
+ """
+ num_points = cfg.num_points
+ oversample_ratio = cfg.oversample_ratio
+ importance_sample_ratio = cfg.importance_sample_ratio
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = seg_logits.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(
+ batch_size, num_sampled, 2, device=seg_logits.device)
+ point_logits = point_sample(seg_logits, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = uncertainty_func(point_logits)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(
+ point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(
+ batch_size, dtype=torch.long, device=seg_logits.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_point_coords = torch.rand(
+ batch_size, num_random_points, 2, device=seg_logits.device)
+ point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
+ return point_coords
+ def get_points_test(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for testing.
+ Find ``num_points`` most uncertain points from ``uncertainty_map``.
+ Args:
+ seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
+ height, width) for class-specific or class-agnostic prediction.
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Testing config of point head.
+ Returns:
+ point_indices (Tensor): A tensor of shape (batch_size, num_points)
+ that contains indices from [0, height x width) of the most
+ uncertain points.
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the ``height x width`` grid .
+ """
+ num_points = cfg.subdivision_num_points
+ uncertainty_map = uncertainty_func(seg_logits)
+ batch_size, _, height, width = uncertainty_map.shape
+ h_step = 1.0 / height
+ w_step = 1.0 / width
+ uncertainty_map = uncertainty_map.view(batch_size, height * width)
+ num_points = min(height * width, num_points)
+ point_indices = uncertainty_map.topk(num_points, dim=1)[1]
+ point_coords = torch.zeros(
+ batch_size,
+ num_points,
+ 2,
+ dtype=torch.float,
+ device=seg_logits.device)
+ point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
+ width).float() * w_step
+ point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
+ width).float() * h_step
+ return point_indices, point_coords
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/psa_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/psa_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..480dbd1a081262e45bf87e32c4a339ac8f8b4ffb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/psa_head.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+ from annotator.uniformer.mmcv.ops import PSAMask
+except ModuleNotFoundError:
+ PSAMask = None
+class PSAHead(BaseDecodeHead):
+ """Point-wise Spatial Attention Network for Scene Parsing.
+ This head is the implementation of `PSANet
+ `_.
+ Args:
+ mask_size (tuple[int]): The PSA mask size. It usually equals input
+ size.
+ psa_type (str): The type of psa module. Options are 'collect',
+ 'distribute', 'bi-direction'. Default: 'bi-direction'
+ compact (bool): Whether use compact map for 'collect' mode.
+ Default: True.
+ shrink_factor (int): The downsample factors of psa mask. Default: 2.
+ normalization_factor (float): The normalize factor of attention.
+ psa_softmax (bool): Whether use softmax for attention.
+ """
+ def __init__(self,
+ mask_size,
+ psa_type='bi-direction',
+ compact=False,
+ shrink_factor=2,
+ normalization_factor=1.0,
+ psa_softmax=True,
+ **kwargs):
+ if PSAMask is None:
+ raise RuntimeError('Please install mmcv-full for PSAMask ops')
+ super(PSAHead, self).__init__(**kwargs)
+ assert psa_type in ['collect', 'distribute', 'bi-direction']
+ self.psa_type = psa_type
+ self.compact = compact
+ self.shrink_factor = shrink_factor
+ self.mask_size = mask_size
+ mask_h, mask_w = mask_size
+ self.psa_softmax = psa_softmax
+ if normalization_factor is None:
+ normalization_factor = mask_h * mask_w
+ self.normalization_factor = normalization_factor
+ self.reduce = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ if psa_type == 'bi-direction':
+ self.reduce_p = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention_p = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ self.psamask_collect = PSAMask('collect', mask_size)
+ self.psamask_distribute = PSAMask('distribute', mask_size)
+ else:
+ self.psamask = PSAMask(psa_type, mask_size)
+ self.proj = ConvModule(
+ self.channels * (2 if psa_type == 'bi-direction' else 1),
+ self.in_channels,
+ kernel_size=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ self.in_channels * 2,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ identity = x
+ align_corners = self.align_corners
+ if self.psa_type in ['collect', 'distribute']:
+ out = self.reduce(x)
+ n, c, h, w = out.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ out = resize(
+ out,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y = self.attention(out)
+ if self.compact:
+ if self.psa_type == 'collect':
+ y = y.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y = self.psamask(y)
+ if self.psa_softmax:
+ y = F.softmax(y, dim=1)
+ out = torch.bmm(
+ out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ else:
+ x_col = self.reduce(x)
+ x_dis = self.reduce_p(x)
+ n, c, h, w = x_col.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ x_col = resize(
+ x_col,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ x_dis = resize(
+ x_dis,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y_col = self.attention(x_col)
+ y_dis = self.attention_p(x_dis)
+ if self.compact:
+ y_dis = y_dis.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y_col = self.psamask_collect(y_col)
+ y_dis = self.psamask_distribute(y_dis)
+ if self.psa_softmax:
+ y_col = F.softmax(y_col, dim=1)
+ y_dis = F.softmax(y_dis, dim=1)
+ x_col = torch.bmm(
+ x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ x_dis = torch.bmm(
+ x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ out = torch.cat([x_col, x_dis], 1)
+ out = self.proj(out)
+ out = resize(
+ out,
+ size=identity.shape[2:],
+ mode='bilinear',
+ align_corners=align_corners)
+ out = self.bottleneck(torch.cat((identity, out), dim=1))
+ out = self.cls_seg(out)
+ return out
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/psp_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/psp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5f1e71c70c3a20f4007c263ec471a87bb214a48
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/psp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+class PPM(nn.ModuleList):
+ """Pooling Pyramid Module used in PSPNet.
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ """
+ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg, align_corners):
+ super(PPM, self).__init__()
+ self.pool_scales = pool_scales
+ self.align_corners = align_corners
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for pool_scale in pool_scales:
+ self.append(
+ nn.Sequential(
+ nn.AdaptiveAvgPool2d(pool_scale),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)))
+ def forward(self, x):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(x)
+ upsampled_ppm_out = resize(
+ ppm_out,
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+class PSPHead(BaseDecodeHead):
+ """Pyramid Scene Parsing Network.
+ This head is the implementation of
+ `PSPNet `_.
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module. Default: (1, 2, 3, 6).
+ """
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(PSPHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.psp_modules = PPM(
+ self.pool_scales,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3339a7ac56e77dfc638e9bffb557d4699148686b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .aspp_head import ASPPHead, ASPPModule
+class DepthwiseSeparableASPPModule(ASPPModule):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
+ conv."""
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
+ for i, dilation in enumerate(self.dilations):
+ if dilation > 1:
+ self[i] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ dilation=dilation,
+ padding=dilation,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+class DepthwiseSeparableASPPHead(ASPPHead):
+ """Encoder-Decoder with Atrous Separable Convolution for Semantic Image
+ Segmentation.
+ This head is the implementation of `DeepLabV3+
+ `_.
+ Args:
+ c1_in_channels (int): The input channels of c1 decoder. If is 0,
+ the no decoder will be used.
+ c1_channels (int): The intermediate channels of c1 decoder.
+ """
+ def __init__(self, c1_in_channels, c1_channels, **kwargs):
+ super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
+ assert c1_in_channels >= 0
+ self.aspp_modules = DepthwiseSeparableASPPModule(
+ dilations=self.dilations,
+ in_channels=self.in_channels,
+ channels=self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if c1_in_channels > 0:
+ self.c1_bottleneck = ConvModule(
+ c1_in_channels,
+ c1_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ else:
+ self.c1_bottleneck = None
+ self.sep_bottleneck = nn.Sequential(
+ DepthwiseSeparableConvModule(
+ self.channels + c1_channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ if self.c1_bottleneck is not None:
+ c1_output = self.c1_bottleneck(inputs[0])
+ output = resize(
+ input=output,
+ size=c1_output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ output = torch.cat([output, c1_output], dim=1)
+ output = self.sep_bottleneck(output)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0986143fa4f2bd36f5271354fe5f843f35b9e6f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py
@@ -0,0 +1,51 @@
+from annotator.uniformer.mmcv.cnn import DepthwiseSeparableConvModule
+from ..builder import HEADS
+from .fcn_head import FCNHead
+class DepthwiseSeparableFCNHead(FCNHead):
+ """Depthwise-Separable Fully Convolutional Network for Semantic
+ Segmentation.
+ This head is implemented according to Fast-SCNN paper.
+ Args:
+ in_channels(int): Number of output channels of FFM.
+ channels(int): Number of middle-stage channels in the decode head.
+ concat_input(bool): Whether to concatenate original decode input into
+ the result of several consecutive convolution layers.
+ Default: True.
+ num_classes(int): Used to determine the dimension of
+ final prediction tensor.
+ in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
+ norm_cfg (dict | None): Config of norm layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ loss_decode(dict): Config of loss type and some
+ relevant additional options.
+ """
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
+ self.convs[0] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+ for i in range(1, self.num_convs):
+ self.convs[i] = DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+ if self.concat_input:
+ self.conv_cat = DepthwiseSeparableConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/decode_heads/uper_head.py b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/uper_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1301b706b0d83ed714bbdee8ee24693f150455
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/decode_heads/uper_head.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+from .psp_head import PPM
+class UPerHead(BaseDecodeHead):
+ """Unified Perceptual Parsing for Scene Understanding.
+ This head is the implementation of `UPerNet
+ `_.
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module applied on the last feature. Default: (1, 2, 3, 6).
+ """
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(UPerHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ # PSP Module
+ self.psp_modules = PPM(
+ pool_scales,
+ self.in_channels[-1],
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels[-1] + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # FPN Module
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the top layer
+ l_conv = ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+ self.fpn_bottleneck = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ def psp_forward(self, inputs):
+ """Forward function of PSP module."""
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+ return output
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ laterals.append(self.psp_forward(inputs))
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += resize(
+ laterals[i],
+ size=prev_shape,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ # build outputs
+ fpn_outs = [
+ self.fpn_convs[i](laterals[i])
+ for i in range(used_backbone_levels - 1)
+ ]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = resize(
+ fpn_outs[i],
+ size=fpn_outs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ fpn_outs = torch.cat(fpn_outs, dim=1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/losses/__init__.py b/ControlNet/annotator/uniformer/mmseg/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..beca72045694273d63465bac2f27dbc6672271db
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/losses/__init__.py
@@ -0,0 +1,12 @@
+from .accuracy import Accuracy, accuracy
+from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
+ cross_entropy, mask_cross_entropy)
+from .dice_loss import DiceLoss
+from .lovasz_loss import LovaszLoss
+from .utils import reduce_loss, weight_reduce_loss, weighted_loss
+__all__ = [
+ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
+ 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
+ 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
diff --git a/ControlNet/annotator/uniformer/mmseg/models/losses/accuracy.py b/ControlNet/annotator/uniformer/mmseg/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fd2e7e74a0f721c4a814c09d6e453e5956bb38
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/losses/accuracy.py
@@ -0,0 +1,78 @@
+import torch.nn as nn
+def accuracy(pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+ maxk = max(topk)
+ if pred.size(0) == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ assert pred.ndim == target.ndim + 1
+ assert pred.size(0) == target.size(0)
+ assert maxk <= pred.size(1), \
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+ pred_value, pred_label = pred.topk(maxk, dim=1)
+ # transpose to shape (maxk, N, ...)
+ pred_label = pred_label.transpose(0, 1)
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
+ if thresh is not None:
+ # Only prediction values larger than thresh are counted as correct
+ correct = correct & (pred_value > thresh).t()
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / target.numel()))
+ return res[0] if return_single else res
+class Accuracy(nn.Module):
+ """Accuracy calculation module."""
+ def __init__(self, topk=(1, ), thresh=None):
+ """Module to calculate the accuracy.
+ Args:
+ topk (tuple, optional): The criterion used to calculate the
+ accuracy. Defaults to (1,).
+ thresh (float, optional): If not None, predictions with scores
+ under this threshold are considered incorrect. Default to None.
+ """
+ super().__init__()
+ self.topk = topk
+ self.thresh = thresh
+ def forward(self, pred, target):
+ """Forward function to calculate accuracy.
+ Args:
+ pred (torch.Tensor): Prediction of models.
+ target (torch.Tensor): Target for each prediction.
+ Returns:
+ tuple[float]: The accuracies under different topk criterions.
+ """
+ return accuracy(pred, target, self.topk, self.thresh)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py b/ControlNet/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c0790c98616bb69621deed55547fc04c7392ef
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+def cross_entropy(pred,
+ label,
+ weight=None,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=-100):
+ """The wrapper function for :func:`F.cross_entropy`"""
+ # class_weight is a manual rescaling weight given to each class.
+ # If given, has to be a Tensor of size C element-wise losses
+ loss = F.cross_entropy(
+ pred,
+ label,
+ weight=class_weight,
+ reduction='none',
+ ignore_index=ignore_index)
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+ return loss
+def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_zeros(target_shape)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(valid_mask, as_tuple=True)
+ if inds[0].numel() > 0:
+ if labels.dim() == 3:
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
+ else:
+ bin_labels[inds[0], labels[valid_mask]] = 1
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
+ bin_label_weights *= valid_mask
+ return bin_labels, bin_label_weights
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=255):
+ """Calculate the binary CrossEntropy loss.
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored. Default: 255
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if pred.dim() != label.dim():
+ assert (pred.dim() == 2 and label.dim() == 1) or (
+ pred.dim() == 4 and label.dim() == 3), \
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
+ 'H, W], label shape [N, H, W] are supported'
+ label, weight = _expand_onehot_labels(label, weight, pred.shape,
+ ignore_index)
+ # weighted element-wise losses
+ if weight is not None:
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, label.float(), pos_weight=class_weight, reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=None):
+ """Calculate the CrossEntropy loss for masks.
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
+class CrossEntropyLoss(nn.Module):
+ """CrossEntropyLoss.
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/ControlNet/annotator/uniformer/mmseg/models/losses/dice_loss.py b/ControlNet/annotator/uniformer/mmseg/models/losses/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..27a77b962d7d8b3079c7d6cd9db52280c6fb4970
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/losses/dice_loss.py
@@ -0,0 +1,119 @@
+"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
+segmentron/solver/loss.py (Apache-2.0 License)"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..builder import LOSSES
+from .utils import get_class_weight, weighted_loss
+def dice_loss(pred,
+ target,
+ valid_mask,
+ smooth=1,
+ exponent=2,
+ class_weight=None,
+ ignore_index=255):
+ assert pred.shape[0] == target.shape[0]
+ total_loss = 0
+ num_classes = pred.shape[1]
+ for i in range(num_classes):
+ if i != ignore_index:
+ dice_loss = binary_dice_loss(
+ pred[:, i],
+ target[..., i],
+ valid_mask=valid_mask,
+ smooth=smooth,
+ exponent=exponent)
+ if class_weight is not None:
+ dice_loss *= class_weight[i]
+ total_loss += dice_loss
+ return total_loss / num_classes
+def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
+ assert pred.shape[0] == target.shape[0]
+ pred = pred.reshape(pred.shape[0], -1)
+ target = target.reshape(target.shape[0], -1)
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
+ den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
+ return 1 - num / den
+class DiceLoss(nn.Module):
+ """DiceLoss.
+ This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
+ Volumetric Medical Image Segmentation `_.
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ smooth (float): A float number to smooth loss, and avoid NaN error.
+ Default: 1
+ exponent (float): An float number to calculate denominator
+ value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Default to 1.0.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+ """
+ def __init__(self,
+ smooth=1,
+ exponent=2,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0,
+ ignore_index=255,
+ **kwards):
+ super(DiceLoss, self).__init__()
+ self.smooth = smooth
+ self.exponent = exponent
+ self.reduction = reduction
+ self.class_weight = get_class_weight(class_weight)
+ self.loss_weight = loss_weight
+ self.ignore_index = ignore_index
+ def forward(self,
+ pred,
+ target,
+ avg_factor=None,
+ reduction_override=None,
+ **kwards):
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = pred.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ pred = F.softmax(pred, dim=1)
+ num_classes = pred.shape[1]
+ one_hot_target = F.one_hot(
+ torch.clamp(target.long(), 0, num_classes - 1),
+ num_classes=num_classes)
+ valid_mask = (target != self.ignore_index).long()
+ loss = self.loss_weight * dice_loss(
+ pred,
+ one_hot_target,
+ valid_mask=valid_mask,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ smooth=self.smooth,
+ exponent=self.exponent,
+ class_weight=class_weight,
+ ignore_index=self.ignore_index)
+ return loss
diff --git a/ControlNet/annotator/uniformer/mmseg/models/losses/lovasz_loss.py b/ControlNet/annotator/uniformer/mmseg/models/losses/lovasz_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6badb67f6d987b59fb07aa97caaaf89896e27a8d
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/losses/lovasz_loss.py
@@ -0,0 +1,303 @@
+"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
+ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
+Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
+import annotator.uniformer.mmcv as mmcv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+def lovasz_grad(gt_sorted):
+ """Computes gradient of the Lovasz extension w.r.t sorted errors.
+ See Alg. 1 in paper.
+ """
+ p = len(gt_sorted)
+ gts = gt_sorted.sum()
+ intersection = gts - gt_sorted.float().cumsum(0)
+ union = gts + (1 - gt_sorted).float().cumsum(0)
+ jaccard = 1. - intersection / union
+ if p > 1: # cover 1-pixel case
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+ return jaccard
+def flatten_binary_logits(logits, labels, ignore_index=None):
+ """Flattens predictions in the batch (binary case) Remove labels equal to
+ 'ignore_index'."""
+ logits = logits.view(-1)
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return logits, labels
+ valid = (labels != ignore_index)
+ vlogits = logits[valid]
+ vlabels = labels[valid]
+ return vlogits, vlabels
+def flatten_probs(probs, labels, ignore_index=None):
+ """Flattens predictions in the batch."""
+ if probs.dim() == 3:
+ # assumes output of a sigmoid layer
+ B, H, W = probs.size()
+ probs = probs.view(B, 1, H, W)
+ B, C, H, W = probs.size()
+ probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return probs, labels
+ valid = (labels != ignore_index)
+ vprobs = probs[valid.nonzero().squeeze()]
+ vlabels = labels[valid]
+ return vprobs, vlabels
+def lovasz_hinge_flat(logits, labels):
+ """Binary Lovasz hinge loss.
+ Args:
+ logits (torch.Tensor): [P], logits at each prediction
+ (between -infty and +infty).
+ labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if len(labels) == 0:
+ # only void pixels, the gradients should be 0
+ return logits.sum() * 0.
+ signs = 2. * labels.float() - 1.
+ errors = (1. - logits * signs)
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
+ perm = perm.data
+ gt_sorted = labels[perm]
+ grad = lovasz_grad(gt_sorted)
+ loss = torch.dot(F.relu(errors_sorted), grad)
+ return loss
+def lovasz_hinge(logits,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Binary Lovasz hinge loss.
+ Args:
+ logits (torch.Tensor): [B, H, W], logits at each pixel
+ (between -infty and +infty).
+ labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
+ classes (str | list[int], optional): Placeholder, to be consistent with
+ other loss. Default: None.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): Placeholder, to be consistent
+ with other loss. Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if per_image:
+ loss = [
+ lovasz_hinge_flat(*flatten_binary_logits(
+ logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
+ for logit, label in zip(logits, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_hinge_flat(
+ *flatten_binary_logits(logits, labels, ignore_index))
+ return loss
+def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
+ """Multi-class Lovasz-Softmax loss.
+ Args:
+ probs (torch.Tensor): [P, C], class probabilities at each prediction
+ (between 0 and 1).
+ labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if probs.numel() == 0:
+ # only void pixels, the gradients should be 0
+ return probs * 0.
+ C = probs.size(1)
+ losses = []
+ class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
+ for c in class_to_sum:
+ fg = (labels == c).float() # foreground for class c
+ if (classes == 'present' and fg.sum() == 0):
+ continue
+ if C == 1:
+ if len(classes) > 1:
+ raise ValueError('Sigmoid output possible only with 1 class')
+ class_pred = probs[:, 0]
+ else:
+ class_pred = probs[:, c]
+ errors = (fg - class_pred).abs()
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
+ perm = perm.data
+ fg_sorted = fg[perm]
+ loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
+ if class_weight is not None:
+ loss *= class_weight[c]
+ losses.append(loss)
+ return torch.stack(losses).mean()
+def lovasz_softmax(probs,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Multi-class Lovasz-Softmax loss.
+ Args:
+ probs (torch.Tensor): [B, C, H, W], class probabilities at each
+ prediction (between 0 and 1).
+ labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
+ C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if per_image:
+ loss = [
+ lovasz_softmax_flat(
+ *flatten_probs(
+ prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ for prob, label in zip(probs, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_softmax_flat(
+ *flatten_probs(probs, labels, ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ return loss
+class LovaszLoss(nn.Module):
+ """LovaszLoss.
+ This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
+ for the optimization of the intersection-over-union measure in neural
+ networks `_.
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+ def __init__(self,
+ loss_type='multi_class',
+ classes='present',
+ per_image=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(LovaszLoss, self).__init__()
+ assert loss_type in ('binary', 'multi_class'), "loss_type should be \
+ 'binary' or 'multi_class'."
+ if loss_type == 'binary':
+ self.cls_criterion = lovasz_hinge
+ else:
+ self.cls_criterion = lovasz_softmax
+ assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
+ if not per_image:
+ assert reduction == 'none', "reduction should be 'none' when \
+ per_image is False."
+ self.classes = classes
+ self.per_image = per_image
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ # if multi-class loss, transform logits to probs
+ if self.cls_criterion == lovasz_softmax:
+ cls_score = F.softmax(cls_score, dim=1)
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ self.classes,
+ self.per_image,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/ControlNet/annotator/uniformer/mmseg/models/losses/utils.py b/ControlNet/annotator/uniformer/mmseg/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..85aec9f3045240c3de96a928324ae8f5c3aebe8b
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/losses/utils.py
@@ -0,0 +1,121 @@
+import functools
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch.nn.functional as F
+def get_class_weight(class_weight):
+ """Get class weight for loss function.
+ Args:
+ class_weight (list[float] | str | None): If class_weight is a str,
+ take it as a file name and read from it.
+ """
+ if isinstance(class_weight, str):
+ # take it as a file path
+ if class_weight.endswith('.npy'):
+ class_weight = np.load(class_weight)
+ else:
+ # pkl, json or yaml
+ class_weight = mmcv.load(class_weight)
+ return class_weight
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Avarage factor when computing the mean of losses.
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ if weight.dim() > 1:
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ loss = loss.sum() / avg_factor
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ avg_factor=None, **kwargs)`.
+ :Example:
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, avg_factor=2)
+ tensor(1.5000)
+ """
+ @functools.wraps(loss_func)
+ def wrapper(pred,
+ target,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+ return wrapper
diff --git a/ControlNet/annotator/uniformer/mmseg/models/necks/__init__.py b/ControlNet/annotator/uniformer/mmseg/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9d3d5b3fe80247642d962edd6fb787537d01d6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/necks/__init__.py
@@ -0,0 +1,4 @@
+from .fpn import FPN
+from .multilevel_neck import MultiLevelNeck
+__all__ = ['FPN', 'MultiLevelNeck']
diff --git a/ControlNet/annotator/uniformer/mmseg/models/necks/fpn.py b/ControlNet/annotator/uniformer/mmseg/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53b2a69500f8c2edb835abc3ff0ccc2173d1fb1
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/necks/fpn.py
@@ -0,0 +1,212 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, xavier_init
+from ..builder import NECKS
+class FPN(nn.Module):
+ """Feature Pyramid Network.
+ This is an implementation of - Feature Pyramid Networks for Object
+ Detection (https://arxiv.org/abs/1612.03144)
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
+ on the original feature from the backbone. If True,
+ it is equivalent to `add_extra_convs='on_input'`. If False, it is
+ equivalent to set `add_extra_convs='on_output'`. Default to True.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: `dict(mode='nearest')`
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ extra_convs_on_inputs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+ super(FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ if extra_convs_on_inputs:
+ # For compatibility with previous release
+ # TODO: deprecate `extra_convs_on_inputs`
+ self.add_extra_convs = 'on_input'
+ else:
+ self.add_extra_convs = 'on_output'
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ laterals[i - 1] += F.interpolate(laterals[i],
+ **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/necks/multilevel_neck.py b/ControlNet/annotator/uniformer/mmseg/models/necks/multilevel_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..766144d8136326a1fab5906a153a0c0df69b6b60
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/necks/multilevel_neck.py
@@ -0,0 +1,70 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+from ..builder import NECKS
+class MultiLevelNeck(nn.Module):
+ """MultiLevelNeck.
+ A neck structure connect vit backbone and decoder_heads.
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ scales (List[int]): Scale factors for each input feature map.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer in ConvModule.
+ Default: None.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ scales=[0.5, 1, 2, 4],
+ norm_cfg=None,
+ act_cfg=None):
+ super(MultiLevelNeck, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scales = scales
+ self.num_outs = len(scales)
+ self.lateral_convs = nn.ModuleList()
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.lateral_convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ kernel_size=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ for _ in range(self.num_outs):
+ self.convs.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+ print(inputs[0].shape)
+ inputs = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ # for len(inputs) not equal to self.num_outs
+ if len(inputs) == 1:
+ inputs = [inputs[0] for _ in range(self.num_outs)]
+ outs = []
+ for i in range(self.num_outs):
+ x_resize = F.interpolate(
+ inputs[i], scale_factor=self.scales[i], mode='bilinear')
+ outs.append(self.convs[i](x_resize))
+ return tuple(outs)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/segmentors/__init__.py b/ControlNet/annotator/uniformer/mmseg/models/segmentors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca2f09405330743c476e190896bee39c45498ea
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/segmentors/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseSegmentor
+from .cascade_encoder_decoder import CascadeEncoderDecoder
+from .encoder_decoder import EncoderDecoder
+__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
diff --git a/ControlNet/annotator/uniformer/mmseg/models/segmentors/base.py b/ControlNet/annotator/uniformer/mmseg/models/segmentors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..172fc63b736c4f13be1cd909433bc260760a1eaa
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/segmentors/base.py
@@ -0,0 +1,273 @@
+import logging
+import warnings
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from annotator.uniformer.mmcv.runner import auto_fp16
+class BaseSegmentor(nn.Module):
+ """Base class for segmentors."""
+ __metaclass__ = ABCMeta
+ def __init__(self):
+ super(BaseSegmentor, self).__init__()
+ self.fp16_enabled = False
+ @property
+ def with_neck(self):
+ """bool: whether the segmentor has neck"""
+ return hasattr(self, 'neck') and self.neck is not None
+ @property
+ def with_auxiliary_head(self):
+ """bool: whether the segmentor has auxiliary head"""
+ return hasattr(self,
+ 'auxiliary_head') and self.auxiliary_head is not None
+ @property
+ def with_decode_head(self):
+ """bool: whether the segmentor has decode head"""
+ return hasattr(self, 'decode_head') and self.decode_head is not None
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Placeholder for extract features from images."""
+ pass
+ @abstractmethod
+ def encode_decode(self, img, img_metas):
+ """Placeholder for encode images with backbone and decode into a
+ semantic segmentation map of the same size as input."""
+ pass
+ @abstractmethod
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """Placeholder for Forward function for training."""
+ pass
+ @abstractmethod
+ def simple_test(self, img, img_meta, **kwargs):
+ """Placeholder for single image test."""
+ pass
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Placeholder for augmentation test."""
+ pass
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in segmentor.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if pretrained is not None:
+ logger = logging.getLogger()
+ logger.info(f'load model from: {pretrained}')
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got '
+ f'{type(var)}')
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) != '
+ f'num of image meta ({len(img_metas)})')
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_['ori_shape'] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_['img_shape'] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_['pad_shape'] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+ loss, log_vars = self._parse_losses(losses)
+ outputs = dict(
+ loss=loss,
+ log_vars=log_vars,
+ num_samples=len(data_batch['img_metas']))
+ return outputs
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+ @staticmethod
+ def _parse_losses(losses):
+ """Parse the raw outputs (losses) of the network.
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+ return loss, log_vars
+ def show_result(self,
+ img,
+ result,
+ palette=None,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None,
+ opacity=0.5):
+ """Draw `result` over `img`.
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (Tensor): The semantic segmentation results to draw over
+ `img`.
+ palette (list[list[int]]] | np.ndarray | None): The palette of
+ segmentation map. If None is given, random palette will be
+ generated. Default: None
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ seg = result[0]
+ if palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(
+ 0, 255, size=(len(self.CLASSES), 3))
+ else:
+ palette = self.PALETTE
+ palette = np.array(palette)
+ assert palette.shape[0] == len(self.CLASSES)
+ assert palette.shape[1] == 3
+ assert len(palette.shape) == 2
+ assert 0 < opacity <= 1.0
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ # convert to BGR
+ color_seg = color_seg[..., ::-1]
+ img = img * (1 - opacity) + color_seg * opacity
+ img = img.astype(np.uint8)
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, only '
+ 'result image will be returned')
+ return img
diff --git a/ControlNet/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py b/ControlNet/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..873957d8d6468147c994493d92ff5c1b15bfb703
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py
@@ -0,0 +1,98 @@
+from torch import nn
+from annotator.uniformer.mmseg.core import add_prefix
+from annotator.uniformer.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .encoder_decoder import EncoderDecoder
+class CascadeEncoderDecoder(EncoderDecoder):
+ """Cascade Encoder Decoder segmentors.
+ CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of
+ CascadeEncoderDecoder are cascaded. The output of previous decoder_head
+ will be the input of next decoder_head.
+ """
+ def __init__(self,
+ num_stages,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ self.num_stages = num_stages
+ super(CascadeEncoderDecoder, self).__init__(
+ backbone=backbone,
+ decode_head=decode_head,
+ neck=neck,
+ auxiliary_head=auxiliary_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ assert isinstance(decode_head, list)
+ assert len(decode_head) == self.num_stages
+ self.decode_head = nn.ModuleList()
+ for i in range(self.num_stages):
+ self.decode_head.append(builder.build_head(decode_head[i]))
+ self.align_corners = self.decode_head[-1].align_corners
+ self.num_classes = self.decode_head[-1].num_classes
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ self.backbone.init_weights(pretrained=pretrained)
+ for i in range(self.num_stages):
+ self.decode_head[i].init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg)
+ for i in range(1, self.num_stages):
+ out = self.decode_head[i].forward_test(x, out, img_metas,
+ self.test_cfg)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head[0].forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_decode, 'decode_0'))
+ for i in range(1, self.num_stages):
+ # forward test again, maybe unnecessary for most methods.
+ prev_outputs = self.decode_head[i - 1].forward_test(
+ x, img_metas, self.test_cfg)
+ loss_decode = self.decode_head[i].forward_train(
+ x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_decode, f'decode_{i}'))
+ return losses
diff --git a/ControlNet/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py b/ControlNet/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..98392ac04c4c44a7f4e7b1c0808266875877dd1f
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py
@@ -0,0 +1,298 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmseg.core import add_prefix
+from annotator.uniformer.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .base import BaseSegmentor
+class EncoderDecoder(BaseSegmentor):
+ """Encoder Decoder segmentors.
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+ Note that auxiliary_head is only used for deep supervision during training,
+ which could be dumped during inference.
+ """
+ def __init__(self,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(EncoderDecoder, self).__init__()
+ self.backbone = builder.build_backbone(backbone)
+ if neck is not None:
+ self.neck = builder.build_neck(neck)
+ self._init_decode_head(decode_head)
+ self._init_auxiliary_head(auxiliary_head)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.init_weights(pretrained=pretrained)
+ assert self.with_decode_head
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ self.decode_head = builder.build_head(decode_head)
+ self.align_corners = self.decode_head.align_corners
+ self.num_classes = self.decode_head.num_classes
+ def _init_auxiliary_head(self, auxiliary_head):
+ """Initialize ``auxiliary_head``"""
+ if auxiliary_head is not None:
+ if isinstance(auxiliary_head, list):
+ self.auxiliary_head = nn.ModuleList()
+ for head_cfg in auxiliary_head:
+ self.auxiliary_head.append(builder.build_head(head_cfg))
+ else:
+ self.auxiliary_head = builder.build_head(auxiliary_head)
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super(EncoderDecoder, self).init_weights(pretrained)
+ self.backbone.init_weights(pretrained=pretrained)
+ self.decode_head.init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+ def extract_feat(self, img):
+ """Extract features from images."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+ losses.update(add_prefix(loss_decode, 'decode'))
+ return losses
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
+ return seg_logits
+ def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for auxiliary head in
+ training."""
+ losses = dict()
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for idx, aux_head in enumerate(self.auxiliary_head):
+ loss_aux = aux_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+ losses.update(add_prefix(loss_aux, f'aux_{idx}'))
+ else:
+ loss_aux = self.auxiliary_head.forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_aux, 'aux'))
+ return losses
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ seg_logit = self.encode_decode(img, None)
+ return seg_logit
+ def forward_train(self, img, img_metas, gt_semantic_seg):
+ """Forward function for training.
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ x = self.extract_feat(img)
+ losses = dict()
+ loss_decode = self._decode_head_forward_train(x, img_metas,
+ gt_semantic_seg)
+ losses.update(loss_decode)
+ if self.with_auxiliary_head:
+ loss_aux = self._auxiliary_head_forward_train(
+ x, img_metas, gt_semantic_seg)
+ losses.update(loss_aux)
+ return losses
+ # TODO refactor
+ def slide_inference(self, img, img_meta, rescale):
+ """Inference by sliding-window with overlap.
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = img.size()
+ num_classes = self.num_classes
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ crop_seg_logit = self.encode_decode(crop_img, img_meta)
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(
+ count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ if rescale:
+ preds = resize(
+ preds,
+ size=img_meta[0]['ori_shape'][:2],
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+ return preds
+ def whole_inference(self, img, img_meta, rescale):
+ """Inference with full image."""
+ seg_logit = self.encode_decode(img, img_meta)
+ if rescale:
+ # support dynamic shape for onnx
+ if torch.onnx.is_in_onnx_export():
+ size = img.shape[2:]
+ else:
+ size = img_meta[0]['ori_shape'][:2]
+ seg_logit = resize(
+ seg_logit,
+ size=size,
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+ return seg_logit
+ def inference(self, img, img_meta, rescale):
+ """Inference with slide/whole style.
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+ Returns:
+ Tensor: The output segmentation map.
+ """
+ assert self.test_cfg.mode in ['slide', 'whole']
+ ori_shape = img_meta[0]['ori_shape']
+ assert all(_['ori_shape'] == ori_shape for _ in img_meta)
+ if self.test_cfg.mode == 'slide':
+ seg_logit = self.slide_inference(img, img_meta, rescale)
+ else:
+ seg_logit = self.whole_inference(img, img_meta, rescale)
+ output = F.softmax(seg_logit, dim=1)
+ flip = img_meta[0]['flip']
+ if flip:
+ flip_direction = img_meta[0]['flip_direction']
+ assert flip_direction in ['horizontal', 'vertical']
+ if flip_direction == 'horizontal':
+ output = output.flip(dims=(3, ))
+ elif flip_direction == 'vertical':
+ output = output.flip(dims=(2, ))
+ return output
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ seg_logit = self.inference(img, img_meta, rescale)
+ seg_pred = seg_logit.argmax(dim=1)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ seg_pred = seg_pred.unsqueeze(0)
+ return seg_pred
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented seg logit inplace
+ seg_logit = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
+ seg_logit += cur_seg_logit
+ seg_logit /= len(imgs)
+ seg_pred = seg_logit.argmax(dim=1)
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/__init__.py b/ControlNet/annotator/uniformer/mmseg/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/__init__.py
@@ -0,0 +1,13 @@
+from .drop import DropPath
+from .inverted_residual import InvertedResidual, InvertedResidualV3
+from .make_divisible import make_divisible
+from .res_layer import ResLayer
+from .se_layer import SELayer
+from .self_attention_block import SelfAttentionBlock
+from .up_conv_block import UpConvBlock
+from .weight_init import trunc_normal_
+__all__ = [
+ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
+ 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/drop.py b/ControlNet/annotator/uniformer/mmseg/models/utils/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..4520b0ff407d2a95a864086bdbca0065f222aa63
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/drop.py
@@ -0,0 +1,31 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+import torch
+from torch import nn
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+ Args:
+ drop_prob (float): Drop rate for paths of model. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+ def __init__(self, drop_prob=0.):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.keep_prob = 1 - drop_prob
+ def forward(self, x):
+ if self.drop_prob == 0. or not self.training:
+ return x
+ shape = (x.shape[0], ) + (1, ) * (
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = self.keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(self.keep_prob) * random_tensor
+ return output
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/inverted_residual.py b/ControlNet/annotator/uniformer/mmseg/models/utils/inverted_residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..53b8fcd41f71d814738f1ac3f5acd3c3d701bf96
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/inverted_residual.py
@@ -0,0 +1,208 @@
+from annotator.uniformer.mmcv.cnn import ConvModule
+from torch import nn
+from torch.utils import checkpoint as cp
+from .se_layer import SELayer
+class InvertedResidual(nn.Module):
+ """InvertedResidual block for MobileNetV2.
+ Args:
+ in_channels (int): The input channels of the InvertedResidual block.
+ out_channels (int): The output channels of the InvertedResidual block.
+ stride (int): Stride of the middle (first) 3x3 convolution.
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ dilation (int): Dilation rate of depthwise conv. Default: 1
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ Returns:
+ Tensor: The output tensor.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ dilation=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ with_cp=False):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2], f'stride must in [1, 2]. ' \
+ f'But received {stride}.'
+ self.with_cp = with_cp
+ self.use_res_connect = self.stride == 1 and in_channels == out_channels
+ hidden_dim = int(round(in_channels * expand_ratio))
+ layers = []
+ if expand_ratio != 1:
+ layers.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=hidden_dim,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ layers.extend([
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ groups=hidden_dim,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ ])
+ self.conv = nn.Sequential(*layers)
+ def forward(self, x):
+ def _inner_forward(x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ return out
+class InvertedResidualV3(nn.Module):
+ """Inverted Residual Block for MobileNetV3.
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ mid_channels (int): The input channels of the depthwise convolution.
+ kernel_size (int): The kernel size of the depthwise convolution.
+ Default: 3.
+ stride (int): The stride of the depthwise convolution. Default: 1.
+ se_cfg (dict): Config dict for se layer. Default: None, which means no
+ se layer.
+ with_expand_conv (bool): Use expand conv or not. If set False,
+ mid_channels must be the same with in_channels. Default: True.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ Returns:
+ Tensor: The output tensor.
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ se_cfg=None,
+ with_expand_conv=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False):
+ super(InvertedResidualV3, self).__init__()
+ self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
+ assert stride in [1, 2]
+ self.with_cp = with_cp
+ self.with_se = se_cfg is not None
+ self.with_expand_conv = with_expand_conv
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+ if not self.with_expand_conv:
+ assert mid_channels == in_channels
+ if self.with_expand_conv:
+ self.expand_conv = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.depthwise_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ groups=mid_channels,
+ conv_cfg=dict(
+ type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+ self.linear_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ def forward(self, x):
+ def _inner_forward(x):
+ out = x
+ if self.with_expand_conv:
+ out = self.expand_conv(out)
+ out = self.depthwise_conv(out)
+ if self.with_se:
+ out = self.se(out)
+ out = self.linear_conv(out)
+ if self.with_res_shortcut:
+ return x + out
+ else:
+ return out
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+ return out
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/make_divisible.py b/ControlNet/annotator/uniformer/mmseg/models/utils/make_divisible.py
new file mode 100644
index 0000000000000000000000000000000000000000..75ad756052529f52fe83bb95dd1f0ecfc9a13078
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/make_divisible.py
@@ -0,0 +1,27 @@
+def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
+ """Make divisible function.
+ This function rounds the channel number to the nearest value that can be
+ divisible by the divisor. It is taken from the original tf repo. It ensures
+ that all layers have a channel number that is divisible by divisor. It can
+ be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
+ Args:
+ value (int): The original channel number.
+ divisor (int): The divisor to fully divide the channel number.
+ min_value (int): The minimum value of the output channel.
+ Default: None, means that the minimum value equal to the divisor.
+ min_ratio (float): The minimum ratio of the rounded channel number to
+ the original channel number. Default: 0.9.
+ Returns:
+ int: The modified output channel number.
+ """
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than (1-min_ratio).
+ if new_value < min_ratio * value:
+ new_value += divisor
+ return new_value
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/res_layer.py b/ControlNet/annotator/uniformer/mmseg/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2c07b47007e92e4c3945b989e79f9d50306f5fe
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/res_layer.py
@@ -0,0 +1,94 @@
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+from torch import nn as nn
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone.
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ multi_grid (int | None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ """
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ dilation=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ multi_grid=None,
+ contract_dilation=False,
+ **kwargs):
+ self.block = block
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+ layers = []
+ if multi_grid is None:
+ if dilation > 1 and contract_dilation:
+ first_dilation = dilation // 2
+ else:
+ first_dilation = dilation
+ else:
+ first_dilation = multi_grid[0]
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ dilation=first_dilation,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ dilation=dilation if multi_grid is None else multi_grid[i],
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/se_layer.py b/ControlNet/annotator/uniformer/mmseg/models/utils/se_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..083bd7d1ccee909c900c7aed2cc928bf14727f3e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/se_layer.py
@@ -0,0 +1,57 @@
+import annotator.uniformer.mmcv as mmcv
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+from .make_divisible import make_divisible
+class SELayer(nn.Module):
+ """Squeeze-and-Excitation Module.
+ Args:
+ channels (int): The input (and output) channels of the SE layer.
+ ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
+ ``int(channels/ratio)``. Default: 16.
+ conv_cfg (None or dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+ If act_cfg is a dict, two activation layers will be configured
+ by this dict. If act_cfg is a sequence of dicts, the first
+ activation layer will be configured by the first dict and the
+ second activation layer will be configured by the second dict.
+ Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
+ divisor=6.0)).
+ """
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0))):
+ super(SELayer, self).__init__()
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmcv.is_tuple_of(act_cfg, dict)
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=make_divisible(channels // ratio, 8),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=make_divisible(channels // ratio, 8),
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[1])
+ def forward(self, x):
+ out = self.global_avgpool(x)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ return x * out
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/self_attention_block.py b/ControlNet/annotator/uniformer/mmseg/models/utils/self_attention_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..440c7b73ee4706fde555595926d63a18d7574acc
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/self_attention_block.py
@@ -0,0 +1,159 @@
+import torch
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init
+from torch import nn as nn
+from torch.nn import functional as F
+class SelfAttentionBlock(nn.Module):
+ """General self-attention block/non-local block.
+ Please refer to https://arxiv.org/abs/1706.03762 for details about key,
+ query and value.
+ Args:
+ key_in_channels (int): Input channels of key feature.
+ query_in_channels (int): Input channels of query feature.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_downsample (nn.Module): Query downsample module.
+ key_downsample (nn.Module): Key downsample module.
+ key_query_num_convs (int): Number of convs for key/query projection.
+ value_num_convs (int): Number of convs for value projection.
+ matmul_norm (bool): Whether normalize attention map with sqrt of
+ channels
+ with_out (bool): Whether use out projection.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+ def __init__(self, key_in_channels, query_in_channels, channels,
+ out_channels, share_key_query, query_downsample,
+ key_downsample, key_query_num_convs, value_out_num_convs,
+ key_query_norm, value_out_norm, matmul_norm, with_out,
+ conv_cfg, norm_cfg, act_cfg):
+ super(SelfAttentionBlock, self).__init__()
+ if share_key_query:
+ assert key_in_channels == query_in_channels
+ self.key_in_channels = key_in_channels
+ self.query_in_channels = query_in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.share_key_query = share_key_query
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.key_project = self.build_project(
+ key_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if share_key_query:
+ self.query_project = self.key_project
+ else:
+ self.query_project = self.build_project(
+ query_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.value_project = self.build_project(
+ key_in_channels,
+ channels if with_out else out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if with_out:
+ self.out_project = self.build_project(
+ channels,
+ out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.out_project = None
+ self.query_downsample = query_downsample
+ self.key_downsample = key_downsample
+ self.matmul_norm = matmul_norm
+ self.init_weights()
+ def init_weights(self):
+ """Initialize weight of later layer."""
+ if self.out_project is not None:
+ if not isinstance(self.out_project, ConvModule):
+ constant_init(self.out_project, 0)
+ def build_project(self, in_channels, channels, num_convs, use_conv_module,
+ conv_cfg, norm_cfg, act_cfg):
+ """Build projection layer for key/query/value/out."""
+ if use_conv_module:
+ convs = [
+ ConvModule(
+ in_channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ ]
+ for _ in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ else:
+ convs = [nn.Conv2d(in_channels, channels, 1)]
+ for _ in range(num_convs - 1):
+ convs.append(nn.Conv2d(channels, channels, 1))
+ if len(convs) > 1:
+ convs = nn.Sequential(*convs)
+ else:
+ convs = convs[0]
+ return convs
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ batch_size = query_feats.size(0)
+ query = self.query_project(query_feats)
+ if self.query_downsample is not None:
+ query = self.query_downsample(query)
+ query = query.reshape(*query.shape[:2], -1)
+ query = query.permute(0, 2, 1).contiguous()
+ key = self.key_project(key_feats)
+ value = self.value_project(key_feats)
+ if self.key_downsample is not None:
+ key = self.key_downsample(key)
+ value = self.key_downsample(value)
+ key = key.reshape(*key.shape[:2], -1)
+ value = value.reshape(*value.shape[:2], -1)
+ value = value.permute(0, 2, 1).contiguous()
+ sim_map = torch.matmul(query, key)
+ if self.matmul_norm:
+ sim_map = (self.channels**-.5) * sim_map
+ sim_map = F.softmax(sim_map, dim=-1)
+ context = torch.matmul(sim_map, value)
+ context = context.permute(0, 2, 1).contiguous()
+ context = context.reshape(batch_size, -1, *query_feats.shape[2:])
+ if self.out_project is not None:
+ context = self.out_project(context)
+ return context
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/up_conv_block.py b/ControlNet/annotator/uniformer/mmseg/models/utils/up_conv_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..378469da76cb7bff6a639e7877b3c275d50490fb
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/up_conv_block.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, build_upsample_layer
+class UpConvBlock(nn.Module):
+ """Upsample convolution block in decoder for UNet.
+ This upsample convolution block consists of one upsample module
+ followed by one convolution block. The upsample module expands the
+ high-level low-resolution feature map and the convolution block fuses
+ the upsampled high-level low-resolution feature map and the low-level
+ high-resolution feature map from encoder.
+ Args:
+ conv_block (nn.Sequential): Sequential of convolutional layers.
+ in_channels (int): Number of input channels of the high-level
+ skip_channels (int): Number of input channels of the low-level
+ high-resolution feature map from encoder.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers in the conv_block.
+ Default: 2.
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
+ dilation (int): Dilation rate of convolutional layer in conv_block.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv'). If the size of
+ high-level feature map is the same as that of skip feature map
+ (low-level feature map from encoder), it does not need upsample the
+ high-level feature map and the upsample_cfg is None.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+ def __init__(self,
+ conv_block,
+ in_channels,
+ skip_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ dcn=None,
+ plugins=None):
+ super(UpConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ self.conv_block = conv_block(
+ in_channels=2 * skip_channels,
+ out_channels=out_channels,
+ num_convs=num_convs,
+ stride=stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None)
+ if upsample_cfg is not None:
+ self.upsample = build_upsample_layer(
+ cfg=upsample_cfg,
+ in_channels=in_channels,
+ out_channels=skip_channels,
+ with_cp=with_cp,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.upsample = ConvModule(
+ in_channels,
+ skip_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ def forward(self, skip, x):
+ """Forward function."""
+ x = self.upsample(x)
+ out = torch.cat([skip, x], dim=1)
+ out = self.conv_block(out)
+ return out
diff --git a/ControlNet/annotator/uniformer/mmseg/models/utils/weight_init.py b/ControlNet/annotator/uniformer/mmseg/models/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..38141ba3d61f64ddfc0a31574b4648cbad96d7dd
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/models/utils/weight_init.py
@@ -0,0 +1,62 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+import math
+import warnings
+import torch
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ """Reference: https://people.sc.fsu.edu/~jburkardt/presentations
+ /truncated_normal.pdf"""
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower_bound = norm_cdf((a - mean) / std)
+ upper_bound = norm_cdf((b - mean) / std)
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
+ mean (float): the mean of the normal distribution
+ std (float): the standard deviation of the normal distribution
+ a (float): the minimum cutoff value
+ b (float): the maximum cutoff value
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/ControlNet/annotator/uniformer/mmseg/ops/__init__.py b/ControlNet/annotator/uniformer/mmseg/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/ops/__init__.py
@@ -0,0 +1,4 @@
+from .encoding import Encoding
+from .wrappers import Upsample, resize
+__all__ = ['Upsample', 'resize', 'Encoding']
diff --git a/ControlNet/annotator/uniformer/mmseg/ops/encoding.py b/ControlNet/annotator/uniformer/mmseg/ops/encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb3629a6426550b8e4c537ee1ff4341893e489e
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/ops/encoding.py
@@ -0,0 +1,74 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+class Encoding(nn.Module):
+ """Encoding Layer: a learnable residual encoder.
+ Input is of shape (batch_size, channels, height, width).
+ Output is of shape (batch_size, num_codes, channels).
+ Args:
+ channels: dimension of the features or feature channels
+ num_codes: number of code words
+ """
+ def __init__(self, channels, num_codes):
+ super(Encoding, self).__init__()
+ # init codewords and smoothing factor
+ self.channels, self.num_codes = channels, num_codes
+ std = 1. / ((num_codes * channels)**0.5)
+ # [num_codes, channels]
+ self.codewords = nn.Parameter(
+ torch.empty(num_codes, channels,
+ dtype=torch.float).uniform_(-std, std),
+ requires_grad=True)
+ # [num_codes]
+ self.scale = nn.Parameter(
+ torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
+ requires_grad=True)
+ @staticmethod
+ def scaled_l2(x, codewords, scale):
+ num_codes, channels = codewords.size()
+ batch_size = x.size(0)
+ reshaped_scale = scale.view((1, 1, num_codes))
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+ scaled_l2_norm = reshaped_scale * (
+ expanded_x - reshaped_codewords).pow(2).sum(dim=3)
+ return scaled_l2_norm
+ @staticmethod
+ def aggregate(assignment_weights, x, codewords):
+ num_codes, channels = codewords.size()
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+ batch_size = x.size(0)
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ encoded_feat = (assignment_weights.unsqueeze(3) *
+ (expanded_x - reshaped_codewords)).sum(dim=1)
+ return encoded_feat
+ def forward(self, x):
+ assert x.dim() == 4 and x.size(1) == self.channels
+ # [batch_size, channels, height, width]
+ batch_size = x.size(0)
+ # [batch_size, height x width, channels]
+ x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
+ # assignment_weights: [batch_size, channels, num_codes]
+ assignment_weights = F.softmax(
+ self.scaled_l2(x, self.codewords, self.scale), dim=2)
+ # aggregate
+ encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
+ return encoded_feat
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
+ f'x{self.channels})'
+ return repr_str
diff --git a/ControlNet/annotator/uniformer/mmseg/ops/wrappers.py b/ControlNet/annotator/uniformer/mmseg/ops/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed9a0cb8d7c0e0ec2748dd89c652756653cac78
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/ops/wrappers.py
@@ -0,0 +1,50 @@
+import warnings
+import torch.nn as nn
+import torch.nn.functional as F
+def resize(input,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None,
+ warning=True):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if ((output_h > 1 and output_w > 1 and input_h > 1
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)):
+ warnings.warn(
+ f'When align_corners={align_corners}, '
+ 'the output would more aligned if '
+ f'input size {(input_h, input_w)} is `x+1` and '
+ f'out size {(output_h, output_w)} is `nx+1`')
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
+class Upsample(nn.Module):
+ def __init__(self,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None):
+ super(Upsample, self).__init__()
+ self.size = size
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.align_corners = align_corners
+ def forward(self, x):
+ if not self.size:
+ size = [int(t * self.scale_factor) for t in x.shape[-2:]]
+ else:
+ size = self.size
+ return resize(x, size, None, self.mode, self.align_corners)
diff --git a/ControlNet/annotator/uniformer/mmseg/utils/__init__.py b/ControlNet/annotator/uniformer/mmseg/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac489e2dbbc0e6fa87f5088b4edcc20f8cadc1a6
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/utils/__init__.py
@@ -0,0 +1,4 @@
+from .collect_env import collect_env
+from .logger import get_root_logger
+__all__ = ['get_root_logger', 'collect_env']
diff --git a/ControlNet/annotator/uniformer/mmseg/utils/collect_env.py b/ControlNet/annotator/uniformer/mmseg/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c2134ddbee9655161237dd0894d38c768c2624
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/utils/collect_env.py
@@ -0,0 +1,17 @@
+from annotator.uniformer.mmcv.utils import collect_env as collect_base_env
+from annotator.uniformer.mmcv.utils import get_git_hash
+import annotator.uniformer.mmseg as mmseg
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+ return env_info
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print('{}: {}'.format(name, val))
diff --git a/ControlNet/annotator/uniformer/mmseg/utils/logger.py b/ControlNet/annotator/uniformer/mmseg/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..4149d9eda3dfef07490352d22ac40c42460315e4
--- /dev/null
+++ b/ControlNet/annotator/uniformer/mmseg/utils/logger.py
@@ -0,0 +1,27 @@
+import logging
+from annotator.uniformer.mmcv.utils import get_logger
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Get the root logger.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmseg".
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
+ return logger
diff --git a/ControlNet/annotator/util.py b/ControlNet/annotator/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05
--- /dev/null
+++ b/ControlNet/annotator/util.py
@@ -0,0 +1,38 @@
+import numpy as np
+import cv2
+import os
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
diff --git a/ControlNet/cldm/cldm.py b/ControlNet/cldm/cldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b3ac7a575cf4933fc14dfc15dd3cca41cb3f3e8
--- /dev/null
+++ b/ControlNet/cldm/cldm.py
@@ -0,0 +1,435 @@
+import einops
+import torch
+import torch as th
+import torch.nn as nn
+from ldm.modules.diffusionmodules.util import (
+ conv_nd,
+ linear,
+ zero_module,
+ timestep_embedding,
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from ldm.modules.attention import SpatialTransformer
+from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.util import log_txt_as_img, exists, instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+class ControlledUnetModel(UNetModel):
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
+ hs = []
+ with torch.no_grad():
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ if control is not None:
+ h += control.pop()
+ for i, module in enumerate(self.output_blocks):
+ if only_mid_control or control is None:
+ h = torch.cat([h, hs.pop()], dim=1)
+ else:
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ return self.out(h)
+class ControlNet(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ hint_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+ self.dims = dims
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+ self.input_hint_block = TimestepEmbedSequential(
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 96, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_convs.append(self.make_zero_conv(ch))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ self.zero_convs.append(self.make_zero_conv(ch))
+ ds *= 2
+ self._feature_size += ch
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self.middle_block_out = self.make_zero_conv(ch)
+ self._feature_size += ch
+ def make_zero_conv(self, channels):
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+ def forward(self, x, hint, timesteps, context, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+ guided_hint = self.input_hint_block(hint, emb, context)
+ outs = []
+ h = x.type(self.dtype)
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None:
+ h = module(h, emb, context)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ outs.append(zero_conv(h, emb, context))
+ h = self.middle_block(h, emb, context)
+ outs.append(self.middle_block_out(h, emb, context))
+ return outs
+class ControlLDM(LatentDiffusion):
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.control_model = instantiate_from_config(control_stage_config)
+ self.control_key = control_key
+ self.only_mid_control = only_mid_control
+ self.control_scales = [1.0] * 13
+ @torch.no_grad()
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
+ control = batch[self.control_key]
+ if bs is not None:
+ control = control[:bs]
+ control = control.to(self.device)
+ control = einops.rearrange(control, 'b h w c -> b c h w')
+ control = control.to(memory_format=torch.contiguous_format).float()
+ return x, dict(c_crossattn=[c], c_concat=[control])
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
+ assert isinstance(cond, dict)
+ diffusion_model = self.model.diffusion_model
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
+ if cond['c_concat'] is None:
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
+ else:
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
+ return eps
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, N):
+ return self.get_learned_conditioning([""] * N)
+ @torch.no_grad()
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ use_ddim = ddim_steps is not None
+ log = dict()
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
+ N = min(z.shape[0], N)
+ n_row = min(z.shape[0], n_row)
+ log["reconstruction"] = self.decode_first_stage(z)
+ log["control"] = c_cat * 2.0 - 1.0
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+ if sample:
+ # get denoise row
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N)
+ uc_cat = c_cat # torch.zeros_like(c_cat)
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ return log
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ ddim_sampler = DDIMSampler(self)
+ b, c, h, w = cond["c_concat"][0].shape
+ shape = (self.channels, h // 8, w // 8)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
+ return samples, intermediates
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.control_model.parameters())
+ if not self.sd_locked:
+ params += list(self.model.diffusion_model.output_blocks.parameters())
+ params += list(self.model.diffusion_model.out.parameters())
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+ def low_vram_shift(self, is_diffusing):
+ if is_diffusing:
+ self.model = self.model.cuda()
+ self.control_model = self.control_model.cuda()
+ self.first_stage_model = self.first_stage_model.cpu()
+ self.cond_stage_model = self.cond_stage_model.cpu()
+ else:
+ self.model = self.model.cpu()
+ self.control_model = self.control_model.cpu()
+ self.first_stage_model = self.first_stage_model.cuda()
+ self.cond_stage_model = self.cond_stage_model.cuda()
diff --git a/ControlNet/cldm/ddim_hacked.py b/ControlNet/cldm/ddim_hacked.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c040b363ba0705f52509b75437b5ea932c80ec1
--- /dev/null
+++ b/ControlNet/cldm/ddim_hacked.py
@@ -0,0 +1,316 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+ return img, intermediates
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ model_t = self.model.apply_model(x, t, c)
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
\ No newline at end of file
diff --git a/ControlNet/cldm/hack.py b/ControlNet/cldm/hack.py
new file mode 100644
index 0000000000000000000000000000000000000000..454361e9d036cd1a6a79122c2fd16b489e4767b1
--- /dev/null
+++ b/ControlNet/cldm/hack.py
@@ -0,0 +1,111 @@
+import torch
+import einops
+import ldm.modules.encoders.modules
+import ldm.modules.attention
+from transformers import logging
+from ldm.modules.attention import default
+def disable_verbosity():
+ logging.set_verbosity_error()
+ print('logging improved.')
+ return
+def enable_sliced_attention():
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
+ print('Enabled sliced_attention.')
+ return
+def hack_everything(clip_skip=0):
+ disable_verbosity()
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
+ print('Enabled clip hacks.')
+ return
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+ PAD = self.tokenizer.pad_token_id
+ EOS = self.tokenizer.eos_token_id
+ BOS = self.tokenizer.bos_token_id
+ def tokenize(t):
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
+ def transformer_encode(t):
+ if self.clip_skip > 1:
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
+ else:
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
+ def split(x):
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
+ def pad(x, p, i):
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+ raw_tokens_list = tokenize(text)
+ tokens_list = []
+ for raw_tokens in raw_tokens_list:
+ raw_tokens_123 = split(raw_tokens)
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+ tokens_list.append(raw_tokens_123)
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
+ y = transformer_encode(feed)
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
+ return z
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ limit = k.shape[0]
+ att_step = 1
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+ q_chunks.reverse()
+ k_chunks.reverse()
+ v_chunks.reverse()
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ del k, q, v
+ for i in range(0, limit, att_step):
+ q_buffer = q_chunks.pop()
+ k_buffer = k_chunks.pop()
+ v_buffer = v_chunks.pop()
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+ del k_buffer, q_buffer
+ # attention, what we cannot get enough of, by chunks
+ sim_buffer = sim_buffer.softmax(dim=-1)
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+ del v_buffer
+ sim[i:i + att_step, :, :] = sim_buffer
+ del sim_buffer
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
diff --git a/ControlNet/cldm/logger.py b/ControlNet/cldm/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8803846f2a8979f87f3cf9ea5b12869439e62f
--- /dev/null
+++ b/ControlNet/cldm/logger.py
@@ -0,0 +1,76 @@
+import os
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities.distributed import rank_zero_only
+class ImageLogger(Callback):
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
+ log_images_kwargs=None):
+ super().__init__()
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.clamp = clamp
+ self.disabled = disabled
+ self.log_on_batch_idx = log_on_batch_idx
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+ self.log_first_step = log_first_step
+ @rank_zero_only
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
+ root = os.path.join(save_dir, "image_log", split)
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
+ if self.rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ Image.fromarray(grid).save(path)
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
+ hasattr(pl_module, "log_images") and
+ callable(pl_module.log_images) and
+ self.max_images > 0):
+ logger = type(pl_module.logger)
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+ with torch.no_grad():
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+ for k in images:
+ N = min(images[k].shape[0], self.max_images)
+ images[k] = images[k][:N]
+ if isinstance(images[k], torch.Tensor):
+ images[k] = images[k].detach().cpu()
+ if self.clamp:
+ images[k] = torch.clamp(images[k], -1., 1.)
+ self.log_local(pl_module.logger.save_dir, split, images,
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
+ if is_train:
+ pl_module.train()
+ def check_frequency(self, check_idx):
+ return check_idx % self.batch_freq == 0
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ if not self.disabled:
+ self.log_img(pl_module, batch, batch_idx, split="train")
diff --git a/ControlNet/cldm/model.py b/ControlNet/cldm/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed3c31ac145b78907c7f771d1d8db6fb32d92ed
--- /dev/null
+++ b/ControlNet/cldm/model.py
@@ -0,0 +1,28 @@
+import os
+import torch
+from omegaconf import OmegaConf
+from ldm.util import instantiate_from_config
+def get_state_dict(d):
+ return d.get('state_dict', d)
+def load_state_dict(ckpt_path, location='cpu'):
+ _, extension = os.path.splitext(ckpt_path)
+ if extension.lower() == ".safetensors":
+ import safetensors.torch
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+ else:
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
+ state_dict = get_state_dict(state_dict)
+ print(f'Loaded state_dict from [{ckpt_path}]')
+ return state_dict
+def create_model(config_path):
+ config = OmegaConf.load(config_path)
+ model = instantiate_from_config(config.model).cpu()
+ print(f'Loaded model config from [{config_path}]')
+ return model
diff --git a/ControlNet/ldm/data/__init__.py b/ControlNet/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/ldm/data/util.py b/ControlNet/ldm/data/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c
--- /dev/null
+++ b/ControlNet/ldm/data/util.py
@@ -0,0 +1,24 @@
+import torch
+from ldm.modules.midas.api import load_midas_transform
+class AddMiDaS(object):
+ def __init__(self, model_type):
+ super().__init__()
+ self.transform = load_midas_transform(model_type)
+ def pt2np(self, x):
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
+ return x
+ def np2pt(self, x):
+ x = torch.from_numpy(x) * 2 - 1.
+ return x
+ def __call__(self, sample):
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
+ x = self.pt2np(sample['jpg'])
+ x = self.transform({"image": x})["image"]
+ sample['midas_in'] = x
+ return sample
\ No newline at end of file
diff --git a/ControlNet/ldm/models/autoencoder.py b/ControlNet/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd
--- /dev/null
+++ b/ControlNet/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False
+ ):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ self.use_ema = ema_decay is not None
+ if self.use_ema:
+ self.ema_decay = ema_decay
+ assert 0. < ema_decay < 1.
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ return log_dict
+ def _validation_step(self, batch, batch_idx, postfix=""):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+ if self.learn_logvar:
+ print(f"{self.__class__.__name__}: Learning logvar")
+ ae_params_list.append(self.loss.logvar)
+ opt_ae = torch.optim.Adam(ae_params_list,
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+ log["reconstructions_ema"] = xrec_ema
+ log["inputs"] = x
+ return log
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+ def encode(self, x, *args, **kwargs):
+ return x
+ def decode(self, x, *args, **kwargs):
+ return x
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/ControlNet/ldm/models/diffusion/__init__.py b/ControlNet/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/ldm/models/diffusion/ddim.py b/ControlNet/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144
--- /dev/null
+++ b/ControlNet/ldm/models/diffusion/ddim.py
@@ -0,0 +1,336 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+ return img, intermediates
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ elif isinstance(c, list):
+ c_in = list()
+ assert isinstance(unconditional_conditioning, list)
+ for i in range(len(c)):
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
\ No newline at end of file
diff --git a/ControlNet/ldm/models/diffusion/ddpm.py b/ControlNet/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71a44af48c8cba8e97849b7e6813b3e6f9fe83c
--- /dev/null
+++ b/ControlNet/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1797 @@
+wild mixture of
+-- merci
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if reset_ema: assert exists(ckpt_path)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ if reset_ema:
+ assert self.use_ema
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+ self.loss_type = loss_type
+ self.learn_logvar = learn_logvar
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+ else:
+ self.register_buffer('logvar', logvar)
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ if self.make_it_fit:
+ n_params = len([name for name, _ in
+ itertools.chain(self.named_parameters(),
+ self.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(),
+ self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+ sd[name] = new_param
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys:\n {missing}")
+ if len(unexpected) > 0:
+ print(f"\nUnexpected Keys:\n {unexpected}")
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+ return loss
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+ log_prefix = 'train' if self.training else 'val'
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+ loss_dict.update({f'{log_prefix}/loss': loss})
+ return loss, loss_dict
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+ def training_step(self, batch, batch_idx):
+ for k in self.ucg_training:
+ p = self.ucg_training[k]["p"]
+ val = self.ucg_training[k]["val"]
+ if val is None:
+ val = ""
+ for i in range(len(batch[k])):
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[k][i] = val
+ loss, loss_dict = self.shared_step(batch)
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+ return loss
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+class LatentDiffusion(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ force_null_conditioning=False,
+ *args, **kwargs):
+ self.force_null_conditioning = force_null_conditioning
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ reset_ema = kwargs.pop("reset_ema", False)
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+ else:
+ raise NotImplementedError
+ return fold, unfold, normalization, weighting
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+ xc = batch[cond_key]
+ elif cond_key in ['class_label', 'cls']:
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_x:
+ out.extend([x])
+ if return_original_cond:
+ out.append(xc)
+ return out
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+ z = 1. / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ if isinstance(cond, dict):
+ # hybrid case, cond is expected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+ x_recon = self.model(x_noisy, t, **cond)
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+ loss = self.l_simple_weight * loss.mean()
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+ return loss, loss_dict
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+ if return_codebook_ids:
+ model_out, logits = model_out
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ if return_intermediates:
+ return img, intermediates
+ return img
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, **kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True, **kwargs)
+ return samples, intermediates
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ if self.cond_stage_key in ["class_label", "cls"]:
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+ return self.get_learned_conditioning(xc)
+ else:
+ raise NotImplementedError("todo")
+ if isinstance(c, list): # in case the encoder gives us a list
+ for i in range(len(c)):
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+ else:
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+ return c
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', "cls"]:
+ try:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ except KeyError:
+ # probably no "human_label" in batch
+ pass
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ if self.model.conditioning_key == "crossattn-adm":
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+ # outpaint
+ mask = 1. - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ if not self.sequential_cross_attn:
+ cc = torch.cat(c_crossattn, 1)
+ else:
+ cc = c_crossattn
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'crossattn-adm':
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+ return out
+class LatentUpscaleDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+ self.noise_level_key = noise_level_key
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ zx, noise_level = self.low_scale_model(x_low)
+ if self.noise_level_key is not None:
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+ raise NotImplementedError('TODO')
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ x_low_rec = self.low_scale_model.decode(zx)
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+ return z, all_conds
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+ log = dict()
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+ log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ # TODO explore better "unconditional" choices for the other keys
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif k == "c_adm": # todo: only run with text-based guidance?
+ assert isinstance(c[k], torch.Tensor)
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+ uc[k] = c[k]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+ return log
+class LatentFinetuneDiffusion(LatentDiffusion):
+ """
+ Basis for different finetunas, such as inpainting or depth2image
+ To disable finetuning mode, set finetune_keys to None
+ """
+ def __init__(self,
+ concat_keys: tuple,
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight"
+ ),
+ keep_finetune_dims=4,
+ # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args, **kwargs
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", list())
+ super().__init__(*args, **kwargs)
+ self.finetune_keys = finetune_keys
+ self.concat_keys = concat_keys
+ self.keep_dims = keep_finetune_dims
+ self.c_concat_log_start = c_concat_log_start
+ self.c_concat_log_end = c_concat_log_end
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+ if exists(ckpt_path):
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ # make it explicit, finetune by including extra input channels
+ if exists(self.finetune_keys) and k in self.finetune_keys:
+ new_entry = None
+ for name, param in self.named_parameters():
+ if name in self.finetune_keys:
+ print(
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+ new_entry = torch.zeros_like(param) # zero init
+ assert exists(new_entry), 'did not find matching parameter to modify'
+ new_entry[:, :self.keep_dims, ...] = sd[k]
+ sd[k] = new_entry
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc_cat = c_cat
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ return log
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+ """
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+ e.g. mask as concat and text via cross-attn.
+ To disable finetuning mode, set finetune_keys to None
+ """
+ def __init__(self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args, **kwargs
+ ):
+ super().__init__(concat_keys, *args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ assert exists(self.concat_keys)
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ bchw = z.shape
+ if ck != self.masked_image_key:
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+ log["masked_image"] = rearrange(args[0]["masked_image"],
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ return log
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on monocular depth estimation
+ """
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.depth_model = instantiate_from_config(depth_stage_config)
+ self.depth_stage_key = concat_keys[0]
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ cc = self.depth_model(cc)
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ depth = self.depth_model(args[0][self.depth_stage_key])
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+ return log
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on low-res image (and optionally on some spatial noise augmentation)
+ """
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.reshuffle_patch_size = reshuffle_patch_size
+ self.low_scale_model = None
+ if low_scale_config is not None:
+ print("Initializing a low-scale model")
+ assert exists(low_scale_key)
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ # optionally make spatial noise_level here
+ c_cat = list()
+ noise_level = None
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ cc = rearrange(cc, 'b h w c -> b c h w')
+ if exists(self.reshuffle_patch_size):
+ assert isinstance(self.reshuffle_patch_size, int)
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
+ cc, noise_level = self.low_scale_model(cc)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ if exists(noise_level):
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+ else:
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+ return log
diff --git a/ControlNet/ldm/models/diffusion/dpm_solver/__init__.py b/ControlNet/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/ControlNet/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/ControlNet/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ControlNet/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c
--- /dev/null
+++ b/ControlNet/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+ t = self.inverse_lambda(lambda_t)
+ ===============================================================
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+ 1. For discrete-time DPMs:
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+ 2. For continuous-time DPMs:
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+ ===============================================================
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+ ===============================================================
+ Example:
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+ """
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+ """Create a wrapper function for the noise prediction model.
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+ We support four types of the diffusion model by setting `model_type`:
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+ ===============================================================
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return timesteps_outer, orders
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+ =====================================================
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+ =====================================================
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+# other utility functions
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/ControlNet/ldm/models/diffusion/dpm_solver/sampler.py b/ControlNet/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b
--- /dev/null
+++ b/ControlNet/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+import torch
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+ "eps": "noise",
+ "v": "v"
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+ return x.to(device), None
\ No newline at end of file
diff --git a/ControlNet/ldm/models/diffusion/plms.py b/ControlNet/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae
--- /dev/null
+++ b/ControlNet/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+ return img, intermediates
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+ return e_t
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+ return x_prev, pred_x0, e_t
diff --git a/ControlNet/ldm/models/diffusion/sampling_util.py b/ControlNet/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33
--- /dev/null
+++ b/ControlNet/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+def norm_thresholding(x0, value):
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+ return x0 * (value / s)
+def spatial_norm_thresholding(x0, value):
+ # b c h w
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+ return x0 * (value / s)
\ No newline at end of file
diff --git a/ControlNet/ldm/modules/attention.py b/ControlNet/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..509cd873768f0dd75a75ab3fcdd652822b12b59f
--- /dev/null
+++ b/ControlNet/ldm/modules/attention.py
@@ -0,0 +1,341 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+from ldm.modules.diffusionmodules.util import checkpoint
+ import xformers
+ import xformers.ops
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+def exists(val):
+ return val is not None
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+ def forward(self, x):
+ return self.net(x)
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+ return x+h_
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ # force cast to fp32 to avoid overflowing
+ if _ATTN_PRECISION =="fp32":
+ with torch.autocast(enabled=False, device_type = 'cuda'):
+ q, k = q.float(), k.float()
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ else:
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+ self.heads = heads
+ self.dim_head = dim_head
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+class BasicTransformerBlock(nn.Module):
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/ControlNet/ldm/modules/diffusionmodules/__init__.py b/ControlNet/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/ldm/modules/diffusionmodules/model.py b/ControlNet/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b
--- /dev/null
+++ b/ControlNet/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+from ldm.modules.attention import MemoryEfficientCrossAttention
+ import xformers
+ import xformers.ops
+ print("No module 'xformers'. Proceeding without it.")
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x+h
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+ h_ = self.proj_out(h_)
+ return x+h_
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return x+out
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None):
+ b, c, h, w = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+ return x + out
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ raise NotImplementedError()
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+ def get_last_layer(self):
+ return self.conv_out.weight
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+ # timestep embedding
+ temb = None
+ # z to block_in
+ h = self.conv_in(z)
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ # end
+ if self.give_pre_end:
+ return h
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
diff --git a/ControlNet/ldm/modules/diffusionmodules/openaimodel.py b/ControlNet/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df6b5abfe8eff07f0c8e8703ba8aee90d45984b
--- /dev/null
+++ b/ControlNet/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,786 @@
+from abc import abstractmethod
+import math
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+def convert_module_to_f32(x):
+ pass
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+ def forward(self,x):
+ return self.up(x)
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+ self.updown = up or down
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ else:
+ raise ValueError()
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/ControlNet/ldm/modules/diffusionmodules/upscaling.py b/ControlNet/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988
--- /dev/null
+++ b/ControlNet/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+class AbstractLowScaleModel(nn.Module):
+ # for concatenating a downsampled image to the latent representation
+ def __init__(self, noise_schedule_config=None):
+ super(AbstractLowScaleModel, self).__init__()
+ if noise_schedule_config is not None:
+ self.register_schedule(**noise_schedule_config)
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+ def forward(self, x):
+ return x, None
+ def decode(self, x):
+ return x
+class SimpleImageConcat(AbstractLowScaleModel):
+ # no noise level conditioning
+ def __init__(self):
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+ self.max_noise_level = 0
+ def forward(self, x):
+ # fix to constant noise level
+ return x, torch.zeros(x.shape[0], device=x.device).long()
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+ super().__init__(noise_schedule_config=noise_schedule_config)
+ self.max_noise_level = max_noise_level
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ z = self.q_sample(x, noise_level)
+ return z, noise_level
diff --git a/ControlNet/ldm/modules/diffusionmodules/util.py b/ControlNet/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f
--- /dev/null
+++ b/ControlNet/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,270 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+# thanks!
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+from ldm.util import instantiate_from_config
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+class HybridConditioner(nn.Module):
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/ControlNet/ldm/modules/distributions/__init__.py b/ControlNet/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/ldm/modules/distributions/distributions.py b/ControlNet/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/ControlNet/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+ def mode(self):
+ raise NotImplementedError()
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+ def sample(self):
+ return self.value
+ def mode(self):
+ return self.value
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+ def mode(self):
+ return self.mean
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/ControlNet/ldm/modules/ema.py b/ControlNet/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/ControlNet/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+ self.collected_params = []
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+ def forward(self, model):
+ decay = self.decay
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+ one_minus_decay = 1.0 - decay
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/ControlNet/ldm/modules/encoders/__init__.py b/ControlNet/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/ldm/modules/encoders/modules.py b/ControlNet/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b
--- /dev/null
+++ b/ControlNet/ldm/modules/encoders/modules.py
@@ -0,0 +1,213 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+import open_clip
+from ldm.util import default, count_params
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+class IdentityEncoder(AbstractEncoder):
+ def encode(self, x):
+ return x
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0. and not disable_dropout:
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+ def encode(self, text):
+ return self(text)
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = [
+ "last",
+ "pooled",
+ "hidden"
+ ]
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+ def encode(self, text):
+ return self(text)
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ #"pooled",
+ "last",
+ "penultimate"
+ ]
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+ def encode(self, text):
+ return self(text)
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+ clip_max_length=77, t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+ def encode(self, text):
+ return self(text)
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
diff --git a/ControlNet/ldm/modules/image_degradation/__init__.py b/ControlNet/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/ControlNet/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ControlNet/ldm/modules/image_degradation/bsrgan.py b/ControlNet/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/ControlNet/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+import numpy as np
+import cv2
+import torch
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+import ldm.modules.image_degradation.utils_image as util
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+ return k
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+ k = k / np.sum(k)
+ return k
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+ return x
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+ return x
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+ return img
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+ return img
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+ hq = img.copy()
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+ return img, hq
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+ hq = image.copy()
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+ for i in shuffle_order:
+ if i == 0:
+ image = add_blur(image, sf=sf)
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+ return img, hq
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/ControlNet/ldm/modules/image_degradation/bsrgan_light.py b/ControlNet/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/ControlNet/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+import ldm.modules.image_degradation.utils_image as util
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+ return k
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+ k = k / np.sum(k)
+ return k
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+ return x
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+ return x
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ wd2 = wd2/4
+ wd = wd/4
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+ return img
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+ return img
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+ hq = img.copy()
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+ return img, hq
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+ hq = image.copy()
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+ for i in shuffle_order:
+ if i == 0:
+ image = add_blur(image, sf=sf)
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+ if i == 0:
+ pass
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ if up:
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
+ example = {"image": image}
+ return example
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/ControlNet/ldm/modules/image_degradation/utils/test.png b/ControlNet/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/ControlNet/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ControlNet/ldm/modules/image_degradation/utils_image.py b/ControlNet/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/ControlNet/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+ return patches
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+def uint2single(img):
+ return np.float32(img/255.)
+def single2uint(img):
+ return np.uint8((img.clip(0, 1)*255.).round())
+def uint162single(img):
+ return np.float32(img/65535.)
+def single2uint16(img):
+ return np.uint16((img.clip(0, 1)*65535.).round())
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return img
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+ return img_tensor.type_as(img)
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+ return [_augment(img) for img in img_list]
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+# --------------------------------------------
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+# --------------------------------------------
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2.numpy()
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ControlNet/ldm/modules/midas/__init__.py b/ControlNet/ldm/modules/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/ldm/modules/midas/api.py b/ControlNet/ldm/modules/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/api.py
@@ -0,0 +1,170 @@
+# based on https://github.com/isl-org/MiDaS
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.midas.midas.midas_net import MidasNet
+from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "",
+ "midas_v21_small": "",
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+ return transform
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+ return model.eval(), transform
+class MiDaSInference(nn.Module):
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+ def forward(self, x):
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+ # NOTE: we expect that the correct transform has been called during dataloading.
+ with torch.no_grad():
+ prediction = self.model(x)
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=x.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+ return prediction
diff --git a/ControlNet/ldm/modules/midas/midas/__init__.py b/ControlNet/ldm/modules/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ControlNet/ldm/modules/midas/midas/base_model.py b/ControlNet/ldm/modules/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+ self.load_state_dict(parameters)
diff --git a/ControlNet/ldm/modules/midas/midas/blocks.py b/ControlNet/ldm/modules/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+ return pretrained, scratch
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+ return pretrained
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+ return pretrained
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: interpolated data
+ """
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+ return x
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+ def __init__(self, features):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+ self.relu = nn.ReLU(inplace=True)
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ return out + x
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+ def __init__(self, features):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+ output = self.resConfUnit2(output)
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+ return output
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+ def __init__(self, features, activation, bn):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+ self.bn = bn
+ self.groups=1
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+ if self.groups > 1:
+ out = self.conv_merge(out)
+ return self.skip_add.add(out, x)
+ # return out + x
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups=1
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+ self.skip_add = nn.quantized.FloatFunctional()
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+ output = self.resConfUnit2(output)
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+ output = self.out_conv(output)
+ return output
diff --git a/ControlNet/ldm/modules/midas/midas/dpt_depth.py b/ControlNet/ldm/modules/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+ super(DPT, self).__init__()
+ self.channels_last = channels_last
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+ self.scratch.output_conv = head
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+ out = self.scratch.output_conv(path_1)
+ return out
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+ super().__init__(head, **kwargs)
+ if path is not None:
+ self.load(path)
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
diff --git a/ControlNet/ldm/modules/midas/midas/midas_net.py b/ControlNet/ldm/modules/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+import torch
+import torch.nn as nn
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+ super(MidasNet, self).__init__()
+ use_pretrained = False if path is None else True
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+ if path:
+ self.load(path)
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input data (image)
+ Returns:
+ tensor: depth
+ """
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+ out = self.scratch.output_conv(path_1)
+ return torch.squeeze(out, dim=1)
diff --git a/ControlNet/ldm/modules/midas/midas/midas_net_custom.py b/ControlNet/ldm/modules/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+import torch
+import torch.nn as nn
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+ super(MidasNet_small, self).__init__()
+ use_pretrained = False if path else True
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+ self.groups = 1
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+ self.scratch.activation = nn.ReLU(False)
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+ if path:
+ self.load(path)
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input data (image)
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+ out = self.scratch.output_conv(path_1)
+ return torch.squeeze(out, dim=1)
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/ControlNet/ldm/modules/midas/midas/transforms.py b/ControlNet/ldm/modules/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+ scale = max(scale)
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+ return tuple(shape)
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+ return y
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+ return (new_width, new_height)
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+ return sample
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+ return sample
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+ def __init__(self):
+ pass
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+ return sample
diff --git a/ControlNet/ldm/modules/midas/midas/vit.py b/ControlNet/ldm/modules/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+ def forward(self, x):
+ return x[:, self.start_index :]
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+ return self.project(features)
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+ glob = pretrained.model.forward_flex(x)
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+ return layer_1, layer_2, layer_3, layer_4
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+ B = x.shape[0]
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x
+activations = {}
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+ return hook
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+ return readout_oper
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ pretrained = nn.Module()
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+ pretrained.activations = activations
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+ return pretrained
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+ pretrained = nn.Module()
+ pretrained.model = model
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+ pretrained.activations = activations
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+ return pretrained
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/ControlNet/ldm/modules/midas/utils.py b/ControlNet/ldm/modules/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/ControlNet/ldm/modules/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+def read_pfm(path):
+ """Read pfm file.
+ Args:
+ path (str): path to file
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data, scale
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+ with open(path, "wb") as file:
+ color = None
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+ image = np.flipud(image)
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+ endian = image.dtype.byteorder
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+ file.write("%f\n".encode() % scale)
+ image.tofile(file)
+def read_image(path):
+ """Read image and output RGB image (0-1).
+ Args:
+ path (str): path to file
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+ return img
+def resize_image(img):
+ """Resize image and make it fit for network.
+ Args:
+ img (array): image
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+ return img_resized
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+ return depth_resized
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+ depth_min = depth.min()
+ depth_max = depth.max()
+ max_val = (2**(8*bits))-1
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+ return
diff --git a/ControlNet/ldm/util.py b/ControlNet/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cb050ece6f401a22dde098ce3f1ff663c5eb6a
--- /dev/null
+++ b/ControlNet/ldm/util.py
@@ -0,0 +1,197 @@
+import importlib
+import torch
+from torch import optim
+import numpy as np
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+def exists(x):
+ return x is not None
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+class AdamWwithEMAandWings(optim.Optimizer):
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
+ ema_power=1., param_names=()):
+ """AdamW that saves EMA versions of the parameters."""
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= ema_decay <= 1.0:
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+ ema_power=ema_power, param_names=param_names)
+ super().__init__(params, defaults)
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ ema_params_with_grad = []
+ state_sums = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ ema_decay = group['ema_decay']
+ ema_power = group['ema_power']
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ grads.append(p.grad)
+ state = self.state[p]
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of parameter values
+ state['param_exp_avg'] = p.detach().float().clone()
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ ema_params_with_grad.append(state['param_exp_avg'])
+ if amsgrad:
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+ # update the steps for each param group update
+ state['step'] += 1
+ # record the step after step update
+ state_steps.append(state['step'])
+ optim._functional.adamw(params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=False)
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+ return loss
\ No newline at end of file
diff --git a/app.py b/app.py
index 689ab2d63b5271c730d86820561e550fa9dec63d..63d5fbaba235b3f0f1563e6e231215ca8e3ce3ac 100644
--- a/app.py
+++ b/app.py
@@ -57,13 +57,6 @@ import re
import gradio as gr
-def set_openai_api_key(api_key, agent):
- if api_key:
- os.environ["OPENAI_API_KEY"] = api_key
- vectorstore = get_weaviate_store()
- qa_chain = get_new_chain1(vectorstore)
- os.environ["OPENAI_API_KEY"] = ""
- return qa_chain
def cut_dialogue_history(history_memory, keep_last_n_words=500):
tokens = history_memory.split()
@@ -106,26 +99,26 @@ def create_model(config_path, device):
class ConversationBot:
def __init__(self):
print("Initializing VisualChatGPT")
- self.llm = OpenAI(temperature=0)
+ self.llm = OpenAI(temperature=0, openai_api_key="")
self.edit = ImageEditing(device="cuda:6")
self.i2t = ImageCaptioning(device="cuda:4")
self.t2i = T2I(device="cuda:1")
- self.image2canny = image2canny()
- self.canny2image = canny2image(device="cuda:1")
- self.image2line = image2line()
- self.line2image = line2image(device="cuda:1")
- self.image2hed = image2hed()
- self.hed2image = hed2image(device="cuda:2")
- self.image2scribble = image2scribble()
- self.scribble2image = scribble2image(device="cuda:3")
- self.image2pose = image2pose()
- self.pose2image = pose2image(device="cuda:3")
- self.BLIPVQA = BLIPVQA(device="cuda:4")
- self.image2seg = image2seg()
- self.seg2image = seg2image(device="cuda:7")
- self.image2depth = image2depth()
- self.depth2image = depth2image(device="cuda:7")
- self.image2normal = image2normal()
+ # self.image2canny = image2canny()
+ # self.canny2image = canny2image(device="cuda:1")
+ # self.image2line = image2line()
+ # self.line2image = line2image(device="cuda:1")
+ # self.image2hed = image2hed()
+ # self.hed2image = hed2image(device="cuda:2")
+ # self.image2scribble = image2scribble()
+ # self.scribble2image = scribble2image(device="cuda:3")
+ # self.image2pose = image2pose()
+ # self.pose2image = pose2image(device="cuda:3")
+ # self.BLIPVQA = BLIPVQA(device="cuda:4")
+ # self.image2seg = image2seg()
+ # self.seg2image = seg2image(device="cuda:7")
+ # self.image2depth = image2depth()
+ # self.depth2image = depth2image(device="cuda:7")
+ # self.image2normal = image2normal()
self.normal2image = normal2image(device="cuda:5")
self.pix2pix = Pix2Pix(device="cuda:3")
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
@@ -274,10 +267,5 @@ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
clear.click(lambda: [], None, chatbot)
clear.click(lambda: [], None, state)
- openai_api_key_textbox.change(
- set_openai_api_key,
- inputs=[openai_api_key_textbox, agent_state],
- outputs=[agent_state],
- )
demo.launch(server_name="", server_port=7860)
diff --git a/visual_foundation_models.py b/visual_foundation_models.py
index 2ae3a21f1ebe276ebc21e724211b590186b2ccde..f8ba69e259bcdb04a29012aa48589f176ebdcd8a 100644
--- a/visual_foundation_models.py
+++ b/visual_foundation_models.py
@@ -4,7 +4,7 @@ from diffusers import StableDiffusionInpaintPipeline
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
-from ldm.util import instantiate_from_config
+from ControlNet.ldm.util import instantiate_from_config
from ControlNet.cldm.model import create_model, load_state_dict
from ControlNet.cldm.ddim_hacked import DDIMSampler
from ControlNet.annotator.canny import CannyDetector