File size: 4,992 Bytes
9da7c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import copy
import glob
import os
from multiprocessing.dummy import Pool as ThreadPool

from PIL import Image
from torchvision.transforms.functional import to_tensor

from ..Models import *


class ImageSplitter:
    # key points:
    # Boarder padding and over-lapping img splitting to avoid the instability of edge value
    # Thanks Waifu2x's autorh nagadomi for suggestions (https://github.com/nagadomi/waifu2x/issues/238)

    def __init__(self, seg_size=48, scale_factor=2, boarder_pad_size=3):
        self.seg_size = seg_size
        self.scale_factor = scale_factor
        self.pad_size = boarder_pad_size
        self.height = 0
        self.width = 0
        self.upsampler = nn.Upsample(scale_factor=scale_factor, mode='bilinear')

    def split_img_tensor(self, pil_img, scale_method=Image.BILINEAR, img_pad=0):
        # resize image and convert them into tensor
        img_tensor = to_tensor(pil_img).unsqueeze(0)
        img_tensor = nn.ReplicationPad2d(self.pad_size)(img_tensor)
        batch, channel, height, width = img_tensor.size()
        self.height = height
        self.width = width

        if scale_method is not None:
            img_up = pil_img.resize((2 * pil_img.size[0], 2 * pil_img.size[1]), scale_method)
            img_up = to_tensor(img_up).unsqueeze(0)
            img_up = nn.ReplicationPad2d(self.pad_size * self.scale_factor)(img_up)

        patch_box = []
        # avoid the residual part is smaller than the padded size
        if height % self.seg_size < self.pad_size or width % self.seg_size < self.pad_size:
            self.seg_size += self.scale_factor * self.pad_size

        # split image into over-lapping pieces
        for i in range(self.pad_size, height, self.seg_size):
            for j in range(self.pad_size, width, self.seg_size):
                part = img_tensor[:, :,
                       (i - self.pad_size):min(i + self.pad_size + self.seg_size, height),
                       (j - self.pad_size):min(j + self.pad_size + self.seg_size, width)]
                if img_pad > 0:
                    part = nn.ZeroPad2d(img_pad)(part)
                if scale_method is not None:
                    # part_up = self.upsampler(part)
                    part_up = img_up[:, :,
                              self.scale_factor * (i - self.pad_size):min(i + self.pad_size + self.seg_size,
                                                                          height) * self.scale_factor,
                              self.scale_factor * (j - self.pad_size):min(j + self.pad_size + self.seg_size,
                                                                          width) * self.scale_factor]

                    patch_box.append((part, part_up))
                else:
                    patch_box.append(part)
        return patch_box

    def merge_img_tensor(self, list_img_tensor):
        out = torch.zeros((1, 3, self.height * self.scale_factor, self.width * self.scale_factor))
        img_tensors = copy.copy(list_img_tensor)
        rem = self.pad_size * 2

        pad_size = self.scale_factor * self.pad_size
        seg_size = self.scale_factor * self.seg_size
        height = self.scale_factor * self.height
        width = self.scale_factor * self.width
        for i in range(pad_size, height, seg_size):
            for j in range(pad_size, width, seg_size):
                part = img_tensors.pop(0)
                part = part[:, :, rem:-rem, rem:-rem]
                # might have error
                if len(part.size()) > 3:
                    _, _, p_h, p_w = part.size()
                    out[:, :, i:i + p_h, j:j + p_w] = part
                # out[:,:,
                # self.scale_factor*i:self.scale_factor*i+p_h,
                # self.scale_factor*j:self.scale_factor*j+p_w] = part
        out = out[:, :, rem:-rem, rem:-rem]
        return out


def load_single_image(img_file,
                      up_scale=False,
                      up_scale_factor=2,
                      up_scale_method=Image.BILINEAR,
                      zero_padding=False):
    img = Image.open(img_file).convert("RGB")
    out = to_tensor(img).unsqueeze(0)
    if zero_padding:
        out = nn.ZeroPad2d(zero_padding)(out)
    if up_scale:
        size = tuple(map(lambda x: x * up_scale_factor, img.size))
        img_up = img.resize(size, up_scale_method)
        img_up = to_tensor(img_up).unsqueeze(0)
        out = (out, img_up)

    return out


def standardize_img_format(img_folder):
    def process(img_file):
        img_path = os.path.dirname(img_file)
        img_name, _ = os.path.basename(img_file).split(".")
        out = os.path.join(img_path, img_name + ".JPEG")
        os.rename(img_file, out)

    list_imgs = []
    for i in ['png', "jpeg", 'jpg']:
        list_imgs.extend(glob.glob(img_folder + "**/*." + i, recursive=True))
    print("Found {} images.".format(len(list_imgs)))
    pool = ThreadPool(4)
    pool.map(process, list_imgs)
    pool.close()
    pool.join()