Spaces:
Paused
Paused
File size: 720 Bytes
db6a3b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn as nn
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.float()).type(x.dtype)
class GroupNorm32(nn.GroupNorm):
"""
A GroupNorm layer that converts to float32 before the forward pass.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.float()).type(x.dtype)
class ChannelLayerNorm32(LayerNorm32):
def forward(self, x: torch.Tensor) -> torch.Tensor:
DIM = x.dim()
x = x.permute(0, *range(2, DIM), 1).contiguous()
x = super().forward(x)
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
return x
|