|
import torch |
|
import torch.nn as nn |
|
from typing import Tuple, List |
|
from ding.hpc_rl import hpc_wrapper |
|
|
|
|
|
def shape_fn_scatter_connection(args, kwargs) -> List[int]: |
|
""" |
|
Overview: |
|
Return the shape of scatter_connection for HPC. |
|
Arguments: |
|
- args (:obj:`Tuple`): The arguments passed to the scatter_connection function. |
|
- kwargs (:obj:`Dict`): The keyword arguments passed to the scatter_connection function. |
|
Returns: |
|
- shape (:obj:`List[int]`): A list representing the shape of scatter_connection, \ |
|
in the form of [B, M, N, H, W, scatter_type]. |
|
""" |
|
if len(args) <= 1: |
|
tmp = list(kwargs['x'].shape) |
|
else: |
|
tmp = list(args[1].shape) |
|
if len(args) <= 2: |
|
tmp.extend(kwargs['spatial_size']) |
|
else: |
|
tmp.extend(args[2]) |
|
tmp.append(args[0].scatter_type) |
|
return tmp |
|
|
|
|
|
class ScatterConnection(nn.Module): |
|
""" |
|
Overview: |
|
Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor, |
|
and these tensors are scattered into a feature map with map size. |
|
Interfaces: |
|
``__init__``, ``forward``, ``xy_forward`` |
|
""" |
|
|
|
def __init__(self, scatter_type: str) -> None: |
|
""" |
|
Overview: |
|
Initialize the ScatterConnection object. |
|
Arguments: |
|
- scatter_type (:obj:`str`): The scatter type, which decides the behavior when two entities have the \ |
|
same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the \ |
|
second one. If 'cover', the first one will be covered by the second one. |
|
""" |
|
super(ScatterConnection, self).__init__() |
|
self.scatter_type = scatter_type |
|
assert self.scatter_type in ['cover', 'add'] |
|
|
|
@hpc_wrapper( |
|
shape_fn=shape_fn_scatter_connection, |
|
namedtuple_data=False, |
|
include_args=[0, 2], |
|
include_kwargs=['x', 'location'], |
|
is_cls_method=True |
|
) |
|
def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Scatter input tensor 'x' into a spatial feature map. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ |
|
is the number of entities, and `N` is the dimension of entity attributes. |
|
- spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ |
|
will be scattered, where `H` is the height and `W` is the width. |
|
- location (:obj:`torch.Tensor`): The tensor of locations of shape `(B, M, 2)`. \ |
|
Each location should be (y, x). |
|
Returns: |
|
- output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. |
|
Note: |
|
When there are some overlapping in locations, 'cover' mode will result in the loss of information. |
|
'add' mode is used as a temporary substitute. |
|
""" |
|
device = x.device |
|
B, M, N = x.shape |
|
x = x.permute(0, 2, 1) |
|
H, W = spatial_size |
|
index = location[:, :, 1] + location[:, :, 0] * W |
|
index = index.unsqueeze(dim=1).repeat(1, N, 1) |
|
output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) |
|
if self.scatter_type == 'cover': |
|
output.scatter_(dim=2, index=index, src=x) |
|
elif self.scatter_type == 'add': |
|
output.scatter_add_(dim=2, index=index, src=x) |
|
output = output.view(B, N, H, W) |
|
return output |
|
|
|
def xy_forward( |
|
self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y |
|
) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ |
|
is the number of entities, and `N` is the dimension of entity attributes. |
|
- spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ |
|
will be scattered, where `H` is the height and `W` is the width. |
|
- coord_x (:obj:`torch.Tensor`): The x-coordinates tensor of shape `(B, M)`. |
|
- coord_y (:obj:`torch.Tensor`): The y-coordinates tensor of shape `(B, M)`. |
|
Returns: |
|
- output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. |
|
Note: |
|
When there are some overlapping in locations, 'cover' mode will result in the loss of information. |
|
'add' mode is used as a temporary substitute. |
|
""" |
|
device = x.device |
|
B, M, N = x.shape |
|
x = x.permute(0, 2, 1) |
|
H, W = spatial_size |
|
index = (coord_x * W + coord_y).long() |
|
index = index.unsqueeze(dim=1).repeat(1, N, 1) |
|
output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) |
|
if self.scatter_type == 'cover': |
|
output.scatter_(dim=2, index=index, src=x) |
|
elif self.scatter_type == 'add': |
|
output.scatter_add_(dim=2, index=index, src=x) |
|
output = output.view(B, N, H, W) |
|
return output |
|
|