Holiday-StyleGAN-NADA / e4e /configs /transforms_config.py
mjdolan's picture
Duplicate from Gradio-Blocks/StyleGAN-NADA
07998f9
raw
history blame
1.77 kB
from abc import abstractmethod
import torchvision.transforms as transforms
class TransformsConfig(object):
def __init__(self, opts):
self.opts = opts
@abstractmethod
def get_transforms(self):
pass
class EncodeTransforms(TransformsConfig):
def __init__(self, opts):
super(EncodeTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': None,
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class CarsEncodeTransforms(TransformsConfig):
def __init__(self, opts):
super(CarsEncodeTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((192, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': None,
'transform_test': transforms.Compose([
transforms.Resize((192, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((192, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict