File size: 7,414 Bytes
6127b48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# Adapted from https://github.com/pytorch/audio/
import hashlib
import logging
import os
import tarfile
import urllib
import urllib.request
import zipfile
from os.path import expanduser
from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm
def stream_url(
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
) -> Iterable:
"""Stream url by chunk
Args:
url (str): Url.
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
"""
# If we already have the whole file, there is no need to download it again
req = urllib.request.Request(url, method="HEAD")
with urllib.request.urlopen(req) as response:
url_size = int(response.info().get("Content-Length", -1))
if url_size == start_byte:
return
req = urllib.request.Request(url)
if start_byte:
req.headers["Range"] = "bytes={}-".format(start_byte)
with urllib.request.urlopen(req) as upointer, tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
) as pbar:
num_bytes = 0
while True:
chunk = upointer.read(block_size)
if not chunk:
break
yield chunk
num_bytes += len(chunk)
pbar.update(len(chunk))
def download_url(
url: str,
download_folder: str,
filename: Optional[str] = None,
hash_value: Optional[str] = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False,
) -> None:
"""Download file to disk.
Args:
url (str): Url.
download_folder (str): Folder to download file.
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
(Default: ``None``).
hash_value (str or None, optional): Hash for url (Default: ``None``).
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
resume (bool, optional): Enable resuming download (Default: ``False``).
"""
req = urllib.request.Request(url, method="HEAD")
req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with
# Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename)
if resume and os.path.exists(filepath):
mode = "ab"
local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
else:
mode = "wb"
local_size = None
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
fpointer.write(chunk)
with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
"""Validate a given file object with its hash.
Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.
"""
if hash_type == "sha256":
hash_func = hashlib.sha256()
elif hash_type == "md5":
hash_func = hashlib.md5()
else:
raise ValueError
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024**2)
if not chunk:
break
hash_func.update(chunk)
return hash_func.hexdigest() == hash_value
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
"""Extract archive.
Args:
from_path (str): the path of the archive.
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
(Default: ``None``)
overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns:
list: List of paths to extracted files even if not overwritten.
"""
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file %s.", from_path)
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
except tarfile.ReadError:
pass
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file %s.", from_path)
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
except zipfile.BadZipFile:
pass
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str):
"""Download dataset from kaggle.
Args:
dataset_path (str):
This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning'
dataset_name (str): Name of the folder the dataset will be saved in.
output_path (str): Path of the location you want the dataset folder to be saved to.
"""
data_path = os.path.join(output_path, dataset_name)
try:
import kaggle # pylint: disable=import-outside-toplevel
kaggle.api.authenticate()
print(f"""\nDownloading {dataset_name}...""")
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
except OSError:
print(
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
)
|