bark / bark_perform.py
jamalsenouci's picture
Upload folder using huggingface_hub
c6919c4
raw
history blame contribute delete
No virus
5.51 kB
import argparse
import numpy as np
from rich import print
from bark_infinity import config
logger = config.logger
from bark_infinity import generation
from bark_infinity import api
from bark_infinity import text_processing
import time
import random
text_prompts_in_this_file = []
import torch
from torch.utils import collect_env
try:
text_prompts_in_this_file.append(
f"It's {text_processing.current_date_time_in_words()} And if you're hearing this, Bark is working. But you didn't provide any text"
)
except Exception as e:
print(f"An error occurred: {e}")
text_prompt = """
In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move. However, Bark is working.
"""
text_prompts_in_this_file.append(text_prompt)
text_prompt = """
A common mistake that people make when trying to design something completely foolproof is to underestimate the ingenuity of complete fools.
"""
text_prompts_in_this_file.append(text_prompt)
def get_group_args(group_name, updated_args):
# Convert the Namespace object to a dictionary
updated_args_dict = vars(updated_args)
group_args = {}
for key, value in updated_args_dict.items():
if key in dict(config.DEFAULTS[group_name]):
group_args[key] = value
return group_args
def main(args):
if args.loglevel is not None:
logger.setLevel(args.loglevel)
if args.OFFLOAD_CPU is not None:
generation.OFFLOAD_CPU = args.OFFLOAD_CPU
# print(f"OFFLOAD_CPU is set to {generation.OFFLOAD_CPU}")
else:
if generation.get_SUNO_USE_DIRECTML() is not True:
generation.OFFLOAD_CPU = True # default on just in case
if args.USE_SMALL_MODELS is not None:
generation.USE_SMALL_MODELS = args.USE_SMALL_MODELS
# print(f"USE_SMALL_MODELS is set to {generation.USE_SMALL_MODELS}")
if args.GLOBAL_ENABLE_MPS is not None:
generation.GLOBAL_ENABLE_MPS = args.GLOBAL_ENABLE_MPS
# print(f"GLOBAL_ENABLE_MPS is set to {generation.GLOBAL_ENABLE_MPS}")
if not args.silent:
if args.detailed_gpu_report or args.show_all_reports:
print(api.startup_status_report(quick=False))
elif not args.text_prompt and not args.prompt_file: # probably a test run, default to show
print(api.startup_status_report(quick=True))
if args.detailed_hugging_face_cache_report or args.show_all_reports:
print(api.hugging_face_cache_report())
if args.detailed_cuda_report or args.show_all_reports:
print(api.cuda_status_report())
if args.detailed_numpy_report:
print(api.numpy_report())
if args.run_numpy_benchmark or args.show_all_reports:
from bark_infinity.debug import numpy_benchmark
numpy_benchmark()
if args.list_speakers:
api.list_speakers()
return
if args.render_npz_samples:
api.render_npz_samples()
return
if args.text_prompt:
text_prompts_to_process = [args.text_prompt]
elif args.prompt_file:
text_file = text_processing.load_text(args.prompt_file)
if text_file is None:
logger.error(f"Error loading file: {args.prompt_file}")
return
text_prompts_to_process = text_processing.split_text(
text_processing.load_text(args.prompt_file),
args.split_input_into_separate_prompts_by,
args.split_input_into_separate_prompts_by_value,
)
print(f"\nProcessing file: {args.prompt_file}")
print(f" Looks like: {len(text_prompts_to_process)} prompt(s)")
else:
print("No --text_prompt or --prompt_file specified, using test prompt.")
text_prompts_to_process = random.sample(text_prompts_in_this_file, 2)
things = len(text_prompts_to_process) + args.output_iterations
if things > 10:
if args.dry_run is False:
print(
f"WARNING: You are about to process {things} prompts. Consider using '--dry-run' to test things first."
)
# pprint(args)
print("Loading Bark models...")
if not args.dry_run and generation.get_SUNO_USE_DIRECTML() is not True:
generation.preload_models(
args.text_use_gpu,
args.text_use_small,
args.coarse_use_gpu,
args.coarse_use_small,
args.fine_use_gpu,
args.fine_use_small,
args.codec_use_gpu,
args.force_reload,
)
print("Done.")
for idx, text_prompt in enumerate(text_prompts_to_process, start=1):
if len(text_prompts_to_process) > 1:
print(f"\nPrompt {idx}/{len(text_prompts_to_process)}:")
# print(f"Text prompt: {text_prompt}")
for iteration in range(1, args.output_iterations + 1):
if args.output_iterations > 1:
print(f"\nIteration {iteration} of {args.output_iterations}.")
if iteration == 1:
print("ss", text_prompt)
args.current_iteration = iteration
args.text_prompt = text_prompt
args_dict = vars(args)
api.generate_audio_long(**args_dict)
if __name__ == "__main__":
parser = config.create_argument_parser()
args = parser.parse_args()
updated_args = config.update_group_args_with_defaults(args)
namespace_args = argparse.Namespace(**updated_args)
main(namespace_args)