HandWriting-Text-Recognizer / app /dataloader_iam.py
Mattral's picture
Upload 13 files
57462b3 verified
raw
history blame
No virus
4.78 kB
import pickle
import random
from collections import namedtuple
from typing import Tuple
import cv2
from imdb import Cinemagoer
import numpy as np
from path import Path
Sample = namedtuple('Sample', 'gt_text, file_path')
Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')
class DataLoaderIAM:
"""
Loads data which corresponds to IAM format,
see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
"""
def __init__(self,
data_dir: Path,
batch_size: int,
data_split: float = 0.95,
fast: bool = True) -> None:
"""Loader for dataset."""
assert data_dir.exists()
self.fast = fast
if fast:
self.env = Cinemagoer.open(str(data_dir / 'lmdb'), readonly=True)
self.data_augmentation = False
self.curr_idx = 0
self.batch_size = batch_size
self.samples = []
f = open(data_dir / 'gt/words.txt')
chars = set()
bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset
for line in f:
# ignore comment line
if not line or line[0] == '#':
continue
line_split = line.strip().split(' ')
assert len(line_split) >= 9
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
file_name_split = line_split[0].split('-')
file_name_subdir1 = file_name_split[0]
file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}'
file_base_name = line_split[0] + '.png'
file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name
if line_split[0] in bad_samples_reference:
print('Ignoring known broken image:', file_name)
continue
# GT text are columns starting at 9
gt_text = ' '.join(line_split[8:])
chars = chars.union(set(list(gt_text)))
# put sample into list
self.samples.append(Sample(gt_text, file_name))
# split into training and validation set: 95% - 5%
split_idx = int(data_split * len(self.samples))
self.train_samples = self.samples[:split_idx]
self.validation_samples = self.samples[split_idx:]
# put words into lists
self.train_words = [x.gt_text for x in self.train_samples]
self.validation_words = [x.gt_text for x in self.validation_samples]
# start with train set
self.train_set()
# list of all chars in dataset
self.char_list = sorted(list(chars))
def train_set(self) -> None:
"""Switch to randomly chosen subset of training set."""
self.data_augmentation = True
self.curr_idx = 0
random.shuffle(self.train_samples)
self.samples = self.train_samples
self.curr_set = 'train'
def validation_set(self) -> None:
"""Switch to validation set."""
self.data_augmentation = False
self.curr_idx = 0
self.samples = self.validation_samples
self.curr_set = 'val'
def get_iterator_info(self) -> Tuple[int, int]:
"""Current batch index and overall number of batches."""
if self.curr_set == 'train':
num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches
else:
num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller
curr_batch = self.curr_idx // self.batch_size + 1
return curr_batch, num_batches
def has_next(self) -> bool:
"""Is there a next element?"""
if self.curr_set == 'train':
return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches
else:
return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller
def _get_img(self, i: int) -> np.ndarray:
if self.fast:
with self.env.begin() as txn:
basename = Path(self.samples[i].file_path).basename()
data = txn.get(basename.encode("ascii"))
img = pickle.loads(data)
else:
img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
return img
def get_next(self) -> Batch:
"""Get next element."""
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
imgs = [self._get_img(i) for i in batch_range]
gt_texts = [self.samples[i].gt_text for i in batch_range]
self.curr_idx += self.batch_size
return Batch(imgs, gt_texts, len(imgs))