Spaces:
No application file
No application file
1c6c068dac340ae702a665f6810ce95b3184b1d6dd405feac9040ae0cefc341a
Browse files- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py +109 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py +181 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py +561 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py +177 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_kidney_fold0_early.pt +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_kidney_fold0_noearly_t200.pt +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_liver_fold0_early.pt +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_liver_fold0_noearly_t200.pt +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_pancreas_fold0_early.pt +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_pancreas_fold0_noearly_t200.pt +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/recon_96d4_all.ckpt +3 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/utils.py +465 -0
- Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/utils_.py +298 -0
- Generation_Pipeline_filter_all2/syn_kidney/healthy_kidney_1k.txt +565 -0
- Generation_Pipeline_filter_all2/syn_kidney/requirements.txt +94 -0
- Generation_Pipeline_filter_all2/syn_liver/CT_syn_data.py +240 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/.DS_Store +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/README.md +5 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/TumorGenerated.py +39 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__init__.py +5 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml +29 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml +37 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__init__.py +1 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/ddim.py +206 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py +1016 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/text.py +94 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py +75 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/unet.py +226 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/util.py +271 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py +3 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc +0 -0
- Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc +0 -0
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (295 Bytes). View file
|
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc
ADDED
Binary file (3.42 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc
ADDED
Binary file (6.79 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc
ADDED
Binary file (16.6 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Adapted from https://github.com/SongweiGe/TATS"""
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
from ..utils import shift_dim
|
12 |
+
|
13 |
+
|
14 |
+
class Codebook(nn.Module):
|
15 |
+
def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0):
|
16 |
+
super().__init__()
|
17 |
+
self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
|
18 |
+
self.register_buffer('N', torch.zeros(n_codes))
|
19 |
+
self.register_buffer('z_avg', self.embeddings.data.clone())
|
20 |
+
|
21 |
+
self.n_codes = n_codes
|
22 |
+
self.embedding_dim = embedding_dim
|
23 |
+
self._need_init = True
|
24 |
+
self.no_random_restart = no_random_restart
|
25 |
+
self.restart_thres = restart_thres
|
26 |
+
|
27 |
+
def _tile(self, x):
|
28 |
+
d, ew = x.shape
|
29 |
+
if d < self.n_codes:
|
30 |
+
n_repeats = (self.n_codes + d - 1) // d
|
31 |
+
std = 0.01 / np.sqrt(ew)
|
32 |
+
x = x.repeat(n_repeats, 1)
|
33 |
+
x = x + torch.randn_like(x) * std
|
34 |
+
return x
|
35 |
+
|
36 |
+
def _init_embeddings(self, z):
|
37 |
+
# z: [b, c, t, h, w]
|
38 |
+
self._need_init = False
|
39 |
+
breakpoint()
|
40 |
+
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32]
|
41 |
+
y = self._tile(flat_inputs) # [65536, 8]
|
42 |
+
|
43 |
+
d = y.shape[0]
|
44 |
+
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
|
45 |
+
if dist.is_initialized():
|
46 |
+
dist.broadcast(_k_rand, 0)
|
47 |
+
self.embeddings.data.copy_(_k_rand)
|
48 |
+
self.z_avg.data.copy_(_k_rand)
|
49 |
+
self.N.data.copy_(torch.ones(self.n_codes))
|
50 |
+
|
51 |
+
def forward(self, z):
|
52 |
+
# z: [b, c, t, h, w]
|
53 |
+
if self._need_init and self.training:
|
54 |
+
self._init_embeddings(z)
|
55 |
+
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8]
|
56 |
+
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
|
57 |
+
- 2 * flat_inputs @ self.embeddings.t() \
|
58 |
+
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8]
|
59 |
+
|
60 |
+
encoding_indices = torch.argmin(distances, dim=1) # [65536]
|
61 |
+
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(
|
62 |
+
flat_inputs) # [bthw, ncode] [65536, 16384]
|
63 |
+
encoding_indices = encoding_indices.view(
|
64 |
+
z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32]
|
65 |
+
|
66 |
+
embeddings = F.embedding(
|
67 |
+
encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8]
|
68 |
+
embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32]
|
69 |
+
|
70 |
+
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
|
71 |
+
|
72 |
+
# EMA codebook update
|
73 |
+
if self.training:
|
74 |
+
n_total = encode_onehot.sum(dim=0) # [16384]
|
75 |
+
encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384]
|
76 |
+
if dist.is_initialized():
|
77 |
+
dist.all_reduce(n_total)
|
78 |
+
dist.all_reduce(encode_sum)
|
79 |
+
|
80 |
+
self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
|
81 |
+
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
|
82 |
+
|
83 |
+
n = self.N.sum()
|
84 |
+
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
|
85 |
+
encode_normalized = self.z_avg / weights.unsqueeze(1)
|
86 |
+
self.embeddings.data.copy_(encode_normalized)
|
87 |
+
|
88 |
+
y = self._tile(flat_inputs)
|
89 |
+
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
|
90 |
+
if dist.is_initialized():
|
91 |
+
dist.broadcast(_k_rand, 0)
|
92 |
+
|
93 |
+
if not self.no_random_restart:
|
94 |
+
usage = (self.N.view(self.n_codes, 1)
|
95 |
+
>= self.restart_thres).float()
|
96 |
+
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
|
97 |
+
|
98 |
+
embeddings_st = (embeddings - z).detach() + z
|
99 |
+
|
100 |
+
avg_probs = torch.mean(encode_onehot, dim=0)
|
101 |
+
perplexity = torch.exp(-torch.sum(avg_probs *
|
102 |
+
torch.log(avg_probs + 1e-10)))
|
103 |
+
|
104 |
+
return dict(embeddings=embeddings_st, encodings=encoding_indices,
|
105 |
+
commitment_loss=commitment_loss, perplexity=perplexity)
|
106 |
+
|
107 |
+
def dictionary_lookup(self, encodings):
|
108 |
+
embeddings = F.embedding(encodings, self.embeddings)
|
109 |
+
return embeddings
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Adapted from https://github.com/SongweiGe/TATS"""
|
2 |
+
|
3 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
4 |
+
|
5 |
+
|
6 |
+
from collections import namedtuple
|
7 |
+
from torchvision import models
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import requests
|
12 |
+
import os
|
13 |
+
import hashlib
|
14 |
+
URL_MAP = {
|
15 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
16 |
+
}
|
17 |
+
|
18 |
+
CKPT_MAP = {
|
19 |
+
"vgg_lpips": "vgg.pth"
|
20 |
+
}
|
21 |
+
|
22 |
+
MD5_MAP = {
|
23 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def download(url, local_path, chunk_size=1024):
|
28 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
29 |
+
with requests.get(url, stream=True) as r:
|
30 |
+
total_size = int(r.headers.get("content-length", 0))
|
31 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
32 |
+
with open(local_path, "wb") as f:
|
33 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
34 |
+
if data:
|
35 |
+
f.write(data)
|
36 |
+
pbar.update(chunk_size)
|
37 |
+
|
38 |
+
|
39 |
+
def md5_hash(path):
|
40 |
+
with open(path, "rb") as f:
|
41 |
+
content = f.read()
|
42 |
+
return hashlib.md5(content).hexdigest()
|
43 |
+
|
44 |
+
|
45 |
+
def get_ckpt_path(name, root, check=False):
|
46 |
+
assert name in URL_MAP
|
47 |
+
path = os.path.join(root, CKPT_MAP[name])
|
48 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
49 |
+
print("Downloading {} model from {} to {}".format(
|
50 |
+
name, URL_MAP[name], path))
|
51 |
+
download(URL_MAP[name], path)
|
52 |
+
md5 = md5_hash(path)
|
53 |
+
assert md5 == MD5_MAP[name], md5
|
54 |
+
return path
|
55 |
+
|
56 |
+
|
57 |
+
class LPIPS(nn.Module):
|
58 |
+
# Learned perceptual metric
|
59 |
+
def __init__(self, use_dropout=True):
|
60 |
+
super().__init__()
|
61 |
+
self.scaling_layer = ScalingLayer()
|
62 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
63 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
64 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
65 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
66 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
67 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
68 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
69 |
+
# self.load_from_pretrained()
|
70 |
+
for param in self.parameters():
|
71 |
+
param.requires_grad = False
|
72 |
+
|
73 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
74 |
+
ckpt = get_ckpt_path(name, os.path.join(
|
75 |
+
os.path.dirname(os.path.abspath(__file__)), "cache"))
|
76 |
+
self.load_state_dict(torch.load(
|
77 |
+
ckpt, map_location=torch.device("cpu")), strict=False)
|
78 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
82 |
+
if name is not "vgg_lpips":
|
83 |
+
raise NotImplementedError
|
84 |
+
model = cls()
|
85 |
+
ckpt = get_ckpt_path(name, os.path.join(
|
86 |
+
os.path.dirname(os.path.abspath(__file__)), "cache"))
|
87 |
+
model.load_state_dict(torch.load(
|
88 |
+
ckpt, map_location=torch.device("cpu")), strict=False)
|
89 |
+
return model
|
90 |
+
|
91 |
+
def forward(self, input, target):
|
92 |
+
in0_input, in1_input = (self.scaling_layer(
|
93 |
+
input), self.scaling_layer(target))
|
94 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
95 |
+
feats0, feats1, diffs = {}, {}, {}
|
96 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
97 |
+
for kk in range(len(self.chns)):
|
98 |
+
feats0[kk], feats1[kk] = normalize_tensor(
|
99 |
+
outs0[kk]), normalize_tensor(outs1[kk])
|
100 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
101 |
+
|
102 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
103 |
+
for kk in range(len(self.chns))]
|
104 |
+
val = res[0]
|
105 |
+
for l in range(1, len(self.chns)):
|
106 |
+
val += res[l]
|
107 |
+
return val
|
108 |
+
|
109 |
+
|
110 |
+
class ScalingLayer(nn.Module):
|
111 |
+
def __init__(self):
|
112 |
+
super(ScalingLayer, self).__init__()
|
113 |
+
self.register_buffer('shift', torch.Tensor(
|
114 |
+
[-.030, -.088, -.188])[None, :, None, None])
|
115 |
+
self.register_buffer('scale', torch.Tensor(
|
116 |
+
[.458, .448, .450])[None, :, None, None])
|
117 |
+
|
118 |
+
def forward(self, inp):
|
119 |
+
return (inp - self.shift) / self.scale
|
120 |
+
|
121 |
+
|
122 |
+
class NetLinLayer(nn.Module):
|
123 |
+
""" A single linear layer which does a 1x1 conv """
|
124 |
+
|
125 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
126 |
+
super(NetLinLayer, self).__init__()
|
127 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
128 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1,
|
129 |
+
padding=0, bias=False), ]
|
130 |
+
self.model = nn.Sequential(*layers)
|
131 |
+
|
132 |
+
|
133 |
+
class vgg16(torch.nn.Module):
|
134 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
135 |
+
super(vgg16, self).__init__()
|
136 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
137 |
+
self.slice1 = torch.nn.Sequential()
|
138 |
+
self.slice2 = torch.nn.Sequential()
|
139 |
+
self.slice3 = torch.nn.Sequential()
|
140 |
+
self.slice4 = torch.nn.Sequential()
|
141 |
+
self.slice5 = torch.nn.Sequential()
|
142 |
+
self.N_slices = 5
|
143 |
+
for x in range(4):
|
144 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
145 |
+
for x in range(4, 9):
|
146 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
147 |
+
for x in range(9, 16):
|
148 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
149 |
+
for x in range(16, 23):
|
150 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
151 |
+
for x in range(23, 30):
|
152 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
153 |
+
if not requires_grad:
|
154 |
+
for param in self.parameters():
|
155 |
+
param.requires_grad = False
|
156 |
+
|
157 |
+
def forward(self, X):
|
158 |
+
h = self.slice1(X)
|
159 |
+
h_relu1_2 = h
|
160 |
+
h = self.slice2(h)
|
161 |
+
h_relu2_2 = h
|
162 |
+
h = self.slice3(h)
|
163 |
+
h_relu3_3 = h
|
164 |
+
h = self.slice4(h)
|
165 |
+
h_relu4_3 = h
|
166 |
+
h = self.slice5(h)
|
167 |
+
h_relu5_3 = h
|
168 |
+
vgg_outputs = namedtuple(
|
169 |
+
"VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
170 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2,
|
171 |
+
h_relu3_3, h_relu4_3, h_relu5_3)
|
172 |
+
return out
|
173 |
+
|
174 |
+
|
175 |
+
def normalize_tensor(x, eps=1e-10):
|
176 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
177 |
+
return x/(norm_factor+eps)
|
178 |
+
|
179 |
+
|
180 |
+
def spatial_average(x, keepdim=True):
|
181 |
+
return x.mean([2, 3], keepdim=keepdim)
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Adapted from https://github.com/SongweiGe/TATS"""
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import math
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
import pickle as pkl
|
8 |
+
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.distributed as dist
|
14 |
+
|
15 |
+
from ..utils import shift_dim, adopt_weight, comp_getattr
|
16 |
+
from .lpips import LPIPS
|
17 |
+
from .codebook import Codebook
|
18 |
+
|
19 |
+
|
20 |
+
def silu(x):
|
21 |
+
return x*torch.sigmoid(x)
|
22 |
+
|
23 |
+
|
24 |
+
class SiLU(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(SiLU, self).__init__()
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return silu(x)
|
30 |
+
|
31 |
+
|
32 |
+
def hinge_d_loss(logits_real, logits_fake):
|
33 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
34 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
35 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
36 |
+
return d_loss
|
37 |
+
|
38 |
+
|
39 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
40 |
+
d_loss = 0.5 * (
|
41 |
+
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
42 |
+
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
43 |
+
return d_loss
|
44 |
+
|
45 |
+
|
46 |
+
class VQGAN(pl.LightningModule):
|
47 |
+
def __init__(self, cfg):
|
48 |
+
super().__init__()
|
49 |
+
self.cfg = cfg
|
50 |
+
self.embedding_dim = cfg.model.embedding_dim # 8
|
51 |
+
self.n_codes = cfg.model.n_codes # 16384
|
52 |
+
|
53 |
+
self.encoder = Encoder(cfg.model.n_hiddens, # 16
|
54 |
+
cfg.model.downsample, # [2, 2, 2]
|
55 |
+
cfg.dataset.image_channels, # 1
|
56 |
+
cfg.model.norm_type, # group
|
57 |
+
cfg.model.padding_type, # replicate
|
58 |
+
cfg.model.num_groups, # 32
|
59 |
+
)
|
60 |
+
self.decoder = Decoder(
|
61 |
+
cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups)
|
62 |
+
self.enc_out_ch = self.encoder.out_channels
|
63 |
+
self.pre_vq_conv = SamePadConv3d(
|
64 |
+
self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type)
|
65 |
+
self.post_vq_conv = SamePadConv3d(
|
66 |
+
cfg.model.embedding_dim, self.enc_out_ch, 1)
|
67 |
+
|
68 |
+
self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim,
|
69 |
+
no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres)
|
70 |
+
|
71 |
+
self.gan_feat_weight = cfg.model.gan_feat_weight
|
72 |
+
# TODO: Changed batchnorm from sync to normal
|
73 |
+
self.image_discriminator = NLayerDiscriminator(
|
74 |
+
cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d)
|
75 |
+
self.video_discriminator = NLayerDiscriminator3D(
|
76 |
+
cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d)
|
77 |
+
|
78 |
+
if cfg.model.disc_loss_type == 'vanilla':
|
79 |
+
self.disc_loss = vanilla_d_loss
|
80 |
+
elif cfg.model.disc_loss_type == 'hinge':
|
81 |
+
self.disc_loss = hinge_d_loss
|
82 |
+
|
83 |
+
self.perceptual_model = LPIPS().eval()
|
84 |
+
|
85 |
+
self.image_gan_weight = cfg.model.image_gan_weight
|
86 |
+
self.video_gan_weight = cfg.model.video_gan_weight
|
87 |
+
|
88 |
+
self.perceptual_weight = cfg.model.perceptual_weight
|
89 |
+
|
90 |
+
self.l1_weight = cfg.model.l1_weight
|
91 |
+
self.save_hyperparameters()
|
92 |
+
|
93 |
+
def encode(self, x, include_embeddings=False, quantize=True):
|
94 |
+
h = self.pre_vq_conv(self.encoder(x))
|
95 |
+
if quantize:
|
96 |
+
vq_output = self.codebook(h)
|
97 |
+
if include_embeddings:
|
98 |
+
return vq_output['embeddings'], vq_output['encodings']
|
99 |
+
else:
|
100 |
+
return vq_output['encodings']
|
101 |
+
return h
|
102 |
+
|
103 |
+
def decode(self, latent, quantize=False):
|
104 |
+
if quantize:
|
105 |
+
vq_output = self.codebook(latent)
|
106 |
+
latent = vq_output['encodings']
|
107 |
+
h = F.embedding(latent, self.codebook.embeddings)
|
108 |
+
h = self.post_vq_conv(shift_dim(h, -1, 1))
|
109 |
+
return self.decoder(h)
|
110 |
+
|
111 |
+
def forward(self, x, optimizer_idx=None, log_image=False):
|
112 |
+
B, C, T, H, W = x.shape
|
113 |
+
z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32]
|
114 |
+
vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity']
|
115 |
+
x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32]
|
116 |
+
|
117 |
+
recon_loss = F.l1_loss(x_recon, x) * self.l1_weight
|
118 |
+
|
119 |
+
# Selects one random 2D image from each 3D Image
|
120 |
+
frame_idx = torch.randint(0, T, [B]).cuda()
|
121 |
+
frame_idx_selected = frame_idx.reshape(-1,
|
122 |
+
1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64]
|
123 |
+
frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64]
|
124 |
+
frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64]
|
125 |
+
|
126 |
+
if log_image:
|
127 |
+
return frames, frames_recon, x, x_recon
|
128 |
+
|
129 |
+
if optimizer_idx == 0:
|
130 |
+
# Autoencoder - train the "generator"
|
131 |
+
|
132 |
+
# Perceptual loss
|
133 |
+
perceptual_loss = 0
|
134 |
+
if self.perceptual_weight > 0:
|
135 |
+
perceptual_loss = self.perceptual_model(
|
136 |
+
frames, frames_recon).mean() * self.perceptual_weight
|
137 |
+
|
138 |
+
# Discriminator loss (turned on after a certain epoch)
|
139 |
+
logits_image_fake, pred_image_fake = self.image_discriminator(
|
140 |
+
frames_recon)
|
141 |
+
logits_video_fake, pred_video_fake = self.video_discriminator(
|
142 |
+
x_recon)
|
143 |
+
g_image_loss = -torch.mean(logits_image_fake)
|
144 |
+
g_video_loss = -torch.mean(logits_video_fake)
|
145 |
+
g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss
|
146 |
+
disc_factor = adopt_weight(
|
147 |
+
self.global_step, threshold=self.cfg.model.discriminator_iter_start)
|
148 |
+
aeloss = disc_factor * g_loss
|
149 |
+
|
150 |
+
# GAN feature matching loss - tune features such that we get the same prediction result on the discriminator
|
151 |
+
image_gan_feat_loss = 0
|
152 |
+
video_gan_feat_loss = 0
|
153 |
+
feat_weights = 4.0 / (3 + 1)
|
154 |
+
if self.image_gan_weight > 0:
|
155 |
+
logits_image_real, pred_image_real = self.image_discriminator(
|
156 |
+
frames)
|
157 |
+
for i in range(len(pred_image_fake)-1):
|
158 |
+
image_gan_feat_loss += feat_weights * \
|
159 |
+
F.l1_loss(pred_image_fake[i], pred_image_real[i].detach(
|
160 |
+
)) * (self.image_gan_weight > 0)
|
161 |
+
if self.video_gan_weight > 0:
|
162 |
+
logits_video_real, pred_video_real = self.video_discriminator(
|
163 |
+
x)
|
164 |
+
for i in range(len(pred_video_fake)-1):
|
165 |
+
video_gan_feat_loss += feat_weights * \
|
166 |
+
F.l1_loss(pred_video_fake[i], pred_video_real[i].detach(
|
167 |
+
)) * (self.video_gan_weight > 0)
|
168 |
+
gan_feat_loss = disc_factor * self.gan_feat_weight * \
|
169 |
+
(image_gan_feat_loss + video_gan_feat_loss)
|
170 |
+
|
171 |
+
self.log("train/g_image_loss", g_image_loss,
|
172 |
+
logger=True, on_step=True, on_epoch=True)
|
173 |
+
self.log("train/g_video_loss", g_video_loss,
|
174 |
+
logger=True, on_step=True, on_epoch=True)
|
175 |
+
self.log("train/image_gan_feat_loss", image_gan_feat_loss,
|
176 |
+
logger=True, on_step=True, on_epoch=True)
|
177 |
+
self.log("train/video_gan_feat_loss", video_gan_feat_loss,
|
178 |
+
logger=True, on_step=True, on_epoch=True)
|
179 |
+
self.log("train/perceptual_loss", perceptual_loss,
|
180 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
181 |
+
self.log("train/recon_loss", recon_loss, prog_bar=True,
|
182 |
+
logger=True, on_step=True, on_epoch=True)
|
183 |
+
self.log("train/aeloss", aeloss, prog_bar=True,
|
184 |
+
logger=True, on_step=True, on_epoch=True)
|
185 |
+
self.log("train/commitment_loss", vq_output['commitment_loss'],
|
186 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
187 |
+
self.log('train/perplexity', vq_output['perplexity'],
|
188 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
189 |
+
return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss
|
190 |
+
|
191 |
+
if optimizer_idx == 1:
|
192 |
+
# Train discriminator
|
193 |
+
logits_image_real, _ = self.image_discriminator(frames.detach())
|
194 |
+
logits_video_real, _ = self.video_discriminator(x.detach())
|
195 |
+
|
196 |
+
logits_image_fake, _ = self.image_discriminator(
|
197 |
+
frames_recon.detach())
|
198 |
+
logits_video_fake, _ = self.video_discriminator(x_recon.detach())
|
199 |
+
|
200 |
+
d_image_loss = self.disc_loss(logits_image_real, logits_image_fake)
|
201 |
+
d_video_loss = self.disc_loss(logits_video_real, logits_video_fake)
|
202 |
+
disc_factor = adopt_weight(
|
203 |
+
self.global_step, threshold=self.cfg.model.discriminator_iter_start)
|
204 |
+
discloss = disc_factor * \
|
205 |
+
(self.image_gan_weight*d_image_loss +
|
206 |
+
self.video_gan_weight*d_video_loss)
|
207 |
+
|
208 |
+
self.log("train/logits_image_real", logits_image_real.mean().detach(),
|
209 |
+
logger=True, on_step=True, on_epoch=True)
|
210 |
+
self.log("train/logits_image_fake", logits_image_fake.mean().detach(),
|
211 |
+
logger=True, on_step=True, on_epoch=True)
|
212 |
+
self.log("train/logits_video_real", logits_video_real.mean().detach(),
|
213 |
+
logger=True, on_step=True, on_epoch=True)
|
214 |
+
self.log("train/logits_video_fake", logits_video_fake.mean().detach(),
|
215 |
+
logger=True, on_step=True, on_epoch=True)
|
216 |
+
self.log("train/d_image_loss", d_image_loss,
|
217 |
+
logger=True, on_step=True, on_epoch=True)
|
218 |
+
self.log("train/d_video_loss", d_video_loss,
|
219 |
+
logger=True, on_step=True, on_epoch=True)
|
220 |
+
self.log("train/discloss", discloss, prog_bar=True,
|
221 |
+
logger=True, on_step=True, on_epoch=True)
|
222 |
+
return discloss
|
223 |
+
|
224 |
+
perceptual_loss = self.perceptual_model(
|
225 |
+
frames, frames_recon) * self.perceptual_weight
|
226 |
+
return recon_loss, x_recon, vq_output, perceptual_loss
|
227 |
+
|
228 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
229 |
+
x = batch['image']
|
230 |
+
if optimizer_idx == 0:
|
231 |
+
recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(
|
232 |
+
x, optimizer_idx)
|
233 |
+
commitment_loss = vq_output['commitment_loss']
|
234 |
+
loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss
|
235 |
+
if optimizer_idx == 1:
|
236 |
+
discloss = self.forward(x, optimizer_idx)
|
237 |
+
loss = discloss
|
238 |
+
return loss
|
239 |
+
|
240 |
+
def validation_step(self, batch, batch_idx):
|
241 |
+
x = batch['image'] # TODO: batch['stft']
|
242 |
+
recon_loss, _, vq_output, perceptual_loss = self.forward(x)
|
243 |
+
self.log('val/recon_loss', recon_loss, prog_bar=True)
|
244 |
+
self.log('val/perceptual_loss', perceptual_loss, prog_bar=True)
|
245 |
+
self.log('val/perplexity', vq_output['perplexity'], prog_bar=True)
|
246 |
+
self.log('val/commitment_loss',
|
247 |
+
vq_output['commitment_loss'], prog_bar=True)
|
248 |
+
|
249 |
+
def configure_optimizers(self):
|
250 |
+
lr = self.cfg.model.lr
|
251 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
|
252 |
+
list(self.decoder.parameters()) +
|
253 |
+
list(self.pre_vq_conv.parameters()) +
|
254 |
+
list(self.post_vq_conv.parameters()) +
|
255 |
+
list(self.codebook.parameters()),
|
256 |
+
lr=lr, betas=(0.5, 0.9))
|
257 |
+
opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) +
|
258 |
+
list(self.video_discriminator.parameters()),
|
259 |
+
lr=lr, betas=(0.5, 0.9))
|
260 |
+
return [opt_ae, opt_disc], []
|
261 |
+
|
262 |
+
def log_images(self, batch, **kwargs):
|
263 |
+
log = dict()
|
264 |
+
x = batch['image']
|
265 |
+
x = x.to(self.device)
|
266 |
+
frames, frames_rec, _, _ = self(x, log_image=True)
|
267 |
+
log["inputs"] = frames
|
268 |
+
log["reconstructions"] = frames_rec
|
269 |
+
#log['mean_org'] = batch['mean_org']
|
270 |
+
#log['std_org'] = batch['std_org']
|
271 |
+
return log
|
272 |
+
|
273 |
+
def log_videos(self, batch, **kwargs):
|
274 |
+
log = dict()
|
275 |
+
x = batch['image']
|
276 |
+
_, _, x, x_rec = self(x, log_image=True)
|
277 |
+
log["inputs"] = x
|
278 |
+
log["reconstructions"] = x_rec
|
279 |
+
#log['mean_org'] = batch['mean_org']
|
280 |
+
#log['std_org'] = batch['std_org']
|
281 |
+
return log
|
282 |
+
|
283 |
+
|
284 |
+
def Normalize(in_channels, norm_type='group', num_groups=32):
|
285 |
+
assert norm_type in ['group', 'batch']
|
286 |
+
if norm_type == 'group':
|
287 |
+
# TODO Changed num_groups from 32 to 8
|
288 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
289 |
+
elif norm_type == 'batch':
|
290 |
+
return torch.nn.SyncBatchNorm(in_channels)
|
291 |
+
|
292 |
+
|
293 |
+
class Encoder(nn.Module):
|
294 |
+
def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32):
|
295 |
+
super().__init__()
|
296 |
+
n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
|
297 |
+
self.conv_blocks = nn.ModuleList()
|
298 |
+
max_ds = n_times_downsample.max()
|
299 |
+
|
300 |
+
self.conv_first = SamePadConv3d(
|
301 |
+
image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
|
302 |
+
|
303 |
+
for i in range(max_ds):
|
304 |
+
block = nn.Module()
|
305 |
+
in_channels = n_hiddens * 2**i
|
306 |
+
out_channels = n_hiddens * 2**(i+1)
|
307 |
+
stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
|
308 |
+
block.down = SamePadConv3d(
|
309 |
+
in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
|
310 |
+
block.res = ResBlock(
|
311 |
+
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
|
312 |
+
self.conv_blocks.append(block)
|
313 |
+
n_times_downsample -= 1
|
314 |
+
|
315 |
+
self.final_block = nn.Sequential(
|
316 |
+
Normalize(out_channels, norm_type, num_groups=num_groups),
|
317 |
+
SiLU()
|
318 |
+
)
|
319 |
+
|
320 |
+
self.out_channels = out_channels
|
321 |
+
|
322 |
+
def forward(self, x):
|
323 |
+
h = self.conv_first(x)
|
324 |
+
for block in self.conv_blocks:
|
325 |
+
h = block.down(h)
|
326 |
+
h = block.res(h)
|
327 |
+
h = self.final_block(h)
|
328 |
+
return h
|
329 |
+
|
330 |
+
|
331 |
+
class Decoder(nn.Module):
|
332 |
+
def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32):
|
333 |
+
super().__init__()
|
334 |
+
|
335 |
+
n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
|
336 |
+
max_us = n_times_upsample.max()
|
337 |
+
|
338 |
+
in_channels = n_hiddens*2**max_us
|
339 |
+
self.final_block = nn.Sequential(
|
340 |
+
Normalize(in_channels, norm_type, num_groups=num_groups),
|
341 |
+
SiLU()
|
342 |
+
)
|
343 |
+
|
344 |
+
self.conv_blocks = nn.ModuleList()
|
345 |
+
for i in range(max_us):
|
346 |
+
block = nn.Module()
|
347 |
+
in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1)
|
348 |
+
out_channels = n_hiddens*2**(max_us-i)
|
349 |
+
us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
|
350 |
+
block.up = SamePadConvTranspose3d(
|
351 |
+
in_channels, out_channels, 4, stride=us)
|
352 |
+
block.res1 = ResBlock(
|
353 |
+
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
|
354 |
+
block.res2 = ResBlock(
|
355 |
+
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
|
356 |
+
self.conv_blocks.append(block)
|
357 |
+
n_times_upsample -= 1
|
358 |
+
|
359 |
+
self.conv_last = SamePadConv3d(
|
360 |
+
out_channels, image_channel, kernel_size=3)
|
361 |
+
|
362 |
+
def forward(self, x):
|
363 |
+
h = self.final_block(x)
|
364 |
+
for i, block in enumerate(self.conv_blocks):
|
365 |
+
h = block.up(h)
|
366 |
+
h = block.res1(h)
|
367 |
+
h = block.res2(h)
|
368 |
+
h = self.conv_last(h)
|
369 |
+
return h
|
370 |
+
|
371 |
+
|
372 |
+
class ResBlock(nn.Module):
|
373 |
+
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32):
|
374 |
+
super().__init__()
|
375 |
+
self.in_channels = in_channels
|
376 |
+
out_channels = in_channels if out_channels is None else out_channels
|
377 |
+
self.out_channels = out_channels
|
378 |
+
self.use_conv_shortcut = conv_shortcut
|
379 |
+
|
380 |
+
self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups)
|
381 |
+
self.conv1 = SamePadConv3d(
|
382 |
+
in_channels, out_channels, kernel_size=3, padding_type=padding_type)
|
383 |
+
self.dropout = torch.nn.Dropout(dropout)
|
384 |
+
self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups)
|
385 |
+
self.conv2 = SamePadConv3d(
|
386 |
+
out_channels, out_channels, kernel_size=3, padding_type=padding_type)
|
387 |
+
if self.in_channels != self.out_channels:
|
388 |
+
self.conv_shortcut = SamePadConv3d(
|
389 |
+
in_channels, out_channels, kernel_size=3, padding_type=padding_type)
|
390 |
+
|
391 |
+
def forward(self, x):
|
392 |
+
h = x
|
393 |
+
h = self.norm1(h)
|
394 |
+
h = silu(h)
|
395 |
+
h = self.conv1(h)
|
396 |
+
h = self.norm2(h)
|
397 |
+
h = silu(h)
|
398 |
+
h = self.conv2(h)
|
399 |
+
|
400 |
+
if self.in_channels != self.out_channels:
|
401 |
+
x = self.conv_shortcut(x)
|
402 |
+
|
403 |
+
return x+h
|
404 |
+
|
405 |
+
|
406 |
+
# Does not support dilation
|
407 |
+
class SamePadConv3d(nn.Module):
|
408 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
|
409 |
+
super().__init__()
|
410 |
+
if isinstance(kernel_size, int):
|
411 |
+
kernel_size = (kernel_size,) * 3
|
412 |
+
if isinstance(stride, int):
|
413 |
+
stride = (stride,) * 3
|
414 |
+
|
415 |
+
# assumes that the input shape is divisible by stride
|
416 |
+
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
|
417 |
+
pad_input = []
|
418 |
+
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
|
419 |
+
pad_input.append((p // 2 + p % 2, p // 2))
|
420 |
+
pad_input = sum(pad_input, tuple())
|
421 |
+
self.pad_input = pad_input
|
422 |
+
self.padding_type = padding_type
|
423 |
+
|
424 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
|
425 |
+
stride=stride, padding=0, bias=bias)
|
426 |
+
|
427 |
+
def forward(self, x):
|
428 |
+
return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
|
429 |
+
|
430 |
+
|
431 |
+
class SamePadConvTranspose3d(nn.Module):
|
432 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
|
433 |
+
super().__init__()
|
434 |
+
if isinstance(kernel_size, int):
|
435 |
+
kernel_size = (kernel_size,) * 3
|
436 |
+
if isinstance(stride, int):
|
437 |
+
stride = (stride,) * 3
|
438 |
+
|
439 |
+
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
|
440 |
+
pad_input = []
|
441 |
+
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
|
442 |
+
pad_input.append((p // 2 + p % 2, p // 2))
|
443 |
+
pad_input = sum(pad_input, tuple())
|
444 |
+
self.pad_input = pad_input
|
445 |
+
self.padding_type = padding_type
|
446 |
+
|
447 |
+
self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
|
448 |
+
stride=stride, bias=bias,
|
449 |
+
padding=tuple([k - 1 for k in kernel_size]))
|
450 |
+
|
451 |
+
def forward(self, x):
|
452 |
+
return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
|
453 |
+
|
454 |
+
|
455 |
+
class NLayerDiscriminator(nn.Module):
|
456 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
|
457 |
+
# def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True):
|
458 |
+
super(NLayerDiscriminator, self).__init__()
|
459 |
+
self.getIntermFeat = getIntermFeat
|
460 |
+
self.n_layers = n_layers
|
461 |
+
|
462 |
+
kw = 4
|
463 |
+
padw = int(np.ceil((kw-1.0)/2))
|
464 |
+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw,
|
465 |
+
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
|
466 |
+
|
467 |
+
nf = ndf
|
468 |
+
for n in range(1, n_layers):
|
469 |
+
nf_prev = nf
|
470 |
+
nf = min(nf * 2, 512)
|
471 |
+
sequence += [[
|
472 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
473 |
+
norm_layer(nf), nn.LeakyReLU(0.2, True)
|
474 |
+
]]
|
475 |
+
|
476 |
+
nf_prev = nf
|
477 |
+
nf = min(nf * 2, 512)
|
478 |
+
sequence += [[
|
479 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
480 |
+
norm_layer(nf),
|
481 |
+
nn.LeakyReLU(0.2, True)
|
482 |
+
]]
|
483 |
+
|
484 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw,
|
485 |
+
stride=1, padding=padw)]]
|
486 |
+
|
487 |
+
if use_sigmoid:
|
488 |
+
sequence += [[nn.Sigmoid()]]
|
489 |
+
|
490 |
+
if getIntermFeat:
|
491 |
+
for n in range(len(sequence)):
|
492 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
493 |
+
else:
|
494 |
+
sequence_stream = []
|
495 |
+
for n in range(len(sequence)):
|
496 |
+
sequence_stream += sequence[n]
|
497 |
+
self.model = nn.Sequential(*sequence_stream)
|
498 |
+
|
499 |
+
def forward(self, input):
|
500 |
+
if self.getIntermFeat:
|
501 |
+
res = [input]
|
502 |
+
for n in range(self.n_layers+2):
|
503 |
+
model = getattr(self, 'model'+str(n))
|
504 |
+
res.append(model(res[-1]))
|
505 |
+
return res[-1], res[1:]
|
506 |
+
else:
|
507 |
+
return self.model(input), _
|
508 |
+
|
509 |
+
|
510 |
+
class NLayerDiscriminator3D(nn.Module):
|
511 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
|
512 |
+
super(NLayerDiscriminator3D, self).__init__()
|
513 |
+
self.getIntermFeat = getIntermFeat
|
514 |
+
self.n_layers = n_layers
|
515 |
+
|
516 |
+
kw = 4
|
517 |
+
padw = int(np.ceil((kw-1.0)/2))
|
518 |
+
sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw,
|
519 |
+
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
|
520 |
+
|
521 |
+
nf = ndf
|
522 |
+
for n in range(1, n_layers):
|
523 |
+
nf_prev = nf
|
524 |
+
nf = min(nf * 2, 512)
|
525 |
+
sequence += [[
|
526 |
+
nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
527 |
+
norm_layer(nf), nn.LeakyReLU(0.2, True)
|
528 |
+
]]
|
529 |
+
|
530 |
+
nf_prev = nf
|
531 |
+
nf = min(nf * 2, 512)
|
532 |
+
sequence += [[
|
533 |
+
nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
534 |
+
norm_layer(nf),
|
535 |
+
nn.LeakyReLU(0.2, True)
|
536 |
+
]]
|
537 |
+
|
538 |
+
sequence += [[nn.Conv3d(nf, 1, kernel_size=kw,
|
539 |
+
stride=1, padding=padw)]]
|
540 |
+
|
541 |
+
if use_sigmoid:
|
542 |
+
sequence += [[nn.Sigmoid()]]
|
543 |
+
|
544 |
+
if getIntermFeat:
|
545 |
+
for n in range(len(sequence)):
|
546 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
547 |
+
else:
|
548 |
+
sequence_stream = []
|
549 |
+
for n in range(len(sequence)):
|
550 |
+
sequence_stream += sequence[n]
|
551 |
+
self.model = nn.Sequential(*sequence_stream)
|
552 |
+
|
553 |
+
def forward(self, input):
|
554 |
+
if self.getIntermFeat:
|
555 |
+
res = [input]
|
556 |
+
for n in range(self.n_layers+2):
|
557 |
+
model = getattr(self, 'model'+str(n))
|
558 |
+
res.append(model(res[-1]))
|
559 |
+
return res[-1], res[1:]
|
560 |
+
else:
|
561 |
+
return self.model(input), _
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Adapted from https://github.com/SongweiGe/TATS"""
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import warnings
|
5 |
+
import torch
|
6 |
+
import imageio
|
7 |
+
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import pdb as pdb_original
|
13 |
+
import logging
|
14 |
+
|
15 |
+
import imageio.core.util
|
16 |
+
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
|
17 |
+
|
18 |
+
|
19 |
+
class ForkedPdb(pdb_original.Pdb):
|
20 |
+
"""A Pdb subclass that may be used
|
21 |
+
from a forked multiprocessing child
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
def interaction(self, *args, **kwargs):
|
26 |
+
_stdin = sys.stdin
|
27 |
+
try:
|
28 |
+
sys.stdin = open('/dev/stdin')
|
29 |
+
pdb_original.Pdb.interaction(self, *args, **kwargs)
|
30 |
+
finally:
|
31 |
+
sys.stdin = _stdin
|
32 |
+
|
33 |
+
|
34 |
+
# Shifts src_tf dim to dest dim
|
35 |
+
# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
|
36 |
+
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
|
37 |
+
n_dims = len(x.shape)
|
38 |
+
if src_dim < 0:
|
39 |
+
src_dim = n_dims + src_dim
|
40 |
+
if dest_dim < 0:
|
41 |
+
dest_dim = n_dims + dest_dim
|
42 |
+
|
43 |
+
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
|
44 |
+
|
45 |
+
dims = list(range(n_dims))
|
46 |
+
del dims[src_dim]
|
47 |
+
|
48 |
+
permutation = []
|
49 |
+
ctr = 0
|
50 |
+
for i in range(n_dims):
|
51 |
+
if i == dest_dim:
|
52 |
+
permutation.append(src_dim)
|
53 |
+
else:
|
54 |
+
permutation.append(dims[ctr])
|
55 |
+
ctr += 1
|
56 |
+
x = x.permute(permutation)
|
57 |
+
if make_contiguous:
|
58 |
+
x = x.contiguous()
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
# reshapes tensor start from dim i (inclusive)
|
63 |
+
# to dim j (exclusive) to the desired shape
|
64 |
+
# e.g. if x.shape = (b, thw, c) then
|
65 |
+
# view_range(x, 1, 2, (t, h, w)) returns
|
66 |
+
# x of shape (b, t, h, w, c)
|
67 |
+
def view_range(x, i, j, shape):
|
68 |
+
shape = tuple(shape)
|
69 |
+
|
70 |
+
n_dims = len(x.shape)
|
71 |
+
if i < 0:
|
72 |
+
i = n_dims + i
|
73 |
+
|
74 |
+
if j is None:
|
75 |
+
j = n_dims
|
76 |
+
elif j < 0:
|
77 |
+
j = n_dims + j
|
78 |
+
|
79 |
+
assert 0 <= i < j <= n_dims
|
80 |
+
|
81 |
+
x_shape = x.shape
|
82 |
+
target_shape = x_shape[:i] + shape + x_shape[j:]
|
83 |
+
return x.view(target_shape)
|
84 |
+
|
85 |
+
|
86 |
+
def accuracy(output, target, topk=(1,)):
|
87 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
88 |
+
with torch.no_grad():
|
89 |
+
maxk = max(topk)
|
90 |
+
batch_size = target.size(0)
|
91 |
+
|
92 |
+
_, pred = output.topk(maxk, 1, True, True)
|
93 |
+
pred = pred.t()
|
94 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
95 |
+
|
96 |
+
res = []
|
97 |
+
for k in topk:
|
98 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
99 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
100 |
+
return res
|
101 |
+
|
102 |
+
|
103 |
+
def tensor_slice(x, begin, size):
|
104 |
+
assert all([b >= 0 for b in begin])
|
105 |
+
size = [l - b if s == -1 else s
|
106 |
+
for s, b, l in zip(size, begin, x.shape)]
|
107 |
+
assert all([s >= 0 for s in size])
|
108 |
+
|
109 |
+
slices = [slice(b, b + s) for b, s in zip(begin, size)]
|
110 |
+
return x[slices]
|
111 |
+
|
112 |
+
|
113 |
+
def adopt_weight(global_step, threshold=0, value=0.):
|
114 |
+
weight = 1
|
115 |
+
if global_step < threshold:
|
116 |
+
weight = value
|
117 |
+
return weight
|
118 |
+
|
119 |
+
|
120 |
+
def save_video_grid(video, fname, nrow=None, fps=6):
|
121 |
+
b, c, t, h, w = video.shape
|
122 |
+
video = video.permute(0, 2, 3, 4, 1)
|
123 |
+
video = (video.cpu().numpy() * 255).astype('uint8')
|
124 |
+
if nrow is None:
|
125 |
+
nrow = math.ceil(math.sqrt(b))
|
126 |
+
ncol = math.ceil(b / nrow)
|
127 |
+
padding = 1
|
128 |
+
video_grid = np.zeros((t, (padding + h) * nrow + padding,
|
129 |
+
(padding + w) * ncol + padding, c), dtype='uint8')
|
130 |
+
for i in range(b):
|
131 |
+
r = i // ncol
|
132 |
+
c = i % ncol
|
133 |
+
start_r = (padding + h) * r
|
134 |
+
start_c = (padding + w) * c
|
135 |
+
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
|
136 |
+
video = []
|
137 |
+
for i in range(t):
|
138 |
+
video.append(video_grid[i])
|
139 |
+
imageio.mimsave(fname, video, fps=fps)
|
140 |
+
## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'})
|
141 |
+
#print('saved videos to', fname)
|
142 |
+
|
143 |
+
|
144 |
+
def comp_getattr(args, attr_name, default=None):
|
145 |
+
if hasattr(args, attr_name):
|
146 |
+
return getattr(args, attr_name)
|
147 |
+
else:
|
148 |
+
return default
|
149 |
+
|
150 |
+
|
151 |
+
def visualize_tensors(t, name=None, nest=0):
|
152 |
+
if name is not None:
|
153 |
+
print(name, "current nest: ", nest)
|
154 |
+
print("type: ", type(t))
|
155 |
+
if 'dict' in str(type(t)):
|
156 |
+
print(t.keys())
|
157 |
+
for k in t.keys():
|
158 |
+
if t[k] is None:
|
159 |
+
print(k, "None")
|
160 |
+
else:
|
161 |
+
if 'Tensor' in str(type(t[k])):
|
162 |
+
print(k, t[k].shape)
|
163 |
+
elif 'dict' in str(type(t[k])):
|
164 |
+
print(k, 'dict')
|
165 |
+
visualize_tensors(t[k], name, nest + 1)
|
166 |
+
elif 'list' in str(type(t[k])):
|
167 |
+
print(k, len(t[k]))
|
168 |
+
visualize_tensors(t[k], name, nest + 1)
|
169 |
+
elif 'list' in str(type(t)):
|
170 |
+
print("list length: ", len(t))
|
171 |
+
for t2 in t:
|
172 |
+
visualize_tensors(t2, name, nest + 1)
|
173 |
+
elif 'Tensor' in str(type(t)):
|
174 |
+
print(t.shape)
|
175 |
+
else:
|
176 |
+
print(t)
|
177 |
+
return ""
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_kidney_fold0_early.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d889b3561803f7490f4050c03a02163f099633e4f00fea4cb10b5b993685e5cc
|
3 |
+
size 290138333
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_kidney_fold0_noearly_t200.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26bc7847ae15377a5586535cbb2e6a1ec5b6a98732f7f795c284d7dcda208c97
|
3 |
+
size 290156765
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_liver_fold0_early.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0135000f031f741252b3e706748b674d33e7278402a7cb2500fec5f4966847bd
|
3 |
+
size 290138333
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_liver_fold0_noearly_t200.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c2a21980e53efd6758ae92e79a82668f0e1e6d9b52fdf6b2a709cb929ebedb3b
|
3 |
+
size 290156765
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_pancreas_fold0_early.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9438e39a44af92bb0fbaf5cc50a3ac3aaa260978a69ac341ed7ec23512c080a5
|
3 |
+
size 290138333
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/diff_pancreas_fold0_noearly_t200.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb37011840156f548fd2348dbb5578f9bc81de16719ec226fbef2de6f0244f9d
|
3 |
+
size 290156765
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/model_weight/recon_96d4_all.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef88523af9590a7325bc9ca41999de191c3fbc41afc6186a8c4db5528446bb1f
|
3 |
+
size 242615727
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/utils.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Tumor Generateion
|
2 |
+
import random
|
3 |
+
import cv2
|
4 |
+
import elasticdeform
|
5 |
+
import numpy as np
|
6 |
+
from scipy.ndimage import gaussian_filter
|
7 |
+
from TumorGeneration.ldm.ddpm.ddim import DDIMSampler
|
8 |
+
|
9 |
+
# Step 1: Random select (numbers) location for tumor.
|
10 |
+
def random_select(mask_scan, organ_type):
|
11 |
+
# we first find z index and then sample point with z slice
|
12 |
+
# print('mask_scan',np.unique(mask_scan))
|
13 |
+
# print('pixel num', (mask_scan == 1).sum())
|
14 |
+
z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max()
|
15 |
+
# print('z_start, z_end',z_start, z_end)
|
16 |
+
# we need to strict number z's position (0.3 - 0.7 in the middle of liver)
|
17 |
+
while 1:
|
18 |
+
z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start
|
19 |
+
liver_mask = mask_scan[..., z]
|
20 |
+
# erode the mask (we don't want the edge points)
|
21 |
+
if organ_type == 'liver':
|
22 |
+
kernel = np.ones((5,5), dtype=np.uint8)
|
23 |
+
liver_mask = cv2.erode(liver_mask, kernel, iterations=1)
|
24 |
+
if (liver_mask == 1).sum() > 0:
|
25 |
+
break
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
# print('liver_mask', (liver_mask == 1).sum())
|
30 |
+
coordinates = np.argwhere(liver_mask == 1)
|
31 |
+
random_index = np.random.randint(0, len(coordinates))
|
32 |
+
xyz = coordinates[random_index].tolist() # get x,y
|
33 |
+
xyz.append(z)
|
34 |
+
potential_points = xyz
|
35 |
+
|
36 |
+
return potential_points
|
37 |
+
|
38 |
+
def center_select(mask_scan):
|
39 |
+
z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max()
|
40 |
+
x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max()
|
41 |
+
y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max()
|
42 |
+
|
43 |
+
z = round(0.5 * (z_end - z_start)) + z_start
|
44 |
+
x = round(0.5 * (x_end - x_start)) + x_start
|
45 |
+
y = round(0.5 * (y_end - y_start)) + y_start
|
46 |
+
|
47 |
+
xyz = [x, y, z]
|
48 |
+
potential_points = xyz
|
49 |
+
|
50 |
+
return potential_points
|
51 |
+
|
52 |
+
# Step 2 : generate the ellipsoid
|
53 |
+
def get_ellipsoid(x, y, z):
|
54 |
+
""""
|
55 |
+
x, y, z is the radius of this ellipsoid in x, y, z direction respectly.
|
56 |
+
"""
|
57 |
+
sh = (4*x, 4*y, 4*z)
|
58 |
+
out = np.zeros(sh, int)
|
59 |
+
aux = np.zeros(sh)
|
60 |
+
radii = np.array([x, y, z])
|
61 |
+
com = np.array([2*x, 2*y, 2*z]) # center point
|
62 |
+
|
63 |
+
# calculate the ellipsoid
|
64 |
+
bboxl = np.floor(com-radii).clip(0,None).astype(int)
|
65 |
+
bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int)
|
66 |
+
roi = out[tuple(map(slice,bboxl,bboxh))]
|
67 |
+
roiaux = aux[tuple(map(slice,bboxl,bboxh))]
|
68 |
+
logrid = *map(np.square,np.ogrid[tuple(
|
69 |
+
map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]),
|
70 |
+
dst = (1-sum(logrid)).clip(0,None)
|
71 |
+
mask = dst>roiaux
|
72 |
+
roi[mask] = 1
|
73 |
+
np.copyto(roiaux,dst,where=mask)
|
74 |
+
|
75 |
+
return out
|
76 |
+
|
77 |
+
def get_fixed_geo(mask_scan, tumor_type, organ_type):
|
78 |
+
if tumor_type == 'large':
|
79 |
+
enlarge_x, enlarge_y, enlarge_z = 280, 280, 280
|
80 |
+
else:
|
81 |
+
enlarge_x, enlarge_y, enlarge_z = 160, 160, 160
|
82 |
+
geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8)
|
83 |
+
tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32
|
84 |
+
|
85 |
+
if tumor_type == 'tiny':
|
86 |
+
num_tumor = random.randint(1,3)
|
87 |
+
# num_tumor = 1
|
88 |
+
for _ in range(num_tumor):
|
89 |
+
# Tiny tumor
|
90 |
+
x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
91 |
+
y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
92 |
+
z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
93 |
+
sigma = random.uniform(0.5,1)
|
94 |
+
|
95 |
+
geo = get_ellipsoid(x, y, z)
|
96 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
97 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
98 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
99 |
+
point = random_select(mask_scan, organ_type)
|
100 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
101 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
102 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
103 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
104 |
+
|
105 |
+
# paste small tumor geo into test sample
|
106 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
107 |
+
|
108 |
+
if tumor_type == 'small':
|
109 |
+
num_tumor = random.randint(1,3)
|
110 |
+
# num_tumor = 1
|
111 |
+
for _ in range(num_tumor):
|
112 |
+
# Small tumor
|
113 |
+
x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
114 |
+
y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
115 |
+
z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
116 |
+
sigma = random.randint(1, 2)
|
117 |
+
|
118 |
+
geo = get_ellipsoid(x, y, z)
|
119 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
120 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
121 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
122 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
123 |
+
point = random_select(mask_scan, organ_type)
|
124 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
125 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
126 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
127 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
128 |
+
|
129 |
+
# paste small tumor geo into test sample
|
130 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
131 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
132 |
+
|
133 |
+
if tumor_type == 'medium':
|
134 |
+
# num_tumor = random.randint(1, 3)
|
135 |
+
num_tumor = 1
|
136 |
+
for _ in range(num_tumor):
|
137 |
+
# medium tumor
|
138 |
+
x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
139 |
+
y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
140 |
+
z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
141 |
+
sigma = random.randint(3, 6)
|
142 |
+
|
143 |
+
geo = get_ellipsoid(x, y, z)
|
144 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
145 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
146 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
147 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
148 |
+
point = random_select(mask_scan, organ_type)
|
149 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
150 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
151 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
152 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
153 |
+
|
154 |
+
# paste medium tumor geo into test sample
|
155 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
156 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
157 |
+
|
158 |
+
if tumor_type == 'large':
|
159 |
+
num_tumor = 1 # random.randint(1,3)
|
160 |
+
for _ in range(num_tumor):
|
161 |
+
# Large tumor
|
162 |
+
|
163 |
+
x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
164 |
+
y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
165 |
+
z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
166 |
+
sigma = random.randint(5, 10)
|
167 |
+
|
168 |
+
geo = get_ellipsoid(x, y, z)
|
169 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
170 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
171 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
172 |
+
if organ_type == 'liver' or organ_type == 'kidney' :
|
173 |
+
point = random_select(mask_scan, organ_type)
|
174 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
175 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
176 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
177 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
178 |
+
else:
|
179 |
+
x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max()
|
180 |
+
y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max()
|
181 |
+
z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max()
|
182 |
+
geo = geo[x_start:x_end, y_start:y_end, z_start:z_end]
|
183 |
+
|
184 |
+
point = center_select(mask_scan)
|
185 |
+
|
186 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
187 |
+
x_low = new_point[0] - geo.shape[0]//2
|
188 |
+
y_low = new_point[1] - geo.shape[1]//2
|
189 |
+
z_low = new_point[2] - geo.shape[2]//2
|
190 |
+
|
191 |
+
# paste small tumor geo into test sample
|
192 |
+
geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo
|
193 |
+
|
194 |
+
geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
|
195 |
+
# texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
|
196 |
+
if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'):
|
197 |
+
if random.random() > 0.5:
|
198 |
+
geo_mask = (geo_mask>=1)
|
199 |
+
else:
|
200 |
+
geo_mask = (geo_mask * mask_scan) >=1
|
201 |
+
else:
|
202 |
+
geo_mask = (geo_mask * mask_scan) >=1
|
203 |
+
|
204 |
+
return geo_mask
|
205 |
+
|
206 |
+
|
207 |
+
from .ldm.vq_gan_3d.model.vqgan import VQGAN
|
208 |
+
import matplotlib.pyplot as plt
|
209 |
+
import SimpleITK as sitk
|
210 |
+
from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester
|
211 |
+
from hydra import initialize, compose
|
212 |
+
import torch
|
213 |
+
import yaml
|
214 |
+
def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'):
|
215 |
+
with initialize(config_path="diffusion_config/"):
|
216 |
+
cfg=compose(config_name="ddpm.yaml")
|
217 |
+
print('diffusion_ckpt',diffusion_ckpt)
|
218 |
+
vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt)
|
219 |
+
vqgan = vqgan.to(device)
|
220 |
+
vqgan.eval()
|
221 |
+
|
222 |
+
early_Unet3D = Unet3D(
|
223 |
+
dim=cfg.diffusion_img_size,
|
224 |
+
dim_mults=cfg.dim_mults,
|
225 |
+
channels=cfg.diffusion_num_channels,
|
226 |
+
out_dim=cfg.out_dim
|
227 |
+
).to(device)
|
228 |
+
|
229 |
+
early_diffusion = GaussianDiffusion(
|
230 |
+
early_Unet3D,
|
231 |
+
vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt,
|
232 |
+
image_size=cfg.diffusion_img_size,
|
233 |
+
num_frames=cfg.diffusion_depth_size,
|
234 |
+
channels=cfg.diffusion_num_channels,
|
235 |
+
timesteps=4, # cfg.timesteps,
|
236 |
+
loss_type=cfg.loss_type,
|
237 |
+
device=device
|
238 |
+
).to(device)
|
239 |
+
|
240 |
+
noearly_Unet3D = Unet3D(
|
241 |
+
dim=cfg.diffusion_img_size,
|
242 |
+
dim_mults=cfg.dim_mults,
|
243 |
+
channels=cfg.diffusion_num_channels,
|
244 |
+
out_dim=cfg.out_dim
|
245 |
+
).to(device)
|
246 |
+
|
247 |
+
noearly_diffusion = GaussianDiffusion(
|
248 |
+
noearly_Unet3D,
|
249 |
+
vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt,
|
250 |
+
image_size=cfg.diffusion_img_size,
|
251 |
+
num_frames=cfg.diffusion_depth_size,
|
252 |
+
channels=cfg.diffusion_num_channels,
|
253 |
+
timesteps=200, # cfg.timesteps,
|
254 |
+
loss_type=cfg.loss_type,
|
255 |
+
device=device
|
256 |
+
).to(device)
|
257 |
+
|
258 |
+
early_tester = Tester(early_diffusion)
|
259 |
+
# noearly_tester = Tester(noearly_diffusion)
|
260 |
+
early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device)
|
261 |
+
# noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device)
|
262 |
+
|
263 |
+
# early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device)
|
264 |
+
noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device)
|
265 |
+
# early_diffusion.load_state_dict(early_checkpoint['ema'])
|
266 |
+
noearly_diffusion.load_state_dict(noearly_checkpoint['ema'])
|
267 |
+
# early_sampler = DDIMSampler(early_diffusion, schedule="cosine")
|
268 |
+
noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine")
|
269 |
+
# breakpoint()
|
270 |
+
return vqgan, early_tester, noearly_sampler
|
271 |
+
|
272 |
+
def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester):
|
273 |
+
device=ct_volume.device
|
274 |
+
|
275 |
+
# generate tumor mask
|
276 |
+
tumor_types = ['tiny', 'small']
|
277 |
+
# tumor_probs = np.array([0.5, 0.5])
|
278 |
+
tumor_probs = np.array([0.2, 0.8])
|
279 |
+
total_tumor_mask = []
|
280 |
+
organ_mask_np = organ_mask.cpu().numpy()
|
281 |
+
with torch.no_grad():
|
282 |
+
# get model input
|
283 |
+
for bs in range(organ_mask_np.shape[0]):
|
284 |
+
synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
|
285 |
+
tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type)
|
286 |
+
total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:])
|
287 |
+
total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device)
|
288 |
+
|
289 |
+
volume = ct_volume*2.0 - 1.0
|
290 |
+
mask = total_tumor_mask*2.0 - 1.0
|
291 |
+
mask_ = 1-total_tumor_mask
|
292 |
+
masked_volume = (volume*mask_).detach()
|
293 |
+
|
294 |
+
volume = volume.permute(0,1,-1,-3,-2)
|
295 |
+
masked_volume = masked_volume.permute(0,1,-1,-3,-2)
|
296 |
+
mask = mask.permute(0,1,-1,-3,-2)
|
297 |
+
|
298 |
+
# vqgan encoder inference
|
299 |
+
masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True)
|
300 |
+
masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) /
|
301 |
+
(vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0
|
302 |
+
|
303 |
+
cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:])
|
304 |
+
cond = torch.cat((masked_volume_feat, cc), dim=1)
|
305 |
+
|
306 |
+
# diffusion inference and decoder
|
307 |
+
tester.ema_model.eval()
|
308 |
+
sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond)
|
309 |
+
|
310 |
+
# if organ_type == 'liver' or organ_type == 'kidney' :
|
311 |
+
|
312 |
+
mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0)
|
313 |
+
sigma = np.random.uniform(0, 4) # (1, 2)
|
314 |
+
mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma])
|
315 |
+
# mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy()
|
316 |
+
|
317 |
+
volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0)
|
318 |
+
sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0)
|
319 |
+
|
320 |
+
mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device)
|
321 |
+
final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_
|
322 |
+
final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0)
|
323 |
+
# elif organ_type == 'pancreas':
|
324 |
+
# final_volume_ = (sample+1.0)/2.0
|
325 |
+
final_volume_ = final_volume_.permute(0,1,-2,-1,-3)
|
326 |
+
organ_tumor_mask = torch.ones_like(organ_mask)
|
327 |
+
organ_tumor_mask[total_tumor_mask==1] = 2
|
328 |
+
|
329 |
+
return final_volume_, organ_tumor_mask
|
330 |
+
|
331 |
+
def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50):
|
332 |
+
device=ct_volume.device
|
333 |
+
|
334 |
+
# generate tumor mask
|
335 |
+
# tumor_types = ['large']
|
336 |
+
# tumor_probs = np.array([1.0])
|
337 |
+
total_tumor_mask = []
|
338 |
+
organ_mask_np = organ_mask.cpu().numpy()
|
339 |
+
with torch.no_grad():
|
340 |
+
# get model input
|
341 |
+
for bs in range(organ_mask_np.shape[0]):
|
342 |
+
# synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
|
343 |
+
synthetic_tumor_type = 'medium'
|
344 |
+
tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type)
|
345 |
+
total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:])
|
346 |
+
total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device)
|
347 |
+
|
348 |
+
volume = ct_volume*2.0 - 1.0
|
349 |
+
mask = total_tumor_mask*2.0 - 1.0
|
350 |
+
mask_ = 1-total_tumor_mask
|
351 |
+
masked_volume = (volume*mask_).detach()
|
352 |
+
|
353 |
+
volume = volume.permute(0,1,-1,-3,-2)
|
354 |
+
masked_volume = masked_volume.permute(0,1,-1,-3,-2)
|
355 |
+
mask = mask.permute(0,1,-1,-3,-2)
|
356 |
+
|
357 |
+
# vqgan encoder inference
|
358 |
+
masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True)
|
359 |
+
masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) /
|
360 |
+
(vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0
|
361 |
+
|
362 |
+
cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:])
|
363 |
+
cond = torch.cat((masked_volume_feat, cc), dim=1)
|
364 |
+
|
365 |
+
# diffusion inference and decoder
|
366 |
+
shape = masked_volume_feat.shape[-4:]
|
367 |
+
samples_ddim, _ = sampler.sample(S=ddim_ts,
|
368 |
+
conditioning=cond,
|
369 |
+
batch_size=1,
|
370 |
+
shape=shape,
|
371 |
+
verbose=False)
|
372 |
+
samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() -
|
373 |
+
vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min()
|
374 |
+
|
375 |
+
sample = vqgan.decode(samples_ddim, quantize=True)
|
376 |
+
|
377 |
+
# if organ_type == 'liver' or organ_type == 'kidney':
|
378 |
+
# post-process
|
379 |
+
mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0)
|
380 |
+
sigma = np.random.uniform(0, 4) # (1, 2)
|
381 |
+
mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma])
|
382 |
+
# mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy()
|
383 |
+
|
384 |
+
volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0)
|
385 |
+
sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0)
|
386 |
+
|
387 |
+
mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device)
|
388 |
+
final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_
|
389 |
+
final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0)
|
390 |
+
# elif organ_type == 'pancreas':
|
391 |
+
# final_volume_ = (sample+1.0)/2.0
|
392 |
+
|
393 |
+
final_volume_ = final_volume_.permute(0,1,-2,-1,-3)
|
394 |
+
organ_tumor_mask = torch.ones_like(organ_mask)
|
395 |
+
organ_tumor_mask[total_tumor_mask==1] = 2
|
396 |
+
|
397 |
+
return final_volume_, organ_tumor_mask
|
398 |
+
|
399 |
+
def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50):
|
400 |
+
device=ct_volume.device
|
401 |
+
|
402 |
+
# generate tumor mask
|
403 |
+
# tumor_types = ['large']
|
404 |
+
# tumor_probs = np.array([1.0])
|
405 |
+
total_tumor_mask = []
|
406 |
+
organ_mask_np = organ_mask.cpu().numpy()
|
407 |
+
with torch.no_grad():
|
408 |
+
# get model input
|
409 |
+
for bs in range(organ_mask_np.shape[0]):
|
410 |
+
# synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
|
411 |
+
synthetic_tumor_type = 'large'
|
412 |
+
tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type)
|
413 |
+
total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:])
|
414 |
+
total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device)
|
415 |
+
|
416 |
+
volume = ct_volume*2.0 - 1.0
|
417 |
+
mask = total_tumor_mask*2.0 - 1.0
|
418 |
+
mask_ = 1-total_tumor_mask
|
419 |
+
masked_volume = (volume*mask_).detach()
|
420 |
+
|
421 |
+
volume = volume.permute(0,1,-1,-3,-2)
|
422 |
+
masked_volume = masked_volume.permute(0,1,-1,-3,-2)
|
423 |
+
mask = mask.permute(0,1,-1,-3,-2)
|
424 |
+
|
425 |
+
# vqgan encoder inference
|
426 |
+
masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True)
|
427 |
+
masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) /
|
428 |
+
(vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0
|
429 |
+
|
430 |
+
cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:])
|
431 |
+
cond = torch.cat((masked_volume_feat, cc), dim=1)
|
432 |
+
|
433 |
+
# diffusion inference and decoder
|
434 |
+
shape = masked_volume_feat.shape[-4:]
|
435 |
+
samples_ddim, _ = sampler.sample(S=ddim_ts,
|
436 |
+
conditioning=cond,
|
437 |
+
batch_size=1,
|
438 |
+
shape=shape,
|
439 |
+
verbose=False)
|
440 |
+
samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() -
|
441 |
+
vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min()
|
442 |
+
|
443 |
+
sample = vqgan.decode(samples_ddim, quantize=True)
|
444 |
+
|
445 |
+
# if organ_type == 'liver' or organ_type == 'kidney':
|
446 |
+
# post-process
|
447 |
+
mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0)
|
448 |
+
sigma = np.random.uniform(0, 4) # (1, 2)
|
449 |
+
mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma])
|
450 |
+
# mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy()
|
451 |
+
|
452 |
+
volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0)
|
453 |
+
sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0)
|
454 |
+
|
455 |
+
mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device)
|
456 |
+
final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_
|
457 |
+
final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0)
|
458 |
+
# elif organ_type == 'pancreas':
|
459 |
+
# final_volume_ = (sample+1.0)/2.0
|
460 |
+
|
461 |
+
final_volume_ = final_volume_.permute(0,1,-2,-1,-3)
|
462 |
+
organ_tumor_mask = torch.ones_like(organ_mask)
|
463 |
+
organ_tumor_mask[total_tumor_mask==1] = 2
|
464 |
+
|
465 |
+
return final_volume_, organ_tumor_mask
|
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/utils_.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Tumor Generateion
|
2 |
+
import random
|
3 |
+
import cv2
|
4 |
+
import elasticdeform
|
5 |
+
import numpy as np
|
6 |
+
from scipy.ndimage import gaussian_filter
|
7 |
+
|
8 |
+
# Step 1: Random select (numbers) location for tumor.
|
9 |
+
def random_select(mask_scan):
|
10 |
+
# we first find z index and then sample point with z slice
|
11 |
+
z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]]
|
12 |
+
|
13 |
+
# we need to strict number z's position (0.3 - 0.7 in the middle of liver)
|
14 |
+
z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start
|
15 |
+
|
16 |
+
liver_mask = mask_scan[..., z]
|
17 |
+
|
18 |
+
# erode the mask (we don't want the edge points)
|
19 |
+
kernel = np.ones((5,5), dtype=np.uint8)
|
20 |
+
liver_mask = cv2.erode(liver_mask, kernel, iterations=1)
|
21 |
+
|
22 |
+
coordinates = np.argwhere(liver_mask == 1)
|
23 |
+
random_index = np.random.randint(0, len(coordinates))
|
24 |
+
xyz = coordinates[random_index].tolist() # get x,y
|
25 |
+
xyz.append(z)
|
26 |
+
potential_points = xyz
|
27 |
+
|
28 |
+
return potential_points
|
29 |
+
|
30 |
+
# Step 2 : generate the ellipsoid
|
31 |
+
def get_ellipsoid(x, y, z):
|
32 |
+
""""
|
33 |
+
x, y, z is the radius of this ellipsoid in x, y, z direction respectly.
|
34 |
+
"""
|
35 |
+
sh = (4*x, 4*y, 4*z)
|
36 |
+
out = np.zeros(sh, int)
|
37 |
+
aux = np.zeros(sh)
|
38 |
+
radii = np.array([x, y, z])
|
39 |
+
com = np.array([2*x, 2*y, 2*z]) # center point
|
40 |
+
|
41 |
+
# calculate the ellipsoid
|
42 |
+
bboxl = np.floor(com-radii).clip(0,None).astype(int)
|
43 |
+
bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int)
|
44 |
+
roi = out[tuple(map(slice,bboxl,bboxh))]
|
45 |
+
roiaux = aux[tuple(map(slice,bboxl,bboxh))]
|
46 |
+
logrid = *map(np.square,np.ogrid[tuple(
|
47 |
+
map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]),
|
48 |
+
dst = (1-sum(logrid)).clip(0,None)
|
49 |
+
mask = dst>roiaux
|
50 |
+
roi[mask] = 1
|
51 |
+
np.copyto(roiaux,dst,where=mask)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
def get_fixed_geo(mask_scan, tumor_type):
|
56 |
+
|
57 |
+
enlarge_x, enlarge_y, enlarge_z = 160, 160, 160
|
58 |
+
geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8)
|
59 |
+
tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32
|
60 |
+
|
61 |
+
if tumor_type == 'tiny':
|
62 |
+
num_tumor = random.randint(3,10)
|
63 |
+
for _ in range(num_tumor):
|
64 |
+
# Tiny tumor
|
65 |
+
x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
66 |
+
y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
67 |
+
z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
68 |
+
sigma = random.uniform(0.5,1)
|
69 |
+
|
70 |
+
geo = get_ellipsoid(x, y, z)
|
71 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
72 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
73 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
74 |
+
point = random_select(mask_scan)
|
75 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
76 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
77 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
78 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
79 |
+
|
80 |
+
# paste small tumor geo into test sample
|
81 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
82 |
+
|
83 |
+
if tumor_type == 'small':
|
84 |
+
num_tumor = random.randint(3,10)
|
85 |
+
for _ in range(num_tumor):
|
86 |
+
# Small tumor
|
87 |
+
x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
88 |
+
y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
89 |
+
z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
90 |
+
sigma = random.randint(1, 2)
|
91 |
+
|
92 |
+
geo = get_ellipsoid(x, y, z)
|
93 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
94 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
95 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
96 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
97 |
+
point = random_select(mask_scan)
|
98 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
99 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
100 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
101 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
102 |
+
|
103 |
+
# paste small tumor geo into test sample
|
104 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
105 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
106 |
+
|
107 |
+
if tumor_type == 'medium':
|
108 |
+
num_tumor = random.randint(2, 5)
|
109 |
+
for _ in range(num_tumor):
|
110 |
+
# medium tumor
|
111 |
+
x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
112 |
+
y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
113 |
+
z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
114 |
+
sigma = random.randint(3, 6)
|
115 |
+
|
116 |
+
geo = get_ellipsoid(x, y, z)
|
117 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
118 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
119 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
120 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
121 |
+
point = random_select(mask_scan)
|
122 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
123 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
124 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
125 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
126 |
+
|
127 |
+
# paste medium tumor geo into test sample
|
128 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
129 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
130 |
+
|
131 |
+
if tumor_type == 'large':
|
132 |
+
num_tumor = random.randint(1,3)
|
133 |
+
for _ in range(num_tumor):
|
134 |
+
# Large tumor
|
135 |
+
x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
136 |
+
y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
137 |
+
z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
138 |
+
sigma = random.randint(5, 10)
|
139 |
+
|
140 |
+
geo = get_ellipsoid(x, y, z)
|
141 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
142 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
143 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
144 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
145 |
+
point = random_select(mask_scan)
|
146 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
147 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
148 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
149 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
150 |
+
|
151 |
+
# paste small tumor geo into test sample
|
152 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
153 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
154 |
+
|
155 |
+
if tumor_type == "mix":
|
156 |
+
# tiny
|
157 |
+
num_tumor = random.randint(3,10)
|
158 |
+
for _ in range(num_tumor):
|
159 |
+
# Tiny tumor
|
160 |
+
x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
161 |
+
y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
162 |
+
z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
|
163 |
+
sigma = random.uniform(0.5,1)
|
164 |
+
|
165 |
+
geo = get_ellipsoid(x, y, z)
|
166 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
167 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
168 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
169 |
+
point = random_select(mask_scan)
|
170 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
171 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
172 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
173 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
174 |
+
|
175 |
+
# paste small tumor geo into test sample
|
176 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
177 |
+
|
178 |
+
# small
|
179 |
+
num_tumor = random.randint(5,10)
|
180 |
+
for _ in range(num_tumor):
|
181 |
+
# Small tumor
|
182 |
+
x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
183 |
+
y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
184 |
+
z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
|
185 |
+
sigma = random.randint(1, 2)
|
186 |
+
|
187 |
+
geo = get_ellipsoid(x, y, z)
|
188 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
189 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
190 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
191 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
192 |
+
point = random_select(mask_scan)
|
193 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
194 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
195 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
196 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
197 |
+
|
198 |
+
# paste small tumor geo into test sample
|
199 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
200 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
201 |
+
|
202 |
+
# medium
|
203 |
+
num_tumor = random.randint(2, 5)
|
204 |
+
for _ in range(num_tumor):
|
205 |
+
# medium tumor
|
206 |
+
x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
207 |
+
y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
208 |
+
z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
|
209 |
+
sigma = random.randint(3, 6)
|
210 |
+
|
211 |
+
geo = get_ellipsoid(x, y, z)
|
212 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
213 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
214 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
215 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
216 |
+
point = random_select(mask_scan)
|
217 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
218 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
219 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
220 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
221 |
+
|
222 |
+
# paste medium tumor geo into test sample
|
223 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
224 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
225 |
+
|
226 |
+
# large
|
227 |
+
num_tumor = random.randint(1,3)
|
228 |
+
for _ in range(num_tumor):
|
229 |
+
# Large tumor
|
230 |
+
x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
231 |
+
y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
232 |
+
z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
|
233 |
+
sigma = random.randint(5, 10)
|
234 |
+
geo = get_ellipsoid(x, y, z)
|
235 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
|
236 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
|
237 |
+
geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
|
238 |
+
# texture = get_texture((4*x, 4*y, 4*z))
|
239 |
+
point = random_select(mask_scan)
|
240 |
+
new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
|
241 |
+
x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
|
242 |
+
y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
|
243 |
+
z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
|
244 |
+
|
245 |
+
# paste small tumor geo into test sample
|
246 |
+
geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
|
247 |
+
# texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
|
248 |
+
|
249 |
+
geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
|
250 |
+
# texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
|
251 |
+
geo_mask = (geo_mask * mask_scan) >=1
|
252 |
+
|
253 |
+
return geo_mask
|
254 |
+
|
255 |
+
|
256 |
+
def get_tumor(volume_scan, mask_scan, tumor_type):
|
257 |
+
tumor_mask = get_fixed_geo(mask_scan, tumor_type)
|
258 |
+
|
259 |
+
sigma = np.random.uniform(1, 2)
|
260 |
+
# difference = np.random.uniform(65, 145)
|
261 |
+
difference = 1
|
262 |
+
|
263 |
+
# blur the boundary
|
264 |
+
tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma)
|
265 |
+
|
266 |
+
|
267 |
+
abnormally_full = volume_scan * (1 - mask_scan) + abnormally
|
268 |
+
abnormally_mask = mask_scan + geo_mask
|
269 |
+
|
270 |
+
return abnormally_full, abnormally_mask
|
271 |
+
|
272 |
+
def SynthesisTumor(volume_scan, mask_scan, tumor_type):
|
273 |
+
# for speed_generate_tumor, we only send the liver part into the generate program
|
274 |
+
x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]]
|
275 |
+
y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]]
|
276 |
+
z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]]
|
277 |
+
|
278 |
+
# shrink the boundary
|
279 |
+
x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1)
|
280 |
+
y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1)
|
281 |
+
z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1)
|
282 |
+
|
283 |
+
ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end]
|
284 |
+
organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end]
|
285 |
+
|
286 |
+
# input texture shape: 420, 300, 320
|
287 |
+
# we need to cut it into the shape of liver_mask
|
288 |
+
# for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape
|
289 |
+
x_length, y_length, z_length = 64, 64, 64
|
290 |
+
crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check
|
291 |
+
crop_y = random.randint(y_start, y_end - y_length - 1)
|
292 |
+
crop_z = random.randint(z_start, z_end - z_length - 1)
|
293 |
+
|
294 |
+
ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type)
|
295 |
+
volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume
|
296 |
+
mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask
|
297 |
+
|
298 |
+
return volume_scan, mask_scan
|
Generation_Pipeline_filter_all2/syn_kidney/healthy_kidney_1k.txt
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BDMAP_00002275
|
2 |
+
BDMAP_00001907
|
3 |
+
BDMAP_00002712
|
4 |
+
BDMAP_00004615
|
5 |
+
BDMAP_00004651
|
6 |
+
BDMAP_00002230
|
7 |
+
BDMAP_00002955
|
8 |
+
BDMAP_00004183
|
9 |
+
BDMAP_00002304
|
10 |
+
BDMAP_00002029
|
11 |
+
BDMAP_00001646
|
12 |
+
BDMAP_00002909
|
13 |
+
BDMAP_00002328
|
14 |
+
BDMAP_00004829
|
15 |
+
BDMAP_00001093
|
16 |
+
BDMAP_00002117
|
17 |
+
BDMAP_00004600
|
18 |
+
BDMAP_00003771
|
19 |
+
BDMAP_00001198
|
20 |
+
BDMAP_00003451
|
21 |
+
BDMAP_00002719
|
22 |
+
BDMAP_00002846
|
23 |
+
BDMAP_00002282
|
24 |
+
BDMAP_00003827
|
25 |
+
BDMAP_00001649
|
26 |
+
BDMAP_00005141
|
27 |
+
BDMAP_00000941
|
28 |
+
BDMAP_00002875
|
29 |
+
BDMAP_00004641
|
30 |
+
BDMAP_00003373
|
31 |
+
BDMAP_00001924
|
32 |
+
BDMAP_00003897
|
33 |
+
BDMAP_00005074
|
34 |
+
BDMAP_00001753
|
35 |
+
BDMAP_00000101
|
36 |
+
BDMAP_00003412
|
37 |
+
BDMAP_00002945
|
38 |
+
BDMAP_00002598
|
39 |
+
BDMAP_00004858
|
40 |
+
BDMAP_00001632
|
41 |
+
BDMAP_00003327
|
42 |
+
BDMAP_00005130
|
43 |
+
BDMAP_00004783
|
44 |
+
BDMAP_00002844
|
45 |
+
BDMAP_00002479
|
46 |
+
BDMAP_00001464
|
47 |
+
BDMAP_00001809
|
48 |
+
BDMAP_00003385
|
49 |
+
BDMAP_00003918
|
50 |
+
BDMAP_00004995
|
51 |
+
BDMAP_00004447
|
52 |
+
BDMAP_00003972
|
53 |
+
BDMAP_00003438
|
54 |
+
BDMAP_00003898
|
55 |
+
BDMAP_00001057
|
56 |
+
BDMAP_00005005
|
57 |
+
BDMAP_00003244
|
58 |
+
BDMAP_00003631
|
59 |
+
BDMAP_00004103
|
60 |
+
BDMAP_00000069
|
61 |
+
BDMAP_00001736
|
62 |
+
BDMAP_00003002
|
63 |
+
BDMAP_00004704
|
64 |
+
BDMAP_00001055
|
65 |
+
BDMAP_00000447
|
66 |
+
BDMAP_00000778
|
67 |
+
BDMAP_00005097
|
68 |
+
BDMAP_00004264
|
69 |
+
BDMAP_00004304
|
70 |
+
BDMAP_00005170
|
71 |
+
BDMAP_00000547
|
72 |
+
BDMAP_00004764
|
73 |
+
BDMAP_00004229
|
74 |
+
BDMAP_00001414
|
75 |
+
BDMAP_00001828
|
76 |
+
BDMAP_00003151
|
77 |
+
BDMAP_00003769
|
78 |
+
BDMAP_00001962
|
79 |
+
BDMAP_00003333
|
80 |
+
BDMAP_00000676
|
81 |
+
BDMAP_00001704
|
82 |
+
BDMAP_00004459
|
83 |
+
BDMAP_00003683
|
84 |
+
BDMAP_00003439
|
85 |
+
BDMAP_00004016
|
86 |
+
BDMAP_00000438
|
87 |
+
BDMAP_00004117
|
88 |
+
BDMAP_00001785
|
89 |
+
BDMAP_00002688
|
90 |
+
BDMAP_00000913
|
91 |
+
BDMAP_00000942
|
92 |
+
BDMAP_00003400
|
93 |
+
BDMAP_00003824
|
94 |
+
BDMAP_00000470
|
95 |
+
BDMAP_00002918
|
96 |
+
BDMAP_00002828
|
97 |
+
BDMAP_00004286
|
98 |
+
BDMAP_00001845
|
99 |
+
BDMAP_00002791
|
100 |
+
BDMAP_00004672
|
101 |
+
BDMAP_00002717
|
102 |
+
BDMAP_00002856
|
103 |
+
BDMAP_00002188
|
104 |
+
BDMAP_00001701
|
105 |
+
BDMAP_00001175
|
106 |
+
BDMAP_00002841
|
107 |
+
BDMAP_00003254
|
108 |
+
BDMAP_00004508
|
109 |
+
BDMAP_00000373
|
110 |
+
BDMAP_00001565
|
111 |
+
BDMAP_00002214
|
112 |
+
BDMAP_00000701
|
113 |
+
BDMAP_00000690
|
114 |
+
BDMAP_00001215
|
115 |
+
BDMAP_00000324
|
116 |
+
BDMAP_00004015
|
117 |
+
BDMAP_00004196
|
118 |
+
BDMAP_00001419
|
119 |
+
BDMAP_00000618
|
120 |
+
BDMAP_00003640
|
121 |
+
BDMAP_00001697
|
122 |
+
BDMAP_00000332
|
123 |
+
BDMAP_00004023
|
124 |
+
BDMAP_00002815
|
125 |
+
BDMAP_00004199
|
126 |
+
BDMAP_00003890
|
127 |
+
BDMAP_00002529
|
128 |
+
BDMAP_00004843
|
129 |
+
BDMAP_00002076
|
130 |
+
BDMAP_00004895
|
131 |
+
BDMAP_00000623
|
132 |
+
BDMAP_00002244
|
133 |
+
BDMAP_00000205
|
134 |
+
BDMAP_00001185
|
135 |
+
BDMAP_00003133
|
136 |
+
BDMAP_00001957
|
137 |
+
BDMAP_00001015
|
138 |
+
BDMAP_00003932
|
139 |
+
BDMAP_00001010
|
140 |
+
BDMAP_00001102
|
141 |
+
BDMAP_00004880
|
142 |
+
BDMAP_00004664
|
143 |
+
BDMAP_00002748
|
144 |
+
BDMAP_00000430
|
145 |
+
BDMAP_00004293
|
146 |
+
BDMAP_00002829
|
147 |
+
BDMAP_00000558
|
148 |
+
BDMAP_00000084
|
149 |
+
BDMAP_00001438
|
150 |
+
BDMAP_00001917
|
151 |
+
BDMAP_00004129
|
152 |
+
BDMAP_00000232
|
153 |
+
BDMAP_00002463
|
154 |
+
BDMAP_00004839
|
155 |
+
BDMAP_00003664
|
156 |
+
BDMAP_00004604
|
157 |
+
BDMAP_00002021
|
158 |
+
BDMAP_00004550
|
159 |
+
BDMAP_00004106
|
160 |
+
BDMAP_00004128
|
161 |
+
BDMAP_00000696
|
162 |
+
BDMAP_00002411
|
163 |
+
BDMAP_00003569
|
164 |
+
BDMAP_00001912
|
165 |
+
BDMAP_00003036
|
166 |
+
BDMAP_00001288
|
167 |
+
BDMAP_00002216
|
168 |
+
BDMAP_00002199
|
169 |
+
BDMAP_00000100
|
170 |
+
BDMAP_00003634
|
171 |
+
BDMAP_00000345
|
172 |
+
BDMAP_00000614
|
173 |
+
BDMAP_00001769
|
174 |
+
BDMAP_00002580
|
175 |
+
BDMAP_00004676
|
176 |
+
BDMAP_00000388
|
177 |
+
BDMAP_00003357
|
178 |
+
BDMAP_00004431
|
179 |
+
BDMAP_00002359
|
180 |
+
BDMAP_00000132
|
181 |
+
BDMAP_00004097
|
182 |
+
BDMAP_00003847
|
183 |
+
BDMAP_00003017
|
184 |
+
BDMAP_00003680
|
185 |
+
BDMAP_00001737
|
186 |
+
BDMAP_00003361
|
187 |
+
BDMAP_00003377
|
188 |
+
BDMAP_00000437
|
189 |
+
BDMAP_00002237
|
190 |
+
BDMAP_00003900
|
191 |
+
BDMAP_00001754
|
192 |
+
BDMAP_00004288
|
193 |
+
BDMAP_00002612
|
194 |
+
BDMAP_00003329
|
195 |
+
BDMAP_00004187
|
196 |
+
BDMAP_00000873
|
197 |
+
BDMAP_00003525
|
198 |
+
BDMAP_00000921
|
199 |
+
BDMAP_00004231
|
200 |
+
BDMAP_00001343
|
201 |
+
BDMAP_00004793
|
202 |
+
BDMAP_00001898
|
203 |
+
BDMAP_00002271
|
204 |
+
BDMAP_00002313
|
205 |
+
BDMAP_00002896
|
206 |
+
BDMAP_00000851
|
207 |
+
BDMAP_00004165
|
208 |
+
BDMAP_00003840
|
209 |
+
BDMAP_00000338
|
210 |
+
BDMAP_00000715
|
211 |
+
BDMAP_00004295
|
212 |
+
BDMAP_00000236
|
213 |
+
BDMAP_00001985
|
214 |
+
BDMAP_00003633
|
215 |
+
BDMAP_00004825
|
216 |
+
BDMAP_00002305
|
217 |
+
BDMAP_00001237
|
218 |
+
BDMAP_00002419
|
219 |
+
BDMAP_00001766
|
220 |
+
BDMAP_00004546
|
221 |
+
BDMAP_00000881
|
222 |
+
BDMAP_00001836
|
223 |
+
BDMAP_00003052
|
224 |
+
BDMAP_00001502
|
225 |
+
BDMAP_00003483
|
226 |
+
BDMAP_00003396
|
227 |
+
BDMAP_00005119
|
228 |
+
BDMAP_00003299
|
229 |
+
BDMAP_00000568
|
230 |
+
BDMAP_00003590
|
231 |
+
BDMAP_00002616
|
232 |
+
BDMAP_00001835
|
233 |
+
BDMAP_00002172
|
234 |
+
BDMAP_00004964
|
235 |
+
BDMAP_00002944
|
236 |
+
BDMAP_00002465
|
237 |
+
BDMAP_00002227
|
238 |
+
BDMAP_00001905
|
239 |
+
BDMAP_00002603
|
240 |
+
BDMAP_00003111
|
241 |
+
BDMAP_00004398
|
242 |
+
BDMAP_00002373
|
243 |
+
BDMAP_00000093
|
244 |
+
BDMAP_00001247
|
245 |
+
BDMAP_00003172
|
246 |
+
BDMAP_00001865
|
247 |
+
BDMAP_00001545
|
248 |
+
BDMAP_00000411
|
249 |
+
BDMAP_00002349
|
250 |
+
BDMAP_00001617
|
251 |
+
BDMAP_00003884
|
252 |
+
BDMAP_00000809
|
253 |
+
BDMAP_00003497
|
254 |
+
BDMAP_00003961
|
255 |
+
BDMAP_00005139
|
256 |
+
BDMAP_00001628
|
257 |
+
BDMAP_00004969
|
258 |
+
BDMAP_00004228
|
259 |
+
BDMAP_00001316
|
260 |
+
BDMAP_00005160
|
261 |
+
BDMAP_00001024
|
262 |
+
BDMAP_00005073
|
263 |
+
BDMAP_00001209
|
264 |
+
BDMAP_00004954
|
265 |
+
BDMAP_00003798
|
266 |
+
BDMAP_00005063
|
267 |
+
BDMAP_00001476
|
268 |
+
BDMAP_00000243
|
269 |
+
BDMAP_00003809
|
270 |
+
BDMAP_00001309
|
271 |
+
BDMAP_00003886
|
272 |
+
BDMAP_00002758
|
273 |
+
BDMAP_00002289
|
274 |
+
BDMAP_00001862
|
275 |
+
BDMAP_00004804
|
276 |
+
BDMAP_00003113
|
277 |
+
BDMAP_00001361
|
278 |
+
BDMAP_00000692
|
279 |
+
BDMAP_00001523
|
280 |
+
BDMAP_00004115
|
281 |
+
BDMAP_00002387
|
282 |
+
BDMAP_00003781
|
283 |
+
BDMAP_00000087
|
284 |
+
BDMAP_00001823
|
285 |
+
BDMAP_00000940
|
286 |
+
BDMAP_00004719
|
287 |
+
BDMAP_00004624
|
288 |
+
BDMAP_00002849
|
289 |
+
BDMAP_00003657
|
290 |
+
BDMAP_00001461
|
291 |
+
BDMAP_00002690
|
292 |
+
BDMAP_00003236
|
293 |
+
BDMAP_00004558
|
294 |
+
BDMAP_00004639
|
295 |
+
BDMAP_00004541
|
296 |
+
BDMAP_00005083
|
297 |
+
BDMAP_00000907
|
298 |
+
BDMAP_00000972
|
299 |
+
BDMAP_00001200
|
300 |
+
BDMAP_00003168
|
301 |
+
BDMAP_00000828
|
302 |
+
BDMAP_00004450
|
303 |
+
BDMAP_00001597
|
304 |
+
BDMAP_00003867
|
305 |
+
BDMAP_00001746
|
306 |
+
BDMAP_00002252
|
307 |
+
BDMAP_00002947
|
308 |
+
BDMAP_00004878
|
309 |
+
BDMAP_00001842
|
310 |
+
BDMAP_00002654
|
311 |
+
BDMAP_00002185
|
312 |
+
BDMAP_00001802
|
313 |
+
BDMAP_00001040
|
314 |
+
BDMAP_00004198
|
315 |
+
BDMAP_00000831
|
316 |
+
BDMAP_00004491
|
317 |
+
BDMAP_00003109
|
318 |
+
BDMAP_00002120
|
319 |
+
BDMAP_00001834
|
320 |
+
BDMAP_00002619
|
321 |
+
BDMAP_00000138
|
322 |
+
BDMAP_00004773
|
323 |
+
BDMAP_00001236
|
324 |
+
BDMAP_00002402
|
325 |
+
BDMAP_00001598
|
326 |
+
BDMAP_00000714
|
327 |
+
BDMAP_00003356
|
328 |
+
BDMAP_00000462
|
329 |
+
BDMAP_00001114
|
330 |
+
BDMAP_00000607
|
331 |
+
BDMAP_00004297
|
332 |
+
BDMAP_00004841
|
333 |
+
BDMAP_00005022
|
334 |
+
BDMAP_00000572
|
335 |
+
BDMAP_00000541
|
336 |
+
BDMAP_00005140
|
337 |
+
BDMAP_00004415
|
338 |
+
BDMAP_00003946
|
339 |
+
BDMAP_00003319
|
340 |
+
BDMAP_00003510
|
341 |
+
BDMAP_00004163
|
342 |
+
BDMAP_00002458
|
343 |
+
BDMAP_00005020
|
344 |
+
BDMAP_00004511
|
345 |
+
BDMAP_00004549
|
346 |
+
BDMAP_00005155
|
347 |
+
BDMAP_00004147
|
348 |
+
BDMAP_00004876
|
349 |
+
BDMAP_00002103
|
350 |
+
BDMAP_00000882
|
351 |
+
BDMAP_00003138
|
352 |
+
BDMAP_00005037
|
353 |
+
BDMAP_00003853
|
354 |
+
BDMAP_00002039
|
355 |
+
BDMAP_00000774
|
356 |
+
BDMAP_00004741
|
357 |
+
BDMAP_00001171
|
358 |
+
BDMAP_00004636
|
359 |
+
BDMAP_00002332
|
360 |
+
BDMAP_00004894
|
361 |
+
BDMAP_00002730
|
362 |
+
BDMAP_00001125
|
363 |
+
BDMAP_00003822
|
364 |
+
BDMAP_00003592
|
365 |
+
BDMAP_00001368
|
366 |
+
BDMAP_00003513
|
367 |
+
BDMAP_00003612
|
368 |
+
BDMAP_00005169
|
369 |
+
BDMAP_00004017
|
370 |
+
BDMAP_00002855
|
371 |
+
BDMAP_00000152
|
372 |
+
BDMAP_00000091
|
373 |
+
BDMAP_00004529
|
374 |
+
BDMAP_00003443
|
375 |
+
BDMAP_00003543
|
376 |
+
BDMAP_00002267
|
377 |
+
BDMAP_00004462
|
378 |
+
BDMAP_00000874
|
379 |
+
BDMAP_00002793
|
380 |
+
BDMAP_00001471
|
381 |
+
BDMAP_00001605
|
382 |
+
BDMAP_00000709
|
383 |
+
BDMAP_00004435
|
384 |
+
BDMAP_00003524
|
385 |
+
BDMAP_00000965
|
386 |
+
BDMAP_00000939
|
387 |
+
BDMAP_00002278
|
388 |
+
BDMAP_00002295
|
389 |
+
BDMAP_00000971
|
390 |
+
BDMAP_00004917
|
391 |
+
BDMAP_00003812
|
392 |
+
BDMAP_00002401
|
393 |
+
BDMAP_00003074
|
394 |
+
BDMAP_00004028
|
395 |
+
BDMAP_00001982
|
396 |
+
BDMAP_00004281
|
397 |
+
BDMAP_00000347
|
398 |
+
BDMAP_00001732
|
399 |
+
BDMAP_00001205
|
400 |
+
BDMAP_00001379
|
401 |
+
BDMAP_00001095
|
402 |
+
BDMAP_00004770
|
403 |
+
BDMAP_00002283
|
404 |
+
BDMAP_00000052
|
405 |
+
BDMAP_00000192
|
406 |
+
BDMAP_00003564
|
407 |
+
BDMAP_00003427
|
408 |
+
BDMAP_00004888
|
409 |
+
BDMAP_00005016
|
410 |
+
BDMAP_00004745
|
411 |
+
BDMAP_00001078
|
412 |
+
BDMAP_00001122
|
413 |
+
BDMAP_00001584
|
414 |
+
BDMAP_00003551
|
415 |
+
BDMAP_00002495
|
416 |
+
BDMAP_00000589
|
417 |
+
BDMAP_00005065
|
418 |
+
BDMAP_00002171
|
419 |
+
BDMAP_00004830
|
420 |
+
BDMAP_00001804
|
421 |
+
BDMAP_00004493
|
422 |
+
BDMAP_00000400
|
423 |
+
BDMAP_00000745
|
424 |
+
BDMAP_00001333
|
425 |
+
BDMAP_00004890
|
426 |
+
BDMAP_00002845
|
427 |
+
BDMAP_00001875
|
428 |
+
BDMAP_00001096
|
429 |
+
BDMAP_00004060
|
430 |
+
BDMAP_00002451
|
431 |
+
BDMAP_00002523
|
432 |
+
BDMAP_00002899
|
433 |
+
BDMAP_00000642
|
434 |
+
BDMAP_00005075
|
435 |
+
BDMAP_00003685
|
436 |
+
BDMAP_00004650
|
437 |
+
BDMAP_00001618
|
438 |
+
BDMAP_00000771
|
439 |
+
BDMAP_00003920
|
440 |
+
BDMAP_00002309
|
441 |
+
BDMAP_00004847
|
442 |
+
BDMAP_00002485
|
443 |
+
BDMAP_00001590
|
444 |
+
BDMAP_00001692
|
445 |
+
BDMAP_00003502
|
446 |
+
BDMAP_00000431
|
447 |
+
BDMAP_00000679
|
448 |
+
BDMAP_00002986
|
449 |
+
BDMAP_00003277
|
450 |
+
BDMAP_00004885
|
451 |
+
BDMAP_00000427
|
452 |
+
BDMAP_00000716
|
453 |
+
BDMAP_00003744
|
454 |
+
BDMAP_00001806
|
455 |
+
BDMAP_00003857
|
456 |
+
BDMAP_00000859
|
457 |
+
BDMAP_00001067
|
458 |
+
BDMAP_00004121
|
459 |
+
BDMAP_00002475
|
460 |
+
BDMAP_00002318
|
461 |
+
BDMAP_00003114
|
462 |
+
BDMAP_00001712
|
463 |
+
BDMAP_00001214
|
464 |
+
BDMAP_00000362
|
465 |
+
BDMAP_00001441
|
466 |
+
BDMAP_00003272
|
467 |
+
BDMAP_00000956
|
468 |
+
BDMAP_00005064
|
469 |
+
BDMAP_00000154
|
470 |
+
BDMAP_00005186
|
471 |
+
BDMAP_00003658
|
472 |
+
BDMAP_00002704
|
473 |
+
BDMAP_00004796
|
474 |
+
BDMAP_00000197
|
475 |
+
BDMAP_00005070
|
476 |
+
BDMAP_00005001
|
477 |
+
BDMAP_00000480
|
478 |
+
BDMAP_00005078
|
479 |
+
BDMAP_00001564
|
480 |
+
BDMAP_00001025
|
481 |
+
BDMAP_00003598
|
482 |
+
BDMAP_00004262
|
483 |
+
BDMAP_00001092
|
484 |
+
BDMAP_00004185
|
485 |
+
BDMAP_00003776
|
486 |
+
BDMAP_00001270
|
487 |
+
BDMAP_00000615
|
488 |
+
BDMAP_00003141
|
489 |
+
BDMAP_00003330
|
490 |
+
BDMAP_00000190
|
491 |
+
BDMAP_00003650
|
492 |
+
BDMAP_00001397
|
493 |
+
BDMAP_00005185
|
494 |
+
BDMAP_00001966
|
495 |
+
BDMAP_00004184
|
496 |
+
BDMAP_00004992
|
497 |
+
BDMAP_00004416
|
498 |
+
BDMAP_00000993
|
499 |
+
BDMAP_00001445
|
500 |
+
BDMAP_00003482
|
501 |
+
BDMAP_00004514
|
502 |
+
BDMAP_00001504
|
503 |
+
BDMAP_00000416
|
504 |
+
BDMAP_00002805
|
505 |
+
BDMAP_00002232
|
506 |
+
BDMAP_00004384
|
507 |
+
BDMAP_00001921
|
508 |
+
BDMAP_00001426
|
509 |
+
BDMAP_00004910
|
510 |
+
BDMAP_00003560
|
511 |
+
BDMAP_00003130
|
512 |
+
BDMAP_00005108
|
513 |
+
BDMAP_00000113
|
514 |
+
BDMAP_00001521
|
515 |
+
BDMAP_00003556
|
516 |
+
BDMAP_00003376
|
517 |
+
BDMAP_00000273
|
518 |
+
BDMAP_00004735
|
519 |
+
BDMAP_00001539
|
520 |
+
BDMAP_00004494
|
521 |
+
BDMAP_00001212
|
522 |
+
BDMAP_00005067
|
523 |
+
BDMAP_00000413
|
524 |
+
BDMAP_00002863
|
525 |
+
BDMAP_00000671
|
526 |
+
BDMAP_00004927
|
527 |
+
BDMAP_00002167
|
528 |
+
BDMAP_00002152
|
529 |
+
BDMAP_00005168
|
530 |
+
BDMAP_00003911
|
531 |
+
BDMAP_00002250
|
532 |
+
BDMAP_00003215
|
533 |
+
BDMAP_00002737
|
534 |
+
BDMAP_00001514
|
535 |
+
BDMAP_00003440
|
536 |
+
BDMAP_00003031
|
537 |
+
BDMAP_00001786
|
538 |
+
BDMAP_00000552
|
539 |
+
BDMAP_00004943
|
540 |
+
BDMAP_00003268
|
541 |
+
BDMAP_00002233
|
542 |
+
BDMAP_00002362
|
543 |
+
BDMAP_00001440
|
544 |
+
BDMAP_00000225
|
545 |
+
BDMAP_00003347
|
546 |
+
BDMAP_00002739
|
547 |
+
BDMAP_00003479
|
548 |
+
BDMAP_00003481
|
549 |
+
BDMAP_00003326
|
550 |
+
BDMAP_00000683
|
551 |
+
BDMAP_00004378
|
552 |
+
BDMAP_00003367
|
553 |
+
BDMAP_00000855
|
554 |
+
BDMAP_00002298
|
555 |
+
BDMAP_00004077
|
556 |
+
BDMAP_00002253
|
557 |
+
BDMAP_00001331
|
558 |
+
BDMAP_00000542
|
559 |
+
BDMAP_00002924
|
560 |
+
BDMAP_00005092
|
561 |
+
BDMAP_00004374
|
562 |
+
BDMAP_00004509
|
563 |
+
BDMAP_00000264
|
564 |
+
BDMAP_00000918
|
565 |
+
BDMAP_00000030
|
Generation_Pipeline_filter_all2/syn_kidney/requirements.txt
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.1.0
|
2 |
+
accelerate==0.11.0
|
3 |
+
aiohttp==3.8.1
|
4 |
+
aiosignal==1.2.0
|
5 |
+
antlr4-python3-runtime==4.9.3
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==21.4.0
|
8 |
+
autopep8==1.6.0
|
9 |
+
cachetools==5.2.0
|
10 |
+
certifi==2022.6.15
|
11 |
+
charset-normalizer==2.0.12
|
12 |
+
click==8.1.3
|
13 |
+
cycler==0.11.0
|
14 |
+
Deprecated==1.2.13
|
15 |
+
docker-pycreds==0.4.0
|
16 |
+
einops==0.4.1
|
17 |
+
einops-exts==0.0.3
|
18 |
+
ema-pytorch==0.0.8
|
19 |
+
fonttools==4.34.4
|
20 |
+
frozenlist==1.3.0
|
21 |
+
fsspec==2022.5.0
|
22 |
+
ftfy==6.1.1
|
23 |
+
future==0.18.2
|
24 |
+
gitdb==4.0.9
|
25 |
+
GitPython==3.1.27
|
26 |
+
google-auth==2.9.0
|
27 |
+
google-auth-oauthlib==0.4.6
|
28 |
+
grpcio==1.47.0
|
29 |
+
h5py==3.7.0
|
30 |
+
humanize==4.2.2
|
31 |
+
hydra-core==1.2.0
|
32 |
+
idna==3.3
|
33 |
+
imageio==2.19.3
|
34 |
+
imageio-ffmpeg==0.4.7
|
35 |
+
importlib-metadata==4.12.0
|
36 |
+
importlib-resources==5.9.0
|
37 |
+
joblib==1.1.0
|
38 |
+
kiwisolver==1.4.3
|
39 |
+
lxml==4.9.1
|
40 |
+
Markdown==3.3.7
|
41 |
+
matplotlib==3.5.2
|
42 |
+
multidict==6.0.2
|
43 |
+
networkx==2.8.5
|
44 |
+
nibabel==4.0.1
|
45 |
+
nilearn==0.9.1
|
46 |
+
numpy==1.23.0
|
47 |
+
oauthlib==3.2.0
|
48 |
+
omegaconf==2.2.3
|
49 |
+
pandas==1.4.3
|
50 |
+
Pillow==9.1.1
|
51 |
+
pyasn1==0.4.8
|
52 |
+
pyasn1-modules==0.2.8
|
53 |
+
pycodestyle==2.8.0
|
54 |
+
pyDeprecate==0.3.1
|
55 |
+
pydicom==2.3.0
|
56 |
+
pytorch-lightning==1.6.4
|
57 |
+
pytz==2022.1
|
58 |
+
PyWavelets==1.3.0
|
59 |
+
PyYAML==6.0
|
60 |
+
pyzmq==19.0.2
|
61 |
+
regex==2022.6.2
|
62 |
+
requests==2.28.0
|
63 |
+
requests-oauthlib==1.3.1
|
64 |
+
rotary-embedding-torch==0.1.5
|
65 |
+
rsa==4.8
|
66 |
+
scikit-image==0.19.3
|
67 |
+
scikit-learn==1.1.2
|
68 |
+
scikit-video==1.1.11
|
69 |
+
scipy==1.8.1
|
70 |
+
seaborn==0.11.2
|
71 |
+
sentry-sdk==1.7.2
|
72 |
+
setproctitle==1.2.3
|
73 |
+
shortuuid==1.0.9
|
74 |
+
SimpleITK==2.1.1.2
|
75 |
+
smmap==5.0.0
|
76 |
+
tensorboard==2.9.1
|
77 |
+
tensorboard-data-server==0.6.1
|
78 |
+
tensorboard-plugin-wit==1.8.1
|
79 |
+
threadpoolctl==3.1.0
|
80 |
+
tifffile==2022.8.3
|
81 |
+
toml==0.10.2
|
82 |
+
torch-tb-profiler==0.4.0
|
83 |
+
torchio==0.18.80
|
84 |
+
torchmetrics==0.9.1
|
85 |
+
tqdm==4.64.0
|
86 |
+
typing_extensions==4.2.0
|
87 |
+
urllib3==1.26.9
|
88 |
+
wandb==0.12.21
|
89 |
+
Werkzeug==2.1.2
|
90 |
+
wrapt==1.14.1
|
91 |
+
yarl==1.7.2
|
92 |
+
zipp==3.8.0
|
93 |
+
wandb
|
94 |
+
tensorboardX==2.4.1
|
Generation_Pipeline_filter_all2/syn_liver/CT_syn_data.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, time, csv
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from sklearn.metrics import confusion_matrix
|
5 |
+
from scipy import ndimage
|
6 |
+
from scipy.ndimage import label
|
7 |
+
from functools import partial
|
8 |
+
import monai
|
9 |
+
from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged
|
10 |
+
from monai import transforms, data
|
11 |
+
from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare
|
12 |
+
import nibabel as nib
|
13 |
+
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
parser = argparse.ArgumentParser(description='liver tumor validation')
|
19 |
+
|
20 |
+
# file dir
|
21 |
+
parser.add_argument('--data_root', default=None, type=str)
|
22 |
+
parser.add_argument('--organ_type', default='liver', type=str)
|
23 |
+
parser.add_argument('--save_dir', default='out', type=str)
|
24 |
+
parser.add_argument('--data_file', default='out', type=str)
|
25 |
+
parser.add_argument('--ddim_ts', default=50, type=int)
|
26 |
+
parser.add_argument('--fg_thresh', default=30, type=int)
|
27 |
+
parser.add_argument('--start', default=0, type=int)
|
28 |
+
parser.add_argument('--end', default=1000, type=int)
|
29 |
+
|
30 |
+
def voxel2R(A):
|
31 |
+
return (np.array(A)/4*3/np.pi)**(1/3)
|
32 |
+
|
33 |
+
class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld):
|
34 |
+
def __init__(self, keys, label_key, spatial_size,
|
35 |
+
pos=1.0, neg=1.0, num_samples=1,
|
36 |
+
image_key=None, image_threshold=0.0, allow_missing_keys=True,
|
37 |
+
fg_thresh=0):
|
38 |
+
super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size,
|
39 |
+
pos=pos, neg=neg, num_samples=num_samples,
|
40 |
+
image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys)
|
41 |
+
self.fg_thresh = fg_thresh
|
42 |
+
|
43 |
+
def R2voxel(self,R):
|
44 |
+
return (4/3*np.pi)*(R)**(3)
|
45 |
+
|
46 |
+
def __call__(self, data):
|
47 |
+
d = dict(data)
|
48 |
+
data_name = d['name']
|
49 |
+
d.pop('name')
|
50 |
+
|
51 |
+
if '10_Decathlon' in data_name or '05_KiTS' in data_name:
|
52 |
+
d_crop = super().__call__(d)
|
53 |
+
|
54 |
+
else:
|
55 |
+
flag=0
|
56 |
+
while 1:
|
57 |
+
flag+=1
|
58 |
+
|
59 |
+
d_crop = super().__call__(d)
|
60 |
+
pixel_num = (d_crop[0]['label']>0).sum()
|
61 |
+
|
62 |
+
if pixel_num > self.R2voxel(self.fg_thresh):
|
63 |
+
break
|
64 |
+
if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)):
|
65 |
+
break
|
66 |
+
if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)):
|
67 |
+
break
|
68 |
+
if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)):
|
69 |
+
break
|
70 |
+
if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)):
|
71 |
+
break
|
72 |
+
if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)):
|
73 |
+
break
|
74 |
+
if flag>50:
|
75 |
+
break
|
76 |
+
|
77 |
+
d_crop[0]['name'] = data_name
|
78 |
+
|
79 |
+
return d_crop
|
80 |
+
|
81 |
+
def _get_loader(args):
|
82 |
+
# val_data_dir = args.val_dir
|
83 |
+
# datalist_json = args.json_dir
|
84 |
+
val_org_transform = transforms.Compose(
|
85 |
+
[
|
86 |
+
transforms.LoadImaged(keys=["image", "label", "raw_image"]),
|
87 |
+
transforms.AddChanneld(keys=["image", "label", "raw_image"]),
|
88 |
+
transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
|
89 |
+
transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")),
|
90 |
+
transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
|
91 |
+
transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]),
|
92 |
+
RandCropByPosNegLabeld_select(
|
93 |
+
keys=["image", "label", "name"],
|
94 |
+
label_key="label",
|
95 |
+
spatial_size=(96, 96, 96),
|
96 |
+
pos=1,
|
97 |
+
neg=0,
|
98 |
+
num_samples=1,
|
99 |
+
image_key="image",
|
100 |
+
image_threshold=0,
|
101 |
+
fg_thresh = args.fg_thresh,
|
102 |
+
),
|
103 |
+
transforms.ToTensord(keys=["image", "label", "raw_image"]),
|
104 |
+
]
|
105 |
+
)
|
106 |
+
|
107 |
+
val_img=[]
|
108 |
+
val_lbl=[]
|
109 |
+
val_name=[]
|
110 |
+
|
111 |
+
for line in open(args.data_file):
|
112 |
+
# name = line.strip().split()[1].split('.')[0]
|
113 |
+
# val_img.append(args.data_root + line.strip().split()[0])
|
114 |
+
# val_lbl.append(args.data_root + line.strip().split()[1])
|
115 |
+
# breakpoint()
|
116 |
+
name = line.strip()
|
117 |
+
val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz'))
|
118 |
+
val_lbl.append(os.path.join(args.data_root, name, 'segmentations/liver.nii.gz'))
|
119 |
+
val_name.append(name)
|
120 |
+
data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name}
|
121 |
+
for image, label, name in zip(val_img, val_lbl, val_name)]
|
122 |
+
|
123 |
+
if args.end < len(data_dicts_val):
|
124 |
+
data_dicts_val = data_dicts_val[args.start:args.end]
|
125 |
+
else:
|
126 |
+
data_dicts_val = data_dicts_val[args.start:]
|
127 |
+
print('val len {}'.format(len(data_dicts_val)))
|
128 |
+
val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform)
|
129 |
+
val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True)
|
130 |
+
|
131 |
+
post_transforms = Compose([
|
132 |
+
Invertd(
|
133 |
+
keys=['image'],
|
134 |
+
transform=val_org_transform,
|
135 |
+
orig_keys="image",
|
136 |
+
nearest_interp=False,
|
137 |
+
# nearest_interp=True,
|
138 |
+
to_tensor=True,
|
139 |
+
),
|
140 |
+
Invertd(
|
141 |
+
keys=['label'],
|
142 |
+
transform=val_org_transform,
|
143 |
+
orig_keys="label",
|
144 |
+
nearest_interp=False,
|
145 |
+
# nearest_interp=True,
|
146 |
+
to_tensor=True,
|
147 |
+
)
|
148 |
+
])
|
149 |
+
return val_org_loader, post_transforms
|
150 |
+
|
151 |
+
def main():
|
152 |
+
args = parser.parse_args()
|
153 |
+
output_dir = args.save_dir
|
154 |
+
if not os.path.exists(output_dir):
|
155 |
+
os.makedirs(output_dir)
|
156 |
+
print("MAIN Argument values:")
|
157 |
+
for k, v in vars(args).items():
|
158 |
+
print(k, '=>', v)
|
159 |
+
print('-----------------')
|
160 |
+
|
161 |
+
## loader and post_transform
|
162 |
+
val_loader, post_transforms = _get_loader(args)
|
163 |
+
|
164 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
165 |
+
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
|
166 |
+
model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth"))
|
167 |
+
model.eval()
|
168 |
+
|
169 |
+
start_time=0
|
170 |
+
with torch.no_grad():
|
171 |
+
for idx, val_data in enumerate(val_loader):
|
172 |
+
print('idx',idx)
|
173 |
+
if idx == 0:
|
174 |
+
start_time = time.time()
|
175 |
+
# val_inputs = val_data["image"]
|
176 |
+
# name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0]
|
177 |
+
|
178 |
+
vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type)
|
179 |
+
|
180 |
+
healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image']
|
181 |
+
case_name = data_names[0].split('/')[-1]
|
182 |
+
print('case_name', case_name)
|
183 |
+
original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy()
|
184 |
+
if healthy_target.sum() == 0:
|
185 |
+
val_data = [post_transforms(i) for i in data.decollate_batch(val_data)]
|
186 |
+
tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8)
|
187 |
+
tumor_mask_ = np.zeros_like(tumor_mask)
|
188 |
+
nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/liver_tumor.nii.gz'))
|
189 |
+
continue
|
190 |
+
|
191 |
+
healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda()
|
192 |
+
healthy_target = (healthy_target==1).to(healthy_target)
|
193 |
+
|
194 |
+
tumor_types = ['early', 'medium', 'large']
|
195 |
+
# tumor_probs = np.array([0.45, 0.45, 0.1])
|
196 |
+
# tumor_probs = np.array([1.0, 0.0, 0.0])
|
197 |
+
tumor_probs = np.array([0.5, 0.4, 0.1])
|
198 |
+
synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
|
199 |
+
print('synthetic_tumor_type',synthetic_tumor_type)
|
200 |
+
flag=0
|
201 |
+
while 1:
|
202 |
+
if synthetic_tumor_type == 'early':
|
203 |
+
synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler)
|
204 |
+
elif synthetic_tumor_type == 'medium':
|
205 |
+
synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts)
|
206 |
+
elif synthetic_tumor_type == 'large':
|
207 |
+
synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts)
|
208 |
+
|
209 |
+
syn_confidence = model(synt_data).sigmoid()[:,1]
|
210 |
+
flag+=1
|
211 |
+
if syn_confidence>0.005:
|
212 |
+
break
|
213 |
+
elif flag > 20 and syn_confidence>0.001:
|
214 |
+
break
|
215 |
+
val_data['image'] = synt_data.detach()
|
216 |
+
val_data['label'] = synt_target.detach()
|
217 |
+
|
218 |
+
val_data = [post_transforms(i) for i in data.decollate_batch(val_data)]
|
219 |
+
synt_data = val_data[0]['image'][0]
|
220 |
+
synt_target = val_data[0]['label'][0]
|
221 |
+
final_data = raw_data[0,0]
|
222 |
+
|
223 |
+
synt_data = (synt_data*(250+175)-175)
|
224 |
+
final_data[synt_target>1] = synt_data[synt_target>1]
|
225 |
+
final_data = final_data.cpu().numpy()
|
226 |
+
final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8)
|
227 |
+
|
228 |
+
os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True)
|
229 |
+
os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True)
|
230 |
+
nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz'))
|
231 |
+
nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/liver_tumor.nii.gz'))
|
232 |
+
# breakpoint()
|
233 |
+
# nib.save(nib.Nifti1Image(synt_data.cpu().numpy(), original_affine), os.path.join(output_dir, 'synt_data.nii.gz'))
|
234 |
+
# nib.save(nib.Nifti1Image(synt_target.cpu().numpy(), original_affine), os.path.join(output_dir, 'synt_target.nii.gz'))
|
235 |
+
print('time = ', time.time()-start_time)
|
236 |
+
start_time = time.time()
|
237 |
+
|
238 |
+
# breakpoint()
|
239 |
+
if __name__ == "__main__":
|
240 |
+
main()
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```bash
|
2 |
+
wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll
|
3 |
+
mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz
|
4 |
+
tar -xzvf model_weight.tar.gz
|
5 |
+
```
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/TumorGenerated.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Hashable, Mapping, Dict
|
3 |
+
|
4 |
+
from monai.config import KeysCollection
|
5 |
+
from monai.config.type_definitions import NdarrayOrTensor
|
6 |
+
from monai.transforms.transform import MapTransform, RandomizableTransform
|
7 |
+
|
8 |
+
from .utils_ import SynthesisTumor
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
class TumorGenerated(RandomizableTransform, MapTransform):
|
12 |
+
def __init__(self,
|
13 |
+
keys: KeysCollection,
|
14 |
+
prob: float = 0.1,
|
15 |
+
tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2],
|
16 |
+
allow_missing_keys: bool = False
|
17 |
+
) -> None:
|
18 |
+
MapTransform.__init__(self, keys, allow_missing_keys)
|
19 |
+
RandomizableTransform.__init__(self, prob)
|
20 |
+
random.seed(0)
|
21 |
+
np.random.seed(0)
|
22 |
+
|
23 |
+
self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix']
|
24 |
+
|
25 |
+
assert len(tumor_prob) == 5
|
26 |
+
self.tumor_prob = np.array(tumor_prob)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
|
31 |
+
d = dict(data)
|
32 |
+
self.randomize(None)
|
33 |
+
|
34 |
+
if self._do_transform and (np.max(d['label']) <= 1):
|
35 |
+
tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel())
|
36 |
+
|
37 |
+
d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type)
|
38 |
+
# print(tumor_type, d['image'].shape, np.max(d['label']))
|
39 |
+
return d
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Online Version TumorGeneration ###
|
2 |
+
|
3 |
+
from .TumorGenerated import TumorGenerated
|
4 |
+
|
5 |
+
# from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc
ADDED
Binary file (1.62 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (301 Bytes). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (11.1 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc
ADDED
Binary file (7.2 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
vqgan_ckpt: None
|
2 |
+
|
3 |
+
# Have to be derived from VQ-GAN Latent space dimensions
|
4 |
+
diffusion_img_size: 24
|
5 |
+
diffusion_depth_size: 24
|
6 |
+
diffusion_num_channels: 17 # 17
|
7 |
+
out_dim: 8
|
8 |
+
dim_mults: [1,2,4,8]
|
9 |
+
results_folder: checkpoints/ddpm/
|
10 |
+
results_folder_postfix: 'own_dataset_t2'
|
11 |
+
load_milestone: False # False
|
12 |
+
|
13 |
+
batch_size: 2 # 40
|
14 |
+
num_workers: 20
|
15 |
+
logger: wandb
|
16 |
+
objective: pred_x0
|
17 |
+
save_and_sample_every: 1000
|
18 |
+
denoising_fn: Unet3D
|
19 |
+
train_lr: 1e-4
|
20 |
+
timesteps: 2 # number of steps
|
21 |
+
sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
|
22 |
+
loss_type: l1 # L1 or L2
|
23 |
+
train_num_steps: 700000 # total training steps
|
24 |
+
gradient_accumulate_every: 2 # gradient accumulation steps
|
25 |
+
ema_decay: 0.995 # exponential moving average decay
|
26 |
+
amp: False # turn on mixed precision
|
27 |
+
num_sample_rows: 1
|
28 |
+
gpus: 0
|
29 |
+
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 1234
|
2 |
+
batch_size: 2 # 30
|
3 |
+
num_workers: 32 # 30
|
4 |
+
|
5 |
+
gpus: 1
|
6 |
+
accumulate_grad_batches: 1
|
7 |
+
default_root_dir: checkpoints/vq_gan/
|
8 |
+
default_root_dir_postfix: 'flair'
|
9 |
+
resume_from_checkpoint:
|
10 |
+
max_steps: -1
|
11 |
+
max_epochs: -1
|
12 |
+
precision: 16
|
13 |
+
gradient_clip_val: 1.0
|
14 |
+
|
15 |
+
|
16 |
+
embedding_dim: 8 # 256
|
17 |
+
n_codes: 16384 # 2048
|
18 |
+
n_hiddens: 16
|
19 |
+
lr: 3e-4
|
20 |
+
downsample: [2, 2, 2] # [4, 4, 4]
|
21 |
+
disc_channels: 64
|
22 |
+
disc_layers: 3
|
23 |
+
discriminator_iter_start: 10000 # 50000
|
24 |
+
disc_loss_type: hinge
|
25 |
+
image_gan_weight: 1.0
|
26 |
+
video_gan_weight: 1.0
|
27 |
+
l1_weight: 4.0
|
28 |
+
gan_feat_weight: 4.0 # 0.0
|
29 |
+
perceptual_weight: 4.0 # 0.0
|
30 |
+
i3d_feat: False
|
31 |
+
restart_thres: 1.0
|
32 |
+
no_random_restart: False
|
33 |
+
norm_type: group
|
34 |
+
padding_type: replicate
|
35 |
+
num_groups: 32
|
36 |
+
|
37 |
+
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .diffusion import Unet3D, GaussianDiffusion, Tester
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (239 Bytes). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc
ADDED
Binary file (6.01 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc
ADDED
Binary file (28.5 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc
ADDED
Binary file (1.86 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc
ADDED
Binary file (2.85 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc
ADDED
Binary file (5.77 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc
ADDED
Binary file (9.42 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/ddim.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SAMPLING ONLY."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
9 |
+
|
10 |
+
|
11 |
+
class DDIMSampler(object):
|
12 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.model = model
|
15 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
16 |
+
self.schedule = schedule
|
17 |
+
|
18 |
+
def register_buffer(self, name, attr):
|
19 |
+
if type(attr) == torch.Tensor:
|
20 |
+
if attr.device != torch.device("cuda"):
|
21 |
+
attr = attr.to(torch.device("cuda"))
|
22 |
+
setattr(self, name, attr)
|
23 |
+
|
24 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad'
|
25 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
26 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
27 |
+
|
28 |
+
alphas_cumprod = self.model.alphas_cumprod
|
29 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
30 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
31 |
+
|
32 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
33 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
34 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
35 |
+
|
36 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
37 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
38 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
39 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
40 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
41 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
42 |
+
# breakpoint()
|
43 |
+
# ddim sampling parameters
|
44 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
45 |
+
ddim_timesteps=self.ddim_timesteps,
|
46 |
+
eta=ddim_eta,verbose=verbose)
|
47 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
48 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
49 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
50 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
51 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
52 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
53 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
54 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def sample(self,
|
58 |
+
S,
|
59 |
+
batch_size,
|
60 |
+
shape,
|
61 |
+
conditioning=None,
|
62 |
+
callback=None,
|
63 |
+
normals_sequence=None,
|
64 |
+
img_callback=None,
|
65 |
+
quantize_x0=False,
|
66 |
+
eta=0.,
|
67 |
+
mask=None,
|
68 |
+
x0=None,
|
69 |
+
temperature=1.,
|
70 |
+
noise_dropout=0.,
|
71 |
+
score_corrector=None,
|
72 |
+
corrector_kwargs=None,
|
73 |
+
verbose=True,
|
74 |
+
x_T=None,
|
75 |
+
log_every_t=100,
|
76 |
+
unconditional_guidance_scale=1.,
|
77 |
+
unconditional_conditioning=None,
|
78 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
79 |
+
**kwargs
|
80 |
+
):
|
81 |
+
if conditioning is not None:
|
82 |
+
if isinstance(conditioning, dict):
|
83 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
84 |
+
if cbs != batch_size:
|
85 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
86 |
+
else:
|
87 |
+
if conditioning.shape[0] != batch_size:
|
88 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
89 |
+
|
90 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
91 |
+
# sampling
|
92 |
+
C, T, H, W = shape
|
93 |
+
# breakpoint()
|
94 |
+
size = (batch_size, C, T, H, W)
|
95 |
+
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
96 |
+
|
97 |
+
samples, intermediates = self.ddim_sampling(conditioning, size,
|
98 |
+
callback=callback,
|
99 |
+
img_callback=img_callback,
|
100 |
+
quantize_denoised=quantize_x0,
|
101 |
+
mask=mask, x0=x0,
|
102 |
+
ddim_use_original_steps=False,
|
103 |
+
noise_dropout=noise_dropout,
|
104 |
+
temperature=temperature,
|
105 |
+
score_corrector=score_corrector,
|
106 |
+
corrector_kwargs=corrector_kwargs,
|
107 |
+
x_T=x_T,
|
108 |
+
log_every_t=log_every_t,
|
109 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
110 |
+
unconditional_conditioning=unconditional_conditioning,
|
111 |
+
)
|
112 |
+
return samples, intermediates
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def ddim_sampling(self, cond, shape,
|
116 |
+
x_T=None, ddim_use_original_steps=False,
|
117 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
118 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
119 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
120 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
121 |
+
device = self.model.betas.device
|
122 |
+
b = shape[0]
|
123 |
+
if x_T is None:
|
124 |
+
img = torch.randn(shape, device=device)
|
125 |
+
else:
|
126 |
+
img = x_T
|
127 |
+
|
128 |
+
if timesteps is None:
|
129 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
130 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
131 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
132 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
133 |
+
|
134 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
135 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
136 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
137 |
+
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
138 |
+
|
139 |
+
# iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
140 |
+
|
141 |
+
for i, step in enumerate(time_range):
|
142 |
+
index = total_steps - i - 1
|
143 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
144 |
+
|
145 |
+
if mask is not None:
|
146 |
+
assert x0 is not None
|
147 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
148 |
+
img = img_orig * mask + (1. - mask) * img
|
149 |
+
|
150 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
151 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
152 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
153 |
+
corrector_kwargs=corrector_kwargs,
|
154 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
155 |
+
unconditional_conditioning=unconditional_conditioning)
|
156 |
+
img, pred_x0 = outs
|
157 |
+
if callback: callback(i)
|
158 |
+
if img_callback: img_callback(pred_x0, i)
|
159 |
+
|
160 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
161 |
+
intermediates['x_inter'].append(img)
|
162 |
+
intermediates['pred_x0'].append(pred_x0)
|
163 |
+
|
164 |
+
return img, intermediates
|
165 |
+
|
166 |
+
@torch.no_grad()
|
167 |
+
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
168 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
169 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
170 |
+
b, *_, device = *x.shape, x.device
|
171 |
+
|
172 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
173 |
+
# breakpoint()
|
174 |
+
e_t = self.model.denoise_fn(x, t, c)
|
175 |
+
else:
|
176 |
+
x_in = torch.cat([x] * 2)
|
177 |
+
t_in = torch.cat([t] * 2)
|
178 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
179 |
+
e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2)
|
180 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
181 |
+
|
182 |
+
if score_corrector is not None:
|
183 |
+
assert self.model.parameterization == "eps"
|
184 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
185 |
+
|
186 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
187 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
188 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
189 |
+
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
190 |
+
# select parameters corresponding to the currently considered timestep
|
191 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
192 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
193 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
194 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
195 |
+
|
196 |
+
# current prediction for x_0
|
197 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
198 |
+
if quantize_denoised:
|
199 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
200 |
+
# direction pointing to x_t
|
201 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
202 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
203 |
+
if noise_dropout > 0.:
|
204 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
205 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
206 |
+
return x_prev, pred_x0
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py
ADDED
@@ -0,0 +1,1016 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
|
2 |
+
|
3 |
+
import math
|
4 |
+
import copy
|
5 |
+
import torch
|
6 |
+
from torch import nn, einsum
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
from torch.utils import data
|
11 |
+
from pathlib import Path
|
12 |
+
from torch.optim import Adam
|
13 |
+
from torchvision import transforms as T, utils
|
14 |
+
from torch.cuda.amp import autocast, GradScaler
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
from tqdm import tqdm
|
18 |
+
from einops import rearrange
|
19 |
+
from einops_exts import check_shape, rearrange_many
|
20 |
+
|
21 |
+
from rotary_embedding_torch import RotaryEmbedding
|
22 |
+
|
23 |
+
from .text import tokenize, bert_embed, BERT_MODEL_DIM
|
24 |
+
from torch.utils.data import Dataset, DataLoader
|
25 |
+
from ..vq_gan_3d.model.vqgan import VQGAN
|
26 |
+
|
27 |
+
import matplotlib.pyplot as plt
|
28 |
+
|
29 |
+
# helpers functions
|
30 |
+
|
31 |
+
|
32 |
+
def exists(x):
|
33 |
+
return x is not None
|
34 |
+
|
35 |
+
|
36 |
+
def noop(*args, **kwargs):
|
37 |
+
pass
|
38 |
+
|
39 |
+
|
40 |
+
def is_odd(n):
|
41 |
+
return (n % 2) == 1
|
42 |
+
|
43 |
+
|
44 |
+
def default(val, d):
|
45 |
+
if exists(val):
|
46 |
+
return val
|
47 |
+
return d() if callable(d) else d
|
48 |
+
|
49 |
+
|
50 |
+
def cycle(dl):
|
51 |
+
while True:
|
52 |
+
for data in dl:
|
53 |
+
yield data
|
54 |
+
|
55 |
+
|
56 |
+
def num_to_groups(num, divisor):
|
57 |
+
groups = num // divisor
|
58 |
+
remainder = num % divisor
|
59 |
+
arr = [divisor] * groups
|
60 |
+
if remainder > 0:
|
61 |
+
arr.append(remainder)
|
62 |
+
return arr
|
63 |
+
|
64 |
+
|
65 |
+
def prob_mask_like(shape, prob, device):
|
66 |
+
if prob == 1:
|
67 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
68 |
+
elif prob == 0:
|
69 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
70 |
+
else:
|
71 |
+
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
72 |
+
|
73 |
+
|
74 |
+
def is_list_str(x):
|
75 |
+
if not isinstance(x, (list, tuple)):
|
76 |
+
return False
|
77 |
+
return all([type(el) == str for el in x])
|
78 |
+
|
79 |
+
# relative positional bias
|
80 |
+
|
81 |
+
|
82 |
+
class RelativePositionBias(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
heads=8,
|
86 |
+
num_buckets=32,
|
87 |
+
max_distance=128
|
88 |
+
):
|
89 |
+
super().__init__()
|
90 |
+
self.num_buckets = num_buckets
|
91 |
+
self.max_distance = max_distance
|
92 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
|
96 |
+
ret = 0
|
97 |
+
n = -relative_position
|
98 |
+
|
99 |
+
num_buckets //= 2
|
100 |
+
ret += (n < 0).long() * num_buckets
|
101 |
+
n = torch.abs(n)
|
102 |
+
|
103 |
+
max_exact = num_buckets // 2
|
104 |
+
is_small = n < max_exact
|
105 |
+
|
106 |
+
val_if_large = max_exact + (
|
107 |
+
torch.log(n.float() / max_exact) / math.log(max_distance /
|
108 |
+
max_exact) * (num_buckets - max_exact)
|
109 |
+
).long()
|
110 |
+
val_if_large = torch.min(
|
111 |
+
val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
112 |
+
|
113 |
+
ret += torch.where(is_small, n, val_if_large)
|
114 |
+
return ret
|
115 |
+
|
116 |
+
def forward(self, n, device):
|
117 |
+
q_pos = torch.arange(n, dtype=torch.long, device=device)
|
118 |
+
k_pos = torch.arange(n, dtype=torch.long, device=device)
|
119 |
+
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
120 |
+
rp_bucket = self._relative_position_bucket(
|
121 |
+
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
|
122 |
+
values = self.relative_attention_bias(rp_bucket)
|
123 |
+
return rearrange(values, 'i j h -> h i j')
|
124 |
+
|
125 |
+
# small helper modules
|
126 |
+
|
127 |
+
|
128 |
+
class EMA():
|
129 |
+
def __init__(self, beta):
|
130 |
+
super().__init__()
|
131 |
+
self.beta = beta
|
132 |
+
|
133 |
+
def update_model_average(self, ma_model, current_model):
|
134 |
+
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
135 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
136 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
137 |
+
|
138 |
+
def update_average(self, old, new):
|
139 |
+
if old is None:
|
140 |
+
return new
|
141 |
+
return old * self.beta + (1 - self.beta) * new
|
142 |
+
|
143 |
+
|
144 |
+
class Residual(nn.Module):
|
145 |
+
def __init__(self, fn):
|
146 |
+
super().__init__()
|
147 |
+
self.fn = fn
|
148 |
+
|
149 |
+
def forward(self, x, *args, **kwargs):
|
150 |
+
return self.fn(x, *args, **kwargs) + x
|
151 |
+
|
152 |
+
|
153 |
+
class SinusoidalPosEmb(nn.Module):
|
154 |
+
def __init__(self, dim):
|
155 |
+
super().__init__()
|
156 |
+
self.dim = dim
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
device = x.device
|
160 |
+
half_dim = self.dim // 2
|
161 |
+
emb = math.log(10000) / (half_dim - 1)
|
162 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
163 |
+
emb = x[:, None] * emb[None, :]
|
164 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
165 |
+
return emb
|
166 |
+
|
167 |
+
|
168 |
+
def Upsample(dim):
|
169 |
+
return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
|
170 |
+
|
171 |
+
|
172 |
+
def Downsample(dim):
|
173 |
+
return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
|
174 |
+
|
175 |
+
|
176 |
+
class LayerNorm(nn.Module):
|
177 |
+
def __init__(self, dim, eps=1e-5):
|
178 |
+
super().__init__()
|
179 |
+
self.eps = eps
|
180 |
+
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
|
184 |
+
mean = torch.mean(x, dim=1, keepdim=True)
|
185 |
+
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
186 |
+
|
187 |
+
|
188 |
+
class PreNorm(nn.Module):
|
189 |
+
def __init__(self, dim, fn):
|
190 |
+
super().__init__()
|
191 |
+
self.fn = fn
|
192 |
+
self.norm = LayerNorm(dim)
|
193 |
+
|
194 |
+
def forward(self, x, **kwargs):
|
195 |
+
x = self.norm(x)
|
196 |
+
return self.fn(x, **kwargs)
|
197 |
+
|
198 |
+
# building block modules
|
199 |
+
|
200 |
+
|
201 |
+
class Block(nn.Module):
|
202 |
+
def __init__(self, dim, dim_out, groups=8):
|
203 |
+
super().__init__()
|
204 |
+
self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
|
205 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
206 |
+
self.act = nn.SiLU()
|
207 |
+
|
208 |
+
def forward(self, x, scale_shift=None):
|
209 |
+
x = self.proj(x)
|
210 |
+
x = self.norm(x)
|
211 |
+
|
212 |
+
if exists(scale_shift):
|
213 |
+
scale, shift = scale_shift
|
214 |
+
x = x * (scale + 1) + shift
|
215 |
+
|
216 |
+
return self.act(x)
|
217 |
+
|
218 |
+
|
219 |
+
class ResnetBlock(nn.Module):
|
220 |
+
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
221 |
+
super().__init__()
|
222 |
+
self.mlp = nn.Sequential(
|
223 |
+
nn.SiLU(),
|
224 |
+
nn.Linear(time_emb_dim, dim_out * 2)
|
225 |
+
) if exists(time_emb_dim) else None
|
226 |
+
|
227 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
228 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
229 |
+
self.res_conv = nn.Conv3d(
|
230 |
+
dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
231 |
+
|
232 |
+
def forward(self, x, time_emb=None):
|
233 |
+
|
234 |
+
scale_shift = None
|
235 |
+
if exists(self.mlp):
|
236 |
+
assert exists(time_emb), 'time emb must be passed in'
|
237 |
+
time_emb = self.mlp(time_emb)
|
238 |
+
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
|
239 |
+
scale_shift = time_emb.chunk(2, dim=1)
|
240 |
+
|
241 |
+
h = self.block1(x, scale_shift=scale_shift)
|
242 |
+
|
243 |
+
h = self.block2(h)
|
244 |
+
return h + self.res_conv(x)
|
245 |
+
|
246 |
+
|
247 |
+
class SpatialLinearAttention(nn.Module):
|
248 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
249 |
+
super().__init__()
|
250 |
+
self.scale = dim_head ** -0.5
|
251 |
+
self.heads = heads
|
252 |
+
hidden_dim = dim_head * heads
|
253 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
254 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
255 |
+
|
256 |
+
def forward(self, x):
|
257 |
+
b, c, f, h, w = x.shape
|
258 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
259 |
+
|
260 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
261 |
+
q, k, v = rearrange_many(
|
262 |
+
qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)
|
263 |
+
|
264 |
+
q = q.softmax(dim=-2)
|
265 |
+
k = k.softmax(dim=-1)
|
266 |
+
|
267 |
+
q = q * self.scale
|
268 |
+
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
|
269 |
+
|
270 |
+
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
|
271 |
+
out = rearrange(out, 'b h c (x y) -> b (h c) x y',
|
272 |
+
h=self.heads, x=h, y=w)
|
273 |
+
out = self.to_out(out)
|
274 |
+
return rearrange(out, '(b f) c h w -> b c f h w', b=b)
|
275 |
+
|
276 |
+
# attention along space and time
|
277 |
+
|
278 |
+
|
279 |
+
class EinopsToAndFrom(nn.Module):
|
280 |
+
def __init__(self, from_einops, to_einops, fn):
|
281 |
+
super().__init__()
|
282 |
+
self.from_einops = from_einops
|
283 |
+
self.to_einops = to_einops
|
284 |
+
self.fn = fn
|
285 |
+
|
286 |
+
def forward(self, x, **kwargs):
|
287 |
+
shape = x.shape
|
288 |
+
reconstitute_kwargs = dict(
|
289 |
+
tuple(zip(self.from_einops.split(' '), shape)))
|
290 |
+
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
|
291 |
+
x = self.fn(x, **kwargs)
|
292 |
+
x = rearrange(
|
293 |
+
x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
|
294 |
+
return x
|
295 |
+
|
296 |
+
|
297 |
+
class Attention(nn.Module):
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
dim,
|
301 |
+
heads=4,
|
302 |
+
dim_head=32,
|
303 |
+
rotary_emb=None
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
self.scale = dim_head ** -0.5
|
307 |
+
self.heads = heads
|
308 |
+
hidden_dim = dim_head * heads
|
309 |
+
|
310 |
+
self.rotary_emb = rotary_emb
|
311 |
+
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
|
312 |
+
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
|
313 |
+
|
314 |
+
def forward(
|
315 |
+
self,
|
316 |
+
x,
|
317 |
+
pos_bias=None,
|
318 |
+
focus_present_mask=None
|
319 |
+
):
|
320 |
+
n, device = x.shape[-2], x.device
|
321 |
+
|
322 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
323 |
+
|
324 |
+
if exists(focus_present_mask) and focus_present_mask.all():
|
325 |
+
# if all batch samples are focusing on present
|
326 |
+
# it would be equivalent to passing that token's values through to the output
|
327 |
+
values = qkv[-1]
|
328 |
+
return self.to_out(values)
|
329 |
+
|
330 |
+
# split out heads
|
331 |
+
|
332 |
+
q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
|
333 |
+
|
334 |
+
# scale
|
335 |
+
|
336 |
+
q = q * self.scale
|
337 |
+
|
338 |
+
# rotate positions into queries and keys for time attention
|
339 |
+
|
340 |
+
if exists(self.rotary_emb):
|
341 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
342 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
343 |
+
|
344 |
+
# similarity
|
345 |
+
|
346 |
+
sim = einsum('... h i d, ... h j d -> ... h i j', q, k)
|
347 |
+
|
348 |
+
# relative positional bias
|
349 |
+
|
350 |
+
if exists(pos_bias):
|
351 |
+
sim = sim + pos_bias
|
352 |
+
|
353 |
+
if exists(focus_present_mask) and not (~focus_present_mask).all():
|
354 |
+
attend_all_mask = torch.ones(
|
355 |
+
(n, n), device=device, dtype=torch.bool)
|
356 |
+
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
|
357 |
+
|
358 |
+
mask = torch.where(
|
359 |
+
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
|
360 |
+
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
|
361 |
+
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
|
362 |
+
)
|
363 |
+
|
364 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
365 |
+
|
366 |
+
# numerical stability
|
367 |
+
|
368 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
369 |
+
attn = sim.softmax(dim=-1)
|
370 |
+
|
371 |
+
# aggregate values
|
372 |
+
|
373 |
+
out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
|
374 |
+
out = rearrange(out, '... h n d -> ... n (h d)')
|
375 |
+
return self.to_out(out)
|
376 |
+
|
377 |
+
# model
|
378 |
+
|
379 |
+
|
380 |
+
class Unet3D(nn.Module):
|
381 |
+
def __init__(
|
382 |
+
self,
|
383 |
+
dim,
|
384 |
+
cond_dim=None,
|
385 |
+
out_dim=None,
|
386 |
+
dim_mults=(1, 2, 4, 8),
|
387 |
+
channels=3,
|
388 |
+
attn_heads=8,
|
389 |
+
attn_dim_head=32,
|
390 |
+
use_bert_text_cond=False,
|
391 |
+
init_dim=None,
|
392 |
+
init_kernel_size=7,
|
393 |
+
use_sparse_linear_attn=True,
|
394 |
+
block_type='resnet',
|
395 |
+
resnet_groups=8
|
396 |
+
):
|
397 |
+
super().__init__()
|
398 |
+
self.channels = channels
|
399 |
+
|
400 |
+
# temporal attention and its relative positional encoding
|
401 |
+
|
402 |
+
rotary_emb = RotaryEmbedding(min(32, attn_dim_head))
|
403 |
+
|
404 |
+
def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(
|
405 |
+
dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb))
|
406 |
+
|
407 |
+
# realistically will not be able to generate that many frames of video... yet
|
408 |
+
self.time_rel_pos_bias = RelativePositionBias(
|
409 |
+
heads=attn_heads, max_distance=32)
|
410 |
+
|
411 |
+
# initial conv
|
412 |
+
|
413 |
+
init_dim = default(init_dim, dim)
|
414 |
+
assert is_odd(init_kernel_size)
|
415 |
+
|
416 |
+
init_padding = init_kernel_size // 2
|
417 |
+
self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size,
|
418 |
+
init_kernel_size), padding=(0, init_padding, init_padding))
|
419 |
+
|
420 |
+
self.init_temporal_attn = Residual(
|
421 |
+
PreNorm(init_dim, temporal_attn(init_dim)))
|
422 |
+
|
423 |
+
# dimensions
|
424 |
+
|
425 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
426 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
427 |
+
|
428 |
+
# time conditioning
|
429 |
+
|
430 |
+
time_dim = dim * 4
|
431 |
+
self.time_mlp = nn.Sequential(
|
432 |
+
SinusoidalPosEmb(dim),
|
433 |
+
nn.Linear(dim, time_dim),
|
434 |
+
nn.GELU(),
|
435 |
+
nn.Linear(time_dim, time_dim)
|
436 |
+
)
|
437 |
+
|
438 |
+
# text conditioning
|
439 |
+
|
440 |
+
self.has_cond = exists(cond_dim) or use_bert_text_cond
|
441 |
+
cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim
|
442 |
+
|
443 |
+
self.null_cond_emb = nn.Parameter(
|
444 |
+
torch.randn(1, cond_dim)) if self.has_cond else None
|
445 |
+
|
446 |
+
cond_dim = time_dim + int(cond_dim or 0)
|
447 |
+
|
448 |
+
# layers
|
449 |
+
|
450 |
+
self.downs = nn.ModuleList([])
|
451 |
+
self.ups = nn.ModuleList([])
|
452 |
+
|
453 |
+
num_resolutions = len(in_out)
|
454 |
+
# block type
|
455 |
+
|
456 |
+
block_klass = partial(ResnetBlock, groups=resnet_groups)
|
457 |
+
block_klass_cond = partial(block_klass, time_emb_dim=cond_dim)
|
458 |
+
|
459 |
+
# modules for all layers
|
460 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
461 |
+
is_last = ind >= (num_resolutions - 1)
|
462 |
+
|
463 |
+
self.downs.append(nn.ModuleList([
|
464 |
+
block_klass_cond(dim_in, dim_out),
|
465 |
+
block_klass_cond(dim_out, dim_out),
|
466 |
+
Residual(PreNorm(dim_out, SpatialLinearAttention(
|
467 |
+
dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
|
468 |
+
Residual(PreNorm(dim_out, temporal_attn(dim_out))),
|
469 |
+
Downsample(dim_out) if not is_last else nn.Identity()
|
470 |
+
]))
|
471 |
+
|
472 |
+
mid_dim = dims[-1]
|
473 |
+
self.mid_block1 = block_klass_cond(mid_dim, mid_dim)
|
474 |
+
|
475 |
+
spatial_attn = EinopsToAndFrom(
|
476 |
+
'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))
|
477 |
+
|
478 |
+
self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
|
479 |
+
self.mid_temporal_attn = Residual(
|
480 |
+
PreNorm(mid_dim, temporal_attn(mid_dim)))
|
481 |
+
|
482 |
+
self.mid_block2 = block_klass_cond(mid_dim, mid_dim)
|
483 |
+
|
484 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
|
485 |
+
is_last = ind >= (num_resolutions - 1)
|
486 |
+
|
487 |
+
self.ups.append(nn.ModuleList([
|
488 |
+
block_klass_cond(dim_out * 2, dim_in),
|
489 |
+
block_klass_cond(dim_in, dim_in),
|
490 |
+
Residual(PreNorm(dim_in, SpatialLinearAttention(
|
491 |
+
dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
|
492 |
+
Residual(PreNorm(dim_in, temporal_attn(dim_in))),
|
493 |
+
Upsample(dim_in) if not is_last else nn.Identity()
|
494 |
+
]))
|
495 |
+
|
496 |
+
out_dim = default(out_dim, channels)
|
497 |
+
self.final_conv = nn.Sequential(
|
498 |
+
block_klass(dim * 2, dim),
|
499 |
+
nn.Conv3d(dim, out_dim, 1)
|
500 |
+
)
|
501 |
+
|
502 |
+
def forward_with_cond_scale(
|
503 |
+
self,
|
504 |
+
*args,
|
505 |
+
cond_scale=2.,
|
506 |
+
**kwargs
|
507 |
+
):
|
508 |
+
logits = self.forward(*args, null_cond_prob=0., **kwargs)
|
509 |
+
if cond_scale == 1 or not self.has_cond:
|
510 |
+
return logits
|
511 |
+
|
512 |
+
null_logits = self.forward(*args, null_cond_prob=1., **kwargs)
|
513 |
+
return null_logits + (logits - null_logits) * cond_scale
|
514 |
+
|
515 |
+
def forward(
|
516 |
+
self,
|
517 |
+
x,
|
518 |
+
time,
|
519 |
+
cond=None,
|
520 |
+
null_cond_prob=0.,
|
521 |
+
focus_present_mask=None,
|
522 |
+
# probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
|
523 |
+
prob_focus_present=0.
|
524 |
+
):
|
525 |
+
assert not (self.has_cond and not exists(cond)
|
526 |
+
), 'cond must be passed in if cond_dim specified'
|
527 |
+
x = torch.cat([x, cond], dim=1)
|
528 |
+
|
529 |
+
batch, device = x.shape[0], x.device
|
530 |
+
|
531 |
+
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like(
|
532 |
+
(batch,), prob_focus_present, device=device))
|
533 |
+
|
534 |
+
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)
|
535 |
+
|
536 |
+
x = self.init_conv(x)
|
537 |
+
r = x.clone()
|
538 |
+
|
539 |
+
x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)
|
540 |
+
|
541 |
+
t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128]
|
542 |
+
|
543 |
+
# classifier free guidance
|
544 |
+
|
545 |
+
if self.has_cond:
|
546 |
+
batch, device = x.shape[0], x.device
|
547 |
+
mask = prob_mask_like((batch,), null_cond_prob, device=device)
|
548 |
+
cond = torch.where(rearrange(mask, 'b -> b 1'),
|
549 |
+
self.null_cond_emb, cond)
|
550 |
+
t = torch.cat((t, cond), dim=-1)
|
551 |
+
|
552 |
+
h = []
|
553 |
+
|
554 |
+
for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
|
555 |
+
x = block1(x, t)
|
556 |
+
x = block2(x, t)
|
557 |
+
x = spatial_attn(x)
|
558 |
+
x = temporal_attn(x, pos_bias=time_rel_pos_bias,
|
559 |
+
focus_present_mask=focus_present_mask)
|
560 |
+
h.append(x)
|
561 |
+
x = downsample(x)
|
562 |
+
|
563 |
+
# [2, 256, 32, 4, 4]
|
564 |
+
x = self.mid_block1(x, t)
|
565 |
+
x = self.mid_spatial_attn(x)
|
566 |
+
x = self.mid_temporal_attn(
|
567 |
+
x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
|
568 |
+
x = self.mid_block2(x, t)
|
569 |
+
|
570 |
+
for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
|
571 |
+
x = torch.cat((x, h.pop()), dim=1)
|
572 |
+
x = block1(x, t)
|
573 |
+
x = block2(x, t)
|
574 |
+
x = spatial_attn(x)
|
575 |
+
x = temporal_attn(x, pos_bias=time_rel_pos_bias,
|
576 |
+
focus_present_mask=focus_present_mask)
|
577 |
+
x = upsample(x)
|
578 |
+
|
579 |
+
x = torch.cat((x, r), dim=1)
|
580 |
+
return self.final_conv(x)
|
581 |
+
|
582 |
+
# gaussian diffusion trainer class
|
583 |
+
|
584 |
+
|
585 |
+
def extract(a, t, x_shape):
|
586 |
+
b, *_ = t.shape
|
587 |
+
out = a.gather(-1, t)
|
588 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
589 |
+
|
590 |
+
|
591 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
592 |
+
"""
|
593 |
+
cosine schedule
|
594 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
595 |
+
"""
|
596 |
+
steps = timesteps + 1
|
597 |
+
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
|
598 |
+
alphas_cumprod = torch.cos(
|
599 |
+
((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
600 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
601 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
602 |
+
return torch.clip(betas, 0, 0.9999)
|
603 |
+
|
604 |
+
|
605 |
+
class GaussianDiffusion(nn.Module):
|
606 |
+
def __init__(
|
607 |
+
self,
|
608 |
+
denoise_fn,
|
609 |
+
*,
|
610 |
+
image_size,
|
611 |
+
num_frames,
|
612 |
+
text_use_bert_cls=False,
|
613 |
+
channels=3,
|
614 |
+
timesteps=1000,
|
615 |
+
loss_type='l1',
|
616 |
+
use_dynamic_thres=False, # from the Imagen paper
|
617 |
+
dynamic_thres_percentile=0.9,
|
618 |
+
vqgan_ckpt=None,
|
619 |
+
device=None
|
620 |
+
):
|
621 |
+
super().__init__()
|
622 |
+
self.channels = channels
|
623 |
+
self.image_size = image_size
|
624 |
+
self.num_frames = num_frames
|
625 |
+
self.denoise_fn = denoise_fn
|
626 |
+
self.device = device
|
627 |
+
|
628 |
+
if vqgan_ckpt:
|
629 |
+
self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda()
|
630 |
+
self.vqgan.eval()
|
631 |
+
else:
|
632 |
+
self.vqgan = None
|
633 |
+
|
634 |
+
betas = cosine_beta_schedule(timesteps)
|
635 |
+
|
636 |
+
alphas = 1. - betas
|
637 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
638 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
|
639 |
+
|
640 |
+
timesteps, = betas.shape
|
641 |
+
self.num_timesteps = int(timesteps)
|
642 |
+
self.loss_type = loss_type
|
643 |
+
|
644 |
+
# register buffer helper function that casts float64 to float32
|
645 |
+
|
646 |
+
def register_buffer(name, val): return self.register_buffer(
|
647 |
+
name, val.to(torch.float32))
|
648 |
+
|
649 |
+
register_buffer('betas', betas)
|
650 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
651 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
652 |
+
|
653 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
654 |
+
|
655 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
656 |
+
register_buffer('sqrt_one_minus_alphas_cumprod',
|
657 |
+
torch.sqrt(1. - alphas_cumprod))
|
658 |
+
register_buffer('log_one_minus_alphas_cumprod',
|
659 |
+
torch.log(1. - alphas_cumprod))
|
660 |
+
register_buffer('sqrt_recip_alphas_cumprod',
|
661 |
+
torch.sqrt(1. / alphas_cumprod))
|
662 |
+
register_buffer('sqrt_recipm1_alphas_cumprod',
|
663 |
+
torch.sqrt(1. / alphas_cumprod - 1))
|
664 |
+
|
665 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
666 |
+
|
667 |
+
posterior_variance = betas * \
|
668 |
+
(1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
669 |
+
|
670 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
671 |
+
|
672 |
+
register_buffer('posterior_variance', posterior_variance)
|
673 |
+
|
674 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
675 |
+
|
676 |
+
register_buffer('posterior_log_variance_clipped',
|
677 |
+
torch.log(posterior_variance.clamp(min=1e-20)))
|
678 |
+
register_buffer('posterior_mean_coef1', betas *
|
679 |
+
torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
680 |
+
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev)
|
681 |
+
* torch.sqrt(alphas) / (1. - alphas_cumprod))
|
682 |
+
|
683 |
+
# text conditioning parameters
|
684 |
+
|
685 |
+
self.text_use_bert_cls = text_use_bert_cls
|
686 |
+
|
687 |
+
# dynamic thresholding when sampling
|
688 |
+
|
689 |
+
self.use_dynamic_thres = use_dynamic_thres
|
690 |
+
self.dynamic_thres_percentile = dynamic_thres_percentile
|
691 |
+
|
692 |
+
def q_mean_variance(self, x_start, t):
|
693 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
694 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
695 |
+
log_variance = extract(
|
696 |
+
self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
697 |
+
return mean, variance, log_variance
|
698 |
+
|
699 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
700 |
+
return (
|
701 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
702 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
703 |
+
)
|
704 |
+
|
705 |
+
def q_posterior(self, x_start, x_t, t):
|
706 |
+
posterior_mean = (
|
707 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
708 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
709 |
+
)
|
710 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
711 |
+
posterior_log_variance_clipped = extract(
|
712 |
+
self.posterior_log_variance_clipped, t, x_t.shape)
|
713 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
714 |
+
|
715 |
+
def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.):
|
716 |
+
x_recon = self.predict_start_from_noise(
|
717 |
+
x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale))
|
718 |
+
|
719 |
+
if clip_denoised:
|
720 |
+
s = 1.
|
721 |
+
if self.use_dynamic_thres:
|
722 |
+
s = torch.quantile(
|
723 |
+
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
724 |
+
self.dynamic_thres_percentile,
|
725 |
+
dim=-1
|
726 |
+
)
|
727 |
+
|
728 |
+
s.clamp_(min=1.)
|
729 |
+
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
730 |
+
|
731 |
+
# clip by threshold, depending on whether static or dynamic
|
732 |
+
x_recon = x_recon.clamp(-s, s) / s
|
733 |
+
|
734 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
735 |
+
x_start=x_recon, x_t=x, t=t)
|
736 |
+
return model_mean, posterior_variance, posterior_log_variance
|
737 |
+
|
738 |
+
@torch.inference_mode()
|
739 |
+
def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True):
|
740 |
+
b, *_, device = *x.shape, x.device
|
741 |
+
model_mean, _, model_log_variance = self.p_mean_variance(
|
742 |
+
x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale)
|
743 |
+
noise = torch.randn_like(x)
|
744 |
+
# no noise when t == 0
|
745 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b,
|
746 |
+
*((1,) * (len(x.shape) - 1)))
|
747 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
748 |
+
|
749 |
+
@torch.inference_mode()
|
750 |
+
def p_sample_loop(self, shape, cond=None, cond_scale=1.):
|
751 |
+
device = self.betas.device
|
752 |
+
|
753 |
+
b = shape[0]
|
754 |
+
img = torch.randn(shape, device=device)
|
755 |
+
# print('cond', cond.shape)
|
756 |
+
for i in reversed(range(0, self.num_timesteps)):
|
757 |
+
img = self.p_sample(img, torch.full(
|
758 |
+
(b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale)
|
759 |
+
|
760 |
+
return img
|
761 |
+
|
762 |
+
@torch.inference_mode()
|
763 |
+
def sample(self, cond=None, cond_scale=1., batch_size=16):
|
764 |
+
device = next(self.denoise_fn.parameters()).device
|
765 |
+
|
766 |
+
if is_list_str(cond):
|
767 |
+
cond = bert_embed(tokenize(cond)).to(device)
|
768 |
+
|
769 |
+
# batch_size = cond.shape[0] if exists(cond) else batch_size
|
770 |
+
batch_size = batch_size
|
771 |
+
image_size = self.image_size
|
772 |
+
channels = 8 # self.channels
|
773 |
+
num_frames = self.num_frames
|
774 |
+
# print((batch_size, channels, num_frames, image_size, image_size))
|
775 |
+
# print('cond_',cond.shape)
|
776 |
+
_sample = self.p_sample_loop(
|
777 |
+
(batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale)
|
778 |
+
|
779 |
+
if isinstance(self.vqgan, VQGAN):
|
780 |
+
# denormalize TODO: Remove eventually
|
781 |
+
_sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() -
|
782 |
+
self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min()
|
783 |
+
|
784 |
+
_sample = self.vqgan.decode(_sample, quantize=True)
|
785 |
+
else:
|
786 |
+
unnormalize_img(_sample)
|
787 |
+
|
788 |
+
return _sample
|
789 |
+
|
790 |
+
@torch.inference_mode()
|
791 |
+
def interpolate(self, x1, x2, t=None, lam=0.5):
|
792 |
+
b, *_, device = *x1.shape, x1.device
|
793 |
+
t = default(t, self.num_timesteps - 1)
|
794 |
+
|
795 |
+
assert x1.shape == x2.shape
|
796 |
+
|
797 |
+
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
|
798 |
+
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
|
799 |
+
|
800 |
+
img = (1 - lam) * xt1 + lam * xt2
|
801 |
+
for i in reversed(range(0, t)):
|
802 |
+
img = self.p_sample(img, torch.full(
|
803 |
+
(b,), i, device=device, dtype=torch.long))
|
804 |
+
|
805 |
+
return img
|
806 |
+
|
807 |
+
def q_sample(self, x_start, t, noise=None):
|
808 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
809 |
+
|
810 |
+
return (
|
811 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
812 |
+
extract(self.sqrt_one_minus_alphas_cumprod,
|
813 |
+
t, x_start.shape) * noise
|
814 |
+
)
|
815 |
+
|
816 |
+
def p_losses(self, x_start, t, cond=None, noise=None, **kwargs):
|
817 |
+
b, c, f, h, w, device = *x_start.shape, x_start.device
|
818 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
819 |
+
# breakpoint()
|
820 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32]
|
821 |
+
|
822 |
+
if is_list_str(cond):
|
823 |
+
cond = bert_embed(
|
824 |
+
tokenize(cond), return_cls_repr=self.text_use_bert_cls)
|
825 |
+
cond = cond.to(device)
|
826 |
+
|
827 |
+
x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs)
|
828 |
+
|
829 |
+
if self.loss_type == 'l1':
|
830 |
+
loss = F.l1_loss(noise, x_recon)
|
831 |
+
elif self.loss_type == 'l2':
|
832 |
+
loss = F.mse_loss(noise, x_recon)
|
833 |
+
else:
|
834 |
+
raise NotImplementedError()
|
835 |
+
|
836 |
+
return loss
|
837 |
+
|
838 |
+
def forward(self, x, *args, **kwargs):
|
839 |
+
bs = int(x.shape[0]/2)
|
840 |
+
img=x[:bs,...]
|
841 |
+
mask=x[bs:,...]
|
842 |
+
mask_=(1-mask).detach()
|
843 |
+
masked_img = (img*mask_).detach()
|
844 |
+
masked_img=masked_img.permute(0,1,-1,-3,-2)
|
845 |
+
img=img.permute(0,1,-1,-3,-2)
|
846 |
+
mask=mask.permute(0,1,-1,-3,-2)
|
847 |
+
# breakpoint()
|
848 |
+
if isinstance(self.vqgan, VQGAN):
|
849 |
+
with torch.no_grad():
|
850 |
+
img = self.vqgan.encode(
|
851 |
+
img, quantize=False, include_embeddings=True)
|
852 |
+
# normalize to -1 and 1
|
853 |
+
img = ((img - self.vqgan.codebook.embeddings.min()) /
|
854 |
+
(self.vqgan.codebook.embeddings.max() -
|
855 |
+
self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
|
856 |
+
|
857 |
+
masked_img = self.vqgan.encode(
|
858 |
+
masked_img, quantize=False, include_embeddings=True)
|
859 |
+
# normalize to -1 and 1
|
860 |
+
masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) /
|
861 |
+
(self.vqgan.codebook.embeddings.max() -
|
862 |
+
self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
|
863 |
+
else:
|
864 |
+
print("Hi")
|
865 |
+
img = normalize_img(img)
|
866 |
+
masked_img = normalize_img(masked_img)
|
867 |
+
mask = mask*2.0 - 1.0
|
868 |
+
cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:])
|
869 |
+
cond = torch.cat((masked_img, cc), dim=1)
|
870 |
+
|
871 |
+
b, device, img_size, = img.shape[0], img.device, self.image_size
|
872 |
+
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
873 |
+
# breakpoint()
|
874 |
+
return self.p_losses(img, t, cond=cond, *args, **kwargs)
|
875 |
+
|
876 |
+
# trainer class
|
877 |
+
|
878 |
+
|
879 |
+
CHANNELS_TO_MODE = {
|
880 |
+
1: 'L',
|
881 |
+
3: 'RGB',
|
882 |
+
4: 'RGBA'
|
883 |
+
}
|
884 |
+
|
885 |
+
|
886 |
+
def seek_all_images(img, channels=3):
|
887 |
+
assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
|
888 |
+
mode = CHANNELS_TO_MODE[channels]
|
889 |
+
|
890 |
+
i = 0
|
891 |
+
while True:
|
892 |
+
try:
|
893 |
+
img.seek(i)
|
894 |
+
yield img.convert(mode)
|
895 |
+
except EOFError:
|
896 |
+
break
|
897 |
+
i += 1
|
898 |
+
|
899 |
+
# tensor of shape (channels, frames, height, width) -> gif
|
900 |
+
|
901 |
+
|
902 |
+
def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
|
903 |
+
tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0
|
904 |
+
images = map(T.ToPILImage(), tensor.unbind(dim=1))
|
905 |
+
first_img, *rest_imgs = images
|
906 |
+
first_img.save(path, save_all=True, append_images=rest_imgs,
|
907 |
+
duration=duration, loop=loop, optimize=optimize)
|
908 |
+
return images
|
909 |
+
|
910 |
+
# gif -> (channels, frame, height, width) tensor
|
911 |
+
|
912 |
+
|
913 |
+
def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
|
914 |
+
img = Image.open(path)
|
915 |
+
tensors = tuple(map(transform, seek_all_images(img, channels=channels)))
|
916 |
+
return torch.stack(tensors, dim=1)
|
917 |
+
|
918 |
+
|
919 |
+
def identity(t, *args, **kwargs):
|
920 |
+
return t
|
921 |
+
|
922 |
+
|
923 |
+
def normalize_img(t):
|
924 |
+
return t * 2 - 1
|
925 |
+
|
926 |
+
|
927 |
+
def unnormalize_img(t):
|
928 |
+
return (t + 1) * 0.5
|
929 |
+
|
930 |
+
|
931 |
+
def cast_num_frames(t, *, frames):
|
932 |
+
f = t.shape[1]
|
933 |
+
|
934 |
+
if f == frames:
|
935 |
+
return t
|
936 |
+
|
937 |
+
if f > frames:
|
938 |
+
return t[:, :frames]
|
939 |
+
|
940 |
+
return F.pad(t, (0, 0, 0, 0, 0, frames - f))
|
941 |
+
|
942 |
+
|
943 |
+
class Dataset(data.Dataset):
|
944 |
+
def __init__(
|
945 |
+
self,
|
946 |
+
folder,
|
947 |
+
image_size,
|
948 |
+
channels=3,
|
949 |
+
num_frames=16,
|
950 |
+
horizontal_flip=False,
|
951 |
+
force_num_frames=True,
|
952 |
+
exts=['gif']
|
953 |
+
):
|
954 |
+
super().__init__()
|
955 |
+
self.folder = folder
|
956 |
+
self.image_size = image_size
|
957 |
+
self.channels = channels
|
958 |
+
self.paths = [p for ext in exts for p in Path(
|
959 |
+
f'{folder}').glob(f'**/*.{ext}')]
|
960 |
+
|
961 |
+
self.cast_num_frames_fn = partial(
|
962 |
+
cast_num_frames, frames=num_frames) if force_num_frames else identity
|
963 |
+
|
964 |
+
self.transform = T.Compose([
|
965 |
+
T.Resize(image_size),
|
966 |
+
T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
|
967 |
+
T.CenterCrop(image_size),
|
968 |
+
T.ToTensor()
|
969 |
+
])
|
970 |
+
|
971 |
+
def __len__(self):
|
972 |
+
return len(self.paths)
|
973 |
+
|
974 |
+
def __getitem__(self, index):
|
975 |
+
path = self.paths[index]
|
976 |
+
tensor = gif_to_tensor(path, self.channels, transform=self.transform)
|
977 |
+
return self.cast_num_frames_fn(tensor)
|
978 |
+
|
979 |
+
# trainer class
|
980 |
+
|
981 |
+
|
982 |
+
class Tester(object):
|
983 |
+
def __init__(
|
984 |
+
self,
|
985 |
+
diffusion_model,
|
986 |
+
):
|
987 |
+
super().__init__()
|
988 |
+
self.model = diffusion_model
|
989 |
+
self.ema_model = copy.deepcopy(self.model)
|
990 |
+
self.step=0
|
991 |
+
self.image_size = diffusion_model.image_size
|
992 |
+
|
993 |
+
self.reset_parameters()
|
994 |
+
|
995 |
+
def reset_parameters(self):
|
996 |
+
self.ema_model.load_state_dict(self.model.state_dict())
|
997 |
+
|
998 |
+
|
999 |
+
def load(self, milestone, map_location=None, **kwargs):
|
1000 |
+
if milestone == -1:
|
1001 |
+
all_milestones = [int(p.stem.split('-')[-1])
|
1002 |
+
for p in Path(self.results_folder).glob('**/*.pt')]
|
1003 |
+
assert len(
|
1004 |
+
all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)'
|
1005 |
+
milestone = max(all_milestones)
|
1006 |
+
|
1007 |
+
if map_location:
|
1008 |
+
data = torch.load(milestone, map_location=map_location)
|
1009 |
+
else:
|
1010 |
+
data = torch.load(milestone)
|
1011 |
+
|
1012 |
+
self.step = data['step']
|
1013 |
+
self.model.load_state_dict(data['model'], **kwargs)
|
1014 |
+
self.ema_model.load_state_dict(data['ema'], **kwargs)
|
1015 |
+
|
1016 |
+
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/text.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
def exists(val):
|
8 |
+
return val is not None
|
9 |
+
|
10 |
+
# singleton globals
|
11 |
+
|
12 |
+
|
13 |
+
MODEL = None
|
14 |
+
TOKENIZER = None
|
15 |
+
BERT_MODEL_DIM = 768
|
16 |
+
|
17 |
+
|
18 |
+
def get_tokenizer():
|
19 |
+
global TOKENIZER
|
20 |
+
if not exists(TOKENIZER):
|
21 |
+
TOKENIZER = torch.hub.load(
|
22 |
+
'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
|
23 |
+
return TOKENIZER
|
24 |
+
|
25 |
+
|
26 |
+
def get_bert():
|
27 |
+
global MODEL
|
28 |
+
if not exists(MODEL):
|
29 |
+
MODEL = torch.hub.load(
|
30 |
+
'huggingface/pytorch-transformers', 'model', 'bert-base-cased')
|
31 |
+
if torch.cuda.is_available():
|
32 |
+
MODEL = MODEL.cuda()
|
33 |
+
|
34 |
+
return MODEL
|
35 |
+
|
36 |
+
# tokenize
|
37 |
+
|
38 |
+
|
39 |
+
def tokenize(texts, add_special_tokens=True):
|
40 |
+
if not isinstance(texts, (list, tuple)):
|
41 |
+
texts = [texts]
|
42 |
+
|
43 |
+
tokenizer = get_tokenizer()
|
44 |
+
|
45 |
+
encoding = tokenizer.batch_encode_plus(
|
46 |
+
texts,
|
47 |
+
add_special_tokens=add_special_tokens,
|
48 |
+
padding=True,
|
49 |
+
return_tensors='pt'
|
50 |
+
)
|
51 |
+
|
52 |
+
token_ids = encoding.input_ids
|
53 |
+
return token_ids
|
54 |
+
|
55 |
+
# embedding function
|
56 |
+
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def bert_embed(
|
60 |
+
token_ids,
|
61 |
+
return_cls_repr=False,
|
62 |
+
eps=1e-8,
|
63 |
+
pad_id=0.
|
64 |
+
):
|
65 |
+
model = get_bert()
|
66 |
+
mask = token_ids != pad_id
|
67 |
+
|
68 |
+
if torch.cuda.is_available():
|
69 |
+
token_ids = token_ids.cuda()
|
70 |
+
mask = mask.cuda()
|
71 |
+
|
72 |
+
outputs = model(
|
73 |
+
input_ids=token_ids,
|
74 |
+
attention_mask=mask,
|
75 |
+
output_hidden_states=True
|
76 |
+
)
|
77 |
+
|
78 |
+
hidden_state = outputs.hidden_states[-1]
|
79 |
+
|
80 |
+
if return_cls_repr:
|
81 |
+
# return [cls] as representation
|
82 |
+
return hidden_state[:, 0]
|
83 |
+
|
84 |
+
if not exists(mask):
|
85 |
+
return hidden_state.mean(dim=1)
|
86 |
+
|
87 |
+
# mean all tokens excluding [cls], accounting for length
|
88 |
+
mask = mask[:, 1:]
|
89 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
90 |
+
|
91 |
+
numer = (hidden_state[:, 1:] * mask).sum(dim=1)
|
92 |
+
denom = mask.sum(dim=1)
|
93 |
+
masked_mean = numer / (denom + eps)
|
94 |
+
return
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from monai.networks.layers.utils import get_act_layer
|
6 |
+
|
7 |
+
|
8 |
+
class SinusoidalPosEmb(nn.Module):
|
9 |
+
def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False):
|
10 |
+
super().__init__()
|
11 |
+
self.emb_dim = emb_dim
|
12 |
+
self.downscale_freq_shift = downscale_freq_shift
|
13 |
+
self.max_period = max_period
|
14 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
device = x.device
|
18 |
+
half_dim = self.emb_dim // 2
|
19 |
+
emb = math.log(self.max_period) / \
|
20 |
+
(half_dim - self.downscale_freq_shift)
|
21 |
+
emb = torch.exp(-emb*torch.arange(half_dim, device=device))
|
22 |
+
emb = x[:, None] * emb[None, :]
|
23 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
24 |
+
|
25 |
+
if self.flip_sin_to_cos:
|
26 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
27 |
+
|
28 |
+
if self.emb_dim % 2 == 1:
|
29 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
30 |
+
return emb
|
31 |
+
|
32 |
+
|
33 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
34 |
+
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
|
35 |
+
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
36 |
+
|
37 |
+
def __init__(self, emb_dim):
|
38 |
+
super().__init__()
|
39 |
+
self.emb_dim = emb_dim
|
40 |
+
half_dim = emb_dim // 2
|
41 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
x = x[:, None]
|
45 |
+
freqs = x * self.weights[None, :] * 2 * math.pi
|
46 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
47 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
48 |
+
if self.emb_dim % 2 == 1:
|
49 |
+
fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0))
|
50 |
+
return fouriered
|
51 |
+
|
52 |
+
|
53 |
+
class TimeEmbbeding(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
emb_dim=64,
|
57 |
+
pos_embedder=SinusoidalPosEmb,
|
58 |
+
pos_embedder_kwargs={},
|
59 |
+
act_name=("SWISH", {}) # Swish = SiLU
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.emb_dim = emb_dim
|
63 |
+
self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4)
|
64 |
+
pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim
|
65 |
+
self.pos_embedder = pos_embedder(**pos_embedder_kwargs)
|
66 |
+
|
67 |
+
self.time_emb = nn.Sequential(
|
68 |
+
self.pos_embedder,
|
69 |
+
nn.Linear(self.pos_emb_dim, self.emb_dim),
|
70 |
+
get_act_layer(act_name),
|
71 |
+
nn.Linear(self.emb_dim, self.emb_dim)
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, time):
|
75 |
+
return self.time_emb(time)
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/unet.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ddpm.time_embedding import TimeEmbbeding
|
2 |
+
|
3 |
+
import monai.networks.nets as nets
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock
|
9 |
+
from monai.networks.layers.utils import get_act_layer
|
10 |
+
|
11 |
+
|
12 |
+
class DownBlock(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
spatial_dims,
|
16 |
+
in_ch,
|
17 |
+
out_ch,
|
18 |
+
time_emb_dim,
|
19 |
+
cond_emb_dim,
|
20 |
+
act_name=("swish", {}),
|
21 |
+
**kwargs):
|
22 |
+
super(DownBlock, self).__init__()
|
23 |
+
self.loca_time_embedder = nn.Sequential(
|
24 |
+
get_act_layer(name=act_name),
|
25 |
+
nn.Linear(time_emb_dim, in_ch) # in_ch * 2
|
26 |
+
)
|
27 |
+
if cond_emb_dim is not None:
|
28 |
+
self.loca_cond_embedder = nn.Sequential(
|
29 |
+
get_act_layer(name=act_name),
|
30 |
+
nn.Linear(cond_emb_dim, in_ch),
|
31 |
+
)
|
32 |
+
self.down_op = UnetBasicBlock(
|
33 |
+
spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs)
|
34 |
+
|
35 |
+
def forward(self, x, time_emb, cond_emb):
|
36 |
+
b, c, *_ = x.shape
|
37 |
+
sp_dim = x.ndim-2
|
38 |
+
|
39 |
+
# ------------ Time ----------
|
40 |
+
time_emb = self.loca_time_embedder(time_emb)
|
41 |
+
time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
|
42 |
+
# scale, shift = time_emb.chunk(2, dim = 1)
|
43 |
+
|
44 |
+
# ------------ Combine ------------
|
45 |
+
# x = x * (scale + 1) + shift
|
46 |
+
x = x + time_emb
|
47 |
+
|
48 |
+
# ----------- Condition ------------
|
49 |
+
if cond_emb is not None:
|
50 |
+
cond_emb = self.loca_cond_embedder(cond_emb)
|
51 |
+
cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
|
52 |
+
x = x + cond_emb
|
53 |
+
|
54 |
+
# ----------- Image ---------
|
55 |
+
y = self.down_op(x)
|
56 |
+
return y
|
57 |
+
|
58 |
+
|
59 |
+
class UpBlock(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
spatial_dims,
|
63 |
+
skip_ch,
|
64 |
+
enc_ch,
|
65 |
+
time_emb_dim,
|
66 |
+
cond_emb_dim,
|
67 |
+
act_name=("swish", {}),
|
68 |
+
**kwargs):
|
69 |
+
super(UpBlock, self).__init__()
|
70 |
+
self.up_op = UnetUpBlock(spatial_dims, enc_ch,
|
71 |
+
skip_ch, act_name=act_name, **kwargs)
|
72 |
+
self.loca_time_embedder = nn.Sequential(
|
73 |
+
get_act_layer(name=act_name),
|
74 |
+
nn.Linear(time_emb_dim, skip_ch * 2),
|
75 |
+
)
|
76 |
+
if cond_emb_dim is not None:
|
77 |
+
self.loca_cond_embedder = nn.Sequential(
|
78 |
+
get_act_layer(name=act_name),
|
79 |
+
nn.Linear(cond_emb_dim, skip_ch * 2),
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, x_skip, x_enc, time_emb, cond_emb):
|
83 |
+
b, c, *_ = x_enc.shape
|
84 |
+
sp_dim = x_enc.ndim-2
|
85 |
+
|
86 |
+
# ----------- Time --------------
|
87 |
+
time_emb = self.loca_time_embedder(time_emb)
|
88 |
+
time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
|
89 |
+
# scale, shift = time_emb.chunk(2, dim = 1)
|
90 |
+
|
91 |
+
# -------- Combine -------------
|
92 |
+
# y = x * (scale + 1) + shift
|
93 |
+
x_enc = x_enc + time_emb
|
94 |
+
|
95 |
+
# ----------- Condition ------------
|
96 |
+
if cond_emb is not None:
|
97 |
+
cond_emb = self.loca_cond_embedder(cond_emb)
|
98 |
+
cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
|
99 |
+
x_enc = x_enc + cond_emb
|
100 |
+
|
101 |
+
# ----------- Image -------------
|
102 |
+
y = self.up_op(x_enc, x_skip)
|
103 |
+
|
104 |
+
# -------- Combine -------------
|
105 |
+
# y = y * (scale + 1) + shift
|
106 |
+
|
107 |
+
return y
|
108 |
+
|
109 |
+
|
110 |
+
class UNet(nn.Module):
|
111 |
+
|
112 |
+
def __init__(self,
|
113 |
+
in_ch=1,
|
114 |
+
out_ch=1,
|
115 |
+
spatial_dims=3,
|
116 |
+
hid_chs=[32, 64, 128, 256, 512],
|
117 |
+
kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3],
|
118 |
+
strides=[1, (1, 2, 2), (1, 2, 2), 2, 2],
|
119 |
+
upsample_kernel_sizes=None,
|
120 |
+
act_name=("SWISH", {}),
|
121 |
+
norm_name=("INSTANCE", {"affine": True}),
|
122 |
+
time_embedder=TimeEmbbeding,
|
123 |
+
time_embedder_kwargs={},
|
124 |
+
cond_embedder=None,
|
125 |
+
cond_embedder_kwargs={},
|
126 |
+
# True = all but last layer, 0/False=disable, 1=only first layer, ...
|
127 |
+
deep_ver_supervision=True,
|
128 |
+
estimate_variance=False,
|
129 |
+
use_self_conditioning=False,
|
130 |
+
**kwargs
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
if upsample_kernel_sizes is None:
|
134 |
+
upsample_kernel_sizes = strides[1:]
|
135 |
+
|
136 |
+
# ------------- Time-Embedder-----------
|
137 |
+
self.time_embedder = time_embedder(**time_embedder_kwargs)
|
138 |
+
|
139 |
+
# ------------- Condition-Embedder-----------
|
140 |
+
if cond_embedder is not None:
|
141 |
+
self.cond_embedder = cond_embedder(**cond_embedder_kwargs)
|
142 |
+
cond_emb_dim = self.cond_embedder.emb_dim
|
143 |
+
else:
|
144 |
+
self.cond_embedder = None
|
145 |
+
cond_emb_dim = None
|
146 |
+
|
147 |
+
# ----------- In-Convolution ------------
|
148 |
+
in_ch = in_ch*2 if use_self_conditioning else in_ch
|
149 |
+
self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0],
|
150 |
+
act_name=act_name, norm_name=norm_name, **kwargs)
|
151 |
+
|
152 |
+
# ----------- Encoder ----------------
|
153 |
+
self.encoders = nn.ModuleList([
|
154 |
+
DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim,
|
155 |
+
cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[
|
156 |
+
i], stride=strides[i], act_name=act_name,
|
157 |
+
norm_name=norm_name, **kwargs)
|
158 |
+
for i in range(1, len(strides))
|
159 |
+
])
|
160 |
+
|
161 |
+
# ------------ Decoder ----------
|
162 |
+
self.decoders = nn.ModuleList([
|
163 |
+
UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim,
|
164 |
+
cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i +
|
165 |
+
1], stride=strides[i+1], act_name=act_name,
|
166 |
+
norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs)
|
167 |
+
for i in range(len(strides)-1)
|
168 |
+
])
|
169 |
+
|
170 |
+
# --------------- Out-Convolution ----------------
|
171 |
+
out_ch_hor = out_ch*2 if estimate_variance else out_ch
|
172 |
+
self.outc = UnetOutBlock(
|
173 |
+
spatial_dims, hid_chs[0], out_ch_hor, dropout=None)
|
174 |
+
if isinstance(deep_ver_supervision, bool):
|
175 |
+
deep_ver_supervision = len(
|
176 |
+
strides)-2 if deep_ver_supervision else 0
|
177 |
+
self.outc_ver = nn.ModuleList([
|
178 |
+
UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None)
|
179 |
+
for i in range(1, deep_ver_supervision+1)
|
180 |
+
])
|
181 |
+
|
182 |
+
def forward(self, x_t, t, cond=None, self_cond=None, **kwargs):
|
183 |
+
condition = cond
|
184 |
+
# x_t [B, C, (D), H, W]
|
185 |
+
# t [B,]
|
186 |
+
|
187 |
+
# -------- In-Convolution --------------
|
188 |
+
x = [None for _ in range(len(self.encoders)+1)]
|
189 |
+
x_t = torch.cat([x_t, self_cond],
|
190 |
+
dim=1) if self_cond is not None else x_t
|
191 |
+
x[0] = self.inc(x_t)
|
192 |
+
|
193 |
+
# -------- Time Embedding (Gloabl) -----------
|
194 |
+
time_emb = self.time_embedder(t) # [B, C]
|
195 |
+
|
196 |
+
# -------- Condition Embedding (Gloabl) -----------
|
197 |
+
if (condition is None) or (self.cond_embedder is None):
|
198 |
+
cond_emb = None
|
199 |
+
else:
|
200 |
+
cond_emb = self.cond_embedder(condition) # [B, C]
|
201 |
+
|
202 |
+
# --------- Encoder --------------
|
203 |
+
for i in range(len(self.encoders)):
|
204 |
+
x[i+1] = self.encoders[i](x[i], time_emb, cond_emb)
|
205 |
+
|
206 |
+
# -------- Decoder -----------
|
207 |
+
for i in range(len(self.decoders), 0, -1):
|
208 |
+
x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb)
|
209 |
+
|
210 |
+
# ---------Out-Convolution ------------
|
211 |
+
y_hor = self.outc(x[0])
|
212 |
+
y_ver = [outc_ver_i(x[i+1])
|
213 |
+
for i, outc_ver_i in enumerate(self.outc_ver)]
|
214 |
+
|
215 |
+
return y_hor # , y_ver
|
216 |
+
|
217 |
+
def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs):
|
218 |
+
return self.forward(*args, **kwargs)
|
219 |
+
|
220 |
+
|
221 |
+
if __name__ == '__main__':
|
222 |
+
model = UNet(in_ch=3)
|
223 |
+
input = torch.randn((1, 3, 16, 128, 128))
|
224 |
+
time = torch.randn((1,))
|
225 |
+
out_hor, out_ver = model(input, time)
|
226 |
+
print(out_hor[0].shape)
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/ddpm/util.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
import math
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
from einops import repeat
|
17 |
+
|
18 |
+
# from ldm.util import instantiate_from_config
|
19 |
+
|
20 |
+
|
21 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
22 |
+
if schedule == "linear":
|
23 |
+
betas = (
|
24 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
25 |
+
)
|
26 |
+
|
27 |
+
elif schedule == "cosine":
|
28 |
+
timesteps = (
|
29 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
30 |
+
)
|
31 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
32 |
+
alphas = torch.cos(alphas).pow(2)
|
33 |
+
alphas = alphas / alphas[0]
|
34 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
35 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
36 |
+
|
37 |
+
elif schedule == "sqrt_linear":
|
38 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
39 |
+
elif schedule == "sqrt":
|
40 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
41 |
+
else:
|
42 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
43 |
+
return betas.numpy()
|
44 |
+
|
45 |
+
|
46 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
47 |
+
if ddim_discr_method == 'uniform':
|
48 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
49 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
50 |
+
elif ddim_discr_method == 'quad':
|
51 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
52 |
+
else:
|
53 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
54 |
+
|
55 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
56 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
57 |
+
if c != 1:
|
58 |
+
steps_out = ddim_timesteps + 1
|
59 |
+
else:
|
60 |
+
steps_out = ddim_timesteps
|
61 |
+
if verbose:
|
62 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
63 |
+
return steps_out
|
64 |
+
|
65 |
+
|
66 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
67 |
+
# select alphas for computing the variance schedule
|
68 |
+
|
69 |
+
alphas = alphacums[ddim_timesteps]
|
70 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
71 |
+
|
72 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
73 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
74 |
+
if verbose:
|
75 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
76 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
77 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
78 |
+
return sigmas, alphas, alphas_prev
|
79 |
+
|
80 |
+
|
81 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
82 |
+
"""
|
83 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
84 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
85 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
86 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
87 |
+
produces the cumulative product of (1-beta) up to that
|
88 |
+
part of the diffusion process.
|
89 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
90 |
+
prevent singularities.
|
91 |
+
"""
|
92 |
+
betas = []
|
93 |
+
for i in range(num_diffusion_timesteps):
|
94 |
+
t1 = i / num_diffusion_timesteps
|
95 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
96 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
97 |
+
return np.array(betas)
|
98 |
+
|
99 |
+
|
100 |
+
def extract_into_tensor(a, t, x_shape):
|
101 |
+
b, *_ = t.shape
|
102 |
+
out = a.gather(-1, t)
|
103 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
104 |
+
|
105 |
+
|
106 |
+
def checkpoint(func, inputs, params, flag):
|
107 |
+
"""
|
108 |
+
Evaluate a function without caching intermediate activations, allowing for
|
109 |
+
reduced memory at the expense of extra compute in the backward pass.
|
110 |
+
:param func: the function to evaluate.
|
111 |
+
:param inputs: the argument sequence to pass to `func`.
|
112 |
+
:param params: a sequence of parameters `func` depends on but does not
|
113 |
+
explicitly take as arguments.
|
114 |
+
:param flag: if False, disable gradient checkpointing.
|
115 |
+
"""
|
116 |
+
if flag:
|
117 |
+
args = tuple(inputs) + tuple(params)
|
118 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
119 |
+
else:
|
120 |
+
return func(*inputs)
|
121 |
+
|
122 |
+
|
123 |
+
class CheckpointFunction(torch.autograd.Function):
|
124 |
+
@staticmethod
|
125 |
+
def forward(ctx, run_function, length, *args):
|
126 |
+
ctx.run_function = run_function
|
127 |
+
ctx.input_tensors = list(args[:length])
|
128 |
+
ctx.input_params = list(args[length:])
|
129 |
+
|
130 |
+
with torch.no_grad():
|
131 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
132 |
+
return output_tensors
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def backward(ctx, *output_grads):
|
136 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
137 |
+
with torch.enable_grad():
|
138 |
+
# Fixes a bug where the first op in run_function modifies the
|
139 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
140 |
+
# Tensors.
|
141 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
142 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
143 |
+
input_grads = torch.autograd.grad(
|
144 |
+
output_tensors,
|
145 |
+
ctx.input_tensors + ctx.input_params,
|
146 |
+
output_grads,
|
147 |
+
allow_unused=True,
|
148 |
+
)
|
149 |
+
del ctx.input_tensors
|
150 |
+
del ctx.input_params
|
151 |
+
del output_tensors
|
152 |
+
return (None, None) + input_grads
|
153 |
+
|
154 |
+
|
155 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
156 |
+
"""
|
157 |
+
Create sinusoidal timestep embeddings.
|
158 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
159 |
+
These may be fractional.
|
160 |
+
:param dim: the dimension of the output.
|
161 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
162 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
163 |
+
"""
|
164 |
+
if not repeat_only:
|
165 |
+
half = dim // 2
|
166 |
+
freqs = torch.exp(
|
167 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
168 |
+
).to(device=timesteps.device)
|
169 |
+
args = timesteps[:, None].float() * freqs[None]
|
170 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
171 |
+
if dim % 2:
|
172 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
173 |
+
else:
|
174 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
175 |
+
return embedding
|
176 |
+
|
177 |
+
|
178 |
+
def zero_module(module):
|
179 |
+
"""
|
180 |
+
Zero out the parameters of a module and return it.
|
181 |
+
"""
|
182 |
+
for p in module.parameters():
|
183 |
+
p.detach().zero_()
|
184 |
+
return module
|
185 |
+
|
186 |
+
|
187 |
+
def scale_module(module, scale):
|
188 |
+
"""
|
189 |
+
Scale the parameters of a module and return it.
|
190 |
+
"""
|
191 |
+
for p in module.parameters():
|
192 |
+
p.detach().mul_(scale)
|
193 |
+
return module
|
194 |
+
|
195 |
+
|
196 |
+
def mean_flat(tensor):
|
197 |
+
"""
|
198 |
+
Take the mean over all non-batch dimensions.
|
199 |
+
"""
|
200 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
201 |
+
|
202 |
+
|
203 |
+
def normalization(channels):
|
204 |
+
"""
|
205 |
+
Make a standard normalization layer.
|
206 |
+
:param channels: number of input channels.
|
207 |
+
:return: an nn.Module for normalization.
|
208 |
+
"""
|
209 |
+
return GroupNorm32(32, channels)
|
210 |
+
|
211 |
+
|
212 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
213 |
+
class SiLU(nn.Module):
|
214 |
+
def forward(self, x):
|
215 |
+
return x * torch.sigmoid(x)
|
216 |
+
|
217 |
+
|
218 |
+
class GroupNorm32(nn.GroupNorm):
|
219 |
+
def forward(self, x):
|
220 |
+
return super().forward(x.float()).type(x.dtype)
|
221 |
+
|
222 |
+
def conv_nd(dims, *args, **kwargs):
|
223 |
+
"""
|
224 |
+
Create a 1D, 2D, or 3D convolution module.
|
225 |
+
"""
|
226 |
+
if dims == 1:
|
227 |
+
return nn.Conv1d(*args, **kwargs)
|
228 |
+
elif dims == 2:
|
229 |
+
return nn.Conv2d(*args, **kwargs)
|
230 |
+
elif dims == 3:
|
231 |
+
return nn.Conv3d(*args, **kwargs)
|
232 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
233 |
+
|
234 |
+
|
235 |
+
def linear(*args, **kwargs):
|
236 |
+
"""
|
237 |
+
Create a linear module.
|
238 |
+
"""
|
239 |
+
return nn.Linear(*args, **kwargs)
|
240 |
+
|
241 |
+
|
242 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
243 |
+
"""
|
244 |
+
Create a 1D, 2D, or 3D average pooling module.
|
245 |
+
"""
|
246 |
+
if dims == 1:
|
247 |
+
return nn.AvgPool1d(*args, **kwargs)
|
248 |
+
elif dims == 2:
|
249 |
+
return nn.AvgPool2d(*args, **kwargs)
|
250 |
+
elif dims == 3:
|
251 |
+
return nn.AvgPool3d(*args, **kwargs)
|
252 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
253 |
+
|
254 |
+
|
255 |
+
class HybridConditioner(nn.Module):
|
256 |
+
|
257 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
258 |
+
super().__init__()
|
259 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
260 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
261 |
+
|
262 |
+
def forward(self, c_concat, c_crossattn):
|
263 |
+
c_concat = self.concat_conditioner(c_concat)
|
264 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
265 |
+
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
266 |
+
|
267 |
+
|
268 |
+
def noise_like(shape, device, repeat=False):
|
269 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
270 |
+
noise = lambda: torch.randn(shape, device=device)
|
271 |
+
return repeat_noise() if repeat else noise()
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.84 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .vqgan import VQGAN
|
2 |
+
from .codebook import Codebook
|
3 |
+
from .lpips import LPIPS
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (276 Bytes). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc
ADDED
Binary file (3.4 kB). View file
|
|
Generation_Pipeline_filter_all2/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc
ADDED
Binary file (6.78 kB). View file
|
|