Spaces:
Runtime error
Runtime error
File size: 5,544 Bytes
a1d409e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib
import io
import unittest
import transformers
# Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, require_tf, require_torch
from transformers.utils import ContextManagers, find_labels, is_flax_available, is_tf_available, is_torch_available
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
if is_tf_available():
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
if is_flax_available():
from transformers import FlaxBertForPreTraining, FlaxBertForQuestionAnswering, FlaxBertForSequenceClassification
MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER
# An actual model hosted on huggingface.co
REVISION_ID_DEFAULT = "main"
# Default branch name
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
# One particular commit (not the top of `main`)
REVISION_ID_INVALID = "aaaaaaa"
# This commit does not exist, so we should 404.
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
# Sha-1 of config.json on the top of `main`, for checking purposes
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
# Dummy contexts to test `ContextManagers`
@contextlib.contextmanager
def context_en():
print("Welcome!")
yield
print("Bye!")
@contextlib.contextmanager
def context_fr():
print("Bonjour!")
yield
print("Au revoir!")
class TestImportMechanisms(unittest.TestCase):
def test_module_spec_available(self):
# If the spec is missing, importlib would not be able to import the module dynamically.
assert transformers.__spec__ is not None
assert importlib.util.find_spec("transformers") is not None
class GenericUtilTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_context_managers_no_context(self, mock_stdout):
with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_context_managers_one_context(self, mock_stdout):
with ContextManagers([context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_context_managers_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
@require_torch
def test_find_labels_pt(self):
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class DummyModel(BertForSequenceClassification):
pass
self.assertEqual(find_labels(DummyModel), ["labels"])
@require_tf
def test_find_labels_tf(self):
self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class DummyModel(TFBertForSequenceClassification):
pass
self.assertEqual(find_labels(DummyModel), ["labels"])
@require_flax
def test_find_labels_flax(self):
# Flax models don't have labels
self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class DummyModel(FlaxBertForSequenceClassification):
pass
self.assertEqual(find_labels(DummyModel), [])
|