|
from torchvision.transforms import ( |
|
Normalize, |
|
Compose, |
|
RandomResizedCrop, |
|
InterpolationMode, |
|
ToTensor, |
|
Resize, |
|
CenterCrop, |
|
) |
|
|
|
|
|
def _convert_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
|
|
def image_transform( |
|
image_size: int, |
|
is_train: bool, |
|
mean=(0.48145466, 0.4578275, 0.40821073), |
|
std=(0.26862954, 0.26130258, 0.27577711), |
|
): |
|
normalize = Normalize(mean=mean, std=std) |
|
if is_train: |
|
return Compose( |
|
[ |
|
RandomResizedCrop( |
|
image_size, |
|
scale=(0.9, 1.0), |
|
interpolation=InterpolationMode.BICUBIC, |
|
), |
|
_convert_to_rgb, |
|
ToTensor(), |
|
normalize, |
|
] |
|
) |
|
else: |
|
return Compose( |
|
[ |
|
Resize(image_size, interpolation=InterpolationMode.BICUBIC), |
|
CenterCrop(image_size), |
|
_convert_to_rgb, |
|
ToTensor(), |
|
normalize, |
|
] |
|
) |
|
|