|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import Namespace |
|
from typing import Dict, Any |
|
|
|
import torch |
|
|
|
from .adaptor_generic import GenericAdaptor, AdaptorBase |
|
|
|
dict_t = Dict[str, Any] |
|
state_t = Dict[str, torch.Tensor] |
|
|
|
|
|
class AdaptorRegistry: |
|
def __init__(self): |
|
self._registry = {} |
|
|
|
def register_adaptor(self, name): |
|
def decorator(factory_function): |
|
if name in self._registry: |
|
raise ValueError(f"Model '{name}' already registered") |
|
self._registry[name] = factory_function |
|
return factory_function |
|
return decorator |
|
|
|
def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase: |
|
if name not in self._registry: |
|
return GenericAdaptor(main_config, adaptor_config, state) |
|
return self._registry[name](main_config, adaptor_config, state) |
|
|
|
|
|
adaptor_registry = AdaptorRegistry() |
|
|