Spaces:
Build error
Build error
import jax | |
import flax | |
import numpy as np | |
from tqdm import tqdm | |
import requests | |
import os | |
import tempfile | |
import logging | |
logger = logging.getLogger(__name__) | |
def download(url, ckpt_dir=None): | |
name = url[url.rfind('/') + 1 : url.rfind('?')] | |
if ckpt_dir is None: | |
ckpt_dir = tempfile.gettempdir() | |
ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') | |
ckpt_file = os.path.join(ckpt_dir, name) | |
if not os.path.exists(ckpt_file): | |
logger.info(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') | |
if not os.path.exists(ckpt_dir): | |
os.makedirs(ckpt_dir) | |
response = requests.get(url, stream=True) | |
total_size_in_bytes = int(response.headers.get('content-length', 0)) | |
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
# first create temp file, in case the download fails | |
ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') | |
with open(ckpt_file_temp, 'wb') as file: | |
for data in response.iter_content(chunk_size=1024): | |
progress_bar.update(len(data)) | |
file.write(data) | |
progress_bar.close() | |
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
logger.error('An error occured while downloading, please try again.') | |
if os.path.exists(ckpt_file_temp): | |
os.remove(ckpt_file_temp) | |
else: | |
# if download was successful, rename the temp file | |
os.rename(ckpt_file_temp, ckpt_file) | |
return ckpt_file | |
def get(dictionary, key): | |
if dictionary is None or key not in dictionary: | |
return None | |
return dictionary[key] | |
def prefetch(dataset, n_prefetch): | |
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py | |
ds_iter = iter(dataset) | |
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), | |
ds_iter) | |
if n_prefetch: | |
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) | |
return ds_iter | |