akhaliq's picture
akhaliq HF staff
add files
81170fd
from tqdm import tqdm
import requests
import os
import tempfile
def download(ckpt_dir, url):
name = url[url.rfind('/') + 1 : url.rfind('?')]
if ckpt_dir is None:
ckpt_dir = tempfile.gettempdir()
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
ckpt_file = os.path.join(ckpt_dir, name)
if not os.path.exists(ckpt_file):
print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
# first create temp file, in case the download fails
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
with open(ckpt_file_temp, 'wb') as file:
for data in response.iter_content(chunk_size=1024):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print('An error occured while downloading, please try again.')
if os.path.exists(ckpt_file_temp):
os.remove(ckpt_file_temp)
else:
# if download was successful, rename the temp file
os.rename(ckpt_file_temp, ckpt_file)
return ckpt_file