SyncDreamer / generate.py
liuyuan-pal's picture
init
8bb8404
raw
history blame contribute delete
No virus
2.3 kB
import argparse
from pathlib import Path
import numpy as np
import torch
from omegaconf import OmegaConf
from skimage.io import imsave
from ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion
from ldm.util import instantiate_from_config, prepare_inputs
def load_model(cfg,ckpt,strict=True):
config = OmegaConf.load(cfg)
model = instantiate_from_config(config.model)
print(f'loading model from {ckpt} ...')
ckpt = torch.load(ckpt,map_location='cpu')
model.load_state_dict(ckpt['state_dict'],strict=strict)
model = model.cuda().eval()
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg',type=str, default='configs/syncdreamer.yaml')
parser.add_argument('--ckpt',type=str, default='ckpt/syncdreamer-step80k.ckpt')
parser.add_argument('--output', type=str, required=True)
parser.add_argument('--input', type=str, required=True)
parser.add_argument('--elevation', type=float, required=True)
parser.add_argument('--sample_num', type=int, default=4)
parser.add_argument('--crop_size', type=int, default=-1)
parser.add_argument('--cfg_scale', type=float, default=2.0)
parser.add_argument('--batch_view_num', type=int, default=8)
parser.add_argument('--seed', type=int, default=6033)
flags = parser.parse_args()
torch.random.manual_seed(flags.seed)
np.random.seed(flags.seed)
model = load_model(flags.cfg, flags.ckpt, strict=True)
assert isinstance(model, SyncMultiviewDiffusion)
Path(f'{flags.output}').mkdir(exist_ok=True, parents=True)
# prepare data
data = prepare_inputs(flags.input, flags.elevation, flags.crop_size)
for k, v in data.items():
data[k] = v.unsqueeze(0).cuda()
data[k] = torch.repeat_interleave(data[k], flags.sample_num, dim=0)
x_sample = model.sample(data, flags.cfg_scale, flags.batch_view_num)
B, N, _, H, W = x_sample.shape
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
x_sample = x_sample.astype(np.uint8)
for bi in range(B):
output_fn = Path(flags.output)/ f'{bi}.png'
imsave(output_fn, np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))
if __name__=="__main__":
main()