OpenSLU / tools /load_from_hugging_face.py
LightChen2333's picture
Upload 78 files
223340a
'''
Author: Qiguang Chen
LastEditors: Qiguang Chen
Date: 2023-02-13 10:44:39
LastEditTime: 2023-02-14 10:28:43
Description:
'''
import os
import dill
from common import utils
from common.utils import InputData, download
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
# parser = argparse.ArgumentParser()
# parser.add_argument('--config_path', '-cp', type=str, default="config/reproduction/atis/joint_bert.yaml")
# args = parser.parse_args()
# config = Config.load_from_yaml(args.config_path)
# config.base["train"] = False
# config.base["test"] = False
# model_manager = ModelManager(config)
# model_manager.load()
class PretrainedConfigForSLU(PretrainedConfig):
def __init__(self, **kargs) -> None:
super().__init__(**kargs)
# pretrained_config = PretrainedConfigForSLU()
# # pretrained_config.push_to_hub("xxxx")
class PretrainedModelForSLU(PreTrainedModel):
def __init__(self, config: PretrainedConfig, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.config_class = config
self.model = utils.instantiate(config.model)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
cls.config_class = PretrainedConfigForSLU
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
class PreTrainedTokenizerForSLU(PreTrainedTokenizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
dir_names = f"save/{pretrained_model_name_or_path}".split("/")
dir_name = ""
for name in dir_names:
dir_name += name+"/"
if not os.path.exists(dir_name):
os.mkdir(dir_name)
cache_path = f"./save/{pretrained_model_name_or_path}/tokenizer.pkl"
if not os.path.exists(cache_path):
download(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/tokenizer.pkl", cache_path)
with open(cache_path, "rb") as f:
tokenizer = dill.load(f)
return tokenizer
# pretrained_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# pretrained_tokenizer = PreTrainedTokenizerForSLU.from_pretrained("LightChen2333/joint-bert-slu-atis")
# test_model = PretrainedModelForSLU.from_pretrained("LightChen2333/joint-bert-slu-atis")
# print(test_model(InputData([pretrained_tokenizer("I want to go to Beijing !")])))