File size: 4,719 Bytes
c7e6202 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from audiocraft.modules.rope import RotaryEmbedding
from audiocraft.modules.transformer import StreamingTransformer
def test_rope():
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C)
xq = torch.rand((B, T, H, C))
xk = torch.rand((B, T, H, C))
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
assert list(xq_out.shape) == [B, T, H, C]
assert list(xk_out.shape) == [B, T, H, C]
def test_rope_io_dtypes():
B, T, H, C = 8, 75, 16, 128
rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
# Test bfloat16 inputs w/ both 32 and 64 precision rope.
xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
assert xq_out.dtype == torch.bfloat16
xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
assert xq_out.dtype == torch.bfloat16
# Test float32 inputs w/ both 32 and 64 precision rope.
xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
assert xq_out.dtype == torch.float32
xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
assert xq_out.dtype == torch.float32
def test_transformer_with_rope():
torch.manual_seed(1234)
for pos in ['rope', 'sin_rope']:
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
positional_embedding=pos)
tr.eval()
steps = 12
x = torch.randn(3, steps, 16)
out = tr(x)
assert list(out.shape) == list(x.shape)
@torch.no_grad()
def test_rope_streaming():
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, causal=True, dropout=0.,
custom=True, positional_embedding='rope')
tr.eval()
steps = 12
x = torch.randn(3, steps, 16)
ref = tr(x)
with tr.streaming():
outs = []
frame_sizes = [1] * steps
for frame_size in frame_sizes:
frame = x[:, :frame_size]
x = x[:, frame_size:]
outs.append(tr(frame))
out = torch.cat(outs, dim=1)
assert list(out.shape) == [3, steps, 16]
delta = torch.norm(out - ref) / torch.norm(out)
assert delta < 1e-6, delta
@torch.no_grad()
def test_rope_streaming_past_context():
torch.manual_seed(1234)
for context in [None, 10]:
tr = StreamingTransformer(
16, 4, 1 if context else 2,
causal=True, past_context=context, custom=True,
dropout=0., positional_embedding='rope')
tr.eval()
steps = 20
x = torch.randn(3, steps, 16)
ref = tr(x)
with tr.streaming():
outs = []
frame_sizes = [1] * steps
for frame_size in frame_sizes:
frame = x[:, :frame_size]
x = x[:, frame_size:]
outs.append(tr(frame))
out = torch.cat(outs, dim=1)
assert list(out.shape) == [3, steps, 16]
delta = torch.norm(out - ref) / torch.norm(out)
assert delta < 1e-6, delta
def test_rope_memory_efficient():
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
positional_embedding='rope')
tr_mem_efficient = StreamingTransformer(
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
positional_embedding='rope')
tr_mem_efficient.load_state_dict(tr.state_dict())
tr.eval()
steps = 12
x = torch.randn(3, steps, 16)
with torch.no_grad():
y = tr(x)
y2 = tr_mem_efficient(x)
# Check at float precision b/c this is the rope default.
assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
def test_rope_with_xpos():
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C, xpos=True)
xq = torch.rand((B, T, H, C))
xk = torch.rand((B, T, H, C))
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
assert list(xq_out.shape) == [B, T, H, C]
assert list(xk_out.shape) == [B, T, H, C]
def test_positional_scale():
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
xq = torch.rand((B, T, H, C))
xk = torch.rand((B, T, H, C))
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
assert torch.allclose(xq, xq_out)
assert torch.allclose(xk, xk_out)
|