Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import pretrainedmodels | |
from torchvision.models import densenet121 | |
from layers import Flatten | |
import torch | |
import torchvision.transforms as transforms | |
from pathlib import Path | |
from constant import IMAGENET_MEAN, IMAGENET_STD | |
import os | |
import sys | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
yolov9 = os.path.join(script_dir, '..', 'chestXray14') | |
sys.path.append(yolov9) | |
class ChexNet(nn.Module): | |
tfm = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) | |
]) | |
def __init__(self, trained=False, model_name='20180525-222635'): | |
super().__init__() | |
# chexnet.parameters() is freezed except head | |
if trained: | |
self.load_model(model_name) | |
else: | |
self.load_pretrained() | |
def load_model(self, model_name): | |
self.backbone = densenet121(False).features | |
self.head = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
Flatten(), | |
nn.Linear(1024, 14) | |
) | |
path = Path('chestX-ray-14') | |
state_dict = torch.load('chexnet.h5') | |
self.load_state_dict(state_dict) | |
def load_pretrained(self, torch=False): | |
if torch: | |
self.backbone = densenet121(True).features | |
else: | |
self.backbone = pretrainedmodels.__dict__['densenet121']().features | |
self.head = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
Flatten(), | |
nn.Linear(1024, 14) | |
) | |
def forward(self, x): | |
return self.head(self.backbone(x)) | |
def predict(self, image): | |
""" | |
input: PIL image (w, h, c) | |
output: prob np.array | |
""" | |
image_tensor = self.tfm(image).unsqueeze(0) # Add batch dimension | |
image_tensor = image_tensor.to(next(self.parameters()).device) # Move to the same device as the model | |
with torch.no_grad(): | |
py = torch.sigmoid(self(image_tensor)) | |
prob = py.cpu().numpy()[0] | |
return prob | |