Spaces:
Running
Running
File size: 1,717 Bytes
01d43bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class ImageDataset(Dataset):
def __init__(self, dir, transform=None) -> None:
self.dir = dir
self.transform = transform
self.file_list = sorted(os.listdir(self.dir))
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
image_name = self.file_list[idx]
image_path = os.path.join(self.dir, image_name)
grayscale_image = Image.open(image_path).convert('L')
colorized_image = Image.open(image_path).convert('RGB')
if self.transform:
grayscale_image = self.transform(grayscale_image)
colorized_image = self.transform(colorized_image)
return grayscale_image, colorized_image
def show_image(image_tensor):
try:
if len(image_tensor) == 1:
plt.imshow(image_tensor[0], cmap="gray")
else:
plt.imshow(image_tensor.numpy().transpose(1, 2, 0))
except Exception as e:
print(f"Exception when showing image: {e}")
# To be able to calculate MSE loss in case output tensor has different shape from target tensor
def adjust_output_shape(output_tensor, target_tensor):
adjusted_tensor = torch.nn.functional.interpolate(output_tensor, size=target_tensor.shape[2:], mode="bilinear", align_corners=False)
return adjusted_tensor
def pil_to_torch(pil_image):
transform = transforms.ToTensor()
return transform(pil_image).unsqueeze(0)
def torch_to_pil(torch_image):
transform = transforms.ToPILImage()
return transform(torch_image.squeeze(0)) |