Spaces:
Sleeping
Sleeping
""" | |
Utilities for splitting batches of examples into smaller sub-batches. | |
This is useful during training when the batch size is too large to fit on GPU, | |
meaning that gradient accumulation across multiple sub-batches must be used. | |
It is also useful for batching examples during evaluation. Unlike a naive | |
approach, this code groups examples with similar lengths to reduce the amount | |
of wasted computation due to padding. | |
""" | |
import numpy as np | |
def split(*data, costs, max_cost): | |
"""Splits a batch of input items into sub-batches. | |
Args: | |
*data: One or more lists of input items, all of the same length | |
costs: A list of costs for each item | |
max_cost: Maximum total cost for each sub-batch | |
Yields: | |
(example_ids, *subbatch_data) tuples. | |
""" | |
costs = np.asarray(costs, dtype=int) | |
costs_argsort = np.argsort(costs).tolist() | |
subbatch_size = 1 | |
while costs_argsort: | |
if subbatch_size == len(costs_argsort) or ( | |
subbatch_size * costs[costs_argsort[subbatch_size]] > max_cost | |
): | |
subbatch_item_ids = costs_argsort[:subbatch_size] | |
subbatch_data = [[items[i] for i in subbatch_item_ids] for items in data] | |
yield (subbatch_item_ids,) + tuple(subbatch_data) | |
costs_argsort = costs_argsort[subbatch_size:] | |
subbatch_size = 1 | |
else: | |
subbatch_size += 1 | |
def map(func, *data, costs, max_cost, **common_kwargs): | |
"""Maps a function over subbatches of input items. | |
Args: | |
func: Function to map over the data | |
*data: One or more lists of input items, all of the same length. | |
costs: A list of costs for each item | |
max_cost: Maximum total cost for each sub-batch | |
**common_kwargs: Keyword arguments to pass to all calls of func | |
Returns: | |
A list of outputs from calling func(*subbatch_data, **kwargs) for each | |
subbatch, and then rearranging the outputs from func into the original | |
item order. | |
""" | |
res = [None] * len(data[0]) | |
for item_ids, *subbatch_items in split(*data, costs=costs, max_cost=max_cost): | |
subbatch_out = func(*subbatch_items, **common_kwargs) | |
for item_id, item_out in zip(item_ids, subbatch_out): | |
res[item_id] = item_out | |
return res | |