AbLang_heavy / extra_fns.py
qilowoq's picture
Upload AbLang
001cc1f
raw
history blame
749 Bytes
import torch
import math
def gelu_new(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))
ACT2FN = {
"relu": torch.nn.functional.relu,
"gelu": torch.nn.functional.gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"mish": mish,
"sigmoid": torch.sigmoid,
}