gomoku / DI-engine /ding /torch_utils /network /scatter_connection.py
zjowowen's picture
init space
079c32c
raw
history blame
5.43 kB
import torch
import torch.nn as nn
from typing import Tuple, List
from ding.hpc_rl import hpc_wrapper
def shape_fn_scatter_connection(args, kwargs) -> List[int]:
"""
Overview:
Return the shape of scatter_connection for HPC.
Arguments:
- args (:obj:`Tuple`): The arguments passed to the scatter_connection function.
- kwargs (:obj:`Dict`): The keyword arguments passed to the scatter_connection function.
Returns:
- shape (:obj:`List[int]`): A list representing the shape of scatter_connection, \
in the form of [B, M, N, H, W, scatter_type].
"""
if len(args) <= 1:
tmp = list(kwargs['x'].shape)
else:
tmp = list(args[1].shape) # args[0] is __main__.ScatterConnection object
if len(args) <= 2:
tmp.extend(kwargs['spatial_size'])
else:
tmp.extend(args[2])
tmp.append(args[0].scatter_type)
return tmp
class ScatterConnection(nn.Module):
"""
Overview:
Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor,
and these tensors are scattered into a feature map with map size.
Interfaces:
``__init__``, ``forward``, ``xy_forward``
"""
def __init__(self, scatter_type: str) -> None:
"""
Overview:
Initialize the ScatterConnection object.
Arguments:
- scatter_type (:obj:`str`): The scatter type, which decides the behavior when two entities have the \
same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the \
second one. If 'cover', the first one will be covered by the second one.
"""
super(ScatterConnection, self).__init__()
self.scatter_type = scatter_type
assert self.scatter_type in ['cover', 'add']
@hpc_wrapper(
shape_fn=shape_fn_scatter_connection,
namedtuple_data=False,
include_args=[0, 2],
include_kwargs=['x', 'location'],
is_cls_method=True
)
def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor:
"""
Overview:
Scatter input tensor 'x' into a spatial feature map.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \
is the number of entities, and `N` is the dimension of entity attributes.
- spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \
will be scattered, where `H` is the height and `W` is the width.
- location (:obj:`torch.Tensor`): The tensor of locations of shape `(B, M, 2)`. \
Each location should be (y, x).
Returns:
- output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`.
Note:
When there are some overlapping in locations, 'cover' mode will result in the loss of information.
'add' mode is used as a temporary substitute.
"""
device = x.device
B, M, N = x.shape
x = x.permute(0, 2, 1)
H, W = spatial_size
index = location[:, :, 1] + location[:, :, 0] * W
index = index.unsqueeze(dim=1).repeat(1, N, 1)
output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W)
if self.scatter_type == 'cover':
output.scatter_(dim=2, index=index, src=x)
elif self.scatter_type == 'add':
output.scatter_add_(dim=2, index=index, src=x)
output = output.view(B, N, H, W)
return output
def xy_forward(
self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y
) -> torch.Tensor:
"""
Overview:
Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \
is the number of entities, and `N` is the dimension of entity attributes.
- spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \
will be scattered, where `H` is the height and `W` is the width.
- coord_x (:obj:`torch.Tensor`): The x-coordinates tensor of shape `(B, M)`.
- coord_y (:obj:`torch.Tensor`): The y-coordinates tensor of shape `(B, M)`.
Returns:
- output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`.
Note:
When there are some overlapping in locations, 'cover' mode will result in the loss of information.
'add' mode is used as a temporary substitute.
"""
device = x.device
B, M, N = x.shape
x = x.permute(0, 2, 1)
H, W = spatial_size
index = (coord_x * W + coord_y).long()
index = index.unsqueeze(dim=1).repeat(1, N, 1)
output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W)
if self.scatter_type == 'cover':
output.scatter_(dim=2, index=index, src=x)
elif self.scatter_type == 'add':
output.scatter_add_(dim=2, index=index, src=x)
output = output.view(B, N, H, W)
return output