Simon Duerr
add fast af
85bd48b
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specialized mapping functions."""
import functools
from typing import Any, Callable, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
PYTREE = Any
PYTREE_JAX_ARRAY = Any
partial = functools.partial
PROXY = object()
def _maybe_slice(array, i, slice_size, axis):
if axis is PROXY:
return array
else:
return jax.lax.dynamic_slice_in_dim(
array, i, slice_size=slice_size, axis=axis)
def _maybe_get_size(array, axis):
if axis == PROXY:
return -1
else:
return array.shape[axis]
def _expand_axes(axes, values, name='sharded_apply'):
values_tree_def = jax.tree_flatten(values)[1]
flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
# Replace None's with PROXY
flat_axes = [PROXY if x is None else x for x in flat_axes]
return jax.tree_unflatten(values_tree_def, flat_axes)
def sharded_map(
fun: Callable[..., PYTREE_JAX_ARRAY],
shard_size: Union[int, None] = 1,
in_axes: Union[int, PYTREE] = 0,
out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]:
"""Sharded vmap.
Maps `fun` over axes, in a way similar to vmap, but does so in shards of
`shard_size`. This allows a smooth trade-off between memory usage
(as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
Returns:
function with smap applied.
"""
vmapped_fun = hk.vmap(fun, in_axes, out_axes)
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
def sharded_apply(
fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic
shard_size: Union[int, None] = 1,
in_axes: Union[int, PYTREE] = 0,
out_axes: Union[int, PYTREE] = 0,
new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]:
"""Sharded apply.
Applies `fun` over shards to axes, in a way similar to vmap,
but does so in shards of `shard_size`. Shards are stacked after.
This allows a smooth trade-off between
memory usage (as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
new_out_axes: whether to stack outputs on new axes. This assumes that the
output sizes for each shard (including the possible remainder shard) are
the same.
Returns:
function with smap applied.
"""
docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} '
'but with additional array axes over which {fun} is mapped.')
if new_out_axes:
raise NotImplementedError('New output axes not yet implemented.')
# shard size None denotes no sharding
if shard_size is None:
return fun
@jax.util.wraps(fun, docstr=docstr)
def mapped_fn(*args):
# Expand in axes and Determine Loop range
in_axes_ = _expand_axes(in_axes, args)
in_sizes = jax.tree_util.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes)
num_extra_shards = (in_size - 1) // shard_size
# Fix Up if necessary
last_shard_size = in_size % shard_size
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
def apply_fun_to_slice(slice_start, slice_size):
input_slice = jax.tree_util.tree_map(
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
), args, in_axes_)
return fun(*input_slice)
remainder_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, last_shard_size))
out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)
if num_extra_shards > 0:
regular_shard_shape_dtype = hk.eval_shape(
partial(apply_fun_to_slice, 0, shard_size))
shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)
def make_output_shape(axis, shard_shape, remainder_shape):
return shard_shape[:axis] + (
shard_shape[axis] * num_extra_shards +
remainder_shape[axis],) + shard_shape[axis + 1:]
out_shapes = jax.tree_util.tree_map(make_output_shape, out_axes_, shard_shapes,
out_shapes)
# Calls dynamic Update slice with different argument order
# This is here since tree_multimap only works with positional arguments
def dynamic_update_slice_in_dim(full_array, update, axis, i):
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
def compute_shard(outputs, slice_start, slice_size):
slice_out = apply_fun_to_slice(slice_start, slice_size)
update_slice = partial(
dynamic_update_slice_in_dim, i=slice_start)
return jax.tree_util.tree_map(update_slice, outputs, slice_out, out_axes_)
def scan_iteration(outputs, i):
new_outputs = compute_shard(outputs, i, shard_size)
return new_outputs, ()
slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)
def allocate_buffer(dtype, shape):
return jnp.zeros(shape, dtype=dtype)
outputs = jax.tree_util.tree_map(allocate_buffer, out_dtypes, out_shapes)
if slice_starts.shape[0] > 0:
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
if last_shard_size != shard_size:
remainder_start = in_size - last_shard_size
outputs = compute_shard(outputs, remainder_start, last_shard_size)
return outputs
return mapped_fn
def inference_subbatch(
module: Callable[..., PYTREE_JAX_ARRAY],
subbatch_size: int,
batched_args: Sequence[PYTREE_JAX_ARRAY],
nonbatched_args: Sequence[PYTREE_JAX_ARRAY],
low_memory: bool = True,
input_subbatch_dim: int = 0,
output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY:
"""Run through subbatches (like batch apply but with split and concat)."""
assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
if not low_memory:
args = list(batched_args) + list(nonbatched_args)
return module(*args)
if output_subbatch_dim is None:
output_subbatch_dim = input_subbatch_dim
def run_module(*batched_args):
args = list(batched_args) + list(nonbatched_args)
return module(*args)
sharded_module = sharded_apply(run_module,
shard_size=subbatch_size,
in_axes=input_subbatch_dim,
out_axes=output_subbatch_dim)
return sharded_module(*batched_args)