Spaces:
Build error
Build error
from tqdm import tqdm | |
import requests | |
import os | |
import tempfile | |
def download(ckpt_dir, url): | |
name = url[url.rfind('/') + 1 : url.rfind('?')] | |
if ckpt_dir is None: | |
ckpt_dir = tempfile.gettempdir() | |
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') | |
ckpt_file = os.path.join(ckpt_dir, name) | |
if not os.path.exists(ckpt_file): | |
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') | |
if not os.path.exists(ckpt_dir): | |
os.makedirs(ckpt_dir) | |
response = requests.get(url, stream=True) | |
total_size_in_bytes = int(response.headers.get('content-length', 0)) | |
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
# first create temp file, in case the download fails | |
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') | |
with open(ckpt_file_temp, 'wb') as file: | |
for data in response.iter_content(chunk_size=1024): | |
progress_bar.update(len(data)) | |
file.write(data) | |
progress_bar.close() | |
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
print('An error occured while downloading, please try again.') | |
if os.path.exists(ckpt_file_temp): | |
os.remove(ckpt_file_temp) | |
else: | |
# if download was successful, rename the temp file | |
os.rename(ckpt_file_temp, ckpt_file) | |
return ckpt_file | |