Spaces:
Running
on
Zero
Running
on
Zero
from typing import Iterator, List, Optional, Union | |
from collections import Counter | |
import logging | |
from operator import itemgetter | |
import random | |
import numpy as np | |
from torch.utils.data import DistributedSampler | |
from torch.utils.data.sampler import Sampler | |
LOGGER = logging.getLogger(__name__) | |
from torch.utils.data import Dataset, Sampler | |
class DatasetFromSampler(Dataset): | |
"""Dataset to create indexes from `Sampler`. | |
Args: | |
sampler: PyTorch sampler | |
""" | |
def __init__(self, sampler: Sampler): | |
"""Initialisation for DatasetFromSampler.""" | |
self.sampler = sampler | |
self.sampler_list = None | |
def __getitem__(self, index: int): | |
"""Gets element of the dataset. | |
Args: | |
index: index of the element in the dataset | |
Returns: | |
Single element by index | |
""" | |
if self.sampler_list is None: | |
self.sampler_list = list(self.sampler) | |
return self.sampler_list[index] | |
def __len__(self) -> int: | |
""" | |
Returns: | |
int: length of the dataset | |
""" | |
return len(self.sampler) | |
class BalanceClassSampler(Sampler): | |
"""Allows you to create stratified sample on unbalanced classes. | |
Args: | |
labels: list of class label for each elem in the dataset | |
mode: Strategy to balance classes. | |
Must be one of [downsampling, upsampling] | |
Python API examples: | |
.. code-block:: python | |
import os | |
from torch import nn, optim | |
from torch.utils.data import DataLoader | |
from catalyst import dl | |
from catalyst.data import ToTensor, BalanceClassSampler | |
from catalyst.contrib.datasets import MNIST | |
train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) | |
train_labels = train_data.targets.cpu().numpy().tolist() | |
train_sampler = BalanceClassSampler(train_labels, mode=5000) | |
valid_data = MNIST(os.getcwd(), train=False) | |
loaders = { | |
"train": DataLoader(train_data, sampler=train_sampler, batch_size=32), | |
"valid": DataLoader(valid_data, batch_size=32), | |
} | |
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.02) | |
runner = dl.SupervisedRunner() | |
# model training | |
runner.train( | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer, | |
loaders=loaders, | |
num_epochs=1, | |
logdir="./logs", | |
valid_loader="valid", | |
valid_metric="loss", | |
minimize_valid_metric=True, | |
verbose=True, | |
) | |
""" | |
def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"): | |
"""Sampler initialisation.""" | |
super().__init__(labels) | |
labels = np.array(labels) | |
samples_per_class = {label: (labels == label).sum() for label in set(labels)} | |
self.lbl2idx = { | |
label: np.arange(len(labels))[labels == label].tolist() | |
for label in set(labels) | |
} | |
if isinstance(mode, str): | |
assert mode in ["downsampling", "upsampling"] | |
if isinstance(mode, int) or mode == "upsampling": | |
samples_per_class = ( | |
mode if isinstance(mode, int) else max(samples_per_class.values()) | |
) | |
else: | |
samples_per_class = min(samples_per_class.values()) | |
self.labels = labels | |
self.samples_per_class = samples_per_class | |
self.length = self.samples_per_class * len(set(labels)) | |
def __iter__(self) -> Iterator[int]: | |
""" | |
Returns: | |
iterator of indices of stratified sample | |
""" | |
indices = [] | |
for key in sorted(self.lbl2idx): | |
replace_flag = self.samples_per_class > len(self.lbl2idx[key]) | |
indices += np.random.choice( | |
self.lbl2idx[key], self.samples_per_class, replace=replace_flag | |
).tolist() | |
assert len(indices) == self.length | |
np.random.shuffle(indices) | |
return iter(indices) | |
def __len__(self) -> int: | |
""" | |
Returns: | |
length of result sample | |
""" | |
return self.length | |
class BatchBalanceClassSampler(Sampler): | |
""" | |
This kind of sampler can be used for both metric learning and classification task. | |
BatchSampler with the given strategy for the C unique classes dataset: | |
- Selection `num_classes` of C classes for each batch | |
- Selection `num_samples` instances for each class in the batch | |
The epoch ends after `num_batches`. | |
So, the batch sise is `num_classes` * `num_samples`. | |
One of the purposes of this sampler is to be used for | |
forming triplets and pos/neg pairs inside the batch. | |
To guarante existance of these pairs in the batch, | |
`num_classes` and `num_samples` should be > 1. (1) | |
This type of sampling can be found in the classical paper of Person Re-Id, | |
where P (`num_classes`) equals 32 and K (`num_samples`) equals 4: | |
`In Defense of the Triplet Loss for Person Re-Identification`_. | |
Args: | |
labels: list of classes labeles for each elem in the dataset | |
num_classes: number of classes in a batch, should be > 1 | |
num_samples: number of instances of each class in a batch, should be > 1 | |
num_batches: number of batches in epoch | |
(default = len(labels) // (num_classes * num_samples)) | |
.. _In Defense of the Triplet Loss for Person Re-Identification: | |
https://arxiv.org/abs/1703.07737 | |
Python API examples: | |
.. code-block:: python | |
import os | |
from torch import nn, optim | |
from torch.utils.data import DataLoader | |
from catalyst import dl | |
from catalyst.data import ToTensor, BatchBalanceClassSampler | |
from catalyst.contrib.datasets import MNIST | |
train_data = MNIST(os.getcwd(), train=True, download=True) | |
train_labels = train_data.targets.cpu().numpy().tolist() | |
train_sampler = BatchBalanceClassSampler( | |
train_labels, num_classes=10, num_samples=4) | |
valid_data = MNIST(os.getcwd(), train=False) | |
loaders = { | |
"train": DataLoader(train_data, batch_sampler=train_sampler), | |
"valid": DataLoader(valid_data, batch_size=32), | |
} | |
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.02) | |
runner = dl.SupervisedRunner() | |
# model training | |
runner.train( | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer, | |
loaders=loaders, | |
num_epochs=1, | |
logdir="./logs", | |
valid_loader="valid", | |
valid_metric="loss", | |
minimize_valid_metric=True, | |
verbose=True, | |
) | |
""" | |
def __init__( | |
self, | |
labels: Union[List[int], np.ndarray], | |
num_classes: int, | |
num_samples: int, | |
num_batches: int = None, | |
): | |
"""Sampler initialisation.""" | |
super().__init__(labels) | |
classes = set(labels) | |
assert isinstance(num_classes, int) and isinstance(num_samples, int) | |
assert (1 < num_classes <= len(classes)) and (1 < num_samples) | |
assert all( | |
n > 1 for n in Counter(labels).values() | |
), "Each class shoud contain at least 2 instances to fit (1)" | |
labels = np.array(labels) | |
self._labels = list(set(labels.tolist())) | |
self._num_classes = num_classes | |
self._num_samples = num_samples | |
self._batch_size = self._num_classes * self._num_samples | |
self._num_batches = num_batches or len(labels) // self._batch_size | |
self.lbl2idx = { | |
label: np.arange(len(labels))[labels == label].tolist() | |
for label in set(labels) | |
} | |
def batch_size(self) -> int: | |
""" | |
Returns: | |
this value should be used in DataLoader as batch size | |
""" | |
return self._batch_size | |
def batches_in_epoch(self) -> int: | |
""" | |
Returns: | |
number of batches in an epoch | |
""" | |
return self._num_batches | |
def __len__(self) -> int: | |
""" | |
Returns: | |
number of samples in an epoch | |
""" | |
return self._num_batches # * self._batch_size | |
def __iter__(self) -> Iterator[int]: | |
""" | |
Returns: | |
indeces for sampling dataset elems during an epoch | |
""" | |
indices = [] | |
for _ in range(self._num_batches): | |
batch_indices = [] | |
classes_for_batch = random.sample(self._labels, self._num_classes) | |
while self._num_classes != len(set(classes_for_batch)): | |
classes_for_batch = random.sample(self._labels, self._num_classes) | |
for cls_id in classes_for_batch: | |
replace_flag = self._num_samples > len(self.lbl2idx[cls_id]) | |
batch_indices += np.random.choice( | |
self.lbl2idx[cls_id], self._num_samples, replace=replace_flag | |
).tolist() | |
indices.append(batch_indices) | |
return iter(indices) | |
class DynamicBalanceClassSampler(Sampler): | |
""" | |
This kind of sampler can be used for classification tasks with significant | |
class imbalance. | |
The idea of this sampler that we start with the original class distribution | |
and gradually move to uniform class distribution like with downsampling. | |
Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class | |
i and :math: #C_min is a size of the rarest class, so :math: D_i define | |
class distribution. Also define :math: g(n_epoch) is a exponential | |
scheduler. On each epoch current :math: D_i calculated as | |
:math: current D_i = D_i ^ g(n_epoch), | |
after this data samples according this distribution. | |
Notes: | |
In the end of the training, epochs will contain only | |
min_size_class * n_classes examples. So, possible it will not | |
necessary to do validation on each epoch. For this reason use | |
ControlFlowCallback. | |
Examples: | |
>>> import torch | |
>>> import numpy as np | |
>>> from catalyst.data import DynamicBalanceClassSampler | |
>>> from torch.utils import data | |
>>> features = torch.Tensor(np.random.random((200, 100))) | |
>>> labels = np.random.randint(0, 4, size=(200,)) | |
>>> sampler = DynamicBalanceClassSampler(labels) | |
>>> labels = torch.LongTensor(labels) | |
>>> dataset = data.TensorDataset(features, labels) | |
>>> loader = data.dataloader.DataLoader(dataset, batch_size=8) | |
>>> for batch in loader: | |
>>> b_features, b_labels = batch | |
Sampler was inspired by https://arxiv.org/abs/1901.06783 | |
""" | |
def __init__( | |
self, | |
labels: List[Union[int, str]], | |
exp_lambda: float = 0.9, | |
start_epoch: int = 0, | |
max_d: Optional[int] = None, | |
mode: Union[str, int] = "downsampling", | |
ignore_warning: bool = False, | |
): | |
""" | |
Args: | |
labels: list of labels for each elem in the dataset | |
exp_lambda: exponent figure for schedule | |
start_epoch: start epoch number, can be useful for multi-stage | |
experiments | |
max_d: if not None, limit on the difference between the most | |
frequent and the rarest classes, heuristic | |
mode: number of samples per class in the end of training. Must be | |
"downsampling" or number. Before change it, make sure that you | |
understand how does it work | |
ignore_warning: ignore warning about min class size | |
""" | |
assert isinstance(start_epoch, int) | |
assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)" | |
super().__init__(labels) | |
self.exp_lambda = exp_lambda | |
if max_d is None: | |
max_d = np.inf | |
self.max_d = max_d | |
self.epoch = start_epoch | |
labels = np.array(labels) | |
samples_per_class = Counter(labels) | |
self.min_class_size = min(samples_per_class.values()) | |
if self.min_class_size < 100 and not ignore_warning: | |
LOGGER.warning( | |
f"the smallest class contains only" | |
f" {self.min_class_size} examples. At the end of" | |
f" training, epochs will contain only" | |
f" {self.min_class_size * len(samples_per_class)}" | |
f" examples" | |
) | |
self.original_d = { | |
key: value / self.min_class_size for key, value in samples_per_class.items() | |
} | |
self.label2idxes = { | |
label: np.arange(len(labels))[labels == label].tolist() | |
for label in set(labels) | |
} | |
if isinstance(mode, int): | |
self.min_class_size = mode | |
else: | |
assert mode == "downsampling" | |
self.labels = labels | |
self._update() | |
def _update(self) -> None: | |
"""Update d coefficients.""" | |
current_d = { | |
key: min(value ** self._exp_scheduler(), self.max_d) | |
for key, value in self.original_d.items() | |
} | |
samples_per_classes = { | |
key: int(value * self.min_class_size) for key, value in current_d.items() | |
} | |
self.samples_per_classes = samples_per_classes | |
self.length = np.sum(list(samples_per_classes.values())) | |
self.epoch += 1 | |
def _exp_scheduler(self) -> float: | |
return self.exp_lambda**self.epoch | |
def __iter__(self) -> Iterator[int]: | |
""" | |
Returns: | |
iterator of indices of stratified sample | |
""" | |
indices = [] | |
for key in sorted(self.label2idxes): | |
samples_per_class = self.samples_per_classes[key] | |
replace_flag = samples_per_class > len(self.label2idxes[key]) | |
indices += np.random.choice( | |
self.label2idxes[key], samples_per_class, replace=replace_flag | |
).tolist() | |
assert len(indices) == self.length | |
np.random.shuffle(indices) | |
self._update() | |
return iter(indices) | |
def __len__(self) -> int: | |
""" | |
Returns: | |
length of result sample | |
""" | |
return self.length | |
class MiniEpochSampler(Sampler): | |
""" | |
Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``. | |
Args: | |
data_len: Size of the dataset | |
mini_epoch_len: Num samples from the dataset used in one | |
mini epoch. | |
drop_last: If ``True``, sampler will drop the last batches | |
if its size would be less than ``batches_per_epoch`` | |
shuffle: one of ``"always"``, ``"real_epoch"``, or `None``. | |
The sampler will shuffle indices | |
> "per_mini_epoch" - every mini epoch (every ``__iter__`` call) | |
> "per_epoch" -- every real epoch | |
> None -- don't shuffle | |
Example: | |
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100) | |
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True) | |
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100, | |
>>> shuffle="per_epoch") | |
""" | |
def __init__( | |
self, | |
data_len: int, | |
mini_epoch_len: int, | |
drop_last: bool = False, | |
shuffle: str = None, | |
): | |
"""Sampler initialisation.""" | |
super().__init__(None) | |
self.data_len = int(data_len) | |
self.mini_epoch_len = int(mini_epoch_len) | |
self.steps = int(data_len / self.mini_epoch_len) | |
self.state_i = 0 | |
has_reminder = data_len - self.steps * mini_epoch_len > 0 | |
if self.steps == 0: | |
self.divider = 1 | |
elif has_reminder and not drop_last: | |
self.divider = self.steps + 1 | |
else: | |
self.divider = self.steps | |
self._indices = np.arange(self.data_len) | |
self.indices = self._indices | |
self.end_pointer = max(self.data_len, self.mini_epoch_len) | |
if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]): | |
raise ValueError( | |
"Shuffle must be one of ['per_mini_epoch', 'per_epoch']. " | |
+ f"Got {shuffle}" | |
) | |
self.shuffle_type = shuffle | |
def shuffle(self) -> None: | |
"""Shuffle sampler indices.""" | |
if self.shuffle_type == "per_mini_epoch" or ( | |
self.shuffle_type == "per_epoch" and self.state_i == 0 | |
): | |
if self.data_len >= self.mini_epoch_len: | |
self.indices = self._indices | |
np.random.shuffle(self.indices) | |
else: | |
self.indices = np.random.choice( | |
self._indices, self.mini_epoch_len, replace=True | |
) | |
def __iter__(self) -> Iterator[int]: | |
"""Iterate over sampler. | |
Returns: | |
python iterator | |
""" | |
self.state_i = self.state_i % self.divider | |
self.shuffle() | |
start = self.state_i * self.mini_epoch_len | |
stop = ( | |
self.end_pointer | |
if (self.state_i == self.steps) | |
else (self.state_i + 1) * self.mini_epoch_len | |
) | |
indices = self.indices[start:stop].tolist() | |
self.state_i += 1 | |
return iter(indices) | |
def __len__(self) -> int: | |
""" | |
Returns: | |
int: length of the mini-epoch | |
""" | |
return self.mini_epoch_len | |
class DistributedSamplerWrapper(DistributedSampler): | |
""" | |
Wrapper over `Sampler` for distributed training. | |
Allows you to use any sampler in distributed mode. | |
It is especially useful in conjunction with | |
`torch.nn.parallel.DistributedDataParallel`. In such case, each | |
process can pass a DistributedSamplerWrapper instance as a DataLoader | |
sampler, and load a subset of subsampled data of the original dataset | |
that is exclusive to it. | |
.. note:: | |
Sampler is assumed to be of constant size. | |
""" | |
def __init__( | |
self, | |
sampler, | |
num_replicas: Optional[int] = None, | |
rank: Optional[int] = None, | |
shuffle: bool = True, | |
): | |
""" | |
Args: | |
sampler: Sampler used for subsampling | |
num_replicas (int, optional): Number of processes participating in | |
distributed training | |
rank (int, optional): Rank of the current process | |
within ``num_replicas`` | |
shuffle (bool, optional): If true (default), | |
sampler will shuffle the indices | |
""" | |
super(DistributedSamplerWrapper, self).__init__( | |
DatasetFromSampler(sampler), | |
num_replicas=num_replicas, | |
rank=rank, | |
shuffle=shuffle, | |
) | |
self.sampler = sampler | |
def __iter__(self) -> Iterator[int]: | |
"""Iterate over sampler. | |
Returns: | |
python iterator | |
""" | |
self.dataset = DatasetFromSampler(self.sampler) | |
indexes_of_indexes = super().__iter__() | |
subsampler_indexes = self.dataset | |
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) | |
__all__ = [ | |
"BalanceClassSampler", | |
"BatchBalanceClassSampler", | |
"DistributedSamplerWrapper", | |
"DynamicBalanceClassSampler", | |
"MiniEpochSampler", | |
] | |