import torch.nn as nn import pretrainedmodels from torchvision.models import densenet121 from layers import Flatten import torch import torchvision.transforms as transforms from pathlib import Path from constant import IMAGENET_MEAN, IMAGENET_STD import os import sys script_dir = os.path.dirname(os.path.abspath(__file__)) yolov9 = os.path.join(script_dir, '..', 'chestXray14') sys.path.append(yolov9) class ChexNet(nn.Module): tfm = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) ]) def __init__(self, trained=False, model_name='20180525-222635'): super().__init__() # chexnet.parameters() is freezed except head if trained: self.load_model(model_name) else: self.load_pretrained() def load_model(self, model_name): self.backbone = densenet121(False).features self.head = nn.Sequential( nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(1024, 14) ) path = Path('chestX-ray-14') state_dict = torch.load('chexnet.h5') self.load_state_dict(state_dict) def load_pretrained(self, torch=False): if torch: self.backbone = densenet121(True).features else: self.backbone = pretrainedmodels.__dict__['densenet121']().features self.head = nn.Sequential( nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(1024, 14) ) def forward(self, x): return self.head(self.backbone(x)) def predict(self, image): """ input: PIL image (w, h, c) output: prob np.array """ image_tensor = self.tfm(image).unsqueeze(0) # Add batch dimension image_tensor = image_tensor.to(next(self.parameters()).device) # Move to the same device as the model with torch.no_grad(): py = torch.sigmoid(self(image_tensor)) prob = py.cpu().numpy()[0] return prob