Spaces:
Build error
Build error
Update e4e/models/psp.py
Browse files- e4e/models/psp.py +3 -1
e4e/models/psp.py
CHANGED
@@ -40,9 +40,11 @@ class pSp(nn.Module):
|
|
40 |
def load_weights(self):
|
41 |
if self.opts.checkpoint_path is not None:
|
42 |
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
43 |
-
ckpt = torch.load(self.opts.checkpoint_path, map_location='
|
44 |
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
|
|
45 |
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
|
|
46 |
self.__load_latent_avg(ckpt)
|
47 |
else:
|
48 |
print('Loading encoders weights from irse50!')
|
|
|
40 |
def load_weights(self):
|
41 |
if self.opts.checkpoint_path is not None:
|
42 |
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
43 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cuda:0' if torch.cuda.is_available() else "cpu")
|
44 |
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
45 |
+
self.encoder.to(self.device)
|
46 |
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
47 |
+
self.decoder.to(self.device)
|
48 |
self.__load_latent_avg(ckpt)
|
49 |
else:
|
50 |
print('Loading encoders weights from irse50!')
|