Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torchvision | |
import torch.nn.functional as F | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
image_transforms_rgb = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize((256, 256)), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]), | |
torchvision.transforms.Grayscale() | |
]) | |
image_transforms_gs = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize((256, 256)), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.0], std=[1.0]), | |
]) | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channel, out_channel): | |
super(ConvBlock, self).__init__() | |
self.main = nn.Sequential( | |
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(out_channel), | |
nn.ReLU(True), | |
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(out_channel), | |
nn.ReLU(True) | |
) | |
def forward(self, x): | |
return self.main(x) | |
class UNETFruitColor(nn.Module): | |
def __init__(self): | |
super(UNETFruitColor, self).__init__() | |
self.convs = [64, 128, 256, 512] | |
self.convEncoder = nn.ModuleList() | |
in_feature = 1 | |
for conv in self.convs: | |
self.convEncoder.append(ConvBlock(in_feature, conv)) | |
in_feature = conv | |
self.bottleNeck = ConvBlock(self.convs[-1], self.convs[-1]*2) | |
in_feature = self.convs[-1]*2 | |
self.convDecoder = nn.ModuleList() | |
self.decoderUpConvs = nn.ModuleList() | |
for conv in self.convs[::-1]: | |
self.convDecoder.append(ConvBlock(in_feature, conv)) | |
self.decoderUpConvs.append(nn.ConvTranspose2d(in_feature, conv, kernel_size=2, stride=2, padding=0)) | |
in_feature = conv | |
# final conv and deconv | |
self.finalUpConv = nn.Conv2d(in_feature, 3, (1, 1)) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self,x): | |
skip_conns = [] | |
for conv in self.convEncoder: | |
# conv | |
x = conv(x) | |
# append for skip conns | |
skip_conns.append(x) | |
# max pool | |
x = F.max_pool2d(x, (2,2), stride=2) | |
x = self.bottleNeck(x) | |
skip_conns = skip_conns[::-1] | |
for idx in range(len(self.convDecoder)): | |
# do upsample here | |
upconv = self.decoderUpConvs[idx] | |
deconv = self.convDecoder[idx] | |
skp = skip_conns[idx] | |
# do up conv | |
x = upconv(x) | |
# crop and cat | |
x_cat = torchvision.transforms.Resize((x.shape[2], x.shape[3]))(skp) | |
x = torch.cat([x_cat, x], dim=1) | |
# do deconv | |
x = deconv(x) | |
# final | |
x = self.finalUpConv(x) | |
# x = self.sigmoid(x) | |
return x | |
model = UNETFruitColor() | |
model = nn.DataParallel(model).to(device) | |
model.load_state_dict(torch.load("unet_colorizer_flickr_5_93_Ploss_10_14K.pth", map_location=device),strict=True) | |
model.eval() |