|
import argparse |
|
import json |
|
import os |
|
import shutil |
|
from collections import defaultdict |
|
from datetime import datetime |
|
from tempfile import TemporaryDirectory |
|
from typing import Dict, List, Optional, Set, Tuple |
|
|
|
import torch |
|
|
|
from huggingface_hub import HfApi, Repository, hf_hub_download |
|
from huggingface_hub.file_download import repo_folder_name |
|
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file |
|
|
|
REPORT_DESCRIPTION = """ |
|
Este es un reporte automatizado creado con una herramienta de conversión personalizada. |
|
|
|
Este nuevo archivo es equivalente a `pytorch_model.bin` pero es seguro en el sentido de que |
|
no se puede inyectar código arbitrario en él. |
|
|
|
Estos archivos también cargan mucho más rápido que su contraparte de PyTorch: |
|
https://colab.research.google.com/github/huggingface/notebooks/blob/main/safetensors_doc/en/speed.ipynb |
|
|
|
Los widgets en la página de tu modelo funcionarán usando este modelo, asegurando que el archivo realmente funcione. |
|
|
|
Si encuentras algún problema: por favor repórtalo en el siguiente enlace: https://huggingface.co/spaces/safetensors/convert/discussions |
|
|
|
Siéntete libre de ignorar este reporte. |
|
""" |
|
|
|
ConversionResult = Tuple[List[str], List[Tuple[str, "Exception"]]] |
|
|
|
def _remove_duplicate_names(state_dict: Dict[str, torch.Tensor], *, preferred_names: List[str] = None, discard_names: List[str] = None) -> Dict[str, List[str]]: |
|
if preferred_names is None: |
|
preferred_names = [] |
|
preferred_names = set(preferred_names) |
|
if discard_names is None: |
|
discard_names = [] |
|
discard_names = set(discard_names) |
|
shareds = _find_shared_tensors(state_dict) |
|
to_remove = defaultdict(list) |
|
for shared in shareds: |
|
complete_names = set([name for name in shared if _is_complete(state_dict[name])]) |
|
if not complete_names: |
|
if len(shared) == 1: |
|
name = list(shared)[0] |
|
state_dict[name] = state_dict[name].clone() |
|
complete_names = {name} |
|
else: |
|
raise RuntimeError(f"Error al intentar encontrar nombres para remover al guardar el state dict, pero no se encontró un nombre adecuado para mantener entre: {shared}. Ninguno cubre todo el almacenamiento. Rechazando guardar/cargar el modelo ya que podrías estar almacenando mucha más memoria de la necesaria. Por favor, refiérete a https://huggingface.co/docs/safetensors/torch_shared_tensors para más información. O abre un issue.") |
|
keep_name = sorted(list(complete_names))[0] |
|
preferred = complete_names.difference(discard_names) |
|
if preferred: |
|
keep_name = sorted(list(preferred))[0] |
|
if preferred_names: |
|
preferred = preferred_names.intersection(complete_names) |
|
if preferred: |
|
keep_name = sorted(list(preferred))[0] |
|
for name in sorted(shared): |
|
if name != keep_name: |
|
to_remove[keep_name].append(name) |
|
return to_remove |
|
|
|
def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]: |
|
try: |
|
import transformers |
|
config_filename = hf_hub_download(model_id, revision=revision, filename="config.json", token=token, cache_dir=folder) |
|
with open(config_filename, "r") as f: |
|
config = json.load(f) |
|
architecture = config["architectures"][0] |
|
class_ = getattr(transformers, architecture) |
|
discard_names = getattr(class_, "_tied_weights_keys", []) |
|
except Exception: |
|
discard_names = [] |
|
return discard_names |
|
|
|
def check_file_size(sf_filename: str, pt_filename: str): |
|
sf_size = os.stat(sf_filename).st_size |
|
pt_size = os.stat(pt_filename).st_size |
|
if (sf_size - pt_size) / pt_size > 0.01: |
|
raise RuntimeError(f"La diferencia de tamaño de archivo es mayor al 1%:\n - {sf_filename}: {sf_size} bytes\n - {pt_filename}: {pt_size} bytes") |
|
|
|
def rename(model_id: str, pt_filename: str) -> str: |
|
filename, ext = os.path.splitext(pt_filename) |
|
base_name = os.path.basename(filename) |
|
safetensors_name = f"{model_id.replace('/', '_')}_{base_name}.safetensors" |
|
return safetensors_name |
|
|
|
def convert_multi(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: |
|
filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder) |
|
with open(filename, "r") as f: |
|
data = json.load(f) |
|
filenames = set(data["weight_map"].values()) |
|
local_filenames = [] |
|
errors = [] |
|
for filename in filenames: |
|
try: |
|
pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token, cache_dir=folder) |
|
sf_filename = rename(model_id, filename) |
|
sf_filepath = os.path.join(folder, sf_filename) |
|
convert_file(pt_filename, sf_filepath, discard_names=discard_names) |
|
local_filenames.append(sf_filepath) |
|
except Exception as e: |
|
errors.append((filename, e)) |
|
index = os.path.join(folder, f"{model_id.replace('/', '_')}_model.safetensors.index.json") |
|
try: |
|
with open(index, "w") as f: |
|
newdata = {k: v for k, v in data.items()} |
|
newmap = {k: rename(model_id, v) for k, v in data["weight_map"].items()} |
|
newdata["weight_map"] = newmap |
|
json.dump(newdata, f, indent=4) |
|
local_filenames.append(index) |
|
except Exception as e: |
|
errors.append((index, e)) |
|
return local_filenames, errors |
|
|
|
def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: |
|
try: |
|
pt_filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder) |
|
sf_name = rename(model_id, "pytorch_model.bin") |
|
sf_filepath = os.path.join(folder, sf_name) |
|
convert_file(pt_filename, sf_filepath, discard_names) |
|
local_filenames = [sf_filepath] |
|
errors = [] |
|
except Exception as e: |
|
local_filenames = [] |
|
errors = [("pytorch_model.bin", e)] |
|
return local_filenames, errors |
|
|
|
def convert_file(pt_filename: str, sf_filename: str, discard_names: List[str]): |
|
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) |
|
if "state_dict" in loaded: |
|
loaded = loaded["state_dict"] |
|
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) |
|
metadata = {"format": "pt"} |
|
for kept_name, to_remove_group in to_removes.items(): |
|
for to_remove in to_remove_group: |
|
if to_remove not in metadata: |
|
metadata[to_remove] = kept_name |
|
del loaded[to_remove] |
|
loaded = {k: v.contiguous() for k, v in loaded.items()} |
|
dirname = os.path.dirname(sf_filename) |
|
os.makedirs(dirname, exist_ok=True) |
|
save_file(loaded, sf_filename, metadata=metadata) |
|
check_file_size(sf_filename, pt_filename) |
|
reloaded = load_file(sf_filename) |
|
for k in loaded: |
|
pt_tensor = loaded[k] |
|
sf_tensor = reloaded[k] |
|
if not torch.equal(pt_tensor, sf_tensor): |
|
raise RuntimeError(f"Los tensores de salida no coinciden para la clave {k}") |
|
|
|
def convert_generic(model_id: str, *, revision: Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult: |
|
local_filenames = [] |
|
errors = [] |
|
extensions = set([".bin", ".ckpt", ".pth"]) |
|
for filename in filenames: |
|
prefix, ext = os.path.splitext(filename) |
|
if ext in extensions: |
|
try: |
|
pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder) |
|
dirname, raw_filename = os.path.split(filename) |
|
if raw_filename in {"pytorch_model.bin", "pytorch_model.pth"}: |
|
sf_in_repo = rename(model_id, raw_filename) |
|
else: |
|
sf_in_repo = rename(model_id, filename) |
|
sf_filepath = os.path.join(folder, sf_in_repo) |
|
convert_file(pt_filename, sf_filepath, discard_names=[]) |
|
local_filenames.append(sf_filepath) |
|
except Exception as e: |
|
errors.append((filename, e)) |
|
return local_filenames, errors |
|
|
|
def prepare_target_repo_files(model_id: str, revision: Optional[str], folder: str, token: str, repo_dir: str): |
|
api = HfApi() |
|
try: |
|
common_files = [ |
|
".gitattributes", |
|
"LICENSE.txt", |
|
"README.md", |
|
"USE_POLICY.md", |
|
"config.json", |
|
"generation_config.json", |
|
"special_tokens_map.json", |
|
"tokenizer.json", |
|
"tokenizer_config.json" |
|
] |
|
for file in common_files: |
|
try: |
|
file_path = hf_hub_download(repo_id=model_id, revision=revision, filename=file, token=token, cache_dir=folder) |
|
shutil.copy(file_path, repo_dir) |
|
except Exception: |
|
if file == ".gitattributes": |
|
gitattributes_content = "model.safetensors filter=safetensors diff=safetensors merge=safetensors -text\n" |
|
with open(os.path.join(repo_dir, file), "w") as f: |
|
f.write(gitattributes_content) |
|
elif file == "LICENSE.txt": |
|
default_license = "MIT License\n\nCopyright (c) 2024" |
|
with open(os.path.join(repo_dir, file), "w") as f: |
|
f.write(default_license) |
|
elif file == "README.md": |
|
readme_content = f"# {model_id.replace('/', ' ').title()}\n\nModelo convertido a safetensors." |
|
with open(os.path.join(repo_dir, file), "w") as f: |
|
f.write(readme_content) |
|
elif file == "USE_POLICY.md": |
|
use_policy_content = "### Política de Uso\n\nEste modelo se distribuye bajo términos de uso estándar." |
|
with open(os.path.join(repo_dir, file), "w") as f: |
|
f.write(use_policy_content) |
|
elif file in {"config.json", "generation_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"}: |
|
default_json_content = {} |
|
with open(os.path.join(repo_dir, file), "w") as f: |
|
json.dump(default_json_content, f, indent=4) |
|
except Exception as e: |
|
raise e |
|
|
|
def generate_report(model_id: str, local_filenames: List[str], errors: List[Tuple[str, Exception]], output_md_path: str): |
|
report_lines = [ |
|
f"# Reporte de Conversión para el Modelo `{model_id}`", |
|
f"Fecha y Hora: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", |
|
"", |
|
"## Archivos Convertidos Exitosamente", |
|
] |
|
if local_filenames: |
|
for filename in local_filenames: |
|
report_lines.append(f"- `{os.path.basename(filename)}`") |
|
else: |
|
report_lines.append("No se convirtieron archivos.") |
|
report_lines.append("") |
|
report_lines.append("## Errores Durante la Conversión") |
|
if errors: |
|
for filename, error in errors: |
|
report_lines.append(f"- **Archivo**: `{os.path.basename(filename)}`\n - **Error**: {error}") |
|
else: |
|
report_lines.append("No hubo errores durante la conversión.") |
|
report_content_md = "\n".join(report_lines) |
|
with open(output_md_path, "w") as f: |
|
f.write(report_content_md) |
|
report_json = { |
|
"model_id": model_id, |
|
"timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
|
"converted_files": [os.path.basename(f) for f in local_filenames], |
|
"errors": [{"file": os.path.basename(f), "error": str(e)} for f, e in errors], |
|
"description": REPORT_DESCRIPTION.strip() |
|
} |
|
json_output_path = os.path.splitext(output_md_path)[0] + "_report.json" |
|
with open(json_output_path, "w") as f: |
|
json.dump(report_json, f, indent=4) |
|
print(f"Reportes generados en: {output_md_path} y {json_output_path}") |
|
|
|
def convert(model_id: str, revision: Optional[str] = None, force: bool = False, token: Optional[str] = None) -> ConversionResult: |
|
api = HfApi() |
|
info = api.model_info(repo_id=model_id, revision=revision) |
|
filenames = set(s.rfilename for s in info.siblings) |
|
with TemporaryDirectory() as d: |
|
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) |
|
os.makedirs(folder, exist_ok=True) |
|
local_filenames = [] |
|
errors = [] |
|
if not force and any(filename.endswith(".safetensors") for filename in filenames): |
|
print(f"El modelo `{model_id}` ya tiene archivos `.safetensors` convertidos. Usando report existente o forzando con --force.") |
|
else: |
|
library_name = getattr(info, "library_name", None) |
|
if library_name == "transformers": |
|
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=token) |
|
if "pytorch_model.bin" in filenames or "pytorch_model.pth" in filenames: |
|
converted, conv_errors = convert_single(model_id, revision=revision, folder=folder, token=token, discard_names=discard_names) |
|
local_filenames.extend(converted) |
|
errors.extend(conv_errors) |
|
elif "pytorch_model.bin.index.json" in filenames: |
|
converted, conv_errors = convert_multi(model_id, revision=revision, folder=folder, token=token, discard_names=discard_names) |
|
local_filenames.extend(converted) |
|
errors.extend(conv_errors) |
|
else: |
|
print(f"El modelo `{model_id}` no parece ser un modelo válido de PyTorch. No se puede convertir.") |
|
else: |
|
converted, conv_errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=token) |
|
local_filenames.extend(converted) |
|
errors.extend(conv_errors) |
|
return local_filenames, errors |
|
|
|
def read_token(token_file: Optional[str]) -> Optional[str]: |
|
if token_file: |
|
if os.path.isfile(token_file): |
|
with open(token_file, "r") as f: |
|
token = f.read().strip() |
|
return token |
|
else: |
|
print(f"El archivo de token especificado no existe: {token_file}") |
|
return None |
|
else: |
|
return os.getenv("HF_TOKEN") |
|
|
|
def create_target_repo(model_id: str, api: HfApi, token: str) -> str: |
|
target_repo_id = f"{api.whoami(token=token)['name']}/{model_id.replace('/', '_')}_safetensors" |
|
try: |
|
api.create_repo(name=f"{model_id.replace('/', '_')}_safetensors", repo_type="model", exist_ok=True, token=token) |
|
print(f"Repositorio creado o ya existente: {target_repo_id}") |
|
except Exception as e: |
|
print(f"Error al crear el repositorio `{target_repo_id}`: {e}") |
|
raise e |
|
return target_repo_id |
|
|
|
def upload_to_hf(local_filenames: List[str], target_repo_id: str, token: str, additional_files: List[str]): |
|
repo_dir = "./temp_repo" |
|
if os.path.exists(repo_dir): |
|
shutil.rmtree(repo_dir) |
|
os.makedirs(repo_dir, exist_ok=True) |
|
try: |
|
repo = Repository(local_dir=repo_dir, clone_from=target_repo_id, use_auth_token=token) |
|
for file_path in local_filenames: |
|
shutil.copy(file_path, repo_dir) |
|
for file_path in additional_files: |
|
shutil.copy(file_path, repo_dir) |
|
repo.git_add(auto_lfs_track=True) |
|
repo.git_commit("Añadiendo archivos safetensors convertidos") |
|
repo.git_push() |
|
print(f"Archivos subidos exitosamente al repositorio: {target_repo_id}") |
|
except Exception as e: |
|
print(f"Error al subir archivos al repositorio `{target_repo_id}`: {e}") |
|
raise e |
|
finally: |
|
shutil.rmtree(repo_dir) |
|
|
|
def main(): |
|
DESCRIPTION = """ |
|
Herramienta de utilidad simple para convertir automáticamente algunos pesos en el hub al formato `safetensors`. |
|
Actualmente exclusiva para PyTorch. |
|
Funciona descargando los pesos (PT), convirtiéndolos localmente, subiéndolos a tu propio perfil en Hugging Face Hub y generando reportes en formato Markdown y JSON. |
|
""" |
|
parser = argparse.ArgumentParser(description=DESCRIPTION) |
|
parser.add_argument( |
|
"model_id", |
|
type=str, |
|
help="El nombre del modelo en el hub para convertir. Por ejemplo, `gpt2` o `facebook/wav2vec2-base-960h`", |
|
) |
|
parser.add_argument( |
|
"--revision", |
|
type=str, |
|
help="La revisión a convertir", |
|
) |
|
parser.add_argument( |
|
"--force", |
|
action="store_true", |
|
help="Forzar la conversión incluso si ya existen archivos `.safetensors` en el modelo.", |
|
) |
|
parser.add_argument( |
|
"-y", |
|
action="store_true", |
|
help="Ignorar prompt de seguridad", |
|
) |
|
parser.add_argument( |
|
"--output", |
|
type=str, |
|
default="conversion_report.md", |
|
help="Ruta donde se guardará el reporte de conversión en formato Markdown.", |
|
) |
|
parser.add_argument( |
|
"--output-json", |
|
type=str, |
|
default=None, |
|
help="Ruta donde se guardará el reporte de conversión en formato JSON. Si no se especifica, se creará en la misma ubicación que el reporte Markdown.", |
|
) |
|
parser.add_argument( |
|
"--token-file", |
|
type=str, |
|
default=None, |
|
help="Ruta al archivo que contiene el token de autenticación de Hugging Face. Si no se especifica, se intentará leer desde la variable de entorno 'HF_TOKEN'.", |
|
) |
|
args = parser.parse_args() |
|
model_id = args.model_id |
|
token = read_token(args.token_file) |
|
if not token: |
|
print("No se proporcionó un token de autenticación válido. Por favor, proporciónalo mediante --token-file o establece la variable de entorno 'HF_TOKEN'.") |
|
return |
|
api = HfApi() |
|
try: |
|
user_info = api.whoami(token=token) |
|
print(f"Autenticado como: {user_info['name']}") |
|
except Exception as e: |
|
print(f"No se pudo autenticar con Hugging Face Hub: {e}") |
|
return |
|
if args.y: |
|
proceed = True |
|
else: |
|
txt = input( |
|
"Este script de conversión desenpaca un archivo pickled, lo cual es inherentemente inseguro. Si no confías en este archivo, te invitamos a usar " |
|
"https://huggingface.co/spaces/safetensors/convert o Google Colab u otra solución alojada para evitar posibles problemas con este archivo." |
|
" ¿Continuar [Y/n] ? " |
|
) |
|
proceed = txt.lower() in {"", "y", "yes"} |
|
if proceed: |
|
try: |
|
with TemporaryDirectory() as d: |
|
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) |
|
os.makedirs(folder, exist_ok=True) |
|
local_filenames, errors = convert(model_id, revision=args.revision, force=args.force, token=token) |
|
target_repo_id = create_target_repo(model_id, api, token) |
|
with TemporaryDirectory() as repo_temp_dir: |
|
prepare_target_repo_files(model_id, args.revision, folder, token, repo_temp_dir) |
|
additional_files = [os.path.join(repo_temp_dir, f) for f in os.listdir(repo_temp_dir)] |
|
if local_filenames or additional_files: |
|
upload_to_hf(local_filenames, target_repo_id, token, additional_files) |
|
print(f"Archivos convertidos y adicionales subidos exitosamente a: {target_repo_id}") |
|
else: |
|
print("No hay archivos convertidos ni adicionales para subir.") |
|
output_md = args.output |
|
if args.output_json: |
|
output_json = args.output_json |
|
else: |
|
output_json = os.path.splitext(output_md)[0] + "_report.json" |
|
generate_report(model_id, local_filenames, errors, output_md) |
|
except Exception as e: |
|
print(f"Ocurrió un error inesperado: {e}") |
|
else: |
|
print(f"La respuesta fue '{txt}', abortando.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|