File size: 4,949 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import pytest
import torch
from ding.torch_utils.network.merge import TorchBilinearCustomized, TorchBilinear, BilinearGeneral, FiLM


@pytest.mark.unittest
def test_torch_bilinear_customized():
    batch_size = 10
    in1_features = 20
    in2_features = 30
    out_features = 40
    bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
    x = torch.randn(batch_size, in1_features)
    z = torch.randn(batch_size, in2_features)
    out = bilinear_customized(x, z)
    assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."


@pytest.mark.unittest
def test_torch_bilinear():
    batch_size = 10
    in1_features = 20
    in2_features = 30
    out_features = 40
    torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
    x = torch.randn(batch_size, in1_features)
    z = torch.randn(batch_size, in2_features)
    out = torch_bilinear(x, z)
    assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."


@pytest.mark.unittest
def test_bilinear_consistency():
    batch_size = 10
    in1_features = 20
    in2_features = 30
    out_features = 40

    # Initialize weights and biases with set values
    weight = torch.randn(out_features, in1_features, in2_features)
    bias = torch.randn(out_features)

    # Create and initialize TorchBilinearCustomized and TorchBilinear models
    bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
    bilinear_customized.weight.data = weight.clone()
    bilinear_customized.bias.data = bias.clone()

    torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
    torch_bilinear.weight.data = weight.clone()
    torch_bilinear.bias.data = bias.clone()

    # Provide same input to both models
    x = torch.randn(batch_size, in1_features)
    z = torch.randn(batch_size, in2_features)

    # Compute outputs
    out_bilinear_customized = bilinear_customized(x, z)
    out_torch_bilinear = torch_bilinear(x, z)

    # Compute the mean squared error between outputs
    mse = torch.mean((out_bilinear_customized - out_torch_bilinear) ** 2)

    print(f"Mean Squared Error between outputs: {mse.item()}")

    # Check if outputs are the same
    # assert torch.allclose(out_bilinear_customized, out_torch_bilinear),
    # "Outputs of TorchBilinearCustomized and TorchBilinear are not the same."


def test_bilinear_general():
    """
    Overview:
        Test for the `BilinearGeneral` class.
    """
    # Define the input dimensions and batch size
    in1_features = 20
    in2_features = 30
    out_features = 40
    batch_size = 10

    # Create a BilinearGeneral instance
    bilinear_general = BilinearGeneral(in1_features, in2_features, out_features)

    # Create random inputs
    input1 = torch.randn(batch_size, in1_features)
    input2 = torch.randn(batch_size, in2_features)

    # Perform forward pass
    output = bilinear_general(input1, input2)

    # Check output shape
    assert output.shape == (batch_size, out_features), "Output shape does not match expected shape."

    # Check parameter shapes
    assert bilinear_general.W.shape == (
        out_features, in1_features, in2_features
    ), "Weight W shape does not match expected shape."
    assert bilinear_general.U.shape == (out_features, in2_features), "Weight U shape does not match expected shape."
    assert bilinear_general.V.shape == (out_features, in1_features), "Weight V shape does not match expected shape."
    assert bilinear_general.b.shape == (out_features, ), "Bias shape does not match expected shape."

    # Check parameter types
    assert isinstance(bilinear_general.W, torch.nn.Parameter), "Weight W is not an instance of torch.nn.Parameter."
    assert isinstance(bilinear_general.U, torch.nn.Parameter), "Weight U is not an instance of torch.nn.Parameter."
    assert isinstance(bilinear_general.V, torch.nn.Parameter), "Weight V is not an instance of torch.nn.Parameter."
    assert isinstance(bilinear_general.b, torch.nn.Parameter), "Bias is not an instance of torch.nn.Parameter."


@pytest.mark.unittest
def test_film_forward():
    # Set the feature and context dimensions
    feature_dim = 128
    context_dim = 256

    # Initialize the FiLM layer
    film_layer = FiLM(feature_dim, context_dim)

    # Create random feature and context vectors
    feature = torch.randn((32, feature_dim))  # batch size is 32
    context = torch.randn((32, context_dim))  # batch size is 32

    # Forward propagation
    conditioned_feature = film_layer(feature, context)

    # Check the output shape
    assert conditioned_feature.shape == feature.shape, \
        f'Expected output shape {feature.shape}, but got {conditioned_feature.shape}'

    # Check that the output is different from the input
    assert not torch.all(torch.eq(feature, conditioned_feature)), \
        'The output feature is the same as the input feature'