advcloud commited on
Commit
9b2bdf6
1 Parent(s): d90c489

commit from $USER

Browse files
Files changed (48) hide show
  1. e4e/.gitignore +129 -0
  2. e4e/LICENSE +21 -0
  3. e4e/README.md +142 -0
  4. e4e/configs/__init__.py +0 -0
  5. e4e/configs/data_configs.py +41 -0
  6. e4e/configs/paths_config.py +28 -0
  7. e4e/configs/transforms_config.py +62 -0
  8. e4e/criteria/__init__.py +0 -0
  9. e4e/criteria/id_loss.py +47 -0
  10. e4e/criteria/lpips/__init__.py +0 -0
  11. e4e/criteria/lpips/lpips.py +35 -0
  12. e4e/criteria/lpips/networks.py +96 -0
  13. e4e/criteria/lpips/utils.py +30 -0
  14. e4e/criteria/moco_loss.py +71 -0
  15. e4e/criteria/w_norm.py +14 -0
  16. e4e/datasets/__init__.py +0 -0
  17. e4e/datasets/gt_res_dataset.py +32 -0
  18. e4e/datasets/images_dataset.py +33 -0
  19. e4e/datasets/inference_dataset.py +25 -0
  20. e4e/editings/ganspace.py +22 -0
  21. e4e/editings/latent_editor.py +45 -0
  22. e4e/editings/sefa.py +46 -0
  23. e4e/environment/e4e_env.yaml +73 -0
  24. e4e/metrics/LEC.py +134 -0
  25. e4e/models/__init__.py +0 -0
  26. e4e/models/discriminator.py +20 -0
  27. e4e/models/encoders/__init__.py +0 -0
  28. e4e/models/encoders/helpers.py +140 -0
  29. e4e/models/encoders/model_irse.py +84 -0
  30. e4e/models/encoders/psp_encoders.py +200 -0
  31. e4e/models/latent_codes_pool.py +55 -0
  32. e4e/models/psp.py +99 -0
  33. e4e/models/stylegan2/__init__.py +0 -0
  34. e4e/models/stylegan2/model.py +678 -0
  35. e4e/options/__init__.py +0 -0
  36. e4e/options/train_options.py +84 -0
  37. e4e/scripts/calc_losses_on_images.py +87 -0
  38. e4e/scripts/inference.py +133 -0
  39. e4e/scripts/train.py +88 -0
  40. e4e/training/__init__.py +0 -0
  41. e4e/training/coach.py +437 -0
  42. e4e/training/ranger.py +164 -0
  43. e4e/utils/__init__.py +0 -0
  44. e4e/utils/alignment.py +115 -0
  45. e4e/utils/common.py +55 -0
  46. e4e/utils/data_utils.py +25 -0
  47. e4e/utils/model_utils.py +35 -0
  48. e4e/utils/train_utils.py +13 -0
e4e/.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
e4e/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 omertov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
e4e/README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Designing an Encoder for StyleGAN Image Manipulation
2
+ <a href="https://arxiv.org/abs/2102.02766"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
3
+ <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
4
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/omertov/encoder4editing/blob/main/notebooks/inference_playground.ipynb)
5
+
6
+ > Recently, there has been a surge of diverse methods for performing image editing by employing pre-trained unconditional generators. Applying these methods on real images, however, remains a challenge, as it necessarily requires the inversion of the images into their latent space. To successfully invert a real image, one needs to find a latent code that reconstructs the input image accurately, and more importantly, allows for its meaningful manipulation. In this paper, we carefully study the latent space of StyleGAN, the state-of-the-art unconditional generator. We identify and analyze the existence of a distortion-editability tradeoff and a distortion-perception tradeoff within the StyleGAN latent space. We then suggest two principles for designing encoders in a manner that allows one to control the proximity of the inversions to regions that StyleGAN was originally trained on. We present an encoder based on our two principles that is specifically designed for facilitating editing on real images by balancing these tradeoffs. By evaluating its performance qualitatively and quantitatively on numerous challenging domains, including cars and horses, we show that our inversion method, followed by common editing techniques, achieves superior real-image editing quality, with only a small reconstruction accuracy drop.
7
+
8
+ <p align="center">
9
+ <img src="docs/teaser.jpg" width="800px"/>
10
+ </p>
11
+
12
+ ## Description
13
+ Official Implementation of "<a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>" paper for both training and evaluation.
14
+ The e4e encoder is specifically designed to complement existing image manipulation techniques performed over StyleGAN's latent space.
15
+
16
+ ## Recent Updates
17
+ `2021.03.25`: Add pose editing direction.
18
+
19
+ ## Getting Started
20
+ ### Prerequisites
21
+ - Linux or macOS
22
+ - NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
23
+ - Python 3
24
+
25
+ ### Installation
26
+ - Clone the repository:
27
+ ```
28
+ git clone https://github.com/omertov/encoder4editing.git
29
+ cd encoder4editing
30
+ ```
31
+ - Dependencies:
32
+ We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/).
33
+ All dependencies for defining the environment are provided in `environment/e4e_env.yaml`.
34
+
35
+ ### Inference Notebook
36
+ We provide a Jupyter notebook found in `notebooks/inference_playground.ipynb` that allows one to encode and perform several editings on real images using StyleGAN.
37
+
38
+ ### Pretrained Models
39
+ Please download the pre-trained models from the following links. Each e4e model contains the entire pSp framework architecture, including the encoder and decoder weights.
40
+ | Path | Description
41
+ | :--- | :----------
42
+ |[FFHQ Inversion](https://drive.google.com/file/d/1cUv_reLE6k3604or78EranS7XzuVMWeO/view?usp=sharing) | FFHQ e4e encoder.
43
+ |[Cars Inversion](https://drive.google.com/file/d/17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV/view?usp=sharing) | Cars e4e encoder.
44
+ |[Horse Inversion](https://drive.google.com/file/d/1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX/view?usp=sharing) | Horse e4e encoder.
45
+ |[Church Inversion](https://drive.google.com/file/d/1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa/view?usp=sharing) | Church e4e encoder.
46
+
47
+ If you wish to use one of the pretrained models for training or inference, you may do so using the flag `--checkpoint_path`.
48
+
49
+ In addition, we provide various auxiliary models needed for training your own e4e model from scratch.
50
+ | Path | Description
51
+ | :--- | :----------
52
+ |[FFHQ StyleGAN](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing) | StyleGAN model pretrained on FFHQ taken from [rosinality](https://github.com/rosinality/stylegan2-pytorch) with 1024x1024 output resolution.
53
+ |[IR-SE50 Model](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing) | Pretrained IR-SE50 model taken from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our ID loss during training.
54
+ |[MOCOv2 Model](https://drive.google.com/file/d/18rLcNGdteX5LwT7sv_F7HWr12HpVEzVe/view?usp=sharing) | Pretrained ResNet-50 model trained using MOCOv2 for use in our simmilarity loss for domains other then human faces during training.
55
+
56
+ By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`. However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
57
+
58
+ ## Training
59
+ To train the e4e encoder, make sure the paths to the required models, as well as training and testing data is configured in `configs/path_configs.py` and `configs/data_configs.py`.
60
+ #### **Training the e4e Encoder**
61
+ ```
62
+ python scripts/train.py \
63
+ --dataset_type cars_encode \
64
+ --exp_dir new/experiment/directory \
65
+ --start_from_latent_avg \
66
+ --use_w_pool \
67
+ --w_discriminator_lambda 0.1 \
68
+ --progressive_start 20000 \
69
+ --id_lambda 0.5 \
70
+ --val_interval 10000 \
71
+ --max_steps 200000 \
72
+ --stylegan_size 512 \
73
+ --stylegan_weights path/to/pretrained/stylegan.pt \
74
+ --workers 8 \
75
+ --batch_size 8 \
76
+ --test_batch_size 4 \
77
+ --test_workers 4
78
+ ```
79
+
80
+ #### Training on your own dataset
81
+ In order to train the e4e encoder on a custom dataset, perform the following adjustments:
82
+ 1. Insert the paths to your train and test data into the `dataset_paths` variable defined in `configs/paths_config.py`:
83
+ ```
84
+ dataset_paths = {
85
+ 'my_train_data': '/path/to/train/images/directory',
86
+ 'my_test_data': '/path/to/test/images/directory'
87
+ }
88
+ ```
89
+ 2. Configure a new dataset under the DATASETS variable defined in `configs/data_configs.py`:
90
+ ```
91
+ DATASETS = {
92
+ 'my_data_encode': {
93
+ 'transforms': transforms_config.EncodeTransforms,
94
+ 'train_source_root': dataset_paths['my_train_data'],
95
+ 'train_target_root': dataset_paths['my_train_data'],
96
+ 'test_source_root': dataset_paths['my_test_data'],
97
+ 'test_target_root': dataset_paths['my_test_data']
98
+ }
99
+ }
100
+ ```
101
+ Refer to `configs/transforms_config.py` for the transformations applied to the train and test images during training.
102
+
103
+ 3. Finally, run a training session with `--dataset_type my_data_encode`.
104
+
105
+ ## Inference
106
+ Having trained your model, you can use `scripts/inference.py` to apply the model on a set of images.
107
+ For example,
108
+ ```
109
+ python scripts/inference.py \
110
+ --images_dir=/path/to/images/directory \
111
+ --save_dir=/path/to/saving/directory \
112
+ path/to/checkpoint.pt
113
+ ```
114
+
115
+ ## Latent Editing Consistency (LEC)
116
+ As described in the paper, we suggest a new metric, Latent Editing Consistency (LEC), for evaluating the encoder's
117
+ performance.
118
+ We provide an example for calculating the metric over the FFHQ StyleGAN using the aging editing direction in
119
+ `metrics/LEC.py`.
120
+
121
+ To run the example:
122
+ ```
123
+ cd metrics
124
+ python LEC.py \
125
+ --images_dir=/path/to/images/directory \
126
+ path/to/checkpoint.pt
127
+ ```
128
+
129
+ ## Acknowledgments
130
+ This code borrows heavily from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel)
131
+
132
+ ## Citation
133
+ If you use this code for your research, please cite our paper <a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>:
134
+
135
+ ```
136
+ @article{tov2021designing,
137
+ title={Designing an Encoder for StyleGAN Image Manipulation},
138
+ author={Tov, Omer and Alaluf, Yuval and Nitzan, Yotam and Patashnik, Or and Cohen-Or, Daniel},
139
+ journal={arXiv preprint arXiv:2102.02766},
140
+ year={2021}
141
+ }
142
+ ```
e4e/configs/__init__.py ADDED
File without changes
e4e/configs/data_configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import transforms_config
2
+ from configs.paths_config import dataset_paths
3
+
4
+
5
+ DATASETS = {
6
+ 'ffhq_encode': {
7
+ 'transforms': transforms_config.EncodeTransforms,
8
+ 'train_source_root': dataset_paths['ffhq'],
9
+ 'train_target_root': dataset_paths['ffhq'],
10
+ 'test_source_root': dataset_paths['celeba_test'],
11
+ 'test_target_root': dataset_paths['celeba_test'],
12
+ },
13
+ 'cars_encode': {
14
+ 'transforms': transforms_config.CarsEncodeTransforms,
15
+ 'train_source_root': dataset_paths['cars_train'],
16
+ 'train_target_root': dataset_paths['cars_train'],
17
+ 'test_source_root': dataset_paths['cars_test'],
18
+ 'test_target_root': dataset_paths['cars_test'],
19
+ },
20
+ 'horse_encode': {
21
+ 'transforms': transforms_config.EncodeTransforms,
22
+ 'train_source_root': dataset_paths['horse_train'],
23
+ 'train_target_root': dataset_paths['horse_train'],
24
+ 'test_source_root': dataset_paths['horse_test'],
25
+ 'test_target_root': dataset_paths['horse_test'],
26
+ },
27
+ 'church_encode': {
28
+ 'transforms': transforms_config.EncodeTransforms,
29
+ 'train_source_root': dataset_paths['church_train'],
30
+ 'train_target_root': dataset_paths['church_train'],
31
+ 'test_source_root': dataset_paths['church_test'],
32
+ 'test_target_root': dataset_paths['church_test'],
33
+ },
34
+ 'cats_encode': {
35
+ 'transforms': transforms_config.EncodeTransforms,
36
+ 'train_source_root': dataset_paths['cats_train'],
37
+ 'train_target_root': dataset_paths['cats_train'],
38
+ 'test_source_root': dataset_paths['cats_test'],
39
+ 'test_target_root': dataset_paths['cats_test'],
40
+ }
41
+ }
e4e/configs/paths_config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_paths = {
2
+ # Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
3
+ 'ffhq': '',
4
+ 'celeba_test': '',
5
+
6
+ # Cars Dataset (In the paper: Stanford cars)
7
+ 'cars_train': '',
8
+ 'cars_test': '',
9
+
10
+ # Horse Dataset (In the paper: LSUN Horse)
11
+ 'horse_train': '',
12
+ 'horse_test': '',
13
+
14
+ # Church Dataset (In the paper: LSUN Church)
15
+ 'church_train': '',
16
+ 'church_test': '',
17
+
18
+ # Cats Dataset (In the paper: LSUN Cat)
19
+ 'cats_train': '',
20
+ 'cats_test': ''
21
+ }
22
+
23
+ model_paths = {
24
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
25
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
26
+ 'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
27
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
28
+ }
e4e/configs/transforms_config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torchvision.transforms as transforms
3
+
4
+
5
+ class TransformsConfig(object):
6
+
7
+ def __init__(self, opts):
8
+ self.opts = opts
9
+
10
+ @abstractmethod
11
+ def get_transforms(self):
12
+ pass
13
+
14
+
15
+ class EncodeTransforms(TransformsConfig):
16
+
17
+ def __init__(self, opts):
18
+ super(EncodeTransforms, self).__init__(opts)
19
+
20
+ def get_transforms(self):
21
+ transforms_dict = {
22
+ 'transform_gt_train': transforms.Compose([
23
+ transforms.Resize((256, 256)),
24
+ transforms.RandomHorizontalFlip(0.5),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
27
+ 'transform_source': None,
28
+ 'transform_test': transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
32
+ 'transform_inference': transforms.Compose([
33
+ transforms.Resize((256, 256)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
36
+ }
37
+ return transforms_dict
38
+
39
+
40
+ class CarsEncodeTransforms(TransformsConfig):
41
+
42
+ def __init__(self, opts):
43
+ super(CarsEncodeTransforms, self).__init__(opts)
44
+
45
+ def get_transforms(self):
46
+ transforms_dict = {
47
+ 'transform_gt_train': transforms.Compose([
48
+ transforms.Resize((192, 256)),
49
+ transforms.RandomHorizontalFlip(0.5),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
52
+ 'transform_source': None,
53
+ 'transform_test': transforms.Compose([
54
+ transforms.Resize((192, 256)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
57
+ 'transform_inference': transforms.Compose([
58
+ transforms.Resize((192, 256)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
61
+ }
62
+ return transforms_dict
e4e/criteria/__init__.py ADDED
File without changes
e4e/criteria/id_loss.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from configs.paths_config import model_paths
4
+ from models.encoders.model_irse import Backbone
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
13
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14
+ self.facenet.eval()
15
+ for module in [self.facenet, self.face_pool]:
16
+ for param in module.parameters():
17
+ param.requires_grad = False
18
+
19
+ def extract_feats(self, x):
20
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
21
+ x = self.face_pool(x)
22
+ x_feats = self.facenet(x)
23
+ return x_feats
24
+
25
+ def forward(self, y_hat, y, x):
26
+ n_samples = x.shape[0]
27
+ x_feats = self.extract_feats(x)
28
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
29
+ y_hat_feats = self.extract_feats(y_hat)
30
+ y_feats = y_feats.detach()
31
+ loss = 0
32
+ sim_improvement = 0
33
+ id_logs = []
34
+ count = 0
35
+ for i in range(n_samples):
36
+ diff_target = y_hat_feats[i].dot(y_feats[i])
37
+ diff_input = y_hat_feats[i].dot(x_feats[i])
38
+ diff_views = y_feats[i].dot(x_feats[i])
39
+ id_logs.append({'diff_target': float(diff_target),
40
+ 'diff_input': float(diff_input),
41
+ 'diff_views': float(diff_views)})
42
+ loss += 1 - diff_target
43
+ id_diff = float(diff_target) - float(diff_views)
44
+ sim_improvement += id_diff
45
+ count += 1
46
+
47
+ return loss / count, sim_improvement / count, id_logs
e4e/criteria/lpips/__init__.py ADDED
File without changes
e4e/criteria/lpips/lpips.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from criteria.lpips.networks import get_network, LinLayers
5
+ from criteria.lpips.utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+ Arguments:
12
+ net_type (str): the network type to compare the features:
13
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14
+ version (str): the version of LPIPS. Default: 0.1.
15
+ """
16
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17
+
18
+ assert version in ['0.1'], 'v0.1 is only supported now'
19
+
20
+ super(LPIPS, self).__init__()
21
+
22
+ # pretrained network
23
+ self.net = get_network(net_type).to("cuda")
24
+
25
+ # linear layers
26
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27
+ self.lin.load_state_dict(get_state_dict(net_type, version))
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
30
+ feat_x, feat_y = self.net(x), self.net(y)
31
+
32
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34
+
35
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
e4e/criteria/lpips/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from criteria.lpips.utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
e4e/criteria/lpips/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
e4e/criteria/moco_loss.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from configs.paths_config import model_paths
6
+
7
+
8
+ class MocoLoss(nn.Module):
9
+
10
+ def __init__(self, opts):
11
+ super(MocoLoss, self).__init__()
12
+ print("Loading MOCO model from path: {}".format(model_paths["moco"]))
13
+ self.model = self.__load_model()
14
+ self.model.eval()
15
+ for param in self.model.parameters():
16
+ param.requires_grad = False
17
+
18
+ @staticmethod
19
+ def __load_model():
20
+ import torchvision.models as models
21
+ model = models.__dict__["resnet50"]()
22
+ # freeze all layers but the last fc
23
+ for name, param in model.named_parameters():
24
+ if name not in ['fc.weight', 'fc.bias']:
25
+ param.requires_grad = False
26
+ checkpoint = torch.load(model_paths['moco'], map_location="cpu")
27
+ state_dict = checkpoint['state_dict']
28
+ # rename moco pre-trained keys
29
+ for k in list(state_dict.keys()):
30
+ # retain only encoder_q up to before the embedding layer
31
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
32
+ # remove prefix
33
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
34
+ # delete renamed or unused k
35
+ del state_dict[k]
36
+ msg = model.load_state_dict(state_dict, strict=False)
37
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
38
+ # remove output layer
39
+ model = nn.Sequential(*list(model.children())[:-1]).cuda()
40
+ return model
41
+
42
+ def extract_feats(self, x):
43
+ x = F.interpolate(x, size=224)
44
+ x_feats = self.model(x)
45
+ x_feats = nn.functional.normalize(x_feats, dim=1)
46
+ x_feats = x_feats.squeeze()
47
+ return x_feats
48
+
49
+ def forward(self, y_hat, y, x):
50
+ n_samples = x.shape[0]
51
+ x_feats = self.extract_feats(x)
52
+ y_feats = self.extract_feats(y)
53
+ y_hat_feats = self.extract_feats(y_hat)
54
+ y_feats = y_feats.detach()
55
+ loss = 0
56
+ sim_improvement = 0
57
+ sim_logs = []
58
+ count = 0
59
+ for i in range(n_samples):
60
+ diff_target = y_hat_feats[i].dot(y_feats[i])
61
+ diff_input = y_hat_feats[i].dot(x_feats[i])
62
+ diff_views = y_feats[i].dot(x_feats[i])
63
+ sim_logs.append({'diff_target': float(diff_target),
64
+ 'diff_input': float(diff_input),
65
+ 'diff_views': float(diff_views)})
66
+ loss += 1 - diff_target
67
+ sim_diff = float(diff_target) - float(diff_views)
68
+ sim_improvement += sim_diff
69
+ count += 1
70
+
71
+ return loss / count, sim_improvement / count, sim_logs
e4e/criteria/w_norm.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class WNormLoss(nn.Module):
6
+
7
+ def __init__(self, start_from_latent_avg=True):
8
+ super(WNormLoss, self).__init__()
9
+ self.start_from_latent_avg = start_from_latent_avg
10
+
11
+ def forward(self, latent, latent_avg=None):
12
+ if self.start_from_latent_avg:
13
+ latent = latent - latent_avg
14
+ return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
e4e/datasets/__init__.py ADDED
File without changes
e4e/datasets/gt_res_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # encoding: utf-8
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import torch
7
+
8
+ class GTResDataset(Dataset):
9
+
10
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
11
+ self.pairs = []
12
+ for f in os.listdir(root_path):
13
+ image_path = os.path.join(root_path, f)
14
+ gt_path = os.path.join(gt_dir, f)
15
+ if f.endswith(".jpg") or f.endswith(".png"):
16
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
17
+ self.transform = transform
18
+ self.transform_train = transform_train
19
+
20
+ def __len__(self):
21
+ return len(self.pairs)
22
+
23
+ def __getitem__(self, index):
24
+ from_path, to_path, _ = self.pairs[index]
25
+ from_im = Image.open(from_path).convert('RGB')
26
+ to_im = Image.open(to_path).convert('RGB')
27
+
28
+ if self.transform:
29
+ to_im = self.transform(to_im)
30
+ from_im = self.transform(from_im)
31
+
32
+ return from_im, to_im
e4e/datasets/images_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class ImagesDataset(Dataset):
7
+
8
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
9
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
10
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
11
+ self.source_transform = source_transform
12
+ self.target_transform = target_transform
13
+ self.opts = opts
14
+
15
+ def __len__(self):
16
+ return len(self.source_paths)
17
+
18
+ def __getitem__(self, index):
19
+ from_path = self.source_paths[index]
20
+ from_im = Image.open(from_path)
21
+ from_im = from_im.convert('RGB')
22
+
23
+ to_path = self.target_paths[index]
24
+ to_im = Image.open(to_path).convert('RGB')
25
+ if self.target_transform:
26
+ to_im = self.target_transform(to_im)
27
+
28
+ if self.source_transform:
29
+ from_im = self.source_transform(from_im)
30
+ else:
31
+ from_im = to_im
32
+
33
+ return from_im, to_im
e4e/datasets/inference_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class InferenceDataset(Dataset):
7
+
8
+ def __init__(self, root, opts, transform=None, preprocess=None):
9
+ self.paths = sorted(data_utils.make_dataset(root))
10
+ self.transform = transform
11
+ self.preprocess = preprocess
12
+ self.opts = opts
13
+
14
+ def __len__(self):
15
+ return len(self.paths)
16
+
17
+ def __getitem__(self, index):
18
+ from_path = self.paths[index]
19
+ if self.preprocess is not None:
20
+ from_im = self.preprocess(from_path)
21
+ else:
22
+ from_im = Image.open(from_path).convert('RGB')
23
+ if self.transform:
24
+ from_im = self.transform(from_im)
25
+ return from_im
e4e/editings/ganspace.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def edit(latents, pca, edit_directions):
5
+ edit_latents = []
6
+ for latent in latents:
7
+ for pca_idx, start, end, strength in edit_directions:
8
+ delta = get_delta(pca, latent, pca_idx, strength)
9
+ delta_padded = torch.zeros(latent.shape).to('cuda')
10
+ delta_padded[start:end] += delta.repeat(end - start, 1)
11
+ edit_latents.append(latent + delta_padded)
12
+ return torch.stack(edit_latents)
13
+
14
+
15
+ def get_delta(pca, latent, idx, strength):
16
+ # pca: ganspace checkpoint. latent: (16, 512) w+
17
+ w_centered = latent - pca['mean'].to('cuda')
18
+ lat_comp = pca['comp'].to('cuda')
19
+ lat_std = pca['std'].to('cuda')
20
+ w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
21
+ delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
22
+ return delta
e4e/editings/latent_editor.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.path.append(".")
4
+ sys.path.append("..")
5
+ from editings import ganspace, sefa
6
+ from utils.common import tensor2im
7
+
8
+
9
+ class LatentEditor(object):
10
+ def __init__(self, stylegan_generator, is_cars=False):
11
+ self.generator = stylegan_generator
12
+ self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output.
13
+
14
+ def apply_ganspace(self, latent, ganspace_pca, edit_directions):
15
+ edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
16
+ return self._latents_to_image(edit_latents)
17
+
18
+ def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
19
+ edit_latents = []
20
+ if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5)
21
+ for f in range(*factor_range):
22
+ edit_latent = latent + f * direction
23
+ edit_latents.append(edit_latent)
24
+ edit_latents = torch.cat(edit_latents)
25
+ else:
26
+ edit_latents = latent + factor * direction
27
+ return self._latents_to_image(edit_latents)
28
+
29
+ def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs):
30
+ edit_latents = sefa.edit(self.generator, latent, indices, **kwargs)
31
+ return self._latents_to_image(edit_latents)
32
+
33
+ # Currently, in order to apply StyleFlow editings, one should run inference,
34
+ # save the latent codes and load them form the official StyleFlow repository.
35
+ # def apply_styleflow(self):
36
+ # pass
37
+
38
+ def _latents_to_image(self, latents):
39
+ with torch.no_grad():
40
+ images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True)
41
+ if self.is_cars:
42
+ images = images[:, :, 64:448, :] # 512x512 -> 384x512
43
+ horizontal_concat_image = torch.cat(list(images), 2)
44
+ final_image = tensor2im(horizontal_concat_image)
45
+ return final_image
e4e/editings/sefa.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+
6
+ def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11):
7
+
8
+ layers, boundaries, values = factorize_weight(generator, indices)
9
+ codes = latents.detach().cpu().numpy() # (1,18,512)
10
+
11
+ # Generate visualization pages.
12
+ distances = np.linspace(start_distance, end_distance, step)
13
+ num_sam = num_samples
14
+ num_sem = semantics
15
+
16
+ edited_latents = []
17
+ for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
18
+ boundary = boundaries[sem_id:sem_id + 1]
19
+ for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
20
+ code = codes[sam_id:sam_id + 1]
21
+ for col_id, d in enumerate(distances, start=1):
22
+ temp_code = code.copy()
23
+ temp_code[:, layers, :] += boundary * d
24
+ edited_latents.append(torch.from_numpy(temp_code).float().cuda())
25
+ return torch.cat(edited_latents)
26
+
27
+
28
+ def factorize_weight(g_ema, layers='all'):
29
+
30
+ weights = []
31
+ if layers == 'all' or 0 in layers:
32
+ weight = g_ema.conv1.conv.modulation.weight.T
33
+ weights.append(weight.cpu().detach().numpy())
34
+
35
+ if layers == 'all':
36
+ layers = list(range(g_ema.num_layers - 1))
37
+ else:
38
+ layers = [l - 1 for l in layers if l != 0]
39
+
40
+ for idx in layers:
41
+ weight = g_ema.convs[idx].conv.modulation.weight.T
42
+ weights.append(weight.cpu().detach().numpy())
43
+ weight = np.concatenate(weights, axis=1).astype(np.float32)
44
+ weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
45
+ eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
46
+ return layers, eigen_vectors.T, eigen_values
e4e/environment/e4e_env.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: e4e_env
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - ca-certificates=2020.4.5.1=hecc5488_0
8
+ - certifi=2020.4.5.1=py36h9f0ad1d_0
9
+ - libedit=3.1.20181209=hc058e9b_0
10
+ - libffi=3.2.1=hd88cf55_4
11
+ - libgcc-ng=9.1.0=hdf63c60_0
12
+ - libstdcxx-ng=9.1.0=hdf63c60_0
13
+ - ncurses=6.2=he6710b0_1
14
+ - ninja=1.10.0=hc9558a2_0
15
+ - openssl=1.1.1g=h516909a_0
16
+ - pip=20.0.2=py36_3
17
+ - python=3.6.7=h0371630_0
18
+ - python_abi=3.6=1_cp36m
19
+ - readline=7.0=h7b6447c_5
20
+ - setuptools=46.4.0=py36_0
21
+ - sqlite=3.31.1=h62c20be_1
22
+ - tk=8.6.8=hbc83047_0
23
+ - wheel=0.34.2=py36_0
24
+ - xz=5.2.5=h7b6447c_0
25
+ - zlib=1.2.11=h7b6447c_3
26
+ - pip:
27
+ - absl-py==0.9.0
28
+ - cachetools==4.1.0
29
+ - chardet==3.0.4
30
+ - cycler==0.10.0
31
+ - decorator==4.4.2
32
+ - future==0.18.2
33
+ - google-auth==1.15.0
34
+ - google-auth-oauthlib==0.4.1
35
+ - grpcio==1.29.0
36
+ - idna==2.9
37
+ - imageio==2.8.0
38
+ - importlib-metadata==1.6.0
39
+ - kiwisolver==1.2.0
40
+ - markdown==3.2.2
41
+ - matplotlib==3.2.1
42
+ - mxnet==1.6.0
43
+ - networkx==2.4
44
+ - numpy==1.18.4
45
+ - oauthlib==3.1.0
46
+ - opencv-python==4.2.0.34
47
+ - pillow==7.1.2
48
+ - protobuf==3.12.1
49
+ - pyasn1==0.4.8
50
+ - pyasn1-modules==0.2.8
51
+ - pyparsing==2.4.7
52
+ - python-dateutil==2.8.1
53
+ - pytorch-lightning==0.7.1
54
+ - pywavelets==1.1.1
55
+ - requests==2.23.0
56
+ - requests-oauthlib==1.3.0
57
+ - rsa==4.0
58
+ - scikit-image==0.17.2
59
+ - scipy==1.4.1
60
+ - six==1.15.0
61
+ - tensorboard==2.2.1
62
+ - tensorboard-plugin-wit==1.6.0.post3
63
+ - tensorboardx==1.9
64
+ - tifffile==2020.5.25
65
+ - torch==1.6.0
66
+ - torchvision==0.7.1
67
+ - tqdm==4.46.0
68
+ - urllib3==1.25.9
69
+ - werkzeug==1.0.1
70
+ - zipp==3.1.0
71
+ - pyaml
72
+ prefix: ~/anaconda3/envs/e4e_env
73
+
e4e/metrics/LEC.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import torch
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader
6
+
7
+ sys.path.append(".")
8
+ sys.path.append("..")
9
+
10
+ from configs import data_configs
11
+ from datasets.images_dataset import ImagesDataset
12
+ from utils.model_utils import setup_model
13
+
14
+
15
+ class LEC:
16
+ def __init__(self, net, is_cars=False):
17
+ """
18
+ Latent Editing Consistency metric as proposed in the main paper.
19
+ :param net: e4e model loaded over the pSp framework.
20
+ :param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
21
+ """
22
+ self.net = net
23
+ self.is_cars = is_cars
24
+
25
+ def _encode(self, images):
26
+ """
27
+ Encodes the given images into StyleGAN's latent space.
28
+ :param images: Tensor of shape NxCxHxW representing the images to be encoded.
29
+ :return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
30
+ """
31
+ codes = self.net.encoder(images)
32
+ assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
33
+ # normalize with respect to the center of an average face
34
+ if self.net.opts.start_from_latent_avg:
35
+ codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
36
+ return codes
37
+
38
+ def _generate(self, codes):
39
+ """
40
+ Generate the StyleGAN2 images of the given codes
41
+ :param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
42
+ :return: Tensor of shape NxCxHxW representing the generated images.
43
+ """
44
+ images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
45
+ images = self.net.face_pool(images)
46
+ if self.is_cars:
47
+ images = images[:, :, 32:224, :]
48
+ return images
49
+
50
+ @staticmethod
51
+ def _filter_outliers(arr):
52
+ arr = np.array(arr)
53
+
54
+ lo = np.percentile(arr, 1, interpolation="lower")
55
+ hi = np.percentile(arr, 99, interpolation="higher")
56
+ return np.extract(
57
+ np.logical_and(lo <= arr, arr <= hi), arr
58
+ )
59
+
60
+ def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
61
+ """
62
+ Calculate the LEC metric score.
63
+ :param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
64
+ :param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
65
+ latent space.
66
+ :param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
67
+ `edit_function` parameter.
68
+ :return: The LEC metric score.
69
+ """
70
+ distances = []
71
+ with torch.no_grad():
72
+ for batch in data_loader:
73
+ x, _ = batch
74
+ inputs = x.to(device).float()
75
+
76
+ codes = self._encode(inputs)
77
+ edited_codes = edit_function(codes)
78
+ edited_image = self._generate(edited_codes)
79
+ edited_image_inversion_codes = self._encode(edited_image)
80
+ inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
81
+
82
+ dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
83
+ distances.append(dist.to("cpu").numpy())
84
+
85
+ distances = self._filter_outliers(distances)
86
+ return distances.mean()
87
+
88
+
89
+ if __name__ == "__main__":
90
+ device = "cuda"
91
+
92
+ parser = argparse.ArgumentParser(description="LEC metric calculator")
93
+
94
+ parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
95
+ parser.add_argument("--images_dir", type=str, default=None,
96
+ help="Path to the images directory on which we calculate the LEC score")
97
+ parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
98
+
99
+ args = parser.parse_args()
100
+ print(args)
101
+
102
+ net, opts = setup_model(args.ckpt, device)
103
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
104
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
105
+
106
+ images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
107
+ test_dataset = ImagesDataset(source_root=images_directory,
108
+ target_root=images_directory,
109
+ source_transform=transforms_dict['transform_source'],
110
+ target_transform=transforms_dict['transform_test'],
111
+ opts=opts)
112
+
113
+ data_loader = DataLoader(test_dataset,
114
+ batch_size=args.batch,
115
+ shuffle=False,
116
+ num_workers=2,
117
+ drop_last=True)
118
+
119
+ print(f'dataset length: {len(test_dataset)}')
120
+
121
+ # In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
122
+ # Change the provided example according to your domain and needs.
123
+ direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
124
+
125
+ def edit_func_example(codes):
126
+ return codes + 3 * direction
127
+
128
+
129
+ def inverse_edit_func_example(codes):
130
+ return codes - 3 * direction
131
+
132
+ lec = LEC(net, is_cars='car' in opts.dataset_type)
133
+ result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
134
+ print(f"LEC: {result}")
e4e/models/__init__.py ADDED
File without changes
e4e/models/discriminator.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class LatentCodesDiscriminator(nn.Module):
5
+ def __init__(self, style_dim, n_mlp):
6
+ super().__init__()
7
+
8
+ self.style_dim = style_dim
9
+
10
+ layers = []
11
+ for i in range(n_mlp-1):
12
+ layers.append(
13
+ nn.Linear(style_dim, style_dim)
14
+ )
15
+ layers.append(nn.LeakyReLU(0.2))
16
+ layers.append(nn.Linear(512, 1))
17
+ self.mlp = nn.Sequential(*layers)
18
+
19
+ def forward(self, w):
20
+ return self.mlp(w)
e4e/models/encoders/__init__.py ADDED
File without changes
e4e/models/encoders/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
5
+
6
+ """
7
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8
+ """
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, input):
13
+ return input.view(input.size(0), -1)
14
+
15
+
16
+ def l2_norm(input, axis=1):
17
+ norm = torch.norm(input, 2, axis, True)
18
+ output = torch.div(input, norm)
19
+ return output
20
+
21
+
22
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
23
+ """ A named tuple describing a ResNet block. """
24
+
25
+
26
+ def get_block(in_channel, depth, num_units, stride=2):
27
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
28
+
29
+
30
+ def get_blocks(num_layers):
31
+ if num_layers == 50:
32
+ blocks = [
33
+ get_block(in_channel=64, depth=64, num_units=3),
34
+ get_block(in_channel=64, depth=128, num_units=4),
35
+ get_block(in_channel=128, depth=256, num_units=14),
36
+ get_block(in_channel=256, depth=512, num_units=3)
37
+ ]
38
+ elif num_layers == 100:
39
+ blocks = [
40
+ get_block(in_channel=64, depth=64, num_units=3),
41
+ get_block(in_channel=64, depth=128, num_units=13),
42
+ get_block(in_channel=128, depth=256, num_units=30),
43
+ get_block(in_channel=256, depth=512, num_units=3)
44
+ ]
45
+ elif num_layers == 152:
46
+ blocks = [
47
+ get_block(in_channel=64, depth=64, num_units=3),
48
+ get_block(in_channel=64, depth=128, num_units=8),
49
+ get_block(in_channel=128, depth=256, num_units=36),
50
+ get_block(in_channel=256, depth=512, num_units=3)
51
+ ]
52
+ else:
53
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
54
+ return blocks
55
+
56
+
57
+ class SEModule(Module):
58
+ def __init__(self, channels, reduction):
59
+ super(SEModule, self).__init__()
60
+ self.avg_pool = AdaptiveAvgPool2d(1)
61
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
62
+ self.relu = ReLU(inplace=True)
63
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
64
+ self.sigmoid = Sigmoid()
65
+
66
+ def forward(self, x):
67
+ module_input = x
68
+ x = self.avg_pool(x)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ x = self.sigmoid(x)
73
+ return module_input * x
74
+
75
+
76
+ class bottleneck_IR(Module):
77
+ def __init__(self, in_channel, depth, stride):
78
+ super(bottleneck_IR, self).__init__()
79
+ if in_channel == depth:
80
+ self.shortcut_layer = MaxPool2d(1, stride)
81
+ else:
82
+ self.shortcut_layer = Sequential(
83
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
84
+ BatchNorm2d(depth)
85
+ )
86
+ self.res_layer = Sequential(
87
+ BatchNorm2d(in_channel),
88
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
89
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
90
+ )
91
+
92
+ def forward(self, x):
93
+ shortcut = self.shortcut_layer(x)
94
+ res = self.res_layer(x)
95
+ return res + shortcut
96
+
97
+
98
+ class bottleneck_IR_SE(Module):
99
+ def __init__(self, in_channel, depth, stride):
100
+ super(bottleneck_IR_SE, self).__init__()
101
+ if in_channel == depth:
102
+ self.shortcut_layer = MaxPool2d(1, stride)
103
+ else:
104
+ self.shortcut_layer = Sequential(
105
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
106
+ BatchNorm2d(depth)
107
+ )
108
+ self.res_layer = Sequential(
109
+ BatchNorm2d(in_channel),
110
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
111
+ PReLU(depth),
112
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
113
+ BatchNorm2d(depth),
114
+ SEModule(depth, 16)
115
+ )
116
+
117
+ def forward(self, x):
118
+ shortcut = self.shortcut_layer(x)
119
+ res = self.res_layer(x)
120
+ return res + shortcut
121
+
122
+
123
+ def _upsample_add(x, y):
124
+ """Upsample and add two feature maps.
125
+ Args:
126
+ x: (Variable) top feature map to be upsampled.
127
+ y: (Variable) lateral feature map.
128
+ Returns:
129
+ (Variable) added feature map.
130
+ Note in PyTorch, when input size is odd, the upsampled feature map
131
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
132
+ maybe not equal to the lateral feature map size.
133
+ e.g.
134
+ original input size: [N,_,15,15] ->
135
+ conv2d feature map size: [N,_,8,8] ->
136
+ upsampled feature map size: [N,_,16,16]
137
+ So we choose bilinear upsample which supports arbitrary output sizes.
138
+ """
139
+ _, _, H, W = y.size()
140
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
e4e/models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from e4e.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
e4e/models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from e4e.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
9
+ from e4e.models.stylegan2.model import EqualLinear
10
+
11
+
12
+ class ProgressiveStage(Enum):
13
+ WTraining = 0
14
+ Delta1Training = 1
15
+ Delta2Training = 2
16
+ Delta3Training = 3
17
+ Delta4Training = 4
18
+ Delta5Training = 5
19
+ Delta6Training = 6
20
+ Delta7Training = 7
21
+ Delta8Training = 8
22
+ Delta9Training = 9
23
+ Delta10Training = 10
24
+ Delta11Training = 11
25
+ Delta12Training = 12
26
+ Delta13Training = 13
27
+ Delta14Training = 14
28
+ Delta15Training = 15
29
+ Delta16Training = 16
30
+ Delta17Training = 17
31
+ Inference = 18
32
+
33
+
34
+ class GradualStyleBlock(Module):
35
+ def __init__(self, in_c, out_c, spatial):
36
+ super(GradualStyleBlock, self).__init__()
37
+ self.out_c = out_c
38
+ self.spatial = spatial
39
+ num_pools = int(np.log2(spatial))
40
+ modules = []
41
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
42
+ nn.LeakyReLU()]
43
+ for i in range(num_pools - 1):
44
+ modules += [
45
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
46
+ nn.LeakyReLU()
47
+ ]
48
+ self.convs = nn.Sequential(*modules)
49
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
50
+
51
+ def forward(self, x):
52
+ x = self.convs(x)
53
+ x = x.view(-1, self.out_c)
54
+ x = self.linear(x)
55
+ return x
56
+
57
+
58
+ class GradualStyleEncoder(Module):
59
+ def __init__(self, num_layers, mode='ir', opts=None):
60
+ super(GradualStyleEncoder, self).__init__()
61
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
62
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
63
+ blocks = get_blocks(num_layers)
64
+ if mode == 'ir':
65
+ unit_module = bottleneck_IR
66
+ elif mode == 'ir_se':
67
+ unit_module = bottleneck_IR_SE
68
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
69
+ BatchNorm2d(64),
70
+ PReLU(64))
71
+ modules = []
72
+ for block in blocks:
73
+ for bottleneck in block:
74
+ modules.append(unit_module(bottleneck.in_channel,
75
+ bottleneck.depth,
76
+ bottleneck.stride))
77
+ self.body = Sequential(*modules)
78
+
79
+ self.styles = nn.ModuleList()
80
+ log_size = int(math.log(opts.stylegan_size, 2))
81
+ self.style_count = 2 * log_size - 2
82
+ self.coarse_ind = 3
83
+ self.middle_ind = 7
84
+ for i in range(self.style_count):
85
+ if i < self.coarse_ind:
86
+ style = GradualStyleBlock(512, 512, 16)
87
+ elif i < self.middle_ind:
88
+ style = GradualStyleBlock(512, 512, 32)
89
+ else:
90
+ style = GradualStyleBlock(512, 512, 64)
91
+ self.styles.append(style)
92
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
93
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, x):
96
+ x = self.input_layer(x)
97
+
98
+ latents = []
99
+ modulelist = list(self.body._modules.values())
100
+ for i, l in enumerate(modulelist):
101
+ x = l(x)
102
+ if i == 6:
103
+ c1 = x
104
+ elif i == 20:
105
+ c2 = x
106
+ elif i == 23:
107
+ c3 = x
108
+
109
+ for j in range(self.coarse_ind):
110
+ latents.append(self.styles[j](c3))
111
+
112
+ p2 = _upsample_add(c3, self.latlayer1(c2))
113
+ for j in range(self.coarse_ind, self.middle_ind):
114
+ latents.append(self.styles[j](p2))
115
+
116
+ p1 = _upsample_add(p2, self.latlayer2(c1))
117
+ for j in range(self.middle_ind, self.style_count):
118
+ latents.append(self.styles[j](p1))
119
+
120
+ out = torch.stack(latents, dim=1)
121
+ return out
122
+
123
+
124
+ class Encoder4Editing(Module):
125
+ def __init__(self, num_layers, mode='ir', opts=None):
126
+ super(Encoder4Editing, self).__init__()
127
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
128
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
129
+ blocks = get_blocks(num_layers)
130
+ if mode == 'ir':
131
+ unit_module = bottleneck_IR
132
+ elif mode == 'ir_se':
133
+ unit_module = bottleneck_IR_SE
134
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
135
+ BatchNorm2d(64),
136
+ PReLU(64))
137
+ modules = []
138
+ for block in blocks:
139
+ for bottleneck in block:
140
+ modules.append(unit_module(bottleneck.in_channel,
141
+ bottleneck.depth,
142
+ bottleneck.stride))
143
+ self.body = Sequential(*modules)
144
+
145
+ self.styles = nn.ModuleList()
146
+ log_size = int(math.log(opts.stylegan_size, 2))
147
+ self.style_count = 2 * log_size - 2
148
+ self.coarse_ind = 3
149
+ self.middle_ind = 7
150
+
151
+ for i in range(self.style_count):
152
+ if i < self.coarse_ind:
153
+ style = GradualStyleBlock(512, 512, 16)
154
+ elif i < self.middle_ind:
155
+ style = GradualStyleBlock(512, 512, 32)
156
+ else:
157
+ style = GradualStyleBlock(512, 512, 64)
158
+ self.styles.append(style)
159
+
160
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
161
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
162
+
163
+ self.progressive_stage = ProgressiveStage.Inference
164
+
165
+ def get_deltas_starting_dimensions(self):
166
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
167
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
168
+
169
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
170
+ self.progressive_stage = new_stage
171
+ print('Changed progressive stage to: ', new_stage)
172
+
173
+ def forward(self, x):
174
+ x = self.input_layer(x)
175
+
176
+ modulelist = list(self.body._modules.values())
177
+ for i, l in enumerate(modulelist):
178
+ x = l(x)
179
+ if i == 6:
180
+ c1 = x
181
+ elif i == 20:
182
+ c2 = x
183
+ elif i == 23:
184
+ c3 = x
185
+
186
+ # Infer main W and duplicate it
187
+ w0 = self.styles[0](c3)
188
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
189
+ stage = self.progressive_stage.value
190
+ features = c3
191
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
192
+ if i == self.coarse_ind:
193
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
194
+ features = p2
195
+ elif i == self.middle_ind:
196
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
197
+ features = p1
198
+ delta_i = self.styles[i](features)
199
+ w[:, i] += delta_i
200
+ return w
e4e/models/latent_codes_pool.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class LatentCodesPool:
6
+ """This class implements latent codes buffer that stores previously generated w latent codes.
7
+ This buffer enables us to update discriminators using a history of generated w's
8
+ rather than the ones produced by the latest encoder.
9
+ """
10
+
11
+ def __init__(self, pool_size):
12
+ """Initialize the ImagePool class
13
+ Parameters:
14
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
15
+ """
16
+ self.pool_size = pool_size
17
+ if self.pool_size > 0: # create an empty pool
18
+ self.num_ws = 0
19
+ self.ws = []
20
+
21
+ def query(self, ws):
22
+ """Return w's from the pool.
23
+ Parameters:
24
+ ws: the latest generated w's from the generator
25
+ Returns w's from the buffer.
26
+ By 50/100, the buffer will return input w's.
27
+ By 50/100, the buffer will return w's previously stored in the buffer,
28
+ and insert the current w's to the buffer.
29
+ """
30
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
31
+ return ws
32
+ return_ws = []
33
+ for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
34
+ # w = torch.unsqueeze(image.data, 0)
35
+ if w.ndim == 2:
36
+ i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
37
+ w = w[i]
38
+ self.handle_w(w, return_ws)
39
+ return_ws = torch.stack(return_ws, 0) # collect all the images and return
40
+ return return_ws
41
+
42
+ def handle_w(self, w, return_ws):
43
+ if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
44
+ self.num_ws = self.num_ws + 1
45
+ self.ws.append(w)
46
+ return_ws.append(w)
47
+ else:
48
+ p = random.uniform(0, 1)
49
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
50
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
51
+ tmp = self.ws[random_id].clone()
52
+ self.ws[random_id] = w
53
+ return_ws.append(tmp)
54
+ else: # by another 50% chance, the buffer will return the current image
55
+ return_ws.append(w)
e4e/models/psp.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg')
4
+ import torch
5
+ from torch import nn
6
+ from e4e.models.encoders import psp_encoders
7
+ from e4e.models.stylegan2.model import Generator
8
+ from e4e.configs.paths_config import model_paths
9
+
10
+
11
+ def get_keys(d, name):
12
+ if 'state_dict' in d:
13
+ d = d['state_dict']
14
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
15
+ return d_filt
16
+
17
+
18
+ class pSp(nn.Module):
19
+
20
+ def __init__(self, opts, device):
21
+ super(pSp, self).__init__()
22
+ self.opts = opts
23
+ self.device = device
24
+ # Define architecture
25
+ self.encoder = self.set_encoder()
26
+ self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
27
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
28
+ # Load weights if needed
29
+ self.load_weights()
30
+
31
+ def set_encoder(self):
32
+ if self.opts.encoder_type == 'GradualStyleEncoder':
33
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
34
+ elif self.opts.encoder_type == 'Encoder4Editing':
35
+ encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
36
+ else:
37
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
38
+ return encoder
39
+
40
+ def load_weights(self):
41
+ if self.opts.checkpoint_path is not None:
42
+ print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
43
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
44
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
45
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
46
+ self.__load_latent_avg(ckpt)
47
+ else:
48
+ print('Loading encoders weights from irse50!')
49
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
50
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
51
+ print('Loading decoder weights from pretrained!')
52
+ ckpt = torch.load(self.opts.stylegan_weights)
53
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
54
+ self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
55
+
56
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
57
+ inject_latent=None, return_latents=False, alpha=None):
58
+ if input_code:
59
+ codes = x
60
+ else:
61
+ codes = self.encoder(x)
62
+ # normalize with respect to the center of an average face
63
+ if self.opts.start_from_latent_avg:
64
+ if codes.ndim == 2:
65
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
66
+ else:
67
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
68
+
69
+ if latent_mask is not None:
70
+ for i in latent_mask:
71
+ if inject_latent is not None:
72
+ if alpha is not None:
73
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
74
+ else:
75
+ codes[:, i] = inject_latent[:, i]
76
+ else:
77
+ codes[:, i] = 0
78
+
79
+ input_is_latent = not input_code
80
+ images, result_latent = self.decoder([codes],
81
+ input_is_latent=input_is_latent,
82
+ randomize_noise=randomize_noise,
83
+ return_latents=return_latents)
84
+
85
+ if resize:
86
+ images = self.face_pool(images)
87
+
88
+ if return_latents:
89
+ return images, result_latent
90
+ else:
91
+ return images
92
+
93
+ def __load_latent_avg(self, ckpt, repeat=None):
94
+ if 'latent_avg' in ckpt:
95
+ self.latent_avg = ckpt['latent_avg'].to(self.device)
96
+ if repeat is not None:
97
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
98
+ else:
99
+ self.latent_avg = None
e4e/models/stylegan2/__init__.py ADDED
File without changes
e4e/models/stylegan2/model.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ if torch.cuda.is_available():
8
+ from op.fused_act import FusedLeakyReLU, fused_leaky_relu
9
+ from op.upfirdn2d import upfirdn2d
10
+ else:
11
+ from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
12
+ from op.upfirdn2d_cpu import upfirdn2d
13
+
14
+
15
+ class PixelNorm(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, input):
20
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
21
+
22
+
23
+ def make_kernel(k):
24
+ k = torch.tensor(k, dtype=torch.float32)
25
+
26
+ if k.ndim == 1:
27
+ k = k[None, :] * k[:, None]
28
+
29
+ k /= k.sum()
30
+
31
+ return k
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, kernel, factor=2):
36
+ super().__init__()
37
+
38
+ self.factor = factor
39
+ kernel = make_kernel(kernel) * (factor ** 2)
40
+ self.register_buffer('kernel', kernel)
41
+
42
+ p = kernel.shape[0] - factor
43
+
44
+ pad0 = (p + 1) // 2 + factor - 1
45
+ pad1 = p // 2
46
+
47
+ self.pad = (pad0, pad1)
48
+
49
+ def forward(self, input):
50
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
51
+
52
+ return out
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, kernel, factor=2):
57
+ super().__init__()
58
+
59
+ self.factor = factor
60
+ kernel = make_kernel(kernel)
61
+ self.register_buffer('kernel', kernel)
62
+
63
+ p = kernel.shape[0] - factor
64
+
65
+ pad0 = (p + 1) // 2
66
+ pad1 = p // 2
67
+
68
+ self.pad = (pad0, pad1)
69
+
70
+ def forward(self, input):
71
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
72
+
73
+ return out
74
+
75
+
76
+ class Blur(nn.Module):
77
+ def __init__(self, kernel, pad, upsample_factor=1):
78
+ super().__init__()
79
+
80
+ kernel = make_kernel(kernel)
81
+
82
+ if upsample_factor > 1:
83
+ kernel = kernel * (upsample_factor ** 2)
84
+
85
+ self.register_buffer('kernel', kernel)
86
+
87
+ self.pad = pad
88
+
89
+ def forward(self, input):
90
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
91
+
92
+ return out
93
+
94
+
95
+ class EqualConv2d(nn.Module):
96
+ def __init__(
97
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
98
+ ):
99
+ super().__init__()
100
+
101
+ self.weight = nn.Parameter(
102
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
103
+ )
104
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
105
+
106
+ self.stride = stride
107
+ self.padding = padding
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(out_channel))
111
+
112
+ else:
113
+ self.bias = None
114
+
115
+ def forward(self, input):
116
+ out = F.conv2d(
117
+ input,
118
+ self.weight * self.scale,
119
+ bias=self.bias,
120
+ stride=self.stride,
121
+ padding=self.padding,
122
+ )
123
+
124
+ return out
125
+
126
+ def __repr__(self):
127
+ return (
128
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
129
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
130
+ )
131
+
132
+
133
+ class EqualLinear(nn.Module):
134
+ def __init__(
135
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
136
+ ):
137
+ super().__init__()
138
+
139
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
140
+
141
+ if bias:
142
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
143
+
144
+ else:
145
+ self.bias = None
146
+
147
+ self.activation = activation
148
+
149
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
150
+ self.lr_mul = lr_mul
151
+
152
+ def forward(self, input):
153
+ if self.activation:
154
+ out = F.linear(input, self.weight * self.scale)
155
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
156
+
157
+ else:
158
+ out = F.linear(
159
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
160
+ )
161
+
162
+ return out
163
+
164
+ def __repr__(self):
165
+ return (
166
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
167
+ )
168
+
169
+
170
+ class ScaledLeakyReLU(nn.Module):
171
+ def __init__(self, negative_slope=0.2):
172
+ super().__init__()
173
+
174
+ self.negative_slope = negative_slope
175
+
176
+ def forward(self, input):
177
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
178
+
179
+ return out * math.sqrt(2)
180
+
181
+
182
+ class ModulatedConv2d(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channel,
186
+ out_channel,
187
+ kernel_size,
188
+ style_dim,
189
+ demodulate=True,
190
+ upsample=False,
191
+ downsample=False,
192
+ blur_kernel=[1, 3, 3, 1],
193
+ ):
194
+ super().__init__()
195
+
196
+ self.eps = 1e-8
197
+ self.kernel_size = kernel_size
198
+ self.in_channel = in_channel
199
+ self.out_channel = out_channel
200
+ self.upsample = upsample
201
+ self.downsample = downsample
202
+
203
+ if upsample:
204
+ factor = 2
205
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
206
+ pad0 = (p + 1) // 2 + factor - 1
207
+ pad1 = p // 2 + 1
208
+
209
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
210
+
211
+ if downsample:
212
+ factor = 2
213
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
214
+ pad0 = (p + 1) // 2
215
+ pad1 = p // 2
216
+
217
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
218
+
219
+ fan_in = in_channel * kernel_size ** 2
220
+ self.scale = 1 / math.sqrt(fan_in)
221
+ self.padding = kernel_size // 2
222
+
223
+ self.weight = nn.Parameter(
224
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
225
+ )
226
+
227
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
228
+
229
+ self.demodulate = demodulate
230
+
231
+ def __repr__(self):
232
+ return (
233
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
234
+ f'upsample={self.upsample}, downsample={self.downsample})'
235
+ )
236
+
237
+ def forward(self, input, style):
238
+ batch, in_channel, height, width = input.shape
239
+
240
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
241
+ weight = self.scale * self.weight * style
242
+
243
+ if self.demodulate:
244
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
245
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
246
+
247
+ weight = weight.view(
248
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
249
+ )
250
+
251
+ if self.upsample:
252
+ input = input.view(1, batch * in_channel, height, width)
253
+ weight = weight.view(
254
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
255
+ )
256
+ weight = weight.transpose(1, 2).reshape(
257
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
258
+ )
259
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
260
+ _, _, height, width = out.shape
261
+ out = out.view(batch, self.out_channel, height, width)
262
+ out = self.blur(out)
263
+
264
+ elif self.downsample:
265
+ input = self.blur(input)
266
+ _, _, height, width = input.shape
267
+ input = input.view(1, batch * in_channel, height, width)
268
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
269
+ _, _, height, width = out.shape
270
+ out = out.view(batch, self.out_channel, height, width)
271
+
272
+ else:
273
+ input = input.view(1, batch * in_channel, height, width)
274
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
275
+ _, _, height, width = out.shape
276
+ out = out.view(batch, self.out_channel, height, width)
277
+
278
+ return out
279
+
280
+
281
+ class NoiseInjection(nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+
285
+ self.weight = nn.Parameter(torch.zeros(1))
286
+
287
+ def forward(self, image, noise=None):
288
+ if noise is None:
289
+ batch, _, height, width = image.shape
290
+ noise = image.new_empty(batch, 1, height, width).normal_()
291
+
292
+ return image + self.weight * noise
293
+
294
+
295
+ class ConstantInput(nn.Module):
296
+ def __init__(self, channel, size=4):
297
+ super().__init__()
298
+
299
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
300
+
301
+ def forward(self, input):
302
+ batch = input.shape[0]
303
+ out = self.input.repeat(batch, 1, 1, 1)
304
+
305
+ return out
306
+
307
+
308
+ class StyledConv(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channel,
312
+ out_channel,
313
+ kernel_size,
314
+ style_dim,
315
+ upsample=False,
316
+ blur_kernel=[1, 3, 3, 1],
317
+ demodulate=True,
318
+ ):
319
+ super().__init__()
320
+
321
+ self.conv = ModulatedConv2d(
322
+ in_channel,
323
+ out_channel,
324
+ kernel_size,
325
+ style_dim,
326
+ upsample=upsample,
327
+ blur_kernel=blur_kernel,
328
+ demodulate=demodulate,
329
+ )
330
+
331
+ self.noise = NoiseInjection()
332
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
333
+ # self.activate = ScaledLeakyReLU(0.2)
334
+ self.activate = FusedLeakyReLU(out_channel)
335
+
336
+ def forward(self, input, style, noise=None):
337
+ out = self.conv(input, style)
338
+ out = self.noise(out, noise=noise)
339
+ # out = out + self.bias
340
+ out = self.activate(out)
341
+
342
+ return out
343
+
344
+
345
+ class ToRGB(nn.Module):
346
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
347
+ super().__init__()
348
+
349
+ if upsample:
350
+ self.upsample = Upsample(blur_kernel)
351
+
352
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
353
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
354
+
355
+ def forward(self, input, style, skip=None):
356
+ out = self.conv(input, style)
357
+ out = out + self.bias
358
+
359
+ if skip is not None:
360
+ skip = self.upsample(skip)
361
+
362
+ out = out + skip
363
+
364
+ return out
365
+
366
+
367
+ class Generator(nn.Module):
368
+ def __init__(
369
+ self,
370
+ size,
371
+ style_dim,
372
+ n_mlp,
373
+ channel_multiplier=2,
374
+ blur_kernel=[1, 3, 3, 1],
375
+ lr_mlp=0.01,
376
+ ):
377
+ super().__init__()
378
+
379
+ self.size = size
380
+
381
+ self.style_dim = style_dim
382
+
383
+ layers = [PixelNorm()]
384
+
385
+ for i in range(n_mlp):
386
+ layers.append(
387
+ EqualLinear(
388
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
389
+ )
390
+ )
391
+
392
+ self.style = nn.Sequential(*layers)
393
+
394
+ self.channels = {
395
+ 4: 512,
396
+ 8: 512,
397
+ 16: 512,
398
+ 32: 512,
399
+ 64: 256 * channel_multiplier,
400
+ 128: 128 * channel_multiplier,
401
+ 256: 64 * channel_multiplier,
402
+ 512: 32 * channel_multiplier,
403
+ 1024: 16 * channel_multiplier,
404
+ }
405
+
406
+ self.input = ConstantInput(self.channels[4])
407
+ self.conv1 = StyledConv(
408
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
409
+ )
410
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
411
+
412
+ self.log_size = int(math.log(size, 2))
413
+ self.num_layers = (self.log_size - 2) * 2 + 1
414
+
415
+ self.convs = nn.ModuleList()
416
+ self.upsamples = nn.ModuleList()
417
+ self.to_rgbs = nn.ModuleList()
418
+ self.noises = nn.Module()
419
+
420
+ in_channel = self.channels[4]
421
+
422
+ for layer_idx in range(self.num_layers):
423
+ res = (layer_idx + 5) // 2
424
+ shape = [1, 1, 2 ** res, 2 ** res]
425
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
426
+
427
+ for i in range(3, self.log_size + 1):
428
+ out_channel = self.channels[2 ** i]
429
+
430
+ self.convs.append(
431
+ StyledConv(
432
+ in_channel,
433
+ out_channel,
434
+ 3,
435
+ style_dim,
436
+ upsample=True,
437
+ blur_kernel=blur_kernel,
438
+ )
439
+ )
440
+
441
+ self.convs.append(
442
+ StyledConv(
443
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
444
+ )
445
+ )
446
+
447
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
448
+
449
+ in_channel = out_channel
450
+
451
+ self.n_latent = self.log_size * 2 - 2
452
+
453
+ def make_noise(self):
454
+ device = self.input.input.device
455
+
456
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
457
+
458
+ for i in range(3, self.log_size + 1):
459
+ for _ in range(2):
460
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
461
+
462
+ return noises
463
+
464
+ def mean_latent(self, n_latent):
465
+ latent_in = torch.randn(
466
+ n_latent, self.style_dim, device=self.input.input.device
467
+ )
468
+ latent = self.style(latent_in).mean(0, keepdim=True)
469
+
470
+ return latent
471
+
472
+ def get_latent(self, input):
473
+ return self.style(input)
474
+
475
+ def forward(
476
+ self,
477
+ styles,
478
+ return_latents=False,
479
+ return_features=False,
480
+ inject_index=None,
481
+ truncation=1,
482
+ truncation_latent=None,
483
+ input_is_latent=False,
484
+ noise=None,
485
+ randomize_noise=True,
486
+ ):
487
+ if not input_is_latent:
488
+ styles = [self.style(s) for s in styles]
489
+
490
+ if noise is None:
491
+ if randomize_noise:
492
+ noise = [None] * self.num_layers
493
+ else:
494
+ noise = [
495
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
496
+ ]
497
+
498
+ if truncation < 1:
499
+ style_t = []
500
+
501
+ for style in styles:
502
+ style_t.append(
503
+ truncation_latent + truncation * (style - truncation_latent)
504
+ )
505
+
506
+ styles = style_t
507
+
508
+ if len(styles) < 2:
509
+ inject_index = self.n_latent
510
+
511
+ if styles[0].ndim < 3:
512
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
513
+ else:
514
+ latent = styles[0]
515
+
516
+ else:
517
+ if inject_index is None:
518
+ inject_index = random.randint(1, self.n_latent - 1)
519
+
520
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
521
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
522
+
523
+ latent = torch.cat([latent, latent2], 1)
524
+
525
+ out = self.input(latent)
526
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
527
+
528
+ skip = self.to_rgb1(out, latent[:, 1])
529
+
530
+ i = 1
531
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
532
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
533
+ ):
534
+ out = conv1(out, latent[:, i], noise=noise1)
535
+ out = conv2(out, latent[:, i + 1], noise=noise2)
536
+ skip = to_rgb(out, latent[:, i + 2], skip)
537
+
538
+ i += 2
539
+
540
+ image = skip
541
+
542
+ if return_latents:
543
+ return image, latent
544
+ elif return_features:
545
+ return image, out
546
+ else:
547
+ return image, None
548
+
549
+
550
+ class ConvLayer(nn.Sequential):
551
+ def __init__(
552
+ self,
553
+ in_channel,
554
+ out_channel,
555
+ kernel_size,
556
+ downsample=False,
557
+ blur_kernel=[1, 3, 3, 1],
558
+ bias=True,
559
+ activate=True,
560
+ ):
561
+ layers = []
562
+
563
+ if downsample:
564
+ factor = 2
565
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
566
+ pad0 = (p + 1) // 2
567
+ pad1 = p // 2
568
+
569
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
570
+
571
+ stride = 2
572
+ self.padding = 0
573
+
574
+ else:
575
+ stride = 1
576
+ self.padding = kernel_size // 2
577
+
578
+ layers.append(
579
+ EqualConv2d(
580
+ in_channel,
581
+ out_channel,
582
+ kernel_size,
583
+ padding=self.padding,
584
+ stride=stride,
585
+ bias=bias and not activate,
586
+ )
587
+ )
588
+
589
+ if activate:
590
+ if bias:
591
+ layers.append(FusedLeakyReLU(out_channel))
592
+
593
+ else:
594
+ layers.append(ScaledLeakyReLU(0.2))
595
+
596
+ super().__init__(*layers)
597
+
598
+
599
+ class ResBlock(nn.Module):
600
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
601
+ super().__init__()
602
+
603
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
604
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
605
+
606
+ self.skip = ConvLayer(
607
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
608
+ )
609
+
610
+ def forward(self, input):
611
+ out = self.conv1(input)
612
+ out = self.conv2(out)
613
+
614
+ skip = self.skip(input)
615
+ out = (out + skip) / math.sqrt(2)
616
+
617
+ return out
618
+
619
+
620
+ class Discriminator(nn.Module):
621
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
622
+ super().__init__()
623
+
624
+ channels = {
625
+ 4: 512,
626
+ 8: 512,
627
+ 16: 512,
628
+ 32: 512,
629
+ 64: 256 * channel_multiplier,
630
+ 128: 128 * channel_multiplier,
631
+ 256: 64 * channel_multiplier,
632
+ 512: 32 * channel_multiplier,
633
+ 1024: 16 * channel_multiplier,
634
+ }
635
+
636
+ convs = [ConvLayer(3, channels[size], 1)]
637
+
638
+ log_size = int(math.log(size, 2))
639
+
640
+ in_channel = channels[size]
641
+
642
+ for i in range(log_size, 2, -1):
643
+ out_channel = channels[2 ** (i - 1)]
644
+
645
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
646
+
647
+ in_channel = out_channel
648
+
649
+ self.convs = nn.Sequential(*convs)
650
+
651
+ self.stddev_group = 4
652
+ self.stddev_feat = 1
653
+
654
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
655
+ self.final_linear = nn.Sequential(
656
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
657
+ EqualLinear(channels[4], 1),
658
+ )
659
+
660
+ def forward(self, input):
661
+ out = self.convs(input)
662
+
663
+ batch, channel, height, width = out.shape
664
+ group = min(batch, self.stddev_group)
665
+ stddev = out.view(
666
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
667
+ )
668
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
669
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
670
+ stddev = stddev.repeat(group, 1, height, width)
671
+ out = torch.cat([out, stddev], 1)
672
+
673
+ out = self.final_conv(out)
674
+
675
+ out = out.view(batch, -1)
676
+ out = self.final_linear(out)
677
+
678
+ return out
e4e/options/__init__.py ADDED
File without changes
e4e/options/train_options.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from configs.paths_config import model_paths
3
+
4
+
5
+ class TrainOptions:
6
+
7
+ def __init__(self):
8
+ self.parser = ArgumentParser()
9
+ self.initialize()
10
+
11
+ def initialize(self):
12
+ self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
13
+ self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str,
14
+ help='Type of dataset/experiment to run')
15
+ self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use')
16
+
17
+ self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
18
+ self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
19
+ self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
20
+ self.parser.add_argument('--test_workers', default=2, type=int,
21
+ help='Number of test/inference dataloader workers')
22
+
23
+ self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
24
+ self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
25
+ self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
26
+ self.parser.add_argument('--start_from_latent_avg', action='store_true',
27
+ help='Whether to add average latent vector to generate codes from encoder.')
28
+ self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone')
29
+
30
+ self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
31
+ self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
32
+ self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
33
+
34
+ self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str,
35
+ help='Path to StyleGAN model weights')
36
+ self.parser.add_argument('--stylegan_size', default=1024, type=int,
37
+ help='size of pretrained StyleGAN Generator')
38
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
39
+
40
+ self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps')
41
+ self.parser.add_argument('--image_interval', default=100, type=int,
42
+ help='Interval for logging train images during training')
43
+ self.parser.add_argument('--board_interval', default=50, type=int,
44
+ help='Interval for logging metrics to tensorboard')
45
+ self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval')
46
+ self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval')
47
+
48
+ # Discriminator flags
49
+ self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier')
50
+ self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate')
51
+ self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
52
+ self.parser.add_argument("--d_reg_every", type=int, default=16,
53
+ help="interval for applying r1 regularization")
54
+ self.parser.add_argument('--use_w_pool', action='store_true',
55
+ help='Whether to store a latnet codes pool for the discriminator\'s training')
56
+ self.parser.add_argument("--w_pool_size", type=int, default=50,
57
+ help="W\'s pool size, depends on --use_w_pool")
58
+
59
+ # e4e specific
60
+ self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas")
61
+ self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss")
62
+
63
+ # Progressive training
64
+ self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
65
+ help="The training steps of training new deltas. steps[i] starts the delta_i training")
66
+ self.parser.add_argument('--progressive_start', type=int, default=None,
67
+ help="The training step to start training the deltas, overrides progressive_steps")
68
+ self.parser.add_argument('--progressive_step_every', type=int, default=2_000,
69
+ help="Amount of training steps for each progressive step")
70
+
71
+ # Save additional training info to enable future training continuation from produced checkpoints
72
+ self.parser.add_argument('--save_training_data', action='store_true',
73
+ help='Save intermediate training data to resume training from the checkpoint')
74
+ self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory')
75
+ self.parser.add_argument('--keep_optimizer', action='store_true',
76
+ help='Whether to continue from the checkpoint\'s optimizer')
77
+ self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str,
78
+ help='Path to training checkpoint, works when --save_training_data was set to True')
79
+ self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None,
80
+ help="Name of training parameters to update the loaded training checkpoint")
81
+
82
+ def parse(self):
83
+ opts = self.parser.parse_args()
84
+ return opts
e4e/scripts/calc_losses_on_images.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import os
3
+ import json
4
+ import sys
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import torchvision.transforms as transforms
10
+
11
+ sys.path.append(".")
12
+ sys.path.append("..")
13
+
14
+ from criteria.lpips.lpips import LPIPS
15
+ from datasets.gt_res_dataset import GTResDataset
16
+
17
+
18
+ def parse_args():
19
+ parser = ArgumentParser(add_help=False)
20
+ parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2'])
21
+ parser.add_argument('--data_path', type=str, default='results')
22
+ parser.add_argument('--gt_path', type=str, default='gt_images')
23
+ parser.add_argument('--workers', type=int, default=4)
24
+ parser.add_argument('--batch_size', type=int, default=4)
25
+ parser.add_argument('--is_cars', action='store_true')
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+
30
+ def run(args):
31
+ resize_dims = (256, 256)
32
+ if args.is_cars:
33
+ resize_dims = (192, 256)
34
+ transform = transforms.Compose([transforms.Resize(resize_dims),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
37
+
38
+ print('Loading dataset')
39
+ dataset = GTResDataset(root_path=args.data_path,
40
+ gt_dir=args.gt_path,
41
+ transform=transform)
42
+
43
+ dataloader = DataLoader(dataset,
44
+ batch_size=args.batch_size,
45
+ shuffle=False,
46
+ num_workers=int(args.workers),
47
+ drop_last=True)
48
+
49
+ if args.mode == 'lpips':
50
+ loss_func = LPIPS(net_type='alex')
51
+ elif args.mode == 'l2':
52
+ loss_func = torch.nn.MSELoss()
53
+ else:
54
+ raise Exception('Not a valid mode!')
55
+ loss_func.cuda()
56
+
57
+ global_i = 0
58
+ scores_dict = {}
59
+ all_scores = []
60
+ for result_batch, gt_batch in tqdm(dataloader):
61
+ for i in range(args.batch_size):
62
+ loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda()))
63
+ all_scores.append(loss)
64
+ im_path = dataset.pairs[global_i][0]
65
+ scores_dict[os.path.basename(im_path)] = loss
66
+ global_i += 1
67
+
68
+ all_scores = list(scores_dict.values())
69
+ mean = np.mean(all_scores)
70
+ std = np.std(all_scores)
71
+ result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std)
72
+ print('Finished with ', args.data_path)
73
+ print(result_str)
74
+
75
+ out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
76
+ if not os.path.exists(out_path):
77
+ os.makedirs(out_path)
78
+
79
+ with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f:
80
+ f.write(result_str)
81
+ with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f:
82
+ json.dump(scores_dict, f)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ args = parse_args()
87
+ run(args)
e4e/scripts/inference.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ import numpy as np
5
+ import sys
6
+ import os
7
+ import dlib
8
+
9
+ sys.path.append(".")
10
+ sys.path.append("..")
11
+
12
+ from configs import data_configs, paths_config
13
+ from datasets.inference_dataset import InferenceDataset
14
+ from torch.utils.data import DataLoader
15
+ from utils.model_utils import setup_model
16
+ from utils.common import tensor2im
17
+ from utils.alignment import align_face
18
+ from PIL import Image
19
+
20
+
21
+ def main(args):
22
+ net, opts = setup_model(args.ckpt, device)
23
+ is_cars = 'cars_' in opts.dataset_type
24
+ generator = net.decoder
25
+ generator.eval()
26
+ args, data_loader = setup_data_loader(args, opts)
27
+
28
+ # Check if latents exist
29
+ latents_file_path = os.path.join(args.save_dir, 'latents.pt')
30
+ if os.path.exists(latents_file_path):
31
+ latent_codes = torch.load(latents_file_path).to(device)
32
+ else:
33
+ latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
34
+ torch.save(latent_codes, latents_file_path)
35
+
36
+ if not args.latents_only:
37
+ generate_inversions(args, generator, latent_codes, is_cars=is_cars)
38
+
39
+
40
+ def setup_data_loader(args, opts):
41
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
42
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
43
+ images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
44
+ print(f"images path: {images_path}")
45
+ align_function = None
46
+ if args.align:
47
+ align_function = run_alignment
48
+ test_dataset = InferenceDataset(root=images_path,
49
+ transform=transforms_dict['transform_test'],
50
+ preprocess=align_function,
51
+ opts=opts)
52
+
53
+ data_loader = DataLoader(test_dataset,
54
+ batch_size=args.batch,
55
+ shuffle=False,
56
+ num_workers=2,
57
+ drop_last=True)
58
+
59
+ print(f'dataset length: {len(test_dataset)}')
60
+
61
+ if args.n_sample is None:
62
+ args.n_sample = len(test_dataset)
63
+ return args, data_loader
64
+
65
+
66
+ def get_latents(net, x, is_cars=False):
67
+ codes = net.encoder(x)
68
+ if net.opts.start_from_latent_avg:
69
+ if codes.ndim == 2:
70
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
71
+ else:
72
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
73
+ if codes.shape[1] == 18 and is_cars:
74
+ codes = codes[:, :16, :]
75
+ return codes
76
+
77
+
78
+ def get_all_latents(net, data_loader, n_images=None, is_cars=False):
79
+ all_latents = []
80
+ i = 0
81
+ with torch.no_grad():
82
+ for batch in data_loader:
83
+ if n_images is not None and i > n_images:
84
+ break
85
+ x = batch
86
+ inputs = x.to(device).float()
87
+ latents = get_latents(net, inputs, is_cars)
88
+ all_latents.append(latents)
89
+ i += len(latents)
90
+ return torch.cat(all_latents)
91
+
92
+
93
+ def save_image(img, save_dir, idx):
94
+ result = tensor2im(img)
95
+ im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg")
96
+ Image.fromarray(np.array(result)).save(im_save_path)
97
+
98
+
99
+ @torch.no_grad()
100
+ def generate_inversions(args, g, latent_codes, is_cars):
101
+ print('Saving inversion images')
102
+ inversions_directory_path = os.path.join(args.save_dir, 'inversions')
103
+ os.makedirs(inversions_directory_path, exist_ok=True)
104
+ for i in range(args.n_sample):
105
+ imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
106
+ if is_cars:
107
+ imgs = imgs[:, :, 64:448, :]
108
+ save_image(imgs[0], inversions_directory_path, i + 1)
109
+
110
+
111
+ def run_alignment(image_path):
112
+ predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
113
+ aligned_image = align_face(filepath=image_path, predictor=predictor)
114
+ print("Aligned image has shape: {}".format(aligned_image.size))
115
+ return aligned_image
116
+
117
+
118
+ if __name__ == "__main__":
119
+ device = "cuda"
120
+
121
+ parser = argparse.ArgumentParser(description="Inference")
122
+ parser.add_argument("--images_dir", type=str, default=None,
123
+ help="The directory of the images to be inverted")
124
+ parser.add_argument("--save_dir", type=str, default=None,
125
+ help="The directory to save the latent codes and inversion images. (default: images_dir")
126
+ parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
127
+ parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
128
+ parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory")
129
+ parser.add_argument("--align", action="store_true", help="align face images before inference")
130
+ parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")
131
+
132
+ args = parser.parse_args()
133
+ main(args)
e4e/scripts/train.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file runs the main training/val loop
3
+ """
4
+ import os
5
+ import json
6
+ import math
7
+ import sys
8
+ import pprint
9
+ import torch
10
+ from argparse import Namespace
11
+
12
+ sys.path.append(".")
13
+ sys.path.append("..")
14
+
15
+ from options.train_options import TrainOptions
16
+ from training.coach import Coach
17
+
18
+
19
+ def main():
20
+ opts = TrainOptions().parse()
21
+ previous_train_ckpt = None
22
+ if opts.resume_training_from_ckpt:
23
+ opts, previous_train_ckpt = load_train_checkpoint(opts)
24
+ else:
25
+ setup_progressive_steps(opts)
26
+ create_initial_experiment_dir(opts)
27
+
28
+ coach = Coach(opts, previous_train_ckpt)
29
+ coach.train()
30
+
31
+
32
+ def load_train_checkpoint(opts):
33
+ train_ckpt_path = opts.resume_training_from_ckpt
34
+ previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu')
35
+ new_opts_dict = vars(opts)
36
+ opts = previous_train_ckpt['opts']
37
+ opts['resume_training_from_ckpt'] = train_ckpt_path
38
+ update_new_configs(opts, new_opts_dict)
39
+ pprint.pprint(opts)
40
+ opts = Namespace(**opts)
41
+ if opts.sub_exp_dir is not None:
42
+ sub_exp_dir = opts.sub_exp_dir
43
+ opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir)
44
+ create_initial_experiment_dir(opts)
45
+ return opts, previous_train_ckpt
46
+
47
+
48
+ def setup_progressive_steps(opts):
49
+ log_size = int(math.log(opts.stylegan_size, 2))
50
+ num_style_layers = 2*log_size - 2
51
+ num_deltas = num_style_layers - 1
52
+ if opts.progressive_start is not None: # If progressive delta training
53
+ opts.progressive_steps = [0]
54
+ next_progressive_step = opts.progressive_start
55
+ for i in range(num_deltas):
56
+ opts.progressive_steps.append(next_progressive_step)
57
+ next_progressive_step += opts.progressive_step_every
58
+
59
+ assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \
60
+ "Invalid progressive training input"
61
+
62
+
63
+ def is_valid_progressive_steps(opts, num_style_layers):
64
+ return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0
65
+
66
+
67
+ def create_initial_experiment_dir(opts):
68
+ if os.path.exists(opts.exp_dir):
69
+ raise Exception('Oops... {} already exists'.format(opts.exp_dir))
70
+ os.makedirs(opts.exp_dir)
71
+
72
+ opts_dict = vars(opts)
73
+ pprint.pprint(opts_dict)
74
+ with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
75
+ json.dump(opts_dict, f, indent=4, sort_keys=True)
76
+
77
+
78
+ def update_new_configs(ckpt_opts, new_opts):
79
+ for k, v in new_opts.items():
80
+ if k not in ckpt_opts:
81
+ ckpt_opts[k] = v
82
+ if new_opts['update_param_list']:
83
+ for param in new_opts['update_param_list']:
84
+ ckpt_opts[param] = new_opts[param]
85
+
86
+
87
+ if __name__ == '__main__':
88
+ main()
e4e/training/__init__.py ADDED
File without changes
e4e/training/coach.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+
6
+ matplotlib.use('Agg')
7
+
8
+ import torch
9
+ from torch import nn, autograd
10
+ from torch.utils.data import DataLoader
11
+ from torch.utils.tensorboard import SummaryWriter
12
+ import torch.nn.functional as F
13
+
14
+ from utils import common, train_utils
15
+ from criteria import id_loss, moco_loss
16
+ from configs import data_configs
17
+ from datasets.images_dataset import ImagesDataset
18
+ from criteria.lpips.lpips import LPIPS
19
+ from models.psp import pSp
20
+ from models.latent_codes_pool import LatentCodesPool
21
+ from models.discriminator import LatentCodesDiscriminator
22
+ from models.encoders.psp_encoders import ProgressiveStage
23
+ from training.ranger import Ranger
24
+
25
+ random.seed(0)
26
+ torch.manual_seed(0)
27
+
28
+
29
+ class Coach:
30
+ def __init__(self, opts, prev_train_checkpoint=None):
31
+ self.opts = opts
32
+
33
+ self.global_step = 0
34
+
35
+ self.device = 'cuda:0'
36
+ self.opts.device = self.device
37
+ # Initialize network
38
+ self.net = pSp(self.opts).to(self.device)
39
+
40
+ # Initialize loss
41
+ if self.opts.lpips_lambda > 0:
42
+ self.lpips_loss = LPIPS(net_type=self.opts.lpips_type).to(self.device).eval()
43
+ if self.opts.id_lambda > 0:
44
+ if 'ffhq' in self.opts.dataset_type or 'celeb' in self.opts.dataset_type:
45
+ self.id_loss = id_loss.IDLoss().to(self.device).eval()
46
+ else:
47
+ self.id_loss = moco_loss.MocoLoss(opts).to(self.device).eval()
48
+ self.mse_loss = nn.MSELoss().to(self.device).eval()
49
+
50
+ # Initialize optimizer
51
+ self.optimizer = self.configure_optimizers()
52
+
53
+ # Initialize discriminator
54
+ if self.opts.w_discriminator_lambda > 0:
55
+ self.discriminator = LatentCodesDiscriminator(512, 4).to(self.device)
56
+ self.discriminator_optimizer = torch.optim.Adam(list(self.discriminator.parameters()),
57
+ lr=opts.w_discriminator_lr)
58
+ self.real_w_pool = LatentCodesPool(self.opts.w_pool_size)
59
+ self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size)
60
+
61
+ # Initialize dataset
62
+ self.train_dataset, self.test_dataset = self.configure_datasets()
63
+ self.train_dataloader = DataLoader(self.train_dataset,
64
+ batch_size=self.opts.batch_size,
65
+ shuffle=True,
66
+ num_workers=int(self.opts.workers),
67
+ drop_last=True)
68
+ self.test_dataloader = DataLoader(self.test_dataset,
69
+ batch_size=self.opts.test_batch_size,
70
+ shuffle=False,
71
+ num_workers=int(self.opts.test_workers),
72
+ drop_last=True)
73
+
74
+ # Initialize logger
75
+ log_dir = os.path.join(opts.exp_dir, 'logs')
76
+ os.makedirs(log_dir, exist_ok=True)
77
+ self.logger = SummaryWriter(log_dir=log_dir)
78
+
79
+ # Initialize checkpoint dir
80
+ self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
81
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
82
+ self.best_val_loss = None
83
+ if self.opts.save_interval is None:
84
+ self.opts.save_interval = self.opts.max_steps
85
+
86
+ if prev_train_checkpoint is not None:
87
+ self.load_from_train_checkpoint(prev_train_checkpoint)
88
+ prev_train_checkpoint = None
89
+
90
+ def load_from_train_checkpoint(self, ckpt):
91
+ print('Loading previous training data...')
92
+ self.global_step = ckpt['global_step'] + 1
93
+ self.best_val_loss = ckpt['best_val_loss']
94
+ self.net.load_state_dict(ckpt['state_dict'])
95
+
96
+ if self.opts.keep_optimizer:
97
+ self.optimizer.load_state_dict(ckpt['optimizer'])
98
+ if self.opts.w_discriminator_lambda > 0:
99
+ self.discriminator.load_state_dict(ckpt['discriminator_state_dict'])
100
+ self.discriminator_optimizer.load_state_dict(ckpt['discriminator_optimizer_state_dict'])
101
+ if self.opts.progressive_steps:
102
+ self.check_for_progressive_training_update(is_resume_from_ckpt=True)
103
+ print(f'Resuming training from step {self.global_step}')
104
+
105
+ def train(self):
106
+ self.net.train()
107
+ if self.opts.progressive_steps:
108
+ self.check_for_progressive_training_update()
109
+ while self.global_step < self.opts.max_steps:
110
+ for batch_idx, batch in enumerate(self.train_dataloader):
111
+ loss_dict = {}
112
+ if self.is_training_discriminator():
113
+ loss_dict = self.train_discriminator(batch)
114
+ x, y, y_hat, latent = self.forward(batch)
115
+ loss, encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
116
+ loss_dict = {**loss_dict, **encoder_loss_dict}
117
+ self.optimizer.zero_grad()
118
+ loss.backward()
119
+ self.optimizer.step()
120
+
121
+ # Logging related
122
+ if self.global_step % self.opts.image_interval == 0 or (
123
+ self.global_step < 1000 and self.global_step % 25 == 0):
124
+ self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces')
125
+ if self.global_step % self.opts.board_interval == 0:
126
+ self.print_metrics(loss_dict, prefix='train')
127
+ self.log_metrics(loss_dict, prefix='train')
128
+
129
+ # Validation related
130
+ val_loss_dict = None
131
+ if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
132
+ val_loss_dict = self.validate()
133
+ if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
134
+ self.best_val_loss = val_loss_dict['loss']
135
+ self.checkpoint_me(val_loss_dict, is_best=True)
136
+
137
+ if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
138
+ if val_loss_dict is not None:
139
+ self.checkpoint_me(val_loss_dict, is_best=False)
140
+ else:
141
+ self.checkpoint_me(loss_dict, is_best=False)
142
+
143
+ if self.global_step == self.opts.max_steps:
144
+ print('OMG, finished training!')
145
+ break
146
+
147
+ self.global_step += 1
148
+ if self.opts.progressive_steps:
149
+ self.check_for_progressive_training_update()
150
+
151
+ def check_for_progressive_training_update(self, is_resume_from_ckpt=False):
152
+ for i in range(len(self.opts.progressive_steps)):
153
+ if is_resume_from_ckpt and self.global_step >= self.opts.progressive_steps[i]: # Case checkpoint
154
+ self.net.encoder.set_progressive_stage(ProgressiveStage(i))
155
+ if self.global_step == self.opts.progressive_steps[i]: # Case training reached progressive step
156
+ self.net.encoder.set_progressive_stage(ProgressiveStage(i))
157
+
158
+ def validate(self):
159
+ self.net.eval()
160
+ agg_loss_dict = []
161
+ for batch_idx, batch in enumerate(self.test_dataloader):
162
+ cur_loss_dict = {}
163
+ if self.is_training_discriminator():
164
+ cur_loss_dict = self.validate_discriminator(batch)
165
+ with torch.no_grad():
166
+ x, y, y_hat, latent = self.forward(batch)
167
+ loss, cur_encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
168
+ cur_loss_dict = {**cur_loss_dict, **cur_encoder_loss_dict}
169
+ agg_loss_dict.append(cur_loss_dict)
170
+
171
+ # Logging related
172
+ self.parse_and_log_images(id_logs, x, y, y_hat,
173
+ title='images/test/faces',
174
+ subscript='{:04d}'.format(batch_idx))
175
+
176
+ # For first step just do sanity test on small amount of data
177
+ if self.global_step == 0 and batch_idx >= 4:
178
+ self.net.train()
179
+ return None # Do not log, inaccurate in first batch
180
+
181
+ loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
182
+ self.log_metrics(loss_dict, prefix='test')
183
+ self.print_metrics(loss_dict, prefix='test')
184
+
185
+ self.net.train()
186
+ return loss_dict
187
+
188
+ def checkpoint_me(self, loss_dict, is_best):
189
+ save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step)
190
+ save_dict = self.__get_save_dict()
191
+ checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
192
+ torch.save(save_dict, checkpoint_path)
193
+ with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
194
+ if is_best:
195
+ f.write(
196
+ '**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
197
+ else:
198
+ f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
199
+
200
+ def configure_optimizers(self):
201
+ params = list(self.net.encoder.parameters())
202
+ if self.opts.train_decoder:
203
+ params += list(self.net.decoder.parameters())
204
+ else:
205
+ self.requires_grad(self.net.decoder, False)
206
+ if self.opts.optim_name == 'adam':
207
+ optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
208
+ else:
209
+ optimizer = Ranger(params, lr=self.opts.learning_rate)
210
+ return optimizer
211
+
212
+ def configure_datasets(self):
213
+ if self.opts.dataset_type not in data_configs.DATASETS.keys():
214
+ Exception('{} is not a valid dataset_type'.format(self.opts.dataset_type))
215
+ print('Loading dataset for {}'.format(self.opts.dataset_type))
216
+ dataset_args = data_configs.DATASETS[self.opts.dataset_type]
217
+ transforms_dict = dataset_args['transforms'](self.opts).get_transforms()
218
+ train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'],
219
+ target_root=dataset_args['train_target_root'],
220
+ source_transform=transforms_dict['transform_source'],
221
+ target_transform=transforms_dict['transform_gt_train'],
222
+ opts=self.opts)
223
+ test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'],
224
+ target_root=dataset_args['test_target_root'],
225
+ source_transform=transforms_dict['transform_source'],
226
+ target_transform=transforms_dict['transform_test'],
227
+ opts=self.opts)
228
+ print("Number of training samples: {}".format(len(train_dataset)))
229
+ print("Number of test samples: {}".format(len(test_dataset)))
230
+ return train_dataset, test_dataset
231
+
232
+ def calc_loss(self, x, y, y_hat, latent):
233
+ loss_dict = {}
234
+ loss = 0.0
235
+ id_logs = None
236
+ if self.is_training_discriminator(): # Adversarial loss
237
+ loss_disc = 0.
238
+ dims_to_discriminate = self.get_dims_to_discriminate() if self.is_progressive_training() else \
239
+ list(range(self.net.decoder.n_latent))
240
+
241
+ for i in dims_to_discriminate:
242
+ w = latent[:, i, :]
243
+ fake_pred = self.discriminator(w)
244
+ loss_disc += F.softplus(-fake_pred).mean()
245
+ loss_disc /= len(dims_to_discriminate)
246
+ loss_dict['encoder_discriminator_loss'] = float(loss_disc)
247
+ loss += self.opts.w_discriminator_lambda * loss_disc
248
+
249
+ if self.opts.progressive_steps and self.net.encoder.progressive_stage.value != 18: # delta regularization loss
250
+ total_delta_loss = 0
251
+ deltas_latent_dims = self.net.encoder.get_deltas_starting_dimensions()
252
+
253
+ first_w = latent[:, 0, :]
254
+ for i in range(1, self.net.encoder.progressive_stage.value + 1):
255
+ curr_dim = deltas_latent_dims[i]
256
+ delta = latent[:, curr_dim, :] - first_w
257
+ delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean()
258
+ loss_dict[f"delta{i}_loss"] = float(delta_loss)
259
+ total_delta_loss += delta_loss
260
+ loss_dict['total_delta_loss'] = float(total_delta_loss)
261
+ loss += self.opts.delta_norm_lambda * total_delta_loss
262
+
263
+ if self.opts.id_lambda > 0: # Similarity loss
264
+ loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x)
265
+ loss_dict['loss_id'] = float(loss_id)
266
+ loss_dict['id_improve'] = float(sim_improvement)
267
+ loss += loss_id * self.opts.id_lambda
268
+ if self.opts.l2_lambda > 0:
269
+ loss_l2 = F.mse_loss(y_hat, y)
270
+ loss_dict['loss_l2'] = float(loss_l2)
271
+ loss += loss_l2 * self.opts.l2_lambda
272
+ if self.opts.lpips_lambda > 0:
273
+ loss_lpips = self.lpips_loss(y_hat, y)
274
+ loss_dict['loss_lpips'] = float(loss_lpips)
275
+ loss += loss_lpips * self.opts.lpips_lambda
276
+ loss_dict['loss'] = float(loss)
277
+ return loss, loss_dict, id_logs
278
+
279
+ def forward(self, batch):
280
+ x, y = batch
281
+ x, y = x.to(self.device).float(), y.to(self.device).float()
282
+ y_hat, latent = self.net.forward(x, return_latents=True)
283
+ if self.opts.dataset_type == "cars_encode":
284
+ y_hat = y_hat[:, :, 32:224, :]
285
+ return x, y, y_hat, latent
286
+
287
+ def log_metrics(self, metrics_dict, prefix):
288
+ for key, value in metrics_dict.items():
289
+ self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step)
290
+
291
+ def print_metrics(self, metrics_dict, prefix):
292
+ print('Metrics for {}, step {}'.format(prefix, self.global_step))
293
+ for key, value in metrics_dict.items():
294
+ print('\t{} = '.format(key), value)
295
+
296
+ def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2):
297
+ im_data = []
298
+ for i in range(display_count):
299
+ cur_im_data = {
300
+ 'input_face': common.log_input_image(x[i], self.opts),
301
+ 'target_face': common.tensor2im(y[i]),
302
+ 'output_face': common.tensor2im(y_hat[i]),
303
+ }
304
+ if id_logs is not None:
305
+ for key in id_logs[i]:
306
+ cur_im_data[key] = id_logs[i][key]
307
+ im_data.append(cur_im_data)
308
+ self.log_images(title, im_data=im_data, subscript=subscript)
309
+
310
+ def log_images(self, name, im_data, subscript=None, log_latest=False):
311
+ fig = common.vis_faces(im_data)
312
+ step = self.global_step
313
+ if log_latest:
314
+ step = 0
315
+ if subscript:
316
+ path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step))
317
+ else:
318
+ path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step))
319
+ os.makedirs(os.path.dirname(path), exist_ok=True)
320
+ fig.savefig(path)
321
+ plt.close(fig)
322
+
323
+ def __get_save_dict(self):
324
+ save_dict = {
325
+ 'state_dict': self.net.state_dict(),
326
+ 'opts': vars(self.opts)
327
+ }
328
+ # save the latent avg in state_dict for inference if truncation of w was used during training
329
+ if self.opts.start_from_latent_avg:
330
+ save_dict['latent_avg'] = self.net.latent_avg
331
+
332
+ if self.opts.save_training_data: # Save necessary information to enable training continuation from checkpoint
333
+ save_dict['global_step'] = self.global_step
334
+ save_dict['optimizer'] = self.optimizer.state_dict()
335
+ save_dict['best_val_loss'] = self.best_val_loss
336
+ if self.opts.w_discriminator_lambda > 0:
337
+ save_dict['discriminator_state_dict'] = self.discriminator.state_dict()
338
+ save_dict['discriminator_optimizer_state_dict'] = self.discriminator_optimizer.state_dict()
339
+ return save_dict
340
+
341
+ def get_dims_to_discriminate(self):
342
+ deltas_starting_dimensions = self.net.encoder.get_deltas_starting_dimensions()
343
+ return deltas_starting_dimensions[:self.net.encoder.progressive_stage.value + 1]
344
+
345
+ def is_progressive_training(self):
346
+ return self.opts.progressive_steps is not None
347
+
348
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Discriminator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
349
+
350
+ def is_training_discriminator(self):
351
+ return self.opts.w_discriminator_lambda > 0
352
+
353
+ @staticmethod
354
+ def discriminator_loss(real_pred, fake_pred, loss_dict):
355
+ real_loss = F.softplus(-real_pred).mean()
356
+ fake_loss = F.softplus(fake_pred).mean()
357
+
358
+ loss_dict['d_real_loss'] = float(real_loss)
359
+ loss_dict['d_fake_loss'] = float(fake_loss)
360
+
361
+ return real_loss + fake_loss
362
+
363
+ @staticmethod
364
+ def discriminator_r1_loss(real_pred, real_w):
365
+ grad_real, = autograd.grad(
366
+ outputs=real_pred.sum(), inputs=real_w, create_graph=True
367
+ )
368
+ grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
369
+
370
+ return grad_penalty
371
+
372
+ @staticmethod
373
+ def requires_grad(model, flag=True):
374
+ for p in model.parameters():
375
+ p.requires_grad = flag
376
+
377
+ def train_discriminator(self, batch):
378
+ loss_dict = {}
379
+ x, _ = batch
380
+ x = x.to(self.device).float()
381
+ self.requires_grad(self.discriminator, True)
382
+
383
+ with torch.no_grad():
384
+ real_w, fake_w = self.sample_real_and_fake_latents(x)
385
+ real_pred = self.discriminator(real_w)
386
+ fake_pred = self.discriminator(fake_w)
387
+ loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
388
+ loss_dict['discriminator_loss'] = float(loss)
389
+
390
+ self.discriminator_optimizer.zero_grad()
391
+ loss.backward()
392
+ self.discriminator_optimizer.step()
393
+
394
+ # r1 regularization
395
+ d_regularize = self.global_step % self.opts.d_reg_every == 0
396
+ if d_regularize:
397
+ real_w = real_w.detach()
398
+ real_w.requires_grad = True
399
+ real_pred = self.discriminator(real_w)
400
+ r1_loss = self.discriminator_r1_loss(real_pred, real_w)
401
+
402
+ self.discriminator.zero_grad()
403
+ r1_final_loss = self.opts.r1 / 2 * r1_loss * self.opts.d_reg_every + 0 * real_pred[0]
404
+ r1_final_loss.backward()
405
+ self.discriminator_optimizer.step()
406
+ loss_dict['discriminator_r1_loss'] = float(r1_final_loss)
407
+
408
+ # Reset to previous state
409
+ self.requires_grad(self.discriminator, False)
410
+
411
+ return loss_dict
412
+
413
+ def validate_discriminator(self, test_batch):
414
+ with torch.no_grad():
415
+ loss_dict = {}
416
+ x, _ = test_batch
417
+ x = x.to(self.device).float()
418
+ real_w, fake_w = self.sample_real_and_fake_latents(x)
419
+ real_pred = self.discriminator(real_w)
420
+ fake_pred = self.discriminator(fake_w)
421
+ loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
422
+ loss_dict['discriminator_loss'] = float(loss)
423
+ return loss_dict
424
+
425
+ def sample_real_and_fake_latents(self, x):
426
+ sample_z = torch.randn(self.opts.batch_size, 512, device=self.device)
427
+ real_w = self.net.decoder.get_latent(sample_z)
428
+ fake_w = self.net.encoder(x)
429
+ if self.is_progressive_training(): # When progressive training, feed only unique w's
430
+ dims_to_discriminate = self.get_dims_to_discriminate()
431
+ fake_w = fake_w[:, dims_to_discriminate, :]
432
+ if self.opts.use_w_pool:
433
+ real_w = self.real_w_pool.query(real_w)
434
+ fake_w = self.fake_w_pool.query(fake_w)
435
+ if fake_w.ndim == 3:
436
+ fake_w = fake_w[:, 0, :]
437
+ return real_w, fake_w
e4e/training/ranger.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2
+
3
+ # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4
+ # and/or
5
+ # https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6
+
7
+ # Ranger has now been used to capture 12 records on the FastAI leaderboard.
8
+
9
+ # This version = 20.4.11
10
+
11
+ # Credits:
12
+ # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
13
+ # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
14
+ # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
15
+ # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
16
+
17
+ # summary of changes:
18
+ # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19
+ # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
20
+ # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21
+ # changes 8/31/19 - fix references to *self*.N_sma_threshold;
22
+ # changed eps to 1e-5 as better default than 1e-8.
23
+
24
+ import math
25
+ import torch
26
+ from torch.optim.optimizer import Optimizer
27
+
28
+
29
+ class Ranger(Optimizer):
30
+
31
+ def __init__(self, params, lr=1e-3, # lr
32
+ alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
33
+ betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
34
+ use_gc=True, gc_conv_only=False
35
+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
36
+ ):
37
+
38
+ # parameter checks
39
+ if not 0.0 <= alpha <= 1.0:
40
+ raise ValueError(f'Invalid slow update rate: {alpha}')
41
+ if not 1 <= k:
42
+ raise ValueError(f'Invalid lookahead steps: {k}')
43
+ if not lr > 0:
44
+ raise ValueError(f'Invalid Learning Rate: {lr}')
45
+ if not eps > 0:
46
+ raise ValueError(f'Invalid eps: {eps}')
47
+
48
+ # parameter comments:
49
+ # beta1 (momentum) of .95 seems to work better than .90...
50
+ # N_sma_threshold of 5 seems better in testing than 4.
51
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
52
+
53
+ # prep defaults and init torch.optim base
54
+ defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
55
+ eps=eps, weight_decay=weight_decay)
56
+ super().__init__(params, defaults)
57
+
58
+ # adjustable threshold
59
+ self.N_sma_threshhold = N_sma_threshhold
60
+
61
+ # look ahead params
62
+
63
+ self.alpha = alpha
64
+ self.k = k
65
+
66
+ # radam buffer for state
67
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
68
+
69
+ # gc on or off
70
+ self.use_gc = use_gc
71
+
72
+ # level of gradient centralization
73
+ self.gc_gradient_threshold = 3 if gc_conv_only else 1
74
+
75
+ def __setstate__(self, state):
76
+ super(Ranger, self).__setstate__(state)
77
+
78
+ def step(self, closure=None):
79
+ loss = None
80
+
81
+ # Evaluate averages and grad, update param tensors
82
+ for group in self.param_groups:
83
+
84
+ for p in group['params']:
85
+ if p.grad is None:
86
+ continue
87
+ grad = p.grad.data.float()
88
+
89
+ if grad.is_sparse:
90
+ raise RuntimeError('Ranger optimizer does not support sparse gradients')
91
+
92
+ p_data_fp32 = p.data.float()
93
+
94
+ state = self.state[p] # get state dict for this param
95
+
96
+ if len(state) == 0: # if first time to run...init dictionary with our desired entries
97
+ # if self.first_run_check==0:
98
+ # self.first_run_check=1
99
+ # print("Initializing slow buffer...should not see this at load from saved model!")
100
+ state['step'] = 0
101
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
102
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
103
+
104
+ # look ahead weight storage now in state dict
105
+ state['slow_buffer'] = torch.empty_like(p.data)
106
+ state['slow_buffer'].copy_(p.data)
107
+
108
+ else:
109
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
110
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
111
+
112
+ # begin computations
113
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
114
+ beta1, beta2 = group['betas']
115
+
116
+ # GC operation for Conv layers and FC layers
117
+ if grad.dim() > self.gc_gradient_threshold:
118
+ grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
119
+
120
+ state['step'] += 1
121
+
122
+ # compute variance mov avg
123
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
124
+ # compute mean moving avg
125
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
126
+
127
+ buffered = self.radam_buffer[int(state['step'] % 10)]
128
+
129
+ if state['step'] == buffered[0]:
130
+ N_sma, step_size = buffered[1], buffered[2]
131
+ else:
132
+ buffered[0] = state['step']
133
+ beta2_t = beta2 ** state['step']
134
+ N_sma_max = 2 / (1 - beta2) - 1
135
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
136
+ buffered[1] = N_sma
137
+ if N_sma > self.N_sma_threshhold:
138
+ step_size = math.sqrt(
139
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
140
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
141
+ else:
142
+ step_size = 1.0 / (1 - beta1 ** state['step'])
143
+ buffered[2] = step_size
144
+
145
+ if group['weight_decay'] != 0:
146
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
147
+
148
+ # apply lr
149
+ if N_sma > self.N_sma_threshhold:
150
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
151
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
152
+ else:
153
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
154
+
155
+ p.data.copy_(p_data_fp32)
156
+
157
+ # integrated look ahead...
158
+ # we do it at the param level instead of group level
159
+ if state['step'] % group['k'] == 0:
160
+ slow_p = state['slow_buffer'] # get access to slow param tensor
161
+ slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
162
+ p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
163
+
164
+ return loss
e4e/utils/__init__.py ADDED
File without changes
e4e/utils/alignment.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL
3
+ import PIL.Image
4
+ import scipy
5
+ import scipy.ndimage
6
+ import dlib
7
+
8
+
9
+ def get_landmark(filepath, predictor):
10
+ """get landmark with dlib
11
+ :return: np.array shape=(68, 2)
12
+ """
13
+ detector = dlib.get_frontal_face_detector()
14
+
15
+ img = dlib.load_rgb_image(filepath)
16
+ dets = detector(img, 1)
17
+
18
+ for k, d in enumerate(dets):
19
+ shape = predictor(img, d)
20
+
21
+ t = list(shape.parts())
22
+ a = []
23
+ for tt in t:
24
+ a.append([tt.x, tt.y])
25
+ lm = np.array(a)
26
+ return lm
27
+
28
+
29
+ def align_face(filepath, predictor):
30
+ """
31
+ :param filepath: str
32
+ :return: PIL Image
33
+ """
34
+
35
+ lm = get_landmark(filepath, predictor)
36
+
37
+ lm_chin = lm[0: 17] # left-right
38
+ lm_eyebrow_left = lm[17: 22] # left-right
39
+ lm_eyebrow_right = lm[22: 27] # left-right
40
+ lm_nose = lm[27: 31] # top-down
41
+ lm_nostrils = lm[31: 36] # top-down
42
+ lm_eye_left = lm[36: 42] # left-clockwise
43
+ lm_eye_right = lm[42: 48] # left-clockwise
44
+ lm_mouth_outer = lm[48: 60] # left-clockwise
45
+ lm_mouth_inner = lm[60: 68] # left-clockwise
46
+
47
+ # Calculate auxiliary vectors.
48
+ eye_left = np.mean(lm_eye_left, axis=0)
49
+ eye_right = np.mean(lm_eye_right, axis=0)
50
+ eye_avg = (eye_left + eye_right) * 0.5
51
+ eye_to_eye = eye_right - eye_left
52
+ mouth_left = lm_mouth_outer[0]
53
+ mouth_right = lm_mouth_outer[6]
54
+ mouth_avg = (mouth_left + mouth_right) * 0.5
55
+ eye_to_mouth = mouth_avg - eye_avg
56
+
57
+ # Choose oriented crop rectangle.
58
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
59
+ x /= np.hypot(*x)
60
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
61
+ y = np.flipud(x) * [-1, 1]
62
+ c = eye_avg + eye_to_mouth * 0.1
63
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
64
+ qsize = np.hypot(*x) * 2
65
+
66
+ # read image
67
+ img = PIL.Image.open(filepath)
68
+
69
+ output_size = 256
70
+ transform_size = 256
71
+ enable_padding = True
72
+
73
+ # Shrink.
74
+ shrink = int(np.floor(qsize / output_size * 0.5))
75
+ if shrink > 1:
76
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
77
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
78
+ quad /= shrink
79
+ qsize /= shrink
80
+
81
+ # Crop.
82
+ border = max(int(np.rint(qsize * 0.1)), 3)
83
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
84
+ int(np.ceil(max(quad[:, 1]))))
85
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
86
+ min(crop[3] + border, img.size[1]))
87
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
88
+ img = img.crop(crop)
89
+ quad -= crop[0:2]
90
+
91
+ # Pad.
92
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
93
+ int(np.ceil(max(quad[:, 1]))))
94
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
95
+ max(pad[3] - img.size[1] + border, 0))
96
+ if enable_padding and max(pad) > border - 4:
97
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
98
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
99
+ h, w, _ = img.shape
100
+ y, x, _ = np.ogrid[:h, :w, :1]
101
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
102
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
103
+ blur = qsize * 0.02
104
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
105
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
106
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
107
+ quad += pad[:2]
108
+
109
+ # Transform.
110
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
111
+ if output_size < transform_size:
112
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
113
+
114
+ # Return aligned image.
115
+ return img
e4e/utils/common.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import matplotlib.pyplot as plt
3
+
4
+
5
+ # Log images
6
+ def log_input_image(x, opts):
7
+ return tensor2im(x)
8
+
9
+
10
+ def tensor2im(var):
11
+ # var shape: (3, H, W)
12
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
13
+ var = ((var + 1) / 2)
14
+ var[var < 0] = 0
15
+ var[var > 1] = 1
16
+ var = var * 255
17
+ return Image.fromarray(var.astype('uint8'))
18
+
19
+
20
+ def vis_faces(log_hooks):
21
+ display_count = len(log_hooks)
22
+ fig = plt.figure(figsize=(8, 4 * display_count))
23
+ gs = fig.add_gridspec(display_count, 3)
24
+ for i in range(display_count):
25
+ hooks_dict = log_hooks[i]
26
+ fig.add_subplot(gs[i, 0])
27
+ if 'diff_input' in hooks_dict:
28
+ vis_faces_with_id(hooks_dict, fig, gs, i)
29
+ else:
30
+ vis_faces_no_id(hooks_dict, fig, gs, i)
31
+ plt.tight_layout()
32
+ return fig
33
+
34
+
35
+ def vis_faces_with_id(hooks_dict, fig, gs, i):
36
+ plt.imshow(hooks_dict['input_face'])
37
+ plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input'])))
38
+ fig.add_subplot(gs[i, 1])
39
+ plt.imshow(hooks_dict['target_face'])
40
+ plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']),
41
+ float(hooks_dict['diff_target'])))
42
+ fig.add_subplot(gs[i, 2])
43
+ plt.imshow(hooks_dict['output_face'])
44
+ plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target'])))
45
+
46
+
47
+ def vis_faces_no_id(hooks_dict, fig, gs, i):
48
+ plt.imshow(hooks_dict['input_face'], cmap="gray")
49
+ plt.title('Input')
50
+ fig.add_subplot(gs[i, 1])
51
+ plt.imshow(hooks_dict['target_face'])
52
+ plt.title('Target')
53
+ fig.add_subplot(gs[i, 2])
54
+ plt.imshow(hooks_dict['output_face'])
55
+ plt.title('Output')
e4e/utils/data_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adopted from pix2pixHD:
3
+ https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py
4
+ """
5
+ import os
6
+
7
+ IMG_EXTENSIONS = [
8
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
9
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
10
+ ]
11
+
12
+
13
+ def is_image_file(filename):
14
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15
+
16
+
17
+ def make_dataset(dir):
18
+ images = []
19
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
20
+ for root, _, fnames in sorted(os.walk(dir)):
21
+ for fname in fnames:
22
+ if is_image_file(fname):
23
+ path = os.path.join(root, fname)
24
+ images.append(path)
25
+ return images
e4e/utils/model_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from models.psp import pSp
4
+ from models.encoders.psp_encoders import Encoder4Editing
5
+
6
+
7
+ def setup_model(checkpoint_path, device='cuda'):
8
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
9
+ opts = ckpt['opts']
10
+
11
+ opts['checkpoint_path'] = checkpoint_path
12
+ opts['device'] = device
13
+ opts = argparse.Namespace(**opts)
14
+
15
+ net = pSp(opts)
16
+ net.eval()
17
+ net = net.to(device)
18
+ return net, opts
19
+
20
+
21
+ def load_e4e_standalone(checkpoint_path, device='cuda'):
22
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
23
+ opts = argparse.Namespace(**ckpt['opts'])
24
+ e4e = Encoder4Editing(50, 'ir_se', opts)
25
+ e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')}
26
+ e4e.load_state_dict(e4e_dict)
27
+ e4e.eval()
28
+ e4e = e4e.to(device)
29
+ latent_avg = ckpt['latent_avg'].to(device)
30
+
31
+ def add_latent_avg(model, inputs, outputs):
32
+ return outputs + latent_avg.repeat(outputs.shape[0], 1, 1)
33
+
34
+ e4e.register_forward_hook(add_latent_avg)
35
+ return e4e
e4e/utils/train_utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def aggregate_loss_dict(agg_loss_dict):
3
+ mean_vals = {}
4
+ for output in agg_loss_dict:
5
+ for key in output:
6
+ mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
7
+ for key in mean_vals:
8
+ if len(mean_vals[key]) > 0:
9
+ mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
10
+ else:
11
+ print('{} has no value'.format(key))
12
+ mean_vals[key] = 0
13
+ return mean_vals