File size: 5,428 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
from typing import Tuple, List
from ding.hpc_rl import hpc_wrapper


def shape_fn_scatter_connection(args, kwargs) -> List[int]:
    """
    Overview:
        Return the shape of scatter_connection for HPC.
    Arguments:
        - args (:obj:`Tuple`): The arguments passed to the scatter_connection function.
        - kwargs (:obj:`Dict`): The keyword arguments passed to the scatter_connection function.
    Returns:
        - shape (:obj:`List[int]`): A list representing the shape of scatter_connection, \
            in the form of [B, M, N, H, W, scatter_type].
    """
    if len(args) <= 1:
        tmp = list(kwargs['x'].shape)
    else:
        tmp = list(args[1].shape)  # args[0] is __main__.ScatterConnection object
    if len(args) <= 2:
        tmp.extend(kwargs['spatial_size'])
    else:
        tmp.extend(args[2])
    tmp.append(args[0].scatter_type)
    return tmp


class ScatterConnection(nn.Module):
    """
    Overview:
        Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor,
        and these tensors are scattered into a feature map with map size.
    Interfaces:
        ``__init__``, ``forward``, ``xy_forward``
    """

    def __init__(self, scatter_type: str) -> None:
        """
        Overview:
            Initialize the ScatterConnection object.
        Arguments:
            - scatter_type (:obj:`str`): The scatter type, which decides the behavior when two entities have the \
                same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the \
                second one. If 'cover', the first one will be covered by the second one.
        """
        super(ScatterConnection, self).__init__()
        self.scatter_type = scatter_type
        assert self.scatter_type in ['cover', 'add']

    @hpc_wrapper(
        shape_fn=shape_fn_scatter_connection,
        namedtuple_data=False,
        include_args=[0, 2],
        include_kwargs=['x', 'location'],
        is_cls_method=True
    )
    def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor:
        """
        Overview:
            Scatter input tensor 'x' into a spatial feature map.
        Arguments:
            - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \
                is the number of entities, and `N` is the dimension of entity attributes.
            - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \
                will be scattered, where `H` is the height and `W` is the width.
            - location (:obj:`torch.Tensor`): The tensor of locations of shape `(B, M, 2)`. \
                Each location should be (y, x).
        Returns:
            - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`.
        Note:
            When there are some overlapping in locations, 'cover' mode will result in the loss of information.
            'add' mode is used as a temporary substitute.
        """
        device = x.device
        B, M, N = x.shape
        x = x.permute(0, 2, 1)
        H, W = spatial_size
        index = location[:, :, 1] + location[:, :, 0] * W
        index = index.unsqueeze(dim=1).repeat(1, N, 1)
        output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W)
        if self.scatter_type == 'cover':
            output.scatter_(dim=2, index=index, src=x)
        elif self.scatter_type == 'add':
            output.scatter_add_(dim=2, index=index, src=x)
        output = output.view(B, N, H, W)
        return output

    def xy_forward(
            self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y
    ) -> torch.Tensor:
        """
        Overview:
            Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates.
        Arguments:
            - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \
                is the number of entities, and `N` is the dimension of entity attributes.
            - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \
                will be scattered, where `H` is the height and `W` is the width.
            - coord_x (:obj:`torch.Tensor`): The x-coordinates tensor of shape `(B, M)`.
            - coord_y (:obj:`torch.Tensor`): The y-coordinates tensor of shape `(B, M)`.
        Returns:
            - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`.
        Note:
            When there are some overlapping in locations, 'cover' mode will result in the loss of information.
            'add' mode is used as a temporary substitute.
        """
        device = x.device
        B, M, N = x.shape
        x = x.permute(0, 2, 1)
        H, W = spatial_size
        index = (coord_x * W + coord_y).long()
        index = index.unsqueeze(dim=1).repeat(1, N, 1)
        output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W)
        if self.scatter_type == 'cover':
            output.scatter_(dim=2, index=index, src=x)
        elif self.scatter_type == 'add':
            output.scatter_add_(dim=2, index=index, src=x)
        output = output.view(B, N, H, W)
        return output