import torch.nn as nn import torch.nn.functional as F import torch from layers import SaveFeature import pretrainedmodels from torchvision.models import resnet34, resnet50, resnet101, resnet152 from pathlib import Path from torchvision.models.resnet import conv3x3, BasicBlock, Bottleneck import skimage from scipy import ndimage import numpy as np import torchvision.transforms as transforms import cv2 from constant import IMAGENET_MEAN, IMAGENET_STD device="cuda" if torch.cuda.is_available() else "cpu" class UpBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, expansion=1): super().__init__() inplanes = inplanes * expansion planes = planes * expansion self.upconv = nn.ConvTranspose2d(inplanes, planes, 2, 2, 0) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv1 = conv3x3(inplanes, planes) self.bn2 = nn.BatchNorm2d(planes) def forward(self, u, x): up = self.relu(self.bn1(self.upconv(u))) out = torch.cat([x, up], dim=1) # cat along channel out = self.relu(self.bn2(self.conv1(out))) return out class UpLayer(nn.Module): def __init__(self, block, inplanes, planes, blocks): super().__init__() self.up = UpBlock(inplanes, planes, block.expansion) layers = [block(planes * block.expansion, planes) for _ in range(1, blocks)] self.conv = nn.Sequential(*layers) def forward(self, u, x): x = self.up(u, x) x = self.conv(x) return x from pathlib import Path class Unet(nn.Module): tfm = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) ]) def __init__(self, trained=False, model_name=None): super().__init__() self.layers = [3, 4, 6] self.block = Bottleneck if trained: assert model_name is not None self.load_model(model_name) else: self.load_pretrained() def cut_model(self, model, cut): return list(model.children())[:cut] def load_model(self, model_name): resnet = resnet50(False) self.backbone = nn.Sequential(*self.cut_model(resnet, 8)) self.init_head() model_path = Path(__file__).parent / 'unet.h5' state_dict = torch.load(model_path, map_location=torch.device(device)) self.load_state_dict(state_dict) def load_pretrained(self, torch=False): if torch: resnet = resnet50(True) else: resnet = pretrainedmodels.__dict__['resnet50']() self.backbone = nn.Sequential(*self.cut_model(resnet, 8)) self.init_head() def init_head(self): self.sfs = [SaveFeature(self.backbone[i]) for i in [2, 4, 5, 6]] self.up_layer1 = UpLayer(self.block, 512, 256, self.layers[-1]) self.up_layer2 = UpLayer(self.block, 256, 128, self.layers[-2]) self.up_layer3 = UpLayer(self.block, 128, 64, self.layers[-3]) self.map = conv3x3(64 * self.block.expansion, 64) # 64e -> 64 self.conv = conv3x3(128, 64) self.bn_conv = nn.BatchNorm2d(64) self.up_conv = nn.ConvTranspose2d(64, 1, 2, 2, 0) self.bn_up = nn.BatchNorm2d(1) def forward(self, x): x = F.relu(self.backbone(x)) x = self.up_layer1(x, self.sfs[3].features) x = self.up_layer2(x, self.sfs[2].features) x = self.up_layer3(x, self.sfs[1].features) x = self.map(x) x = F.interpolate(x, scale_factor=2) x = torch.cat([self.sfs[0].features, x], dim=1) x = F.relu(self.bn_conv(self.conv(x))) x = F.relu(self.bn_up(self.up_conv(x))) return x def close(self): for sf in self.sfs: sf.remove() def segment(self, image): """ image: cropped CXR PIL Image (h, w, 3) """ kernel = np.ones((10, 10)) iw, ih = image.size image_tensor = self.tfm(image).unsqueeze(0).to(next(self.parameters()).device) with torch.no_grad(): py = torch.sigmoid(self(image_tensor)) py = (py[0].cpu() > 0.5).type(torch.FloatTensor) # 1, 256, 256 mask = py[0].numpy() mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) mask = cv2.resize(mask, (iw, ih)) slice_y, slice_x = ndimage.find_objects(mask, 1)[0] h, w = slice_y.stop - slice_y.start, slice_x.stop - slice_x.start nw, nh = int(w / .875), int(h / .875) dw, dh = (nw - w) // 2, (nh - h) // 2 t = max(slice_y.start - dh, 0) l = max(slice_x.start - dw, 0) b = min(slice_y.stop + dh, ih) r = min(slice_x.stop + dw, iw) return (t, l, b, r), mask