|
import os |
|
import numpy as np |
|
from PIL import Image |
|
import cv2 |
|
import torch |
|
from torch.utils.data import DataLoader |
|
import albumentations as A |
|
import pytorch_lightning as pl |
|
from transformers import AutoImageProcessor |
|
from datasets import Dataset, DatasetDict |
|
|
|
|
|
MODEL_CHECKPOINT = "apple/deeplabv3-mobilevit-xx-small" |
|
|
|
IMG_SIZE = [256, 256] |
|
|
|
|
|
class FluorescentNeuronalDataModule(pl.LightningDataModule): |
|
def __init__(self, batch_size, data_dir, dataset_size=1.0): |
|
super().__init__() |
|
self.data_dir = data_dir |
|
self.batch_size = batch_size |
|
self.image_processor = AutoImageProcessor.from_pretrained( |
|
MODEL_CHECKPOINT, do_reduce_labels=False |
|
) |
|
self.image_resizer = A.Compose( |
|
[ |
|
A.Resize( |
|
width=IMG_SIZE[0], |
|
height=IMG_SIZE[1], |
|
interpolation=cv2.INTER_NEAREST, |
|
) |
|
] |
|
) |
|
self.image_augmentator = A.Compose( |
|
[ |
|
A.HorizontalFlip(p=0.6), |
|
A.VerticalFlip(p=0.6), |
|
A.RandomBrightnessContrast(p=0.6), |
|
A.RandomGamma(p=0.6), |
|
A.HueSaturationValue(p=0.6), |
|
] |
|
) |
|
|
|
|
|
self.dataset_size = dataset_size |
|
|
|
def _create_dataset(self): |
|
images_path = os.path.join(self.data_dir, "all_images", "images") |
|
masks_path = os.path.join(self.data_dir, "all_masks", "masks") |
|
list_images = os.listdir(images_path) |
|
|
|
|
|
if self.dataset_size < 1.0: |
|
n_images = int(len(list_images) * self.dataset_size) |
|
list_images = list_images[:n_images] |
|
|
|
images = [] |
|
masks = [] |
|
for image_filename in list_images: |
|
image_path = os.path.join(images_path, image_filename) |
|
mask_path = os.path.join(masks_path, image_filename) |
|
|
|
image = np.array(Image.open(image_path).convert("RGB"), dtype=np.uint8) |
|
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) |
|
mask = (mask / 255).astype(np.uint8) |
|
|
|
images.append(image) |
|
masks.append(mask) |
|
|
|
dataset = Dataset.from_dict({"image": images, "mask": masks}) |
|
|
|
|
|
dataset = dataset.train_test_split(test_size=0.1) |
|
train_val = dataset["train"] |
|
test_ds = dataset["test"] |
|
del dataset |
|
|
|
train_val = train_val.train_test_split(test_size=0.2) |
|
train_ds = train_val["train"] |
|
valid_ds = train_val["test"] |
|
del train_val |
|
|
|
dataset = DatasetDict( |
|
{"train": train_ds, "validation": valid_ds, "test": test_ds} |
|
) |
|
del train_ds, valid_ds, test_ds |
|
return dataset |
|
|
|
def _transform_train_data(self, batch): |
|
|
|
images, masks = [], [] |
|
for i, m in zip(batch["image"], batch["mask"]): |
|
img = np.asarray(i, dtype=np.uint8) |
|
mask = np.asarray(m, dtype=np.uint8) |
|
|
|
resized_outputs = self.image_resizer(image=img, mask=mask) |
|
images.append(resized_outputs["image"]) |
|
masks.append(resized_outputs["mask"]) |
|
|
|
|
|
augmented_outputs = self.image_augmentator( |
|
image=resized_outputs["image"], mask=resized_outputs["mask"] |
|
) |
|
images.append(augmented_outputs["image"]) |
|
masks.append(augmented_outputs["mask"]) |
|
|
|
inputs = self.image_processor( |
|
images=images, |
|
return_tensors="pt", |
|
) |
|
inputs["labels"] = torch.tensor(masks, dtype=torch.long) |
|
return inputs |
|
|
|
def _transform_data(self, batch): |
|
|
|
images, masks = [], [] |
|
for i, m in zip(batch["image"], batch["mask"]): |
|
img = np.asarray(i, dtype=np.uint8) |
|
mask = np.asarray(m, dtype=np.uint8) |
|
|
|
resized_outputs = self.image_resizer(image=img, mask=mask) |
|
images.append(resized_outputs["image"]) |
|
masks.append(resized_outputs["mask"]) |
|
|
|
inputs = self.image_processor( |
|
images=images, |
|
return_tensors="pt", |
|
) |
|
inputs["labels"] = inputs["labels"] = torch.tensor(masks, dtype=torch.long) |
|
return inputs |
|
|
|
def setup(self, stage=None): |
|
dataset = self._create_dataset() |
|
train_ds = dataset["train"] |
|
valid_ds = dataset["validation"] |
|
test_ds = dataset["test"] |
|
|
|
if stage is None or stage == "fit": |
|
self.train_ds = train_ds.with_transform(self._transform_train_data) |
|
self.valid_ds = valid_ds.with_transform(self._transform_data) |
|
if stage is None or stage == "test" or stage == "predict": |
|
self.test_ds = test_ds.with_transform(self._transform_data) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.valid_ds, batch_size=self.batch_size) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.test_ds, batch_size=self.batch_size) |
|
|
|
def predict_dataloader(self): |
|
return DataLoader(self.test_ds, batch_size=self.batch_size) |
|
|