File size: 4,949 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import pytest
import torch
from ding.torch_utils.network.merge import TorchBilinearCustomized, TorchBilinear, BilinearGeneral, FiLM
@pytest.mark.unittest
def test_torch_bilinear_customized():
batch_size = 10
in1_features = 20
in2_features = 30
out_features = 40
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
x = torch.randn(batch_size, in1_features)
z = torch.randn(batch_size, in2_features)
out = bilinear_customized(x, z)
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."
@pytest.mark.unittest
def test_torch_bilinear():
batch_size = 10
in1_features = 20
in2_features = 30
out_features = 40
torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
x = torch.randn(batch_size, in1_features)
z = torch.randn(batch_size, in2_features)
out = torch_bilinear(x, z)
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."
@pytest.mark.unittest
def test_bilinear_consistency():
batch_size = 10
in1_features = 20
in2_features = 30
out_features = 40
# Initialize weights and biases with set values
weight = torch.randn(out_features, in1_features, in2_features)
bias = torch.randn(out_features)
# Create and initialize TorchBilinearCustomized and TorchBilinear models
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
bilinear_customized.weight.data = weight.clone()
bilinear_customized.bias.data = bias.clone()
torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
torch_bilinear.weight.data = weight.clone()
torch_bilinear.bias.data = bias.clone()
# Provide same input to both models
x = torch.randn(batch_size, in1_features)
z = torch.randn(batch_size, in2_features)
# Compute outputs
out_bilinear_customized = bilinear_customized(x, z)
out_torch_bilinear = torch_bilinear(x, z)
# Compute the mean squared error between outputs
mse = torch.mean((out_bilinear_customized - out_torch_bilinear) ** 2)
print(f"Mean Squared Error between outputs: {mse.item()}")
# Check if outputs are the same
# assert torch.allclose(out_bilinear_customized, out_torch_bilinear),
# "Outputs of TorchBilinearCustomized and TorchBilinear are not the same."
def test_bilinear_general():
"""
Overview:
Test for the `BilinearGeneral` class.
"""
# Define the input dimensions and batch size
in1_features = 20
in2_features = 30
out_features = 40
batch_size = 10
# Create a BilinearGeneral instance
bilinear_general = BilinearGeneral(in1_features, in2_features, out_features)
# Create random inputs
input1 = torch.randn(batch_size, in1_features)
input2 = torch.randn(batch_size, in2_features)
# Perform forward pass
output = bilinear_general(input1, input2)
# Check output shape
assert output.shape == (batch_size, out_features), "Output shape does not match expected shape."
# Check parameter shapes
assert bilinear_general.W.shape == (
out_features, in1_features, in2_features
), "Weight W shape does not match expected shape."
assert bilinear_general.U.shape == (out_features, in2_features), "Weight U shape does not match expected shape."
assert bilinear_general.V.shape == (out_features, in1_features), "Weight V shape does not match expected shape."
assert bilinear_general.b.shape == (out_features, ), "Bias shape does not match expected shape."
# Check parameter types
assert isinstance(bilinear_general.W, torch.nn.Parameter), "Weight W is not an instance of torch.nn.Parameter."
assert isinstance(bilinear_general.U, torch.nn.Parameter), "Weight U is not an instance of torch.nn.Parameter."
assert isinstance(bilinear_general.V, torch.nn.Parameter), "Weight V is not an instance of torch.nn.Parameter."
assert isinstance(bilinear_general.b, torch.nn.Parameter), "Bias is not an instance of torch.nn.Parameter."
@pytest.mark.unittest
def test_film_forward():
# Set the feature and context dimensions
feature_dim = 128
context_dim = 256
# Initialize the FiLM layer
film_layer = FiLM(feature_dim, context_dim)
# Create random feature and context vectors
feature = torch.randn((32, feature_dim)) # batch size is 32
context = torch.randn((32, context_dim)) # batch size is 32
# Forward propagation
conditioned_feature = film_layer(feature, context)
# Check the output shape
assert conditioned_feature.shape == feature.shape, \
f'Expected output shape {feature.shape}, but got {conditioned_feature.shape}'
# Check that the output is different from the input
assert not torch.all(torch.eq(feature, conditioned_feature)), \
'The output feature is the same as the input feature'
|