Spaces:
Running
on
A10G
Running
on
A10G
import json | |
import cv2 | |
import numpy as np | |
from torch.utils.data import Dataset | |
class MyDataset(Dataset): | |
def __init__(self): | |
self.data = [] | |
with open('./training/fill50k/prompt.json', 'rt') as f: | |
for line in f: | |
self.data.append(json.loads(line)) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
source_filename = item['source'] | |
target_filename = item['target'] | |
prompt = item['prompt'] | |
source = cv2.imread('./training/fill50k/' + source_filename) | |
target = cv2.imread('./training/fill50k/' + target_filename) | |
# Do not forget that OpenCV read images in BGR order. | |
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) | |
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) | |
# Normalize source images to [0, 1]. | |
source = source.astype(np.float32) / 255.0 | |
# Normalize target images to [-1, 1]. | |
target = (target.astype(np.float32) / 127.5) - 1.0 | |
return dict(jpg=target, txt=prompt, hint=source) | |