Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2018 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Auto Model class.""" | |
import warnings | |
from collections import OrderedDict | |
from ...utils import logging | |
from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update | |
from .configuration_auto import CONFIG_MAPPING_NAMES | |
logger = logging.get_logger(__name__) | |
MODEL_MAPPING_NAMES = OrderedDict( | |
[ | |
# Base model mapping | |
("albert", "AlbertModel"), | |
("align", "AlignModel"), | |
("altclip", "AltCLIPModel"), | |
("audio-spectrogram-transformer", "ASTModel"), | |
("autoformer", "AutoformerModel"), | |
("bark", "BarkModel"), | |
("bart", "BartModel"), | |
("beit", "BeitModel"), | |
("bert", "BertModel"), | |
("bert-generation", "BertGenerationEncoder"), | |
("big_bird", "BigBirdModel"), | |
("bigbird_pegasus", "BigBirdPegasusModel"), | |
("biogpt", "BioGptModel"), | |
("bit", "BitModel"), | |
("blenderbot", "BlenderbotModel"), | |
("blenderbot-small", "BlenderbotSmallModel"), | |
("blip", "BlipModel"), | |
("blip-2", "Blip2Model"), | |
("bloom", "BloomModel"), | |
("bridgetower", "BridgeTowerModel"), | |
("bros", "BrosModel"), | |
("camembert", "CamembertModel"), | |
("canine", "CanineModel"), | |
("chinese_clip", "ChineseCLIPModel"), | |
("clap", "ClapModel"), | |
("clip", "CLIPModel"), | |
("clipseg", "CLIPSegModel"), | |
("code_llama", "LlamaModel"), | |
("codegen", "CodeGenModel"), | |
("conditional_detr", "ConditionalDetrModel"), | |
("convbert", "ConvBertModel"), | |
("convnext", "ConvNextModel"), | |
("convnextv2", "ConvNextV2Model"), | |
("cpmant", "CpmAntModel"), | |
("ctrl", "CTRLModel"), | |
("cvt", "CvtModel"), | |
("data2vec-audio", "Data2VecAudioModel"), | |
("data2vec-text", "Data2VecTextModel"), | |
("data2vec-vision", "Data2VecVisionModel"), | |
("deberta", "DebertaModel"), | |
("deberta-v2", "DebertaV2Model"), | |
("decision_transformer", "DecisionTransformerModel"), | |
("deformable_detr", "DeformableDetrModel"), | |
("deit", "DeiTModel"), | |
("deta", "DetaModel"), | |
("detr", "DetrModel"), | |
("dinat", "DinatModel"), | |
("dinov2", "Dinov2Model"), | |
("distilbert", "DistilBertModel"), | |
("donut-swin", "DonutSwinModel"), | |
("dpr", "DPRQuestionEncoder"), | |
("dpt", "DPTModel"), | |
("efficientformer", "EfficientFormerModel"), | |
("efficientnet", "EfficientNetModel"), | |
("electra", "ElectraModel"), | |
("encodec", "EncodecModel"), | |
("ernie", "ErnieModel"), | |
("ernie_m", "ErnieMModel"), | |
("esm", "EsmModel"), | |
("falcon", "FalconModel"), | |
("flaubert", "FlaubertModel"), | |
("flava", "FlavaModel"), | |
("fnet", "FNetModel"), | |
("focalnet", "FocalNetModel"), | |
("fsmt", "FSMTModel"), | |
("funnel", ("FunnelModel", "FunnelBaseModel")), | |
("git", "GitModel"), | |
("glpn", "GLPNModel"), | |
("gpt-sw3", "GPT2Model"), | |
("gpt2", "GPT2Model"), | |
("gpt_bigcode", "GPTBigCodeModel"), | |
("gpt_neo", "GPTNeoModel"), | |
("gpt_neox", "GPTNeoXModel"), | |
("gpt_neox_japanese", "GPTNeoXJapaneseModel"), | |
("gptj", "GPTJModel"), | |
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), | |
("graphormer", "GraphormerModel"), | |
("groupvit", "GroupViTModel"), | |
("hubert", "HubertModel"), | |
("ibert", "IBertModel"), | |
("idefics", "IdeficsModel"), | |
("imagegpt", "ImageGPTModel"), | |
("informer", "InformerModel"), | |
("jukebox", "JukeboxModel"), | |
("layoutlm", "LayoutLMModel"), | |
("layoutlmv2", "LayoutLMv2Model"), | |
("layoutlmv3", "LayoutLMv3Model"), | |
("led", "LEDModel"), | |
("levit", "LevitModel"), | |
("lilt", "LiltModel"), | |
("llama", "LlamaModel"), | |
("longformer", "LongformerModel"), | |
("longt5", "LongT5Model"), | |
("luke", "LukeModel"), | |
("lxmert", "LxmertModel"), | |
("m2m_100", "M2M100Model"), | |
("marian", "MarianModel"), | |
("markuplm", "MarkupLMModel"), | |
("mask2former", "Mask2FormerModel"), | |
("maskformer", "MaskFormerModel"), | |
("maskformer-swin", "MaskFormerSwinModel"), | |
("mbart", "MBartModel"), | |
("mctct", "MCTCTModel"), | |
("mega", "MegaModel"), | |
("megatron-bert", "MegatronBertModel"), | |
("mgp-str", "MgpstrForSceneTextRecognition"), | |
("mistral", "MistralModel"), | |
("mobilebert", "MobileBertModel"), | |
("mobilenet_v1", "MobileNetV1Model"), | |
("mobilenet_v2", "MobileNetV2Model"), | |
("mobilevit", "MobileViTModel"), | |
("mobilevitv2", "MobileViTV2Model"), | |
("mpnet", "MPNetModel"), | |
("mpt", "MptModel"), | |
("mra", "MraModel"), | |
("mt5", "MT5Model"), | |
("mvp", "MvpModel"), | |
("nat", "NatModel"), | |
("nezha", "NezhaModel"), | |
("nllb-moe", "NllbMoeModel"), | |
("nystromformer", "NystromformerModel"), | |
("oneformer", "OneFormerModel"), | |
("open-llama", "OpenLlamaModel"), | |
("openai-gpt", "OpenAIGPTModel"), | |
("opt", "OPTModel"), | |
("owlvit", "OwlViTModel"), | |
("pegasus", "PegasusModel"), | |
("pegasus_x", "PegasusXModel"), | |
("perceiver", "PerceiverModel"), | |
("persimmon", "PersimmonModel"), | |
("plbart", "PLBartModel"), | |
("poolformer", "PoolFormerModel"), | |
("prophetnet", "ProphetNetModel"), | |
("pvt", "PvtModel"), | |
("qdqbert", "QDQBertModel"), | |
("reformer", "ReformerModel"), | |
("regnet", "RegNetModel"), | |
("rembert", "RemBertModel"), | |
("resnet", "ResNetModel"), | |
("retribert", "RetriBertModel"), | |
("roberta", "RobertaModel"), | |
("roberta-prelayernorm", "RobertaPreLayerNormModel"), | |
("roc_bert", "RoCBertModel"), | |
("roformer", "RoFormerModel"), | |
("rwkv", "RwkvModel"), | |
("sam", "SamModel"), | |
("segformer", "SegformerModel"), | |
("sew", "SEWModel"), | |
("sew-d", "SEWDModel"), | |
("speech_to_text", "Speech2TextModel"), | |
("speecht5", "SpeechT5Model"), | |
("splinter", "SplinterModel"), | |
("squeezebert", "SqueezeBertModel"), | |
("swiftformer", "SwiftFormerModel"), | |
("swin", "SwinModel"), | |
("swin2sr", "Swin2SRModel"), | |
("swinv2", "Swinv2Model"), | |
("switch_transformers", "SwitchTransformersModel"), | |
("t5", "T5Model"), | |
("table-transformer", "TableTransformerModel"), | |
("tapas", "TapasModel"), | |
("time_series_transformer", "TimeSeriesTransformerModel"), | |
("timesformer", "TimesformerModel"), | |
("timm_backbone", "TimmBackbone"), | |
("trajectory_transformer", "TrajectoryTransformerModel"), | |
("transfo-xl", "TransfoXLModel"), | |
("tvlt", "TvltModel"), | |
("umt5", "UMT5Model"), | |
("unispeech", "UniSpeechModel"), | |
("unispeech-sat", "UniSpeechSatModel"), | |
("van", "VanModel"), | |
("videomae", "VideoMAEModel"), | |
("vilt", "ViltModel"), | |
("vision-text-dual-encoder", "VisionTextDualEncoderModel"), | |
("visual_bert", "VisualBertModel"), | |
("vit", "ViTModel"), | |
("vit_hybrid", "ViTHybridModel"), | |
("vit_mae", "ViTMAEModel"), | |
("vit_msn", "ViTMSNModel"), | |
("vitdet", "VitDetModel"), | |
("vits", "VitsModel"), | |
("vivit", "VivitModel"), | |
("wav2vec2", "Wav2Vec2Model"), | |
("wav2vec2-conformer", "Wav2Vec2ConformerModel"), | |
("wavlm", "WavLMModel"), | |
("whisper", "WhisperModel"), | |
("xclip", "XCLIPModel"), | |
("xglm", "XGLMModel"), | |
("xlm", "XLMModel"), | |
("xlm-prophetnet", "XLMProphetNetModel"), | |
("xlm-roberta", "XLMRobertaModel"), | |
("xlm-roberta-xl", "XLMRobertaXLModel"), | |
("xlnet", "XLNetModel"), | |
("xmod", "XmodModel"), | |
("yolos", "YolosModel"), | |
("yoso", "YosoModel"), | |
] | |
) | |
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for pre-training mapping | |
("albert", "AlbertForPreTraining"), | |
("bart", "BartForConditionalGeneration"), | |
("bert", "BertForPreTraining"), | |
("big_bird", "BigBirdForPreTraining"), | |
("bloom", "BloomForCausalLM"), | |
("camembert", "CamembertForMaskedLM"), | |
("ctrl", "CTRLLMHeadModel"), | |
("data2vec-text", "Data2VecTextForMaskedLM"), | |
("deberta", "DebertaForMaskedLM"), | |
("deberta-v2", "DebertaV2ForMaskedLM"), | |
("distilbert", "DistilBertForMaskedLM"), | |
("electra", "ElectraForPreTraining"), | |
("ernie", "ErnieForPreTraining"), | |
("flaubert", "FlaubertWithLMHeadModel"), | |
("flava", "FlavaForPreTraining"), | |
("fnet", "FNetForPreTraining"), | |
("fsmt", "FSMTForConditionalGeneration"), | |
("funnel", "FunnelForPreTraining"), | |
("gpt-sw3", "GPT2LMHeadModel"), | |
("gpt2", "GPT2LMHeadModel"), | |
("gpt_bigcode", "GPTBigCodeForCausalLM"), | |
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), | |
("ibert", "IBertForMaskedLM"), | |
("idefics", "IdeficsForVisionText2Text"), | |
("layoutlm", "LayoutLMForMaskedLM"), | |
("longformer", "LongformerForMaskedLM"), | |
("luke", "LukeForMaskedLM"), | |
("lxmert", "LxmertForPreTraining"), | |
("mega", "MegaForMaskedLM"), | |
("megatron-bert", "MegatronBertForPreTraining"), | |
("mobilebert", "MobileBertForPreTraining"), | |
("mpnet", "MPNetForMaskedLM"), | |
("mpt", "MptForCausalLM"), | |
("mra", "MraForMaskedLM"), | |
("mvp", "MvpForConditionalGeneration"), | |
("nezha", "NezhaForPreTraining"), | |
("nllb-moe", "NllbMoeForConditionalGeneration"), | |
("openai-gpt", "OpenAIGPTLMHeadModel"), | |
("retribert", "RetriBertModel"), | |
("roberta", "RobertaForMaskedLM"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), | |
("roc_bert", "RoCBertForPreTraining"), | |
("rwkv", "RwkvForCausalLM"), | |
("splinter", "SplinterForPreTraining"), | |
("squeezebert", "SqueezeBertForMaskedLM"), | |
("switch_transformers", "SwitchTransformersForConditionalGeneration"), | |
("t5", "T5ForConditionalGeneration"), | |
("tapas", "TapasForMaskedLM"), | |
("transfo-xl", "TransfoXLLMHeadModel"), | |
("tvlt", "TvltForPreTraining"), | |
("unispeech", "UniSpeechForPreTraining"), | |
("unispeech-sat", "UniSpeechSatForPreTraining"), | |
("videomae", "VideoMAEForPreTraining"), | |
("visual_bert", "VisualBertForPreTraining"), | |
("vit_mae", "ViTMAEForPreTraining"), | |
("wav2vec2", "Wav2Vec2ForPreTraining"), | |
("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"), | |
("xlm", "XLMWithLMHeadModel"), | |
("xlm-roberta", "XLMRobertaForMaskedLM"), | |
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), | |
("xlnet", "XLNetLMHeadModel"), | |
("xmod", "XmodForMaskedLM"), | |
] | |
) | |
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model with LM heads mapping | |
("albert", "AlbertForMaskedLM"), | |
("bart", "BartForConditionalGeneration"), | |
("bert", "BertForMaskedLM"), | |
("big_bird", "BigBirdForMaskedLM"), | |
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), | |
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), | |
("bloom", "BloomForCausalLM"), | |
("camembert", "CamembertForMaskedLM"), | |
("codegen", "CodeGenForCausalLM"), | |
("convbert", "ConvBertForMaskedLM"), | |
("cpmant", "CpmAntForCausalLM"), | |
("ctrl", "CTRLLMHeadModel"), | |
("data2vec-text", "Data2VecTextForMaskedLM"), | |
("deberta", "DebertaForMaskedLM"), | |
("deberta-v2", "DebertaV2ForMaskedLM"), | |
("distilbert", "DistilBertForMaskedLM"), | |
("electra", "ElectraForMaskedLM"), | |
("encoder-decoder", "EncoderDecoderModel"), | |
("ernie", "ErnieForMaskedLM"), | |
("esm", "EsmForMaskedLM"), | |
("flaubert", "FlaubertWithLMHeadModel"), | |
("fnet", "FNetForMaskedLM"), | |
("fsmt", "FSMTForConditionalGeneration"), | |
("funnel", "FunnelForMaskedLM"), | |
("git", "GitForCausalLM"), | |
("gpt-sw3", "GPT2LMHeadModel"), | |
("gpt2", "GPT2LMHeadModel"), | |
("gpt_bigcode", "GPTBigCodeForCausalLM"), | |
("gpt_neo", "GPTNeoForCausalLM"), | |
("gpt_neox", "GPTNeoXForCausalLM"), | |
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), | |
("gptj", "GPTJForCausalLM"), | |
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), | |
("ibert", "IBertForMaskedLM"), | |
("layoutlm", "LayoutLMForMaskedLM"), | |
("led", "LEDForConditionalGeneration"), | |
("longformer", "LongformerForMaskedLM"), | |
("longt5", "LongT5ForConditionalGeneration"), | |
("luke", "LukeForMaskedLM"), | |
("m2m_100", "M2M100ForConditionalGeneration"), | |
("marian", "MarianMTModel"), | |
("mega", "MegaForMaskedLM"), | |
("megatron-bert", "MegatronBertForCausalLM"), | |
("mobilebert", "MobileBertForMaskedLM"), | |
("mpnet", "MPNetForMaskedLM"), | |
("mpt", "MptForCausalLM"), | |
("mra", "MraForMaskedLM"), | |
("mvp", "MvpForConditionalGeneration"), | |
("nezha", "NezhaForMaskedLM"), | |
("nllb-moe", "NllbMoeForConditionalGeneration"), | |
("nystromformer", "NystromformerForMaskedLM"), | |
("openai-gpt", "OpenAIGPTLMHeadModel"), | |
("pegasus_x", "PegasusXForConditionalGeneration"), | |
("plbart", "PLBartForConditionalGeneration"), | |
("pop2piano", "Pop2PianoForConditionalGeneration"), | |
("qdqbert", "QDQBertForMaskedLM"), | |
("reformer", "ReformerModelWithLMHead"), | |
("rembert", "RemBertForMaskedLM"), | |
("roberta", "RobertaForMaskedLM"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), | |
("roc_bert", "RoCBertForMaskedLM"), | |
("roformer", "RoFormerForMaskedLM"), | |
("rwkv", "RwkvForCausalLM"), | |
("speech_to_text", "Speech2TextForConditionalGeneration"), | |
("squeezebert", "SqueezeBertForMaskedLM"), | |
("switch_transformers", "SwitchTransformersForConditionalGeneration"), | |
("t5", "T5ForConditionalGeneration"), | |
("tapas", "TapasForMaskedLM"), | |
("transfo-xl", "TransfoXLLMHeadModel"), | |
("wav2vec2", "Wav2Vec2ForMaskedLM"), | |
("whisper", "WhisperForConditionalGeneration"), | |
("xlm", "XLMWithLMHeadModel"), | |
("xlm-roberta", "XLMRobertaForMaskedLM"), | |
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), | |
("xlnet", "XLNetLMHeadModel"), | |
("xmod", "XmodForMaskedLM"), | |
("yoso", "YosoForMaskedLM"), | |
] | |
) | |
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Causal LM mapping | |
("bart", "BartForCausalLM"), | |
("bert", "BertLMHeadModel"), | |
("bert-generation", "BertGenerationDecoder"), | |
("big_bird", "BigBirdForCausalLM"), | |
("bigbird_pegasus", "BigBirdPegasusForCausalLM"), | |
("biogpt", "BioGptForCausalLM"), | |
("blenderbot", "BlenderbotForCausalLM"), | |
("blenderbot-small", "BlenderbotSmallForCausalLM"), | |
("bloom", "BloomForCausalLM"), | |
("camembert", "CamembertForCausalLM"), | |
("code_llama", "LlamaForCausalLM"), | |
("codegen", "CodeGenForCausalLM"), | |
("cpmant", "CpmAntForCausalLM"), | |
("ctrl", "CTRLLMHeadModel"), | |
("data2vec-text", "Data2VecTextForCausalLM"), | |
("electra", "ElectraForCausalLM"), | |
("ernie", "ErnieForCausalLM"), | |
("falcon", "FalconForCausalLM"), | |
("git", "GitForCausalLM"), | |
("gpt-sw3", "GPT2LMHeadModel"), | |
("gpt2", "GPT2LMHeadModel"), | |
("gpt_bigcode", "GPTBigCodeForCausalLM"), | |
("gpt_neo", "GPTNeoForCausalLM"), | |
("gpt_neox", "GPTNeoXForCausalLM"), | |
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), | |
("gptj", "GPTJForCausalLM"), | |
("llama", "LlamaForCausalLM"), | |
("marian", "MarianForCausalLM"), | |
("mbart", "MBartForCausalLM"), | |
("mega", "MegaForCausalLM"), | |
("megatron-bert", "MegatronBertForCausalLM"), | |
("mistral", "MistralForCausalLM"), | |
("mpt", "MptForCausalLM"), | |
("musicgen", "MusicgenForCausalLM"), | |
("mvp", "MvpForCausalLM"), | |
("open-llama", "OpenLlamaForCausalLM"), | |
("openai-gpt", "OpenAIGPTLMHeadModel"), | |
("opt", "OPTForCausalLM"), | |
("pegasus", "PegasusForCausalLM"), | |
("persimmon", "PersimmonForCausalLM"), | |
("plbart", "PLBartForCausalLM"), | |
("prophetnet", "ProphetNetForCausalLM"), | |
("qdqbert", "QDQBertLMHeadModel"), | |
("reformer", "ReformerModelWithLMHead"), | |
("rembert", "RemBertForCausalLM"), | |
("roberta", "RobertaForCausalLM"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"), | |
("roc_bert", "RoCBertForCausalLM"), | |
("roformer", "RoFormerForCausalLM"), | |
("rwkv", "RwkvForCausalLM"), | |
("speech_to_text_2", "Speech2Text2ForCausalLM"), | |
("transfo-xl", "TransfoXLLMHeadModel"), | |
("trocr", "TrOCRForCausalLM"), | |
("xglm", "XGLMForCausalLM"), | |
("xlm", "XLMWithLMHeadModel"), | |
("xlm-prophetnet", "XLMProphetNetForCausalLM"), | |
("xlm-roberta", "XLMRobertaForCausalLM"), | |
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), | |
("xlnet", "XLNetLMHeadModel"), | |
("xmod", "XmodForCausalLM"), | |
] | |
) | |
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( | |
[ | |
("deit", "DeiTForMaskedImageModeling"), | |
("focalnet", "FocalNetForMaskedImageModeling"), | |
("swin", "SwinForMaskedImageModeling"), | |
("swinv2", "Swinv2ForMaskedImageModeling"), | |
("vit", "ViTForMaskedImageModeling"), | |
] | |
) | |
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( | |
# Model for Causal Image Modeling mapping | |
[ | |
("imagegpt", "ImageGPTForCausalImageModeling"), | |
] | |
) | |
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Image Classification mapping | |
("beit", "BeitForImageClassification"), | |
("bit", "BitForImageClassification"), | |
("convnext", "ConvNextForImageClassification"), | |
("convnextv2", "ConvNextV2ForImageClassification"), | |
("cvt", "CvtForImageClassification"), | |
("data2vec-vision", "Data2VecVisionForImageClassification"), | |
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), | |
("dinat", "DinatForImageClassification"), | |
("dinov2", "Dinov2ForImageClassification"), | |
( | |
"efficientformer", | |
( | |
"EfficientFormerForImageClassification", | |
"EfficientFormerForImageClassificationWithTeacher", | |
), | |
), | |
("efficientnet", "EfficientNetForImageClassification"), | |
("focalnet", "FocalNetForImageClassification"), | |
("imagegpt", "ImageGPTForImageClassification"), | |
("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")), | |
("mobilenet_v1", "MobileNetV1ForImageClassification"), | |
("mobilenet_v2", "MobileNetV2ForImageClassification"), | |
("mobilevit", "MobileViTForImageClassification"), | |
("mobilevitv2", "MobileViTV2ForImageClassification"), | |
("nat", "NatForImageClassification"), | |
( | |
"perceiver", | |
( | |
"PerceiverForImageClassificationLearned", | |
"PerceiverForImageClassificationFourier", | |
"PerceiverForImageClassificationConvProcessing", | |
), | |
), | |
("poolformer", "PoolFormerForImageClassification"), | |
("pvt", "PvtForImageClassification"), | |
("regnet", "RegNetForImageClassification"), | |
("resnet", "ResNetForImageClassification"), | |
("segformer", "SegformerForImageClassification"), | |
("swiftformer", "SwiftFormerForImageClassification"), | |
("swin", "SwinForImageClassification"), | |
("swinv2", "Swinv2ForImageClassification"), | |
("van", "VanForImageClassification"), | |
("vit", "ViTForImageClassification"), | |
("vit_hybrid", "ViTHybridForImageClassification"), | |
("vit_msn", "ViTMSNForImageClassification"), | |
] | |
) | |
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Do not add new models here, this class will be deprecated in the future. | |
# Model for Image Segmentation mapping | |
("detr", "DetrForSegmentation"), | |
] | |
) | |
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Semantic Segmentation mapping | |
("beit", "BeitForSemanticSegmentation"), | |
("data2vec-vision", "Data2VecVisionForSemanticSegmentation"), | |
("dpt", "DPTForSemanticSegmentation"), | |
("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"), | |
("mobilevit", "MobileViTForSemanticSegmentation"), | |
("mobilevitv2", "MobileViTV2ForSemanticSegmentation"), | |
("segformer", "SegformerForSemanticSegmentation"), | |
("upernet", "UperNetForSemanticSegmentation"), | |
] | |
) | |
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Instance Segmentation mapping | |
# MaskFormerForInstanceSegmentation can be removed from this mapping in v5 | |
("maskformer", "MaskFormerForInstanceSegmentation"), | |
] | |
) | |
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Universal Segmentation mapping | |
("detr", "DetrForSegmentation"), | |
("mask2former", "Mask2FormerForUniversalSegmentation"), | |
("maskformer", "MaskFormerForInstanceSegmentation"), | |
("oneformer", "OneFormerForUniversalSegmentation"), | |
] | |
) | |
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
("timesformer", "TimesformerForVideoClassification"), | |
("videomae", "VideoMAEForVideoClassification"), | |
("vivit", "VivitForVideoClassification"), | |
] | |
) | |
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( | |
[ | |
("blip", "BlipForConditionalGeneration"), | |
("blip-2", "Blip2ForConditionalGeneration"), | |
("git", "GitForCausalLM"), | |
("instructblip", "InstructBlipForConditionalGeneration"), | |
("pix2struct", "Pix2StructForConditionalGeneration"), | |
("vision-encoder-decoder", "VisionEncoderDecoderModel"), | |
] | |
) | |
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Masked LM mapping | |
("albert", "AlbertForMaskedLM"), | |
("bart", "BartForConditionalGeneration"), | |
("bert", "BertForMaskedLM"), | |
("big_bird", "BigBirdForMaskedLM"), | |
("camembert", "CamembertForMaskedLM"), | |
("convbert", "ConvBertForMaskedLM"), | |
("data2vec-text", "Data2VecTextForMaskedLM"), | |
("deberta", "DebertaForMaskedLM"), | |
("deberta-v2", "DebertaV2ForMaskedLM"), | |
("distilbert", "DistilBertForMaskedLM"), | |
("electra", "ElectraForMaskedLM"), | |
("ernie", "ErnieForMaskedLM"), | |
("esm", "EsmForMaskedLM"), | |
("flaubert", "FlaubertWithLMHeadModel"), | |
("fnet", "FNetForMaskedLM"), | |
("funnel", "FunnelForMaskedLM"), | |
("ibert", "IBertForMaskedLM"), | |
("layoutlm", "LayoutLMForMaskedLM"), | |
("longformer", "LongformerForMaskedLM"), | |
("luke", "LukeForMaskedLM"), | |
("mbart", "MBartForConditionalGeneration"), | |
("mega", "MegaForMaskedLM"), | |
("megatron-bert", "MegatronBertForMaskedLM"), | |
("mobilebert", "MobileBertForMaskedLM"), | |
("mpnet", "MPNetForMaskedLM"), | |
("mra", "MraForMaskedLM"), | |
("mvp", "MvpForConditionalGeneration"), | |
("nezha", "NezhaForMaskedLM"), | |
("nystromformer", "NystromformerForMaskedLM"), | |
("perceiver", "PerceiverForMaskedLM"), | |
("qdqbert", "QDQBertForMaskedLM"), | |
("reformer", "ReformerForMaskedLM"), | |
("rembert", "RemBertForMaskedLM"), | |
("roberta", "RobertaForMaskedLM"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), | |
("roc_bert", "RoCBertForMaskedLM"), | |
("roformer", "RoFormerForMaskedLM"), | |
("squeezebert", "SqueezeBertForMaskedLM"), | |
("tapas", "TapasForMaskedLM"), | |
("wav2vec2", "Wav2Vec2ForMaskedLM"), | |
("xlm", "XLMWithLMHeadModel"), | |
("xlm-roberta", "XLMRobertaForMaskedLM"), | |
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), | |
("xmod", "XmodForMaskedLM"), | |
("yoso", "YosoForMaskedLM"), | |
] | |
) | |
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Object Detection mapping | |
("conditional_detr", "ConditionalDetrForObjectDetection"), | |
("deformable_detr", "DeformableDetrForObjectDetection"), | |
("deta", "DetaForObjectDetection"), | |
("detr", "DetrForObjectDetection"), | |
("table-transformer", "TableTransformerForObjectDetection"), | |
("yolos", "YolosForObjectDetection"), | |
] | |
) | |
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Zero Shot Object Detection mapping | |
("owlvit", "OwlViTForObjectDetection") | |
] | |
) | |
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for depth estimation mapping | |
("dpt", "DPTForDepthEstimation"), | |
("glpn", "GLPNForDepthEstimation"), | |
] | |
) | |
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Seq2Seq Causal LM mapping | |
("bart", "BartForConditionalGeneration"), | |
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), | |
("blenderbot", "BlenderbotForConditionalGeneration"), | |
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), | |
("encoder-decoder", "EncoderDecoderModel"), | |
("fsmt", "FSMTForConditionalGeneration"), | |
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), | |
("led", "LEDForConditionalGeneration"), | |
("longt5", "LongT5ForConditionalGeneration"), | |
("m2m_100", "M2M100ForConditionalGeneration"), | |
("marian", "MarianMTModel"), | |
("mbart", "MBartForConditionalGeneration"), | |
("mt5", "MT5ForConditionalGeneration"), | |
("mvp", "MvpForConditionalGeneration"), | |
("nllb-moe", "NllbMoeForConditionalGeneration"), | |
("pegasus", "PegasusForConditionalGeneration"), | |
("pegasus_x", "PegasusXForConditionalGeneration"), | |
("plbart", "PLBartForConditionalGeneration"), | |
("prophetnet", "ProphetNetForConditionalGeneration"), | |
("switch_transformers", "SwitchTransformersForConditionalGeneration"), | |
("t5", "T5ForConditionalGeneration"), | |
("umt5", "UMT5ForConditionalGeneration"), | |
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), | |
] | |
) | |
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( | |
[ | |
("pop2piano", "Pop2PianoForConditionalGeneration"), | |
("speech-encoder-decoder", "SpeechEncoderDecoderModel"), | |
("speech_to_text", "Speech2TextForConditionalGeneration"), | |
("speecht5", "SpeechT5ForSpeechToText"), | |
("whisper", "WhisperForConditionalGeneration"), | |
] | |
) | |
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Sequence Classification mapping | |
("albert", "AlbertForSequenceClassification"), | |
("bart", "BartForSequenceClassification"), | |
("bert", "BertForSequenceClassification"), | |
("big_bird", "BigBirdForSequenceClassification"), | |
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), | |
("biogpt", "BioGptForSequenceClassification"), | |
("bloom", "BloomForSequenceClassification"), | |
("camembert", "CamembertForSequenceClassification"), | |
("canine", "CanineForSequenceClassification"), | |
("code_llama", "LlamaForSequenceClassification"), | |
("convbert", "ConvBertForSequenceClassification"), | |
("ctrl", "CTRLForSequenceClassification"), | |
("data2vec-text", "Data2VecTextForSequenceClassification"), | |
("deberta", "DebertaForSequenceClassification"), | |
("deberta-v2", "DebertaV2ForSequenceClassification"), | |
("distilbert", "DistilBertForSequenceClassification"), | |
("electra", "ElectraForSequenceClassification"), | |
("ernie", "ErnieForSequenceClassification"), | |
("ernie_m", "ErnieMForSequenceClassification"), | |
("esm", "EsmForSequenceClassification"), | |
("falcon", "FalconForSequenceClassification"), | |
("flaubert", "FlaubertForSequenceClassification"), | |
("fnet", "FNetForSequenceClassification"), | |
("funnel", "FunnelForSequenceClassification"), | |
("gpt-sw3", "GPT2ForSequenceClassification"), | |
("gpt2", "GPT2ForSequenceClassification"), | |
("gpt_bigcode", "GPTBigCodeForSequenceClassification"), | |
("gpt_neo", "GPTNeoForSequenceClassification"), | |
("gpt_neox", "GPTNeoXForSequenceClassification"), | |
("gptj", "GPTJForSequenceClassification"), | |
("ibert", "IBertForSequenceClassification"), | |
("layoutlm", "LayoutLMForSequenceClassification"), | |
("layoutlmv2", "LayoutLMv2ForSequenceClassification"), | |
("layoutlmv3", "LayoutLMv3ForSequenceClassification"), | |
("led", "LEDForSequenceClassification"), | |
("lilt", "LiltForSequenceClassification"), | |
("llama", "LlamaForSequenceClassification"), | |
("longformer", "LongformerForSequenceClassification"), | |
("luke", "LukeForSequenceClassification"), | |
("markuplm", "MarkupLMForSequenceClassification"), | |
("mbart", "MBartForSequenceClassification"), | |
("mega", "MegaForSequenceClassification"), | |
("megatron-bert", "MegatronBertForSequenceClassification"), | |
("mistral", "MistralForSequenceClassification"), | |
("mobilebert", "MobileBertForSequenceClassification"), | |
("mpnet", "MPNetForSequenceClassification"), | |
("mpt", "MptForSequenceClassification"), | |
("mra", "MraForSequenceClassification"), | |
("mt5", "MT5ForSequenceClassification"), | |
("mvp", "MvpForSequenceClassification"), | |
("nezha", "NezhaForSequenceClassification"), | |
("nystromformer", "NystromformerForSequenceClassification"), | |
("open-llama", "OpenLlamaForSequenceClassification"), | |
("openai-gpt", "OpenAIGPTForSequenceClassification"), | |
("opt", "OPTForSequenceClassification"), | |
("perceiver", "PerceiverForSequenceClassification"), | |
("persimmon", "PersimmonForSequenceClassification"), | |
("plbart", "PLBartForSequenceClassification"), | |
("qdqbert", "QDQBertForSequenceClassification"), | |
("reformer", "ReformerForSequenceClassification"), | |
("rembert", "RemBertForSequenceClassification"), | |
("roberta", "RobertaForSequenceClassification"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"), | |
("roc_bert", "RoCBertForSequenceClassification"), | |
("roformer", "RoFormerForSequenceClassification"), | |
("squeezebert", "SqueezeBertForSequenceClassification"), | |
("t5", "T5ForSequenceClassification"), | |
("tapas", "TapasForSequenceClassification"), | |
("transfo-xl", "TransfoXLForSequenceClassification"), | |
("umt5", "UMT5ForSequenceClassification"), | |
("xlm", "XLMForSequenceClassification"), | |
("xlm-roberta", "XLMRobertaForSequenceClassification"), | |
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"), | |
("xlnet", "XLNetForSequenceClassification"), | |
("xmod", "XmodForSequenceClassification"), | |
("yoso", "YosoForSequenceClassification"), | |
] | |
) | |
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Question Answering mapping | |
("albert", "AlbertForQuestionAnswering"), | |
("bart", "BartForQuestionAnswering"), | |
("bert", "BertForQuestionAnswering"), | |
("big_bird", "BigBirdForQuestionAnswering"), | |
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), | |
("bloom", "BloomForQuestionAnswering"), | |
("camembert", "CamembertForQuestionAnswering"), | |
("canine", "CanineForQuestionAnswering"), | |
("convbert", "ConvBertForQuestionAnswering"), | |
("data2vec-text", "Data2VecTextForQuestionAnswering"), | |
("deberta", "DebertaForQuestionAnswering"), | |
("deberta-v2", "DebertaV2ForQuestionAnswering"), | |
("distilbert", "DistilBertForQuestionAnswering"), | |
("electra", "ElectraForQuestionAnswering"), | |
("ernie", "ErnieForQuestionAnswering"), | |
("ernie_m", "ErnieMForQuestionAnswering"), | |
("falcon", "FalconForQuestionAnswering"), | |
("flaubert", "FlaubertForQuestionAnsweringSimple"), | |
("fnet", "FNetForQuestionAnswering"), | |
("funnel", "FunnelForQuestionAnswering"), | |
("gpt2", "GPT2ForQuestionAnswering"), | |
("gpt_neo", "GPTNeoForQuestionAnswering"), | |
("gpt_neox", "GPTNeoXForQuestionAnswering"), | |
("gptj", "GPTJForQuestionAnswering"), | |
("ibert", "IBertForQuestionAnswering"), | |
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), | |
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), | |
("led", "LEDForQuestionAnswering"), | |
("lilt", "LiltForQuestionAnswering"), | |
("longformer", "LongformerForQuestionAnswering"), | |
("luke", "LukeForQuestionAnswering"), | |
("lxmert", "LxmertForQuestionAnswering"), | |
("markuplm", "MarkupLMForQuestionAnswering"), | |
("mbart", "MBartForQuestionAnswering"), | |
("mega", "MegaForQuestionAnswering"), | |
("megatron-bert", "MegatronBertForQuestionAnswering"), | |
("mobilebert", "MobileBertForQuestionAnswering"), | |
("mpnet", "MPNetForQuestionAnswering"), | |
("mpt", "MptForQuestionAnswering"), | |
("mra", "MraForQuestionAnswering"), | |
("mt5", "MT5ForQuestionAnswering"), | |
("mvp", "MvpForQuestionAnswering"), | |
("nezha", "NezhaForQuestionAnswering"), | |
("nystromformer", "NystromformerForQuestionAnswering"), | |
("opt", "OPTForQuestionAnswering"), | |
("qdqbert", "QDQBertForQuestionAnswering"), | |
("reformer", "ReformerForQuestionAnswering"), | |
("rembert", "RemBertForQuestionAnswering"), | |
("roberta", "RobertaForQuestionAnswering"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"), | |
("roc_bert", "RoCBertForQuestionAnswering"), | |
("roformer", "RoFormerForQuestionAnswering"), | |
("splinter", "SplinterForQuestionAnswering"), | |
("squeezebert", "SqueezeBertForQuestionAnswering"), | |
("t5", "T5ForQuestionAnswering"), | |
("umt5", "UMT5ForQuestionAnswering"), | |
("xlm", "XLMForQuestionAnsweringSimple"), | |
("xlm-roberta", "XLMRobertaForQuestionAnswering"), | |
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"), | |
("xlnet", "XLNetForQuestionAnsweringSimple"), | |
("xmod", "XmodForQuestionAnswering"), | |
("yoso", "YosoForQuestionAnswering"), | |
] | |
) | |
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Table Question Answering mapping | |
("tapas", "TapasForQuestionAnswering"), | |
] | |
) | |
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
[ | |
("blip-2", "Blip2ForConditionalGeneration"), | |
("vilt", "ViltForQuestionAnswering"), | |
] | |
) | |
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
[ | |
("layoutlm", "LayoutLMForQuestionAnswering"), | |
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), | |
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), | |
] | |
) | |
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Token Classification mapping | |
("albert", "AlbertForTokenClassification"), | |
("bert", "BertForTokenClassification"), | |
("big_bird", "BigBirdForTokenClassification"), | |
("biogpt", "BioGptForTokenClassification"), | |
("bloom", "BloomForTokenClassification"), | |
("bros", "BrosForTokenClassification"), | |
("camembert", "CamembertForTokenClassification"), | |
("canine", "CanineForTokenClassification"), | |
("convbert", "ConvBertForTokenClassification"), | |
("data2vec-text", "Data2VecTextForTokenClassification"), | |
("deberta", "DebertaForTokenClassification"), | |
("deberta-v2", "DebertaV2ForTokenClassification"), | |
("distilbert", "DistilBertForTokenClassification"), | |
("electra", "ElectraForTokenClassification"), | |
("ernie", "ErnieForTokenClassification"), | |
("ernie_m", "ErnieMForTokenClassification"), | |
("esm", "EsmForTokenClassification"), | |
("falcon", "FalconForTokenClassification"), | |
("flaubert", "FlaubertForTokenClassification"), | |
("fnet", "FNetForTokenClassification"), | |
("funnel", "FunnelForTokenClassification"), | |
("gpt-sw3", "GPT2ForTokenClassification"), | |
("gpt2", "GPT2ForTokenClassification"), | |
("gpt_bigcode", "GPTBigCodeForTokenClassification"), | |
("gpt_neo", "GPTNeoForTokenClassification"), | |
("gpt_neox", "GPTNeoXForTokenClassification"), | |
("ibert", "IBertForTokenClassification"), | |
("layoutlm", "LayoutLMForTokenClassification"), | |
("layoutlmv2", "LayoutLMv2ForTokenClassification"), | |
("layoutlmv3", "LayoutLMv3ForTokenClassification"), | |
("lilt", "LiltForTokenClassification"), | |
("longformer", "LongformerForTokenClassification"), | |
("luke", "LukeForTokenClassification"), | |
("markuplm", "MarkupLMForTokenClassification"), | |
("mega", "MegaForTokenClassification"), | |
("megatron-bert", "MegatronBertForTokenClassification"), | |
("mobilebert", "MobileBertForTokenClassification"), | |
("mpnet", "MPNetForTokenClassification"), | |
("mpt", "MptForTokenClassification"), | |
("mra", "MraForTokenClassification"), | |
("nezha", "NezhaForTokenClassification"), | |
("nystromformer", "NystromformerForTokenClassification"), | |
("qdqbert", "QDQBertForTokenClassification"), | |
("rembert", "RemBertForTokenClassification"), | |
("roberta", "RobertaForTokenClassification"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), | |
("roc_bert", "RoCBertForTokenClassification"), | |
("roformer", "RoFormerForTokenClassification"), | |
("squeezebert", "SqueezeBertForTokenClassification"), | |
("xlm", "XLMForTokenClassification"), | |
("xlm-roberta", "XLMRobertaForTokenClassification"), | |
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"), | |
("xlnet", "XLNetForTokenClassification"), | |
("xmod", "XmodForTokenClassification"), | |
("yoso", "YosoForTokenClassification"), | |
] | |
) | |
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Multiple Choice mapping | |
("albert", "AlbertForMultipleChoice"), | |
("bert", "BertForMultipleChoice"), | |
("big_bird", "BigBirdForMultipleChoice"), | |
("camembert", "CamembertForMultipleChoice"), | |
("canine", "CanineForMultipleChoice"), | |
("convbert", "ConvBertForMultipleChoice"), | |
("data2vec-text", "Data2VecTextForMultipleChoice"), | |
("deberta-v2", "DebertaV2ForMultipleChoice"), | |
("distilbert", "DistilBertForMultipleChoice"), | |
("electra", "ElectraForMultipleChoice"), | |
("ernie", "ErnieForMultipleChoice"), | |
("ernie_m", "ErnieMForMultipleChoice"), | |
("flaubert", "FlaubertForMultipleChoice"), | |
("fnet", "FNetForMultipleChoice"), | |
("funnel", "FunnelForMultipleChoice"), | |
("ibert", "IBertForMultipleChoice"), | |
("longformer", "LongformerForMultipleChoice"), | |
("luke", "LukeForMultipleChoice"), | |
("mega", "MegaForMultipleChoice"), | |
("megatron-bert", "MegatronBertForMultipleChoice"), | |
("mobilebert", "MobileBertForMultipleChoice"), | |
("mpnet", "MPNetForMultipleChoice"), | |
("mra", "MraForMultipleChoice"), | |
("nezha", "NezhaForMultipleChoice"), | |
("nystromformer", "NystromformerForMultipleChoice"), | |
("qdqbert", "QDQBertForMultipleChoice"), | |
("rembert", "RemBertForMultipleChoice"), | |
("roberta", "RobertaForMultipleChoice"), | |
("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"), | |
("roc_bert", "RoCBertForMultipleChoice"), | |
("roformer", "RoFormerForMultipleChoice"), | |
("squeezebert", "SqueezeBertForMultipleChoice"), | |
("xlm", "XLMForMultipleChoice"), | |
("xlm-roberta", "XLMRobertaForMultipleChoice"), | |
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"), | |
("xlnet", "XLNetForMultipleChoice"), | |
("xmod", "XmodForMultipleChoice"), | |
("yoso", "YosoForMultipleChoice"), | |
] | |
) | |
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( | |
[ | |
("bert", "BertForNextSentencePrediction"), | |
("ernie", "ErnieForNextSentencePrediction"), | |
("fnet", "FNetForNextSentencePrediction"), | |
("megatron-bert", "MegatronBertForNextSentencePrediction"), | |
("mobilebert", "MobileBertForNextSentencePrediction"), | |
("nezha", "NezhaForNextSentencePrediction"), | |
("qdqbert", "QDQBertForNextSentencePrediction"), | |
] | |
) | |
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Audio Classification mapping | |
("audio-spectrogram-transformer", "ASTForAudioClassification"), | |
("data2vec-audio", "Data2VecAudioForSequenceClassification"), | |
("hubert", "HubertForSequenceClassification"), | |
("sew", "SEWForSequenceClassification"), | |
("sew-d", "SEWDForSequenceClassification"), | |
("unispeech", "UniSpeechForSequenceClassification"), | |
("unispeech-sat", "UniSpeechSatForSequenceClassification"), | |
("wav2vec2", "Wav2Vec2ForSequenceClassification"), | |
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"), | |
("wavlm", "WavLMForSequenceClassification"), | |
("whisper", "WhisperForAudioClassification"), | |
] | |
) | |
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Connectionist temporal classification (CTC) mapping | |
("data2vec-audio", "Data2VecAudioForCTC"), | |
("hubert", "HubertForCTC"), | |
("mctct", "MCTCTForCTC"), | |
("sew", "SEWForCTC"), | |
("sew-d", "SEWDForCTC"), | |
("unispeech", "UniSpeechForCTC"), | |
("unispeech-sat", "UniSpeechSatForCTC"), | |
("wav2vec2", "Wav2Vec2ForCTC"), | |
("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"), | |
("wavlm", "WavLMForCTC"), | |
] | |
) | |
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Audio Classification mapping | |
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"), | |
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), | |
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), | |
("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"), | |
("wavlm", "WavLMForAudioFrameClassification"), | |
] | |
) | |
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Audio Classification mapping | |
("data2vec-audio", "Data2VecAudioForXVector"), | |
("unispeech-sat", "UniSpeechSatForXVector"), | |
("wav2vec2", "Wav2Vec2ForXVector"), | |
("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"), | |
("wavlm", "WavLMForXVector"), | |
] | |
) | |
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Text-To-Spectrogram mapping | |
("speecht5", "SpeechT5ForTextToSpeech"), | |
] | |
) | |
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Text-To-Waveform mapping | |
("bark", "BarkModel"), | |
("musicgen", "MusicgenForConditionalGeneration"), | |
("vits", "VitsModel"), | |
] | |
) | |
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Zero Shot Image Classification mapping | |
("align", "AlignModel"), | |
("altclip", "AltCLIPModel"), | |
("blip", "BlipModel"), | |
("chinese_clip", "ChineseCLIPModel"), | |
("clip", "CLIPModel"), | |
("clipseg", "CLIPSegModel"), | |
] | |
) | |
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( | |
[ | |
# Backbone mapping | |
("bit", "BitBackbone"), | |
("convnext", "ConvNextBackbone"), | |
("convnextv2", "ConvNextV2Backbone"), | |
("dinat", "DinatBackbone"), | |
("dinov2", "Dinov2Backbone"), | |
("focalnet", "FocalNetBackbone"), | |
("maskformer-swin", "MaskFormerSwinBackbone"), | |
("nat", "NatBackbone"), | |
("resnet", "ResNetBackbone"), | |
("swin", "SwinBackbone"), | |
("timm_backbone", "TimmBackbone"), | |
("vitdet", "VitDetBackbone"), | |
] | |
) | |
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( | |
[ | |
("sam", "SamModel"), | |
] | |
) | |
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( | |
[ | |
("albert", "AlbertModel"), | |
("bert", "BertModel"), | |
("big_bird", "BigBirdModel"), | |
("data2vec-text", "Data2VecTextModel"), | |
("deberta", "DebertaModel"), | |
("deberta-v2", "DebertaV2Model"), | |
("distilbert", "DistilBertModel"), | |
("electra", "ElectraModel"), | |
("flaubert", "FlaubertModel"), | |
("ibert", "IBertModel"), | |
("longformer", "LongformerModel"), | |
("mobilebert", "MobileBertModel"), | |
("mt5", "MT5EncoderModel"), | |
("nystromformer", "NystromformerModel"), | |
("reformer", "ReformerModel"), | |
("rembert", "RemBertModel"), | |
("roberta", "RobertaModel"), | |
("roberta-prelayernorm", "RobertaPreLayerNormModel"), | |
("roc_bert", "RoCBertModel"), | |
("roformer", "RoFormerModel"), | |
("squeezebert", "SqueezeBertModel"), | |
("t5", "T5EncoderModel"), | |
("umt5", "UMT5EncoderModel"), | |
("xlm", "XLMModel"), | |
("xlm-roberta", "XLMRobertaModel"), | |
("xlm-roberta-xl", "XLMRobertaXLModel"), | |
] | |
) | |
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( | |
[ | |
("swin2sr", "Swin2SRForImageSuperResolution"), | |
] | |
) | |
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) | |
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) | |
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) | |
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) | |
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES | |
) | |
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES | |
) | |
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES | |
) | |
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES | |
) | |
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES | |
) | |
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) | |
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES | |
) | |
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES | |
) | |
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) | |
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES | |
) | |
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) | |
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES | |
) | |
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES) | |
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES | |
) | |
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES | |
) | |
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES | |
) | |
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) | |
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES | |
) | |
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) | |
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) | |
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) | |
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES | |
) | |
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) | |
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) | |
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) | |
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) | |
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) | |
class AutoModelForMaskGeneration(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING | |
class AutoModelForTextEncoding(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING | |
class AutoModelForImageToImage(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING | |
class AutoModel(_BaseAutoModelClass): | |
_model_mapping = MODEL_MAPPING | |
AutoModel = auto_class_update(AutoModel) | |
class AutoModelForPreTraining(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING | |
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") | |
# Private on purpose, the public class will add the deprecation warnings. | |
class _AutoModelWithLMHead(_BaseAutoModelClass): | |
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING | |
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") | |
class AutoModelForCausalLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING | |
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | |
class AutoModelForMaskedLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING | |
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") | |
class AutoModelForSeq2SeqLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
AutoModelForSeq2SeqLM = auto_class_update( | |
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |
) | |
class AutoModelForSequenceClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |
AutoModelForSequenceClassification = auto_class_update( | |
AutoModelForSequenceClassification, head_doc="sequence classification" | |
) | |
class AutoModelForQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING | |
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") | |
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING | |
AutoModelForTableQuestionAnswering = auto_class_update( | |
AutoModelForTableQuestionAnswering, | |
head_doc="table question answering", | |
checkpoint_for_example="google/tapas-base-finetuned-wtq", | |
) | |
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING | |
AutoModelForVisualQuestionAnswering = auto_class_update( | |
AutoModelForVisualQuestionAnswering, | |
head_doc="visual question answering", | |
checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa", | |
) | |
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING | |
AutoModelForDocumentQuestionAnswering = auto_class_update( | |
AutoModelForDocumentQuestionAnswering, | |
head_doc="document question answering", | |
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', | |
) | |
class AutoModelForTokenClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") | |
class AutoModelForMultipleChoice(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") | |
class AutoModelForNextSentencePrediction(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING | |
AutoModelForNextSentencePrediction = auto_class_update( | |
AutoModelForNextSentencePrediction, head_doc="next sentence prediction" | |
) | |
class AutoModelForImageClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING | |
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") | |
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING | |
AutoModelForZeroShotImageClassification = auto_class_update( | |
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" | |
) | |
class AutoModelForImageSegmentation(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING | |
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") | |
class AutoModelForSemanticSegmentation(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING | |
AutoModelForSemanticSegmentation = auto_class_update( | |
AutoModelForSemanticSegmentation, head_doc="semantic segmentation" | |
) | |
class AutoModelForUniversalSegmentation(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING | |
AutoModelForUniversalSegmentation = auto_class_update( | |
AutoModelForUniversalSegmentation, head_doc="universal image segmentation" | |
) | |
class AutoModelForInstanceSegmentation(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING | |
AutoModelForInstanceSegmentation = auto_class_update( | |
AutoModelForInstanceSegmentation, head_doc="instance segmentation" | |
) | |
class AutoModelForObjectDetection(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING | |
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") | |
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING | |
AutoModelForZeroShotObjectDetection = auto_class_update( | |
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection" | |
) | |
class AutoModelForDepthEstimation(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING | |
AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") | |
class AutoModelForVideoClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING | |
AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification") | |
class AutoModelForVision2Seq(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING | |
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling") | |
class AutoModelForAudioClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING | |
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") | |
class AutoModelForCTC(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_CTC_MAPPING | |
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") | |
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING | |
AutoModelForSpeechSeq2Seq = auto_class_update( | |
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" | |
) | |
class AutoModelForAudioFrameClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING | |
AutoModelForAudioFrameClassification = auto_class_update( | |
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" | |
) | |
class AutoModelForAudioXVector(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING | |
class AutoModelForTextToSpectrogram(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING | |
class AutoModelForTextToWaveform(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING | |
class AutoBackbone(_BaseAutoBackboneClass): | |
_model_mapping = MODEL_FOR_BACKBONE_MAPPING | |
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") | |
class AutoModelForMaskedImageModeling(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING | |
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") | |
class AutoModelWithLMHead(_AutoModelWithLMHead): | |
def from_config(cls, config): | |
warnings.warn( | |
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_config(config) | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
warnings.warn( | |
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |