# Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp import shutil import warnings import mmcv from mmocr import digit_version from mmocr.utils import list_from_file class LmdbAnnFileBackend: """Lmdb storage backend for annotation file. Args: lmdb_path (str): Lmdb file path. """ def __init__(self, lmdb_path, encoding='utf8'): self.lmdb_path = lmdb_path self.encoding = encoding env = self._get_env() with env.begin(write=False) as txn: self.total_number = int( txn.get('total_number'.encode('utf-8')).decode(self.encoding)) def __getitem__(self, index): """Retrieve one line from lmdb file by index.""" # only attach env to self when __getitem__ is called # because env object cannot be pickle if not hasattr(self, 'env'): self.env = self._get_env() with self.env.begin(write=False) as txn: line = txn.get(str(index).encode('utf-8')).decode(self.encoding) return line def __len__(self): return self.total_number def _get_env(self): try: import lmdb except ImportError: raise ImportError( 'Please install lmdb to enable LmdbAnnFileBackend.') return lmdb.open( self.lmdb_path, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False, ) def close(self): self.env.close() class HardDiskAnnFileBackend: """Load annotation file with raw hard disks storage backend.""" def __init__(self, file_format='txt'): assert file_format in ['txt', 'lmdb'] self.file_format = file_format def __call__(self, ann_file): if self.file_format == 'lmdb': return LmdbAnnFileBackend(ann_file) return list_from_file(ann_file) class PetrelAnnFileBackend: """Load annotation file with petrel storage backend.""" def __init__(self, file_format='txt', save_dir='tmp_dir'): assert file_format in ['txt', 'lmdb'] self.file_format = file_format self.save_dir = save_dir def __call__(self, ann_file): file_client = mmcv.FileClient(backend='petrel') if self.file_format == 'lmdb': mmcv_version = digit_version(mmcv.__version__) if mmcv_version < digit_version('1.3.16'): raise Exception('Please update mmcv to 1.3.16 or higher ' 'to enable "get_local_path" of "FileClient".') assert file_client.isdir(ann_file) files = file_client.list_dir_or_file(ann_file) ann_file_rel_path = ann_file.split('s3://')[-1] ann_file_dir = osp.dirname(ann_file_rel_path) ann_file_name = osp.basename(ann_file_rel_path) local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name) if osp.exists(local_dir): warnings.warn( f'local_ann_file: {local_dir} is already existed and ' 'will be used. If it is not the correct ann_file ' 'corresponding to {ann_file}, please remove it or ' 'change "save_dir" first then try again.') else: os.makedirs(local_dir, exist_ok=True) print(f'Fetching {ann_file} to {local_dir}...') for each_file in files: tmp_file_path = file_client.join_path(ann_file, each_file) with file_client.get_local_path( tmp_file_path) as local_path: shutil.copy(local_path, osp.join(local_dir, each_file)) return LmdbAnnFileBackend(local_dir) lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') return [x for x in lines if x.strip() != ''] class HTTPAnnFileBackend: """Load annotation file with http storage backend.""" def __init__(self, file_format='txt'): assert file_format in ['txt', 'lmdb'] self.file_format = file_format def __call__(self, ann_file): file_client = mmcv.FileClient(backend='http') if self.file_format == 'lmdb': raise NotImplementedError( 'Loading lmdb file on http is not supported yet.') lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') return [x for x in lines if x.strip() != '']