import os import sys import argparse import subprocess from pathlib import Path from concurrent.futures import ( ProcessPoolExecutor, as_completed, ) from zipnn_decompress_file import ( decompress_file, ) sys.path.append( os.path.abspath( os.path.join( os.path.dirname(__file__), "..", ) ) ) RED = "\033[91m" YELLOW = "\033[93m" GREEN = "\033[92m" RESET = "\033[0m" def check_and_install_zipnn(): try: import zipnn except ImportError: print("zipnn not found. Installing...") subprocess.check_call( [ sys.executable, "-m", "pip", "install", "zipnn", ] ) import zipnn def replace_in_file(file_path: Path | str, old: str, new: str) -> None: """Given a file_path, replace all occurrences of `old` with `new` inpalce.""" with open(file_path, 'r') as file: file_data = file.read() file_data = file_data.replace(old, new) with open(file_path, 'w') as file: file.write(file_data) def decompress_znn_files( path=".", delete=False, force=False, max_processes=1, hf_cache=False, model="", branch="main", ): import zipnn overwrite_first=True if model: if not hf_cache: raise ValueError( "Must specify --hf_cache when using --model" ) try: from huggingface_hub import scan_cache_dir except ImportError: raise ImportError( "huggingface_hub not found. Please pip install huggingface_hub." ) cache = scan_cache_dir() repo = next((repo for repo in cache.repos if repo.repo_id == model), None) if repo is not None: print(f"Found repo {model} in cache") # Get the latest revision path hash = '' try: with open(os.path.join(repo.repo_path, 'refs', branch), "r") as ref: hash = ref.read() except FileNotFoundError: raise FileNotFoundError(f"Branch {branch} not found in repo {model}") path = os.path.join(repo.repo_path, 'snapshots', hash) file_list = [] directories_to_search = [ ( path, [], os.listdir(path), ) ] for ( root, _, files, ) in directories_to_search: for file_name in files: if file_name.endswith(".znn"): decompressed_path = file_name[:-4] if not force and os.path.exists( decompressed_path ): # if overwrite_first: overwrite_first=False user_input = ( input( f"Decompressed files already exists; Would you like to overwrite them all (y/n)? " ) .strip() .lower() ) if user_input not in ( "y", "yes", ): print( f"No forced overwriting." ) else: print( f"Overwriting all decompressed files." ) force=True # if not force: user_input = ( input( f"{decompressed_path} already exists; overwrite (y/n)? " ) .strip() .lower() ) if user_input not in ( "y", "yes", ): print( f"Skipping {file_name}..." ) continue full_path = os.path.join( root, file_name, ) file_list.append(full_path) if file_list and hf_cache: try: from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME ) except ImportError: raise ImportError( "Transformers not found. Please pip install transformers." ) suffix = file_list[0].split('/')[-1].split('.')[-2] # get the one before .znn if os.path.exists(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME)): print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") blob_name = os.path.join(path, os.readlink(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME))) replace_in_file( file_path=blob_name, old=f"{suffix}.znn", new=f"{suffix}" ) elif os.path.exists(os.path.join(path, WEIGHTS_INDEX_NAME)): print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") blob_name = os.path.join(path, os.readlink(os.path.join(path, WEIGHTS_INDEX_NAME))) replace_in_file( file_path=blob_name, old=f"{suffix}.znn", new=f"{suffix}" ) with ProcessPoolExecutor( max_workers=max_processes ) as executor: for file in file_list[:max_processes]: future_to_file = { executor.submit( decompress_file, file, delete, True, hf_cache, ): file for file in file_list[ :max_processes ] } file_list = file_list[max_processes:] while future_to_file: for future in as_completed( future_to_file ): file = future_to_file.pop( future ) try: future.result() except Exception as exc: print( f"{RED}File {file} generated an exception: {exc}{RESET}" ) if file_list: next_file = file_list.pop( 0 ) future_to_file[ executor.submit( decompress_file, next_file, delete, True, hf_cache, ) ] = next_file # print(f"{GREEN}All files decompressed{RESET}") if __name__ == "__main__": check_and_install_zipnn() parser = argparse.ArgumentParser( description="Compresses all .znn files." ) parser.add_argument( "--path", type=str, help="Path to folder of files to decompress. If left empty, checks current folder.", ) parser.add_argument( "--delete", action="store_true", help="A flag that triggers deletion of a single compressed file instead of decompression", ) parser.add_argument( "--force", action="store_true", help="A flag that forces overwriting when decompressing.", ) parser.add_argument( "--max_processes", type=int, help="The amount of maximum processes.", ) parser.add_argument( "--hf_cache", action="store_true", help="A flag that indicates if the file is in the Hugging Face cache. Must either specify --model or --path to the model's snapshot cache.", ) parser.add_argument( "--model", type=str, help="Only when using --hf_cache, specify the model name or path. E.g. 'ibm-granite/granite-7b-instruct'", ) parser.add_argument( "--model_branch", type=str, default="main", help="Only when using --model, specify the model branch. Default is 'main'", ) args = parser.parse_args() optional_kwargs = {} if args.path is not None: optional_kwargs["path"] = args.path if args.delete: optional_kwargs["delete"] = args.delete if args.force: optional_kwargs["force"] = args.force if args.max_processes: optional_kwargs["max_processes"] = ( args.max_processes ) if args.hf_cache: optional_kwargs["hf_cache"] = args.hf_cache if args.model: optional_kwargs["model"] = args.model if args.model_branch: optional_kwargs[ "branch" ] = args.model_branch decompress_znn_files(**optional_kwargs)