File size: 4,810 Bytes
a1d409e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from tempfile import TemporaryDirectory
from unittest import TestCase
from unittest.mock import MagicMock, patch

from transformers import AutoModel, TFAutoModel
from transformers.onnx import FeaturesManager
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch


@require_torch
@require_tf
class DetermineFrameworkTest(TestCase):
    """
    Test `FeaturesManager.determine_framework`
    """

    def setUp(self):
        self.test_model = SMALL_MODEL_IDENTIFIER
        self.framework_pt = "pt"
        self.framework_tf = "tf"

    def _setup_pt_ckpt(self, save_dir):
        model_pt = AutoModel.from_pretrained(self.test_model)
        model_pt.save_pretrained(save_dir)

    def _setup_tf_ckpt(self, save_dir):
        model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True)
        model_tf.save_pretrained(save_dir)

    def test_framework_provided(self):
        """
        Ensure the that the provided framework is returned.
        """
        mock_framework = "mock_framework"

        # Framework provided - return whatever the user provides
        result = FeaturesManager.determine_framework(self.test_model, mock_framework)
        self.assertEqual(result, mock_framework)

        # Local checkpoint and framework provided - return provided framework
        # PyTorch checkpoint
        with TemporaryDirectory() as local_pt_ckpt:
            self._setup_pt_ckpt(local_pt_ckpt)
            result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework)
            self.assertEqual(result, mock_framework)

        # TensorFlow checkpoint
        with TemporaryDirectory() as local_tf_ckpt:
            self._setup_tf_ckpt(local_tf_ckpt)
            result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework)
            self.assertEqual(result, mock_framework)

    def test_checkpoint_provided(self):
        """
        Ensure that the determined framework is the one used for the local checkpoint.

        For the functionality to execute, local checkpoints are provided but framework is not.
        """
        # PyTorch checkpoint
        with TemporaryDirectory() as local_pt_ckpt:
            self._setup_pt_ckpt(local_pt_ckpt)
            result = FeaturesManager.determine_framework(local_pt_ckpt)
            self.assertEqual(result, self.framework_pt)

        # TensorFlow checkpoint
        with TemporaryDirectory() as local_tf_ckpt:
            self._setup_tf_ckpt(local_tf_ckpt)
            result = FeaturesManager.determine_framework(local_tf_ckpt)
            self.assertEqual(result, self.framework_tf)

        # Invalid local checkpoint
        with TemporaryDirectory() as local_invalid_ckpt:
            with self.assertRaises(FileNotFoundError):
                result = FeaturesManager.determine_framework(local_invalid_ckpt)

    def test_from_environment(self):
        """
        Ensure that the determined framework is the one available in the environment.

        For the functionality to execute, framework and local checkpoints are not provided.
        """
        # Framework not provided, hub model is used (no local checkpoint directory)
        # TensorFlow not in environment -> use PyTorch
        mock_tf_available = MagicMock(return_value=False)
        with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
            result = FeaturesManager.determine_framework(self.test_model)
            self.assertEqual(result, self.framework_pt)

        # PyTorch not in environment -> use TensorFlow
        mock_torch_available = MagicMock(return_value=False)
        with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
            result = FeaturesManager.determine_framework(self.test_model)
            self.assertEqual(result, self.framework_tf)

        # Both in environment -> use PyTorch
        mock_tf_available = MagicMock(return_value=True)
        mock_torch_available = MagicMock(return_value=True)
        with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
            "transformers.onnx.features.is_torch_available", mock_torch_available
        ):
            result = FeaturesManager.determine_framework(self.test_model)
            self.assertEqual(result, self.framework_pt)

        # Both not in environment -> raise error
        mock_tf_available = MagicMock(return_value=False)
        mock_torch_available = MagicMock(return_value=False)
        with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
            "transformers.onnx.features.is_torch_available", mock_torch_available
        ):
            with self.assertRaises(EnvironmentError):
                result = FeaturesManager.determine_framework(self.test_model)