File size: 2,925 Bytes
f771463
 
 
 
 
 
 
 
 
 
 
 
 
dee7eb6
f771463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Dict, List, Set

from spacy.cli.download import get_compatibility


def metrics_options() -> List[str]:
    return [
        "descriptive_stats",
        "readability",
        "dependency_distance",
        "pos_proportions",
        "coherence",
        "quality",
        "information_theory",
    ]


def language_options() -> Dict[str, str]:
    return {
        "Catalan": "ca",
        "Chinese": "zh",
        "Croatian": "hr",
        "Danish": "da",
        "Dutch": "nl",
        "English": "en",
        "Finnish": "fi",
        "French": "fr",
        "German": "de",
        "Greek": "el",
        "Italian": "it",
        "Japanese": "ja",
        "Korean": "ko",
        "Lithuanian": "lt",
        "Macedonian": "mk",
        "Multi-language": "xx",
        "Norwegian Bokmål": "nb",
        "Polish": "pl",
        "Portuguese": "pt",
        "Romanian": "ro",
        "Russian": "ru",
        "Spanish": "es",
        "Swedish": "sv",
        "Ukrainian": "uk",
    }


#################
# Model options #
#################


def all_model_size_options_pretty_to_short() -> Dict[str, str]:
    return {
        "Small": "sm",
        "Medium": "md",
        "Large": "lg",
        # "Transformer": "trf"  # Disabled for now
    }


def all_model_size_options_short_to_pretty() -> Dict[str, str]:
    return {
        short: pretty
        for pretty, short in all_model_size_options_pretty_to_short().items()
    }


def available_model_size_options(lang) -> List[str]:
    short_to_pretty = all_model_size_options_short_to_pretty()
    if lang == "all":
        return sorted(list(short_to_pretty.values()))
    return sorted(
        [
            short_to_pretty[short]
            for short in ModelAvailabilityChecker.available_model_sizes_for_language(
                lang
            )
        ]
    )


class ModelAvailabilityChecker:
    @staticmethod
    def available_models() -> List[str]:
        return list(get_compatibility().keys())

    @staticmethod
    def extract_language_and_size() -> List[List[str]]:
        # [["ca", "sm"], ["en", "lg"], ...]
        return list(
            [
                list(map(m.split("_").__getitem__, [0, -1]))
                for m in ModelAvailabilityChecker.available_models()
            ]
        )

    @staticmethod
    def model_is_available(lang: str, size: str) -> bool:
        lang_and_size = set(
            [
                "_".join(lang_size)
                for lang_size in ModelAvailabilityChecker.extract_language_and_size()
            ]
        )
        return f"{lang}_{size}" in lang_and_size

    @staticmethod
    def available_model_sizes_for_language(lang: str) -> Set[str]:
        return set([
            size
            for (lang_, size) in ModelAvailabilityChecker.extract_language_and_size()
            if lang_ == lang and size in all_model_size_options_pretty_to_short().values()
        ])