File size: 2,012 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Simply load images from a folder or nested folders (does not have any split).
"""

import logging
from pathlib import Path

import omegaconf
import torch

from ..utils.image import ImagePreprocessor, load_image
from .base_dataset import BaseDataset


class ImageFolder(BaseDataset, torch.utils.data.Dataset):
    default_conf = {
        "glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
        "images": "???",
        "root_folder": "/",
        "preprocessing": ImagePreprocessor.default_conf,
    }

    def _init(self, conf):
        self.root = conf.root_folder
        if isinstance(conf.images, str):
            if not Path(conf.images).is_dir():
                with open(conf.images, "r") as f:
                    self.images = f.read().rstrip("\n").split("\n")
                logging.info(f"Found {len(self.images)} images in list file.")
            else:
                self.images = []
                glob = [conf.glob] if isinstance(conf.glob, str) else conf.glob
                for g in glob:
                    self.images += list(Path(conf.images).glob("**/" + g))
                if len(self.images) == 0:
                    raise ValueError(
                        f"Could not find any image in folder: {conf.images}."
                    )
                self.images = [i.relative_to(conf.images) for i in self.images]
                self.root = conf.images
                logging.info(f"Found {len(self.images)} images in folder.")
        elif isinstance(conf.images, omegaconf.listconfig.ListConfig):
            self.images = conf.images.to_container()
        else:
            raise ValueError(conf.images)

        self.preprocessor = ImagePreprocessor(conf.preprocessing)

    def get_dataset(self, split):
        return self

    def __getitem__(self, idx):
        path = self.images[idx]
        img = load_image(path)
        data = {"name": str(path), **self.preprocessor(img)}
        return data

    def __len__(self):
        return len(self.images)