|
import os |
|
import sys |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as transforms |
|
from argparse import Namespace |
|
from e4e.models.psp import pSp |
|
from util import * |
|
|
|
|
|
|
|
@ torch.no_grad() |
|
def projection(img, name, device='cuda'): |
|
|
|
|
|
model_path = 'e4e.pt' |
|
ckpt = torch.load(model_path, map_location='cpu') |
|
opts = ckpt['opts'] |
|
opts['checkpoint_path'] = model_path |
|
opts= Namespace(**opts) |
|
net = pSp(opts, device).eval().to(device) |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(256), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
] |
|
) |
|
|
|
img = transform(img).unsqueeze(0).to(device) |
|
images, w_plus = net(img, randomize_noise=False, return_latents=True) |
|
result_file = {} |
|
result_file['latent'] = w_plus[0] |
|
torch.save(result_file, name) |
|
return w_plus[0] |
|
|