Spaces:
Runtime error
Runtime error
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
import shutil | |
import warnings | |
from argparse import ArgumentParser, Namespace | |
from pathlib import Path | |
from typing import List | |
from ..utils import logging | |
from . import BaseTransformersCLICommand | |
try: | |
from cookiecutter.main import cookiecutter | |
_has_cookiecutter = True | |
except ImportError: | |
_has_cookiecutter = False | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def add_new_model_command_factory(args: Namespace): | |
return AddNewModelCommand(args.testing, args.testing_file, path=args.path) | |
class AddNewModelCommand(BaseTransformersCLICommand): | |
def register_subcommand(parser: ArgumentParser): | |
add_new_model_parser = parser.add_parser("add-new-model") | |
add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.") | |
add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.") | |
add_new_model_parser.add_argument( | |
"--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes." | |
) | |
add_new_model_parser.set_defaults(func=add_new_model_command_factory) | |
def __init__(self, testing: bool, testing_file: str, path=None, *args): | |
self._testing = testing | |
self._testing_file = testing_file | |
self._path = path | |
def run(self): | |
warnings.warn( | |
"The command `transformers-cli add-new-model` is deprecated and will be removed in v5 of Transformers. " | |
"It is not actively maintained anymore, so might give a result that won't pass all tests and quality " | |
"checks, you should use `transformers-cli add-new-model-like` instead." | |
) | |
if not _has_cookiecutter: | |
raise ImportError( | |
"Model creation dependencies are required to use the `add_new_model` command. Install them by running " | |
"the following at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n" | |
) | |
# Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory | |
directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]] | |
if len(directories) > 0: | |
raise ValueError( | |
"Several directories starting with `cookiecutter-template-` in current working directory. " | |
"Please clean your directory by removing all folders starting with `cookiecutter-template-` or " | |
"change your working directory." | |
) | |
path_to_transformer_root = ( | |
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent | |
) | |
path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model" | |
# Execute cookiecutter | |
if not self._testing: | |
cookiecutter(str(path_to_cookiecutter)) | |
else: | |
with open(self._testing_file, "r") as configuration_file: | |
testing_configuration = json.load(configuration_file) | |
cookiecutter( | |
str(path_to_cookiecutter if self._path is None else self._path), | |
no_input=True, | |
extra_context=testing_configuration, | |
) | |
directory = [directory for directory in os.listdir() if "cookiecutter-template-" in directory[:22]][0] | |
# Retrieve configuration | |
with open(directory + "/configuration.json", "r") as configuration_file: | |
configuration = json.load(configuration_file) | |
lowercase_model_name = configuration["lowercase_modelname"] | |
generate_tensorflow_pytorch_and_flax = configuration["generate_tensorflow_pytorch_and_flax"] | |
os.remove(f"{directory}/configuration.json") | |
output_pytorch = "PyTorch" in generate_tensorflow_pytorch_and_flax | |
output_tensorflow = "TensorFlow" in generate_tensorflow_pytorch_and_flax | |
output_flax = "Flax" in generate_tensorflow_pytorch_and_flax | |
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}" | |
os.makedirs(model_dir, exist_ok=True) | |
os.makedirs(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}", exist_ok=True) | |
# Tests require submodules as they have parent imports | |
with open(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/__init__.py", "w"): | |
pass | |
shutil.move( | |
f"{directory}/__init__.py", | |
f"{model_dir}/__init__.py", | |
) | |
shutil.move( | |
f"{directory}/configuration_{lowercase_model_name}.py", | |
f"{model_dir}/configuration_{lowercase_model_name}.py", | |
) | |
def remove_copy_lines(path): | |
with open(path, "r") as f: | |
lines = f.readlines() | |
with open(path, "w") as f: | |
for line in lines: | |
if "# Copied from transformers." not in line: | |
f.write(line) | |
if output_pytorch: | |
if not self._testing: | |
remove_copy_lines(f"{directory}/modeling_{lowercase_model_name}.py") | |
shutil.move( | |
f"{directory}/modeling_{lowercase_model_name}.py", | |
f"{model_dir}/modeling_{lowercase_model_name}.py", | |
) | |
shutil.move( | |
f"{directory}/test_modeling_{lowercase_model_name}.py", | |
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py", | |
) | |
else: | |
os.remove(f"{directory}/modeling_{lowercase_model_name}.py") | |
os.remove(f"{directory}/test_modeling_{lowercase_model_name}.py") | |
if output_tensorflow: | |
if not self._testing: | |
remove_copy_lines(f"{directory}/modeling_tf_{lowercase_model_name}.py") | |
shutil.move( | |
f"{directory}/modeling_tf_{lowercase_model_name}.py", | |
f"{model_dir}/modeling_tf_{lowercase_model_name}.py", | |
) | |
shutil.move( | |
f"{directory}/test_modeling_tf_{lowercase_model_name}.py", | |
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py", | |
) | |
else: | |
os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py") | |
os.remove(f"{directory}/test_modeling_tf_{lowercase_model_name}.py") | |
if output_flax: | |
if not self._testing: | |
remove_copy_lines(f"{directory}/modeling_flax_{lowercase_model_name}.py") | |
shutil.move( | |
f"{directory}/modeling_flax_{lowercase_model_name}.py", | |
f"{model_dir}/modeling_flax_{lowercase_model_name}.py", | |
) | |
shutil.move( | |
f"{directory}/test_modeling_flax_{lowercase_model_name}.py", | |
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py", | |
) | |
else: | |
os.remove(f"{directory}/modeling_flax_{lowercase_model_name}.py") | |
os.remove(f"{directory}/test_modeling_flax_{lowercase_model_name}.py") | |
shutil.move( | |
f"{directory}/{lowercase_model_name}.md", | |
f"{path_to_transformer_root}/docs/source/en/model_doc/{lowercase_model_name}.md", | |
) | |
shutil.move( | |
f"{directory}/tokenization_{lowercase_model_name}.py", | |
f"{model_dir}/tokenization_{lowercase_model_name}.py", | |
) | |
shutil.move( | |
f"{directory}/tokenization_fast_{lowercase_model_name}.py", | |
f"{model_dir}/tokenization_{lowercase_model_name}_fast.py", | |
) | |
from os import fdopen, remove | |
from shutil import copymode, move | |
from tempfile import mkstemp | |
def replace(original_file: str, line_to_copy_below: str, lines_to_copy: List[str]): | |
# Create temp file | |
fh, abs_path = mkstemp() | |
line_found = False | |
with fdopen(fh, "w") as new_file: | |
with open(original_file) as old_file: | |
for line in old_file: | |
new_file.write(line) | |
if line_to_copy_below in line: | |
line_found = True | |
for line_to_copy in lines_to_copy: | |
new_file.write(line_to_copy) | |
if not line_found: | |
raise ValueError(f"Line {line_to_copy_below} was not found in file.") | |
# Copy the file permissions from the old file to the new file | |
copymode(original_file, abs_path) | |
# Remove original file | |
remove(original_file) | |
# Move new file | |
move(abs_path, original_file) | |
def skip_units(line): | |
return ( | |
("generating PyTorch" in line and not output_pytorch) | |
or ("generating TensorFlow" in line and not output_tensorflow) | |
or ("generating Flax" in line and not output_flax) | |
) | |
def replace_in_files(path_to_datafile): | |
with open(path_to_datafile) as datafile: | |
lines_to_copy = [] | |
skip_file = False | |
skip_snippet = False | |
for line in datafile: | |
if "# To replace in: " in line and "##" not in line: | |
file_to_replace_in = line.split('"')[1] | |
skip_file = skip_units(line) | |
elif "# Below: " in line and "##" not in line: | |
line_to_copy_below = line.split('"')[1] | |
skip_snippet = skip_units(line) | |
elif "# End." in line and "##" not in line: | |
if not skip_file and not skip_snippet: | |
replace(file_to_replace_in, line_to_copy_below, lines_to_copy) | |
lines_to_copy = [] | |
elif "# Replace with" in line and "##" not in line: | |
lines_to_copy = [] | |
elif "##" not in line: | |
lines_to_copy.append(line) | |
remove(path_to_datafile) | |
replace_in_files(f"{directory}/to_replace_{lowercase_model_name}.py") | |
os.rmdir(directory) | |