|
import os
|
|
import torch
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from torch.utils.data import Dataset
|
|
|
|
class FileLatentDataset(Dataset):
|
|
def __init__(self, src_file, dst_file, device="cpu", dtype=torch.float16):
|
|
assert os.path.isfile(src_file), f"src bin missing! ({src_file})"
|
|
assert os.path.isfile(dst_file), f"dst bin missing! ({dst_file})"
|
|
self.src_data = torch.load(src_file).to(dtype).to(device)
|
|
self.dst_data = torch.load(dst_file).to(dtype).to(device)
|
|
assert self.src_data.shape[0] == self.dst_data.shape[0], "Data size mismatch!"
|
|
|
|
def __len__(self):
|
|
return self.src_data.shape[0]
|
|
|
|
def __getitem__(self, index):
|
|
return {
|
|
"src": self.src_data[index].float(),
|
|
"dst": self.dst_data[index].float(),
|
|
}
|
|
|
|
class Shard:
|
|
def __init__(self, paths):
|
|
self.paths = paths
|
|
self.data = None
|
|
|
|
def exists(self):
|
|
return all([os.path.isfile(x) for x in self.paths.values()])
|
|
|
|
def get_data(self):
|
|
if self.data is not None: return self.data
|
|
return {k:self.load_latent(v) for k,v in self.paths.items()}
|
|
|
|
def load_latent(self, path):
|
|
lat = torch.from_numpy(np.load(path))
|
|
if lat.shape[0] == 1:
|
|
lat = torch.squeeze(lat, 0)
|
|
assert not torch.isnan(torch.sum(lat.float()))
|
|
return lat
|
|
|
|
def preload(self):
|
|
self.data = self.get_data()
|
|
|
|
class LatentDataset(Dataset):
|
|
def __init__(self, src_root, dst_root, preload=True):
|
|
assert os.path.isdir(src_root), f"Source folder missing! ({src_root})"
|
|
assert os.path.isdir(dst_root), f"Destination folder missing! ({dst_root})"
|
|
|
|
print("Dataset: Parsing data from disk")
|
|
fnames = list(
|
|
set(os.listdir(src_root)).intersection(
|
|
set(os.listdir(dst_root)))
|
|
)
|
|
assert len(fnames) > 0, "Source/destination have no overlapping files"
|
|
|
|
self.shards = []
|
|
for fname in tqdm(fnames):
|
|
src_path = os.path.join(src_root, fname)
|
|
dst_path = os.path.join(dst_root, fname)
|
|
name, ext = os.path.splitext(fname)
|
|
if ext not in [".npy"]:
|
|
continue
|
|
shard = Shard({
|
|
"src": src_path,
|
|
"dst": dst_path,
|
|
})
|
|
if shard.exists():
|
|
self.shards.append(shard)
|
|
assert len(self.shards) > 0, "No valid files found."
|
|
|
|
if preload:
|
|
print("Dataset: Preloading data to system RAM")
|
|
[x.preload() for x in tqdm(self.shards)]
|
|
|
|
print(f"Dataset: OK, {len(self)} items")
|
|
|
|
def __len__(self):
|
|
return len(self.shards)
|
|
|
|
def __getitem__(self, index):
|
|
return self.shards[index].get_data()
|
|
|
|
def load_evals(evals):
|
|
data = {}
|
|
for name, paths in evals.items():
|
|
shard = Shard(paths)
|
|
assert shard.exists(), f"Eval data missing ({name})"
|
|
data[name] = {}
|
|
for k, v in shard.get_data().items():
|
|
if len(v.shape) == 3:
|
|
v = v.unsqueeze(0)
|
|
data[name][k] = v.float()
|
|
return data
|
|
|