Spaces:
Sleeping
Sleeping
import pickle | |
import random | |
from collections import namedtuple | |
from typing import Tuple | |
import cv2 | |
from imdb import Cinemagoer | |
import numpy as np | |
from path import Path | |
Sample = namedtuple('Sample', 'gt_text, file_path') | |
Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size') | |
class DataLoaderIAM: | |
""" | |
Loads data which corresponds to IAM format, | |
see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database | |
""" | |
def __init__(self, | |
data_dir: Path, | |
batch_size: int, | |
data_split: float = 0.95, | |
fast: bool = True) -> None: | |
"""Loader for dataset.""" | |
assert data_dir.exists() | |
self.fast = fast | |
if fast: | |
self.env = Cinemagoer.open(str(data_dir / 'lmdb'), readonly=True) | |
self.data_augmentation = False | |
self.curr_idx = 0 | |
self.batch_size = batch_size | |
self.samples = [] | |
f = open(data_dir / 'gt/words.txt') | |
chars = set() | |
bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset | |
for line in f: | |
# ignore comment line | |
if not line or line[0] == '#': | |
continue | |
line_split = line.strip().split(' ') | |
assert len(line_split) >= 9 | |
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png | |
file_name_split = line_split[0].split('-') | |
file_name_subdir1 = file_name_split[0] | |
file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}' | |
file_base_name = line_split[0] + '.png' | |
file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name | |
if line_split[0] in bad_samples_reference: | |
print('Ignoring known broken image:', file_name) | |
continue | |
# GT text are columns starting at 9 | |
gt_text = ' '.join(line_split[8:]) | |
chars = chars.union(set(list(gt_text))) | |
# put sample into list | |
self.samples.append(Sample(gt_text, file_name)) | |
# split into training and validation set: 95% - 5% | |
split_idx = int(data_split * len(self.samples)) | |
self.train_samples = self.samples[:split_idx] | |
self.validation_samples = self.samples[split_idx:] | |
# put words into lists | |
self.train_words = [x.gt_text for x in self.train_samples] | |
self.validation_words = [x.gt_text for x in self.validation_samples] | |
# start with train set | |
self.train_set() | |
# list of all chars in dataset | |
self.char_list = sorted(list(chars)) | |
def train_set(self) -> None: | |
"""Switch to randomly chosen subset of training set.""" | |
self.data_augmentation = True | |
self.curr_idx = 0 | |
random.shuffle(self.train_samples) | |
self.samples = self.train_samples | |
self.curr_set = 'train' | |
def validation_set(self) -> None: | |
"""Switch to validation set.""" | |
self.data_augmentation = False | |
self.curr_idx = 0 | |
self.samples = self.validation_samples | |
self.curr_set = 'val' | |
def get_iterator_info(self) -> Tuple[int, int]: | |
"""Current batch index and overall number of batches.""" | |
if self.curr_set == 'train': | |
num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches | |
else: | |
num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller | |
curr_batch = self.curr_idx // self.batch_size + 1 | |
return curr_batch, num_batches | |
def has_next(self) -> bool: | |
"""Is there a next element?""" | |
if self.curr_set == 'train': | |
return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches | |
else: | |
return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller | |
def _get_img(self, i: int) -> np.ndarray: | |
if self.fast: | |
with self.env.begin() as txn: | |
basename = Path(self.samples[i].file_path).basename() | |
data = txn.get(basename.encode("ascii")) | |
img = pickle.loads(data) | |
else: | |
img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) | |
return img | |
def get_next(self) -> Batch: | |
"""Get next element.""" | |
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples))) | |
imgs = [self._get_img(i) for i in batch_range] | |
gt_texts = [self.samples[i].gt_text for i in batch_range] | |
self.curr_idx += self.batch_size | |
return Batch(imgs, gt_texts, len(imgs)) | |