#!/usr/bin/python3 import os import torch import logging from audiosr import super_resolution, build_model, save_wave, get_time, read_list import argparse os.environ["TOKENIZERS_PARALLELISM"] = "true" matplotlib_logger = logging.getLogger('matplotlib') matplotlib_logger.setLevel(logging.WARNING) parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input_audio_file", type=str, required=False, help="Input audio file for audio super resolution", ) parser.add_argument( "-il", "--input_file_list", type=str, required=False, default="", help="A file that contains all audio files that need to perform audio super resolution", ) parser.add_argument( "-s", "--save_path", type=str, required=False, help="The path to save model output", default="./output", ) parser.add_argument( "--model_name", type=str, required=False, help="The checkpoint you gonna use", default="basic", choices=["basic","speech"] ) parser.add_argument( "-d", "--device", type=str, required=False, help="The device for computation. If not specified, the script will automatically choose the device based on your environment.", default="auto", ) parser.add_argument( "--ddim_steps", type=int, required=False, default=50, help="The sampling step for DDIM", ) parser.add_argument( "-gs", "--guidance_scale", type=float, required=False, default=3.5, help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", ) parser.add_argument( "--seed", type=int, required=False, default=42, help="Changing this value (any integer number) will lead to a different generation result.", ) parser.add_argument( "--suffix", type=str, required=False, help="Suffix for the output file", default="_AudioSR_Processed_48K", ) args = parser.parse_args() torch.set_float32_matmul_precision("high") save_path = os.path.join(args.save_path, get_time()) assert args.input_file_list is not None or args.input_audio_file is not None,"Please provide either a list of audio files or a single audio file" input_file = args.input_audio_file random_seed = args.seed sample_rate=48000 latent_t_per_second=12.8 guidance_scale = args.guidance_scale os.makedirs(save_path, exist_ok=True) audiosr = build_model(model_name=args.model_name, device=args.device) if(args.input_file_list): print("Generate audio based on the text prompts in %s" % args.input_file_list) files_todo = read_list(args.input_file_list) else: files_todo = [input_file] for input_file in files_todo: name = os.path.splitext(os.path.basename(input_file))[0] + args.suffix waveform = super_resolution( audiosr, input_file, seed=random_seed, guidance_scale=guidance_scale, ddim_steps=args.ddim_steps, latent_t_per_second=latent_t_per_second ) save_wave(waveform, inputpath=input_file, savepath=save_path, name=name, samplerate=sample_rate)