Update tokenization_dart.py
Browse files- tokenization_dart.py +1 -23
tokenization_dart.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import logging
|
2 |
-
import
|
3 |
-
from typing import Dict, List
|
4 |
-
from pydantic.dataclasses import dataclass
|
5 |
|
6 |
from transformers import PreTrainedTokenizerFast
|
7 |
from tokenizers.decoders import Decoder
|
@@ -57,26 +55,6 @@ PROMPT_TEMPLATE = (
|
|
57 |
# fmt: on
|
58 |
|
59 |
|
60 |
-
@dataclass
|
61 |
-
class Category:
|
62 |
-
name: str
|
63 |
-
bos_token_id: int
|
64 |
-
eos_token_id: int
|
65 |
-
|
66 |
-
|
67 |
-
@dataclass
|
68 |
-
class TagCategoryConfig:
|
69 |
-
categories: Dict[str, Category]
|
70 |
-
category_to_token_ids: Dict[str, List[int]]
|
71 |
-
|
72 |
-
|
73 |
-
def load_tag_category_config(config_json: str):
|
74 |
-
with open(config_json, "rb") as file:
|
75 |
-
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
|
76 |
-
|
77 |
-
return config
|
78 |
-
|
79 |
-
|
80 |
class DartDecoder:
|
81 |
def __init__(self, special_tokens: List[str]):
|
82 |
self.special_tokens = list(special_tokens)
|
|
|
1 |
import logging
|
2 |
+
from typing import List
|
|
|
|
|
3 |
|
4 |
from transformers import PreTrainedTokenizerFast
|
5 |
from tokenizers.decoders import Decoder
|
|
|
55 |
# fmt: on
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
class DartDecoder:
|
59 |
def __init__(self, special_tokens: List[str]):
|
60 |
self.special_tokens = list(special_tokens)
|