import pickle import random from collections import namedtuple from typing import Tuple import cv2 from imdb import Cinemagoer import numpy as np from path import Path Sample = namedtuple('Sample', 'gt_text, file_path') Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size') class DataLoaderIAM: """ Loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database """ def __init__(self, data_dir: Path, batch_size: int, data_split: float = 0.95, fast: bool = True) -> None: """Loader for dataset.""" assert data_dir.exists() self.fast = fast if fast: self.env = Cinemagoer.open(str(data_dir / 'lmdb'), readonly=True) self.data_augmentation = False self.curr_idx = 0 self.batch_size = batch_size self.samples = [] f = open(data_dir / 'gt/words.txt') chars = set() bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset for line in f: # ignore comment line if not line or line[0] == '#': continue line_split = line.strip().split(' ') assert len(line_split) >= 9 # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png file_name_split = line_split[0].split('-') file_name_subdir1 = file_name_split[0] file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}' file_base_name = line_split[0] + '.png' file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name if line_split[0] in bad_samples_reference: print('Ignoring known broken image:', file_name) continue # GT text are columns starting at 9 gt_text = ' '.join(line_split[8:]) chars = chars.union(set(list(gt_text))) # put sample into list self.samples.append(Sample(gt_text, file_name)) # split into training and validation set: 95% - 5% split_idx = int(data_split * len(self.samples)) self.train_samples = self.samples[:split_idx] self.validation_samples = self.samples[split_idx:] # put words into lists self.train_words = [x.gt_text for x in self.train_samples] self.validation_words = [x.gt_text for x in self.validation_samples] # start with train set self.train_set() # list of all chars in dataset self.char_list = sorted(list(chars)) def train_set(self) -> None: """Switch to randomly chosen subset of training set.""" self.data_augmentation = True self.curr_idx = 0 random.shuffle(self.train_samples) self.samples = self.train_samples self.curr_set = 'train' def validation_set(self) -> None: """Switch to validation set.""" self.data_augmentation = False self.curr_idx = 0 self.samples = self.validation_samples self.curr_set = 'val' def get_iterator_info(self) -> Tuple[int, int]: """Current batch index and overall number of batches.""" if self.curr_set == 'train': num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches else: num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller curr_batch = self.curr_idx // self.batch_size + 1 return curr_batch, num_batches def has_next(self) -> bool: """Is there a next element?""" if self.curr_set == 'train': return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches else: return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller def _get_img(self, i: int) -> np.ndarray: if self.fast: with self.env.begin() as txn: basename = Path(self.samples[i].file_path).basename() data = txn.get(basename.encode("ascii")) img = pickle.loads(data) else: img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) return img def get_next(self) -> Batch: """Get next element.""" batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples))) imgs = [self._get_img(i) for i in batch_range] gt_texts = [self.samples[i].gt_text for i in batch_range] self.curr_idx += self.batch_size return Batch(imgs, gt_texts, len(imgs))