chendl's picture
add requirements
a1d409e
from tempfile import TemporaryDirectory
from unittest import TestCase
from unittest.mock import MagicMock, patch
from transformers import AutoModel, TFAutoModel
from transformers.onnx import FeaturesManager
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch
@require_torch
@require_tf
class DetermineFrameworkTest(TestCase):
"""
Test `FeaturesManager.determine_framework`
"""
def setUp(self):
self.test_model = SMALL_MODEL_IDENTIFIER
self.framework_pt = "pt"
self.framework_tf = "tf"
def _setup_pt_ckpt(self, save_dir):
model_pt = AutoModel.from_pretrained(self.test_model)
model_pt.save_pretrained(save_dir)
def _setup_tf_ckpt(self, save_dir):
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True)
model_tf.save_pretrained(save_dir)
def test_framework_provided(self):
"""
Ensure the that the provided framework is returned.
"""
mock_framework = "mock_framework"
# Framework provided - return whatever the user provides
result = FeaturesManager.determine_framework(self.test_model, mock_framework)
self.assertEqual(result, mock_framework)
# Local checkpoint and framework provided - return provided framework
# PyTorch checkpoint
with TemporaryDirectory() as local_pt_ckpt:
self._setup_pt_ckpt(local_pt_ckpt)
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework)
self.assertEqual(result, mock_framework)
# TensorFlow checkpoint
with TemporaryDirectory() as local_tf_ckpt:
self._setup_tf_ckpt(local_tf_ckpt)
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework)
self.assertEqual(result, mock_framework)
def test_checkpoint_provided(self):
"""
Ensure that the determined framework is the one used for the local checkpoint.
For the functionality to execute, local checkpoints are provided but framework is not.
"""
# PyTorch checkpoint
with TemporaryDirectory() as local_pt_ckpt:
self._setup_pt_ckpt(local_pt_ckpt)
result = FeaturesManager.determine_framework(local_pt_ckpt)
self.assertEqual(result, self.framework_pt)
# TensorFlow checkpoint
with TemporaryDirectory() as local_tf_ckpt:
self._setup_tf_ckpt(local_tf_ckpt)
result = FeaturesManager.determine_framework(local_tf_ckpt)
self.assertEqual(result, self.framework_tf)
# Invalid local checkpoint
with TemporaryDirectory() as local_invalid_ckpt:
with self.assertRaises(FileNotFoundError):
result = FeaturesManager.determine_framework(local_invalid_ckpt)
def test_from_environment(self):
"""
Ensure that the determined framework is the one available in the environment.
For the functionality to execute, framework and local checkpoints are not provided.
"""
# Framework not provided, hub model is used (no local checkpoint directory)
# TensorFlow not in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_pt)
# PyTorch not in environment -> use TensorFlow
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_tf)
# Both in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=True)
mock_torch_available = MagicMock(return_value=True)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_pt)
# Both not in environment -> raise error
mock_tf_available = MagicMock(return_value=False)
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
with self.assertRaises(EnvironmentError):
result = FeaturesManager.determine_framework(self.test_model)