import argparse import numpy as np from rich import print from bark_infinity import config logger = config.logger from bark_infinity import generation from bark_infinity import api from bark_infinity import text_processing import time import random text_prompts_in_this_file = [] import torch from torch.utils import collect_env try: text_prompts_in_this_file.append( f"It's {text_processing.current_date_time_in_words()} And if you're hearing this, Bark is working. But you didn't provide any text" ) except Exception as e: print(f"An error occurred: {e}") text_prompt = """ In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move. However, Bark is working. """ text_prompts_in_this_file.append(text_prompt) text_prompt = """ A common mistake that people make when trying to design something completely foolproof is to underestimate the ingenuity of complete fools. """ text_prompts_in_this_file.append(text_prompt) def get_group_args(group_name, updated_args): # Convert the Namespace object to a dictionary updated_args_dict = vars(updated_args) group_args = {} for key, value in updated_args_dict.items(): if key in dict(config.DEFAULTS[group_name]): group_args[key] = value return group_args def main(args): if args.loglevel is not None: logger.setLevel(args.loglevel) if args.OFFLOAD_CPU is not None: generation.OFFLOAD_CPU = args.OFFLOAD_CPU # print(f"OFFLOAD_CPU is set to {generation.OFFLOAD_CPU}") else: if generation.get_SUNO_USE_DIRECTML() is not True: generation.OFFLOAD_CPU = True # default on just in case if args.USE_SMALL_MODELS is not None: generation.USE_SMALL_MODELS = args.USE_SMALL_MODELS # print(f"USE_SMALL_MODELS is set to {generation.USE_SMALL_MODELS}") if args.GLOBAL_ENABLE_MPS is not None: generation.GLOBAL_ENABLE_MPS = args.GLOBAL_ENABLE_MPS # print(f"GLOBAL_ENABLE_MPS is set to {generation.GLOBAL_ENABLE_MPS}") if not args.silent: if args.detailed_gpu_report or args.show_all_reports: print(api.startup_status_report(quick=False)) elif not args.text_prompt and not args.prompt_file: # probably a test run, default to show print(api.startup_status_report(quick=True)) if args.detailed_hugging_face_cache_report or args.show_all_reports: print(api.hugging_face_cache_report()) if args.detailed_cuda_report or args.show_all_reports: print(api.cuda_status_report()) if args.detailed_numpy_report: print(api.numpy_report()) if args.run_numpy_benchmark or args.show_all_reports: from bark_infinity.debug import numpy_benchmark numpy_benchmark() if args.list_speakers: api.list_speakers() return if args.render_npz_samples: api.render_npz_samples() return if args.text_prompt: text_prompts_to_process = [args.text_prompt] elif args.prompt_file: text_file = text_processing.load_text(args.prompt_file) if text_file is None: logger.error(f"Error loading file: {args.prompt_file}") return text_prompts_to_process = text_processing.split_text( text_processing.load_text(args.prompt_file), args.split_input_into_separate_prompts_by, args.split_input_into_separate_prompts_by_value, ) print(f"\nProcessing file: {args.prompt_file}") print(f" Looks like: {len(text_prompts_to_process)} prompt(s)") else: print("No --text_prompt or --prompt_file specified, using test prompt.") text_prompts_to_process = random.sample(text_prompts_in_this_file, 2) things = len(text_prompts_to_process) + args.output_iterations if things > 10: if args.dry_run is False: print( f"WARNING: You are about to process {things} prompts. Consider using '--dry-run' to test things first." ) # pprint(args) print("Loading Bark models...") if not args.dry_run and generation.get_SUNO_USE_DIRECTML() is not True: generation.preload_models( args.text_use_gpu, args.text_use_small, args.coarse_use_gpu, args.coarse_use_small, args.fine_use_gpu, args.fine_use_small, args.codec_use_gpu, args.force_reload, ) print("Done.") for idx, text_prompt in enumerate(text_prompts_to_process, start=1): if len(text_prompts_to_process) > 1: print(f"\nPrompt {idx}/{len(text_prompts_to_process)}:") # print(f"Text prompt: {text_prompt}") for iteration in range(1, args.output_iterations + 1): if args.output_iterations > 1: print(f"\nIteration {iteration} of {args.output_iterations}.") if iteration == 1: print("ss", text_prompt) args.current_iteration = iteration args.text_prompt = text_prompt args_dict = vars(args) api.generate_audio_long(**args_dict) if __name__ == "__main__": parser = config.create_argument_parser() args = parser.parse_args() updated_args = config.update_group_args_with_defaults(args) namespace_args = argparse.Namespace(**updated_args) main(namespace_args)