Spaces:
Runtime error
Runtime error
from geffnet import config | |
from geffnet.activations.activations_me import * | |
from geffnet.activations.activations_jit import * | |
from geffnet.activations.activations import * | |
import torch | |
_has_silu = 'silu' in dir(torch.nn.functional) | |
_ACT_FN_DEFAULT = dict( | |
silu=F.silu if _has_silu else swish, | |
swish=F.silu if _has_silu else swish, | |
mish=mish, | |
relu=F.relu, | |
relu6=F.relu6, | |
sigmoid=sigmoid, | |
tanh=tanh, | |
hard_sigmoid=hard_sigmoid, | |
hard_swish=hard_swish, | |
) | |
_ACT_FN_JIT = dict( | |
silu=F.silu if _has_silu else swish_jit, | |
swish=F.silu if _has_silu else swish_jit, | |
mish=mish_jit, | |
) | |
_ACT_FN_ME = dict( | |
silu=F.silu if _has_silu else swish_me, | |
swish=F.silu if _has_silu else swish_me, | |
mish=mish_me, | |
hard_swish=hard_swish_me, | |
hard_sigmoid_jit=hard_sigmoid_me, | |
) | |
_ACT_LAYER_DEFAULT = dict( | |
silu=nn.SiLU if _has_silu else Swish, | |
swish=nn.SiLU if _has_silu else Swish, | |
mish=Mish, | |
relu=nn.ReLU, | |
relu6=nn.ReLU6, | |
sigmoid=Sigmoid, | |
tanh=Tanh, | |
hard_sigmoid=HardSigmoid, | |
hard_swish=HardSwish, | |
) | |
_ACT_LAYER_JIT = dict( | |
silu=nn.SiLU if _has_silu else SwishJit, | |
swish=nn.SiLU if _has_silu else SwishJit, | |
mish=MishJit, | |
) | |
_ACT_LAYER_ME = dict( | |
silu=nn.SiLU if _has_silu else SwishMe, | |
swish=nn.SiLU if _has_silu else SwishMe, | |
mish=MishMe, | |
hard_swish=HardSwishMe, | |
hard_sigmoid=HardSigmoidMe | |
) | |
_OVERRIDE_FN = dict() | |
_OVERRIDE_LAYER = dict() | |
def add_override_act_fn(name, fn): | |
global _OVERRIDE_FN | |
_OVERRIDE_FN[name] = fn | |
def update_override_act_fn(overrides): | |
assert isinstance(overrides, dict) | |
global _OVERRIDE_FN | |
_OVERRIDE_FN.update(overrides) | |
def clear_override_act_fn(): | |
global _OVERRIDE_FN | |
_OVERRIDE_FN = dict() | |
def add_override_act_layer(name, fn): | |
_OVERRIDE_LAYER[name] = fn | |
def update_override_act_layer(overrides): | |
assert isinstance(overrides, dict) | |
global _OVERRIDE_LAYER | |
_OVERRIDE_LAYER.update(overrides) | |
def clear_override_act_layer(): | |
global _OVERRIDE_LAYER | |
_OVERRIDE_LAYER = dict() | |
def get_act_fn(name='relu'): | |
""" Activation Function Factory | |
Fetching activation fns by name with this function allows export or torch script friendly | |
functions to be returned dynamically based on current config. | |
""" | |
if name in _OVERRIDE_FN: | |
return _OVERRIDE_FN[name] | |
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) | |
if use_me and name in _ACT_FN_ME: | |
# If not exporting or scripting the model, first look for a memory optimized version | |
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin | |
return _ACT_FN_ME[name] | |
if config.is_exportable() and name in ('silu', 'swish'): | |
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack | |
return swish | |
use_jit = not (config.is_exportable() or config.is_no_jit()) | |
# NOTE: export tracing should work with jit scripted components, but I keep running into issues | |
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting | |
return _ACT_FN_JIT[name] | |
return _ACT_FN_DEFAULT[name] | |
def get_act_layer(name='relu'): | |
""" Activation Layer Factory | |
Fetching activation layers by name with this function allows export or torch script friendly | |
functions to be returned dynamically based on current config. | |
""" | |
if name in _OVERRIDE_LAYER: | |
return _OVERRIDE_LAYER[name] | |
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) | |
if use_me and name in _ACT_LAYER_ME: | |
return _ACT_LAYER_ME[name] | |
if config.is_exportable() and name in ('silu', 'swish'): | |
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack | |
return Swish | |
use_jit = not (config.is_exportable() or config.is_no_jit()) | |
# NOTE: export tracing should work with jit scripted components, but I keep running into issues | |
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting | |
return _ACT_LAYER_JIT[name] | |
return _ACT_LAYER_DEFAULT[name] | |