Safetensors
custom_code
File size: 1,371 Bytes
3c63951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
from typing import Dict, Any

import torch

from .adaptor_generic import GenericAdaptor, AdaptorBase

dict_t = Dict[str, Any]
state_t = Dict[str, torch.Tensor]


class AdaptorRegistry:
    def __init__(self):
        self._registry = {}

    def register_adaptor(self, name):
        def decorator(factory_function):
            if name in self._registry:
                raise ValueError(f"Model '{name}' already registered")
            self._registry[name] = factory_function
            return factory_function
        return decorator

    def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
        if name not in self._registry:
            return GenericAdaptor(main_config, adaptor_config, state)
        return self._registry[name](main_config, adaptor_config, state)

# Creating an instance of the registry
adaptor_registry = AdaptorRegistry()