akhaliq's picture
akhaliq HF staff
add files
81170fd
import jax
import flax
import numpy as np
from tqdm import tqdm
import requests
import os
import tempfile
import logging
logger = logging.getLogger(__name__)
def download(url, ckpt_dir=None):
name = url[url.rfind('/') + 1 : url.rfind('?')]
if ckpt_dir is None:
ckpt_dir = tempfile.gettempdir()
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
ckpt_file = os.path.join(ckpt_dir, name)
if not os.path.exists(ckpt_file):
logger.info(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
# first create temp file, in case the download fails
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
with open(ckpt_file_temp, 'wb') as file:
for data in response.iter_content(chunk_size=1024):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
logger.error('An error occured while downloading, please try again.')
if os.path.exists(ckpt_file_temp):
os.remove(ckpt_file_temp)
else:
# if download was successful, rename the temp file
os.rename(ckpt_file_temp, ckpt_file)
return ckpt_file
def get(dictionary, key):
if dictionary is None or key not in dictionary:
return None
return dictionary[key]
def prefetch(dataset, n_prefetch):
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
ds_iter = iter(dataset)
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
ds_iter)
if n_prefetch:
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
return ds_iter