abhirajeshbhai's picture
loaded new colorizer weights
03856d4
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
image_transforms_rgb = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]),
torchvision.transforms.Grayscale()
])
image_transforms_gs = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.0], std=[1.0]),
])
class ConvBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super(ConvBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(True)
)
def forward(self, x):
return self.main(x)
class UNETFruitColor(nn.Module):
def __init__(self):
super(UNETFruitColor, self).__init__()
self.convs = [64, 128, 256, 512]
self.convEncoder = nn.ModuleList()
in_feature = 1
for conv in self.convs:
self.convEncoder.append(ConvBlock(in_feature, conv))
in_feature = conv
self.bottleNeck = ConvBlock(self.convs[-1], self.convs[-1]*2)
in_feature = self.convs[-1]*2
self.convDecoder = nn.ModuleList()
self.decoderUpConvs = nn.ModuleList()
for conv in self.convs[::-1]:
self.convDecoder.append(ConvBlock(in_feature, conv))
self.decoderUpConvs.append(nn.ConvTranspose2d(in_feature, conv, kernel_size=2, stride=2, padding=0))
in_feature = conv
# final conv and deconv
self.finalUpConv = nn.Conv2d(in_feature, 3, (1, 1))
self.sigmoid = nn.Sigmoid()
def forward(self,x):
skip_conns = []
for conv in self.convEncoder:
# conv
x = conv(x)
# append for skip conns
skip_conns.append(x)
# max pool
x = F.max_pool2d(x, (2,2), stride=2)
x = self.bottleNeck(x)
skip_conns = skip_conns[::-1]
for idx in range(len(self.convDecoder)):
# do upsample here
upconv = self.decoderUpConvs[idx]
deconv = self.convDecoder[idx]
skp = skip_conns[idx]
# do up conv
x = upconv(x)
# crop and cat
x_cat = torchvision.transforms.Resize((x.shape[2], x.shape[3]))(skp)
x = torch.cat([x_cat, x], dim=1)
# do deconv
x = deconv(x)
# final
x = self.finalUpConv(x)
# x = self.sigmoid(x)
return x
model = UNETFruitColor()
model = nn.DataParallel(model).to(device)
model.load_state_dict(torch.load("unet_colorizer_flickr_5_93_Ploss_10_14K.pth", map_location=device),strict=True)
model.eval()