Spaces:
Sleeping
Sleeping
IceClear
commited on
Commit
•
a68feeb
1
Parent(s):
2abaa1b
init app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- StableSR/.gitignore +134 -0
- StableSR/LICENSE.txt +35 -0
- StableSR/README.md +175 -0
- StableSR/basicsr/__init__.py +12 -0
- StableSR/basicsr/archs/__init__.py +24 -0
- StableSR/basicsr/archs/arch_util.py +352 -0
- StableSR/basicsr/archs/basicvsr_arch.py +336 -0
- StableSR/basicsr/archs/basicvsrpp_arch.py +417 -0
- StableSR/basicsr/archs/degradat_arch.py +90 -0
- StableSR/basicsr/archs/dfdnet_arch.py +169 -0
- StableSR/basicsr/archs/dfdnet_util.py +162 -0
- StableSR/basicsr/archs/discriminator_arch.py +150 -0
- StableSR/basicsr/archs/duf_arch.py +276 -0
- StableSR/basicsr/archs/ecbsr_arch.py +275 -0
- StableSR/basicsr/archs/edsr_arch.py +61 -0
- StableSR/basicsr/archs/edvr_arch.py +382 -0
- StableSR/basicsr/archs/hifacegan_arch.py +260 -0
- StableSR/basicsr/archs/hifacegan_util.py +255 -0
- StableSR/basicsr/archs/inception.py +307 -0
- StableSR/basicsr/archs/rcan_arch.py +135 -0
- StableSR/basicsr/archs/ridnet_arch.py +180 -0
- StableSR/basicsr/archs/rrdbnet_arch.py +119 -0
- StableSR/basicsr/archs/spynet_arch.py +96 -0
- StableSR/basicsr/archs/srresnet_arch.py +65 -0
- StableSR/basicsr/archs/srvgg_arch.py +70 -0
- StableSR/basicsr/archs/stylegan2_arch.py +799 -0
- StableSR/basicsr/archs/stylegan2_bilinear_arch.py +614 -0
- StableSR/basicsr/archs/swinir_arch.py +956 -0
- StableSR/basicsr/archs/tof_arch.py +172 -0
- StableSR/basicsr/archs/vgg_arch.py +161 -0
- StableSR/basicsr/data/__init__.py +101 -0
- StableSR/basicsr/data/data_sampler.py +48 -0
- StableSR/basicsr/data/data_util.py +362 -0
- StableSR/basicsr/data/degradations.py +935 -0
- StableSR/basicsr/data/ffhq_dataset.py +80 -0
- StableSR/basicsr/data/ffhq_degradation_dataset.py +231 -0
- StableSR/basicsr/data/paired_image_dataset.py +115 -0
- StableSR/basicsr/data/prefetch_dataloader.py +122 -0
- StableSR/basicsr/data/realesrgan_dataset.py +242 -0
- StableSR/basicsr/data/realesrgan_paired_dataset.py +114 -0
- StableSR/basicsr/data/reds_dataset.py +352 -0
- StableSR/basicsr/data/single_image_dataset.py +164 -0
- StableSR/basicsr/data/transforms.py +240 -0
- StableSR/basicsr/data/video_test_dataset.py +283 -0
- StableSR/basicsr/data/vimeo90k_dataset.py +199 -0
- StableSR/basicsr/losses/__init__.py +31 -0
- StableSR/basicsr/losses/basic_loss.py +253 -0
- StableSR/basicsr/losses/gan_loss.py +207 -0
- StableSR/basicsr/losses/loss_util.py +145 -0
- StableSR/basicsr/metrics/README.md +48 -0
StableSR/.gitignore
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ignored folders
|
2 |
+
logs/*
|
3 |
+
models/*
|
4 |
+
src/
|
5 |
+
results/
|
6 |
+
wandb/
|
7 |
+
output/
|
8 |
+
|
9 |
+
*.DS_Store
|
10 |
+
.idea
|
11 |
+
|
12 |
+
# ignored files
|
13 |
+
version.py
|
14 |
+
|
15 |
+
# ignored files with suffix
|
16 |
+
*.html
|
17 |
+
*.png
|
18 |
+
*.jpeg
|
19 |
+
*.jpg
|
20 |
+
*.gif
|
21 |
+
*.pth
|
22 |
+
*.zip
|
23 |
+
# *.txt
|
24 |
+
*.svg
|
25 |
+
*.ckpt
|
26 |
+
|
27 |
+
# template
|
28 |
+
|
29 |
+
# Byte-compiled / optimized / DLL files
|
30 |
+
__pycache__/
|
31 |
+
*.py[cod]
|
32 |
+
*$py.class
|
33 |
+
|
34 |
+
# C extensions
|
35 |
+
*.so
|
36 |
+
|
37 |
+
# Distribution / packaging
|
38 |
+
.Python
|
39 |
+
build/
|
40 |
+
develop-eggs/
|
41 |
+
dist/
|
42 |
+
downloads/
|
43 |
+
eggs/
|
44 |
+
.eggs/
|
45 |
+
lib/
|
46 |
+
lib64/
|
47 |
+
parts/
|
48 |
+
sdist/
|
49 |
+
var/
|
50 |
+
wheels/
|
51 |
+
*.egg-info/
|
52 |
+
.installed.cfg
|
53 |
+
*.egg
|
54 |
+
MANIFEST
|
55 |
+
|
56 |
+
# PyInstaller
|
57 |
+
# Usually these files are written by a python script from a template
|
58 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
59 |
+
*.manifest
|
60 |
+
*.spec
|
61 |
+
|
62 |
+
# Installer logs
|
63 |
+
pip-log.txt
|
64 |
+
pip-delete-this-directory.txt
|
65 |
+
|
66 |
+
# Unit test / coverage reports
|
67 |
+
htmlcov/
|
68 |
+
.tox/
|
69 |
+
.coverage
|
70 |
+
.coverage.*
|
71 |
+
.cache
|
72 |
+
nosetests.xml
|
73 |
+
coverage.xml
|
74 |
+
*.cover
|
75 |
+
.hypothesis/
|
76 |
+
.pytest_cache/
|
77 |
+
|
78 |
+
# Translations
|
79 |
+
*.mo
|
80 |
+
*.pot
|
81 |
+
|
82 |
+
# Django stuff:
|
83 |
+
*.log
|
84 |
+
local_settings.py
|
85 |
+
db.sqlite3
|
86 |
+
|
87 |
+
# Flask stuff:
|
88 |
+
instance/
|
89 |
+
.webassets-cache
|
90 |
+
|
91 |
+
# Scrapy stuff:
|
92 |
+
.scrapy
|
93 |
+
|
94 |
+
# Sphinx documentation
|
95 |
+
docs/_build/
|
96 |
+
|
97 |
+
# PyBuilder
|
98 |
+
target/
|
99 |
+
|
100 |
+
# Jupyter Notebook
|
101 |
+
.ipynb_checkpoints
|
102 |
+
|
103 |
+
# pyenv
|
104 |
+
.python-version
|
105 |
+
|
106 |
+
# celery beat schedule file
|
107 |
+
celerybeat-schedule
|
108 |
+
|
109 |
+
# SageMath parsed files
|
110 |
+
*.sage.py
|
111 |
+
|
112 |
+
# Environments
|
113 |
+
.env
|
114 |
+
.venv
|
115 |
+
env/
|
116 |
+
venv/
|
117 |
+
ENV/
|
118 |
+
env.bak/
|
119 |
+
venv.bak/
|
120 |
+
|
121 |
+
# Spyder project settings
|
122 |
+
.spyderproject
|
123 |
+
.spyproject
|
124 |
+
|
125 |
+
# Rope project settings
|
126 |
+
.ropeproject
|
127 |
+
|
128 |
+
# mkdocs documentation
|
129 |
+
/site
|
130 |
+
|
131 |
+
# mypy
|
132 |
+
.mypy_cache/
|
133 |
+
|
134 |
+
outputs/
|
StableSR/LICENSE.txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
S-Lab License 1.0
|
2 |
+
|
3 |
+
Copyright 2022 S-Lab
|
4 |
+
|
5 |
+
Redistribution and use for non-commercial purpose in source and
|
6 |
+
binary forms, with or without modification, are permitted provided
|
7 |
+
that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright
|
10 |
+
notice, this list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright
|
13 |
+
notice, this list of conditions and the following disclaimer in
|
14 |
+
the documentation and/or other materials provided with the
|
15 |
+
distribution.
|
16 |
+
|
17 |
+
3. Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived
|
19 |
+
from this software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
22 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
23 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
24 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
25 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
26 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
27 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
28 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
29 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
30 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
31 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
32 |
+
|
33 |
+
In the event that redistribution and/or use for commercial purpose in
|
34 |
+
source or binary forms, with or without modification is required,
|
35 |
+
please contact the contributor(s) of the work.
|
StableSR/README.md
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="https://user-images.githubusercontent.com/22350795/236680126-0b1cdd62-d6fc-4620-b998-75ed6c31bf6f.png" height=40>
|
3 |
+
</p>
|
4 |
+
|
5 |
+
## Exploiting Diffusion Prior for Real-World Image Super-Resolution
|
6 |
+
|
7 |
+
[Paper](https://arxiv.org/abs/2305.07015) | [Project Page](https://iceclear.github.io/projects/stablesr/) | [Video](https://www.youtube.com/watch?v=5MZy9Uhpkw4) | [WebUI](https://github.com/pkuliyi2015/sd-webui-stablesr) | [ModelScope](https://modelscope.cn/models/xhlin129/cv_stablesr_image-super-resolution/summary)
|
8 |
+
|
9 |
+
|
10 |
+
<a href="https://colab.research.google.com/drive/11SE2_oDvbYtcuHDbaLAxsKk_o3flsO1T?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/stablesr) ![visitors](https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR)
|
11 |
+
|
12 |
+
|
13 |
+
[Jianyi Wang](https://iceclear.github.io/), [Zongsheng Yue](https://zsyoaoa.github.io/), [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
|
14 |
+
|
15 |
+
S-Lab, Nanyang Technological University
|
16 |
+
|
17 |
+
<img src="assets/network.png" width="800px"/>
|
18 |
+
|
19 |
+
:star: If StableSR is helpful to your images or projects, please help star this repo. Thanks! :hugs:
|
20 |
+
|
21 |
+
### Update
|
22 |
+
- **2023.07.31**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/stablesr) Thank [Chenxi](https://github.com/chenxwh) for the implementation!
|
23 |
+
- **2023.07.16**: You may reproduce the LDM baseline used in our paper using [LDM-SRtuning](https://github.com/IceClear/LDM-SRtuning) [![GitHub Stars](https://img.shields.io/github/stars/IceClear/LDM-SRtuning?style=social)](https://github.com/IceClear/LDM-SRtuning).
|
24 |
+
- **2023.07.14**: :whale: [**ModelScope**](https://modelscope.cn/models/xhlin129/cv_stablesr_image-super-resolution/summary) for StableSR is released!
|
25 |
+
- **2023.06.30**: :whale: [**New model**](https://huggingface.co/Iceclear/StableSR/blob/main/stablesr_768v_000139.ckpt) trained on [SD-2.1-768v](https://huggingface.co/stabilityai/stable-diffusion-2-1) is released! Better performance with fewer artifacts!
|
26 |
+
- **2023.06.28**: Support training on SD-2.1-768v.
|
27 |
+
- **2023.05.22**: :whale: Improve the code to save more GPU memory, now 128 --> 512 needs 8.9G. Enable start from intermediate steps.
|
28 |
+
- **2023.05.20**: :whale: The [**WebUI**](https://github.com/pkuliyi2015/sd-webui-stablesr) [![GitHub Stars](https://img.shields.io/github/stars/pkuliyi2015/sd-webui-stablesr?style=social)](https://github.com/pkuliyi2015/sd-webui-stablesr) of StableSR is available. Thank [Li Yi](https://github.com/pkuliyi2015) for the implementation!
|
29 |
+
- **2023.05.13**: Add Colab demo of StableSR. <a href="https://colab.research.google.com/drive/11SE2_oDvbYtcuHDbaLAxsKk_o3flsO1T?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
|
30 |
+
- **2023.05.11**: Repo is released.
|
31 |
+
|
32 |
+
### TODO
|
33 |
+
- [ ] HuggingFace demo (If necessary)
|
34 |
+
- [x] ~~Code release~~
|
35 |
+
- [x] ~~Update link to paper and project page~~
|
36 |
+
- [x] ~~Pretrained models~~
|
37 |
+
- [x] ~~Colab demo~~
|
38 |
+
- [x] ~~StableSR-768v released~~
|
39 |
+
- [x] ~~Replicate demo~~
|
40 |
+
|
41 |
+
### Demo on real-world SR
|
42 |
+
|
43 |
+
[<img src="assets/imgsli_1.jpg" height="223px"/>](https://imgsli.com/MTc2MTI2) [<img src="assets/imgsli_2.jpg" height="223px"/>](https://imgsli.com/MTc2MTE2) [<img src="assets/imgsli_3.jpg" height="223px"/>](https://imgsli.com/MTc2MTIw)
|
44 |
+
[<img src="assets/imgsli_8.jpg" height="223px"/>](https://imgsli.com/MTc2MjUy) [<img src="assets/imgsli_4.jpg" height="223px"/>](https://imgsli.com/MTc2MTMy) [<img src="assets/imgsli_5.jpg" height="223px"/>](https://imgsli.com/MTc2MTMz)
|
45 |
+
[<img src="assets/imgsli_9.jpg" height="214px"/>](https://imgsli.com/MTc2MjQ5) [<img src="assets/imgsli_6.jpg" height="214px"/>](https://imgsli.com/MTc2MTM0) [<img src="assets/imgsli_7.jpg" height="214px"/>](https://imgsli.com/MTc2MTM2) [<img src="assets/imgsli_10.jpg" height="214px"/>](https://imgsli.com/MTc2MjU0)
|
46 |
+
|
47 |
+
For more evaluation, please refer to our [paper](https://arxiv.org/abs/2305.07015) for details.
|
48 |
+
|
49 |
+
### Demo on 4K Results
|
50 |
+
|
51 |
+
- StableSR is capable of achieving arbitrary upscaling in theory, below is a 8x example with a result beyond 4K (5120x3680).
|
52 |
+
The example image is taken from [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/tests/images/ski.jpg).
|
53 |
+
|
54 |
+
[<img src="assets/imgsli_11.jpg" width="800px"/>](https://imgsli.com/MTc4NDk2)
|
55 |
+
|
56 |
+
- We further directly test StableSR on AIGC and compared with several diffusion-based upscalers following the suggestions.
|
57 |
+
A 4K demo is [here](https://imgsli.com/MTc4MDg3), which is a 4x SR on the image from [here](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111).
|
58 |
+
More comparisons can be found [here](https://github.com/IceClear/StableSR/issues/2).
|
59 |
+
|
60 |
+
### Dependencies and Installation
|
61 |
+
- Pytorch == 1.12.1
|
62 |
+
- CUDA == 11.7
|
63 |
+
- pytorch-lightning==1.4.2
|
64 |
+
- xformers == 0.0.16 (Optional)
|
65 |
+
- Other required packages in `environment.yaml`
|
66 |
+
```
|
67 |
+
# git clone this repository
|
68 |
+
git clone https://github.com/IceClear/StableSR.git
|
69 |
+
cd StableSR
|
70 |
+
|
71 |
+
# Create a conda environment and activate it
|
72 |
+
conda env create --file environment.yaml
|
73 |
+
conda activate stablesr
|
74 |
+
|
75 |
+
# Install xformers
|
76 |
+
conda install xformers -c xformers/label/dev
|
77 |
+
|
78 |
+
# Install taming & clip
|
79 |
+
pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
80 |
+
pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
81 |
+
pip install -e .
|
82 |
+
```
|
83 |
+
|
84 |
+
### Running Examples
|
85 |
+
|
86 |
+
#### Train
|
87 |
+
Download the pretrained Stable Diffusion models from [[HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base)]
|
88 |
+
|
89 |
+
- Train Time-aware encoder with SFT: set the ckpt_path in config files ([Line 22](https://github.com/IceClear/StableSR/blob/main/configs/stableSRNew/v2-finetune_text_T_512.yaml#L22) and [Line 55](https://github.com/IceClear/StableSR/blob/main/configs/stableSRNew/v2-finetune_text_T_512.yaml#L55))
|
90 |
+
```
|
91 |
+
python main.py --train --base configs/stableSRNew/v2-finetune_text_T_512.yaml --gpus GPU_ID, --name NAME --scale_lr False
|
92 |
+
```
|
93 |
+
|
94 |
+
- Train CFW: set the ckpt_path in config files ([Line 6](https://github.com/IceClear/StableSR/blob/main/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml#L6)).
|
95 |
+
|
96 |
+
You need to first generate training data using the finetuned diffusion model in the first stage. The data folder should be like this:
|
97 |
+
```
|
98 |
+
CFW_trainingdata/
|
99 |
+
└── inputs
|
100 |
+
└── 00000001.png # LQ images, (512, 512, 3) (resize to 512x512)
|
101 |
+
└── ...
|
102 |
+
└── gts
|
103 |
+
└── 00000001.png # GT images, (512, 512, 3) (512x512)
|
104 |
+
└── ...
|
105 |
+
└── latents
|
106 |
+
└── 00000001.npy # Latent codes (N, 4, 64, 64) of HR images generated by the diffusion U-net, saved in .npy format.
|
107 |
+
└── ...
|
108 |
+
└── samples
|
109 |
+
└── 00000001.png # The HR images generated from latent codes, just to make sure the generated latents are correct.
|
110 |
+
└── ...
|
111 |
+
```
|
112 |
+
|
113 |
+
Then you can train CFW:
|
114 |
+
```
|
115 |
+
python main.py --train --base configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml --gpus GPU_ID, --name NAME --scale_lr False
|
116 |
+
```
|
117 |
+
|
118 |
+
#### Resume
|
119 |
+
|
120 |
+
```
|
121 |
+
python main.py --train --base configs/stableSRNew/v2-finetune_text_T_512.yaml --gpus GPU_ID, --resume RESUME_PATH --scale_lr False
|
122 |
+
```
|
123 |
+
|
124 |
+
#### Test directly
|
125 |
+
|
126 |
+
Download the Diffusion and autoencoder pretrained models from [[HuggingFace](https://huggingface.co/Iceclear/StableSR/blob/main/README.md) | [Google Drive](https://drive.google.com/drive/folders/1FBkW9FtTBssM_42kOycMPE0o9U5biYCl?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/jianyi001_e_ntu_edu_sg/Et5HPkgRyyxNk269f5xYCacBpZq-bggFRCDbL9imSQ5QDQ)].
|
127 |
+
We use the same color correction scheme introduced in paper by default.
|
128 |
+
You may change ```--colorfix_type wavelet``` for better color correction.
|
129 |
+
You may also disable color correction by ```--colorfix_type nofix```
|
130 |
+
|
131 |
+
- Test on 128 --> 512: You need at least 10G GPU memory to run this script (batchsize 2 by default)
|
132 |
+
```
|
133 |
+
python scripts/sr_val_ddpm_text_T_vqganfin_old.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt CKPT_PATH --vqgan_ckpt VQGANCKPT_PATH --init-img INPUT_PATH --outdir OUT_DIR --ddpm_steps 200 --dec_w 0.5 --colorfix_type adain
|
134 |
+
```
|
135 |
+
- Test on arbitrary size w/o chop for autoencoder (for results beyond 512): The memory cost depends on your image size, but is usually above 10G.
|
136 |
+
```
|
137 |
+
python scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt CKPT_PATH --vqgan_ckpt VQGANCKPT_PATH --init-img INPUT_PATH --outdir OUT_DIR --ddpm_steps 200 --dec_w 0.5 --colorfix_type adain
|
138 |
+
```
|
139 |
+
|
140 |
+
- Test on arbitrary size w/ chop for autoencoder: Current default setting needs at least 18G to run, you may reduce the autoencoder tile size by setting ```--vqgantile_size``` and ```--vqgantile_stride```.
|
141 |
+
Note the min tile size is 512 and the stride should be smaller than the tile size. A smaller size may introduce more border artifacts.
|
142 |
+
```
|
143 |
+
python scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt CKPT_PATH --vqgan_ckpt VQGANCKPT_PATH --init-img INPUT_PATH --outdir OUT_DIR --ddpm_steps 200 --dec_w 0.5 --colorfix_type adain
|
144 |
+
```
|
145 |
+
|
146 |
+
- For test on 768 model, you need to set ```--config configs/stableSRNew/v2-finetune_text_T_768v.yaml```, ```--input_size 768``` and ```--ckpt```. You can also adjust ```--tile_overlap```, ```--vqgantile_size``` and ```--vqgantile_stride``` accordingly. We did not finetune CFW.
|
147 |
+
|
148 |
+
#### Test using Replicate API
|
149 |
+
```
|
150 |
+
import replicate
|
151 |
+
model = replicate.models.get(<model_name>)
|
152 |
+
model.predict(input_image=...)
|
153 |
+
```
|
154 |
+
You may see [here](https://replicate.com/cjwbw/stablesr/api) for more information.
|
155 |
+
|
156 |
+
### Citation
|
157 |
+
If our work is useful for your research, please consider citing:
|
158 |
+
|
159 |
+
@inproceedings{wang2023exploiting,
|
160 |
+
author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin CK and Loy, Chen Change},
|
161 |
+
title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
|
162 |
+
booktitle = {arXiv preprint arXiv:2305.07015},
|
163 |
+
year = {2023}
|
164 |
+
}
|
165 |
+
|
166 |
+
### License
|
167 |
+
|
168 |
+
This project is licensed under <a rel="license" href="https://github.com/IceClear/StableSR/blob/main/LICENSE.txt">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.
|
169 |
+
|
170 |
+
### Acknowledgement
|
171 |
+
|
172 |
+
This project is based on [stablediffusion](https://github.com/Stability-AI/stablediffusion), [latent-diffusion](https://github.com/CompVis/latent-diffusion), [SPADE](https://github.com/NVlabs/SPADE), [mixture-of-diffusers](https://github.com/albarji/mixture-of-diffusers) and [BasicSR](https://github.com/XPixelGroup/BasicSR). Thanks for their awesome work.
|
173 |
+
|
174 |
+
### Contact
|
175 |
+
If you have any questions, please feel free to reach me out at `iceclearwjy@gmail.com`.
|
StableSR/basicsr/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/xinntao/BasicSR
|
2 |
+
# flake8: noqa
|
3 |
+
from .archs import *
|
4 |
+
from .data import *
|
5 |
+
from .losses import *
|
6 |
+
from .metrics import *
|
7 |
+
from .models import *
|
8 |
+
from .ops import *
|
9 |
+
from .test import *
|
10 |
+
from .train import *
|
11 |
+
from .utils import *
|
12 |
+
# from .version import __gitsha__, __version__
|
StableSR/basicsr/archs/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from basicsr.utils import get_root_logger, scandir
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_network']
|
9 |
+
|
10 |
+
# automatically scan and import arch modules for registry
|
11 |
+
# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
|
12 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
13 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
14 |
+
# import all the arch modules
|
15 |
+
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
|
16 |
+
|
17 |
+
|
18 |
+
def build_network(opt):
|
19 |
+
opt = deepcopy(opt)
|
20 |
+
network_type = opt.pop('type')
|
21 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
22 |
+
logger = get_root_logger()
|
23 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
24 |
+
return net
|
StableSR/basicsr/archs/arch_util.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
from distutils.version import LooseVersion
|
7 |
+
from itertools import repeat
|
8 |
+
from torch import nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn import init as init
|
11 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
12 |
+
|
13 |
+
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
|
14 |
+
from basicsr.utils import get_root_logger
|
15 |
+
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
19 |
+
"""Initialize network weights.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
23 |
+
scale (float): Scale initialized weights, especially for residual
|
24 |
+
blocks. Default: 1.
|
25 |
+
bias_fill (float): The value to fill bias. Default: 0
|
26 |
+
kwargs (dict): Other arguments for initialization function.
|
27 |
+
"""
|
28 |
+
if not isinstance(module_list, list):
|
29 |
+
module_list = [module_list]
|
30 |
+
for module in module_list:
|
31 |
+
for m in module.modules():
|
32 |
+
if isinstance(m, nn.Conv2d):
|
33 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
34 |
+
m.weight.data *= scale
|
35 |
+
if m.bias is not None:
|
36 |
+
m.bias.data.fill_(bias_fill)
|
37 |
+
elif isinstance(m, nn.Linear):
|
38 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
39 |
+
m.weight.data *= scale
|
40 |
+
if m.bias is not None:
|
41 |
+
m.bias.data.fill_(bias_fill)
|
42 |
+
elif isinstance(m, _BatchNorm):
|
43 |
+
init.constant_(m.weight, 1)
|
44 |
+
if m.bias is not None:
|
45 |
+
m.bias.data.fill_(bias_fill)
|
46 |
+
|
47 |
+
|
48 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
49 |
+
"""Make layers by stacking the same blocks.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
basic_block (nn.module): nn.module class for basic block.
|
53 |
+
num_basic_block (int): number of blocks.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
57 |
+
"""
|
58 |
+
layers = []
|
59 |
+
for _ in range(num_basic_block):
|
60 |
+
layers.append(basic_block(**kwarg))
|
61 |
+
return nn.Sequential(*layers)
|
62 |
+
|
63 |
+
class PixelShufflePack(nn.Module):
|
64 |
+
"""Pixel Shuffle upsample layer.
|
65 |
+
Args:
|
66 |
+
in_channels (int): Number of input channels.
|
67 |
+
out_channels (int): Number of output channels.
|
68 |
+
scale_factor (int): Upsample ratio.
|
69 |
+
upsample_kernel (int): Kernel size of Conv layer to expand channels.
|
70 |
+
Returns:
|
71 |
+
Upsampled feature map.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, in_channels, out_channels, scale_factor,
|
75 |
+
upsample_kernel):
|
76 |
+
super().__init__()
|
77 |
+
self.in_channels = in_channels
|
78 |
+
self.out_channels = out_channels
|
79 |
+
self.scale_factor = scale_factor
|
80 |
+
self.upsample_kernel = upsample_kernel
|
81 |
+
self.upsample_conv = nn.Conv2d(
|
82 |
+
self.in_channels,
|
83 |
+
self.out_channels * scale_factor * scale_factor,
|
84 |
+
self.upsample_kernel,
|
85 |
+
padding=(self.upsample_kernel - 1) // 2)
|
86 |
+
self.init_weights()
|
87 |
+
|
88 |
+
def init_weights(self):
|
89 |
+
"""Initialize weights for PixelShufflePack."""
|
90 |
+
default_init_weights(self, 1)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
"""Forward function for PixelShufflePack.
|
94 |
+
Args:
|
95 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
96 |
+
Returns:
|
97 |
+
Tensor: Forward results.
|
98 |
+
"""
|
99 |
+
x = self.upsample_conv(x)
|
100 |
+
x = F.pixel_shuffle(x, self.scale_factor)
|
101 |
+
return x
|
102 |
+
|
103 |
+
class ResidualBlockNoBN(nn.Module):
|
104 |
+
"""Residual block without BN.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
num_feat (int): Channel number of intermediate features.
|
108 |
+
Default: 64.
|
109 |
+
res_scale (float): Residual scale. Default: 1.
|
110 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
111 |
+
otherwise, use default_init_weights. Default: False.
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
115 |
+
super(ResidualBlockNoBN, self).__init__()
|
116 |
+
self.res_scale = res_scale
|
117 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
118 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
119 |
+
self.relu = nn.ReLU(inplace=True)
|
120 |
+
|
121 |
+
if not pytorch_init:
|
122 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
identity = x
|
126 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
127 |
+
return identity + out * self.res_scale
|
128 |
+
|
129 |
+
|
130 |
+
class Upsample(nn.Sequential):
|
131 |
+
"""Upsample module.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
135 |
+
num_feat (int): Channel number of intermediate features.
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, scale, num_feat):
|
139 |
+
m = []
|
140 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
141 |
+
for _ in range(int(math.log(scale, 2))):
|
142 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
143 |
+
m.append(nn.PixelShuffle(2))
|
144 |
+
elif scale == 3:
|
145 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
146 |
+
m.append(nn.PixelShuffle(3))
|
147 |
+
else:
|
148 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
149 |
+
super(Upsample, self).__init__(*m)
|
150 |
+
|
151 |
+
|
152 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
153 |
+
"""Warp an image or feature map with optical flow.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
157 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
158 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
159 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
160 |
+
Default: 'zeros'.
|
161 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
162 |
+
align_corners=True. After pytorch 1.3, the default value is
|
163 |
+
align_corners=False. Here, we use the True as default.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: Warped image or feature map.
|
167 |
+
"""
|
168 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
169 |
+
_, _, h, w = x.size()
|
170 |
+
# create mesh grid
|
171 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
172 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
173 |
+
grid.requires_grad = False
|
174 |
+
|
175 |
+
vgrid = grid + flow
|
176 |
+
# scale grid to [-1,1]
|
177 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
178 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
179 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
180 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
181 |
+
|
182 |
+
# TODO, what if align_corners=False
|
183 |
+
return output
|
184 |
+
|
185 |
+
|
186 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
187 |
+
"""Resize a flow according to ratio or shape.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
191 |
+
size_type (str): 'ratio' or 'shape'.
|
192 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
193 |
+
shape.
|
194 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
195 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
196 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
197 |
+
ratio > 1.0).
|
198 |
+
2) The order of output_size should be [out_h, out_w].
|
199 |
+
interp_mode (str): The mode of interpolation for resizing.
|
200 |
+
Default: 'bilinear'.
|
201 |
+
align_corners (bool): Whether align corners. Default: False.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Tensor: Resized flow.
|
205 |
+
"""
|
206 |
+
_, _, flow_h, flow_w = flow.size()
|
207 |
+
if size_type == 'ratio':
|
208 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
209 |
+
elif size_type == 'shape':
|
210 |
+
output_h, output_w = sizes[0], sizes[1]
|
211 |
+
else:
|
212 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
213 |
+
|
214 |
+
input_flow = flow.clone()
|
215 |
+
ratio_h = output_h / flow_h
|
216 |
+
ratio_w = output_w / flow_w
|
217 |
+
input_flow[:, 0, :, :] *= ratio_w
|
218 |
+
input_flow[:, 1, :, :] *= ratio_h
|
219 |
+
resized_flow = F.interpolate(
|
220 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
221 |
+
return resized_flow
|
222 |
+
|
223 |
+
|
224 |
+
# TODO: may write a cpp file
|
225 |
+
def pixel_unshuffle(x, scale):
|
226 |
+
""" Pixel unshuffle.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
230 |
+
scale (int): Downsample ratio.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
Tensor: the pixel unshuffled feature.
|
234 |
+
"""
|
235 |
+
b, c, hh, hw = x.size()
|
236 |
+
out_channel = c * (scale**2)
|
237 |
+
assert hh % scale == 0 and hw % scale == 0
|
238 |
+
h = hh // scale
|
239 |
+
w = hw // scale
|
240 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
241 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
242 |
+
|
243 |
+
|
244 |
+
class DCNv2Pack(ModulatedDeformConvPack):
|
245 |
+
"""Modulated deformable conv for deformable alignment.
|
246 |
+
|
247 |
+
Different from the official DCNv2Pack, which generates offsets and masks
|
248 |
+
from the preceding features, this DCNv2Pack takes another different
|
249 |
+
features to generate offsets and masks.
|
250 |
+
|
251 |
+
``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
|
252 |
+
"""
|
253 |
+
|
254 |
+
def forward(self, x, feat):
|
255 |
+
out = self.conv_offset(feat)
|
256 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
257 |
+
offset = torch.cat((o1, o2), dim=1)
|
258 |
+
mask = torch.sigmoid(mask)
|
259 |
+
|
260 |
+
offset_absmean = torch.mean(torch.abs(offset))
|
261 |
+
if offset_absmean > 50:
|
262 |
+
logger = get_root_logger()
|
263 |
+
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
|
264 |
+
|
265 |
+
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
|
266 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
267 |
+
self.dilation, mask)
|
268 |
+
else:
|
269 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
270 |
+
self.dilation, self.groups, self.deformable_groups)
|
271 |
+
|
272 |
+
|
273 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
274 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
275 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
276 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
277 |
+
def norm_cdf(x):
|
278 |
+
# Computes standard normal cumulative distribution function
|
279 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
280 |
+
|
281 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
282 |
+
warnings.warn(
|
283 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
284 |
+
'The distribution of values may be incorrect.',
|
285 |
+
stacklevel=2)
|
286 |
+
|
287 |
+
with torch.no_grad():
|
288 |
+
# Values are generated by using a truncated uniform distribution and
|
289 |
+
# then using the inverse CDF for the normal distribution.
|
290 |
+
# Get upper and lower cdf values
|
291 |
+
low = norm_cdf((a - mean) / std)
|
292 |
+
up = norm_cdf((b - mean) / std)
|
293 |
+
|
294 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
295 |
+
# [2l-1, 2u-1].
|
296 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
297 |
+
|
298 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
299 |
+
# standard normal
|
300 |
+
tensor.erfinv_()
|
301 |
+
|
302 |
+
# Transform to proper mean, std
|
303 |
+
tensor.mul_(std * math.sqrt(2.))
|
304 |
+
tensor.add_(mean)
|
305 |
+
|
306 |
+
# Clamp to ensure it's in the proper range
|
307 |
+
tensor.clamp_(min=a, max=b)
|
308 |
+
return tensor
|
309 |
+
|
310 |
+
|
311 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
312 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
313 |
+
normal distribution.
|
314 |
+
|
315 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
316 |
+
|
317 |
+
The values are effectively drawn from the
|
318 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
319 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
320 |
+
the bounds. The method used for generating the random values works
|
321 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
tensor: an n-dimensional `torch.Tensor`
|
325 |
+
mean: the mean of the normal distribution
|
326 |
+
std: the standard deviation of the normal distribution
|
327 |
+
a: the minimum cutoff value
|
328 |
+
b: the maximum cutoff value
|
329 |
+
|
330 |
+
Examples:
|
331 |
+
>>> w = torch.empty(3, 5)
|
332 |
+
>>> nn.init.trunc_normal_(w)
|
333 |
+
"""
|
334 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
335 |
+
|
336 |
+
|
337 |
+
# From PyTorch
|
338 |
+
def _ntuple(n):
|
339 |
+
|
340 |
+
def parse(x):
|
341 |
+
if isinstance(x, collections.abc.Iterable):
|
342 |
+
return x
|
343 |
+
return tuple(repeat(x, n))
|
344 |
+
|
345 |
+
return parse
|
346 |
+
|
347 |
+
|
348 |
+
to_1tuple = _ntuple(1)
|
349 |
+
to_2tuple = _ntuple(2)
|
350 |
+
to_3tuple = _ntuple(3)
|
351 |
+
to_4tuple = _ntuple(4)
|
352 |
+
to_ntuple = _ntuple
|
StableSR/basicsr/archs/basicvsr_arch.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
|
7 |
+
from .edvr_arch import PCDAlignment, TSAFusion
|
8 |
+
from .spynet_arch import SpyNet
|
9 |
+
|
10 |
+
|
11 |
+
@ARCH_REGISTRY.register()
|
12 |
+
class BasicVSR(nn.Module):
|
13 |
+
"""A recurrent network for video SR. Now only x4 is supported.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_feat (int): Number of channels. Default: 64.
|
17 |
+
num_block (int): Number of residual blocks for each branch. Default: 15
|
18 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feat=64, num_block=15, spynet_path=None):
|
22 |
+
super().__init__()
|
23 |
+
self.num_feat = num_feat
|
24 |
+
|
25 |
+
# alignment
|
26 |
+
self.spynet = SpyNet(spynet_path)
|
27 |
+
|
28 |
+
# propagation
|
29 |
+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
30 |
+
self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
31 |
+
|
32 |
+
# reconstruction
|
33 |
+
self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
|
34 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
|
35 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
|
36 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
37 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
38 |
+
|
39 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
40 |
+
|
41 |
+
# activation functions
|
42 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
43 |
+
|
44 |
+
def get_flow(self, x):
|
45 |
+
b, n, c, h, w = x.size()
|
46 |
+
|
47 |
+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
|
48 |
+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
|
49 |
+
|
50 |
+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
|
51 |
+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
|
52 |
+
|
53 |
+
return flows_forward, flows_backward
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
"""Forward function of BasicVSR.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
|
60 |
+
"""
|
61 |
+
flows_forward, flows_backward = self.get_flow(x)
|
62 |
+
b, n, _, h, w = x.size()
|
63 |
+
|
64 |
+
# backward branch
|
65 |
+
out_l = []
|
66 |
+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
|
67 |
+
for i in range(n - 1, -1, -1):
|
68 |
+
x_i = x[:, i, :, :, :]
|
69 |
+
if i < n - 1:
|
70 |
+
flow = flows_backward[:, i, :, :, :]
|
71 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
72 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
73 |
+
feat_prop = self.backward_trunk(feat_prop)
|
74 |
+
out_l.insert(0, feat_prop)
|
75 |
+
|
76 |
+
# forward branch
|
77 |
+
feat_prop = torch.zeros_like(feat_prop)
|
78 |
+
for i in range(0, n):
|
79 |
+
x_i = x[:, i, :, :, :]
|
80 |
+
if i > 0:
|
81 |
+
flow = flows_forward[:, i - 1, :, :, :]
|
82 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
83 |
+
|
84 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
85 |
+
feat_prop = self.forward_trunk(feat_prop)
|
86 |
+
|
87 |
+
# upsample
|
88 |
+
out = torch.cat([out_l[i], feat_prop], dim=1)
|
89 |
+
out = self.lrelu(self.fusion(out))
|
90 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
91 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
92 |
+
out = self.lrelu(self.conv_hr(out))
|
93 |
+
out = self.conv_last(out)
|
94 |
+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
|
95 |
+
out += base
|
96 |
+
out_l[i] = out
|
97 |
+
|
98 |
+
return torch.stack(out_l, dim=1)
|
99 |
+
|
100 |
+
|
101 |
+
class ConvResidualBlocks(nn.Module):
|
102 |
+
"""Conv and residual block used in BasicVSR.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
num_in_ch (int): Number of input channels. Default: 3.
|
106 |
+
num_out_ch (int): Number of output channels. Default: 64.
|
107 |
+
num_block (int): Number of residual blocks. Default: 15.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
|
111 |
+
super().__init__()
|
112 |
+
self.main = nn.Sequential(
|
113 |
+
nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
114 |
+
make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
|
115 |
+
|
116 |
+
def forward(self, fea):
|
117 |
+
return self.main(fea)
|
118 |
+
|
119 |
+
|
120 |
+
@ARCH_REGISTRY.register()
|
121 |
+
class IconVSR(nn.Module):
|
122 |
+
"""IconVSR, proposed also in the BasicVSR paper.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
num_feat (int): Number of channels. Default: 64.
|
126 |
+
num_block (int): Number of residual blocks for each branch. Default: 15.
|
127 |
+
keyframe_stride (int): Keyframe stride. Default: 5.
|
128 |
+
temporal_padding (int): Temporal padding. Default: 2.
|
129 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
130 |
+
edvr_path (str): Path to the pretrained EDVR model. Default: None.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self,
|
134 |
+
num_feat=64,
|
135 |
+
num_block=15,
|
136 |
+
keyframe_stride=5,
|
137 |
+
temporal_padding=2,
|
138 |
+
spynet_path=None,
|
139 |
+
edvr_path=None):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.num_feat = num_feat
|
143 |
+
self.temporal_padding = temporal_padding
|
144 |
+
self.keyframe_stride = keyframe_stride
|
145 |
+
|
146 |
+
# keyframe_branch
|
147 |
+
self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
|
148 |
+
# alignment
|
149 |
+
self.spynet = SpyNet(spynet_path)
|
150 |
+
|
151 |
+
# propagation
|
152 |
+
self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
|
153 |
+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
154 |
+
|
155 |
+
self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
|
156 |
+
self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
|
157 |
+
|
158 |
+
# reconstruction
|
159 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
|
160 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
|
161 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
162 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
163 |
+
|
164 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
165 |
+
|
166 |
+
# activation functions
|
167 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
168 |
+
|
169 |
+
def pad_spatial(self, x):
|
170 |
+
"""Apply padding spatially.
|
171 |
+
|
172 |
+
Since the PCD module in EDVR requires that the resolution is a multiple
|
173 |
+
of 4, we apply padding to the input LR images if their resolution is
|
174 |
+
not divisible by 4.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
|
178 |
+
Returns:
|
179 |
+
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
|
180 |
+
"""
|
181 |
+
n, t, c, h, w = x.size()
|
182 |
+
|
183 |
+
pad_h = (4 - h % 4) % 4
|
184 |
+
pad_w = (4 - w % 4) % 4
|
185 |
+
|
186 |
+
# padding
|
187 |
+
x = x.view(-1, c, h, w)
|
188 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
189 |
+
|
190 |
+
return x.view(n, t, c, h + pad_h, w + pad_w)
|
191 |
+
|
192 |
+
def get_flow(self, x):
|
193 |
+
b, n, c, h, w = x.size()
|
194 |
+
|
195 |
+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
|
196 |
+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
|
197 |
+
|
198 |
+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
|
199 |
+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
|
200 |
+
|
201 |
+
return flows_forward, flows_backward
|
202 |
+
|
203 |
+
def get_keyframe_feature(self, x, keyframe_idx):
|
204 |
+
if self.temporal_padding == 2:
|
205 |
+
x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
|
206 |
+
elif self.temporal_padding == 3:
|
207 |
+
x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
|
208 |
+
x = torch.cat(x, dim=1)
|
209 |
+
|
210 |
+
num_frames = 2 * self.temporal_padding + 1
|
211 |
+
feats_keyframe = {}
|
212 |
+
for i in keyframe_idx:
|
213 |
+
feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
|
214 |
+
return feats_keyframe
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
b, n, _, h_input, w_input = x.size()
|
218 |
+
|
219 |
+
x = self.pad_spatial(x)
|
220 |
+
h, w = x.shape[3:]
|
221 |
+
|
222 |
+
keyframe_idx = list(range(0, n, self.keyframe_stride))
|
223 |
+
if keyframe_idx[-1] != n - 1:
|
224 |
+
keyframe_idx.append(n - 1) # last frame is a keyframe
|
225 |
+
|
226 |
+
# compute flow and keyframe features
|
227 |
+
flows_forward, flows_backward = self.get_flow(x)
|
228 |
+
feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
|
229 |
+
|
230 |
+
# backward branch
|
231 |
+
out_l = []
|
232 |
+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
|
233 |
+
for i in range(n - 1, -1, -1):
|
234 |
+
x_i = x[:, i, :, :, :]
|
235 |
+
if i < n - 1:
|
236 |
+
flow = flows_backward[:, i, :, :, :]
|
237 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
238 |
+
if i in keyframe_idx:
|
239 |
+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
|
240 |
+
feat_prop = self.backward_fusion(feat_prop)
|
241 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
242 |
+
feat_prop = self.backward_trunk(feat_prop)
|
243 |
+
out_l.insert(0, feat_prop)
|
244 |
+
|
245 |
+
# forward branch
|
246 |
+
feat_prop = torch.zeros_like(feat_prop)
|
247 |
+
for i in range(0, n):
|
248 |
+
x_i = x[:, i, :, :, :]
|
249 |
+
if i > 0:
|
250 |
+
flow = flows_forward[:, i - 1, :, :, :]
|
251 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
252 |
+
if i in keyframe_idx:
|
253 |
+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
|
254 |
+
feat_prop = self.forward_fusion(feat_prop)
|
255 |
+
|
256 |
+
feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
|
257 |
+
feat_prop = self.forward_trunk(feat_prop)
|
258 |
+
|
259 |
+
# upsample
|
260 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
|
261 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
262 |
+
out = self.lrelu(self.conv_hr(out))
|
263 |
+
out = self.conv_last(out)
|
264 |
+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
|
265 |
+
out += base
|
266 |
+
out_l[i] = out
|
267 |
+
|
268 |
+
return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
|
269 |
+
|
270 |
+
|
271 |
+
class EDVRFeatureExtractor(nn.Module):
|
272 |
+
"""EDVR feature extractor used in IconVSR.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
num_input_frame (int): Number of input frames.
|
276 |
+
num_feat (int): Number of feature channels
|
277 |
+
load_path (str): Path to the pretrained weights of EDVR. Default: None.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, num_input_frame, num_feat, load_path):
|
281 |
+
|
282 |
+
super(EDVRFeatureExtractor, self).__init__()
|
283 |
+
|
284 |
+
self.center_frame_idx = num_input_frame // 2
|
285 |
+
|
286 |
+
# extract pyramid features
|
287 |
+
self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
|
288 |
+
self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
|
289 |
+
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
290 |
+
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
291 |
+
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
292 |
+
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
293 |
+
|
294 |
+
# pcd and tsa module
|
295 |
+
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
|
296 |
+
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
|
297 |
+
|
298 |
+
# activation function
|
299 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
300 |
+
|
301 |
+
if load_path:
|
302 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
b, n, c, h, w = x.size()
|
306 |
+
|
307 |
+
# extract features for each frame
|
308 |
+
# L1
|
309 |
+
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
|
310 |
+
feat_l1 = self.feature_extraction(feat_l1)
|
311 |
+
# L2
|
312 |
+
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
|
313 |
+
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
|
314 |
+
# L3
|
315 |
+
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
|
316 |
+
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
|
317 |
+
|
318 |
+
feat_l1 = feat_l1.view(b, n, -1, h, w)
|
319 |
+
feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
|
320 |
+
feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
|
321 |
+
|
322 |
+
# PCD alignment
|
323 |
+
ref_feat_l = [ # reference feature list
|
324 |
+
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
|
325 |
+
feat_l3[:, self.center_frame_idx, :, :, :].clone()
|
326 |
+
]
|
327 |
+
aligned_feat = []
|
328 |
+
for i in range(n):
|
329 |
+
nbr_feat_l = [ # neighboring feature list
|
330 |
+
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
|
331 |
+
]
|
332 |
+
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
|
333 |
+
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
|
334 |
+
|
335 |
+
# TSA fusion
|
336 |
+
return self.fusion(aligned_feat)
|
StableSR/basicsr/archs/basicvsrpp_arch.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
from basicsr.archs.arch_util import flow_warp
|
8 |
+
from basicsr.archs.basicvsr_arch import ConvResidualBlocks
|
9 |
+
from basicsr.archs.spynet_arch import SpyNet
|
10 |
+
from basicsr.ops.dcn import ModulatedDeformConvPack
|
11 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
12 |
+
|
13 |
+
|
14 |
+
@ARCH_REGISTRY.register()
|
15 |
+
class BasicVSRPlusPlus(nn.Module):
|
16 |
+
"""BasicVSR++ network structure.
|
17 |
+
|
18 |
+
Support either x4 upsampling or same size output. Since DCN is used in this
|
19 |
+
model, it can only be used with CUDA enabled. If CUDA is not enabled,
|
20 |
+
feature alignment will be skipped. Besides, we adopt the official DCN
|
21 |
+
implementation and the version of torch need to be higher than 1.9.
|
22 |
+
|
23 |
+
``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
|
24 |
+
|
25 |
+
Args:
|
26 |
+
mid_channels (int, optional): Channel number of the intermediate
|
27 |
+
features. Default: 64.
|
28 |
+
num_blocks (int, optional): The number of residual blocks in each
|
29 |
+
propagation branch. Default: 7.
|
30 |
+
max_residue_magnitude (int): The maximum magnitude of the offset
|
31 |
+
residue (Eq. 6 in paper). Default: 10.
|
32 |
+
is_low_res_input (bool, optional): Whether the input is low-resolution
|
33 |
+
or not. If False, the output resolution is equal to the input
|
34 |
+
resolution. Default: True.
|
35 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
36 |
+
cpu_cache_length (int, optional): When the length of sequence is larger
|
37 |
+
than this value, the intermediate features are sent to CPU. This
|
38 |
+
saves GPU memory, but slows down the inference speed. You can
|
39 |
+
increase this number if you have a GPU with large memory.
|
40 |
+
Default: 100.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
mid_channels=64,
|
45 |
+
num_blocks=7,
|
46 |
+
max_residue_magnitude=10,
|
47 |
+
is_low_res_input=True,
|
48 |
+
spynet_path=None,
|
49 |
+
cpu_cache_length=100):
|
50 |
+
|
51 |
+
super().__init__()
|
52 |
+
self.mid_channels = mid_channels
|
53 |
+
self.is_low_res_input = is_low_res_input
|
54 |
+
self.cpu_cache_length = cpu_cache_length
|
55 |
+
|
56 |
+
# optical flow
|
57 |
+
self.spynet = SpyNet(spynet_path)
|
58 |
+
|
59 |
+
# feature extraction module
|
60 |
+
if is_low_res_input:
|
61 |
+
self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
|
62 |
+
else:
|
63 |
+
self.feat_extract = nn.Sequential(
|
64 |
+
nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
65 |
+
nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
66 |
+
ConvResidualBlocks(mid_channels, mid_channels, 5))
|
67 |
+
|
68 |
+
# propagation branches
|
69 |
+
self.deform_align = nn.ModuleDict()
|
70 |
+
self.backbone = nn.ModuleDict()
|
71 |
+
modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
|
72 |
+
for i, module in enumerate(modules):
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
self.deform_align[module] = SecondOrderDeformableAlignment(
|
75 |
+
2 * mid_channels,
|
76 |
+
mid_channels,
|
77 |
+
3,
|
78 |
+
padding=1,
|
79 |
+
deformable_groups=16,
|
80 |
+
max_residue_magnitude=max_residue_magnitude)
|
81 |
+
self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
|
82 |
+
|
83 |
+
# upsampling module
|
84 |
+
self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
|
85 |
+
|
86 |
+
self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
|
87 |
+
self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
|
88 |
+
|
89 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
90 |
+
|
91 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
92 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
93 |
+
self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
|
94 |
+
|
95 |
+
# activation function
|
96 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
97 |
+
|
98 |
+
# check if the sequence is augmented by flipping
|
99 |
+
self.is_mirror_extended = False
|
100 |
+
|
101 |
+
if len(self.deform_align) > 0:
|
102 |
+
self.is_with_alignment = True
|
103 |
+
else:
|
104 |
+
self.is_with_alignment = False
|
105 |
+
warnings.warn('Deformable alignment module is not added. '
|
106 |
+
'Probably your CUDA is not configured correctly. DCN can only '
|
107 |
+
'be used with CUDA enabled. Alignment is skipped now.')
|
108 |
+
|
109 |
+
def check_if_mirror_extended(self, lqs):
|
110 |
+
"""Check whether the input is a mirror-extended sequence.
|
111 |
+
|
112 |
+
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
|
116 |
+
"""
|
117 |
+
|
118 |
+
if lqs.size(1) % 2 == 0:
|
119 |
+
lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
|
120 |
+
if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
|
121 |
+
self.is_mirror_extended = True
|
122 |
+
|
123 |
+
def compute_flow(self, lqs):
|
124 |
+
"""Compute optical flow using SPyNet for feature alignment.
|
125 |
+
|
126 |
+
Note that if the input is an mirror-extended sequence, 'flows_forward'
|
127 |
+
is not needed, since it is equal to 'flows_backward.flip(1)'.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
131 |
+
shape (n, t, c, h, w).
|
132 |
+
|
133 |
+
Return:
|
134 |
+
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
|
135 |
+
(current to previous). 'flows_backward' corresponds to the flows used for backward-time \
|
136 |
+
propagation (current to next).
|
137 |
+
"""
|
138 |
+
|
139 |
+
n, t, c, h, w = lqs.size()
|
140 |
+
lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
|
141 |
+
lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
|
142 |
+
|
143 |
+
flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
|
144 |
+
|
145 |
+
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
|
146 |
+
flows_forward = flows_backward.flip(1)
|
147 |
+
else:
|
148 |
+
flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
|
149 |
+
|
150 |
+
if self.cpu_cache:
|
151 |
+
flows_backward = flows_backward.cpu()
|
152 |
+
flows_forward = flows_forward.cpu()
|
153 |
+
|
154 |
+
return flows_forward, flows_backward
|
155 |
+
|
156 |
+
def propagate(self, feats, flows, module_name):
|
157 |
+
"""Propagate the latent features throughout the sequence.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
feats dict(list[tensor]): Features from previous branches. Each
|
161 |
+
component is a list of tensors with shape (n, c, h, w).
|
162 |
+
flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
|
163 |
+
module_name (str): The name of the propgation branches. Can either
|
164 |
+
be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
|
165 |
+
|
166 |
+
Return:
|
167 |
+
dict(list[tensor]): A dictionary containing all the propagated \
|
168 |
+
features. Each key in the dictionary corresponds to a \
|
169 |
+
propagation branch, which is represented by a list of tensors.
|
170 |
+
"""
|
171 |
+
|
172 |
+
n, t, _, h, w = flows.size()
|
173 |
+
|
174 |
+
frame_idx = range(0, t + 1)
|
175 |
+
flow_idx = range(-1, t)
|
176 |
+
mapping_idx = list(range(0, len(feats['spatial'])))
|
177 |
+
mapping_idx += mapping_idx[::-1]
|
178 |
+
|
179 |
+
if 'backward' in module_name:
|
180 |
+
frame_idx = frame_idx[::-1]
|
181 |
+
flow_idx = frame_idx
|
182 |
+
|
183 |
+
feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
|
184 |
+
for i, idx in enumerate(frame_idx):
|
185 |
+
feat_current = feats['spatial'][mapping_idx[idx]]
|
186 |
+
if self.cpu_cache:
|
187 |
+
feat_current = feat_current.cuda()
|
188 |
+
feat_prop = feat_prop.cuda()
|
189 |
+
# second-order deformable alignment
|
190 |
+
if i > 0 and self.is_with_alignment:
|
191 |
+
flow_n1 = flows[:, flow_idx[i], :, :, :]
|
192 |
+
if self.cpu_cache:
|
193 |
+
flow_n1 = flow_n1.cuda()
|
194 |
+
|
195 |
+
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
|
196 |
+
|
197 |
+
# initialize second-order features
|
198 |
+
feat_n2 = torch.zeros_like(feat_prop)
|
199 |
+
flow_n2 = torch.zeros_like(flow_n1)
|
200 |
+
cond_n2 = torch.zeros_like(cond_n1)
|
201 |
+
|
202 |
+
if i > 1: # second-order features
|
203 |
+
feat_n2 = feats[module_name][-2]
|
204 |
+
if self.cpu_cache:
|
205 |
+
feat_n2 = feat_n2.cuda()
|
206 |
+
|
207 |
+
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
|
208 |
+
if self.cpu_cache:
|
209 |
+
flow_n2 = flow_n2.cuda()
|
210 |
+
|
211 |
+
flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
|
212 |
+
cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
|
213 |
+
|
214 |
+
# flow-guided deformable convolution
|
215 |
+
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
|
216 |
+
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
|
217 |
+
feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
|
218 |
+
|
219 |
+
# concatenate and residual blocks
|
220 |
+
feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
|
221 |
+
if self.cpu_cache:
|
222 |
+
feat = [f.cuda() for f in feat]
|
223 |
+
|
224 |
+
feat = torch.cat(feat, dim=1)
|
225 |
+
feat_prop = feat_prop + self.backbone[module_name](feat)
|
226 |
+
feats[module_name].append(feat_prop)
|
227 |
+
|
228 |
+
if self.cpu_cache:
|
229 |
+
feats[module_name][-1] = feats[module_name][-1].cpu()
|
230 |
+
torch.cuda.empty_cache()
|
231 |
+
|
232 |
+
if 'backward' in module_name:
|
233 |
+
feats[module_name] = feats[module_name][::-1]
|
234 |
+
|
235 |
+
return feats
|
236 |
+
|
237 |
+
def upsample(self, lqs, feats):
|
238 |
+
"""Compute the output image given the features.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
242 |
+
shape (n, t, c, h, w).
|
243 |
+
feats (dict): The features from the propagation branches.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
|
247 |
+
"""
|
248 |
+
|
249 |
+
outputs = []
|
250 |
+
num_outputs = len(feats['spatial'])
|
251 |
+
|
252 |
+
mapping_idx = list(range(0, num_outputs))
|
253 |
+
mapping_idx += mapping_idx[::-1]
|
254 |
+
|
255 |
+
for i in range(0, lqs.size(1)):
|
256 |
+
hr = [feats[k].pop(0) for k in feats if k != 'spatial']
|
257 |
+
hr.insert(0, feats['spatial'][mapping_idx[i]])
|
258 |
+
hr = torch.cat(hr, dim=1)
|
259 |
+
if self.cpu_cache:
|
260 |
+
hr = hr.cuda()
|
261 |
+
|
262 |
+
hr = self.reconstruction(hr)
|
263 |
+
hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
|
264 |
+
hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
|
265 |
+
hr = self.lrelu(self.conv_hr(hr))
|
266 |
+
hr = self.conv_last(hr)
|
267 |
+
if self.is_low_res_input:
|
268 |
+
hr += self.img_upsample(lqs[:, i, :, :, :])
|
269 |
+
else:
|
270 |
+
hr += lqs[:, i, :, :, :]
|
271 |
+
|
272 |
+
if self.cpu_cache:
|
273 |
+
hr = hr.cpu()
|
274 |
+
torch.cuda.empty_cache()
|
275 |
+
|
276 |
+
outputs.append(hr)
|
277 |
+
|
278 |
+
return torch.stack(outputs, dim=1)
|
279 |
+
|
280 |
+
def forward(self, lqs):
|
281 |
+
"""Forward function for BasicVSR++.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
285 |
+
shape (n, t, c, h, w).
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
|
289 |
+
"""
|
290 |
+
|
291 |
+
n, t, c, h, w = lqs.size()
|
292 |
+
|
293 |
+
# whether to cache the features in CPU
|
294 |
+
self.cpu_cache = True if t > self.cpu_cache_length else False
|
295 |
+
|
296 |
+
if self.is_low_res_input:
|
297 |
+
lqs_downsample = lqs.clone()
|
298 |
+
else:
|
299 |
+
lqs_downsample = F.interpolate(
|
300 |
+
lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
|
301 |
+
|
302 |
+
# check whether the input is an extended sequence
|
303 |
+
self.check_if_mirror_extended(lqs)
|
304 |
+
|
305 |
+
feats = {}
|
306 |
+
# compute spatial features
|
307 |
+
if self.cpu_cache:
|
308 |
+
feats['spatial'] = []
|
309 |
+
for i in range(0, t):
|
310 |
+
feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
|
311 |
+
feats['spatial'].append(feat)
|
312 |
+
torch.cuda.empty_cache()
|
313 |
+
else:
|
314 |
+
feats_ = self.feat_extract(lqs.view(-1, c, h, w))
|
315 |
+
h, w = feats_.shape[2:]
|
316 |
+
feats_ = feats_.view(n, t, -1, h, w)
|
317 |
+
feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
|
318 |
+
|
319 |
+
# compute optical flow using the low-res inputs
|
320 |
+
assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
|
321 |
+
'The height and width of low-res inputs must be at least 64, '
|
322 |
+
f'but got {h} and {w}.')
|
323 |
+
flows_forward, flows_backward = self.compute_flow(lqs_downsample)
|
324 |
+
|
325 |
+
# feature propgation
|
326 |
+
for iter_ in [1, 2]:
|
327 |
+
for direction in ['backward', 'forward']:
|
328 |
+
module = f'{direction}_{iter_}'
|
329 |
+
|
330 |
+
feats[module] = []
|
331 |
+
|
332 |
+
if direction == 'backward':
|
333 |
+
flows = flows_backward
|
334 |
+
elif flows_forward is not None:
|
335 |
+
flows = flows_forward
|
336 |
+
else:
|
337 |
+
flows = flows_backward.flip(1)
|
338 |
+
|
339 |
+
feats = self.propagate(feats, flows, module)
|
340 |
+
if self.cpu_cache:
|
341 |
+
del flows
|
342 |
+
torch.cuda.empty_cache()
|
343 |
+
|
344 |
+
return self.upsample(lqs, feats)
|
345 |
+
|
346 |
+
|
347 |
+
class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
|
348 |
+
"""Second-order deformable alignment module.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
in_channels (int): Same as nn.Conv2d.
|
352 |
+
out_channels (int): Same as nn.Conv2d.
|
353 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
354 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
355 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
356 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
357 |
+
groups (int): Same as nn.Conv2d.
|
358 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
359 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
360 |
+
False.
|
361 |
+
max_residue_magnitude (int): The maximum magnitude of the offset
|
362 |
+
residue (Eq. 6 in paper). Default: 10.
|
363 |
+
"""
|
364 |
+
|
365 |
+
def __init__(self, *args, **kwargs):
|
366 |
+
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
367 |
+
|
368 |
+
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
|
369 |
+
|
370 |
+
self.conv_offset = nn.Sequential(
|
371 |
+
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
|
372 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
373 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
374 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
375 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
376 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
377 |
+
nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
|
378 |
+
)
|
379 |
+
|
380 |
+
self.init_offset()
|
381 |
+
|
382 |
+
def init_offset(self):
|
383 |
+
|
384 |
+
def _constant_init(module, val, bias=0):
|
385 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
386 |
+
nn.init.constant_(module.weight, val)
|
387 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
388 |
+
nn.init.constant_(module.bias, bias)
|
389 |
+
|
390 |
+
_constant_init(self.conv_offset[-1], val=0, bias=0)
|
391 |
+
|
392 |
+
def forward(self, x, extra_feat, flow_1, flow_2):
|
393 |
+
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
|
394 |
+
out = self.conv_offset(extra_feat)
|
395 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
396 |
+
|
397 |
+
# offset
|
398 |
+
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
|
399 |
+
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
|
400 |
+
offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
|
401 |
+
offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
|
402 |
+
offset = torch.cat([offset_1, offset_2], dim=1)
|
403 |
+
|
404 |
+
# mask
|
405 |
+
mask = torch.sigmoid(mask)
|
406 |
+
|
407 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
408 |
+
self.dilation, mask)
|
409 |
+
|
410 |
+
|
411 |
+
# if __name__ == '__main__':
|
412 |
+
# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
|
413 |
+
# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
|
414 |
+
# input = torch.rand(1, 2, 3, 64, 64).cuda()
|
415 |
+
# output = model(input)
|
416 |
+
# print('===================')
|
417 |
+
# print(output.shape)
|
StableSR/basicsr/archs/degradat_arch.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
|
3 |
+
from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
|
6 |
+
@ARCH_REGISTRY.register()
|
7 |
+
class DEResNet(nn.Module):
|
8 |
+
"""Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore
|
9 |
+
As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
|
10 |
+
resnet arch works for image quality estimation.
|
11 |
+
Args:
|
12 |
+
num_in_ch (int): channel number of inputs. Default: 3.
|
13 |
+
num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
|
14 |
+
degradation_embed_size (int): embedding size of each degradation vector.
|
15 |
+
degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
|
16 |
+
num_feats (list): channel number of each stage.
|
17 |
+
num_blocks (list): residual block of each stage.
|
18 |
+
downscales (list): downscales of each stage.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
num_in_ch=3,
|
23 |
+
num_degradation=2,
|
24 |
+
degradation_degree_actv='sigmoid',
|
25 |
+
num_feats=(64, 128, 256, 512),
|
26 |
+
num_blocks=(2, 2, 2, 2),
|
27 |
+
downscales=(2, 2, 2, 1)):
|
28 |
+
super(DEResNet, self).__init__()
|
29 |
+
|
30 |
+
assert isinstance(num_feats, list)
|
31 |
+
assert isinstance(num_blocks, list)
|
32 |
+
assert isinstance(downscales, list)
|
33 |
+
assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)
|
34 |
+
|
35 |
+
num_stage = len(num_feats)
|
36 |
+
|
37 |
+
self.conv_first = nn.ModuleList()
|
38 |
+
for _ in range(num_degradation):
|
39 |
+
self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
|
40 |
+
self.body = nn.ModuleList()
|
41 |
+
for _ in range(num_degradation):
|
42 |
+
body = list()
|
43 |
+
for stage in range(num_stage):
|
44 |
+
for _ in range(num_blocks[stage]):
|
45 |
+
body.append(ResidualBlockNoBN(num_feats[stage]))
|
46 |
+
if downscales[stage] == 1:
|
47 |
+
if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
|
48 |
+
body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
|
49 |
+
continue
|
50 |
+
elif downscales[stage] == 2:
|
51 |
+
body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
|
52 |
+
else:
|
53 |
+
raise NotImplementedError
|
54 |
+
self.body.append(nn.Sequential(*body))
|
55 |
+
|
56 |
+
# self.body = nn.Sequential(*body)
|
57 |
+
|
58 |
+
self.num_degradation = num_degradation
|
59 |
+
self.fc_degree = nn.ModuleList()
|
60 |
+
if degradation_degree_actv == 'sigmoid':
|
61 |
+
actv = nn.Sigmoid
|
62 |
+
elif degradation_degree_actv == 'tanh':
|
63 |
+
actv = nn.Tanh
|
64 |
+
else:
|
65 |
+
raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
|
66 |
+
f'{degradation_degree_actv} is not supported yet.')
|
67 |
+
for _ in range(num_degradation):
|
68 |
+
self.fc_degree.append(
|
69 |
+
nn.Sequential(
|
70 |
+
nn.Linear(num_feats[-1], 512),
|
71 |
+
nn.ReLU(inplace=True),
|
72 |
+
nn.Linear(512, 1),
|
73 |
+
actv(),
|
74 |
+
))
|
75 |
+
|
76 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
77 |
+
|
78 |
+
default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
degrees = []
|
82 |
+
for i in range(self.num_degradation):
|
83 |
+
x_out = self.conv_first[i](x)
|
84 |
+
feat = self.body[i](x_out)
|
85 |
+
feat = self.avg_pool(feat)
|
86 |
+
feat = feat.squeeze(-1).squeeze(-1)
|
87 |
+
# for i in range(self.num_degradation):
|
88 |
+
degrees.append(self.fc_degree[i](feat).squeeze(-1))
|
89 |
+
|
90 |
+
return degrees
|
StableSR/basicsr/archs/dfdnet_arch.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
6 |
+
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
|
9 |
+
from .vgg_arch import VGGFeatureExtractor
|
10 |
+
|
11 |
+
|
12 |
+
class SFTUpBlock(nn.Module):
|
13 |
+
"""Spatial feature transform (SFT) with upsampling block.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
in_channel (int): Number of input channels.
|
17 |
+
out_channel (int): Number of output channels.
|
18 |
+
kernel_size (int): Kernel size in convolutions. Default: 3.
|
19 |
+
padding (int): Padding in convolutions. Default: 1.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
|
23 |
+
super(SFTUpBlock, self).__init__()
|
24 |
+
self.conv1 = nn.Sequential(
|
25 |
+
Blur(in_channel),
|
26 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
27 |
+
nn.LeakyReLU(0.04, True),
|
28 |
+
# The official codes use two LeakyReLU here, so 0.04 for equivalent
|
29 |
+
)
|
30 |
+
self.convup = nn.Sequential(
|
31 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
32 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
33 |
+
nn.LeakyReLU(0.2, True),
|
34 |
+
)
|
35 |
+
|
36 |
+
# for SFT scale and shift
|
37 |
+
self.scale_block = nn.Sequential(
|
38 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
39 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
|
40 |
+
self.shift_block = nn.Sequential(
|
41 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
42 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
|
43 |
+
# The official codes use sigmoid for shift block, do not know why
|
44 |
+
|
45 |
+
def forward(self, x, updated_feat):
|
46 |
+
out = self.conv1(x)
|
47 |
+
# SFT
|
48 |
+
scale = self.scale_block(updated_feat)
|
49 |
+
shift = self.shift_block(updated_feat)
|
50 |
+
out = out * scale + shift
|
51 |
+
# upsample
|
52 |
+
out = self.convup(out)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
@ARCH_REGISTRY.register()
|
57 |
+
class DFDNet(nn.Module):
|
58 |
+
"""DFDNet: Deep Face Dictionary Network.
|
59 |
+
|
60 |
+
It only processes faces with 512x512 size.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
num_feat (int): Number of feature channels.
|
64 |
+
dict_path (str): Path to the facial component dictionary.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, num_feat, dict_path):
|
68 |
+
super().__init__()
|
69 |
+
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
|
70 |
+
# part_sizes: [80, 80, 50, 110]
|
71 |
+
channel_sizes = [128, 256, 512, 512]
|
72 |
+
self.feature_sizes = np.array([256, 128, 64, 32])
|
73 |
+
self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
|
74 |
+
self.flag_dict_device = False
|
75 |
+
|
76 |
+
# dict
|
77 |
+
self.dict = torch.load(dict_path)
|
78 |
+
|
79 |
+
# vgg face extractor
|
80 |
+
self.vgg_extractor = VGGFeatureExtractor(
|
81 |
+
layer_name_list=self.vgg_layers,
|
82 |
+
vgg_type='vgg19',
|
83 |
+
use_input_norm=True,
|
84 |
+
range_norm=True,
|
85 |
+
requires_grad=False)
|
86 |
+
|
87 |
+
# attention block for fusing dictionary features and input features
|
88 |
+
self.attn_blocks = nn.ModuleDict()
|
89 |
+
for idx, feat_size in enumerate(self.feature_sizes):
|
90 |
+
for name in self.parts:
|
91 |
+
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
|
92 |
+
|
93 |
+
# multi scale dilation block
|
94 |
+
self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
|
95 |
+
|
96 |
+
# upsampling and reconstruction
|
97 |
+
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
|
98 |
+
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
|
99 |
+
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
|
100 |
+
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
|
101 |
+
self.upsample4 = nn.Sequential(
|
102 |
+
spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
|
103 |
+
UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
|
104 |
+
|
105 |
+
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
|
106 |
+
"""swap the features from the dictionary."""
|
107 |
+
# get the original vgg features
|
108 |
+
part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
|
109 |
+
# resize original vgg features
|
110 |
+
part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
|
111 |
+
# use adaptive instance normalization to adjust color and illuminations
|
112 |
+
dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
|
113 |
+
# get similarity scores
|
114 |
+
similarity_score = F.conv2d(part_resize_feat, dict_feat)
|
115 |
+
similarity_score = F.softmax(similarity_score.view(-1), dim=0)
|
116 |
+
# select the most similar features in the dict (after norm)
|
117 |
+
select_idx = torch.argmax(similarity_score)
|
118 |
+
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
|
119 |
+
# attention
|
120 |
+
attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
|
121 |
+
attn_feat = attn * swap_feat
|
122 |
+
# update features
|
123 |
+
updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
|
124 |
+
return updated_feat
|
125 |
+
|
126 |
+
def put_dict_to_device(self, x):
|
127 |
+
if self.flag_dict_device is False:
|
128 |
+
for k, v in self.dict.items():
|
129 |
+
for kk, vv in v.items():
|
130 |
+
self.dict[k][kk] = vv.to(x)
|
131 |
+
self.flag_dict_device = True
|
132 |
+
|
133 |
+
def forward(self, x, part_locations):
|
134 |
+
"""
|
135 |
+
Now only support testing with batch size = 0.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
x (Tensor): Input faces with shape (b, c, 512, 512).
|
139 |
+
part_locations (list[Tensor]): Part locations.
|
140 |
+
"""
|
141 |
+
self.put_dict_to_device(x)
|
142 |
+
# extract vggface features
|
143 |
+
vgg_features = self.vgg_extractor(x)
|
144 |
+
# update vggface features using the dictionary for each part
|
145 |
+
updated_vgg_features = []
|
146 |
+
batch = 0 # only supports testing with batch size = 0
|
147 |
+
for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
|
148 |
+
dict_features = self.dict[f'{f_size}']
|
149 |
+
vgg_feat = vgg_features[vgg_layer]
|
150 |
+
updated_feat = vgg_feat.clone()
|
151 |
+
|
152 |
+
# swap features from dictionary
|
153 |
+
for part_idx, part_name in enumerate(self.parts):
|
154 |
+
location = (part_locations[part_idx][batch] // (512 / f_size)).int()
|
155 |
+
updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
|
156 |
+
f_size)
|
157 |
+
|
158 |
+
updated_vgg_features.append(updated_feat)
|
159 |
+
|
160 |
+
vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
|
161 |
+
# use updated vgg features to modulate the upsampled features with
|
162 |
+
# SFT (Spatial Feature Transform) scaling and shifting manner.
|
163 |
+
upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
|
164 |
+
upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
|
165 |
+
upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
|
166 |
+
upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
|
167 |
+
out = self.upsample4(upsampled_feat)
|
168 |
+
|
169 |
+
return out
|
StableSR/basicsr/archs/dfdnet_util.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
6 |
+
|
7 |
+
|
8 |
+
class BlurFunctionBackward(Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, grad_output, kernel, kernel_flip):
|
12 |
+
ctx.save_for_backward(kernel, kernel_flip)
|
13 |
+
grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
|
14 |
+
return grad_input
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def backward(ctx, gradgrad_output):
|
18 |
+
kernel, _ = ctx.saved_tensors
|
19 |
+
grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
|
20 |
+
return grad_input, None, None
|
21 |
+
|
22 |
+
|
23 |
+
class BlurFunction(Function):
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def forward(ctx, x, kernel, kernel_flip):
|
27 |
+
ctx.save_for_backward(kernel, kernel_flip)
|
28 |
+
output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
|
29 |
+
return output
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def backward(ctx, grad_output):
|
33 |
+
kernel, kernel_flip = ctx.saved_tensors
|
34 |
+
grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
|
35 |
+
return grad_input, None, None
|
36 |
+
|
37 |
+
|
38 |
+
blur = BlurFunction.apply
|
39 |
+
|
40 |
+
|
41 |
+
class Blur(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self, channel):
|
44 |
+
super().__init__()
|
45 |
+
kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
|
46 |
+
kernel = kernel.view(1, 1, 3, 3)
|
47 |
+
kernel = kernel / kernel.sum()
|
48 |
+
kernel_flip = torch.flip(kernel, [2, 3])
|
49 |
+
|
50 |
+
self.kernel = kernel.repeat(channel, 1, 1, 1)
|
51 |
+
self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
|
55 |
+
|
56 |
+
|
57 |
+
def calc_mean_std(feat, eps=1e-5):
|
58 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
feat (Tensor): 4D tensor.
|
62 |
+
eps (float): A small value added to the variance to avoid
|
63 |
+
divide-by-zero. Default: 1e-5.
|
64 |
+
"""
|
65 |
+
size = feat.size()
|
66 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
67 |
+
n, c = size[:2]
|
68 |
+
feat_var = feat.view(n, c, -1).var(dim=2) + eps
|
69 |
+
feat_std = feat_var.sqrt().view(n, c, 1, 1)
|
70 |
+
feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
|
71 |
+
return feat_mean, feat_std
|
72 |
+
|
73 |
+
|
74 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
75 |
+
"""Adaptive instance normalization.
|
76 |
+
|
77 |
+
Adjust the reference features to have the similar color and illuminations
|
78 |
+
as those in the degradate features.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
content_feat (Tensor): The reference feature.
|
82 |
+
style_feat (Tensor): The degradate features.
|
83 |
+
"""
|
84 |
+
size = content_feat.size()
|
85 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
86 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
87 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
88 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
89 |
+
|
90 |
+
|
91 |
+
def AttentionBlock(in_channel):
|
92 |
+
return nn.Sequential(
|
93 |
+
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
94 |
+
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
|
95 |
+
|
96 |
+
|
97 |
+
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
|
98 |
+
"""Conv block used in MSDilationBlock."""
|
99 |
+
|
100 |
+
return nn.Sequential(
|
101 |
+
spectral_norm(
|
102 |
+
nn.Conv2d(
|
103 |
+
in_channels,
|
104 |
+
out_channels,
|
105 |
+
kernel_size=kernel_size,
|
106 |
+
stride=stride,
|
107 |
+
dilation=dilation,
|
108 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
109 |
+
bias=bias)),
|
110 |
+
nn.LeakyReLU(0.2),
|
111 |
+
spectral_norm(
|
112 |
+
nn.Conv2d(
|
113 |
+
out_channels,
|
114 |
+
out_channels,
|
115 |
+
kernel_size=kernel_size,
|
116 |
+
stride=stride,
|
117 |
+
dilation=dilation,
|
118 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
119 |
+
bias=bias)),
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
class MSDilationBlock(nn.Module):
|
124 |
+
"""Multi-scale dilation block."""
|
125 |
+
|
126 |
+
def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
|
127 |
+
super(MSDilationBlock, self).__init__()
|
128 |
+
|
129 |
+
self.conv_blocks = nn.ModuleList()
|
130 |
+
for i in range(4):
|
131 |
+
self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
|
132 |
+
self.conv_fusion = spectral_norm(
|
133 |
+
nn.Conv2d(
|
134 |
+
in_channels * 4,
|
135 |
+
in_channels,
|
136 |
+
kernel_size=kernel_size,
|
137 |
+
stride=1,
|
138 |
+
padding=(kernel_size - 1) // 2,
|
139 |
+
bias=bias))
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
out = []
|
143 |
+
for i in range(4):
|
144 |
+
out.append(self.conv_blocks[i](x))
|
145 |
+
out = torch.cat(out, 1)
|
146 |
+
out = self.conv_fusion(out) + x
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
class UpResBlock(nn.Module):
|
151 |
+
|
152 |
+
def __init__(self, in_channel):
|
153 |
+
super(UpResBlock, self).__init__()
|
154 |
+
self.body = nn.Sequential(
|
155 |
+
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
|
156 |
+
nn.LeakyReLU(0.2, True),
|
157 |
+
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
|
158 |
+
)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
out = x + self.body(x)
|
162 |
+
return out
|
StableSR/basicsr/archs/discriminator_arch.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from torch.nn.utils import spectral_norm
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class VGGStyleDiscriminator(nn.Module):
|
10 |
+
"""VGG style discriminator with input size 128 x 128 or 256 x 256.
|
11 |
+
|
12 |
+
It is used to train SRGAN, ESRGAN, and VideoGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
16 |
+
num_feat (int): Channel number of base intermediate features.Default: 64.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_in_ch, num_feat, input_size=128):
|
20 |
+
super(VGGStyleDiscriminator, self).__init__()
|
21 |
+
self.input_size = input_size
|
22 |
+
assert self.input_size == 128 or self.input_size == 256, (
|
23 |
+
f'input size must be 128 or 256, but received {input_size}')
|
24 |
+
|
25 |
+
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
|
26 |
+
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
|
27 |
+
self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
|
28 |
+
|
29 |
+
self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
|
30 |
+
self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
|
31 |
+
self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
|
32 |
+
self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
|
33 |
+
|
34 |
+
self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
|
35 |
+
self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
|
36 |
+
self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
|
37 |
+
self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
|
38 |
+
|
39 |
+
self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
|
40 |
+
self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
41 |
+
self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
42 |
+
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
43 |
+
|
44 |
+
self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
|
45 |
+
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
46 |
+
self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
47 |
+
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
48 |
+
|
49 |
+
if self.input_size == 256:
|
50 |
+
self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
|
51 |
+
self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
52 |
+
self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
53 |
+
self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
54 |
+
|
55 |
+
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
|
56 |
+
self.linear2 = nn.Linear(100, 1)
|
57 |
+
|
58 |
+
# activation function
|
59 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
|
63 |
+
|
64 |
+
feat = self.lrelu(self.conv0_0(x))
|
65 |
+
feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
|
66 |
+
|
67 |
+
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
|
68 |
+
feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
|
69 |
+
|
70 |
+
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
|
71 |
+
feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
|
72 |
+
|
73 |
+
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
|
74 |
+
feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
|
75 |
+
|
76 |
+
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
|
77 |
+
feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
|
78 |
+
|
79 |
+
if self.input_size == 256:
|
80 |
+
feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
|
81 |
+
feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
|
82 |
+
|
83 |
+
# spatial size: (4, 4)
|
84 |
+
feat = feat.view(feat.size(0), -1)
|
85 |
+
feat = self.lrelu(self.linear1(feat))
|
86 |
+
out = self.linear2(feat)
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
@ARCH_REGISTRY.register(suffix='basicsr')
|
91 |
+
class UNetDiscriminatorSN(nn.Module):
|
92 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
93 |
+
|
94 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
95 |
+
|
96 |
+
Arg:
|
97 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
98 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
99 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
103 |
+
super(UNetDiscriminatorSN, self).__init__()
|
104 |
+
self.skip_connection = skip_connection
|
105 |
+
norm = spectral_norm
|
106 |
+
# the first convolution
|
107 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
108 |
+
# downsample
|
109 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
110 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
111 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
112 |
+
# upsample
|
113 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
114 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
115 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
116 |
+
# extra convolutions
|
117 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
118 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
119 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
# downsample
|
123 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
124 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
125 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
126 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
127 |
+
|
128 |
+
# upsample
|
129 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
130 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
131 |
+
|
132 |
+
if self.skip_connection:
|
133 |
+
x4 = x4 + x2
|
134 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
135 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
136 |
+
|
137 |
+
if self.skip_connection:
|
138 |
+
x5 = x5 + x1
|
139 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
140 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
141 |
+
|
142 |
+
if self.skip_connection:
|
143 |
+
x6 = x6 + x0
|
144 |
+
|
145 |
+
# extra convolutions
|
146 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
147 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
148 |
+
out = self.conv9(out)
|
149 |
+
|
150 |
+
return out
|
StableSR/basicsr/archs/duf_arch.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
|
9 |
+
class DenseBlocksTemporalReduce(nn.Module):
|
10 |
+
"""A concatenation of 3 dense blocks with reduction in temporal dimension.
|
11 |
+
|
12 |
+
Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feat (int): Number of channels in the blocks. Default: 64.
|
16 |
+
num_grow_ch (int): Growing factor of the dense blocks. Default: 32
|
17 |
+
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
|
18 |
+
Set to false if you want to train from scratch. Default: False.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
|
22 |
+
super(DenseBlocksTemporalReduce, self).__init__()
|
23 |
+
if adapt_official_weights:
|
24 |
+
eps = 1e-3
|
25 |
+
momentum = 1e-3
|
26 |
+
else: # pytorch default values
|
27 |
+
eps = 1e-05
|
28 |
+
momentum = 0.1
|
29 |
+
|
30 |
+
self.temporal_reduce1 = nn.Sequential(
|
31 |
+
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
32 |
+
nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
|
33 |
+
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
34 |
+
nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
35 |
+
|
36 |
+
self.temporal_reduce2 = nn.Sequential(
|
37 |
+
nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
38 |
+
nn.Conv3d(
|
39 |
+
num_feat + num_grow_ch,
|
40 |
+
num_feat + num_grow_ch, (1, 1, 1),
|
41 |
+
stride=(1, 1, 1),
|
42 |
+
padding=(0, 0, 0),
|
43 |
+
bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
44 |
+
nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
45 |
+
|
46 |
+
self.temporal_reduce3 = nn.Sequential(
|
47 |
+
nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
48 |
+
nn.Conv3d(
|
49 |
+
num_feat + 2 * num_grow_ch,
|
50 |
+
num_feat + 2 * num_grow_ch, (1, 1, 1),
|
51 |
+
stride=(1, 1, 1),
|
52 |
+
padding=(0, 0, 0),
|
53 |
+
bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
|
54 |
+
nn.ReLU(inplace=True),
|
55 |
+
nn.Conv3d(
|
56 |
+
num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
|
65 |
+
"""
|
66 |
+
x1 = self.temporal_reduce1(x)
|
67 |
+
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
|
68 |
+
|
69 |
+
x2 = self.temporal_reduce2(x1)
|
70 |
+
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
|
71 |
+
|
72 |
+
x3 = self.temporal_reduce3(x2)
|
73 |
+
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
|
74 |
+
|
75 |
+
return x3
|
76 |
+
|
77 |
+
|
78 |
+
class DenseBlocks(nn.Module):
|
79 |
+
""" A concatenation of N dense blocks.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
num_feat (int): Number of channels in the blocks. Default: 64.
|
83 |
+
num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
|
84 |
+
num_block (int): Number of dense blocks. The values are:
|
85 |
+
DUF-S (16 layers): 3
|
86 |
+
DUF-M (18 layers): 9
|
87 |
+
DUF-L (52 layers): 21
|
88 |
+
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
|
89 |
+
Set to false if you want to train from scratch. Default: False.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
|
93 |
+
super(DenseBlocks, self).__init__()
|
94 |
+
if adapt_official_weights:
|
95 |
+
eps = 1e-3
|
96 |
+
momentum = 1e-3
|
97 |
+
else: # pytorch default values
|
98 |
+
eps = 1e-05
|
99 |
+
momentum = 0.1
|
100 |
+
|
101 |
+
self.dense_blocks = nn.ModuleList()
|
102 |
+
for i in range(0, num_block):
|
103 |
+
self.dense_blocks.append(
|
104 |
+
nn.Sequential(
|
105 |
+
nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
106 |
+
nn.Conv3d(
|
107 |
+
num_feat + i * num_grow_ch,
|
108 |
+
num_feat + i * num_grow_ch, (1, 1, 1),
|
109 |
+
stride=(1, 1, 1),
|
110 |
+
padding=(0, 0, 0),
|
111 |
+
bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
|
112 |
+
nn.ReLU(inplace=True),
|
113 |
+
nn.Conv3d(
|
114 |
+
num_feat + i * num_grow_ch,
|
115 |
+
num_grow_ch, (3, 3, 3),
|
116 |
+
stride=(1, 1, 1),
|
117 |
+
padding=(1, 1, 1),
|
118 |
+
bias=True)))
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
|
127 |
+
"""
|
128 |
+
for i in range(0, len(self.dense_blocks)):
|
129 |
+
y = self.dense_blocks[i](x)
|
130 |
+
x = torch.cat((x, y), 1)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class DynamicUpsamplingFilter(nn.Module):
|
135 |
+
"""Dynamic upsampling filter used in DUF.
|
136 |
+
|
137 |
+
Reference: https://github.com/yhjo09/VSR-DUF
|
138 |
+
|
139 |
+
It only supports input with 3 channels. And it applies the same filters to 3 channels.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, filter_size=(5, 5)):
|
146 |
+
super(DynamicUpsamplingFilter, self).__init__()
|
147 |
+
if not isinstance(filter_size, tuple):
|
148 |
+
raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
|
149 |
+
if len(filter_size) != 2:
|
150 |
+
raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
|
151 |
+
# generate a local expansion filter, similar to im2col
|
152 |
+
self.filter_size = filter_size
|
153 |
+
filter_prod = np.prod(filter_size)
|
154 |
+
expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
|
155 |
+
self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
|
156 |
+
|
157 |
+
def forward(self, x, filters):
|
158 |
+
"""Forward function for DynamicUpsamplingFilter.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
|
162 |
+
filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
|
163 |
+
filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
|
164 |
+
upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
|
165 |
+
e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
|
169 |
+
"""
|
170 |
+
n, filter_prod, upsampling_square, h, w = filters.size()
|
171 |
+
kh, kw = self.filter_size
|
172 |
+
expanded_input = F.conv2d(
|
173 |
+
x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
|
174 |
+
expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
|
175 |
+
2) # (n, h, w, 3, filter_prod)
|
176 |
+
filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
|
177 |
+
out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
|
178 |
+
return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
|
179 |
+
|
180 |
+
|
181 |
+
@ARCH_REGISTRY.register()
|
182 |
+
class DUF(nn.Module):
|
183 |
+
"""Network architecture for DUF
|
184 |
+
|
185 |
+
``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
|
186 |
+
|
187 |
+
Reference: https://github.com/yhjo09/VSR-DUF
|
188 |
+
|
189 |
+
For all the models below, 'adapt_official_weights' is only necessary when
|
190 |
+
loading the weights converted from the official TensorFlow weights.
|
191 |
+
Please set it to False if you are training the model from scratch.
|
192 |
+
|
193 |
+
There are three models with different model size: DUF16Layers, DUF28Layers,
|
194 |
+
and DUF52Layers. This class is the base class for these models.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
scale (int): The upsampling factor. Default: 4.
|
198 |
+
num_layer (int): The number of layers. Default: 52.
|
199 |
+
adapt_official_weights_weights (bool): Whether to adapt the weights
|
200 |
+
translated from the official implementation. Set to false if you
|
201 |
+
want to train from scratch. Default: False.
|
202 |
+
"""
|
203 |
+
|
204 |
+
def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
|
205 |
+
super(DUF, self).__init__()
|
206 |
+
self.scale = scale
|
207 |
+
if adapt_official_weights:
|
208 |
+
eps = 1e-3
|
209 |
+
momentum = 1e-3
|
210 |
+
else: # pytorch default values
|
211 |
+
eps = 1e-05
|
212 |
+
momentum = 0.1
|
213 |
+
|
214 |
+
self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
215 |
+
self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
|
216 |
+
|
217 |
+
if num_layer == 16:
|
218 |
+
num_block = 3
|
219 |
+
num_grow_ch = 32
|
220 |
+
elif num_layer == 28:
|
221 |
+
num_block = 9
|
222 |
+
num_grow_ch = 16
|
223 |
+
elif num_layer == 52:
|
224 |
+
num_block = 21
|
225 |
+
num_grow_ch = 16
|
226 |
+
else:
|
227 |
+
raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
|
228 |
+
|
229 |
+
self.dense_block1 = DenseBlocks(
|
230 |
+
num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
|
231 |
+
adapt_official_weights=adapt_official_weights) # T = 7
|
232 |
+
self.dense_block2 = DenseBlocksTemporalReduce(
|
233 |
+
64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
|
234 |
+
channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
|
235 |
+
self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
|
236 |
+
self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
237 |
+
|
238 |
+
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
239 |
+
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
240 |
+
|
241 |
+
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
242 |
+
self.conv3d_f2 = nn.Conv3d(
|
243 |
+
512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
244 |
+
|
245 |
+
def forward(self, x):
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
x (Tensor): Input with shape (b, 7, c, h, w)
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
Tensor: Output with shape (b, c, h * scale, w * scale)
|
252 |
+
"""
|
253 |
+
num_batches, num_imgs, _, h, w = x.size()
|
254 |
+
|
255 |
+
x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
|
256 |
+
x_center = x[:, :, num_imgs // 2, :, :]
|
257 |
+
|
258 |
+
x = self.conv3d1(x)
|
259 |
+
x = self.dense_block1(x)
|
260 |
+
x = self.dense_block2(x)
|
261 |
+
x = F.relu(self.bn3d2(x), inplace=True)
|
262 |
+
x = F.relu(self.conv3d2(x), inplace=True)
|
263 |
+
|
264 |
+
# residual image
|
265 |
+
res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
|
266 |
+
|
267 |
+
# filter
|
268 |
+
filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
|
269 |
+
filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
|
270 |
+
|
271 |
+
# dynamic filter
|
272 |
+
out = self.dynamic_filter(x_center, filter_)
|
273 |
+
out += res.squeeze_(2)
|
274 |
+
out = F.pixel_shuffle(out, self.scale)
|
275 |
+
|
276 |
+
return out
|
StableSR/basicsr/archs/ecbsr_arch.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
class SeqConv3x3(nn.Module):
|
9 |
+
"""The re-parameterizable block used in the ECBSR architecture.
|
10 |
+
|
11 |
+
``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
|
12 |
+
|
13 |
+
Reference: https://github.com/xindongzhang/ECBSR
|
14 |
+
|
15 |
+
Args:
|
16 |
+
seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
|
17 |
+
in_channels (int): Channel number of input.
|
18 |
+
out_channels (int): Channel number of output.
|
19 |
+
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
|
23 |
+
super(SeqConv3x3, self).__init__()
|
24 |
+
self.seq_type = seq_type
|
25 |
+
self.in_channels = in_channels
|
26 |
+
self.out_channels = out_channels
|
27 |
+
|
28 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
29 |
+
self.mid_planes = int(out_channels * depth_multiplier)
|
30 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
|
31 |
+
self.k0 = conv0.weight
|
32 |
+
self.b0 = conv0.bias
|
33 |
+
|
34 |
+
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
|
35 |
+
self.k1 = conv1.weight
|
36 |
+
self.b1 = conv1.bias
|
37 |
+
|
38 |
+
elif self.seq_type == 'conv1x1-sobelx':
|
39 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
40 |
+
self.k0 = conv0.weight
|
41 |
+
self.b0 = conv0.bias
|
42 |
+
|
43 |
+
# init scale and bias
|
44 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
45 |
+
self.scale = nn.Parameter(scale)
|
46 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
47 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
48 |
+
self.bias = nn.Parameter(bias)
|
49 |
+
# init mask
|
50 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
51 |
+
for i in range(self.out_channels):
|
52 |
+
self.mask[i, 0, 0, 0] = 1.0
|
53 |
+
self.mask[i, 0, 1, 0] = 2.0
|
54 |
+
self.mask[i, 0, 2, 0] = 1.0
|
55 |
+
self.mask[i, 0, 0, 2] = -1.0
|
56 |
+
self.mask[i, 0, 1, 2] = -2.0
|
57 |
+
self.mask[i, 0, 2, 2] = -1.0
|
58 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
59 |
+
|
60 |
+
elif self.seq_type == 'conv1x1-sobely':
|
61 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
62 |
+
self.k0 = conv0.weight
|
63 |
+
self.b0 = conv0.bias
|
64 |
+
|
65 |
+
# init scale and bias
|
66 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
67 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
68 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
69 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
70 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
71 |
+
# init mask
|
72 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
73 |
+
for i in range(self.out_channels):
|
74 |
+
self.mask[i, 0, 0, 0] = 1.0
|
75 |
+
self.mask[i, 0, 0, 1] = 2.0
|
76 |
+
self.mask[i, 0, 0, 2] = 1.0
|
77 |
+
self.mask[i, 0, 2, 0] = -1.0
|
78 |
+
self.mask[i, 0, 2, 1] = -2.0
|
79 |
+
self.mask[i, 0, 2, 2] = -1.0
|
80 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
81 |
+
|
82 |
+
elif self.seq_type == 'conv1x1-laplacian':
|
83 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
84 |
+
self.k0 = conv0.weight
|
85 |
+
self.b0 = conv0.bias
|
86 |
+
|
87 |
+
# init scale and bias
|
88 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
89 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
90 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
91 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
92 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
93 |
+
# init mask
|
94 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
95 |
+
for i in range(self.out_channels):
|
96 |
+
self.mask[i, 0, 0, 1] = 1.0
|
97 |
+
self.mask[i, 0, 1, 0] = 1.0
|
98 |
+
self.mask[i, 0, 1, 2] = 1.0
|
99 |
+
self.mask[i, 0, 2, 1] = 1.0
|
100 |
+
self.mask[i, 0, 1, 1] = -4.0
|
101 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
102 |
+
else:
|
103 |
+
raise ValueError('The type of seqconv is not supported!')
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
107 |
+
# conv-1x1
|
108 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
109 |
+
# explicitly padding with bias
|
110 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
111 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
112 |
+
y0[:, :, 0:1, :] = b0_pad
|
113 |
+
y0[:, :, -1:, :] = b0_pad
|
114 |
+
y0[:, :, :, 0:1] = b0_pad
|
115 |
+
y0[:, :, :, -1:] = b0_pad
|
116 |
+
# conv-3x3
|
117 |
+
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
|
118 |
+
else:
|
119 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
120 |
+
# explicitly padding with bias
|
121 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
122 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
123 |
+
y0[:, :, 0:1, :] = b0_pad
|
124 |
+
y0[:, :, -1:, :] = b0_pad
|
125 |
+
y0[:, :, :, 0:1] = b0_pad
|
126 |
+
y0[:, :, :, -1:] = b0_pad
|
127 |
+
# conv-3x3
|
128 |
+
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
|
129 |
+
return y1
|
130 |
+
|
131 |
+
def rep_params(self):
|
132 |
+
device = self.k0.get_device()
|
133 |
+
if device < 0:
|
134 |
+
device = None
|
135 |
+
|
136 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
137 |
+
# re-param conv kernel
|
138 |
+
rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
|
139 |
+
# re-param conv bias
|
140 |
+
rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
141 |
+
rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
|
142 |
+
else:
|
143 |
+
tmp = self.scale * self.mask
|
144 |
+
k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
|
145 |
+
for i in range(self.out_channels):
|
146 |
+
k1[i, i, :, :] = tmp[i, 0, :, :]
|
147 |
+
b1 = self.bias
|
148 |
+
# re-param conv kernel
|
149 |
+
rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
|
150 |
+
# re-param conv bias
|
151 |
+
rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
152 |
+
rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
|
153 |
+
return rep_weight, rep_bias
|
154 |
+
|
155 |
+
|
156 |
+
class ECB(nn.Module):
|
157 |
+
"""The ECB block used in the ECBSR architecture.
|
158 |
+
|
159 |
+
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
|
160 |
+
Ref git repo: https://github.com/xindongzhang/ECBSR
|
161 |
+
|
162 |
+
Args:
|
163 |
+
in_channels (int): Channel number of input.
|
164 |
+
out_channels (int): Channel number of output.
|
165 |
+
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
|
166 |
+
act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
|
167 |
+
with_idt (bool): Whether to use identity connection. Default: False.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
|
171 |
+
super(ECB, self).__init__()
|
172 |
+
|
173 |
+
self.depth_multiplier = depth_multiplier
|
174 |
+
self.in_channels = in_channels
|
175 |
+
self.out_channels = out_channels
|
176 |
+
self.act_type = act_type
|
177 |
+
|
178 |
+
if with_idt and (self.in_channels == self.out_channels):
|
179 |
+
self.with_idt = True
|
180 |
+
else:
|
181 |
+
self.with_idt = False
|
182 |
+
|
183 |
+
self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
|
184 |
+
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
|
185 |
+
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
|
186 |
+
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
|
187 |
+
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
|
188 |
+
|
189 |
+
if self.act_type == 'prelu':
|
190 |
+
self.act = nn.PReLU(num_parameters=self.out_channels)
|
191 |
+
elif self.act_type == 'relu':
|
192 |
+
self.act = nn.ReLU(inplace=True)
|
193 |
+
elif self.act_type == 'rrelu':
|
194 |
+
self.act = nn.RReLU(lower=-0.05, upper=0.05)
|
195 |
+
elif self.act_type == 'softplus':
|
196 |
+
self.act = nn.Softplus()
|
197 |
+
elif self.act_type == 'linear':
|
198 |
+
pass
|
199 |
+
else:
|
200 |
+
raise ValueError('The type of activation if not support!')
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
if self.training:
|
204 |
+
y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
|
205 |
+
if self.with_idt:
|
206 |
+
y += x
|
207 |
+
else:
|
208 |
+
rep_weight, rep_bias = self.rep_params()
|
209 |
+
y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
|
210 |
+
if self.act_type != 'linear':
|
211 |
+
y = self.act(y)
|
212 |
+
return y
|
213 |
+
|
214 |
+
def rep_params(self):
|
215 |
+
weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
|
216 |
+
weight1, bias1 = self.conv1x1_3x3.rep_params()
|
217 |
+
weight2, bias2 = self.conv1x1_sbx.rep_params()
|
218 |
+
weight3, bias3 = self.conv1x1_sby.rep_params()
|
219 |
+
weight4, bias4 = self.conv1x1_lpl.rep_params()
|
220 |
+
rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
|
221 |
+
bias0 + bias1 + bias2 + bias3 + bias4)
|
222 |
+
|
223 |
+
if self.with_idt:
|
224 |
+
device = rep_weight.get_device()
|
225 |
+
if device < 0:
|
226 |
+
device = None
|
227 |
+
weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
|
228 |
+
for i in range(self.out_channels):
|
229 |
+
weight_idt[i, i, 1, 1] = 1.0
|
230 |
+
bias_idt = 0.0
|
231 |
+
rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
|
232 |
+
return rep_weight, rep_bias
|
233 |
+
|
234 |
+
|
235 |
+
@ARCH_REGISTRY.register()
|
236 |
+
class ECBSR(nn.Module):
|
237 |
+
"""ECBSR architecture.
|
238 |
+
|
239 |
+
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
|
240 |
+
Ref git repo: https://github.com/xindongzhang/ECBSR
|
241 |
+
|
242 |
+
Args:
|
243 |
+
num_in_ch (int): Channel number of inputs.
|
244 |
+
num_out_ch (int): Channel number of outputs.
|
245 |
+
num_block (int): Block number in the trunk network.
|
246 |
+
num_channel (int): Channel number.
|
247 |
+
with_idt (bool): Whether use identity in convolution layers.
|
248 |
+
act_type (str): Activation type.
|
249 |
+
scale (int): Upsampling factor.
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
|
253 |
+
super(ECBSR, self).__init__()
|
254 |
+
self.num_in_ch = num_in_ch
|
255 |
+
self.scale = scale
|
256 |
+
|
257 |
+
backbone = []
|
258 |
+
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
|
259 |
+
for _ in range(num_block):
|
260 |
+
backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
|
261 |
+
backbone += [
|
262 |
+
ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
|
263 |
+
]
|
264 |
+
|
265 |
+
self.backbone = nn.Sequential(*backbone)
|
266 |
+
self.upsampler = nn.PixelShuffle(scale)
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
if self.num_in_ch > 1:
|
270 |
+
shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
|
271 |
+
else:
|
272 |
+
shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
|
273 |
+
y = self.backbone(x) + shortcut
|
274 |
+
y = self.upsampler(y)
|
275 |
+
return y
|
StableSR/basicsr/archs/edsr_arch.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
|
4 |
+
from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class EDSR(nn.Module):
|
10 |
+
"""EDSR network structure.
|
11 |
+
|
12 |
+
Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
|
13 |
+
Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_in_ch (int): Channel number of inputs.
|
17 |
+
num_out_ch (int): Channel number of outputs.
|
18 |
+
num_feat (int): Channel number of intermediate features.
|
19 |
+
Default: 64.
|
20 |
+
num_block (int): Block number in the trunk network. Default: 16.
|
21 |
+
upscale (int): Upsampling factor. Support 2^n and 3.
|
22 |
+
Default: 4.
|
23 |
+
res_scale (float): Used to scale the residual in residual block.
|
24 |
+
Default: 1.
|
25 |
+
img_range (float): Image range. Default: 255.
|
26 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
27 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
num_in_ch,
|
32 |
+
num_out_ch,
|
33 |
+
num_feat=64,
|
34 |
+
num_block=16,
|
35 |
+
upscale=4,
|
36 |
+
res_scale=1,
|
37 |
+
img_range=255.,
|
38 |
+
rgb_mean=(0.4488, 0.4371, 0.4040)):
|
39 |
+
super(EDSR, self).__init__()
|
40 |
+
|
41 |
+
self.img_range = img_range
|
42 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
43 |
+
|
44 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
45 |
+
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
|
46 |
+
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
47 |
+
self.upsample = Upsample(upscale, num_feat)
|
48 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
self.mean = self.mean.type_as(x)
|
52 |
+
|
53 |
+
x = (x - self.mean) * self.img_range
|
54 |
+
x = self.conv_first(x)
|
55 |
+
res = self.conv_after_body(self.body(x))
|
56 |
+
res += x
|
57 |
+
|
58 |
+
x = self.conv_last(self.upsample(res))
|
59 |
+
x = x / self.img_range + self.mean
|
60 |
+
|
61 |
+
return x
|
StableSR/basicsr/archs/edvr_arch.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
|
7 |
+
|
8 |
+
|
9 |
+
class PCDAlignment(nn.Module):
|
10 |
+
"""Alignment module using Pyramid, Cascading and Deformable convolution
|
11 |
+
(PCD). It is used in EDVR.
|
12 |
+
|
13 |
+
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_feat (int): Channel number of middle features. Default: 64.
|
17 |
+
deformable_groups (int): Deformable groups. Defaults: 8.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, num_feat=64, deformable_groups=8):
|
21 |
+
super(PCDAlignment, self).__init__()
|
22 |
+
|
23 |
+
# Pyramid has three levels:
|
24 |
+
# L3: level 3, 1/4 spatial size
|
25 |
+
# L2: level 2, 1/2 spatial size
|
26 |
+
# L1: level 1, original spatial size
|
27 |
+
self.offset_conv1 = nn.ModuleDict()
|
28 |
+
self.offset_conv2 = nn.ModuleDict()
|
29 |
+
self.offset_conv3 = nn.ModuleDict()
|
30 |
+
self.dcn_pack = nn.ModuleDict()
|
31 |
+
self.feat_conv = nn.ModuleDict()
|
32 |
+
|
33 |
+
# Pyramids
|
34 |
+
for i in range(3, 0, -1):
|
35 |
+
level = f'l{i}'
|
36 |
+
self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
37 |
+
if i == 3:
|
38 |
+
self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
39 |
+
else:
|
40 |
+
self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
41 |
+
self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
42 |
+
self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
|
43 |
+
|
44 |
+
if i < 3:
|
45 |
+
self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
46 |
+
|
47 |
+
# Cascading dcn
|
48 |
+
self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
49 |
+
self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
50 |
+
self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
|
51 |
+
|
52 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
53 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
54 |
+
|
55 |
+
def forward(self, nbr_feat_l, ref_feat_l):
|
56 |
+
"""Align neighboring frame features to the reference frame features.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
nbr_feat_l (list[Tensor]): Neighboring feature list. It
|
60 |
+
contains three pyramid levels (L1, L2, L3),
|
61 |
+
each with shape (b, c, h, w).
|
62 |
+
ref_feat_l (list[Tensor]): Reference feature list. It
|
63 |
+
contains three pyramid levels (L1, L2, L3),
|
64 |
+
each with shape (b, c, h, w).
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensor: Aligned features.
|
68 |
+
"""
|
69 |
+
# Pyramids
|
70 |
+
upsampled_offset, upsampled_feat = None, None
|
71 |
+
for i in range(3, 0, -1):
|
72 |
+
level = f'l{i}'
|
73 |
+
offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
|
74 |
+
offset = self.lrelu(self.offset_conv1[level](offset))
|
75 |
+
if i == 3:
|
76 |
+
offset = self.lrelu(self.offset_conv2[level](offset))
|
77 |
+
else:
|
78 |
+
offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
|
79 |
+
offset = self.lrelu(self.offset_conv3[level](offset))
|
80 |
+
|
81 |
+
feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
|
82 |
+
if i < 3:
|
83 |
+
feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
|
84 |
+
if i > 1:
|
85 |
+
feat = self.lrelu(feat)
|
86 |
+
|
87 |
+
if i > 1: # upsample offset and features
|
88 |
+
# x2: when we upsample the offset, we should also enlarge
|
89 |
+
# the magnitude.
|
90 |
+
upsampled_offset = self.upsample(offset) * 2
|
91 |
+
upsampled_feat = self.upsample(feat)
|
92 |
+
|
93 |
+
# Cascading
|
94 |
+
offset = torch.cat([feat, ref_feat_l[0]], dim=1)
|
95 |
+
offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
|
96 |
+
feat = self.lrelu(self.cas_dcnpack(feat, offset))
|
97 |
+
return feat
|
98 |
+
|
99 |
+
|
100 |
+
class TSAFusion(nn.Module):
|
101 |
+
"""Temporal Spatial Attention (TSA) fusion module.
|
102 |
+
|
103 |
+
Temporal: Calculate the correlation between center frame and
|
104 |
+
neighboring frames;
|
105 |
+
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
|
106 |
+
(SFT: Recovering realistic texture in image super-resolution by deep
|
107 |
+
spatial feature transform.)
|
108 |
+
|
109 |
+
Args:
|
110 |
+
num_feat (int): Channel number of middle features. Default: 64.
|
111 |
+
num_frame (int): Number of frames. Default: 5.
|
112 |
+
center_frame_idx (int): The index of center frame. Default: 2.
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
|
116 |
+
super(TSAFusion, self).__init__()
|
117 |
+
self.center_frame_idx = center_frame_idx
|
118 |
+
# temporal attention (before fusion conv)
|
119 |
+
self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
120 |
+
self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
121 |
+
self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
|
122 |
+
|
123 |
+
# spatial attention (after fusion conv)
|
124 |
+
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
|
125 |
+
self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
|
126 |
+
self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
|
127 |
+
self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
|
128 |
+
self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
129 |
+
self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
|
130 |
+
self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
131 |
+
self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
|
132 |
+
self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
133 |
+
self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
134 |
+
self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
|
135 |
+
self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
|
136 |
+
|
137 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
138 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
139 |
+
|
140 |
+
def forward(self, aligned_feat):
|
141 |
+
"""
|
142 |
+
Args:
|
143 |
+
aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
Tensor: Features after TSA with the shape (b, c, h, w).
|
147 |
+
"""
|
148 |
+
b, t, c, h, w = aligned_feat.size()
|
149 |
+
# temporal attention
|
150 |
+
embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
|
151 |
+
embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
|
152 |
+
embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
|
153 |
+
|
154 |
+
corr_l = [] # correlation list
|
155 |
+
for i in range(t):
|
156 |
+
emb_neighbor = embedding[:, i, :, :, :]
|
157 |
+
corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
|
158 |
+
corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
|
159 |
+
corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
|
160 |
+
corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
|
161 |
+
corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
|
162 |
+
aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
|
163 |
+
|
164 |
+
# fusion
|
165 |
+
feat = self.lrelu(self.feat_fusion(aligned_feat))
|
166 |
+
|
167 |
+
# spatial attention
|
168 |
+
attn = self.lrelu(self.spatial_attn1(aligned_feat))
|
169 |
+
attn_max = self.max_pool(attn)
|
170 |
+
attn_avg = self.avg_pool(attn)
|
171 |
+
attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
|
172 |
+
# pyramid levels
|
173 |
+
attn_level = self.lrelu(self.spatial_attn_l1(attn))
|
174 |
+
attn_max = self.max_pool(attn_level)
|
175 |
+
attn_avg = self.avg_pool(attn_level)
|
176 |
+
attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
|
177 |
+
attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
|
178 |
+
attn_level = self.upsample(attn_level)
|
179 |
+
|
180 |
+
attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
|
181 |
+
attn = self.lrelu(self.spatial_attn4(attn))
|
182 |
+
attn = self.upsample(attn)
|
183 |
+
attn = self.spatial_attn5(attn)
|
184 |
+
attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
|
185 |
+
attn = torch.sigmoid(attn)
|
186 |
+
|
187 |
+
# after initialization, * 2 makes (attn * 2) to be close to 1.
|
188 |
+
feat = feat * attn * 2 + attn_add
|
189 |
+
return feat
|
190 |
+
|
191 |
+
|
192 |
+
class PredeblurModule(nn.Module):
|
193 |
+
"""Pre-dublur module.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
num_in_ch (int): Channel number of input image. Default: 3.
|
197 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
198 |
+
hr_in (bool): Whether the input has high resolution. Default: False.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
|
202 |
+
super(PredeblurModule, self).__init__()
|
203 |
+
self.hr_in = hr_in
|
204 |
+
|
205 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
206 |
+
if self.hr_in:
|
207 |
+
# downsample x4 by stride conv
|
208 |
+
self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
209 |
+
self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
210 |
+
|
211 |
+
# generate feature pyramid
|
212 |
+
self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
213 |
+
self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
214 |
+
|
215 |
+
self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
|
216 |
+
self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
|
217 |
+
self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
|
218 |
+
self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
|
219 |
+
|
220 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
221 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
feat_l1 = self.lrelu(self.conv_first(x))
|
225 |
+
if self.hr_in:
|
226 |
+
feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
|
227 |
+
feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
|
228 |
+
|
229 |
+
# generate feature pyramid
|
230 |
+
feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
|
231 |
+
feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
|
232 |
+
|
233 |
+
feat_l3 = self.upsample(self.resblock_l3(feat_l3))
|
234 |
+
feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
|
235 |
+
feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
|
236 |
+
|
237 |
+
for i in range(2):
|
238 |
+
feat_l1 = self.resblock_l1[i](feat_l1)
|
239 |
+
feat_l1 = feat_l1 + feat_l2
|
240 |
+
for i in range(2, 5):
|
241 |
+
feat_l1 = self.resblock_l1[i](feat_l1)
|
242 |
+
return feat_l1
|
243 |
+
|
244 |
+
|
245 |
+
@ARCH_REGISTRY.register()
|
246 |
+
class EDVR(nn.Module):
|
247 |
+
"""EDVR network structure for video super-resolution.
|
248 |
+
|
249 |
+
Now only support X4 upsampling factor.
|
250 |
+
|
251 |
+
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
|
252 |
+
|
253 |
+
Args:
|
254 |
+
num_in_ch (int): Channel number of input image. Default: 3.
|
255 |
+
num_out_ch (int): Channel number of output image. Default: 3.
|
256 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
257 |
+
num_frame (int): Number of input frames. Default: 5.
|
258 |
+
deformable_groups (int): Deformable groups. Defaults: 8.
|
259 |
+
num_extract_block (int): Number of blocks for feature extraction.
|
260 |
+
Default: 5.
|
261 |
+
num_reconstruct_block (int): Number of blocks for reconstruction.
|
262 |
+
Default: 10.
|
263 |
+
center_frame_idx (int): The index of center frame. Frame counting from
|
264 |
+
0. Default: Middle of input frames.
|
265 |
+
hr_in (bool): Whether the input has high resolution. Default: False.
|
266 |
+
with_predeblur (bool): Whether has predeblur module.
|
267 |
+
Default: False.
|
268 |
+
with_tsa (bool): Whether has TSA module. Default: True.
|
269 |
+
"""
|
270 |
+
|
271 |
+
def __init__(self,
|
272 |
+
num_in_ch=3,
|
273 |
+
num_out_ch=3,
|
274 |
+
num_feat=64,
|
275 |
+
num_frame=5,
|
276 |
+
deformable_groups=8,
|
277 |
+
num_extract_block=5,
|
278 |
+
num_reconstruct_block=10,
|
279 |
+
center_frame_idx=None,
|
280 |
+
hr_in=False,
|
281 |
+
with_predeblur=False,
|
282 |
+
with_tsa=True):
|
283 |
+
super(EDVR, self).__init__()
|
284 |
+
if center_frame_idx is None:
|
285 |
+
self.center_frame_idx = num_frame // 2
|
286 |
+
else:
|
287 |
+
self.center_frame_idx = center_frame_idx
|
288 |
+
self.hr_in = hr_in
|
289 |
+
self.with_predeblur = with_predeblur
|
290 |
+
self.with_tsa = with_tsa
|
291 |
+
|
292 |
+
# extract features for each frame
|
293 |
+
if self.with_predeblur:
|
294 |
+
self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
|
295 |
+
self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
|
296 |
+
else:
|
297 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
298 |
+
|
299 |
+
# extract pyramid features
|
300 |
+
self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
|
301 |
+
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
302 |
+
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
303 |
+
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
304 |
+
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
305 |
+
|
306 |
+
# pcd and tsa module
|
307 |
+
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
|
308 |
+
if self.with_tsa:
|
309 |
+
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
|
310 |
+
else:
|
311 |
+
self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
|
312 |
+
|
313 |
+
# reconstruction
|
314 |
+
self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
|
315 |
+
# upsample
|
316 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
317 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
|
318 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
319 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
320 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
321 |
+
|
322 |
+
# activation function
|
323 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
b, t, c, h, w = x.size()
|
327 |
+
if self.hr_in:
|
328 |
+
assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
|
329 |
+
else:
|
330 |
+
assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
|
331 |
+
|
332 |
+
x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
|
333 |
+
|
334 |
+
# extract features for each frame
|
335 |
+
# L1
|
336 |
+
if self.with_predeblur:
|
337 |
+
feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
|
338 |
+
if self.hr_in:
|
339 |
+
h, w = h // 4, w // 4
|
340 |
+
else:
|
341 |
+
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
|
342 |
+
|
343 |
+
feat_l1 = self.feature_extraction(feat_l1)
|
344 |
+
# L2
|
345 |
+
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
|
346 |
+
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
|
347 |
+
# L3
|
348 |
+
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
|
349 |
+
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
|
350 |
+
|
351 |
+
feat_l1 = feat_l1.view(b, t, -1, h, w)
|
352 |
+
feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
|
353 |
+
feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
|
354 |
+
|
355 |
+
# PCD alignment
|
356 |
+
ref_feat_l = [ # reference feature list
|
357 |
+
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
|
358 |
+
feat_l3[:, self.center_frame_idx, :, :, :].clone()
|
359 |
+
]
|
360 |
+
aligned_feat = []
|
361 |
+
for i in range(t):
|
362 |
+
nbr_feat_l = [ # neighboring feature list
|
363 |
+
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
|
364 |
+
]
|
365 |
+
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
|
366 |
+
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
|
367 |
+
|
368 |
+
if not self.with_tsa:
|
369 |
+
aligned_feat = aligned_feat.view(b, -1, h, w)
|
370 |
+
feat = self.fusion(aligned_feat)
|
371 |
+
|
372 |
+
out = self.reconstruction(feat)
|
373 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
374 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
375 |
+
out = self.lrelu(self.conv_hr(out))
|
376 |
+
out = self.conv_last(out)
|
377 |
+
if self.hr_in:
|
378 |
+
base = x_center
|
379 |
+
else:
|
380 |
+
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
|
381 |
+
out += base
|
382 |
+
return out
|
StableSR/basicsr/archs/hifacegan_arch.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
|
8 |
+
|
9 |
+
|
10 |
+
class SPADEGenerator(BaseNetwork):
|
11 |
+
"""Generator with SPADEResBlock"""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
num_in_ch=3,
|
15 |
+
num_feat=64,
|
16 |
+
use_vae=False,
|
17 |
+
z_dim=256,
|
18 |
+
crop_size=512,
|
19 |
+
norm_g='spectralspadesyncbatch3x3',
|
20 |
+
is_train=True,
|
21 |
+
init_train_phase=3): # progressive training disabled
|
22 |
+
super().__init__()
|
23 |
+
self.nf = num_feat
|
24 |
+
self.input_nc = num_in_ch
|
25 |
+
self.is_train = is_train
|
26 |
+
self.train_phase = init_train_phase
|
27 |
+
|
28 |
+
self.scale_ratio = 5 # hardcoded now
|
29 |
+
self.sw = crop_size // (2**self.scale_ratio)
|
30 |
+
self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
|
31 |
+
|
32 |
+
if use_vae:
|
33 |
+
# In case of VAE, we will sample from random z vector
|
34 |
+
self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
|
35 |
+
else:
|
36 |
+
# Otherwise, we make the network deterministic by starting with
|
37 |
+
# downsampled segmentation map instead of random z
|
38 |
+
self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
|
39 |
+
|
40 |
+
self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
41 |
+
|
42 |
+
self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
43 |
+
self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
44 |
+
|
45 |
+
self.ups = nn.ModuleList([
|
46 |
+
SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
|
47 |
+
SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
|
48 |
+
SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
|
49 |
+
SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
|
50 |
+
])
|
51 |
+
|
52 |
+
self.to_rgbs = nn.ModuleList([
|
53 |
+
nn.Conv2d(8 * self.nf, 3, 3, padding=1),
|
54 |
+
nn.Conv2d(4 * self.nf, 3, 3, padding=1),
|
55 |
+
nn.Conv2d(2 * self.nf, 3, 3, padding=1),
|
56 |
+
nn.Conv2d(1 * self.nf, 3, 3, padding=1)
|
57 |
+
])
|
58 |
+
|
59 |
+
self.up = nn.Upsample(scale_factor=2)
|
60 |
+
|
61 |
+
def encode(self, input_tensor):
|
62 |
+
"""
|
63 |
+
Encode input_tensor into feature maps, can be overridden in derived classes
|
64 |
+
Default: nearest downsampling of 2**5 = 32 times
|
65 |
+
"""
|
66 |
+
h, w = input_tensor.size()[-2:]
|
67 |
+
sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
|
68 |
+
x = F.interpolate(input_tensor, size=(sh, sw))
|
69 |
+
return self.fc(x)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
# In oroginal SPADE, seg means a segmentation map, but here we use x instead.
|
73 |
+
seg = x
|
74 |
+
|
75 |
+
x = self.encode(x)
|
76 |
+
x = self.head_0(x, seg)
|
77 |
+
|
78 |
+
x = self.up(x)
|
79 |
+
x = self.g_middle_0(x, seg)
|
80 |
+
x = self.g_middle_1(x, seg)
|
81 |
+
|
82 |
+
if self.is_train:
|
83 |
+
phase = self.train_phase + 1
|
84 |
+
else:
|
85 |
+
phase = len(self.to_rgbs)
|
86 |
+
|
87 |
+
for i in range(phase):
|
88 |
+
x = self.up(x)
|
89 |
+
x = self.ups[i](x, seg)
|
90 |
+
|
91 |
+
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
|
92 |
+
x = torch.tanh(x)
|
93 |
+
|
94 |
+
return x
|
95 |
+
|
96 |
+
def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
|
97 |
+
"""
|
98 |
+
A helper class for subspace visualization. Input and seg are different images.
|
99 |
+
For the first n levels (including encoder) we use input, for the rest we use seg.
|
100 |
+
|
101 |
+
If mode = 'progressive', the output's like: AAABBB
|
102 |
+
If mode = 'one_plug', the output's like: AAABAA
|
103 |
+
If mode = 'one_ablate', the output's like: BBBABB
|
104 |
+
"""
|
105 |
+
|
106 |
+
if seg is None:
|
107 |
+
return self.forward(input_x)
|
108 |
+
|
109 |
+
if self.is_train:
|
110 |
+
phase = self.train_phase + 1
|
111 |
+
else:
|
112 |
+
phase = len(self.to_rgbs)
|
113 |
+
|
114 |
+
if mode == 'progressive':
|
115 |
+
n = max(min(n, 4 + phase), 0)
|
116 |
+
guide_list = [input_x] * n + [seg] * (4 + phase - n)
|
117 |
+
elif mode == 'one_plug':
|
118 |
+
n = max(min(n, 4 + phase - 1), 0)
|
119 |
+
guide_list = [seg] * (4 + phase)
|
120 |
+
guide_list[n] = input_x
|
121 |
+
elif mode == 'one_ablate':
|
122 |
+
if n > 3 + phase:
|
123 |
+
return self.forward(input_x)
|
124 |
+
guide_list = [input_x] * (4 + phase)
|
125 |
+
guide_list[n] = seg
|
126 |
+
|
127 |
+
x = self.encode(guide_list[0])
|
128 |
+
x = self.head_0(x, guide_list[1])
|
129 |
+
|
130 |
+
x = self.up(x)
|
131 |
+
x = self.g_middle_0(x, guide_list[2])
|
132 |
+
x = self.g_middle_1(x, guide_list[3])
|
133 |
+
|
134 |
+
for i in range(phase):
|
135 |
+
x = self.up(x)
|
136 |
+
x = self.ups[i](x, guide_list[4 + i])
|
137 |
+
|
138 |
+
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
|
139 |
+
x = torch.tanh(x)
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
@ARCH_REGISTRY.register()
|
145 |
+
class HiFaceGAN(SPADEGenerator):
|
146 |
+
"""
|
147 |
+
HiFaceGAN: SPADEGenerator with a learnable feature encoder
|
148 |
+
Current encoder design: LIPEncoder
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self,
|
152 |
+
num_in_ch=3,
|
153 |
+
num_feat=64,
|
154 |
+
use_vae=False,
|
155 |
+
z_dim=256,
|
156 |
+
crop_size=512,
|
157 |
+
norm_g='spectralspadesyncbatch3x3',
|
158 |
+
is_train=True,
|
159 |
+
init_train_phase=3):
|
160 |
+
super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
|
161 |
+
self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
|
162 |
+
|
163 |
+
def encode(self, input_tensor):
|
164 |
+
return self.lip_encoder(input_tensor)
|
165 |
+
|
166 |
+
|
167 |
+
@ARCH_REGISTRY.register()
|
168 |
+
class HiFaceGANDiscriminator(BaseNetwork):
|
169 |
+
"""
|
170 |
+
Inspired by pix2pixHD multiscale discriminator.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
174 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
175 |
+
conditional_d (bool): Whether use conditional discriminator.
|
176 |
+
Default: True.
|
177 |
+
num_d (int): Number of Multiscale discriminators. Default: 3.
|
178 |
+
n_layers_d (int): Number of downsample layers in each D. Default: 4.
|
179 |
+
num_feat (int): Channel number of base intermediate features.
|
180 |
+
Default: 64.
|
181 |
+
norm_d (str): String to determine normalization layers in D.
|
182 |
+
Choices: [spectral][instance/batch/syncbatch]
|
183 |
+
Default: 'spectralinstance'.
|
184 |
+
keep_features (bool): Keep intermediate features for matching loss, etc.
|
185 |
+
Default: True.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self,
|
189 |
+
num_in_ch=3,
|
190 |
+
num_out_ch=3,
|
191 |
+
conditional_d=True,
|
192 |
+
num_d=2,
|
193 |
+
n_layers_d=4,
|
194 |
+
num_feat=64,
|
195 |
+
norm_d='spectralinstance',
|
196 |
+
keep_features=True):
|
197 |
+
super().__init__()
|
198 |
+
self.num_d = num_d
|
199 |
+
|
200 |
+
input_nc = num_in_ch
|
201 |
+
if conditional_d:
|
202 |
+
input_nc += num_out_ch
|
203 |
+
|
204 |
+
for i in range(num_d):
|
205 |
+
subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
|
206 |
+
self.add_module(f'discriminator_{i}', subnet_d)
|
207 |
+
|
208 |
+
def downsample(self, x):
|
209 |
+
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
210 |
+
|
211 |
+
# Returns list of lists of discriminator outputs.
|
212 |
+
# The final result is of size opt.num_d x opt.n_layers_D
|
213 |
+
def forward(self, x):
|
214 |
+
result = []
|
215 |
+
for _, _net_d in self.named_children():
|
216 |
+
out = _net_d(x)
|
217 |
+
result.append(out)
|
218 |
+
x = self.downsample(x)
|
219 |
+
|
220 |
+
return result
|
221 |
+
|
222 |
+
|
223 |
+
class NLayerDiscriminator(BaseNetwork):
|
224 |
+
"""Defines the PatchGAN discriminator with the specified arguments."""
|
225 |
+
|
226 |
+
def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
|
227 |
+
super().__init__()
|
228 |
+
kw = 4
|
229 |
+
padw = int(np.ceil((kw - 1.0) / 2))
|
230 |
+
nf = num_feat
|
231 |
+
self.keep_features = keep_features
|
232 |
+
|
233 |
+
norm_layer = get_nonspade_norm_layer(norm_d)
|
234 |
+
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
|
235 |
+
|
236 |
+
for n in range(1, n_layers_d):
|
237 |
+
nf_prev = nf
|
238 |
+
nf = min(nf * 2, 512)
|
239 |
+
stride = 1 if n == n_layers_d - 1 else 2
|
240 |
+
sequence += [[
|
241 |
+
norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
|
242 |
+
nn.LeakyReLU(0.2, False)
|
243 |
+
]]
|
244 |
+
|
245 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
246 |
+
|
247 |
+
# We divide the layers into groups to extract intermediate layer outputs
|
248 |
+
for n in range(len(sequence)):
|
249 |
+
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
results = [x]
|
253 |
+
for submodel in self.children():
|
254 |
+
intermediate_output = submodel(results[-1])
|
255 |
+
results.append(intermediate_output)
|
256 |
+
|
257 |
+
if self.keep_features:
|
258 |
+
return results[1:]
|
259 |
+
else:
|
260 |
+
return results[-1]
|
StableSR/basicsr/archs/hifacegan_util.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init
|
6 |
+
# Warning: spectral norm could be buggy
|
7 |
+
# under eval mode and multi-GPU inference
|
8 |
+
# A workaround is sticking to single-GPU inference and train mode
|
9 |
+
from torch.nn.utils import spectral_norm
|
10 |
+
|
11 |
+
|
12 |
+
class SPADE(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, config_text, norm_nc, label_nc):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
assert config_text.startswith('spade')
|
18 |
+
parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
|
19 |
+
param_free_norm_type = str(parsed.group(1))
|
20 |
+
ks = int(parsed.group(2))
|
21 |
+
|
22 |
+
if param_free_norm_type == 'instance':
|
23 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
|
24 |
+
elif param_free_norm_type == 'syncbatch':
|
25 |
+
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
|
26 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
|
27 |
+
elif param_free_norm_type == 'batch':
|
28 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
29 |
+
else:
|
30 |
+
raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
|
31 |
+
|
32 |
+
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
33 |
+
nhidden = 128 if norm_nc > 128 else norm_nc
|
34 |
+
|
35 |
+
pw = ks // 2
|
36 |
+
self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
|
37 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
|
38 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
|
39 |
+
|
40 |
+
def forward(self, x, segmap):
|
41 |
+
|
42 |
+
# Part 1. generate parameter-free normalized activations
|
43 |
+
normalized = self.param_free_norm(x)
|
44 |
+
|
45 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
46 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
47 |
+
actv = self.mlp_shared(segmap)
|
48 |
+
gamma = self.mlp_gamma(actv)
|
49 |
+
beta = self.mlp_beta(actv)
|
50 |
+
|
51 |
+
# apply scale and bias
|
52 |
+
out = normalized * gamma + beta
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class SPADEResnetBlock(nn.Module):
|
58 |
+
"""
|
59 |
+
ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
|
60 |
+
it takes in the segmentation map as input, learns the skip connection if necessary,
|
61 |
+
and applies normalization first and then convolution.
|
62 |
+
This architecture seemed like a standard architecture for unconditional or
|
63 |
+
class-conditional GAN architecture using residual block.
|
64 |
+
The code was inspired from https://github.com/LMescheder/GAN_stability.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
|
68 |
+
super().__init__()
|
69 |
+
# Attributes
|
70 |
+
self.learned_shortcut = (fin != fout)
|
71 |
+
fmiddle = min(fin, fout)
|
72 |
+
|
73 |
+
# create conv layers
|
74 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
75 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
76 |
+
if self.learned_shortcut:
|
77 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
78 |
+
|
79 |
+
# apply spectral norm if specified
|
80 |
+
if 'spectral' in norm_g:
|
81 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
82 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
83 |
+
if self.learned_shortcut:
|
84 |
+
self.conv_s = spectral_norm(self.conv_s)
|
85 |
+
|
86 |
+
# define normalization layers
|
87 |
+
spade_config_str = norm_g.replace('spectral', '')
|
88 |
+
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
|
89 |
+
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
|
90 |
+
if self.learned_shortcut:
|
91 |
+
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
|
92 |
+
|
93 |
+
# note the resnet block with SPADE also takes in |seg|,
|
94 |
+
# the semantic segmentation map as input
|
95 |
+
def forward(self, x, seg):
|
96 |
+
x_s = self.shortcut(x, seg)
|
97 |
+
dx = self.conv_0(self.act(self.norm_0(x, seg)))
|
98 |
+
dx = self.conv_1(self.act(self.norm_1(dx, seg)))
|
99 |
+
out = x_s + dx
|
100 |
+
return out
|
101 |
+
|
102 |
+
def shortcut(self, x, seg):
|
103 |
+
if self.learned_shortcut:
|
104 |
+
x_s = self.conv_s(self.norm_s(x, seg))
|
105 |
+
else:
|
106 |
+
x_s = x
|
107 |
+
return x_s
|
108 |
+
|
109 |
+
def act(self, x):
|
110 |
+
return F.leaky_relu(x, 2e-1)
|
111 |
+
|
112 |
+
|
113 |
+
class BaseNetwork(nn.Module):
|
114 |
+
""" A basis for hifacegan archs with custom initialization """
|
115 |
+
|
116 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
117 |
+
|
118 |
+
def init_func(m):
|
119 |
+
classname = m.__class__.__name__
|
120 |
+
if classname.find('BatchNorm2d') != -1:
|
121 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
122 |
+
init.normal_(m.weight.data, 1.0, gain)
|
123 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
124 |
+
init.constant_(m.bias.data, 0.0)
|
125 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
126 |
+
if init_type == 'normal':
|
127 |
+
init.normal_(m.weight.data, 0.0, gain)
|
128 |
+
elif init_type == 'xavier':
|
129 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
130 |
+
elif init_type == 'xavier_uniform':
|
131 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
132 |
+
elif init_type == 'kaiming':
|
133 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
134 |
+
elif init_type == 'orthogonal':
|
135 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
136 |
+
elif init_type == 'none': # uses pytorch's default init method
|
137 |
+
m.reset_parameters()
|
138 |
+
else:
|
139 |
+
raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
|
140 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
141 |
+
init.constant_(m.bias.data, 0.0)
|
142 |
+
|
143 |
+
self.apply(init_func)
|
144 |
+
|
145 |
+
# propagate to children
|
146 |
+
for m in self.children():
|
147 |
+
if hasattr(m, 'init_weights'):
|
148 |
+
m.init_weights(init_type, gain)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
pass
|
152 |
+
|
153 |
+
|
154 |
+
def lip2d(x, logit, kernel=3, stride=2, padding=1):
|
155 |
+
weight = logit.exp()
|
156 |
+
return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
|
157 |
+
|
158 |
+
|
159 |
+
class SoftGate(nn.Module):
|
160 |
+
COEFF = 12.0
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
return torch.sigmoid(x).mul(self.COEFF)
|
164 |
+
|
165 |
+
|
166 |
+
class SimplifiedLIP(nn.Module):
|
167 |
+
|
168 |
+
def __init__(self, channels):
|
169 |
+
super(SimplifiedLIP, self).__init__()
|
170 |
+
self.logit = nn.Sequential(
|
171 |
+
nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
|
172 |
+
SoftGate())
|
173 |
+
|
174 |
+
def init_layer(self):
|
175 |
+
self.logit[0].weight.data.fill_(0.0)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
frac = lip2d(x, self.logit(x))
|
179 |
+
return frac
|
180 |
+
|
181 |
+
|
182 |
+
class LIPEncoder(BaseNetwork):
|
183 |
+
"""Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
|
184 |
+
|
185 |
+
def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
|
186 |
+
super().__init__()
|
187 |
+
self.sw = sw
|
188 |
+
self.sh = sh
|
189 |
+
self.max_ratio = 16
|
190 |
+
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
|
191 |
+
kw = 3
|
192 |
+
pw = (kw - 1) // 2
|
193 |
+
|
194 |
+
model = [
|
195 |
+
nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
|
196 |
+
norm_layer(ngf),
|
197 |
+
nn.ReLU(),
|
198 |
+
]
|
199 |
+
cur_ratio = 1
|
200 |
+
for i in range(n_2xdown):
|
201 |
+
next_ratio = min(cur_ratio * 2, self.max_ratio)
|
202 |
+
model += [
|
203 |
+
SimplifiedLIP(ngf * cur_ratio),
|
204 |
+
nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
|
205 |
+
norm_layer(ngf * next_ratio),
|
206 |
+
]
|
207 |
+
cur_ratio = next_ratio
|
208 |
+
if i < n_2xdown - 1:
|
209 |
+
model += [nn.ReLU(inplace=True)]
|
210 |
+
|
211 |
+
self.model = nn.Sequential(*model)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
return self.model(x)
|
215 |
+
|
216 |
+
|
217 |
+
def get_nonspade_norm_layer(norm_type='instance'):
|
218 |
+
# helper function to get # output channels of the previous layer
|
219 |
+
def get_out_channel(layer):
|
220 |
+
if hasattr(layer, 'out_channels'):
|
221 |
+
return getattr(layer, 'out_channels')
|
222 |
+
return layer.weight.size(0)
|
223 |
+
|
224 |
+
# this function will be returned
|
225 |
+
def add_norm_layer(layer):
|
226 |
+
nonlocal norm_type
|
227 |
+
if norm_type.startswith('spectral'):
|
228 |
+
layer = spectral_norm(layer)
|
229 |
+
subnorm_type = norm_type[len('spectral'):]
|
230 |
+
|
231 |
+
if subnorm_type == 'none' or len(subnorm_type) == 0:
|
232 |
+
return layer
|
233 |
+
|
234 |
+
# remove bias in the previous layer, which is meaningless
|
235 |
+
# since it has no effect after normalization
|
236 |
+
if getattr(layer, 'bias', None) is not None:
|
237 |
+
delattr(layer, 'bias')
|
238 |
+
layer.register_parameter('bias', None)
|
239 |
+
|
240 |
+
if subnorm_type == 'batch':
|
241 |
+
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
242 |
+
elif subnorm_type == 'sync_batch':
|
243 |
+
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
|
244 |
+
# norm_layer = SynchronizedBatchNorm2d(
|
245 |
+
# get_out_channel(layer), affine=True)
|
246 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
247 |
+
elif subnorm_type == 'instance':
|
248 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
249 |
+
else:
|
250 |
+
raise ValueError(f'normalization layer {subnorm_type} is not recognized')
|
251 |
+
|
252 |
+
return nn.Sequential(layer, norm_layer)
|
253 |
+
|
254 |
+
print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
|
255 |
+
return add_norm_layer
|
StableSR/basicsr/archs/inception.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
|
2 |
+
# For FID metric
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.utils.model_zoo import load_url
|
9 |
+
from torchvision import models
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
15 |
+
|
16 |
+
|
17 |
+
class InceptionV3(nn.Module):
|
18 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
19 |
+
|
20 |
+
# Index of default block of inception to return,
|
21 |
+
# corresponds to output of final average pooling
|
22 |
+
DEFAULT_BLOCK_INDEX = 3
|
23 |
+
|
24 |
+
# Maps feature dimensionality to their output blocks indices
|
25 |
+
BLOCK_INDEX_BY_DIM = {
|
26 |
+
64: 0, # First max pooling features
|
27 |
+
192: 1, # Second max pooling features
|
28 |
+
768: 2, # Pre-aux classifier features
|
29 |
+
2048: 3 # Final average pooling features
|
30 |
+
}
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
output_blocks=(DEFAULT_BLOCK_INDEX),
|
34 |
+
resize_input=True,
|
35 |
+
normalize_input=True,
|
36 |
+
requires_grad=False,
|
37 |
+
use_fid_inception=True):
|
38 |
+
"""Build pretrained InceptionV3.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
output_blocks (list[int]): Indices of blocks to return features of.
|
42 |
+
Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input (bool): If true, bilinearly resizes input to width and
|
48 |
+
height 299 before feeding input to model. As the network
|
49 |
+
without fully connected layers is fully convolutional, it
|
50 |
+
should be able to handle inputs of arbitrary size, so resizing
|
51 |
+
might not be strictly needed. Default: True.
|
52 |
+
normalize_input (bool): If true, scales the input from range (0, 1)
|
53 |
+
to the range the pretrained Inception network expects,
|
54 |
+
namely (-1, 1). Default: True.
|
55 |
+
requires_grad (bool): If true, parameters of the model require
|
56 |
+
gradients. Possibly useful for finetuning the network.
|
57 |
+
Default: False.
|
58 |
+
use_fid_inception (bool): If true, uses the pretrained Inception
|
59 |
+
model used in Tensorflow's FID implementation.
|
60 |
+
If false, uses the pretrained Inception model available in
|
61 |
+
torchvision. The FID Inception model has different weights
|
62 |
+
and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get
|
65 |
+
comparable results. Default: True.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, ('Last possible output block index is 3')
|
75 |
+
|
76 |
+
self.blocks = nn.ModuleList()
|
77 |
+
|
78 |
+
if use_fid_inception:
|
79 |
+
inception = fid_inception_v3()
|
80 |
+
else:
|
81 |
+
try:
|
82 |
+
inception = models.inception_v3(pretrained=True, init_weights=False)
|
83 |
+
except TypeError:
|
84 |
+
# pytorch < 1.5 does not have init_weights for inception_v3
|
85 |
+
inception = models.inception_v3(pretrained=True)
|
86 |
+
|
87 |
+
# Block 0: input to maxpool1
|
88 |
+
block0 = [
|
89 |
+
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
|
90 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
91 |
+
]
|
92 |
+
self.blocks.append(nn.Sequential(*block0))
|
93 |
+
|
94 |
+
# Block 1: maxpool1 to maxpool2
|
95 |
+
if self.last_needed_block >= 1:
|
96 |
+
block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
|
97 |
+
self.blocks.append(nn.Sequential(*block1))
|
98 |
+
|
99 |
+
# Block 2: maxpool2 to aux classifier
|
100 |
+
if self.last_needed_block >= 2:
|
101 |
+
block2 = [
|
102 |
+
inception.Mixed_5b,
|
103 |
+
inception.Mixed_5c,
|
104 |
+
inception.Mixed_5d,
|
105 |
+
inception.Mixed_6a,
|
106 |
+
inception.Mixed_6b,
|
107 |
+
inception.Mixed_6c,
|
108 |
+
inception.Mixed_6d,
|
109 |
+
inception.Mixed_6e,
|
110 |
+
]
|
111 |
+
self.blocks.append(nn.Sequential(*block2))
|
112 |
+
|
113 |
+
# Block 3: aux classifier to final avgpool
|
114 |
+
if self.last_needed_block >= 3:
|
115 |
+
block3 = [
|
116 |
+
inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
|
117 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
118 |
+
]
|
119 |
+
self.blocks.append(nn.Sequential(*block3))
|
120 |
+
|
121 |
+
for param in self.parameters():
|
122 |
+
param.requires_grad = requires_grad
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
"""Get Inception feature maps.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
x (Tensor): Input tensor of shape (b, 3, h, w).
|
129 |
+
Values are expected to be in range (-1, 1). You can also input
|
130 |
+
(0, 1) with setting normalize_input = True.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
list[Tensor]: Corresponding to the selected output block, sorted
|
134 |
+
ascending by index.
|
135 |
+
"""
|
136 |
+
output = []
|
137 |
+
|
138 |
+
if self.resize_input:
|
139 |
+
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
|
140 |
+
|
141 |
+
if self.normalize_input:
|
142 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
143 |
+
|
144 |
+
for idx, block in enumerate(self.blocks):
|
145 |
+
x = block(x)
|
146 |
+
if idx in self.output_blocks:
|
147 |
+
output.append(x)
|
148 |
+
|
149 |
+
if idx == self.last_needed_block:
|
150 |
+
break
|
151 |
+
|
152 |
+
return output
|
153 |
+
|
154 |
+
|
155 |
+
def fid_inception_v3():
|
156 |
+
"""Build pretrained Inception model for FID computation.
|
157 |
+
|
158 |
+
The Inception model for FID computation uses a different set of weights
|
159 |
+
and has a slightly different structure than torchvision's Inception.
|
160 |
+
|
161 |
+
This method first constructs torchvision's Inception and then patches the
|
162 |
+
necessary parts that are different in the FID Inception model.
|
163 |
+
"""
|
164 |
+
try:
|
165 |
+
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
|
166 |
+
except TypeError:
|
167 |
+
# pytorch < 1.5 does not have init_weights for inception_v3
|
168 |
+
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
|
169 |
+
|
170 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
171 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
172 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
173 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
174 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
175 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
176 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
177 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
178 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
179 |
+
|
180 |
+
if os.path.exists(LOCAL_FID_WEIGHTS):
|
181 |
+
state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
|
182 |
+
else:
|
183 |
+
state_dict = load_url(FID_WEIGHTS_URL, progress=True)
|
184 |
+
|
185 |
+
inception.load_state_dict(state_dict)
|
186 |
+
return inception
|
187 |
+
|
188 |
+
|
189 |
+
class FIDInceptionA(models.inception.InceptionA):
|
190 |
+
"""InceptionA block patched for FID computation"""
|
191 |
+
|
192 |
+
def __init__(self, in_channels, pool_features):
|
193 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
branch1x1 = self.branch1x1(x)
|
197 |
+
|
198 |
+
branch5x5 = self.branch5x5_1(x)
|
199 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
200 |
+
|
201 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
202 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
203 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
204 |
+
|
205 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
206 |
+
# its average calculation
|
207 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
208 |
+
branch_pool = self.branch_pool(branch_pool)
|
209 |
+
|
210 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
211 |
+
return torch.cat(outputs, 1)
|
212 |
+
|
213 |
+
|
214 |
+
class FIDInceptionC(models.inception.InceptionC):
|
215 |
+
"""InceptionC block patched for FID computation"""
|
216 |
+
|
217 |
+
def __init__(self, in_channels, channels_7x7):
|
218 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
branch1x1 = self.branch1x1(x)
|
222 |
+
|
223 |
+
branch7x7 = self.branch7x7_1(x)
|
224 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
225 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
226 |
+
|
227 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
228 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
229 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
230 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
231 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
232 |
+
|
233 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
234 |
+
# its average calculation
|
235 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
236 |
+
branch_pool = self.branch_pool(branch_pool)
|
237 |
+
|
238 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
239 |
+
return torch.cat(outputs, 1)
|
240 |
+
|
241 |
+
|
242 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
243 |
+
"""First InceptionE block patched for FID computation"""
|
244 |
+
|
245 |
+
def __init__(self, in_channels):
|
246 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
branch1x1 = self.branch1x1(x)
|
250 |
+
|
251 |
+
branch3x3 = self.branch3x3_1(x)
|
252 |
+
branch3x3 = [
|
253 |
+
self.branch3x3_2a(branch3x3),
|
254 |
+
self.branch3x3_2b(branch3x3),
|
255 |
+
]
|
256 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
257 |
+
|
258 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
259 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
260 |
+
branch3x3dbl = [
|
261 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
262 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
263 |
+
]
|
264 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
265 |
+
|
266 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
267 |
+
# its average calculation
|
268 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
269 |
+
branch_pool = self.branch_pool(branch_pool)
|
270 |
+
|
271 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
272 |
+
return torch.cat(outputs, 1)
|
273 |
+
|
274 |
+
|
275 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
276 |
+
"""Second InceptionE block patched for FID computation"""
|
277 |
+
|
278 |
+
def __init__(self, in_channels):
|
279 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
branch1x1 = self.branch1x1(x)
|
283 |
+
|
284 |
+
branch3x3 = self.branch3x3_1(x)
|
285 |
+
branch3x3 = [
|
286 |
+
self.branch3x3_2a(branch3x3),
|
287 |
+
self.branch3x3_2b(branch3x3),
|
288 |
+
]
|
289 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
290 |
+
|
291 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
292 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
293 |
+
branch3x3dbl = [
|
294 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
295 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
296 |
+
]
|
297 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
298 |
+
|
299 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
300 |
+
# pooling. This is likely an error in this specific Inception
|
301 |
+
# implementation, as other Inception models use average pooling here
|
302 |
+
# (which matches the description in the paper).
|
303 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
304 |
+
branch_pool = self.branch_pool(branch_pool)
|
305 |
+
|
306 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
307 |
+
return torch.cat(outputs, 1)
|
StableSR/basicsr/archs/rcan_arch.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import Upsample, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
class ChannelAttention(nn.Module):
|
9 |
+
"""Channel attention used in RCAN.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
num_feat (int): Channel number of intermediate features.
|
13 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, num_feat, squeeze_factor=16):
|
17 |
+
super(ChannelAttention, self).__init__()
|
18 |
+
self.attention = nn.Sequential(
|
19 |
+
nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
|
20 |
+
nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
y = self.attention(x)
|
24 |
+
return x * y
|
25 |
+
|
26 |
+
|
27 |
+
class RCAB(nn.Module):
|
28 |
+
"""Residual Channel Attention Block (RCAB) used in RCAN.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
num_feat (int): Channel number of intermediate features.
|
32 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
33 |
+
res_scale (float): Scale the residual. Default: 1.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
|
37 |
+
super(RCAB, self).__init__()
|
38 |
+
self.res_scale = res_scale
|
39 |
+
|
40 |
+
self.rcab = nn.Sequential(
|
41 |
+
nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
|
42 |
+
ChannelAttention(num_feat, squeeze_factor))
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
res = self.rcab(x) * self.res_scale
|
46 |
+
return res + x
|
47 |
+
|
48 |
+
|
49 |
+
class ResidualGroup(nn.Module):
|
50 |
+
"""Residual Group of RCAB.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
num_feat (int): Channel number of intermediate features.
|
54 |
+
num_block (int): Block number in the body network.
|
55 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
56 |
+
res_scale (float): Scale the residual. Default: 1.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
|
60 |
+
super(ResidualGroup, self).__init__()
|
61 |
+
|
62 |
+
self.residual_group = make_layer(
|
63 |
+
RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
|
64 |
+
self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
res = self.conv(self.residual_group(x))
|
68 |
+
return res + x
|
69 |
+
|
70 |
+
|
71 |
+
@ARCH_REGISTRY.register()
|
72 |
+
class RCAN(nn.Module):
|
73 |
+
"""Residual Channel Attention Networks.
|
74 |
+
|
75 |
+
``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks``
|
76 |
+
|
77 |
+
Reference: https://github.com/yulunzhang/RCAN
|
78 |
+
|
79 |
+
Args:
|
80 |
+
num_in_ch (int): Channel number of inputs.
|
81 |
+
num_out_ch (int): Channel number of outputs.
|
82 |
+
num_feat (int): Channel number of intermediate features.
|
83 |
+
Default: 64.
|
84 |
+
num_group (int): Number of ResidualGroup. Default: 10.
|
85 |
+
num_block (int): Number of RCAB in ResidualGroup. Default: 16.
|
86 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
87 |
+
upscale (int): Upsampling factor. Support 2^n and 3.
|
88 |
+
Default: 4.
|
89 |
+
res_scale (float): Used to scale the residual in residual block.
|
90 |
+
Default: 1.
|
91 |
+
img_range (float): Image range. Default: 255.
|
92 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
93 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self,
|
97 |
+
num_in_ch,
|
98 |
+
num_out_ch,
|
99 |
+
num_feat=64,
|
100 |
+
num_group=10,
|
101 |
+
num_block=16,
|
102 |
+
squeeze_factor=16,
|
103 |
+
upscale=4,
|
104 |
+
res_scale=1,
|
105 |
+
img_range=255.,
|
106 |
+
rgb_mean=(0.4488, 0.4371, 0.4040)):
|
107 |
+
super(RCAN, self).__init__()
|
108 |
+
|
109 |
+
self.img_range = img_range
|
110 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
111 |
+
|
112 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
113 |
+
self.body = make_layer(
|
114 |
+
ResidualGroup,
|
115 |
+
num_group,
|
116 |
+
num_feat=num_feat,
|
117 |
+
num_block=num_block,
|
118 |
+
squeeze_factor=squeeze_factor,
|
119 |
+
res_scale=res_scale)
|
120 |
+
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
121 |
+
self.upsample = Upsample(upscale, num_feat)
|
122 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
self.mean = self.mean.type_as(x)
|
126 |
+
|
127 |
+
x = (x - self.mean) * self.img_range
|
128 |
+
x = self.conv_first(x)
|
129 |
+
res = self.conv_after_body(self.body(x))
|
130 |
+
res += x
|
131 |
+
|
132 |
+
x = self.conv_last(self.upsample(res))
|
133 |
+
x = x / self.img_range + self.mean
|
134 |
+
|
135 |
+
return x
|
StableSR/basicsr/archs/ridnet_arch.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import ResidualBlockNoBN, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
class MeanShift(nn.Conv2d):
|
9 |
+
""" Data normalization with mean and std.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
rgb_range (int): Maximum value of RGB.
|
13 |
+
rgb_mean (list[float]): Mean for RGB channels.
|
14 |
+
rgb_std (list[float]): Std for RGB channels.
|
15 |
+
sign (int): For subtraction, sign is -1, for addition, sign is 1.
|
16 |
+
Default: -1.
|
17 |
+
requires_grad (bool): Whether to update the self.weight and self.bias.
|
18 |
+
Default: True.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
|
22 |
+
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
23 |
+
std = torch.Tensor(rgb_std)
|
24 |
+
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
25 |
+
self.weight.data.div_(std.view(3, 1, 1, 1))
|
26 |
+
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
27 |
+
self.bias.data.div_(std)
|
28 |
+
self.requires_grad = requires_grad
|
29 |
+
|
30 |
+
|
31 |
+
class EResidualBlockNoBN(nn.Module):
|
32 |
+
"""Enhanced Residual block without BN.
|
33 |
+
|
34 |
+
There are three convolution layers in residual branch.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, in_channels, out_channels):
|
38 |
+
super(EResidualBlockNoBN, self).__init__()
|
39 |
+
|
40 |
+
self.body = nn.Sequential(
|
41 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
|
42 |
+
nn.ReLU(inplace=True),
|
43 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.Conv2d(out_channels, out_channels, 1, 1, 0),
|
46 |
+
)
|
47 |
+
self.relu = nn.ReLU(inplace=True)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
out = self.body(x)
|
51 |
+
out = self.relu(out + x)
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class MergeRun(nn.Module):
|
56 |
+
""" Merge-and-run unit.
|
57 |
+
|
58 |
+
This unit contains two branches with different dilated convolutions,
|
59 |
+
followed by a convolution to process the concatenated features.
|
60 |
+
|
61 |
+
Paper: Real Image Denoising with Feature Attention
|
62 |
+
Ref git repo: https://github.com/saeed-anwar/RIDNet
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
|
66 |
+
super(MergeRun, self).__init__()
|
67 |
+
|
68 |
+
self.dilation1 = nn.Sequential(
|
69 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
|
70 |
+
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
|
71 |
+
self.dilation2 = nn.Sequential(
|
72 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
|
73 |
+
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
|
74 |
+
|
75 |
+
self.aggregation = nn.Sequential(
|
76 |
+
nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
dilation1 = self.dilation1(x)
|
80 |
+
dilation2 = self.dilation2(x)
|
81 |
+
out = torch.cat([dilation1, dilation2], dim=1)
|
82 |
+
out = self.aggregation(out)
|
83 |
+
out = out + x
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class ChannelAttention(nn.Module):
|
88 |
+
"""Channel attention.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
num_feat (int): Channel number of intermediate features.
|
92 |
+
squeeze_factor (int): Channel squeeze factor. Default:
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, mid_channels, squeeze_factor=16):
|
96 |
+
super(ChannelAttention, self).__init__()
|
97 |
+
self.attention = nn.Sequential(
|
98 |
+
nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
|
99 |
+
nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
y = self.attention(x)
|
103 |
+
return x * y
|
104 |
+
|
105 |
+
|
106 |
+
class EAM(nn.Module):
|
107 |
+
"""Enhancement attention modules (EAM) in RIDNet.
|
108 |
+
|
109 |
+
This module contains a merge-and-run unit, a residual block,
|
110 |
+
an enhanced residual block and a feature attention unit.
|
111 |
+
|
112 |
+
Attributes:
|
113 |
+
merge: The merge-and-run unit.
|
114 |
+
block1: The residual block.
|
115 |
+
block2: The enhanced residual block.
|
116 |
+
ca: The feature/channel attention unit.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, in_channels, mid_channels, out_channels):
|
120 |
+
super(EAM, self).__init__()
|
121 |
+
|
122 |
+
self.merge = MergeRun(in_channels, mid_channels)
|
123 |
+
self.block1 = ResidualBlockNoBN(mid_channels)
|
124 |
+
self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
|
125 |
+
self.ca = ChannelAttention(out_channels)
|
126 |
+
# The residual block in the paper contains a relu after addition.
|
127 |
+
self.relu = nn.ReLU(inplace=True)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
out = self.merge(x)
|
131 |
+
out = self.relu(self.block1(out))
|
132 |
+
out = self.block2(out)
|
133 |
+
out = self.ca(out)
|
134 |
+
return out
|
135 |
+
|
136 |
+
|
137 |
+
@ARCH_REGISTRY.register()
|
138 |
+
class RIDNet(nn.Module):
|
139 |
+
"""RIDNet: Real Image Denoising with Feature Attention.
|
140 |
+
|
141 |
+
Ref git repo: https://github.com/saeed-anwar/RIDNet
|
142 |
+
|
143 |
+
Args:
|
144 |
+
in_channels (int): Channel number of inputs.
|
145 |
+
mid_channels (int): Channel number of EAM modules.
|
146 |
+
Default: 64.
|
147 |
+
out_channels (int): Channel number of outputs.
|
148 |
+
num_block (int): Number of EAM. Default: 4.
|
149 |
+
img_range (float): Image range. Default: 255.
|
150 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
151 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self,
|
155 |
+
in_channels,
|
156 |
+
mid_channels,
|
157 |
+
out_channels,
|
158 |
+
num_block=4,
|
159 |
+
img_range=255.,
|
160 |
+
rgb_mean=(0.4488, 0.4371, 0.4040),
|
161 |
+
rgb_std=(1.0, 1.0, 1.0)):
|
162 |
+
super(RIDNet, self).__init__()
|
163 |
+
|
164 |
+
self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
|
165 |
+
self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
|
166 |
+
|
167 |
+
self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
168 |
+
self.body = make_layer(
|
169 |
+
EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
|
170 |
+
self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
171 |
+
|
172 |
+
self.relu = nn.ReLU(inplace=True)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
res = self.sub_mean(x)
|
176 |
+
res = self.tail(self.body(self.relu(self.head(res))))
|
177 |
+
res = self.add_mean(res)
|
178 |
+
|
179 |
+
out = x + res
|
180 |
+
return out
|
StableSR/basicsr/archs/rrdbnet_arch.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
7 |
+
|
8 |
+
|
9 |
+
class ResidualDenseBlock(nn.Module):
|
10 |
+
"""Residual Dense Block.
|
11 |
+
|
12 |
+
Used in RRDB block in ESRGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feat (int): Channel number of intermediate features.
|
16 |
+
num_grow_ch (int): Channels for each growth.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
20 |
+
super(ResidualDenseBlock, self).__init__()
|
21 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
22 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
23 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
24 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
25 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
26 |
+
|
27 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
28 |
+
|
29 |
+
# initialization
|
30 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x1 = self.lrelu(self.conv1(x))
|
34 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
35 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
36 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
37 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
38 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
39 |
+
return x5 * 0.2 + x
|
40 |
+
|
41 |
+
|
42 |
+
class RRDB(nn.Module):
|
43 |
+
"""Residual in Residual Dense Block.
|
44 |
+
|
45 |
+
Used in RRDB-Net in ESRGAN.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
num_feat (int): Channel number of intermediate features.
|
49 |
+
num_grow_ch (int): Channels for each growth.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
53 |
+
super(RRDB, self).__init__()
|
54 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
55 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
56 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
out = self.rdb1(x)
|
60 |
+
out = self.rdb2(out)
|
61 |
+
out = self.rdb3(out)
|
62 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
63 |
+
return out * 0.2 + x
|
64 |
+
|
65 |
+
|
66 |
+
@ARCH_REGISTRY.register()
|
67 |
+
class RRDBNet(nn.Module):
|
68 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
69 |
+
in ESRGAN.
|
70 |
+
|
71 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
72 |
+
|
73 |
+
We extend ESRGAN for scale x2 and scale x1.
|
74 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
75 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
76 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
num_in_ch (int): Channel number of inputs.
|
80 |
+
num_out_ch (int): Channel number of outputs.
|
81 |
+
num_feat (int): Channel number of intermediate features.
|
82 |
+
Default: 64
|
83 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
84 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
88 |
+
super(RRDBNet, self).__init__()
|
89 |
+
self.scale = scale
|
90 |
+
if scale == 2:
|
91 |
+
num_in_ch = num_in_ch * 4
|
92 |
+
elif scale == 1:
|
93 |
+
num_in_ch = num_in_ch * 16
|
94 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
95 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
96 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
97 |
+
# upsample
|
98 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
99 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
102 |
+
|
103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.scale == 2:
|
107 |
+
feat = pixel_unshuffle(x, scale=2)
|
108 |
+
elif self.scale == 1:
|
109 |
+
feat = pixel_unshuffle(x, scale=4)
|
110 |
+
else:
|
111 |
+
feat = x
|
112 |
+
feat = self.conv_first(feat)
|
113 |
+
body_feat = self.conv_body(self.body(feat))
|
114 |
+
feat = feat + body_feat
|
115 |
+
# upsample
|
116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
118 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
119 |
+
return out
|
StableSR/basicsr/archs/spynet_arch.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
from .arch_util import flow_warp
|
8 |
+
|
9 |
+
|
10 |
+
class BasicModule(nn.Module):
|
11 |
+
"""Basic Module for SpyNet.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super(BasicModule, self).__init__()
|
16 |
+
|
17 |
+
self.basic_module = nn.Sequential(
|
18 |
+
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
19 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
20 |
+
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
21 |
+
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
22 |
+
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
23 |
+
|
24 |
+
def forward(self, tensor_input):
|
25 |
+
return self.basic_module(tensor_input)
|
26 |
+
|
27 |
+
|
28 |
+
@ARCH_REGISTRY.register()
|
29 |
+
class SpyNet(nn.Module):
|
30 |
+
"""SpyNet architecture.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
load_path (str): path for pretrained SpyNet. Default: None.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, load_path=None):
|
37 |
+
super(SpyNet, self).__init__()
|
38 |
+
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
|
39 |
+
if load_path:
|
40 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
41 |
+
|
42 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
43 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
44 |
+
|
45 |
+
def preprocess(self, tensor_input):
|
46 |
+
tensor_output = (tensor_input - self.mean) / self.std
|
47 |
+
return tensor_output
|
48 |
+
|
49 |
+
def process(self, ref, supp):
|
50 |
+
flow = []
|
51 |
+
|
52 |
+
ref = [self.preprocess(ref)]
|
53 |
+
supp = [self.preprocess(supp)]
|
54 |
+
|
55 |
+
for level in range(5):
|
56 |
+
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
|
57 |
+
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
|
58 |
+
|
59 |
+
flow = ref[0].new_zeros(
|
60 |
+
[ref[0].size(0), 2,
|
61 |
+
int(math.floor(ref[0].size(2) / 2.0)),
|
62 |
+
int(math.floor(ref[0].size(3) / 2.0))])
|
63 |
+
|
64 |
+
for level in range(len(ref)):
|
65 |
+
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
|
66 |
+
|
67 |
+
if upsampled_flow.size(2) != ref[level].size(2):
|
68 |
+
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
|
69 |
+
if upsampled_flow.size(3) != ref[level].size(3):
|
70 |
+
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
|
71 |
+
|
72 |
+
flow = self.basic_module[level](torch.cat([
|
73 |
+
ref[level],
|
74 |
+
flow_warp(
|
75 |
+
supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
|
76 |
+
upsampled_flow
|
77 |
+
], 1)) + upsampled_flow
|
78 |
+
|
79 |
+
return flow
|
80 |
+
|
81 |
+
def forward(self, ref, supp):
|
82 |
+
assert ref.size() == supp.size()
|
83 |
+
|
84 |
+
h, w = ref.size(2), ref.size(3)
|
85 |
+
w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
|
86 |
+
h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
|
87 |
+
|
88 |
+
ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
|
89 |
+
supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
|
90 |
+
|
91 |
+
flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
|
92 |
+
|
93 |
+
flow[:, 0, :, :] *= float(w) / float(w_floor)
|
94 |
+
flow[:, 1, :, :] *= float(h) / float(h_floor)
|
95 |
+
|
96 |
+
return flow
|
StableSR/basicsr/archs/srresnet_arch.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class MSRResNet(nn.Module):
|
10 |
+
"""Modified SRResNet.
|
11 |
+
|
12 |
+
A compacted version modified from SRResNet in
|
13 |
+
"Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
|
14 |
+
It uses residual blocks without BN, similar to EDSR.
|
15 |
+
Currently, it supports x2, x3 and x4 upsampling scale factor.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
19 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
20 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
21 |
+
num_block (int): Block number in the body network. Default: 16.
|
22 |
+
upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
|
26 |
+
super(MSRResNet, self).__init__()
|
27 |
+
self.upscale = upscale
|
28 |
+
|
29 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
30 |
+
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
|
31 |
+
|
32 |
+
# upsampling
|
33 |
+
if self.upscale in [2, 3]:
|
34 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
|
35 |
+
self.pixel_shuffle = nn.PixelShuffle(self.upscale)
|
36 |
+
elif self.upscale == 4:
|
37 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
38 |
+
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
39 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
40 |
+
|
41 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
42 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
43 |
+
|
44 |
+
# activation function
|
45 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
46 |
+
|
47 |
+
# initialization
|
48 |
+
default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
|
49 |
+
if self.upscale == 4:
|
50 |
+
default_init_weights(self.upconv2, 0.1)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
feat = self.lrelu(self.conv_first(x))
|
54 |
+
out = self.body(feat)
|
55 |
+
|
56 |
+
if self.upscale == 4:
|
57 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
58 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
59 |
+
elif self.upscale in [2, 3]:
|
60 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
61 |
+
|
62 |
+
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
63 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
|
64 |
+
out += base
|
65 |
+
return out
|
StableSR/basicsr/archs/srvgg_arch.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
|
6 |
+
|
7 |
+
@ARCH_REGISTRY.register(suffix='basicsr')
|
8 |
+
class SRVGGNetCompact(nn.Module):
|
9 |
+
"""A compact VGG-style network structure for super-resolution.
|
10 |
+
|
11 |
+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
12 |
+
conducted on the HR feature space.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
16 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
17 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
18 |
+
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
19 |
+
upscale (int): Upsampling factor. Default: 4.
|
20 |
+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
24 |
+
super(SRVGGNetCompact, self).__init__()
|
25 |
+
self.num_in_ch = num_in_ch
|
26 |
+
self.num_out_ch = num_out_ch
|
27 |
+
self.num_feat = num_feat
|
28 |
+
self.num_conv = num_conv
|
29 |
+
self.upscale = upscale
|
30 |
+
self.act_type = act_type
|
31 |
+
|
32 |
+
self.body = nn.ModuleList()
|
33 |
+
# the first conv
|
34 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
35 |
+
# the first activation
|
36 |
+
if act_type == 'relu':
|
37 |
+
activation = nn.ReLU(inplace=True)
|
38 |
+
elif act_type == 'prelu':
|
39 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
40 |
+
elif act_type == 'leakyrelu':
|
41 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
42 |
+
self.body.append(activation)
|
43 |
+
|
44 |
+
# the body structure
|
45 |
+
for _ in range(num_conv):
|
46 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
47 |
+
# activation
|
48 |
+
if act_type == 'relu':
|
49 |
+
activation = nn.ReLU(inplace=True)
|
50 |
+
elif act_type == 'prelu':
|
51 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
52 |
+
elif act_type == 'leakyrelu':
|
53 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
54 |
+
self.body.append(activation)
|
55 |
+
|
56 |
+
# the last conv
|
57 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
58 |
+
# upsample
|
59 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out = x
|
63 |
+
for i in range(0, len(self.body)):
|
64 |
+
out = self.body[i](out)
|
65 |
+
|
66 |
+
out = self.upsampler(out)
|
67 |
+
# add the nearest upsampled image, so that the network learns the residual
|
68 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
69 |
+
out += base
|
70 |
+
return out
|
StableSR/basicsr/archs/stylegan2_arch.py
ADDED
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
8 |
+
from basicsr.ops.upfirdn2d import upfirdn2d
|
9 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
10 |
+
|
11 |
+
|
12 |
+
class NormStyleCode(nn.Module):
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
"""Normalize the style codes.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
x (Tensor): Style codes with shape (b, c).
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Tensor: Normalized tensor.
|
22 |
+
"""
|
23 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
24 |
+
|
25 |
+
|
26 |
+
def make_resample_kernel(k):
|
27 |
+
"""Make resampling kernel for UpFirDn.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
k (list[int]): A list indicating the 1D resample kernel magnitude.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Tensor: 2D resampled kernel.
|
34 |
+
"""
|
35 |
+
k = torch.tensor(k, dtype=torch.float32)
|
36 |
+
if k.ndim == 1:
|
37 |
+
k = k[None, :] * k[:, None] # to 2D kernel, outer product
|
38 |
+
# normalize
|
39 |
+
k /= k.sum()
|
40 |
+
return k
|
41 |
+
|
42 |
+
|
43 |
+
class UpFirDnUpsample(nn.Module):
|
44 |
+
"""Upsample, FIR filter, and downsample (upsampole version).
|
45 |
+
|
46 |
+
References:
|
47 |
+
1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
|
48 |
+
2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
|
49 |
+
|
50 |
+
Args:
|
51 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
52 |
+
magnitude.
|
53 |
+
factor (int): Upsampling scale factor. Default: 2.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, resample_kernel, factor=2):
|
57 |
+
super(UpFirDnUpsample, self).__init__()
|
58 |
+
self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
|
59 |
+
self.factor = factor
|
60 |
+
|
61 |
+
pad = self.kernel.shape[0] - factor
|
62 |
+
self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
|
66 |
+
return out
|
67 |
+
|
68 |
+
def __repr__(self):
|
69 |
+
return (f'{self.__class__.__name__}(factor={self.factor})')
|
70 |
+
|
71 |
+
|
72 |
+
class UpFirDnDownsample(nn.Module):
|
73 |
+
"""Upsample, FIR filter, and downsample (downsampole version).
|
74 |
+
|
75 |
+
Args:
|
76 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
77 |
+
magnitude.
|
78 |
+
factor (int): Downsampling scale factor. Default: 2.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self, resample_kernel, factor=2):
|
82 |
+
super(UpFirDnDownsample, self).__init__()
|
83 |
+
self.kernel = make_resample_kernel(resample_kernel)
|
84 |
+
self.factor = factor
|
85 |
+
|
86 |
+
pad = self.kernel.shape[0] - factor
|
87 |
+
self.pad = ((pad + 1) // 2, pad // 2)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
|
91 |
+
return out
|
92 |
+
|
93 |
+
def __repr__(self):
|
94 |
+
return (f'{self.__class__.__name__}(factor={self.factor})')
|
95 |
+
|
96 |
+
|
97 |
+
class UpFirDnSmooth(nn.Module):
|
98 |
+
"""Upsample, FIR filter, and downsample (smooth version).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
102 |
+
magnitude.
|
103 |
+
upsample_factor (int): Upsampling scale factor. Default: 1.
|
104 |
+
downsample_factor (int): Downsampling scale factor. Default: 1.
|
105 |
+
kernel_size (int): Kernel size: Default: 1.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
|
109 |
+
super(UpFirDnSmooth, self).__init__()
|
110 |
+
self.upsample_factor = upsample_factor
|
111 |
+
self.downsample_factor = downsample_factor
|
112 |
+
self.kernel = make_resample_kernel(resample_kernel)
|
113 |
+
if upsample_factor > 1:
|
114 |
+
self.kernel = self.kernel * (upsample_factor**2)
|
115 |
+
|
116 |
+
if upsample_factor > 1:
|
117 |
+
pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
|
118 |
+
self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
|
119 |
+
elif downsample_factor > 1:
|
120 |
+
pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
|
121 |
+
self.pad = ((pad + 1) // 2, pad // 2)
|
122 |
+
else:
|
123 |
+
raise NotImplementedError
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
|
127 |
+
return out
|
128 |
+
|
129 |
+
def __repr__(self):
|
130 |
+
return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
|
131 |
+
f', downsample_factor={self.downsample_factor})')
|
132 |
+
|
133 |
+
|
134 |
+
class EqualLinear(nn.Module):
|
135 |
+
"""Equalized Linear as StyleGAN2.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
in_channels (int): Size of each sample.
|
139 |
+
out_channels (int): Size of each output sample.
|
140 |
+
bias (bool): If set to ``False``, the layer will not learn an additive
|
141 |
+
bias. Default: ``True``.
|
142 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
143 |
+
lr_mul (float): Learning rate multiplier. Default: 1.
|
144 |
+
activation (None | str): The activation after ``linear`` operation.
|
145 |
+
Supported: 'fused_lrelu', None. Default: None.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
149 |
+
super(EqualLinear, self).__init__()
|
150 |
+
self.in_channels = in_channels
|
151 |
+
self.out_channels = out_channels
|
152 |
+
self.lr_mul = lr_mul
|
153 |
+
self.activation = activation
|
154 |
+
if self.activation not in ['fused_lrelu', None]:
|
155 |
+
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
156 |
+
"Supported ones are: ['fused_lrelu', None].")
|
157 |
+
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
158 |
+
|
159 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
160 |
+
if bias:
|
161 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
162 |
+
else:
|
163 |
+
self.register_parameter('bias', None)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
if self.bias is None:
|
167 |
+
bias = None
|
168 |
+
else:
|
169 |
+
bias = self.bias * self.lr_mul
|
170 |
+
if self.activation == 'fused_lrelu':
|
171 |
+
out = F.linear(x, self.weight * self.scale)
|
172 |
+
out = fused_leaky_relu(out, bias)
|
173 |
+
else:
|
174 |
+
out = F.linear(x, self.weight * self.scale, bias=bias)
|
175 |
+
return out
|
176 |
+
|
177 |
+
def __repr__(self):
|
178 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
179 |
+
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
180 |
+
|
181 |
+
|
182 |
+
class ModulatedConv2d(nn.Module):
|
183 |
+
"""Modulated Conv2d used in StyleGAN2.
|
184 |
+
|
185 |
+
There is no bias in ModulatedConv2d.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
in_channels (int): Channel number of the input.
|
189 |
+
out_channels (int): Channel number of the output.
|
190 |
+
kernel_size (int): Size of the convolving kernel.
|
191 |
+
num_style_feat (int): Channel number of style features.
|
192 |
+
demodulate (bool): Whether to demodulate in the conv layer.
|
193 |
+
Default: True.
|
194 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
195 |
+
Default: None.
|
196 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
197 |
+
magnitude. Default: (1, 3, 3, 1).
|
198 |
+
eps (float): A value added to the denominator for numerical stability.
|
199 |
+
Default: 1e-8.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self,
|
203 |
+
in_channels,
|
204 |
+
out_channels,
|
205 |
+
kernel_size,
|
206 |
+
num_style_feat,
|
207 |
+
demodulate=True,
|
208 |
+
sample_mode=None,
|
209 |
+
resample_kernel=(1, 3, 3, 1),
|
210 |
+
eps=1e-8):
|
211 |
+
super(ModulatedConv2d, self).__init__()
|
212 |
+
self.in_channels = in_channels
|
213 |
+
self.out_channels = out_channels
|
214 |
+
self.kernel_size = kernel_size
|
215 |
+
self.demodulate = demodulate
|
216 |
+
self.sample_mode = sample_mode
|
217 |
+
self.eps = eps
|
218 |
+
|
219 |
+
if self.sample_mode == 'upsample':
|
220 |
+
self.smooth = UpFirDnSmooth(
|
221 |
+
resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
|
222 |
+
elif self.sample_mode == 'downsample':
|
223 |
+
self.smooth = UpFirDnSmooth(
|
224 |
+
resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
|
225 |
+
elif self.sample_mode is None:
|
226 |
+
pass
|
227 |
+
else:
|
228 |
+
raise ValueError(f'Wrong sample mode {self.sample_mode}, '
|
229 |
+
"supported ones are ['upsample', 'downsample', None].")
|
230 |
+
|
231 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
232 |
+
# modulation inside each modulated conv
|
233 |
+
self.modulation = EqualLinear(
|
234 |
+
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
235 |
+
|
236 |
+
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
237 |
+
self.padding = kernel_size // 2
|
238 |
+
|
239 |
+
def forward(self, x, style):
|
240 |
+
"""Forward function.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
244 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
Tensor: Modulated tensor after convolution.
|
248 |
+
"""
|
249 |
+
b, c, h, w = x.shape # c = c_in
|
250 |
+
# weight modulation
|
251 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
252 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
253 |
+
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
254 |
+
|
255 |
+
if self.demodulate:
|
256 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
257 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
258 |
+
|
259 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
260 |
+
|
261 |
+
if self.sample_mode == 'upsample':
|
262 |
+
x = x.view(1, b * c, h, w)
|
263 |
+
weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
|
264 |
+
weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
|
265 |
+
out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
|
266 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
267 |
+
out = self.smooth(out)
|
268 |
+
elif self.sample_mode == 'downsample':
|
269 |
+
x = self.smooth(x)
|
270 |
+
x = x.view(1, b * c, *x.shape[2:4])
|
271 |
+
out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
|
272 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
273 |
+
else:
|
274 |
+
x = x.view(1, b * c, h, w)
|
275 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
276 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
277 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
278 |
+
|
279 |
+
return out
|
280 |
+
|
281 |
+
def __repr__(self):
|
282 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
283 |
+
f'out_channels={self.out_channels}, '
|
284 |
+
f'kernel_size={self.kernel_size}, '
|
285 |
+
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
286 |
+
|
287 |
+
|
288 |
+
class StyleConv(nn.Module):
|
289 |
+
"""Style conv.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
in_channels (int): Channel number of the input.
|
293 |
+
out_channels (int): Channel number of the output.
|
294 |
+
kernel_size (int): Size of the convolving kernel.
|
295 |
+
num_style_feat (int): Channel number of style features.
|
296 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
297 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
298 |
+
Default: None.
|
299 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
300 |
+
magnitude. Default: (1, 3, 3, 1).
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self,
|
304 |
+
in_channels,
|
305 |
+
out_channels,
|
306 |
+
kernel_size,
|
307 |
+
num_style_feat,
|
308 |
+
demodulate=True,
|
309 |
+
sample_mode=None,
|
310 |
+
resample_kernel=(1, 3, 3, 1)):
|
311 |
+
super(StyleConv, self).__init__()
|
312 |
+
self.modulated_conv = ModulatedConv2d(
|
313 |
+
in_channels,
|
314 |
+
out_channels,
|
315 |
+
kernel_size,
|
316 |
+
num_style_feat,
|
317 |
+
demodulate=demodulate,
|
318 |
+
sample_mode=sample_mode,
|
319 |
+
resample_kernel=resample_kernel)
|
320 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
321 |
+
self.activate = FusedLeakyReLU(out_channels)
|
322 |
+
|
323 |
+
def forward(self, x, style, noise=None):
|
324 |
+
# modulate
|
325 |
+
out = self.modulated_conv(x, style)
|
326 |
+
# noise injection
|
327 |
+
if noise is None:
|
328 |
+
b, _, h, w = out.shape
|
329 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
330 |
+
out = out + self.weight * noise
|
331 |
+
# activation (with bias)
|
332 |
+
out = self.activate(out)
|
333 |
+
return out
|
334 |
+
|
335 |
+
|
336 |
+
class ToRGB(nn.Module):
|
337 |
+
"""To RGB from features.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
in_channels (int): Channel number of input.
|
341 |
+
num_style_feat (int): Channel number of style features.
|
342 |
+
upsample (bool): Whether to upsample. Default: True.
|
343 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
344 |
+
magnitude. Default: (1, 3, 3, 1).
|
345 |
+
"""
|
346 |
+
|
347 |
+
def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
|
348 |
+
super(ToRGB, self).__init__()
|
349 |
+
if upsample:
|
350 |
+
self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
|
351 |
+
else:
|
352 |
+
self.upsample = None
|
353 |
+
self.modulated_conv = ModulatedConv2d(
|
354 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
355 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
356 |
+
|
357 |
+
def forward(self, x, style, skip=None):
|
358 |
+
"""Forward function.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
362 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
363 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Tensor: RGB images.
|
367 |
+
"""
|
368 |
+
out = self.modulated_conv(x, style)
|
369 |
+
out = out + self.bias
|
370 |
+
if skip is not None:
|
371 |
+
if self.upsample:
|
372 |
+
skip = self.upsample(skip)
|
373 |
+
out = out + skip
|
374 |
+
return out
|
375 |
+
|
376 |
+
|
377 |
+
class ConstantInput(nn.Module):
|
378 |
+
"""Constant input.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
num_channel (int): Channel number of constant input.
|
382 |
+
size (int): Spatial size of constant input.
|
383 |
+
"""
|
384 |
+
|
385 |
+
def __init__(self, num_channel, size):
|
386 |
+
super(ConstantInput, self).__init__()
|
387 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
388 |
+
|
389 |
+
def forward(self, batch):
|
390 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
391 |
+
return out
|
392 |
+
|
393 |
+
|
394 |
+
@ARCH_REGISTRY.register()
|
395 |
+
class StyleGAN2Generator(nn.Module):
|
396 |
+
"""StyleGAN2 Generator.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
out_size (int): The spatial size of outputs.
|
400 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
401 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
402 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
403 |
+
StyleGAN2. Default: 2.
|
404 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
405 |
+
magnitude. A cross production will be applied to extent 1D resample
|
406 |
+
kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
407 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
408 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
409 |
+
"""
|
410 |
+
|
411 |
+
def __init__(self,
|
412 |
+
out_size,
|
413 |
+
num_style_feat=512,
|
414 |
+
num_mlp=8,
|
415 |
+
channel_multiplier=2,
|
416 |
+
resample_kernel=(1, 3, 3, 1),
|
417 |
+
lr_mlp=0.01,
|
418 |
+
narrow=1):
|
419 |
+
super(StyleGAN2Generator, self).__init__()
|
420 |
+
# Style MLP layers
|
421 |
+
self.num_style_feat = num_style_feat
|
422 |
+
style_mlp_layers = [NormStyleCode()]
|
423 |
+
for i in range(num_mlp):
|
424 |
+
style_mlp_layers.append(
|
425 |
+
EqualLinear(
|
426 |
+
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
427 |
+
activation='fused_lrelu'))
|
428 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
429 |
+
|
430 |
+
channels = {
|
431 |
+
'4': int(512 * narrow),
|
432 |
+
'8': int(512 * narrow),
|
433 |
+
'16': int(512 * narrow),
|
434 |
+
'32': int(512 * narrow),
|
435 |
+
'64': int(256 * channel_multiplier * narrow),
|
436 |
+
'128': int(128 * channel_multiplier * narrow),
|
437 |
+
'256': int(64 * channel_multiplier * narrow),
|
438 |
+
'512': int(32 * channel_multiplier * narrow),
|
439 |
+
'1024': int(16 * channel_multiplier * narrow)
|
440 |
+
}
|
441 |
+
self.channels = channels
|
442 |
+
|
443 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
444 |
+
self.style_conv1 = StyleConv(
|
445 |
+
channels['4'],
|
446 |
+
channels['4'],
|
447 |
+
kernel_size=3,
|
448 |
+
num_style_feat=num_style_feat,
|
449 |
+
demodulate=True,
|
450 |
+
sample_mode=None,
|
451 |
+
resample_kernel=resample_kernel)
|
452 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
|
453 |
+
|
454 |
+
self.log_size = int(math.log(out_size, 2))
|
455 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
456 |
+
self.num_latent = self.log_size * 2 - 2
|
457 |
+
|
458 |
+
self.style_convs = nn.ModuleList()
|
459 |
+
self.to_rgbs = nn.ModuleList()
|
460 |
+
self.noises = nn.Module()
|
461 |
+
|
462 |
+
in_channels = channels['4']
|
463 |
+
# noise
|
464 |
+
for layer_idx in range(self.num_layers):
|
465 |
+
resolution = 2**((layer_idx + 5) // 2)
|
466 |
+
shape = [1, 1, resolution, resolution]
|
467 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
468 |
+
# style convs and to_rgbs
|
469 |
+
for i in range(3, self.log_size + 1):
|
470 |
+
out_channels = channels[f'{2**i}']
|
471 |
+
self.style_convs.append(
|
472 |
+
StyleConv(
|
473 |
+
in_channels,
|
474 |
+
out_channels,
|
475 |
+
kernel_size=3,
|
476 |
+
num_style_feat=num_style_feat,
|
477 |
+
demodulate=True,
|
478 |
+
sample_mode='upsample',
|
479 |
+
resample_kernel=resample_kernel,
|
480 |
+
))
|
481 |
+
self.style_convs.append(
|
482 |
+
StyleConv(
|
483 |
+
out_channels,
|
484 |
+
out_channels,
|
485 |
+
kernel_size=3,
|
486 |
+
num_style_feat=num_style_feat,
|
487 |
+
demodulate=True,
|
488 |
+
sample_mode=None,
|
489 |
+
resample_kernel=resample_kernel))
|
490 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
|
491 |
+
in_channels = out_channels
|
492 |
+
|
493 |
+
def make_noise(self):
|
494 |
+
"""Make noise for noise injection."""
|
495 |
+
device = self.constant_input.weight.device
|
496 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
497 |
+
|
498 |
+
for i in range(3, self.log_size + 1):
|
499 |
+
for _ in range(2):
|
500 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
501 |
+
|
502 |
+
return noises
|
503 |
+
|
504 |
+
def get_latent(self, x):
|
505 |
+
return self.style_mlp(x)
|
506 |
+
|
507 |
+
def mean_latent(self, num_latent):
|
508 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
509 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
510 |
+
return latent
|
511 |
+
|
512 |
+
def forward(self,
|
513 |
+
styles,
|
514 |
+
input_is_latent=False,
|
515 |
+
noise=None,
|
516 |
+
randomize_noise=True,
|
517 |
+
truncation=1,
|
518 |
+
truncation_latent=None,
|
519 |
+
inject_index=None,
|
520 |
+
return_latents=False):
|
521 |
+
"""Forward function for StyleGAN2Generator.
|
522 |
+
|
523 |
+
Args:
|
524 |
+
styles (list[Tensor]): Sample codes of styles.
|
525 |
+
input_is_latent (bool): Whether input is latent style.
|
526 |
+
Default: False.
|
527 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
528 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is
|
529 |
+
False. Default: True.
|
530 |
+
truncation (float): TODO. Default: 1.
|
531 |
+
truncation_latent (Tensor | None): TODO. Default: None.
|
532 |
+
inject_index (int | None): The injection index for mixing noise.
|
533 |
+
Default: None.
|
534 |
+
return_latents (bool): Whether to return style latents.
|
535 |
+
Default: False.
|
536 |
+
"""
|
537 |
+
# style codes -> latents with Style MLP layer
|
538 |
+
if not input_is_latent:
|
539 |
+
styles = [self.style_mlp(s) for s in styles]
|
540 |
+
# noises
|
541 |
+
if noise is None:
|
542 |
+
if randomize_noise:
|
543 |
+
noise = [None] * self.num_layers # for each style conv layer
|
544 |
+
else: # use the stored noise
|
545 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
546 |
+
# style truncation
|
547 |
+
if truncation < 1:
|
548 |
+
style_truncation = []
|
549 |
+
for style in styles:
|
550 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
551 |
+
styles = style_truncation
|
552 |
+
# get style latent with injection
|
553 |
+
if len(styles) == 1:
|
554 |
+
inject_index = self.num_latent
|
555 |
+
|
556 |
+
if styles[0].ndim < 3:
|
557 |
+
# repeat latent code for all the layers
|
558 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
559 |
+
else: # used for encoder with different latent code for each layer
|
560 |
+
latent = styles[0]
|
561 |
+
elif len(styles) == 2: # mixing noises
|
562 |
+
if inject_index is None:
|
563 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
564 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
565 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
566 |
+
latent = torch.cat([latent1, latent2], 1)
|
567 |
+
|
568 |
+
# main generation
|
569 |
+
out = self.constant_input(latent.shape[0])
|
570 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
571 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
572 |
+
|
573 |
+
i = 1
|
574 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
575 |
+
noise[2::2], self.to_rgbs):
|
576 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
577 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
578 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
579 |
+
i += 2
|
580 |
+
|
581 |
+
image = skip
|
582 |
+
|
583 |
+
if return_latents:
|
584 |
+
return image, latent
|
585 |
+
else:
|
586 |
+
return image, None
|
587 |
+
|
588 |
+
|
589 |
+
class ScaledLeakyReLU(nn.Module):
|
590 |
+
"""Scaled LeakyReLU.
|
591 |
+
|
592 |
+
Args:
|
593 |
+
negative_slope (float): Negative slope. Default: 0.2.
|
594 |
+
"""
|
595 |
+
|
596 |
+
def __init__(self, negative_slope=0.2):
|
597 |
+
super(ScaledLeakyReLU, self).__init__()
|
598 |
+
self.negative_slope = negative_slope
|
599 |
+
|
600 |
+
def forward(self, x):
|
601 |
+
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
602 |
+
return out * math.sqrt(2)
|
603 |
+
|
604 |
+
|
605 |
+
class EqualConv2d(nn.Module):
|
606 |
+
"""Equalized Linear as StyleGAN2.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
in_channels (int): Channel number of the input.
|
610 |
+
out_channels (int): Channel number of the output.
|
611 |
+
kernel_size (int): Size of the convolving kernel.
|
612 |
+
stride (int): Stride of the convolution. Default: 1
|
613 |
+
padding (int): Zero-padding added to both sides of the input.
|
614 |
+
Default: 0.
|
615 |
+
bias (bool): If ``True``, adds a learnable bias to the output.
|
616 |
+
Default: ``True``.
|
617 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
618 |
+
"""
|
619 |
+
|
620 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
621 |
+
super(EqualConv2d, self).__init__()
|
622 |
+
self.in_channels = in_channels
|
623 |
+
self.out_channels = out_channels
|
624 |
+
self.kernel_size = kernel_size
|
625 |
+
self.stride = stride
|
626 |
+
self.padding = padding
|
627 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
628 |
+
|
629 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
630 |
+
if bias:
|
631 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
632 |
+
else:
|
633 |
+
self.register_parameter('bias', None)
|
634 |
+
|
635 |
+
def forward(self, x):
|
636 |
+
out = F.conv2d(
|
637 |
+
x,
|
638 |
+
self.weight * self.scale,
|
639 |
+
bias=self.bias,
|
640 |
+
stride=self.stride,
|
641 |
+
padding=self.padding,
|
642 |
+
)
|
643 |
+
|
644 |
+
return out
|
645 |
+
|
646 |
+
def __repr__(self):
|
647 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
648 |
+
f'out_channels={self.out_channels}, '
|
649 |
+
f'kernel_size={self.kernel_size},'
|
650 |
+
f' stride={self.stride}, padding={self.padding}, '
|
651 |
+
f'bias={self.bias is not None})')
|
652 |
+
|
653 |
+
|
654 |
+
class ConvLayer(nn.Sequential):
|
655 |
+
"""Conv Layer used in StyleGAN2 Discriminator.
|
656 |
+
|
657 |
+
Args:
|
658 |
+
in_channels (int): Channel number of the input.
|
659 |
+
out_channels (int): Channel number of the output.
|
660 |
+
kernel_size (int): Kernel size.
|
661 |
+
downsample (bool): Whether downsample by a factor of 2.
|
662 |
+
Default: False.
|
663 |
+
resample_kernel (list[int]): A list indicating the 1D resample
|
664 |
+
kernel magnitude. A cross production will be applied to
|
665 |
+
extent 1D resample kernel to 2D resample kernel.
|
666 |
+
Default: (1, 3, 3, 1).
|
667 |
+
bias (bool): Whether with bias. Default: True.
|
668 |
+
activate (bool): Whether use activateion. Default: True.
|
669 |
+
"""
|
670 |
+
|
671 |
+
def __init__(self,
|
672 |
+
in_channels,
|
673 |
+
out_channels,
|
674 |
+
kernel_size,
|
675 |
+
downsample=False,
|
676 |
+
resample_kernel=(1, 3, 3, 1),
|
677 |
+
bias=True,
|
678 |
+
activate=True):
|
679 |
+
layers = []
|
680 |
+
# downsample
|
681 |
+
if downsample:
|
682 |
+
layers.append(
|
683 |
+
UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
|
684 |
+
stride = 2
|
685 |
+
self.padding = 0
|
686 |
+
else:
|
687 |
+
stride = 1
|
688 |
+
self.padding = kernel_size // 2
|
689 |
+
# conv
|
690 |
+
layers.append(
|
691 |
+
EqualConv2d(
|
692 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
693 |
+
and not activate))
|
694 |
+
# activation
|
695 |
+
if activate:
|
696 |
+
if bias:
|
697 |
+
layers.append(FusedLeakyReLU(out_channels))
|
698 |
+
else:
|
699 |
+
layers.append(ScaledLeakyReLU(0.2))
|
700 |
+
|
701 |
+
super(ConvLayer, self).__init__(*layers)
|
702 |
+
|
703 |
+
|
704 |
+
class ResBlock(nn.Module):
|
705 |
+
"""Residual block used in StyleGAN2 Discriminator.
|
706 |
+
|
707 |
+
Args:
|
708 |
+
in_channels (int): Channel number of the input.
|
709 |
+
out_channels (int): Channel number of the output.
|
710 |
+
resample_kernel (list[int]): A list indicating the 1D resample
|
711 |
+
kernel magnitude. A cross production will be applied to
|
712 |
+
extent 1D resample kernel to 2D resample kernel.
|
713 |
+
Default: (1, 3, 3, 1).
|
714 |
+
"""
|
715 |
+
|
716 |
+
def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
|
717 |
+
super(ResBlock, self).__init__()
|
718 |
+
|
719 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
720 |
+
self.conv2 = ConvLayer(
|
721 |
+
in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
|
722 |
+
self.skip = ConvLayer(
|
723 |
+
in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
|
724 |
+
|
725 |
+
def forward(self, x):
|
726 |
+
out = self.conv1(x)
|
727 |
+
out = self.conv2(out)
|
728 |
+
skip = self.skip(x)
|
729 |
+
out = (out + skip) / math.sqrt(2)
|
730 |
+
return out
|
731 |
+
|
732 |
+
|
733 |
+
@ARCH_REGISTRY.register()
|
734 |
+
class StyleGAN2Discriminator(nn.Module):
|
735 |
+
"""StyleGAN2 Discriminator.
|
736 |
+
|
737 |
+
Args:
|
738 |
+
out_size (int): The spatial size of outputs.
|
739 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
740 |
+
StyleGAN2. Default: 2.
|
741 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
742 |
+
magnitude. A cross production will be applied to extent 1D resample
|
743 |
+
kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
744 |
+
stddev_group (int): For group stddev statistics. Default: 4.
|
745 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
746 |
+
"""
|
747 |
+
|
748 |
+
def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
|
749 |
+
super(StyleGAN2Discriminator, self).__init__()
|
750 |
+
|
751 |
+
channels = {
|
752 |
+
'4': int(512 * narrow),
|
753 |
+
'8': int(512 * narrow),
|
754 |
+
'16': int(512 * narrow),
|
755 |
+
'32': int(512 * narrow),
|
756 |
+
'64': int(256 * channel_multiplier * narrow),
|
757 |
+
'128': int(128 * channel_multiplier * narrow),
|
758 |
+
'256': int(64 * channel_multiplier * narrow),
|
759 |
+
'512': int(32 * channel_multiplier * narrow),
|
760 |
+
'1024': int(16 * channel_multiplier * narrow)
|
761 |
+
}
|
762 |
+
|
763 |
+
log_size = int(math.log(out_size, 2))
|
764 |
+
|
765 |
+
conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
|
766 |
+
|
767 |
+
in_channels = channels[f'{out_size}']
|
768 |
+
for i in range(log_size, 2, -1):
|
769 |
+
out_channels = channels[f'{2**(i - 1)}']
|
770 |
+
conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
|
771 |
+
in_channels = out_channels
|
772 |
+
self.conv_body = nn.Sequential(*conv_body)
|
773 |
+
|
774 |
+
self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
|
775 |
+
self.final_linear = nn.Sequential(
|
776 |
+
EqualLinear(
|
777 |
+
channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
|
778 |
+
EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
|
779 |
+
)
|
780 |
+
self.stddev_group = stddev_group
|
781 |
+
self.stddev_feat = 1
|
782 |
+
|
783 |
+
def forward(self, x):
|
784 |
+
out = self.conv_body(x)
|
785 |
+
|
786 |
+
b, c, h, w = out.shape
|
787 |
+
# concatenate a group stddev statistics to out
|
788 |
+
group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size
|
789 |
+
stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
|
790 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
791 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
792 |
+
stddev = stddev.repeat(group, 1, h, w)
|
793 |
+
out = torch.cat([out, stddev], 1)
|
794 |
+
|
795 |
+
out = self.final_conv(out)
|
796 |
+
out = out.view(b, -1)
|
797 |
+
out = self.final_linear(out)
|
798 |
+
|
799 |
+
return out
|
StableSR/basicsr/archs/stylegan2_bilinear_arch.py
ADDED
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
8 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
9 |
+
|
10 |
+
|
11 |
+
class NormStyleCode(nn.Module):
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
"""Normalize the style codes.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
x (Tensor): Style codes with shape (b, c).
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor: Normalized tensor.
|
21 |
+
"""
|
22 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
23 |
+
|
24 |
+
|
25 |
+
class EqualLinear(nn.Module):
|
26 |
+
"""Equalized Linear as StyleGAN2.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
in_channels (int): Size of each sample.
|
30 |
+
out_channels (int): Size of each output sample.
|
31 |
+
bias (bool): If set to ``False``, the layer will not learn an additive
|
32 |
+
bias. Default: ``True``.
|
33 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
34 |
+
lr_mul (float): Learning rate multiplier. Default: 1.
|
35 |
+
activation (None | str): The activation after ``linear`` operation.
|
36 |
+
Supported: 'fused_lrelu', None. Default: None.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
40 |
+
super(EqualLinear, self).__init__()
|
41 |
+
self.in_channels = in_channels
|
42 |
+
self.out_channels = out_channels
|
43 |
+
self.lr_mul = lr_mul
|
44 |
+
self.activation = activation
|
45 |
+
if self.activation not in ['fused_lrelu', None]:
|
46 |
+
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
47 |
+
"Supported ones are: ['fused_lrelu', None].")
|
48 |
+
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
49 |
+
|
50 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
51 |
+
if bias:
|
52 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
53 |
+
else:
|
54 |
+
self.register_parameter('bias', None)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if self.bias is None:
|
58 |
+
bias = None
|
59 |
+
else:
|
60 |
+
bias = self.bias * self.lr_mul
|
61 |
+
if self.activation == 'fused_lrelu':
|
62 |
+
out = F.linear(x, self.weight * self.scale)
|
63 |
+
out = fused_leaky_relu(out, bias)
|
64 |
+
else:
|
65 |
+
out = F.linear(x, self.weight * self.scale, bias=bias)
|
66 |
+
return out
|
67 |
+
|
68 |
+
def __repr__(self):
|
69 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
70 |
+
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
71 |
+
|
72 |
+
|
73 |
+
class ModulatedConv2d(nn.Module):
|
74 |
+
"""Modulated Conv2d used in StyleGAN2.
|
75 |
+
|
76 |
+
There is no bias in ModulatedConv2d.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
in_channels (int): Channel number of the input.
|
80 |
+
out_channels (int): Channel number of the output.
|
81 |
+
kernel_size (int): Size of the convolving kernel.
|
82 |
+
num_style_feat (int): Channel number of style features.
|
83 |
+
demodulate (bool): Whether to demodulate in the conv layer.
|
84 |
+
Default: True.
|
85 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
86 |
+
Default: None.
|
87 |
+
eps (float): A value added to the denominator for numerical stability.
|
88 |
+
Default: 1e-8.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self,
|
92 |
+
in_channels,
|
93 |
+
out_channels,
|
94 |
+
kernel_size,
|
95 |
+
num_style_feat,
|
96 |
+
demodulate=True,
|
97 |
+
sample_mode=None,
|
98 |
+
eps=1e-8,
|
99 |
+
interpolation_mode='bilinear'):
|
100 |
+
super(ModulatedConv2d, self).__init__()
|
101 |
+
self.in_channels = in_channels
|
102 |
+
self.out_channels = out_channels
|
103 |
+
self.kernel_size = kernel_size
|
104 |
+
self.demodulate = demodulate
|
105 |
+
self.sample_mode = sample_mode
|
106 |
+
self.eps = eps
|
107 |
+
self.interpolation_mode = interpolation_mode
|
108 |
+
if self.interpolation_mode == 'nearest':
|
109 |
+
self.align_corners = None
|
110 |
+
else:
|
111 |
+
self.align_corners = False
|
112 |
+
|
113 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
114 |
+
# modulation inside each modulated conv
|
115 |
+
self.modulation = EqualLinear(
|
116 |
+
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
117 |
+
|
118 |
+
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
119 |
+
self.padding = kernel_size // 2
|
120 |
+
|
121 |
+
def forward(self, x, style):
|
122 |
+
"""Forward function.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
126 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
Tensor: Modulated tensor after convolution.
|
130 |
+
"""
|
131 |
+
b, c, h, w = x.shape # c = c_in
|
132 |
+
# weight modulation
|
133 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
134 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
135 |
+
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
136 |
+
|
137 |
+
if self.demodulate:
|
138 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
139 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
140 |
+
|
141 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
142 |
+
|
143 |
+
if self.sample_mode == 'upsample':
|
144 |
+
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
145 |
+
elif self.sample_mode == 'downsample':
|
146 |
+
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
|
147 |
+
|
148 |
+
b, c, h, w = x.shape
|
149 |
+
x = x.view(1, b * c, h, w)
|
150 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
151 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
152 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
153 |
+
|
154 |
+
return out
|
155 |
+
|
156 |
+
def __repr__(self):
|
157 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
158 |
+
f'out_channels={self.out_channels}, '
|
159 |
+
f'kernel_size={self.kernel_size}, '
|
160 |
+
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
161 |
+
|
162 |
+
|
163 |
+
class StyleConv(nn.Module):
|
164 |
+
"""Style conv.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
in_channels (int): Channel number of the input.
|
168 |
+
out_channels (int): Channel number of the output.
|
169 |
+
kernel_size (int): Size of the convolving kernel.
|
170 |
+
num_style_feat (int): Channel number of style features.
|
171 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
172 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
173 |
+
Default: None.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self,
|
177 |
+
in_channels,
|
178 |
+
out_channels,
|
179 |
+
kernel_size,
|
180 |
+
num_style_feat,
|
181 |
+
demodulate=True,
|
182 |
+
sample_mode=None,
|
183 |
+
interpolation_mode='bilinear'):
|
184 |
+
super(StyleConv, self).__init__()
|
185 |
+
self.modulated_conv = ModulatedConv2d(
|
186 |
+
in_channels,
|
187 |
+
out_channels,
|
188 |
+
kernel_size,
|
189 |
+
num_style_feat,
|
190 |
+
demodulate=demodulate,
|
191 |
+
sample_mode=sample_mode,
|
192 |
+
interpolation_mode=interpolation_mode)
|
193 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
194 |
+
self.activate = FusedLeakyReLU(out_channels)
|
195 |
+
|
196 |
+
def forward(self, x, style, noise=None):
|
197 |
+
# modulate
|
198 |
+
out = self.modulated_conv(x, style)
|
199 |
+
# noise injection
|
200 |
+
if noise is None:
|
201 |
+
b, _, h, w = out.shape
|
202 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
203 |
+
out = out + self.weight * noise
|
204 |
+
# activation (with bias)
|
205 |
+
out = self.activate(out)
|
206 |
+
return out
|
207 |
+
|
208 |
+
|
209 |
+
class ToRGB(nn.Module):
|
210 |
+
"""To RGB from features.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
in_channels (int): Channel number of input.
|
214 |
+
num_style_feat (int): Channel number of style features.
|
215 |
+
upsample (bool): Whether to upsample. Default: True.
|
216 |
+
"""
|
217 |
+
|
218 |
+
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
|
219 |
+
super(ToRGB, self).__init__()
|
220 |
+
self.upsample = upsample
|
221 |
+
self.interpolation_mode = interpolation_mode
|
222 |
+
if self.interpolation_mode == 'nearest':
|
223 |
+
self.align_corners = None
|
224 |
+
else:
|
225 |
+
self.align_corners = False
|
226 |
+
self.modulated_conv = ModulatedConv2d(
|
227 |
+
in_channels,
|
228 |
+
3,
|
229 |
+
kernel_size=1,
|
230 |
+
num_style_feat=num_style_feat,
|
231 |
+
demodulate=False,
|
232 |
+
sample_mode=None,
|
233 |
+
interpolation_mode=interpolation_mode)
|
234 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
235 |
+
|
236 |
+
def forward(self, x, style, skip=None):
|
237 |
+
"""Forward function.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
241 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
242 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
Tensor: RGB images.
|
246 |
+
"""
|
247 |
+
out = self.modulated_conv(x, style)
|
248 |
+
out = out + self.bias
|
249 |
+
if skip is not None:
|
250 |
+
if self.upsample:
|
251 |
+
skip = F.interpolate(
|
252 |
+
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
253 |
+
out = out + skip
|
254 |
+
return out
|
255 |
+
|
256 |
+
|
257 |
+
class ConstantInput(nn.Module):
|
258 |
+
"""Constant input.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
num_channel (int): Channel number of constant input.
|
262 |
+
size (int): Spatial size of constant input.
|
263 |
+
"""
|
264 |
+
|
265 |
+
def __init__(self, num_channel, size):
|
266 |
+
super(ConstantInput, self).__init__()
|
267 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
268 |
+
|
269 |
+
def forward(self, batch):
|
270 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
271 |
+
return out
|
272 |
+
|
273 |
+
|
274 |
+
@ARCH_REGISTRY.register(suffix='basicsr')
|
275 |
+
class StyleGAN2GeneratorBilinear(nn.Module):
|
276 |
+
"""StyleGAN2 Generator.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
out_size (int): The spatial size of outputs.
|
280 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
281 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
282 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
283 |
+
StyleGAN2. Default: 2.
|
284 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
285 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
286 |
+
"""
|
287 |
+
|
288 |
+
def __init__(self,
|
289 |
+
out_size,
|
290 |
+
num_style_feat=512,
|
291 |
+
num_mlp=8,
|
292 |
+
channel_multiplier=2,
|
293 |
+
lr_mlp=0.01,
|
294 |
+
narrow=1,
|
295 |
+
interpolation_mode='bilinear'):
|
296 |
+
super(StyleGAN2GeneratorBilinear, self).__init__()
|
297 |
+
# Style MLP layers
|
298 |
+
self.num_style_feat = num_style_feat
|
299 |
+
style_mlp_layers = [NormStyleCode()]
|
300 |
+
for i in range(num_mlp):
|
301 |
+
style_mlp_layers.append(
|
302 |
+
EqualLinear(
|
303 |
+
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
304 |
+
activation='fused_lrelu'))
|
305 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
306 |
+
|
307 |
+
channels = {
|
308 |
+
'4': int(512 * narrow),
|
309 |
+
'8': int(512 * narrow),
|
310 |
+
'16': int(512 * narrow),
|
311 |
+
'32': int(512 * narrow),
|
312 |
+
'64': int(256 * channel_multiplier * narrow),
|
313 |
+
'128': int(128 * channel_multiplier * narrow),
|
314 |
+
'256': int(64 * channel_multiplier * narrow),
|
315 |
+
'512': int(32 * channel_multiplier * narrow),
|
316 |
+
'1024': int(16 * channel_multiplier * narrow)
|
317 |
+
}
|
318 |
+
self.channels = channels
|
319 |
+
|
320 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
321 |
+
self.style_conv1 = StyleConv(
|
322 |
+
channels['4'],
|
323 |
+
channels['4'],
|
324 |
+
kernel_size=3,
|
325 |
+
num_style_feat=num_style_feat,
|
326 |
+
demodulate=True,
|
327 |
+
sample_mode=None,
|
328 |
+
interpolation_mode=interpolation_mode)
|
329 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
|
330 |
+
|
331 |
+
self.log_size = int(math.log(out_size, 2))
|
332 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
333 |
+
self.num_latent = self.log_size * 2 - 2
|
334 |
+
|
335 |
+
self.style_convs = nn.ModuleList()
|
336 |
+
self.to_rgbs = nn.ModuleList()
|
337 |
+
self.noises = nn.Module()
|
338 |
+
|
339 |
+
in_channels = channels['4']
|
340 |
+
# noise
|
341 |
+
for layer_idx in range(self.num_layers):
|
342 |
+
resolution = 2**((layer_idx + 5) // 2)
|
343 |
+
shape = [1, 1, resolution, resolution]
|
344 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
345 |
+
# style convs and to_rgbs
|
346 |
+
for i in range(3, self.log_size + 1):
|
347 |
+
out_channels = channels[f'{2**i}']
|
348 |
+
self.style_convs.append(
|
349 |
+
StyleConv(
|
350 |
+
in_channels,
|
351 |
+
out_channels,
|
352 |
+
kernel_size=3,
|
353 |
+
num_style_feat=num_style_feat,
|
354 |
+
demodulate=True,
|
355 |
+
sample_mode='upsample',
|
356 |
+
interpolation_mode=interpolation_mode))
|
357 |
+
self.style_convs.append(
|
358 |
+
StyleConv(
|
359 |
+
out_channels,
|
360 |
+
out_channels,
|
361 |
+
kernel_size=3,
|
362 |
+
num_style_feat=num_style_feat,
|
363 |
+
demodulate=True,
|
364 |
+
sample_mode=None,
|
365 |
+
interpolation_mode=interpolation_mode))
|
366 |
+
self.to_rgbs.append(
|
367 |
+
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
|
368 |
+
in_channels = out_channels
|
369 |
+
|
370 |
+
def make_noise(self):
|
371 |
+
"""Make noise for noise injection."""
|
372 |
+
device = self.constant_input.weight.device
|
373 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
374 |
+
|
375 |
+
for i in range(3, self.log_size + 1):
|
376 |
+
for _ in range(2):
|
377 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
378 |
+
|
379 |
+
return noises
|
380 |
+
|
381 |
+
def get_latent(self, x):
|
382 |
+
return self.style_mlp(x)
|
383 |
+
|
384 |
+
def mean_latent(self, num_latent):
|
385 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
386 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
387 |
+
return latent
|
388 |
+
|
389 |
+
def forward(self,
|
390 |
+
styles,
|
391 |
+
input_is_latent=False,
|
392 |
+
noise=None,
|
393 |
+
randomize_noise=True,
|
394 |
+
truncation=1,
|
395 |
+
truncation_latent=None,
|
396 |
+
inject_index=None,
|
397 |
+
return_latents=False):
|
398 |
+
"""Forward function for StyleGAN2Generator.
|
399 |
+
|
400 |
+
Args:
|
401 |
+
styles (list[Tensor]): Sample codes of styles.
|
402 |
+
input_is_latent (bool): Whether input is latent style.
|
403 |
+
Default: False.
|
404 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
405 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is
|
406 |
+
False. Default: True.
|
407 |
+
truncation (float): TODO. Default: 1.
|
408 |
+
truncation_latent (Tensor | None): TODO. Default: None.
|
409 |
+
inject_index (int | None): The injection index for mixing noise.
|
410 |
+
Default: None.
|
411 |
+
return_latents (bool): Whether to return style latents.
|
412 |
+
Default: False.
|
413 |
+
"""
|
414 |
+
# style codes -> latents with Style MLP layer
|
415 |
+
if not input_is_latent:
|
416 |
+
styles = [self.style_mlp(s) for s in styles]
|
417 |
+
# noises
|
418 |
+
if noise is None:
|
419 |
+
if randomize_noise:
|
420 |
+
noise = [None] * self.num_layers # for each style conv layer
|
421 |
+
else: # use the stored noise
|
422 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
423 |
+
# style truncation
|
424 |
+
if truncation < 1:
|
425 |
+
style_truncation = []
|
426 |
+
for style in styles:
|
427 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
428 |
+
styles = style_truncation
|
429 |
+
# get style latent with injection
|
430 |
+
if len(styles) == 1:
|
431 |
+
inject_index = self.num_latent
|
432 |
+
|
433 |
+
if styles[0].ndim < 3:
|
434 |
+
# repeat latent code for all the layers
|
435 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
436 |
+
else: # used for encoder with different latent code for each layer
|
437 |
+
latent = styles[0]
|
438 |
+
elif len(styles) == 2: # mixing noises
|
439 |
+
if inject_index is None:
|
440 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
441 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
442 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
443 |
+
latent = torch.cat([latent1, latent2], 1)
|
444 |
+
|
445 |
+
# main generation
|
446 |
+
out = self.constant_input(latent.shape[0])
|
447 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
448 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
449 |
+
|
450 |
+
i = 1
|
451 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
452 |
+
noise[2::2], self.to_rgbs):
|
453 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
454 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
455 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
456 |
+
i += 2
|
457 |
+
|
458 |
+
image = skip
|
459 |
+
|
460 |
+
if return_latents:
|
461 |
+
return image, latent
|
462 |
+
else:
|
463 |
+
return image, None
|
464 |
+
|
465 |
+
|
466 |
+
class ScaledLeakyReLU(nn.Module):
|
467 |
+
"""Scaled LeakyReLU.
|
468 |
+
|
469 |
+
Args:
|
470 |
+
negative_slope (float): Negative slope. Default: 0.2.
|
471 |
+
"""
|
472 |
+
|
473 |
+
def __init__(self, negative_slope=0.2):
|
474 |
+
super(ScaledLeakyReLU, self).__init__()
|
475 |
+
self.negative_slope = negative_slope
|
476 |
+
|
477 |
+
def forward(self, x):
|
478 |
+
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
479 |
+
return out * math.sqrt(2)
|
480 |
+
|
481 |
+
|
482 |
+
class EqualConv2d(nn.Module):
|
483 |
+
"""Equalized Linear as StyleGAN2.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
in_channels (int): Channel number of the input.
|
487 |
+
out_channels (int): Channel number of the output.
|
488 |
+
kernel_size (int): Size of the convolving kernel.
|
489 |
+
stride (int): Stride of the convolution. Default: 1
|
490 |
+
padding (int): Zero-padding added to both sides of the input.
|
491 |
+
Default: 0.
|
492 |
+
bias (bool): If ``True``, adds a learnable bias to the output.
|
493 |
+
Default: ``True``.
|
494 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
495 |
+
"""
|
496 |
+
|
497 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
498 |
+
super(EqualConv2d, self).__init__()
|
499 |
+
self.in_channels = in_channels
|
500 |
+
self.out_channels = out_channels
|
501 |
+
self.kernel_size = kernel_size
|
502 |
+
self.stride = stride
|
503 |
+
self.padding = padding
|
504 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
505 |
+
|
506 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
507 |
+
if bias:
|
508 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
509 |
+
else:
|
510 |
+
self.register_parameter('bias', None)
|
511 |
+
|
512 |
+
def forward(self, x):
|
513 |
+
out = F.conv2d(
|
514 |
+
x,
|
515 |
+
self.weight * self.scale,
|
516 |
+
bias=self.bias,
|
517 |
+
stride=self.stride,
|
518 |
+
padding=self.padding,
|
519 |
+
)
|
520 |
+
|
521 |
+
return out
|
522 |
+
|
523 |
+
def __repr__(self):
|
524 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
525 |
+
f'out_channels={self.out_channels}, '
|
526 |
+
f'kernel_size={self.kernel_size},'
|
527 |
+
f' stride={self.stride}, padding={self.padding}, '
|
528 |
+
f'bias={self.bias is not None})')
|
529 |
+
|
530 |
+
|
531 |
+
class ConvLayer(nn.Sequential):
|
532 |
+
"""Conv Layer used in StyleGAN2 Discriminator.
|
533 |
+
|
534 |
+
Args:
|
535 |
+
in_channels (int): Channel number of the input.
|
536 |
+
out_channels (int): Channel number of the output.
|
537 |
+
kernel_size (int): Kernel size.
|
538 |
+
downsample (bool): Whether downsample by a factor of 2.
|
539 |
+
Default: False.
|
540 |
+
bias (bool): Whether with bias. Default: True.
|
541 |
+
activate (bool): Whether use activateion. Default: True.
|
542 |
+
"""
|
543 |
+
|
544 |
+
def __init__(self,
|
545 |
+
in_channels,
|
546 |
+
out_channels,
|
547 |
+
kernel_size,
|
548 |
+
downsample=False,
|
549 |
+
bias=True,
|
550 |
+
activate=True,
|
551 |
+
interpolation_mode='bilinear'):
|
552 |
+
layers = []
|
553 |
+
self.interpolation_mode = interpolation_mode
|
554 |
+
# downsample
|
555 |
+
if downsample:
|
556 |
+
if self.interpolation_mode == 'nearest':
|
557 |
+
self.align_corners = None
|
558 |
+
else:
|
559 |
+
self.align_corners = False
|
560 |
+
|
561 |
+
layers.append(
|
562 |
+
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
|
563 |
+
stride = 1
|
564 |
+
self.padding = kernel_size // 2
|
565 |
+
# conv
|
566 |
+
layers.append(
|
567 |
+
EqualConv2d(
|
568 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
569 |
+
and not activate))
|
570 |
+
# activation
|
571 |
+
if activate:
|
572 |
+
if bias:
|
573 |
+
layers.append(FusedLeakyReLU(out_channels))
|
574 |
+
else:
|
575 |
+
layers.append(ScaledLeakyReLU(0.2))
|
576 |
+
|
577 |
+
super(ConvLayer, self).__init__(*layers)
|
578 |
+
|
579 |
+
|
580 |
+
class ResBlock(nn.Module):
|
581 |
+
"""Residual block used in StyleGAN2 Discriminator.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
in_channels (int): Channel number of the input.
|
585 |
+
out_channels (int): Channel number of the output.
|
586 |
+
"""
|
587 |
+
|
588 |
+
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
|
589 |
+
super(ResBlock, self).__init__()
|
590 |
+
|
591 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
592 |
+
self.conv2 = ConvLayer(
|
593 |
+
in_channels,
|
594 |
+
out_channels,
|
595 |
+
3,
|
596 |
+
downsample=True,
|
597 |
+
interpolation_mode=interpolation_mode,
|
598 |
+
bias=True,
|
599 |
+
activate=True)
|
600 |
+
self.skip = ConvLayer(
|
601 |
+
in_channels,
|
602 |
+
out_channels,
|
603 |
+
1,
|
604 |
+
downsample=True,
|
605 |
+
interpolation_mode=interpolation_mode,
|
606 |
+
bias=False,
|
607 |
+
activate=False)
|
608 |
+
|
609 |
+
def forward(self, x):
|
610 |
+
out = self.conv1(x)
|
611 |
+
out = self.conv2(out)
|
612 |
+
skip = self.skip(x)
|
613 |
+
out = (out + skip) / math.sqrt(2)
|
614 |
+
return out
|
StableSR/basicsr/archs/swinir_arch.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/JingyunLiang/SwinIR
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.utils.checkpoint as checkpoint
|
9 |
+
|
10 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
11 |
+
from .arch_util import to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
15 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
16 |
+
|
17 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
18 |
+
"""
|
19 |
+
if drop_prob == 0. or not training:
|
20 |
+
return x
|
21 |
+
keep_prob = 1 - drop_prob
|
22 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
23 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
24 |
+
random_tensor.floor_() # binarize
|
25 |
+
output = x.div(keep_prob) * random_tensor
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
class DropPath(nn.Module):
|
30 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
31 |
+
|
32 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, drop_prob=None):
|
36 |
+
super(DropPath, self).__init__()
|
37 |
+
self.drop_prob = drop_prob
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
return drop_path(x, self.drop_prob, self.training)
|
41 |
+
|
42 |
+
|
43 |
+
class Mlp(nn.Module):
|
44 |
+
|
45 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
46 |
+
super().__init__()
|
47 |
+
out_features = out_features or in_features
|
48 |
+
hidden_features = hidden_features or in_features
|
49 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
50 |
+
self.act = act_layer()
|
51 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
52 |
+
self.drop = nn.Dropout(drop)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.fc1(x)
|
56 |
+
x = self.act(x)
|
57 |
+
x = self.drop(x)
|
58 |
+
x = self.fc2(x)
|
59 |
+
x = self.drop(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
def window_partition(x, window_size):
|
64 |
+
"""
|
65 |
+
Args:
|
66 |
+
x: (b, h, w, c)
|
67 |
+
window_size (int): window size
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
windows: (num_windows*b, window_size, window_size, c)
|
71 |
+
"""
|
72 |
+
b, h, w, c = x.shape
|
73 |
+
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
|
74 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
|
75 |
+
return windows
|
76 |
+
|
77 |
+
|
78 |
+
def window_reverse(windows, window_size, h, w):
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
windows: (num_windows*b, window_size, window_size, c)
|
82 |
+
window_size (int): Window size
|
83 |
+
h (int): Height of image
|
84 |
+
w (int): Width of image
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
x: (b, h, w, c)
|
88 |
+
"""
|
89 |
+
b = int(windows.shape[0] / (h * w / window_size / window_size))
|
90 |
+
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
|
91 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class WindowAttention(nn.Module):
|
96 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
97 |
+
It supports both of shifted and non-shifted window.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
dim (int): Number of input channels.
|
101 |
+
window_size (tuple[int]): The height and width of the window.
|
102 |
+
num_heads (int): Number of attention heads.
|
103 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
104 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
105 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
106 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
110 |
+
|
111 |
+
super().__init__()
|
112 |
+
self.dim = dim
|
113 |
+
self.window_size = window_size # Wh, Ww
|
114 |
+
self.num_heads = num_heads
|
115 |
+
head_dim = dim // num_heads
|
116 |
+
self.scale = qk_scale or head_dim**-0.5
|
117 |
+
|
118 |
+
# define a parameter table of relative position bias
|
119 |
+
self.relative_position_bias_table = nn.Parameter(
|
120 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
121 |
+
|
122 |
+
# get pair-wise relative position index for each token inside the window
|
123 |
+
coords_h = torch.arange(self.window_size[0])
|
124 |
+
coords_w = torch.arange(self.window_size[1])
|
125 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
126 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
127 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
128 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
129 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
130 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
131 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
132 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
133 |
+
self.register_buffer('relative_position_index', relative_position_index)
|
134 |
+
|
135 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
136 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
137 |
+
self.proj = nn.Linear(dim, dim)
|
138 |
+
|
139 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
140 |
+
|
141 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
142 |
+
self.softmax = nn.Softmax(dim=-1)
|
143 |
+
|
144 |
+
def forward(self, x, mask=None):
|
145 |
+
"""
|
146 |
+
Args:
|
147 |
+
x: input features with shape of (num_windows*b, n, c)
|
148 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
149 |
+
"""
|
150 |
+
b_, n, c = x.shape
|
151 |
+
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
152 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
153 |
+
|
154 |
+
q = q * self.scale
|
155 |
+
attn = (q @ k.transpose(-2, -1))
|
156 |
+
|
157 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
158 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
159 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
160 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
161 |
+
|
162 |
+
if mask is not None:
|
163 |
+
nw = mask.shape[0]
|
164 |
+
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
165 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
166 |
+
attn = self.softmax(attn)
|
167 |
+
else:
|
168 |
+
attn = self.softmax(attn)
|
169 |
+
|
170 |
+
attn = self.attn_drop(attn)
|
171 |
+
|
172 |
+
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
|
173 |
+
x = self.proj(x)
|
174 |
+
x = self.proj_drop(x)
|
175 |
+
return x
|
176 |
+
|
177 |
+
def extra_repr(self) -> str:
|
178 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
179 |
+
|
180 |
+
def flops(self, n):
|
181 |
+
# calculate flops for 1 window with token length of n
|
182 |
+
flops = 0
|
183 |
+
# qkv = self.qkv(x)
|
184 |
+
flops += n * self.dim * 3 * self.dim
|
185 |
+
# attn = (q @ k.transpose(-2, -1))
|
186 |
+
flops += self.num_heads * n * (self.dim // self.num_heads) * n
|
187 |
+
# x = (attn @ v)
|
188 |
+
flops += self.num_heads * n * n * (self.dim // self.num_heads)
|
189 |
+
# x = self.proj(x)
|
190 |
+
flops += n * self.dim * self.dim
|
191 |
+
return flops
|
192 |
+
|
193 |
+
|
194 |
+
class SwinTransformerBlock(nn.Module):
|
195 |
+
r""" Swin Transformer Block.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
dim (int): Number of input channels.
|
199 |
+
input_resolution (tuple[int]): Input resolution.
|
200 |
+
num_heads (int): Number of attention heads.
|
201 |
+
window_size (int): Window size.
|
202 |
+
shift_size (int): Shift size for SW-MSA.
|
203 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
204 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
205 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
206 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
207 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
208 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
209 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
210 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(self,
|
214 |
+
dim,
|
215 |
+
input_resolution,
|
216 |
+
num_heads,
|
217 |
+
window_size=7,
|
218 |
+
shift_size=0,
|
219 |
+
mlp_ratio=4.,
|
220 |
+
qkv_bias=True,
|
221 |
+
qk_scale=None,
|
222 |
+
drop=0.,
|
223 |
+
attn_drop=0.,
|
224 |
+
drop_path=0.,
|
225 |
+
act_layer=nn.GELU,
|
226 |
+
norm_layer=nn.LayerNorm):
|
227 |
+
super().__init__()
|
228 |
+
self.dim = dim
|
229 |
+
self.input_resolution = input_resolution
|
230 |
+
self.num_heads = num_heads
|
231 |
+
self.window_size = window_size
|
232 |
+
self.shift_size = shift_size
|
233 |
+
self.mlp_ratio = mlp_ratio
|
234 |
+
if min(self.input_resolution) <= self.window_size:
|
235 |
+
# if window size is larger than input resolution, we don't partition windows
|
236 |
+
self.shift_size = 0
|
237 |
+
self.window_size = min(self.input_resolution)
|
238 |
+
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
239 |
+
|
240 |
+
self.norm1 = norm_layer(dim)
|
241 |
+
self.attn = WindowAttention(
|
242 |
+
dim,
|
243 |
+
window_size=to_2tuple(self.window_size),
|
244 |
+
num_heads=num_heads,
|
245 |
+
qkv_bias=qkv_bias,
|
246 |
+
qk_scale=qk_scale,
|
247 |
+
attn_drop=attn_drop,
|
248 |
+
proj_drop=drop)
|
249 |
+
|
250 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
251 |
+
self.norm2 = norm_layer(dim)
|
252 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
253 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
254 |
+
|
255 |
+
if self.shift_size > 0:
|
256 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
257 |
+
else:
|
258 |
+
attn_mask = None
|
259 |
+
|
260 |
+
self.register_buffer('attn_mask', attn_mask)
|
261 |
+
|
262 |
+
def calculate_mask(self, x_size):
|
263 |
+
# calculate attention mask for SW-MSA
|
264 |
+
h, w = x_size
|
265 |
+
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
|
266 |
+
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
267 |
+
-self.shift_size), slice(-self.shift_size, None))
|
268 |
+
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
269 |
+
-self.shift_size), slice(-self.shift_size, None))
|
270 |
+
cnt = 0
|
271 |
+
for h in h_slices:
|
272 |
+
for w in w_slices:
|
273 |
+
img_mask[:, h, w, :] = cnt
|
274 |
+
cnt += 1
|
275 |
+
|
276 |
+
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
|
277 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
278 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
279 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
280 |
+
|
281 |
+
return attn_mask
|
282 |
+
|
283 |
+
def forward(self, x, x_size):
|
284 |
+
h, w = x_size
|
285 |
+
b, _, c = x.shape
|
286 |
+
# assert seq_len == h * w, "input feature has wrong size"
|
287 |
+
|
288 |
+
shortcut = x
|
289 |
+
x = self.norm1(x)
|
290 |
+
x = x.view(b, h, w, c)
|
291 |
+
|
292 |
+
# cyclic shift
|
293 |
+
if self.shift_size > 0:
|
294 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
295 |
+
else:
|
296 |
+
shifted_x = x
|
297 |
+
|
298 |
+
# partition windows
|
299 |
+
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
|
300 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
301 |
+
|
302 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
303 |
+
if self.input_resolution == x_size:
|
304 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
|
305 |
+
else:
|
306 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
307 |
+
|
308 |
+
# merge windows
|
309 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
|
310 |
+
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
|
311 |
+
|
312 |
+
# reverse cyclic shift
|
313 |
+
if self.shift_size > 0:
|
314 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
315 |
+
else:
|
316 |
+
x = shifted_x
|
317 |
+
x = x.view(b, h * w, c)
|
318 |
+
|
319 |
+
# FFN
|
320 |
+
x = shortcut + self.drop_path(x)
|
321 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
322 |
+
|
323 |
+
return x
|
324 |
+
|
325 |
+
def extra_repr(self) -> str:
|
326 |
+
return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
|
327 |
+
f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
|
328 |
+
|
329 |
+
def flops(self):
|
330 |
+
flops = 0
|
331 |
+
h, w = self.input_resolution
|
332 |
+
# norm1
|
333 |
+
flops += self.dim * h * w
|
334 |
+
# W-MSA/SW-MSA
|
335 |
+
nw = h * w / self.window_size / self.window_size
|
336 |
+
flops += nw * self.attn.flops(self.window_size * self.window_size)
|
337 |
+
# mlp
|
338 |
+
flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
|
339 |
+
# norm2
|
340 |
+
flops += self.dim * h * w
|
341 |
+
return flops
|
342 |
+
|
343 |
+
|
344 |
+
class PatchMerging(nn.Module):
|
345 |
+
r""" Patch Merging Layer.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
349 |
+
dim (int): Number of input channels.
|
350 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
351 |
+
"""
|
352 |
+
|
353 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
354 |
+
super().__init__()
|
355 |
+
self.input_resolution = input_resolution
|
356 |
+
self.dim = dim
|
357 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
358 |
+
self.norm = norm_layer(4 * dim)
|
359 |
+
|
360 |
+
def forward(self, x):
|
361 |
+
"""
|
362 |
+
x: b, h*w, c
|
363 |
+
"""
|
364 |
+
h, w = self.input_resolution
|
365 |
+
b, seq_len, c = x.shape
|
366 |
+
assert seq_len == h * w, 'input feature has wrong size'
|
367 |
+
assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
|
368 |
+
|
369 |
+
x = x.view(b, h, w, c)
|
370 |
+
|
371 |
+
x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
|
372 |
+
x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
|
373 |
+
x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
|
374 |
+
x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
|
375 |
+
x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
|
376 |
+
x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
|
377 |
+
|
378 |
+
x = self.norm(x)
|
379 |
+
x = self.reduction(x)
|
380 |
+
|
381 |
+
return x
|
382 |
+
|
383 |
+
def extra_repr(self) -> str:
|
384 |
+
return f'input_resolution={self.input_resolution}, dim={self.dim}'
|
385 |
+
|
386 |
+
def flops(self):
|
387 |
+
h, w = self.input_resolution
|
388 |
+
flops = h * w * self.dim
|
389 |
+
flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
|
390 |
+
return flops
|
391 |
+
|
392 |
+
|
393 |
+
class BasicLayer(nn.Module):
|
394 |
+
""" A basic Swin Transformer layer for one stage.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
dim (int): Number of input channels.
|
398 |
+
input_resolution (tuple[int]): Input resolution.
|
399 |
+
depth (int): Number of blocks.
|
400 |
+
num_heads (int): Number of attention heads.
|
401 |
+
window_size (int): Local window size.
|
402 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
403 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
404 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
405 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
406 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
407 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
408 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
409 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
410 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
411 |
+
"""
|
412 |
+
|
413 |
+
def __init__(self,
|
414 |
+
dim,
|
415 |
+
input_resolution,
|
416 |
+
depth,
|
417 |
+
num_heads,
|
418 |
+
window_size,
|
419 |
+
mlp_ratio=4.,
|
420 |
+
qkv_bias=True,
|
421 |
+
qk_scale=None,
|
422 |
+
drop=0.,
|
423 |
+
attn_drop=0.,
|
424 |
+
drop_path=0.,
|
425 |
+
norm_layer=nn.LayerNorm,
|
426 |
+
downsample=None,
|
427 |
+
use_checkpoint=False):
|
428 |
+
|
429 |
+
super().__init__()
|
430 |
+
self.dim = dim
|
431 |
+
self.input_resolution = input_resolution
|
432 |
+
self.depth = depth
|
433 |
+
self.use_checkpoint = use_checkpoint
|
434 |
+
|
435 |
+
# build blocks
|
436 |
+
self.blocks = nn.ModuleList([
|
437 |
+
SwinTransformerBlock(
|
438 |
+
dim=dim,
|
439 |
+
input_resolution=input_resolution,
|
440 |
+
num_heads=num_heads,
|
441 |
+
window_size=window_size,
|
442 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
443 |
+
mlp_ratio=mlp_ratio,
|
444 |
+
qkv_bias=qkv_bias,
|
445 |
+
qk_scale=qk_scale,
|
446 |
+
drop=drop,
|
447 |
+
attn_drop=attn_drop,
|
448 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
449 |
+
norm_layer=norm_layer) for i in range(depth)
|
450 |
+
])
|
451 |
+
|
452 |
+
# patch merging layer
|
453 |
+
if downsample is not None:
|
454 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
455 |
+
else:
|
456 |
+
self.downsample = None
|
457 |
+
|
458 |
+
def forward(self, x, x_size):
|
459 |
+
for blk in self.blocks:
|
460 |
+
if self.use_checkpoint:
|
461 |
+
x = checkpoint.checkpoint(blk, x)
|
462 |
+
else:
|
463 |
+
x = blk(x, x_size)
|
464 |
+
if self.downsample is not None:
|
465 |
+
x = self.downsample(x)
|
466 |
+
return x
|
467 |
+
|
468 |
+
def extra_repr(self) -> str:
|
469 |
+
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
470 |
+
|
471 |
+
def flops(self):
|
472 |
+
flops = 0
|
473 |
+
for blk in self.blocks:
|
474 |
+
flops += blk.flops()
|
475 |
+
if self.downsample is not None:
|
476 |
+
flops += self.downsample.flops()
|
477 |
+
return flops
|
478 |
+
|
479 |
+
|
480 |
+
class RSTB(nn.Module):
|
481 |
+
"""Residual Swin Transformer Block (RSTB).
|
482 |
+
|
483 |
+
Args:
|
484 |
+
dim (int): Number of input channels.
|
485 |
+
input_resolution (tuple[int]): Input resolution.
|
486 |
+
depth (int): Number of blocks.
|
487 |
+
num_heads (int): Number of attention heads.
|
488 |
+
window_size (int): Local window size.
|
489 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
490 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
491 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
492 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
493 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
494 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
495 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
496 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
497 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
498 |
+
img_size: Input image size.
|
499 |
+
patch_size: Patch size.
|
500 |
+
resi_connection: The convolutional block before residual connection.
|
501 |
+
"""
|
502 |
+
|
503 |
+
def __init__(self,
|
504 |
+
dim,
|
505 |
+
input_resolution,
|
506 |
+
depth,
|
507 |
+
num_heads,
|
508 |
+
window_size,
|
509 |
+
mlp_ratio=4.,
|
510 |
+
qkv_bias=True,
|
511 |
+
qk_scale=None,
|
512 |
+
drop=0.,
|
513 |
+
attn_drop=0.,
|
514 |
+
drop_path=0.,
|
515 |
+
norm_layer=nn.LayerNorm,
|
516 |
+
downsample=None,
|
517 |
+
use_checkpoint=False,
|
518 |
+
img_size=224,
|
519 |
+
patch_size=4,
|
520 |
+
resi_connection='1conv'):
|
521 |
+
super(RSTB, self).__init__()
|
522 |
+
|
523 |
+
self.dim = dim
|
524 |
+
self.input_resolution = input_resolution
|
525 |
+
|
526 |
+
self.residual_group = BasicLayer(
|
527 |
+
dim=dim,
|
528 |
+
input_resolution=input_resolution,
|
529 |
+
depth=depth,
|
530 |
+
num_heads=num_heads,
|
531 |
+
window_size=window_size,
|
532 |
+
mlp_ratio=mlp_ratio,
|
533 |
+
qkv_bias=qkv_bias,
|
534 |
+
qk_scale=qk_scale,
|
535 |
+
drop=drop,
|
536 |
+
attn_drop=attn_drop,
|
537 |
+
drop_path=drop_path,
|
538 |
+
norm_layer=norm_layer,
|
539 |
+
downsample=downsample,
|
540 |
+
use_checkpoint=use_checkpoint)
|
541 |
+
|
542 |
+
if resi_connection == '1conv':
|
543 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
544 |
+
elif resi_connection == '3conv':
|
545 |
+
# to save parameters and memory
|
546 |
+
self.conv = nn.Sequential(
|
547 |
+
nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
548 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
549 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
550 |
+
|
551 |
+
self.patch_embed = PatchEmbed(
|
552 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
553 |
+
|
554 |
+
self.patch_unembed = PatchUnEmbed(
|
555 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
556 |
+
|
557 |
+
def forward(self, x, x_size):
|
558 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
559 |
+
|
560 |
+
def flops(self):
|
561 |
+
flops = 0
|
562 |
+
flops += self.residual_group.flops()
|
563 |
+
h, w = self.input_resolution
|
564 |
+
flops += h * w * self.dim * self.dim * 9
|
565 |
+
flops += self.patch_embed.flops()
|
566 |
+
flops += self.patch_unembed.flops()
|
567 |
+
|
568 |
+
return flops
|
569 |
+
|
570 |
+
|
571 |
+
class PatchEmbed(nn.Module):
|
572 |
+
r""" Image to Patch Embedding
|
573 |
+
|
574 |
+
Args:
|
575 |
+
img_size (int): Image size. Default: 224.
|
576 |
+
patch_size (int): Patch token size. Default: 4.
|
577 |
+
in_chans (int): Number of input image channels. Default: 3.
|
578 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
579 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
583 |
+
super().__init__()
|
584 |
+
img_size = to_2tuple(img_size)
|
585 |
+
patch_size = to_2tuple(patch_size)
|
586 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
587 |
+
self.img_size = img_size
|
588 |
+
self.patch_size = patch_size
|
589 |
+
self.patches_resolution = patches_resolution
|
590 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
591 |
+
|
592 |
+
self.in_chans = in_chans
|
593 |
+
self.embed_dim = embed_dim
|
594 |
+
|
595 |
+
if norm_layer is not None:
|
596 |
+
self.norm = norm_layer(embed_dim)
|
597 |
+
else:
|
598 |
+
self.norm = None
|
599 |
+
|
600 |
+
def forward(self, x):
|
601 |
+
x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
|
602 |
+
if self.norm is not None:
|
603 |
+
x = self.norm(x)
|
604 |
+
return x
|
605 |
+
|
606 |
+
def flops(self):
|
607 |
+
flops = 0
|
608 |
+
h, w = self.img_size
|
609 |
+
if self.norm is not None:
|
610 |
+
flops += h * w * self.embed_dim
|
611 |
+
return flops
|
612 |
+
|
613 |
+
|
614 |
+
class PatchUnEmbed(nn.Module):
|
615 |
+
r""" Image to Patch Unembedding
|
616 |
+
|
617 |
+
Args:
|
618 |
+
img_size (int): Image size. Default: 224.
|
619 |
+
patch_size (int): Patch token size. Default: 4.
|
620 |
+
in_chans (int): Number of input image channels. Default: 3.
|
621 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
622 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
623 |
+
"""
|
624 |
+
|
625 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
626 |
+
super().__init__()
|
627 |
+
img_size = to_2tuple(img_size)
|
628 |
+
patch_size = to_2tuple(patch_size)
|
629 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
630 |
+
self.img_size = img_size
|
631 |
+
self.patch_size = patch_size
|
632 |
+
self.patches_resolution = patches_resolution
|
633 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
634 |
+
|
635 |
+
self.in_chans = in_chans
|
636 |
+
self.embed_dim = embed_dim
|
637 |
+
|
638 |
+
def forward(self, x, x_size):
|
639 |
+
x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
|
640 |
+
return x
|
641 |
+
|
642 |
+
def flops(self):
|
643 |
+
flops = 0
|
644 |
+
return flops
|
645 |
+
|
646 |
+
|
647 |
+
class Upsample(nn.Sequential):
|
648 |
+
"""Upsample module.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
652 |
+
num_feat (int): Channel number of intermediate features.
|
653 |
+
"""
|
654 |
+
|
655 |
+
def __init__(self, scale, num_feat):
|
656 |
+
m = []
|
657 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
658 |
+
for _ in range(int(math.log(scale, 2))):
|
659 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
660 |
+
m.append(nn.PixelShuffle(2))
|
661 |
+
elif scale == 3:
|
662 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
663 |
+
m.append(nn.PixelShuffle(3))
|
664 |
+
else:
|
665 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
666 |
+
super(Upsample, self).__init__(*m)
|
667 |
+
|
668 |
+
|
669 |
+
class UpsampleOneStep(nn.Sequential):
|
670 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
671 |
+
Used in lightweight SR to save parameters.
|
672 |
+
|
673 |
+
Args:
|
674 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
675 |
+
num_feat (int): Channel number of intermediate features.
|
676 |
+
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
680 |
+
self.num_feat = num_feat
|
681 |
+
self.input_resolution = input_resolution
|
682 |
+
m = []
|
683 |
+
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
|
684 |
+
m.append(nn.PixelShuffle(scale))
|
685 |
+
super(UpsampleOneStep, self).__init__(*m)
|
686 |
+
|
687 |
+
def flops(self):
|
688 |
+
h, w = self.input_resolution
|
689 |
+
flops = h * w * self.num_feat * 3 * 9
|
690 |
+
return flops
|
691 |
+
|
692 |
+
|
693 |
+
@ARCH_REGISTRY.register()
|
694 |
+
class SwinIR(nn.Module):
|
695 |
+
r""" SwinIR
|
696 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
697 |
+
|
698 |
+
Args:
|
699 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
700 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
701 |
+
in_chans (int): Number of input image channels. Default: 3
|
702 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
703 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
704 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
705 |
+
window_size (int): Window size. Default: 7
|
706 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
707 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
708 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
709 |
+
drop_rate (float): Dropout rate. Default: 0
|
710 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
711 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
712 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
713 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
714 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
715 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
716 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
717 |
+
img_range: Image range. 1. or 255.
|
718 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
719 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
720 |
+
"""
|
721 |
+
|
722 |
+
def __init__(self,
|
723 |
+
img_size=64,
|
724 |
+
patch_size=1,
|
725 |
+
in_chans=3,
|
726 |
+
embed_dim=96,
|
727 |
+
depths=(6, 6, 6, 6),
|
728 |
+
num_heads=(6, 6, 6, 6),
|
729 |
+
window_size=7,
|
730 |
+
mlp_ratio=4.,
|
731 |
+
qkv_bias=True,
|
732 |
+
qk_scale=None,
|
733 |
+
drop_rate=0.,
|
734 |
+
attn_drop_rate=0.,
|
735 |
+
drop_path_rate=0.1,
|
736 |
+
norm_layer=nn.LayerNorm,
|
737 |
+
ape=False,
|
738 |
+
patch_norm=True,
|
739 |
+
use_checkpoint=False,
|
740 |
+
upscale=2,
|
741 |
+
img_range=1.,
|
742 |
+
upsampler='',
|
743 |
+
resi_connection='1conv',
|
744 |
+
**kwargs):
|
745 |
+
super(SwinIR, self).__init__()
|
746 |
+
num_in_ch = in_chans
|
747 |
+
num_out_ch = in_chans
|
748 |
+
num_feat = 64
|
749 |
+
self.img_range = img_range
|
750 |
+
if in_chans == 3:
|
751 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
752 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
753 |
+
else:
|
754 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
755 |
+
self.upscale = upscale
|
756 |
+
self.upsampler = upsampler
|
757 |
+
|
758 |
+
# ------------------------- 1, shallow feature extraction ------------------------- #
|
759 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
760 |
+
|
761 |
+
# ------------------------- 2, deep feature extraction ------------------------- #
|
762 |
+
self.num_layers = len(depths)
|
763 |
+
self.embed_dim = embed_dim
|
764 |
+
self.ape = ape
|
765 |
+
self.patch_norm = patch_norm
|
766 |
+
self.num_features = embed_dim
|
767 |
+
self.mlp_ratio = mlp_ratio
|
768 |
+
|
769 |
+
# split image into non-overlapping patches
|
770 |
+
self.patch_embed = PatchEmbed(
|
771 |
+
img_size=img_size,
|
772 |
+
patch_size=patch_size,
|
773 |
+
in_chans=embed_dim,
|
774 |
+
embed_dim=embed_dim,
|
775 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
776 |
+
num_patches = self.patch_embed.num_patches
|
777 |
+
patches_resolution = self.patch_embed.patches_resolution
|
778 |
+
self.patches_resolution = patches_resolution
|
779 |
+
|
780 |
+
# merge non-overlapping patches into image
|
781 |
+
self.patch_unembed = PatchUnEmbed(
|
782 |
+
img_size=img_size,
|
783 |
+
patch_size=patch_size,
|
784 |
+
in_chans=embed_dim,
|
785 |
+
embed_dim=embed_dim,
|
786 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
787 |
+
|
788 |
+
# absolute position embedding
|
789 |
+
if self.ape:
|
790 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
791 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
792 |
+
|
793 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
794 |
+
|
795 |
+
# stochastic depth
|
796 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
797 |
+
|
798 |
+
# build Residual Swin Transformer blocks (RSTB)
|
799 |
+
self.layers = nn.ModuleList()
|
800 |
+
for i_layer in range(self.num_layers):
|
801 |
+
layer = RSTB(
|
802 |
+
dim=embed_dim,
|
803 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
804 |
+
depth=depths[i_layer],
|
805 |
+
num_heads=num_heads[i_layer],
|
806 |
+
window_size=window_size,
|
807 |
+
mlp_ratio=self.mlp_ratio,
|
808 |
+
qkv_bias=qkv_bias,
|
809 |
+
qk_scale=qk_scale,
|
810 |
+
drop=drop_rate,
|
811 |
+
attn_drop=attn_drop_rate,
|
812 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
813 |
+
norm_layer=norm_layer,
|
814 |
+
downsample=None,
|
815 |
+
use_checkpoint=use_checkpoint,
|
816 |
+
img_size=img_size,
|
817 |
+
patch_size=patch_size,
|
818 |
+
resi_connection=resi_connection)
|
819 |
+
self.layers.append(layer)
|
820 |
+
self.norm = norm_layer(self.num_features)
|
821 |
+
|
822 |
+
# build the last conv layer in deep feature extraction
|
823 |
+
if resi_connection == '1conv':
|
824 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
825 |
+
elif resi_connection == '3conv':
|
826 |
+
# to save parameters and memory
|
827 |
+
self.conv_after_body = nn.Sequential(
|
828 |
+
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
829 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
830 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
831 |
+
|
832 |
+
# ------------------------- 3, high quality image reconstruction ------------------------- #
|
833 |
+
if self.upsampler == 'pixelshuffle':
|
834 |
+
# for classical SR
|
835 |
+
self.conv_before_upsample = nn.Sequential(
|
836 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
837 |
+
self.upsample = Upsample(upscale, num_feat)
|
838 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
839 |
+
elif self.upsampler == 'pixelshuffledirect':
|
840 |
+
# for lightweight SR (to save parameters)
|
841 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
842 |
+
(patches_resolution[0], patches_resolution[1]))
|
843 |
+
elif self.upsampler == 'nearest+conv':
|
844 |
+
# for real-world SR (less artifacts)
|
845 |
+
assert self.upscale == 4, 'only support x4 now.'
|
846 |
+
self.conv_before_upsample = nn.Sequential(
|
847 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
848 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
849 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
850 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
851 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
852 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
853 |
+
else:
|
854 |
+
# for image denoising and JPEG compression artifact reduction
|
855 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
856 |
+
|
857 |
+
self.apply(self._init_weights)
|
858 |
+
|
859 |
+
def _init_weights(self, m):
|
860 |
+
if isinstance(m, nn.Linear):
|
861 |
+
trunc_normal_(m.weight, std=.02)
|
862 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
863 |
+
nn.init.constant_(m.bias, 0)
|
864 |
+
elif isinstance(m, nn.LayerNorm):
|
865 |
+
nn.init.constant_(m.bias, 0)
|
866 |
+
nn.init.constant_(m.weight, 1.0)
|
867 |
+
|
868 |
+
@torch.jit.ignore
|
869 |
+
def no_weight_decay(self):
|
870 |
+
return {'absolute_pos_embed'}
|
871 |
+
|
872 |
+
@torch.jit.ignore
|
873 |
+
def no_weight_decay_keywords(self):
|
874 |
+
return {'relative_position_bias_table'}
|
875 |
+
|
876 |
+
def forward_features(self, x):
|
877 |
+
x_size = (x.shape[2], x.shape[3])
|
878 |
+
x = self.patch_embed(x)
|
879 |
+
if self.ape:
|
880 |
+
x = x + self.absolute_pos_embed
|
881 |
+
x = self.pos_drop(x)
|
882 |
+
|
883 |
+
for layer in self.layers:
|
884 |
+
x = layer(x, x_size)
|
885 |
+
|
886 |
+
x = self.norm(x) # b seq_len c
|
887 |
+
x = self.patch_unembed(x, x_size)
|
888 |
+
|
889 |
+
return x
|
890 |
+
|
891 |
+
def forward(self, x):
|
892 |
+
self.mean = self.mean.type_as(x)
|
893 |
+
x = (x - self.mean) * self.img_range
|
894 |
+
|
895 |
+
if self.upsampler == 'pixelshuffle':
|
896 |
+
# for classical SR
|
897 |
+
x = self.conv_first(x)
|
898 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
899 |
+
x = self.conv_before_upsample(x)
|
900 |
+
x = self.conv_last(self.upsample(x))
|
901 |
+
elif self.upsampler == 'pixelshuffledirect':
|
902 |
+
# for lightweight SR
|
903 |
+
x = self.conv_first(x)
|
904 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
905 |
+
x = self.upsample(x)
|
906 |
+
elif self.upsampler == 'nearest+conv':
|
907 |
+
# for real-world SR
|
908 |
+
x = self.conv_first(x)
|
909 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
910 |
+
x = self.conv_before_upsample(x)
|
911 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
912 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
913 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
914 |
+
else:
|
915 |
+
# for image denoising and JPEG compression artifact reduction
|
916 |
+
x_first = self.conv_first(x)
|
917 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
918 |
+
x = x + self.conv_last(res)
|
919 |
+
|
920 |
+
x = x / self.img_range + self.mean
|
921 |
+
|
922 |
+
return x
|
923 |
+
|
924 |
+
def flops(self):
|
925 |
+
flops = 0
|
926 |
+
h, w = self.patches_resolution
|
927 |
+
flops += h * w * 3 * self.embed_dim * 9
|
928 |
+
flops += self.patch_embed.flops()
|
929 |
+
for layer in self.layers:
|
930 |
+
flops += layer.flops()
|
931 |
+
flops += h * w * 3 * self.embed_dim * self.embed_dim
|
932 |
+
flops += self.upsample.flops()
|
933 |
+
return flops
|
934 |
+
|
935 |
+
|
936 |
+
if __name__ == '__main__':
|
937 |
+
upscale = 4
|
938 |
+
window_size = 8
|
939 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
940 |
+
width = (720 // upscale // window_size + 1) * window_size
|
941 |
+
model = SwinIR(
|
942 |
+
upscale=2,
|
943 |
+
img_size=(height, width),
|
944 |
+
window_size=window_size,
|
945 |
+
img_range=1.,
|
946 |
+
depths=[6, 6, 6, 6],
|
947 |
+
embed_dim=60,
|
948 |
+
num_heads=[6, 6, 6, 6],
|
949 |
+
mlp_ratio=2,
|
950 |
+
upsampler='pixelshuffledirect')
|
951 |
+
print(model)
|
952 |
+
print(height, width, model.flops() / 1e9)
|
953 |
+
|
954 |
+
x = torch.randn((1, 3, height, width))
|
955 |
+
x = model(x)
|
956 |
+
print(x.shape)
|
StableSR/basicsr/archs/tof_arch.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import flow_warp
|
7 |
+
|
8 |
+
|
9 |
+
class BasicModule(nn.Module):
|
10 |
+
"""Basic module of SPyNet.
|
11 |
+
|
12 |
+
Note that unlike the architecture in spynet_arch.py, the basic module
|
13 |
+
here contains batch normalization.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super(BasicModule, self).__init__()
|
18 |
+
self.basic_module = nn.Sequential(
|
19 |
+
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
|
20 |
+
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
|
22 |
+
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
|
24 |
+
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
25 |
+
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
|
26 |
+
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
|
27 |
+
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
28 |
+
|
29 |
+
def forward(self, tensor_input):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
|
33 |
+
8 channels contain:
|
34 |
+
[reference image (3), neighbor image (3), initial flow (2)].
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Tensor: Estimated flow with shape (b, 2, h, w)
|
38 |
+
"""
|
39 |
+
return self.basic_module(tensor_input)
|
40 |
+
|
41 |
+
|
42 |
+
class SPyNetTOF(nn.Module):
|
43 |
+
"""SPyNet architecture for TOF.
|
44 |
+
|
45 |
+
Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use.
|
46 |
+
They differ in the following aspects:
|
47 |
+
|
48 |
+
1. The basic modules here contain BatchNorm.
|
49 |
+
2. Normalization and denormalization are not done here, as they are done in TOFlow.
|
50 |
+
|
51 |
+
``Paper: Optical Flow Estimation using a Spatial Pyramid Network``
|
52 |
+
|
53 |
+
Reference: https://github.com/Coldog2333/pytoflow
|
54 |
+
|
55 |
+
Args:
|
56 |
+
load_path (str): Path for pretrained SPyNet. Default: None.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, load_path=None):
|
60 |
+
super(SPyNetTOF, self).__init__()
|
61 |
+
|
62 |
+
self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
|
63 |
+
if load_path:
|
64 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
65 |
+
|
66 |
+
def forward(self, ref, supp):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
ref (Tensor): Reference image with shape of (b, 3, h, w).
|
70 |
+
supp: The supporting image to be warped: (b, 3, h, w).
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Tensor: Estimated optical flow: (b, 2, h, w).
|
74 |
+
"""
|
75 |
+
num_batches, _, h, w = ref.size()
|
76 |
+
ref = [ref]
|
77 |
+
supp = [supp]
|
78 |
+
|
79 |
+
# generate downsampled frames
|
80 |
+
for _ in range(3):
|
81 |
+
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
|
82 |
+
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
|
83 |
+
|
84 |
+
# flow computation
|
85 |
+
flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
|
86 |
+
for i in range(4):
|
87 |
+
flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
|
88 |
+
flow = flow_up + self.basic_module[i](
|
89 |
+
torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
|
90 |
+
return flow
|
91 |
+
|
92 |
+
|
93 |
+
@ARCH_REGISTRY.register()
|
94 |
+
class TOFlow(nn.Module):
|
95 |
+
"""PyTorch implementation of TOFlow.
|
96 |
+
|
97 |
+
In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.
|
98 |
+
|
99 |
+
``Paper: Video Enhancement with Task-Oriented Flow``
|
100 |
+
|
101 |
+
Reference: https://github.com/anchen1011/toflow
|
102 |
+
|
103 |
+
Reference: https://github.com/Coldog2333/pytoflow
|
104 |
+
|
105 |
+
Args:
|
106 |
+
adapt_official_weights (bool): Whether to adapt the weights translated
|
107 |
+
from the official implementation. Set to false if you want to
|
108 |
+
train from scratch. Default: False
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self, adapt_official_weights=False):
|
112 |
+
super(TOFlow, self).__init__()
|
113 |
+
self.adapt_official_weights = adapt_official_weights
|
114 |
+
self.ref_idx = 0 if adapt_official_weights else 3
|
115 |
+
|
116 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
117 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
118 |
+
|
119 |
+
# flow estimation module
|
120 |
+
self.spynet = SPyNetTOF()
|
121 |
+
|
122 |
+
# reconstruction module
|
123 |
+
self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
|
124 |
+
self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
|
125 |
+
self.conv_3 = nn.Conv2d(64, 64, 1)
|
126 |
+
self.conv_4 = nn.Conv2d(64, 3, 1)
|
127 |
+
|
128 |
+
# activation function
|
129 |
+
self.relu = nn.ReLU(inplace=True)
|
130 |
+
|
131 |
+
def normalize(self, img):
|
132 |
+
return (img - self.mean) / self.std
|
133 |
+
|
134 |
+
def denormalize(self, img):
|
135 |
+
return img * self.std + self.mean
|
136 |
+
|
137 |
+
def forward(self, lrs):
|
138 |
+
"""
|
139 |
+
Args:
|
140 |
+
lrs: Input lr frames: (b, 7, 3, h, w).
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Tensor: SR frame: (b, 3, h, w).
|
144 |
+
"""
|
145 |
+
# In the official implementation, the 0-th frame is the reference frame
|
146 |
+
if self.adapt_official_weights:
|
147 |
+
lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
|
148 |
+
|
149 |
+
num_batches, num_lrs, _, h, w = lrs.size()
|
150 |
+
|
151 |
+
lrs = self.normalize(lrs.view(-1, 3, h, w))
|
152 |
+
lrs = lrs.view(num_batches, num_lrs, 3, h, w)
|
153 |
+
|
154 |
+
lr_ref = lrs[:, self.ref_idx, :, :, :]
|
155 |
+
lr_aligned = []
|
156 |
+
for i in range(7): # 7 frames
|
157 |
+
if i == self.ref_idx:
|
158 |
+
lr_aligned.append(lr_ref)
|
159 |
+
else:
|
160 |
+
lr_supp = lrs[:, i, :, :, :]
|
161 |
+
flow = self.spynet(lr_ref, lr_supp)
|
162 |
+
lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
|
163 |
+
|
164 |
+
# reconstruction
|
165 |
+
hr = torch.stack(lr_aligned, dim=1)
|
166 |
+
hr = hr.view(num_batches, -1, h, w)
|
167 |
+
hr = self.relu(self.conv_1(hr))
|
168 |
+
hr = self.relu(self.conv_2(hr))
|
169 |
+
hr = self.relu(self.conv_3(hr))
|
170 |
+
hr = self.conv_4(hr) + lr_ref
|
171 |
+
|
172 |
+
return self.denormalize(hr)
|
StableSR/basicsr/archs/vgg_arch.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from torch import nn as nn
|
5 |
+
from torchvision.models import vgg as vgg
|
6 |
+
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
|
9 |
+
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
10 |
+
NAMES = {
|
11 |
+
'vgg11': [
|
12 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
13 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
14 |
+
'pool5'
|
15 |
+
],
|
16 |
+
'vgg13': [
|
17 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
18 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
19 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
20 |
+
],
|
21 |
+
'vgg16': [
|
22 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
23 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
24 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
25 |
+
'pool5'
|
26 |
+
],
|
27 |
+
'vgg19': [
|
28 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
29 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
30 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
31 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
32 |
+
]
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def insert_bn(names):
|
37 |
+
"""Insert bn layer after each conv.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
names (list): The list of layer names.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
list: The list of layer names with bn layers.
|
44 |
+
"""
|
45 |
+
names_bn = []
|
46 |
+
for name in names:
|
47 |
+
names_bn.append(name)
|
48 |
+
if 'conv' in name:
|
49 |
+
position = name.replace('conv', '')
|
50 |
+
names_bn.append('bn' + position)
|
51 |
+
return names_bn
|
52 |
+
|
53 |
+
|
54 |
+
@ARCH_REGISTRY.register()
|
55 |
+
class VGGFeatureExtractor(nn.Module):
|
56 |
+
"""VGG network for feature extraction.
|
57 |
+
|
58 |
+
In this implementation, we allow users to choose whether use normalization
|
59 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
60 |
+
path must fit the vgg type.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
64 |
+
features according to the layer_name_list.
|
65 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
66 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
67 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
68 |
+
the input feature must in the range [0, 1]. Default: True.
|
69 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
70 |
+
Default: False.
|
71 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
72 |
+
optimized. Default: False.
|
73 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
74 |
+
will be removed. Default: False.
|
75 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self,
|
79 |
+
layer_name_list,
|
80 |
+
vgg_type='vgg19',
|
81 |
+
use_input_norm=True,
|
82 |
+
range_norm=False,
|
83 |
+
requires_grad=False,
|
84 |
+
remove_pooling=False,
|
85 |
+
pooling_stride=2):
|
86 |
+
super(VGGFeatureExtractor, self).__init__()
|
87 |
+
|
88 |
+
self.layer_name_list = layer_name_list
|
89 |
+
self.use_input_norm = use_input_norm
|
90 |
+
self.range_norm = range_norm
|
91 |
+
|
92 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
93 |
+
if 'bn' in vgg_type:
|
94 |
+
self.names = insert_bn(self.names)
|
95 |
+
|
96 |
+
# only borrow layers that will be used to avoid unused params
|
97 |
+
max_idx = 0
|
98 |
+
for v in layer_name_list:
|
99 |
+
idx = self.names.index(v)
|
100 |
+
if idx > max_idx:
|
101 |
+
max_idx = idx
|
102 |
+
|
103 |
+
if os.path.exists(VGG_PRETRAIN_PATH):
|
104 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
105 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
106 |
+
vgg_net.load_state_dict(state_dict)
|
107 |
+
else:
|
108 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
109 |
+
|
110 |
+
features = vgg_net.features[:max_idx + 1]
|
111 |
+
|
112 |
+
modified_net = OrderedDict()
|
113 |
+
for k, v in zip(self.names, features):
|
114 |
+
if 'pool' in k:
|
115 |
+
# if remove_pooling is true, pooling operation will be removed
|
116 |
+
if remove_pooling:
|
117 |
+
continue
|
118 |
+
else:
|
119 |
+
# in some cases, we may want to change the default stride
|
120 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
121 |
+
else:
|
122 |
+
modified_net[k] = v
|
123 |
+
|
124 |
+
self.vgg_net = nn.Sequential(modified_net)
|
125 |
+
|
126 |
+
if not requires_grad:
|
127 |
+
self.vgg_net.eval()
|
128 |
+
for param in self.parameters():
|
129 |
+
param.requires_grad = False
|
130 |
+
else:
|
131 |
+
self.vgg_net.train()
|
132 |
+
for param in self.parameters():
|
133 |
+
param.requires_grad = True
|
134 |
+
|
135 |
+
if self.use_input_norm:
|
136 |
+
# the mean is for image with range [0, 1]
|
137 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
138 |
+
# the std is for image with range [0, 1]
|
139 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
"""Forward function.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Tensor: Forward results.
|
149 |
+
"""
|
150 |
+
if self.range_norm:
|
151 |
+
x = (x + 1) / 2
|
152 |
+
if self.use_input_norm:
|
153 |
+
x = (x - self.mean) / self.std
|
154 |
+
|
155 |
+
output = {}
|
156 |
+
for key, layer in self.vgg_net._modules.items():
|
157 |
+
x = layer(x)
|
158 |
+
if key in self.layer_name_list:
|
159 |
+
output[key] = x.clone()
|
160 |
+
|
161 |
+
return output
|
StableSR/basicsr/data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from copy import deepcopy
|
7 |
+
from functools import partial
|
8 |
+
from os import path as osp
|
9 |
+
|
10 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
11 |
+
from basicsr.utils import get_root_logger, scandir
|
12 |
+
from basicsr.utils.dist_util import get_dist_info
|
13 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
14 |
+
|
15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
16 |
+
|
17 |
+
# automatically scan and import dataset modules for registry
|
18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
21 |
+
# import all the dataset modules
|
22 |
+
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
23 |
+
|
24 |
+
|
25 |
+
def build_dataset(dataset_opt):
|
26 |
+
"""Build dataset from options.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dataset_opt (dict): Configuration for dataset. It must contain:
|
30 |
+
name (str): Dataset name.
|
31 |
+
type (str): Dataset type.
|
32 |
+
"""
|
33 |
+
dataset_opt = deepcopy(dataset_opt)
|
34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
35 |
+
logger = get_root_logger()
|
36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
41 |
+
"""Build dataloader.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
46 |
+
phase (str): 'train' or 'val'.
|
47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
50 |
+
Default: 1.
|
51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
52 |
+
phase. Default: False.
|
53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
54 |
+
seed (int | None): Seed. Default: None
|
55 |
+
"""
|
56 |
+
phase = dataset_opt['phase']
|
57 |
+
rank, _ = get_dist_info()
|
58 |
+
if phase == 'train':
|
59 |
+
if dist: # distributed training
|
60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
62 |
+
else: # non-distributed training
|
63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
66 |
+
dataloader_args = dict(
|
67 |
+
dataset=dataset,
|
68 |
+
batch_size=batch_size,
|
69 |
+
shuffle=False,
|
70 |
+
num_workers=num_workers,
|
71 |
+
sampler=sampler,
|
72 |
+
drop_last=True)
|
73 |
+
if sampler is None:
|
74 |
+
dataloader_args['shuffle'] = True
|
75 |
+
dataloader_args['worker_init_fn'] = partial(
|
76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
77 |
+
elif phase in ['val', 'test']: # validation
|
78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
|
81 |
+
|
82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
83 |
+
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
84 |
+
|
85 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
86 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
87 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
88 |
+
logger = get_root_logger()
|
89 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
90 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
91 |
+
else:
|
92 |
+
# prefetch_mode=None: Normal dataloader
|
93 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
94 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
95 |
+
|
96 |
+
|
97 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
98 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
99 |
+
worker_seed = num_workers * rank + worker_id + seed
|
100 |
+
np.random.seed(worker_seed)
|
101 |
+
random.seed(worker_seed)
|
StableSR/basicsr/data/data_sampler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data.sampler import Sampler
|
4 |
+
|
5 |
+
|
6 |
+
class EnlargedSampler(Sampler):
|
7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
8 |
+
|
9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
11 |
+
time when restart the dataloader after each epoch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
15 |
+
num_replicas (int | None): Number of processes participating in
|
16 |
+
the training. It is usually the world_size.
|
17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
22 |
+
self.dataset = dataset
|
23 |
+
self.num_replicas = num_replicas
|
24 |
+
self.rank = rank
|
25 |
+
self.epoch = 0
|
26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
27 |
+
self.total_size = self.num_samples * self.num_replicas
|
28 |
+
|
29 |
+
def __iter__(self):
|
30 |
+
# deterministically shuffle based on epoch
|
31 |
+
g = torch.Generator()
|
32 |
+
g.manual_seed(self.epoch)
|
33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
34 |
+
|
35 |
+
dataset_size = len(self.dataset)
|
36 |
+
indices = [v % dataset_size for v in indices]
|
37 |
+
|
38 |
+
# subsample
|
39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
40 |
+
assert len(indices) == self.num_samples
|
41 |
+
|
42 |
+
return iter(indices)
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return self.num_samples
|
46 |
+
|
47 |
+
def set_epoch(self, epoch):
|
48 |
+
self.epoch = epoch
|
StableSR/basicsr/data/data_util.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from os import path as osp
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from basicsr.data.transforms import mod_crop
|
8 |
+
from basicsr.utils import img2tensor, scandir
|
9 |
+
|
10 |
+
|
11 |
+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
12 |
+
"""Read a sequence of images from a given folder path.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
path (list[str] | str): List of image paths or image folder path.
|
16 |
+
require_mod_crop (bool): Require mod crop for each image.
|
17 |
+
Default: False.
|
18 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
19 |
+
return_imgname(bool): Whether return image names. Default False.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
23 |
+
list[str]: Returned image name list.
|
24 |
+
"""
|
25 |
+
if isinstance(path, list):
|
26 |
+
img_paths = path
|
27 |
+
else:
|
28 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
29 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
30 |
+
|
31 |
+
if require_mod_crop:
|
32 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
33 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
34 |
+
imgs = torch.stack(imgs, dim=0)
|
35 |
+
|
36 |
+
if return_imgname:
|
37 |
+
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
38 |
+
return imgs, imgnames
|
39 |
+
else:
|
40 |
+
return imgs
|
41 |
+
|
42 |
+
|
43 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
44 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
45 |
+
of images.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
crt_idx (int): Current center index.
|
49 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
50 |
+
num_frames (int): Reading num_frames frames.
|
51 |
+
padding (str): Padding mode, one of
|
52 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
53 |
+
Examples: current_idx = 0, num_frames = 5
|
54 |
+
The generated frame indices under different padding mode:
|
55 |
+
replicate: [0, 0, 0, 1, 2]
|
56 |
+
reflection: [2, 1, 0, 1, 2]
|
57 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
58 |
+
circle: [3, 4, 0, 1, 2]
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list[int]: A list of indices.
|
62 |
+
"""
|
63 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
64 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
65 |
+
|
66 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
67 |
+
num_pad = num_frames // 2
|
68 |
+
|
69 |
+
indices = []
|
70 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
71 |
+
if i < 0:
|
72 |
+
if padding == 'replicate':
|
73 |
+
pad_idx = 0
|
74 |
+
elif padding == 'reflection':
|
75 |
+
pad_idx = -i
|
76 |
+
elif padding == 'reflection_circle':
|
77 |
+
pad_idx = crt_idx + num_pad - i
|
78 |
+
else:
|
79 |
+
pad_idx = num_frames + i
|
80 |
+
elif i > max_frame_num:
|
81 |
+
if padding == 'replicate':
|
82 |
+
pad_idx = max_frame_num
|
83 |
+
elif padding == 'reflection':
|
84 |
+
pad_idx = max_frame_num * 2 - i
|
85 |
+
elif padding == 'reflection_circle':
|
86 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
87 |
+
else:
|
88 |
+
pad_idx = i - num_frames
|
89 |
+
else:
|
90 |
+
pad_idx = i
|
91 |
+
indices.append(pad_idx)
|
92 |
+
return indices
|
93 |
+
|
94 |
+
|
95 |
+
def paired_paths_from_lmdb(folders, keys):
|
96 |
+
"""Generate paired paths from lmdb files.
|
97 |
+
|
98 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
99 |
+
|
100 |
+
::
|
101 |
+
|
102 |
+
lq.lmdb
|
103 |
+
├── data.mdb
|
104 |
+
├── lock.mdb
|
105 |
+
├── meta_info.txt
|
106 |
+
|
107 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
108 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
109 |
+
|
110 |
+
The meta_info.txt is a specified txt file to record the meta information
|
111 |
+
of our datasets. It will be automatically created when preparing
|
112 |
+
datasets by our provided dataset tools.
|
113 |
+
Each line in the txt file records
|
114 |
+
1)image name (with extension),
|
115 |
+
2)image shape,
|
116 |
+
3)compression level, separated by a white space.
|
117 |
+
Example: `baboon.png (120,125,3) 1`
|
118 |
+
|
119 |
+
We use the image name without extension as the lmdb key.
|
120 |
+
Note that we use the same key for the corresponding lq and gt images.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
folders (list[str]): A list of folder path. The order of list should
|
124 |
+
be [input_folder, gt_folder].
|
125 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
126 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
127 |
+
Note that this key is different from lmdb keys.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
list[str]: Returned path list.
|
131 |
+
"""
|
132 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
133 |
+
f'But got {len(folders)}')
|
134 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
135 |
+
input_folder, gt_folder = folders
|
136 |
+
input_key, gt_key = keys
|
137 |
+
|
138 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
139 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
140 |
+
f'formats. But received {input_key}: {input_folder}; '
|
141 |
+
f'{gt_key}: {gt_folder}')
|
142 |
+
# ensure that the two meta_info files are the same
|
143 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
144 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
145 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
146 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
147 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
148 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
149 |
+
else:
|
150 |
+
paths = []
|
151 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
152 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
153 |
+
return paths
|
154 |
+
|
155 |
+
|
156 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
157 |
+
"""Generate paired paths from an meta information file.
|
158 |
+
|
159 |
+
Each line in the meta information file contains the image names and
|
160 |
+
image shape (usually for gt), separated by a white space.
|
161 |
+
|
162 |
+
Example of an meta information file:
|
163 |
+
```
|
164 |
+
0001_s001.png (480,480,3)
|
165 |
+
0001_s002.png (480,480,3)
|
166 |
+
```
|
167 |
+
|
168 |
+
Args:
|
169 |
+
folders (list[str]): A list of folder path. The order of list should
|
170 |
+
be [input_folder, gt_folder].
|
171 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
172 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
173 |
+
meta_info_file (str): Path to the meta information file.
|
174 |
+
filename_tmpl (str): Template for each filename. Note that the
|
175 |
+
template excludes the file extension. Usually the filename_tmpl is
|
176 |
+
for files in the input folder.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
list[str]: Returned path list.
|
180 |
+
"""
|
181 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
182 |
+
f'But got {len(folders)}')
|
183 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
184 |
+
input_folder, gt_folder = folders
|
185 |
+
input_key, gt_key = keys
|
186 |
+
|
187 |
+
with open(meta_info_file, 'r') as fin:
|
188 |
+
gt_names = [line.strip().split(' ')[0] for line in fin]
|
189 |
+
|
190 |
+
paths = []
|
191 |
+
for gt_name in gt_names:
|
192 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
193 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
194 |
+
input_path = osp.join(input_folder, input_name)
|
195 |
+
gt_path = osp.join(gt_folder, gt_name)
|
196 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
197 |
+
return paths
|
198 |
+
|
199 |
+
def paired_paths_from_meta_info_file_2(folders, keys, meta_info_file, filename_tmpl):
|
200 |
+
"""Generate paired paths from an meta information file.
|
201 |
+
|
202 |
+
Each line in the meta information file contains the image names and
|
203 |
+
image shape (usually for gt), separated by a white space.
|
204 |
+
|
205 |
+
Example of an meta information file:
|
206 |
+
```
|
207 |
+
0001_s001.png (480,480,3)
|
208 |
+
0001_s002.png (480,480,3)
|
209 |
+
```
|
210 |
+
|
211 |
+
Args:
|
212 |
+
folders (list[str]): A list of folder path. The order of list should
|
213 |
+
be [input_folder, gt_folder].
|
214 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
215 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
216 |
+
meta_info_file (str): Path to the meta information file.
|
217 |
+
filename_tmpl (str): Template for each filename. Note that the
|
218 |
+
template excludes the file extension. Usually the filename_tmpl is
|
219 |
+
for files in the input folder.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
list[str]: Returned path list.
|
223 |
+
"""
|
224 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
225 |
+
f'But got {len(folders)}')
|
226 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
227 |
+
input_folder, gt_folder = folders
|
228 |
+
input_key, gt_key = keys
|
229 |
+
|
230 |
+
with open(meta_info_file, 'r') as fin:
|
231 |
+
gt_names = [line.strip().split(' ')[0] for line in fin]
|
232 |
+
with open(meta_info_file, 'r') as fin:
|
233 |
+
input_names = [line.strip().split(' ')[1] for line in fin]
|
234 |
+
paths = []
|
235 |
+
for i in range(len(gt_names)):
|
236 |
+
gt_name = gt_names[i]
|
237 |
+
lq_name = input_names[i]
|
238 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
239 |
+
basename = gt_name[:-len(ext)]
|
240 |
+
gt_path = osp.join(gt_folder, gt_name)
|
241 |
+
basename, ext = osp.splitext(osp.basename(lq_name))
|
242 |
+
basename = lq_name[:-len(ext)]
|
243 |
+
input_path = osp.join(input_folder, lq_name)
|
244 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
245 |
+
return paths
|
246 |
+
|
247 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
248 |
+
"""Generate paired paths from folders.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
folders (list[str]): A list of folder path. The order of list should
|
252 |
+
be [input_folder, gt_folder].
|
253 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
254 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
255 |
+
filename_tmpl (str): Template for each filename. Note that the
|
256 |
+
template excludes the file extension. Usually the filename_tmpl is
|
257 |
+
for files in the input folder.
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
list[str]: Returned path list.
|
261 |
+
"""
|
262 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
263 |
+
f'But got {len(folders)}')
|
264 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
265 |
+
input_folder, gt_folder = folders
|
266 |
+
input_key, gt_key = keys
|
267 |
+
|
268 |
+
input_paths = list(scandir(input_folder))
|
269 |
+
gt_paths = list(scandir(gt_folder))
|
270 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
271 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
272 |
+
paths = []
|
273 |
+
for gt_path in gt_paths:
|
274 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
275 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
276 |
+
input_path = osp.join(input_folder, input_name)
|
277 |
+
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
|
278 |
+
gt_path = osp.join(gt_folder, gt_path)
|
279 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
280 |
+
return paths
|
281 |
+
|
282 |
+
|
283 |
+
def paths_from_folder(folder):
|
284 |
+
"""Generate paths from folder.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
folder (str): Folder path.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
list[str]: Returned path list.
|
291 |
+
"""
|
292 |
+
|
293 |
+
paths = list(scandir(folder))
|
294 |
+
paths = [osp.join(folder, path) for path in paths]
|
295 |
+
return paths
|
296 |
+
|
297 |
+
|
298 |
+
def paths_from_lmdb(folder):
|
299 |
+
"""Generate paths from lmdb.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
folder (str): Folder path.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
list[str]: Returned path list.
|
306 |
+
"""
|
307 |
+
if not folder.endswith('.lmdb'):
|
308 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
309 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
310 |
+
paths = [line.split('.')[0] for line in fin]
|
311 |
+
return paths
|
312 |
+
|
313 |
+
|
314 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
315 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
kernel_size (int): Kernel size. Default: 13.
|
319 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
np.array: The Gaussian kernel.
|
323 |
+
"""
|
324 |
+
from scipy.ndimage import filters as filters
|
325 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
326 |
+
# set element at the middle to one, a dirac delta
|
327 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
328 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
329 |
+
return filters.gaussian_filter(kernel, sigma)
|
330 |
+
|
331 |
+
|
332 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
333 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
337 |
+
kernel_size (int): Kernel size. Default: 13.
|
338 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
339 |
+
Default: 4.
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
Tensor: DUF downsampled frames.
|
343 |
+
"""
|
344 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
345 |
+
|
346 |
+
squeeze_flag = False
|
347 |
+
if x.ndim == 4:
|
348 |
+
squeeze_flag = True
|
349 |
+
x = x.unsqueeze(0)
|
350 |
+
b, t, c, h, w = x.size()
|
351 |
+
x = x.view(-1, 1, h, w)
|
352 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
353 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
354 |
+
|
355 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
356 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
357 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
358 |
+
x = x[:, :, 2:-2, 2:-2]
|
359 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
360 |
+
if squeeze_flag:
|
361 |
+
x = x.squeeze(0)
|
362 |
+
return x
|
StableSR/basicsr/data/degradations.py
ADDED
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
from scipy import special
|
7 |
+
from scipy.stats import multivariate_normal
|
8 |
+
from torchvision.transforms.functional_tensor import rgb_to_grayscale
|
9 |
+
|
10 |
+
# -------------------------------------------------------------------- #
|
11 |
+
# --------------------------- blur kernels --------------------------- #
|
12 |
+
# -------------------------------------------------------------------- #
|
13 |
+
|
14 |
+
|
15 |
+
# --------------------------- util functions --------------------------- #
|
16 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
17 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
18 |
+
|
19 |
+
Args:
|
20 |
+
sig_x (float):
|
21 |
+
sig_y (float):
|
22 |
+
theta (float): Radian measurement.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
ndarray: Rotated sigma matrix.
|
26 |
+
"""
|
27 |
+
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
28 |
+
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
29 |
+
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
30 |
+
|
31 |
+
|
32 |
+
def mesh_grid(kernel_size):
|
33 |
+
"""Generate the mesh grid, centering at zero.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
kernel_size (int):
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
40 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
41 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
42 |
+
"""
|
43 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
44 |
+
xx, yy = np.meshgrid(ax, ax)
|
45 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
46 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
47 |
+
return xy, xx, yy
|
48 |
+
|
49 |
+
|
50 |
+
def pdf2(sigma_matrix, grid):
|
51 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
55 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
56 |
+
with the shape (K, K, 2), K is the kernel size.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
kernel (ndarrray): un-normalized kernel.
|
60 |
+
"""
|
61 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
62 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
63 |
+
return kernel
|
64 |
+
|
65 |
+
|
66 |
+
def cdf2(d_matrix, grid):
|
67 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
68 |
+
Used in skewed Gaussian distribution.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
d_matrix (ndarrasy): skew matrix.
|
72 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
73 |
+
with the shape (K, K, 2), K is the kernel size.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
cdf (ndarray): skewed cdf.
|
77 |
+
"""
|
78 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
79 |
+
grid = np.dot(grid, d_matrix)
|
80 |
+
cdf = rv.cdf(grid)
|
81 |
+
return cdf
|
82 |
+
|
83 |
+
|
84 |
+
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
85 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
86 |
+
|
87 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
kernel_size (int):
|
91 |
+
sig_x (float):
|
92 |
+
sig_y (float):
|
93 |
+
theta (float): Radian measurement.
|
94 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
95 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
96 |
+
isotropic (bool):
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
kernel (ndarray): normalized kernel.
|
100 |
+
"""
|
101 |
+
if grid is None:
|
102 |
+
grid, _, _ = mesh_grid(kernel_size)
|
103 |
+
if isotropic:
|
104 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
105 |
+
else:
|
106 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
107 |
+
kernel = pdf2(sigma_matrix, grid)
|
108 |
+
kernel = kernel / np.sum(kernel)
|
109 |
+
return kernel
|
110 |
+
|
111 |
+
|
112 |
+
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
113 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
114 |
+
|
115 |
+
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
|
116 |
+
|
117 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
kernel_size (int):
|
121 |
+
sig_x (float):
|
122 |
+
sig_y (float):
|
123 |
+
theta (float): Radian measurement.
|
124 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
125 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
126 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
kernel (ndarray): normalized kernel.
|
130 |
+
"""
|
131 |
+
if grid is None:
|
132 |
+
grid, _, _ = mesh_grid(kernel_size)
|
133 |
+
if isotropic:
|
134 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
135 |
+
else:
|
136 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
137 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
138 |
+
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
139 |
+
kernel = kernel / np.sum(kernel)
|
140 |
+
return kernel
|
141 |
+
|
142 |
+
|
143 |
+
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
144 |
+
"""Generate a plateau-like anisotropic kernel.
|
145 |
+
|
146 |
+
1 / (1+x^(beta))
|
147 |
+
|
148 |
+
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
149 |
+
|
150 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
kernel_size (int):
|
154 |
+
sig_x (float):
|
155 |
+
sig_y (float):
|
156 |
+
theta (float): Radian measurement.
|
157 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
158 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
159 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
kernel (ndarray): normalized kernel.
|
163 |
+
"""
|
164 |
+
if grid is None:
|
165 |
+
grid, _, _ = mesh_grid(kernel_size)
|
166 |
+
if isotropic:
|
167 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
168 |
+
else:
|
169 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
170 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
171 |
+
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
172 |
+
kernel = kernel / np.sum(kernel)
|
173 |
+
return kernel
|
174 |
+
|
175 |
+
|
176 |
+
def random_bivariate_Gaussian(kernel_size,
|
177 |
+
sigma_x_range,
|
178 |
+
sigma_y_range,
|
179 |
+
rotation_range,
|
180 |
+
noise_range=None,
|
181 |
+
isotropic=True,
|
182 |
+
return_sigma=False):
|
183 |
+
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
184 |
+
|
185 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
kernel_size (int):
|
189 |
+
sigma_x_range (tuple): [0.6, 5]
|
190 |
+
sigma_y_range (tuple): [0.6, 5]
|
191 |
+
rotation range (tuple): [-math.pi, math.pi]
|
192 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
193 |
+
[0.75, 1.25]. Default: None
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
kernel (ndarray):
|
197 |
+
"""
|
198 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
199 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
200 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
201 |
+
if isotropic is False:
|
202 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
203 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
204 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
205 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
206 |
+
else:
|
207 |
+
sigma_y = sigma_x
|
208 |
+
rotation = 0
|
209 |
+
|
210 |
+
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
211 |
+
|
212 |
+
# add multiplicative noise
|
213 |
+
if noise_range is not None:
|
214 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
215 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
216 |
+
kernel = kernel * noise
|
217 |
+
kernel = kernel / np.sum(kernel)
|
218 |
+
if not return_sigma:
|
219 |
+
return kernel
|
220 |
+
else:
|
221 |
+
return kernel, [sigma_x, sigma_y]
|
222 |
+
|
223 |
+
|
224 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
225 |
+
sigma_x_range,
|
226 |
+
sigma_y_range,
|
227 |
+
rotation_range,
|
228 |
+
beta_range,
|
229 |
+
noise_range=None,
|
230 |
+
isotropic=True,
|
231 |
+
return_sigma=False):
|
232 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
233 |
+
|
234 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
kernel_size (int):
|
238 |
+
sigma_x_range (tuple): [0.6, 5]
|
239 |
+
sigma_y_range (tuple): [0.6, 5]
|
240 |
+
rotation range (tuple): [-math.pi, math.pi]
|
241 |
+
beta_range (tuple): [0.5, 8]
|
242 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
243 |
+
[0.75, 1.25]. Default: None
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
kernel (ndarray):
|
247 |
+
"""
|
248 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
249 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
250 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
251 |
+
if isotropic is False:
|
252 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
253 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
254 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
255 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
256 |
+
else:
|
257 |
+
sigma_y = sigma_x
|
258 |
+
rotation = 0
|
259 |
+
|
260 |
+
# assume beta_range[0] < 1 < beta_range[1]
|
261 |
+
if np.random.uniform() < 0.5:
|
262 |
+
beta = np.random.uniform(beta_range[0], 1)
|
263 |
+
else:
|
264 |
+
beta = np.random.uniform(1, beta_range[1])
|
265 |
+
|
266 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
267 |
+
|
268 |
+
# add multiplicative noise
|
269 |
+
if noise_range is not None:
|
270 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
271 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
272 |
+
kernel = kernel * noise
|
273 |
+
kernel = kernel / np.sum(kernel)
|
274 |
+
if not return_sigma:
|
275 |
+
return kernel
|
276 |
+
else:
|
277 |
+
return kernel, [sigma_x, sigma_y]
|
278 |
+
|
279 |
+
|
280 |
+
def random_bivariate_plateau(kernel_size,
|
281 |
+
sigma_x_range,
|
282 |
+
sigma_y_range,
|
283 |
+
rotation_range,
|
284 |
+
beta_range,
|
285 |
+
noise_range=None,
|
286 |
+
isotropic=True,
|
287 |
+
return_sigma=False):
|
288 |
+
"""Randomly generate bivariate plateau kernels.
|
289 |
+
|
290 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
kernel_size (int):
|
294 |
+
sigma_x_range (tuple): [0.6, 5]
|
295 |
+
sigma_y_range (tuple): [0.6, 5]
|
296 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
297 |
+
beta_range (tuple): [1, 4]
|
298 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
299 |
+
[0.75, 1.25]. Default: None
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
kernel (ndarray):
|
303 |
+
"""
|
304 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
305 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
306 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
307 |
+
if isotropic is False:
|
308 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
309 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
310 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
311 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
312 |
+
else:
|
313 |
+
sigma_y = sigma_x
|
314 |
+
rotation = 0
|
315 |
+
|
316 |
+
# TODO: this may be not proper
|
317 |
+
if np.random.uniform() < 0.5:
|
318 |
+
beta = np.random.uniform(beta_range[0], 1)
|
319 |
+
else:
|
320 |
+
beta = np.random.uniform(1, beta_range[1])
|
321 |
+
|
322 |
+
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
323 |
+
# add multiplicative noise
|
324 |
+
if noise_range is not None:
|
325 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
326 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
327 |
+
kernel = kernel * noise
|
328 |
+
kernel = kernel / np.sum(kernel)
|
329 |
+
|
330 |
+
if not return_sigma:
|
331 |
+
return kernel
|
332 |
+
else:
|
333 |
+
return kernel, [sigma_x, sigma_y]
|
334 |
+
|
335 |
+
|
336 |
+
def random_mixed_kernels(kernel_list,
|
337 |
+
kernel_prob,
|
338 |
+
kernel_size=21,
|
339 |
+
sigma_x_range=(0.6, 5),
|
340 |
+
sigma_y_range=(0.6, 5),
|
341 |
+
rotation_range=(-math.pi, math.pi),
|
342 |
+
betag_range=(0.5, 8),
|
343 |
+
betap_range=(0.5, 8),
|
344 |
+
noise_range=None,
|
345 |
+
return_sigma=False):
|
346 |
+
"""Randomly generate mixed kernels.
|
347 |
+
|
348 |
+
Args:
|
349 |
+
kernel_list (tuple): a list name of kernel types,
|
350 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
351 |
+
'plateau_aniso']
|
352 |
+
kernel_prob (tuple): corresponding kernel probability for each
|
353 |
+
kernel type
|
354 |
+
kernel_size (int):
|
355 |
+
sigma_x_range (tuple): [0.6, 5]
|
356 |
+
sigma_y_range (tuple): [0.6, 5]
|
357 |
+
rotation range (tuple): [-math.pi, math.pi]
|
358 |
+
beta_range (tuple): [0.5, 8]
|
359 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
360 |
+
[0.75, 1.25]. Default: None
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
kernel (ndarray):
|
364 |
+
"""
|
365 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
366 |
+
if not return_sigma:
|
367 |
+
if kernel_type == 'iso':
|
368 |
+
kernel = random_bivariate_Gaussian(
|
369 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma)
|
370 |
+
elif kernel_type == 'aniso':
|
371 |
+
kernel = random_bivariate_Gaussian(
|
372 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma)
|
373 |
+
elif kernel_type == 'generalized_iso':
|
374 |
+
kernel = random_bivariate_generalized_Gaussian(
|
375 |
+
kernel_size,
|
376 |
+
sigma_x_range,
|
377 |
+
sigma_y_range,
|
378 |
+
rotation_range,
|
379 |
+
betag_range,
|
380 |
+
noise_range=noise_range,
|
381 |
+
isotropic=True,
|
382 |
+
return_sigma=return_sigma)
|
383 |
+
elif kernel_type == 'generalized_aniso':
|
384 |
+
kernel = random_bivariate_generalized_Gaussian(
|
385 |
+
kernel_size,
|
386 |
+
sigma_x_range,
|
387 |
+
sigma_y_range,
|
388 |
+
rotation_range,
|
389 |
+
betag_range,
|
390 |
+
noise_range=noise_range,
|
391 |
+
isotropic=False,
|
392 |
+
return_sigma=return_sigma)
|
393 |
+
elif kernel_type == 'plateau_iso':
|
394 |
+
kernel = random_bivariate_plateau(
|
395 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma)
|
396 |
+
elif kernel_type == 'plateau_aniso':
|
397 |
+
kernel = random_bivariate_plateau(
|
398 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma)
|
399 |
+
return kernel
|
400 |
+
else:
|
401 |
+
if kernel_type == 'iso':
|
402 |
+
kernel, sigma_list = random_bivariate_Gaussian(
|
403 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma)
|
404 |
+
elif kernel_type == 'aniso':
|
405 |
+
kernel, sigma_list = random_bivariate_Gaussian(
|
406 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma)
|
407 |
+
elif kernel_type == 'generalized_iso':
|
408 |
+
kernel, sigma_list = random_bivariate_generalized_Gaussian(
|
409 |
+
kernel_size,
|
410 |
+
sigma_x_range,
|
411 |
+
sigma_y_range,
|
412 |
+
rotation_range,
|
413 |
+
betag_range,
|
414 |
+
noise_range=noise_range,
|
415 |
+
isotropic=True,
|
416 |
+
return_sigma=return_sigma)
|
417 |
+
elif kernel_type == 'generalized_aniso':
|
418 |
+
kernel, sigma_list = random_bivariate_generalized_Gaussian(
|
419 |
+
kernel_size,
|
420 |
+
sigma_x_range,
|
421 |
+
sigma_y_range,
|
422 |
+
rotation_range,
|
423 |
+
betag_range,
|
424 |
+
noise_range=noise_range,
|
425 |
+
isotropic=False,
|
426 |
+
return_sigma=return_sigma)
|
427 |
+
elif kernel_type == 'plateau_iso':
|
428 |
+
kernel, sigma_list = random_bivariate_plateau(
|
429 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma)
|
430 |
+
elif kernel_type == 'plateau_aniso':
|
431 |
+
kernel, sigma_list = random_bivariate_plateau(
|
432 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma)
|
433 |
+
return kernel, sigma_list
|
434 |
+
|
435 |
+
|
436 |
+
np.seterr(divide='ignore', invalid='ignore')
|
437 |
+
|
438 |
+
|
439 |
+
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
440 |
+
"""2D sinc filter
|
441 |
+
|
442 |
+
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
443 |
+
|
444 |
+
Args:
|
445 |
+
cutoff (float): cutoff frequency in radians (pi is max)
|
446 |
+
kernel_size (int): horizontal and vertical size, must be odd.
|
447 |
+
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
448 |
+
"""
|
449 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
450 |
+
kernel = np.fromfunction(
|
451 |
+
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
452 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
|
453 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
|
454 |
+
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
|
455 |
+
kernel = kernel / np.sum(kernel)
|
456 |
+
if pad_to > kernel_size:
|
457 |
+
pad_size = (pad_to - kernel_size) // 2
|
458 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
459 |
+
return kernel
|
460 |
+
|
461 |
+
|
462 |
+
# ------------------------------------------------------------- #
|
463 |
+
# --------------------------- noise --------------------------- #
|
464 |
+
# ------------------------------------------------------------- #
|
465 |
+
|
466 |
+
# ----------------------- Gaussian Noise ----------------------- #
|
467 |
+
|
468 |
+
|
469 |
+
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
470 |
+
"""Generate Gaussian noise.
|
471 |
+
|
472 |
+
Args:
|
473 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
474 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
478 |
+
float32.
|
479 |
+
"""
|
480 |
+
if gray_noise:
|
481 |
+
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
482 |
+
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
483 |
+
else:
|
484 |
+
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
485 |
+
return noise
|
486 |
+
|
487 |
+
|
488 |
+
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
489 |
+
"""Add Gaussian noise.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
493 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
497 |
+
float32.
|
498 |
+
"""
|
499 |
+
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
500 |
+
out = img + noise
|
501 |
+
if clip and rounds:
|
502 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
503 |
+
elif clip:
|
504 |
+
out = np.clip(out, 0, 1)
|
505 |
+
elif rounds:
|
506 |
+
out = (out * 255.0).round() / 255.
|
507 |
+
return out
|
508 |
+
|
509 |
+
|
510 |
+
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
511 |
+
"""Add Gaussian noise (PyTorch version).
|
512 |
+
|
513 |
+
Args:
|
514 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
515 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
516 |
+
|
517 |
+
Returns:
|
518 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
519 |
+
float32.
|
520 |
+
"""
|
521 |
+
b, _, h, w = img.size()
|
522 |
+
if not isinstance(sigma, (float, int)):
|
523 |
+
sigma = sigma.view(img.size(0), 1, 1, 1)
|
524 |
+
if isinstance(gray_noise, (float, int)):
|
525 |
+
cal_gray_noise = gray_noise > 0
|
526 |
+
else:
|
527 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
528 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
529 |
+
|
530 |
+
if cal_gray_noise:
|
531 |
+
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
532 |
+
noise_gray = noise_gray.view(b, 1, h, w)
|
533 |
+
|
534 |
+
# always calculate color noise
|
535 |
+
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
536 |
+
|
537 |
+
if cal_gray_noise:
|
538 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
539 |
+
return noise
|
540 |
+
|
541 |
+
|
542 |
+
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
543 |
+
"""Add Gaussian noise (PyTorch version).
|
544 |
+
|
545 |
+
Args:
|
546 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
547 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
548 |
+
|
549 |
+
Returns:
|
550 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
551 |
+
float32.
|
552 |
+
"""
|
553 |
+
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
|
554 |
+
out = img + noise
|
555 |
+
if clip and rounds:
|
556 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
557 |
+
elif clip:
|
558 |
+
out = torch.clamp(out, 0, 1)
|
559 |
+
elif rounds:
|
560 |
+
out = (out * 255.0).round() / 255.
|
561 |
+
return out
|
562 |
+
|
563 |
+
|
564 |
+
# ----------------------- Random Gaussian Noise ----------------------- #
|
565 |
+
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0, return_sigma=False):
|
566 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
567 |
+
if np.random.uniform() < gray_prob:
|
568 |
+
gray_noise = True
|
569 |
+
else:
|
570 |
+
gray_noise = False
|
571 |
+
if return_sigma:
|
572 |
+
return generate_gaussian_noise(img, sigma, gray_noise), sigma
|
573 |
+
else:
|
574 |
+
return generate_gaussian_noise(img, sigma, gray_noise)
|
575 |
+
|
576 |
+
|
577 |
+
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False, return_sigma=False):
|
578 |
+
if return_sigma:
|
579 |
+
noise, sigma = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma)
|
580 |
+
else:
|
581 |
+
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma)
|
582 |
+
out = img + noise
|
583 |
+
if clip and rounds:
|
584 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
585 |
+
elif clip:
|
586 |
+
out = np.clip(out, 0, 1)
|
587 |
+
elif rounds:
|
588 |
+
out = (out * 255.0).round() / 255.
|
589 |
+
if return_sigma:
|
590 |
+
return out, sigma
|
591 |
+
else:
|
592 |
+
return out
|
593 |
+
|
594 |
+
|
595 |
+
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
596 |
+
sigma = torch.rand(
|
597 |
+
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
598 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
599 |
+
gray_noise = (gray_noise < gray_prob).float()
|
600 |
+
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
601 |
+
|
602 |
+
|
603 |
+
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
604 |
+
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
605 |
+
out = img + noise
|
606 |
+
if clip and rounds:
|
607 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
608 |
+
elif clip:
|
609 |
+
out = torch.clamp(out, 0, 1)
|
610 |
+
elif rounds:
|
611 |
+
out = (out * 255.0).round() / 255.
|
612 |
+
return out
|
613 |
+
|
614 |
+
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
615 |
+
|
616 |
+
|
617 |
+
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
618 |
+
"""Generate poisson noise.
|
619 |
+
|
620 |
+
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
621 |
+
|
622 |
+
Args:
|
623 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
624 |
+
scale (float): Noise scale. Default: 1.0.
|
625 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
626 |
+
|
627 |
+
Returns:
|
628 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
629 |
+
float32.
|
630 |
+
"""
|
631 |
+
if gray_noise:
|
632 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
633 |
+
# round and clip image for counting vals correctly
|
634 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
635 |
+
vals = len(np.unique(img))
|
636 |
+
vals = 2**np.ceil(np.log2(vals))
|
637 |
+
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
638 |
+
noise = out - img
|
639 |
+
if gray_noise:
|
640 |
+
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
641 |
+
return noise * scale
|
642 |
+
|
643 |
+
|
644 |
+
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
645 |
+
"""Add poisson noise.
|
646 |
+
|
647 |
+
Args:
|
648 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
649 |
+
scale (float): Noise scale. Default: 1.0.
|
650 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
651 |
+
|
652 |
+
Returns:
|
653 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
654 |
+
float32.
|
655 |
+
"""
|
656 |
+
noise = generate_poisson_noise(img, scale, gray_noise)
|
657 |
+
out = img + noise
|
658 |
+
if clip and rounds:
|
659 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
660 |
+
elif clip:
|
661 |
+
out = np.clip(out, 0, 1)
|
662 |
+
elif rounds:
|
663 |
+
out = (out * 255.0).round() / 255.
|
664 |
+
return out
|
665 |
+
|
666 |
+
|
667 |
+
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
668 |
+
"""Generate a batch of poisson noise (PyTorch version)
|
669 |
+
|
670 |
+
Args:
|
671 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
672 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
673 |
+
Default: 1.0.
|
674 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
675 |
+
0 for False, 1 for True. Default: 0.
|
676 |
+
|
677 |
+
Returns:
|
678 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
679 |
+
float32.
|
680 |
+
"""
|
681 |
+
b, _, h, w = img.size()
|
682 |
+
if isinstance(gray_noise, (float, int)):
|
683 |
+
cal_gray_noise = gray_noise > 0
|
684 |
+
else:
|
685 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
686 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
687 |
+
if cal_gray_noise:
|
688 |
+
img_gray = rgb_to_grayscale(img, num_output_channels=1)
|
689 |
+
# round and clip image for counting vals correctly
|
690 |
+
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
691 |
+
# use for-loop to get the unique values for each sample
|
692 |
+
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
693 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
694 |
+
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
695 |
+
out = torch.poisson(img_gray * vals) / vals
|
696 |
+
noise_gray = out - img_gray
|
697 |
+
noise_gray = noise_gray.expand(b, 3, h, w)
|
698 |
+
|
699 |
+
# always calculate color noise
|
700 |
+
# round and clip image for counting vals correctly
|
701 |
+
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
702 |
+
# use for-loop to get the unique values for each sample
|
703 |
+
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
704 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
705 |
+
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
706 |
+
out = torch.poisson(img * vals) / vals
|
707 |
+
noise = out - img
|
708 |
+
if cal_gray_noise:
|
709 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
710 |
+
if not isinstance(scale, (float, int)):
|
711 |
+
scale = scale.view(b, 1, 1, 1)
|
712 |
+
return noise * scale
|
713 |
+
|
714 |
+
|
715 |
+
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
716 |
+
"""Add poisson noise to a batch of images (PyTorch version).
|
717 |
+
|
718 |
+
Args:
|
719 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
720 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
721 |
+
Default: 1.0.
|
722 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
723 |
+
0 for False, 1 for True. Default: 0.
|
724 |
+
|
725 |
+
Returns:
|
726 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
727 |
+
float32.
|
728 |
+
"""
|
729 |
+
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
730 |
+
out = img + noise
|
731 |
+
if clip and rounds:
|
732 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
733 |
+
elif clip:
|
734 |
+
out = torch.clamp(out, 0, 1)
|
735 |
+
elif rounds:
|
736 |
+
out = (out * 255.0).round() / 255.
|
737 |
+
return out
|
738 |
+
|
739 |
+
|
740 |
+
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
741 |
+
|
742 |
+
|
743 |
+
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
744 |
+
scale = np.random.uniform(scale_range[0], scale_range[1])
|
745 |
+
if np.random.uniform() < gray_prob:
|
746 |
+
gray_noise = True
|
747 |
+
else:
|
748 |
+
gray_noise = False
|
749 |
+
return generate_poisson_noise(img, scale, gray_noise)
|
750 |
+
|
751 |
+
|
752 |
+
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
753 |
+
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
754 |
+
out = img + noise
|
755 |
+
if clip and rounds:
|
756 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
757 |
+
elif clip:
|
758 |
+
out = np.clip(out, 0, 1)
|
759 |
+
elif rounds:
|
760 |
+
out = (out * 255.0).round() / 255.
|
761 |
+
return out
|
762 |
+
|
763 |
+
|
764 |
+
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
765 |
+
scale = torch.rand(
|
766 |
+
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
767 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
768 |
+
gray_noise = (gray_noise < gray_prob).float()
|
769 |
+
return generate_poisson_noise_pt(img, scale, gray_noise)
|
770 |
+
|
771 |
+
|
772 |
+
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
773 |
+
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
774 |
+
out = img + noise
|
775 |
+
if clip and rounds:
|
776 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
777 |
+
elif clip:
|
778 |
+
out = torch.clamp(out, 0, 1)
|
779 |
+
elif rounds:
|
780 |
+
out = (out * 255.0).round() / 255.
|
781 |
+
return out
|
782 |
+
|
783 |
+
# ----------------------- Random speckle Noise ----------------------- #
|
784 |
+
|
785 |
+
def random_add_speckle_noise(imgs, speckle_std):
|
786 |
+
std_range = speckle_std
|
787 |
+
std_l = std_range[0]
|
788 |
+
std_r = std_range[1]
|
789 |
+
mean=0
|
790 |
+
std=random.uniform(std_l/255.,std_r/255.)
|
791 |
+
|
792 |
+
outputs = []
|
793 |
+
for img in imgs:
|
794 |
+
gauss=np.random.normal(loc=mean,scale=std,size=img.shape)
|
795 |
+
noisy=img+gauss*img
|
796 |
+
noisy=np.clip(noisy,0,1).astype(np.float32)
|
797 |
+
|
798 |
+
outputs.append(noisy)
|
799 |
+
|
800 |
+
return outputs
|
801 |
+
|
802 |
+
|
803 |
+
def random_add_speckle_noise_pt(img, speckle_std):
|
804 |
+
std_range = speckle_std
|
805 |
+
std_l = std_range[0]
|
806 |
+
std_r = std_range[1]
|
807 |
+
mean=0
|
808 |
+
std=random.uniform(std_l/255.,std_r/255.)
|
809 |
+
gauss=torch.normal(mean=mean,std=std,size=img.size()).to(img.device)
|
810 |
+
noisy=img+gauss*img
|
811 |
+
noisy=torch.clamp(noisy,0,1)
|
812 |
+
return noisy
|
813 |
+
|
814 |
+
# ----------------------- Random saltpepper Noise ----------------------- #
|
815 |
+
|
816 |
+
def random_add_saltpepper_noise(imgs, saltpepper_amount, saltpepper_svsp):
|
817 |
+
p_range = saltpepper_amount
|
818 |
+
p = random.uniform(p_range[0], p_range[1])
|
819 |
+
q_range = saltpepper_svsp
|
820 |
+
q = random.uniform(q_range[0], q_range[1])
|
821 |
+
|
822 |
+
outputs = []
|
823 |
+
for img in imgs:
|
824 |
+
out = img.copy()
|
825 |
+
flipped = np.random.choice([True, False], size=img.shape,
|
826 |
+
p=[p, 1 - p])
|
827 |
+
salted = np.random.choice([True, False], size=img.shape,
|
828 |
+
p=[q, 1 - q])
|
829 |
+
peppered = ~salted
|
830 |
+
out[flipped & salted] = 1
|
831 |
+
out[flipped & peppered] = 0.
|
832 |
+
noisy = np.clip(out, 0, 1).astype(np.float32)
|
833 |
+
|
834 |
+
outputs.append(noisy)
|
835 |
+
|
836 |
+
return outputs
|
837 |
+
|
838 |
+
def random_add_saltpepper_noise_pt(imgs, saltpepper_amount, saltpepper_svsp):
|
839 |
+
p_range = saltpepper_amount
|
840 |
+
p = random.uniform(p_range[0], p_range[1])
|
841 |
+
q_range = saltpepper_svsp
|
842 |
+
q = random.uniform(q_range[0], q_range[1])
|
843 |
+
|
844 |
+
imgs = imgs.permute(0,2,3,1)
|
845 |
+
|
846 |
+
outputs = []
|
847 |
+
for i in range(imgs.size(0)):
|
848 |
+
img = imgs[i]
|
849 |
+
out = img.clone()
|
850 |
+
flipped = np.random.choice([True, False], size=img.shape,
|
851 |
+
p=[p, 1 - p])
|
852 |
+
salted = np.random.choice([True, False], size=img.shape,
|
853 |
+
p=[q, 1 - q])
|
854 |
+
peppered = ~salted
|
855 |
+
temp = flipped & salted
|
856 |
+
out[flipped & salted] = 1
|
857 |
+
out[flipped & peppered] = 0.
|
858 |
+
noisy = torch.clamp(out, 0, 1)
|
859 |
+
|
860 |
+
outputs.append(noisy.permute(2,0,1))
|
861 |
+
if len(outputs)>1:
|
862 |
+
return torch.cat(outputs, dim=0)
|
863 |
+
else:
|
864 |
+
return outputs[0].unsqueeze(0)
|
865 |
+
|
866 |
+
# ----------------------- Random screen Noise ----------------------- #
|
867 |
+
|
868 |
+
def random_add_screen_noise(imgs, linewidth, space):
|
869 |
+
#screen_noise = np.random.uniform() < self.params['noise_prob'][0]
|
870 |
+
linewidth = linewidth
|
871 |
+
linewidth = int(np.random.uniform(linewidth[0], linewidth[1]))
|
872 |
+
space = space
|
873 |
+
space = int(np.random.uniform(space[0], space[1]))
|
874 |
+
center_color = [213,230,230] # RGB
|
875 |
+
outputs = []
|
876 |
+
for img in imgs:
|
877 |
+
noise = img.copy()
|
878 |
+
|
879 |
+
tmp_mask = np.zeros((img.shape[1], img.shape[0]), dtype=np.float32)
|
880 |
+
for i in range(0, img.shape[0], int((space+linewidth))):
|
881 |
+
tmp_mask[:, i:(i+linewidth)] = 1
|
882 |
+
colour_masks = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.float32)
|
883 |
+
colour_masks[:,:,0] = (center_color[0] + np.random.uniform(-20, 20))/255.
|
884 |
+
colour_masks[:,:,1] = (center_color[1] + np.random.uniform(0, 20))/255.
|
885 |
+
colour_masks[:,:,2] = (center_color[2] + np.random.uniform(0, 20))/255.
|
886 |
+
noise_color = cv2.addWeighted(noise, 0.6, colour_masks, 0.4, 0.0)
|
887 |
+
noise = noise*(1-(tmp_mask[:,:,np.newaxis])) + noise_color*(tmp_mask[:,:,np.newaxis])
|
888 |
+
|
889 |
+
outputs.append(noise)
|
890 |
+
|
891 |
+
return outputs
|
892 |
+
|
893 |
+
|
894 |
+
# ------------------------------------------------------------------------ #
|
895 |
+
# --------------------------- JPEG compression --------------------------- #
|
896 |
+
# ------------------------------------------------------------------------ #
|
897 |
+
|
898 |
+
|
899 |
+
def add_jpg_compression(img, quality=90):
|
900 |
+
"""Add JPG compression artifacts.
|
901 |
+
|
902 |
+
Args:
|
903 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
904 |
+
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
905 |
+
best quality. Default: 90.
|
906 |
+
|
907 |
+
Returns:
|
908 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
909 |
+
float32.
|
910 |
+
"""
|
911 |
+
img = np.clip(img, 0, 1)
|
912 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]
|
913 |
+
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
914 |
+
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
915 |
+
return img
|
916 |
+
|
917 |
+
|
918 |
+
def random_add_jpg_compression(img, quality_range=(90, 100), return_q=False):
|
919 |
+
"""Randomly add JPG compression artifacts.
|
920 |
+
|
921 |
+
Args:
|
922 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
923 |
+
quality_range (tuple[float] | list[float]): JPG compression quality
|
924 |
+
range. 0 for lowest quality, 100 for best quality.
|
925 |
+
Default: (90, 100).
|
926 |
+
|
927 |
+
Returns:
|
928 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
929 |
+
float32.
|
930 |
+
"""
|
931 |
+
quality = np.random.uniform(quality_range[0], quality_range[1])
|
932 |
+
if return_q:
|
933 |
+
return add_jpg_compression(img, quality), quality
|
934 |
+
else:
|
935 |
+
return add_jpg_compression(img, quality)
|
StableSR/basicsr/data/ffhq_dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
from os import path as osp
|
4 |
+
from torch.utils import data as data
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
|
7 |
+
from basicsr.data.transforms import augment
|
8 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
9 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
10 |
+
|
11 |
+
|
12 |
+
@DATASET_REGISTRY.register()
|
13 |
+
class FFHQDataset(data.Dataset):
|
14 |
+
"""FFHQ dataset for StyleGAN.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
18 |
+
dataroot_gt (str): Data root path for gt.
|
19 |
+
io_backend (dict): IO backend type and other kwarg.
|
20 |
+
mean (list | tuple): Image mean.
|
21 |
+
std (list | tuple): Image std.
|
22 |
+
use_hflip (bool): Whether to horizontally flip.
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, opt):
|
27 |
+
super(FFHQDataset, self).__init__()
|
28 |
+
self.opt = opt
|
29 |
+
# file client (io backend)
|
30 |
+
self.file_client = None
|
31 |
+
self.io_backend_opt = opt['io_backend']
|
32 |
+
|
33 |
+
self.gt_folder = opt['dataroot_gt']
|
34 |
+
self.mean = opt['mean']
|
35 |
+
self.std = opt['std']
|
36 |
+
|
37 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
38 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
39 |
+
if not self.gt_folder.endswith('.lmdb'):
|
40 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
41 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
42 |
+
self.paths = [line.split('.')[0] for line in fin]
|
43 |
+
else:
|
44 |
+
# FFHQ has 70000 images in total
|
45 |
+
self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
|
46 |
+
|
47 |
+
def __getitem__(self, index):
|
48 |
+
if self.file_client is None:
|
49 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
50 |
+
|
51 |
+
# load gt image
|
52 |
+
gt_path = self.paths[index]
|
53 |
+
# avoid errors caused by high latency in reading files
|
54 |
+
retry = 3
|
55 |
+
while retry > 0:
|
56 |
+
try:
|
57 |
+
img_bytes = self.file_client.get(gt_path)
|
58 |
+
except Exception as e:
|
59 |
+
logger = get_root_logger()
|
60 |
+
logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
|
61 |
+
# change another file to read
|
62 |
+
index = random.randint(0, self.__len__())
|
63 |
+
gt_path = self.paths[index]
|
64 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
65 |
+
else:
|
66 |
+
break
|
67 |
+
finally:
|
68 |
+
retry -= 1
|
69 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
70 |
+
|
71 |
+
# random horizontal flip
|
72 |
+
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
|
73 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
74 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
75 |
+
# normalize
|
76 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
77 |
+
return {'gt': img_gt, 'gt_path': gt_path}
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.paths)
|
StableSR/basicsr/data/ffhq_degradation_dataset.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os.path as osp
|
5 |
+
import torch
|
6 |
+
import torch.utils.data as data
|
7 |
+
import random
|
8 |
+
from basicsr.data import degradations as degradations
|
9 |
+
from basicsr.data.data_util import paths_from_folder
|
10 |
+
from basicsr.data.transforms import augment
|
11 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
12 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
13 |
+
from pathlib import Path
|
14 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
15 |
+
normalize)
|
16 |
+
|
17 |
+
@DATASET_REGISTRY.register()
|
18 |
+
class FFHQDegradationDataset(data.Dataset):
|
19 |
+
"""FFHQ dataset for GFPGAN.
|
20 |
+
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
21 |
+
Args:
|
22 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
23 |
+
dataroot_gt (str): Data root path for gt.
|
24 |
+
io_backend (dict): IO backend type and other kwarg.
|
25 |
+
mean (list | tuple): Image mean.
|
26 |
+
std (list | tuple): Image std.
|
27 |
+
use_hflip (bool): Whether to horizontally flip.
|
28 |
+
Please see more options in the codes.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, opt):
|
32 |
+
super(FFHQDegradationDataset, self).__init__()
|
33 |
+
self.opt = opt
|
34 |
+
# file client (io backend)
|
35 |
+
self.file_client = None
|
36 |
+
self.io_backend_opt = opt['io_backend']
|
37 |
+
if 'image_type' not in opt:
|
38 |
+
opt['image_type'] = 'png'
|
39 |
+
|
40 |
+
self.gt_folder = opt['dataroot_gt']
|
41 |
+
self.mean = opt['mean']
|
42 |
+
self.std = opt['std']
|
43 |
+
self.out_size = opt['out_size']
|
44 |
+
|
45 |
+
self.crop_components = opt.get('crop_components', False) # facial components
|
46 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
47 |
+
|
48 |
+
if self.crop_components:
|
49 |
+
# load component list from a pre-process pth files
|
50 |
+
self.components_list = torch.load(opt.get('component_path'))
|
51 |
+
|
52 |
+
# file client (lmdb io backend)
|
53 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
54 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
55 |
+
if not self.gt_folder.endswith('.lmdb'):
|
56 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
57 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
58 |
+
self.paths = [line.split('.')[0] for line in fin]
|
59 |
+
else:
|
60 |
+
# disk backend: scan file list from a folder
|
61 |
+
self.paths = self.paths = sorted([str(x) for x in Path(self.gt_folder).glob('*.'+opt['image_type'])])
|
62 |
+
|
63 |
+
# degradation configurations
|
64 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
65 |
+
self.kernel_list = opt['kernel_list']
|
66 |
+
self.kernel_prob = opt['kernel_prob']
|
67 |
+
self.blur_sigma = opt['blur_sigma']
|
68 |
+
self.downsample_range = opt['downsample_range']
|
69 |
+
self.noise_range = opt['noise_range']
|
70 |
+
self.jpeg_range = opt['jpeg_range']
|
71 |
+
|
72 |
+
# color jitter
|
73 |
+
self.color_jitter_prob = opt.get('color_jitter_prob')
|
74 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
75 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
76 |
+
# to gray
|
77 |
+
self.gray_prob = opt.get('gray_prob')
|
78 |
+
|
79 |
+
logger = get_root_logger()
|
80 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
81 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
82 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
83 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
84 |
+
|
85 |
+
if self.color_jitter_prob is not None:
|
86 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
87 |
+
if self.gray_prob is not None:
|
88 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
89 |
+
self.color_jitter_shift /= 255.
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def color_jitter(img, shift):
|
93 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
94 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
95 |
+
img = img + jitter_val
|
96 |
+
img = np.clip(img, 0, 1)
|
97 |
+
return img
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
101 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
102 |
+
fn_idx = torch.randperm(4)
|
103 |
+
for fn_id in fn_idx:
|
104 |
+
if fn_id == 0 and brightness is not None:
|
105 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
106 |
+
img = adjust_brightness(img, brightness_factor)
|
107 |
+
|
108 |
+
if fn_id == 1 and contrast is not None:
|
109 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
110 |
+
img = adjust_contrast(img, contrast_factor)
|
111 |
+
|
112 |
+
if fn_id == 2 and saturation is not None:
|
113 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
114 |
+
img = adjust_saturation(img, saturation_factor)
|
115 |
+
|
116 |
+
if fn_id == 3 and hue is not None:
|
117 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
118 |
+
img = adjust_hue(img, hue_factor)
|
119 |
+
return img
|
120 |
+
|
121 |
+
def get_component_coordinates(self, index, status):
|
122 |
+
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
123 |
+
components_bbox = self.components_list[f'{index:08d}']
|
124 |
+
if status[0]: # hflip
|
125 |
+
# exchange right and left eye
|
126 |
+
tmp = components_bbox['left_eye']
|
127 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
128 |
+
components_bbox['right_eye'] = tmp
|
129 |
+
# modify the width coordinate
|
130 |
+
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
131 |
+
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
132 |
+
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
133 |
+
|
134 |
+
# get coordinates
|
135 |
+
locations = []
|
136 |
+
for part in ['left_eye', 'right_eye', 'mouth']:
|
137 |
+
mean = components_bbox[part][0:2]
|
138 |
+
half_len = components_bbox[part][2]
|
139 |
+
if 'eye' in part:
|
140 |
+
half_len *= self.eye_enlarge_ratio
|
141 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
142 |
+
loc = torch.from_numpy(loc).float()
|
143 |
+
locations.append(loc)
|
144 |
+
return locations
|
145 |
+
|
146 |
+
def __getitem__(self, index):
|
147 |
+
if self.file_client is None:
|
148 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
149 |
+
|
150 |
+
# load gt image
|
151 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
152 |
+
gt_path = self.paths[index]
|
153 |
+
img_bytes = self.file_client.get(gt_path)
|
154 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
155 |
+
|
156 |
+
# random horizontal flip
|
157 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
158 |
+
h, w, _ = img_gt.shape
|
159 |
+
|
160 |
+
# get facial component coordinates
|
161 |
+
if self.crop_components:
|
162 |
+
locations = self.get_component_coordinates(index, status)
|
163 |
+
loc_left_eye, loc_right_eye, loc_mouth = locations
|
164 |
+
|
165 |
+
# ------------------------ generate lq image ------------------------ #
|
166 |
+
# blur
|
167 |
+
kernel = degradations.random_mixed_kernels(
|
168 |
+
self.kernel_list,
|
169 |
+
self.kernel_prob,
|
170 |
+
self.blur_kernel_size,
|
171 |
+
self.blur_sigma,
|
172 |
+
self.blur_sigma, [-math.pi, math.pi],
|
173 |
+
noise_range=None)
|
174 |
+
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
175 |
+
# downsample
|
176 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
177 |
+
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
178 |
+
# noise
|
179 |
+
if self.noise_range is not None:
|
180 |
+
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
181 |
+
# jpeg compression
|
182 |
+
if self.jpeg_range is not None:
|
183 |
+
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
184 |
+
|
185 |
+
# resize to original size
|
186 |
+
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
187 |
+
|
188 |
+
# random color jitter (only for lq)
|
189 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
190 |
+
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
191 |
+
# random to gray (only for lq)
|
192 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
193 |
+
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
194 |
+
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
195 |
+
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
196 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
197 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
198 |
+
|
199 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
200 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
201 |
+
|
202 |
+
# random color jitter (pytorch version) (only for lq)
|
203 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
204 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
205 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
206 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
207 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
208 |
+
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
209 |
+
|
210 |
+
# round and clip
|
211 |
+
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
212 |
+
|
213 |
+
# normalize
|
214 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
215 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
216 |
+
|
217 |
+
if self.crop_components:
|
218 |
+
return_dict = {
|
219 |
+
'lq': img_lq,
|
220 |
+
'gt': img_gt,
|
221 |
+
'gt_path': gt_path,
|
222 |
+
'loc_left_eye': loc_left_eye,
|
223 |
+
'loc_right_eye': loc_right_eye,
|
224 |
+
'loc_mouth': loc_mouth
|
225 |
+
}
|
226 |
+
return return_dict
|
227 |
+
else:
|
228 |
+
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
229 |
+
|
230 |
+
def __len__(self):
|
231 |
+
return len(self.paths)
|
StableSR/basicsr/data/paired_image_dataset.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils import data as data
|
2 |
+
from torchvision.transforms.functional import normalize
|
3 |
+
|
4 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file, paired_paths_from_meta_info_file_2
|
5 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
6 |
+
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
|
7 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
@DATASET_REGISTRY.register()
|
12 |
+
class PairedImageDataset(data.Dataset):
|
13 |
+
"""Paired image dataset for image restoration.
|
14 |
+
|
15 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
16 |
+
|
17 |
+
There are three modes:
|
18 |
+
|
19 |
+
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
20 |
+
2. **meta_info_file**: Use meta information file to generate paths. \
|
21 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
22 |
+
3. **folder**: Scan folders to generate paths. The rest.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
26 |
+
dataroot_gt (str): Data root path for gt.
|
27 |
+
dataroot_lq (str): Data root path for lq.
|
28 |
+
meta_info_file (str): Path for meta information file.
|
29 |
+
io_backend (dict): IO backend type and other kwarg.
|
30 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
31 |
+
Default: '{}'.
|
32 |
+
gt_size (int): Cropped patched size for gt patches.
|
33 |
+
use_hflip (bool): Use horizontal flips.
|
34 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
35 |
+
scale (bool): Scale, which will be added automatically.
|
36 |
+
phase (str): 'train' or 'val'.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, opt):
|
40 |
+
super(PairedImageDataset, self).__init__()
|
41 |
+
self.opt = opt
|
42 |
+
# file client (io backend)
|
43 |
+
self.file_client = None
|
44 |
+
self.io_backend_opt = opt['io_backend']
|
45 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
46 |
+
self.std = opt['std'] if 'std' in opt else None
|
47 |
+
|
48 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
49 |
+
if 'filename_tmpl' in opt:
|
50 |
+
self.filename_tmpl = opt['filename_tmpl']
|
51 |
+
else:
|
52 |
+
self.filename_tmpl = '{}'
|
53 |
+
|
54 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
55 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
56 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
57 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
58 |
+
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
|
59 |
+
self.paths = paired_paths_from_meta_info_file_2([self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
60 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
61 |
+
else:
|
62 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
63 |
+
|
64 |
+
def __getitem__(self, index):
|
65 |
+
if self.file_client is None:
|
66 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
67 |
+
|
68 |
+
scale = self.opt['scale']
|
69 |
+
|
70 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
71 |
+
# image range: [0, 1], float32.
|
72 |
+
gt_path = self.paths[index]['gt_path']
|
73 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
74 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
75 |
+
lq_path = self.paths[index]['lq_path']
|
76 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
77 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
78 |
+
|
79 |
+
h, w = img_gt.shape[0:2]
|
80 |
+
# pad
|
81 |
+
if h < self.opt['gt_size'] or w < self.opt['gt_size']:
|
82 |
+
pad_h = max(0, self.opt['gt_size'] - h)
|
83 |
+
pad_w = max(0, self.opt['gt_size'] - w)
|
84 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
85 |
+
img_lq = cv2.copyMakeBorder(img_lq, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
86 |
+
|
87 |
+
# augmentation for training
|
88 |
+
if self.opt['phase'] == 'train':
|
89 |
+
gt_size = self.opt['gt_size']
|
90 |
+
# random crop
|
91 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
92 |
+
# flip, rotation
|
93 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
94 |
+
|
95 |
+
# color space transform
|
96 |
+
if 'color' in self.opt and self.opt['color'] == 'y':
|
97 |
+
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
|
98 |
+
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
|
99 |
+
|
100 |
+
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
|
101 |
+
# TODO: It is better to update the datasets, rather than force to crop
|
102 |
+
if self.opt['phase'] != 'train':
|
103 |
+
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
|
104 |
+
|
105 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
106 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
107 |
+
# normalize
|
108 |
+
if self.mean is not None or self.std is not None:
|
109 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
110 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
111 |
+
|
112 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
return len(self.paths)
|
StableSR/basicsr/data/prefetch_dataloader.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue as Queue
|
2 |
+
import threading
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
class PrefetchGenerator(threading.Thread):
|
8 |
+
"""A general prefetch generator.
|
9 |
+
|
10 |
+
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
11 |
+
|
12 |
+
Args:
|
13 |
+
generator: Python generator.
|
14 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, generator, num_prefetch_queue):
|
18 |
+
threading.Thread.__init__(self)
|
19 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
20 |
+
self.generator = generator
|
21 |
+
self.daemon = True
|
22 |
+
self.start()
|
23 |
+
|
24 |
+
def run(self):
|
25 |
+
for item in self.generator:
|
26 |
+
self.queue.put(item)
|
27 |
+
self.queue.put(None)
|
28 |
+
|
29 |
+
def __next__(self):
|
30 |
+
next_item = self.queue.get()
|
31 |
+
if next_item is None:
|
32 |
+
raise StopIteration
|
33 |
+
return next_item
|
34 |
+
|
35 |
+
def __iter__(self):
|
36 |
+
return self
|
37 |
+
|
38 |
+
|
39 |
+
class PrefetchDataLoader(DataLoader):
|
40 |
+
"""Prefetch version of dataloader.
|
41 |
+
|
42 |
+
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
43 |
+
|
44 |
+
TODO:
|
45 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
46 |
+
ddp.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
50 |
+
kwargs (dict): Other arguments for dataloader.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
54 |
+
self.num_prefetch_queue = num_prefetch_queue
|
55 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
56 |
+
|
57 |
+
def __iter__(self):
|
58 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
59 |
+
|
60 |
+
|
61 |
+
class CPUPrefetcher():
|
62 |
+
"""CPU prefetcher.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
loader: Dataloader.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, loader):
|
69 |
+
self.ori_loader = loader
|
70 |
+
self.loader = iter(loader)
|
71 |
+
|
72 |
+
def next(self):
|
73 |
+
try:
|
74 |
+
return next(self.loader)
|
75 |
+
except StopIteration:
|
76 |
+
return None
|
77 |
+
|
78 |
+
def reset(self):
|
79 |
+
self.loader = iter(self.ori_loader)
|
80 |
+
|
81 |
+
|
82 |
+
class CUDAPrefetcher():
|
83 |
+
"""CUDA prefetcher.
|
84 |
+
|
85 |
+
Reference: https://github.com/NVIDIA/apex/issues/304#
|
86 |
+
|
87 |
+
It may consume more GPU memory.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
loader: Dataloader.
|
91 |
+
opt (dict): Options.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, loader, opt):
|
95 |
+
self.ori_loader = loader
|
96 |
+
self.loader = iter(loader)
|
97 |
+
self.opt = opt
|
98 |
+
self.stream = torch.cuda.Stream()
|
99 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
100 |
+
self.preload()
|
101 |
+
|
102 |
+
def preload(self):
|
103 |
+
try:
|
104 |
+
self.batch = next(self.loader) # self.batch is a dict
|
105 |
+
except StopIteration:
|
106 |
+
self.batch = None
|
107 |
+
return None
|
108 |
+
# put tensors to gpu
|
109 |
+
with torch.cuda.stream(self.stream):
|
110 |
+
for k, v in self.batch.items():
|
111 |
+
if torch.is_tensor(v):
|
112 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
113 |
+
|
114 |
+
def next(self):
|
115 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
116 |
+
batch = self.batch
|
117 |
+
self.preload()
|
118 |
+
return batch
|
119 |
+
|
120 |
+
def reset(self):
|
121 |
+
self.loader = iter(self.ori_loader)
|
122 |
+
self.preload()
|
StableSR/basicsr/data/realesrgan_dataset.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import torch
|
9 |
+
from pathlib import Path
|
10 |
+
from torch.utils import data as data
|
11 |
+
|
12 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
13 |
+
from basicsr.data.transforms import augment
|
14 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
15 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
16 |
+
|
17 |
+
@DATASET_REGISTRY.register(suffix='basicsr')
|
18 |
+
class RealESRGANDataset(data.Dataset):
|
19 |
+
"""Modified dataset based on the dataset used for Real-ESRGAN model:
|
20 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
21 |
+
|
22 |
+
It loads gt (Ground-Truth) images, and augments them.
|
23 |
+
It also generates blur kernels and sinc kernels for generating low-quality images.
|
24 |
+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
28 |
+
dataroot_gt (str): Data root path for gt.
|
29 |
+
meta_info (str): Path for meta information file.
|
30 |
+
io_backend (dict): IO backend type and other kwarg.
|
31 |
+
use_hflip (bool): Use horizontal flips.
|
32 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
33 |
+
Please see more options in the codes.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, opt):
|
37 |
+
super(RealESRGANDataset, self).__init__()
|
38 |
+
self.opt = opt
|
39 |
+
self.file_client = None
|
40 |
+
self.io_backend_opt = opt['io_backend']
|
41 |
+
if 'crop_size' in opt:
|
42 |
+
self.crop_size = opt['crop_size']
|
43 |
+
else:
|
44 |
+
self.crop_size = 512
|
45 |
+
if 'image_type' not in opt:
|
46 |
+
opt['image_type'] = 'png'
|
47 |
+
|
48 |
+
# support multiple type of data: file path and meta data, remove support of lmdb
|
49 |
+
self.paths = []
|
50 |
+
if 'meta_info' in opt:
|
51 |
+
with open(self.opt['meta_info']) as fin:
|
52 |
+
paths = [line.strip().split(' ')[0] for line in fin]
|
53 |
+
self.paths = [v for v in paths]
|
54 |
+
if 'meta_num' in opt:
|
55 |
+
self.paths = sorted(self.paths)[:opt['meta_num']]
|
56 |
+
if 'gt_path' in opt:
|
57 |
+
if isinstance(opt['gt_path'], str):
|
58 |
+
self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).glob('*.'+opt['image_type'])]))
|
59 |
+
else:
|
60 |
+
self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][0]).glob('*.'+opt['image_type'])]))
|
61 |
+
if len(opt['gt_path']) > 1:
|
62 |
+
for i in range(len(opt['gt_path'])-1):
|
63 |
+
self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]).glob('*.'+opt['image_type'])]))
|
64 |
+
if 'imagenet_path' in opt:
|
65 |
+
class_list = os.listdir(opt['imagenet_path'])
|
66 |
+
for class_file in class_list:
|
67 |
+
self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')]))
|
68 |
+
if 'face_gt_path' in opt:
|
69 |
+
if isinstance(opt['face_gt_path'], str):
|
70 |
+
face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])])
|
71 |
+
self.paths.extend(face_list[:opt['num_face']])
|
72 |
+
else:
|
73 |
+
face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])
|
74 |
+
self.paths.extend(face_list[:opt['num_face']])
|
75 |
+
if len(opt['face_gt_path']) > 1:
|
76 |
+
for i in range(len(opt['face_gt_path'])-1):
|
77 |
+
self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']])
|
78 |
+
|
79 |
+
# limit number of pictures for test
|
80 |
+
if 'num_pic' in opt:
|
81 |
+
if 'val' or 'test' in opt:
|
82 |
+
random.shuffle(self.paths)
|
83 |
+
self.paths = self.paths[:opt['num_pic']]
|
84 |
+
else:
|
85 |
+
self.paths = self.paths[:opt['num_pic']]
|
86 |
+
|
87 |
+
if 'mul_num' in opt:
|
88 |
+
self.paths = self.paths * opt['mul_num']
|
89 |
+
# print('>>>>>>>>>>>>>>>>>>>>>')
|
90 |
+
# print(self.paths)
|
91 |
+
|
92 |
+
# blur settings for the first degradation
|
93 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
94 |
+
self.kernel_list = opt['kernel_list']
|
95 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
96 |
+
self.blur_sigma = opt['blur_sigma']
|
97 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
98 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
99 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
100 |
+
|
101 |
+
# blur settings for the second degradation
|
102 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
103 |
+
self.kernel_list2 = opt['kernel_list2']
|
104 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
105 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
106 |
+
self.betag_range2 = opt['betag_range2']
|
107 |
+
self.betap_range2 = opt['betap_range2']
|
108 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
109 |
+
|
110 |
+
# a final sinc filter
|
111 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
112 |
+
|
113 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
114 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
115 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
116 |
+
self.pulse_tensor[10, 10] = 1
|
117 |
+
|
118 |
+
def __getitem__(self, index):
|
119 |
+
if self.file_client is None:
|
120 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
121 |
+
|
122 |
+
# -------------------------------- Load gt images -------------------------------- #
|
123 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
124 |
+
gt_path = self.paths[index]
|
125 |
+
# avoid errors caused by high latency in reading files
|
126 |
+
retry = 3
|
127 |
+
while retry > 0:
|
128 |
+
try:
|
129 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
130 |
+
except (IOError, OSError) as e:
|
131 |
+
# logger = get_root_logger()
|
132 |
+
# logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
133 |
+
# change another file to read
|
134 |
+
index = random.randint(0, self.__len__()-1)
|
135 |
+
gt_path = self.paths[index]
|
136 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
137 |
+
else:
|
138 |
+
break
|
139 |
+
finally:
|
140 |
+
retry -= 1
|
141 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
142 |
+
# filter the dataset and remove images with too low quality
|
143 |
+
img_size = os.path.getsize(gt_path)
|
144 |
+
img_size = img_size/1024
|
145 |
+
|
146 |
+
while img_gt.shape[0] * img_gt.shape[1] < 384*384 or img_size<100:
|
147 |
+
index = random.randint(0, self.__len__()-1)
|
148 |
+
gt_path = self.paths[index]
|
149 |
+
|
150 |
+
time.sleep(0.1) # sleep 1s for occasional server congestion
|
151 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
152 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
153 |
+
img_size = os.path.getsize(gt_path)
|
154 |
+
img_size = img_size/1024
|
155 |
+
|
156 |
+
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
157 |
+
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
158 |
+
|
159 |
+
# crop or pad to 400
|
160 |
+
# TODO: 400 is hard-coded. You may change it accordingly
|
161 |
+
h, w = img_gt.shape[0:2]
|
162 |
+
crop_pad_size = self.crop_size
|
163 |
+
# pad
|
164 |
+
if h < crop_pad_size or w < crop_pad_size:
|
165 |
+
pad_h = max(0, crop_pad_size - h)
|
166 |
+
pad_w = max(0, crop_pad_size - w)
|
167 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
168 |
+
# crop
|
169 |
+
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
170 |
+
h, w = img_gt.shape[0:2]
|
171 |
+
# randomly choose top and left coordinates
|
172 |
+
top = random.randint(0, h - crop_pad_size)
|
173 |
+
left = random.randint(0, w - crop_pad_size)
|
174 |
+
# top = (h - crop_pad_size) // 2 -1
|
175 |
+
# left = (w - crop_pad_size) // 2 -1
|
176 |
+
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
177 |
+
|
178 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
179 |
+
kernel_size = random.choice(self.kernel_range)
|
180 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
181 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
182 |
+
if kernel_size < 13:
|
183 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
184 |
+
else:
|
185 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
186 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
187 |
+
else:
|
188 |
+
kernel = random_mixed_kernels(
|
189 |
+
self.kernel_list,
|
190 |
+
self.kernel_prob,
|
191 |
+
kernel_size,
|
192 |
+
self.blur_sigma,
|
193 |
+
self.blur_sigma, [-math.pi, math.pi],
|
194 |
+
self.betag_range,
|
195 |
+
self.betap_range,
|
196 |
+
noise_range=None)
|
197 |
+
# pad kernel
|
198 |
+
pad_size = (21 - kernel_size) // 2
|
199 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
200 |
+
|
201 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
202 |
+
kernel_size = random.choice(self.kernel_range)
|
203 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
204 |
+
if kernel_size < 13:
|
205 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
206 |
+
else:
|
207 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
208 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
209 |
+
else:
|
210 |
+
kernel2 = random_mixed_kernels(
|
211 |
+
self.kernel_list2,
|
212 |
+
self.kernel_prob2,
|
213 |
+
kernel_size,
|
214 |
+
self.blur_sigma2,
|
215 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
216 |
+
self.betag_range2,
|
217 |
+
self.betap_range2,
|
218 |
+
noise_range=None)
|
219 |
+
|
220 |
+
# pad kernel
|
221 |
+
pad_size = (21 - kernel_size) // 2
|
222 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
223 |
+
|
224 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
225 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
226 |
+
kernel_size = random.choice(self.kernel_range)
|
227 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
228 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
229 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
230 |
+
else:
|
231 |
+
sinc_kernel = self.pulse_tensor
|
232 |
+
|
233 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
234 |
+
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
235 |
+
kernel = torch.FloatTensor(kernel)
|
236 |
+
kernel2 = torch.FloatTensor(kernel2)
|
237 |
+
|
238 |
+
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
239 |
+
return return_d
|
240 |
+
|
241 |
+
def __len__(self):
|
242 |
+
return len(self.paths)
|
StableSR/basicsr/data/realesrgan_paired_dataset.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils import data as data
|
3 |
+
from torchvision.transforms.functional import normalize
|
4 |
+
|
5 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
6 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
7 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
8 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
9 |
+
|
10 |
+
|
11 |
+
@DATASET_REGISTRY.register(suffix='basicsr')
|
12 |
+
class RealESRGANPairedDataset(data.Dataset):
|
13 |
+
"""Paired image dataset for image restoration.
|
14 |
+
|
15 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
16 |
+
|
17 |
+
There are three modes:
|
18 |
+
|
19 |
+
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
|
20 |
+
2. **meta_info_file**: Use meta information file to generate paths. \
|
21 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
22 |
+
3. **folder**: Scan folders to generate paths. The rest.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
26 |
+
dataroot_gt (str): Data root path for gt.
|
27 |
+
dataroot_lq (str): Data root path for lq.
|
28 |
+
meta_info (str): Path for meta information file.
|
29 |
+
io_backend (dict): IO backend type and other kwarg.
|
30 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
31 |
+
Default: '{}'.
|
32 |
+
gt_size (int): Cropped patched size for gt patches.
|
33 |
+
use_hflip (bool): Use horizontal flips.
|
34 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
35 |
+
scale (bool): Scale, which will be added automatically.
|
36 |
+
phase (str): 'train' or 'val'.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, opt):
|
40 |
+
super(RealESRGANPairedDataset, self).__init__()
|
41 |
+
self.opt = opt
|
42 |
+
self.file_client = None
|
43 |
+
self.io_backend_opt = opt['io_backend']
|
44 |
+
# mean and std for normalizing the input images
|
45 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
46 |
+
self.std = opt['std'] if 'std' in opt else None
|
47 |
+
|
48 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
49 |
+
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
50 |
+
|
51 |
+
# file client (lmdb io backend)
|
52 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
53 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
54 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
55 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
56 |
+
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
57 |
+
# disk backend with meta_info
|
58 |
+
# Each line in the meta_info describes the relative path to an image
|
59 |
+
with open(self.opt['meta_info']) as fin:
|
60 |
+
paths = [line.strip() for line in fin]
|
61 |
+
self.paths = []
|
62 |
+
for path in paths:
|
63 |
+
gt_path, lq_path = path.split(', ')
|
64 |
+
gt_path = os.path.join(self.gt_folder, gt_path)
|
65 |
+
lq_path = os.path.join(self.lq_folder, lq_path)
|
66 |
+
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
67 |
+
else:
|
68 |
+
# disk backend
|
69 |
+
# it will scan the whole folder to get meta info
|
70 |
+
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
71 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
72 |
+
|
73 |
+
if 'num_pic' in self.opt:
|
74 |
+
self.paths = self.paths[:self.opt['num_pic']]
|
75 |
+
if 'phase' not in self.opt:
|
76 |
+
self.opt['phase'] = 'test'
|
77 |
+
if 'scale' not in self.opt:
|
78 |
+
self.opt['scale'] = 1
|
79 |
+
|
80 |
+
|
81 |
+
def __getitem__(self, index):
|
82 |
+
if self.file_client is None:
|
83 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
84 |
+
|
85 |
+
scale = self.opt['scale']
|
86 |
+
|
87 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
88 |
+
# image range: [0, 1], float32.
|
89 |
+
gt_path = self.paths[index]['gt_path']
|
90 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
91 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
92 |
+
lq_path = self.paths[index]['lq_path']
|
93 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
94 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
95 |
+
|
96 |
+
# augmentation for training
|
97 |
+
if self.opt['phase'] == 'train':
|
98 |
+
gt_size = self.opt['gt_size']
|
99 |
+
# random crop
|
100 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
101 |
+
# flip, rotation
|
102 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
103 |
+
|
104 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
105 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
106 |
+
# normalize
|
107 |
+
if self.mean is not None or self.std is not None:
|
108 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
109 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
110 |
+
|
111 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
112 |
+
|
113 |
+
def __len__(self):
|
114 |
+
return len(self.paths)
|
StableSR/basicsr/data/reds_dataset.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from pathlib import Path
|
5 |
+
from torch.utils import data as data
|
6 |
+
|
7 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
8 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
9 |
+
from basicsr.utils.flow_util import dequantize_flow
|
10 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
11 |
+
|
12 |
+
|
13 |
+
@DATASET_REGISTRY.register()
|
14 |
+
class REDSDataset(data.Dataset):
|
15 |
+
"""REDS dataset for training.
|
16 |
+
|
17 |
+
The keys are generated from a meta info txt file.
|
18 |
+
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
19 |
+
|
20 |
+
Each line contains:
|
21 |
+
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
22 |
+
a white space.
|
23 |
+
Examples:
|
24 |
+
000 100 (720,1280,3)
|
25 |
+
001 100 (720,1280,3)
|
26 |
+
...
|
27 |
+
|
28 |
+
Key examples: "000/00000000"
|
29 |
+
GT (gt): Ground-Truth;
|
30 |
+
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
34 |
+
dataroot_gt (str): Data root path for gt.
|
35 |
+
dataroot_lq (str): Data root path for lq.
|
36 |
+
dataroot_flow (str, optional): Data root path for flow.
|
37 |
+
meta_info_file (str): Path for meta information file.
|
38 |
+
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
39 |
+
io_backend (dict): IO backend type and other kwarg.
|
40 |
+
num_frame (int): Window size for input frames.
|
41 |
+
gt_size (int): Cropped patched size for gt patches.
|
42 |
+
interval_list (list): Interval list for temporal augmentation.
|
43 |
+
random_reverse (bool): Random reverse input frames.
|
44 |
+
use_hflip (bool): Use horizontal flips.
|
45 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
46 |
+
scale (bool): Scale, which will be added automatically.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, opt):
|
50 |
+
super(REDSDataset, self).__init__()
|
51 |
+
self.opt = opt
|
52 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
53 |
+
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
|
54 |
+
assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
|
55 |
+
self.num_frame = opt['num_frame']
|
56 |
+
self.num_half_frames = opt['num_frame'] // 2
|
57 |
+
|
58 |
+
self.keys = []
|
59 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
60 |
+
for line in fin:
|
61 |
+
folder, frame_num, _ = line.split(' ')
|
62 |
+
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
63 |
+
|
64 |
+
# remove the video clips used in validation
|
65 |
+
if opt['val_partition'] == 'REDS4':
|
66 |
+
val_partition = ['000', '011', '015', '020']
|
67 |
+
elif opt['val_partition'] == 'official':
|
68 |
+
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
69 |
+
else:
|
70 |
+
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
71 |
+
f"Supported ones are ['official', 'REDS4'].")
|
72 |
+
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
73 |
+
|
74 |
+
# file client (io backend)
|
75 |
+
self.file_client = None
|
76 |
+
self.io_backend_opt = opt['io_backend']
|
77 |
+
self.is_lmdb = False
|
78 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
79 |
+
self.is_lmdb = True
|
80 |
+
if self.flow_root is not None:
|
81 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
82 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
83 |
+
else:
|
84 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
85 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
86 |
+
|
87 |
+
# temporal augmentation configs
|
88 |
+
self.interval_list = opt['interval_list']
|
89 |
+
self.random_reverse = opt['random_reverse']
|
90 |
+
interval_str = ','.join(str(x) for x in opt['interval_list'])
|
91 |
+
logger = get_root_logger()
|
92 |
+
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
93 |
+
f'random reverse is {self.random_reverse}.')
|
94 |
+
|
95 |
+
def __getitem__(self, index):
|
96 |
+
if self.file_client is None:
|
97 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
98 |
+
|
99 |
+
scale = self.opt['scale']
|
100 |
+
gt_size = self.opt['gt_size']
|
101 |
+
key = self.keys[index]
|
102 |
+
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
103 |
+
center_frame_idx = int(frame_name)
|
104 |
+
|
105 |
+
# determine the neighboring frames
|
106 |
+
interval = random.choice(self.interval_list)
|
107 |
+
|
108 |
+
# ensure not exceeding the borders
|
109 |
+
start_frame_idx = center_frame_idx - self.num_half_frames * interval
|
110 |
+
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
111 |
+
# each clip has 100 frames starting from 0 to 99
|
112 |
+
while (start_frame_idx < 0) or (end_frame_idx > 99):
|
113 |
+
center_frame_idx = random.randint(0, 99)
|
114 |
+
start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
|
115 |
+
end_frame_idx = center_frame_idx + self.num_half_frames * interval
|
116 |
+
frame_name = f'{center_frame_idx:08d}'
|
117 |
+
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
|
118 |
+
# random reverse
|
119 |
+
if self.random_reverse and random.random() < 0.5:
|
120 |
+
neighbor_list.reverse()
|
121 |
+
|
122 |
+
assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
|
123 |
+
|
124 |
+
# get the GT frame (as the center frame)
|
125 |
+
if self.is_lmdb:
|
126 |
+
img_gt_path = f'{clip_name}/{frame_name}'
|
127 |
+
else:
|
128 |
+
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
|
129 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
130 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
131 |
+
|
132 |
+
# get the neighboring LQ frames
|
133 |
+
img_lqs = []
|
134 |
+
for neighbor in neighbor_list:
|
135 |
+
if self.is_lmdb:
|
136 |
+
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
137 |
+
else:
|
138 |
+
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
139 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
140 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
141 |
+
img_lqs.append(img_lq)
|
142 |
+
|
143 |
+
# get flows
|
144 |
+
if self.flow_root is not None:
|
145 |
+
img_flows = []
|
146 |
+
# read previous flows
|
147 |
+
for i in range(self.num_half_frames, 0, -1):
|
148 |
+
if self.is_lmdb:
|
149 |
+
flow_path = f'{clip_name}/{frame_name}_p{i}'
|
150 |
+
else:
|
151 |
+
flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
|
152 |
+
img_bytes = self.file_client.get(flow_path, 'flow')
|
153 |
+
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
154 |
+
dx, dy = np.split(cat_flow, 2, axis=0)
|
155 |
+
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
156 |
+
img_flows.append(flow)
|
157 |
+
# read next flows
|
158 |
+
for i in range(1, self.num_half_frames + 1):
|
159 |
+
if self.is_lmdb:
|
160 |
+
flow_path = f'{clip_name}/{frame_name}_n{i}'
|
161 |
+
else:
|
162 |
+
flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
|
163 |
+
img_bytes = self.file_client.get(flow_path, 'flow')
|
164 |
+
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
|
165 |
+
dx, dy = np.split(cat_flow, 2, axis=0)
|
166 |
+
flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
|
167 |
+
img_flows.append(flow)
|
168 |
+
|
169 |
+
# for random crop, here, img_flows and img_lqs have the same
|
170 |
+
# spatial size
|
171 |
+
img_lqs.extend(img_flows)
|
172 |
+
|
173 |
+
# randomly crop
|
174 |
+
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
175 |
+
if self.flow_root is not None:
|
176 |
+
img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
|
177 |
+
|
178 |
+
# augmentation - flip, rotate
|
179 |
+
img_lqs.append(img_gt)
|
180 |
+
if self.flow_root is not None:
|
181 |
+
img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
|
182 |
+
else:
|
183 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
184 |
+
|
185 |
+
img_results = img2tensor(img_results)
|
186 |
+
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
187 |
+
img_gt = img_results[-1]
|
188 |
+
|
189 |
+
if self.flow_root is not None:
|
190 |
+
img_flows = img2tensor(img_flows)
|
191 |
+
# add the zero center flow
|
192 |
+
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
|
193 |
+
img_flows = torch.stack(img_flows, dim=0)
|
194 |
+
|
195 |
+
# img_lqs: (t, c, h, w)
|
196 |
+
# img_flows: (t, 2, h, w)
|
197 |
+
# img_gt: (c, h, w)
|
198 |
+
# key: str
|
199 |
+
if self.flow_root is not None:
|
200 |
+
return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
|
201 |
+
else:
|
202 |
+
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
203 |
+
|
204 |
+
def __len__(self):
|
205 |
+
return len(self.keys)
|
206 |
+
|
207 |
+
|
208 |
+
@DATASET_REGISTRY.register()
|
209 |
+
class REDSRecurrentDataset(data.Dataset):
|
210 |
+
"""REDS dataset for training recurrent networks.
|
211 |
+
|
212 |
+
The keys are generated from a meta info txt file.
|
213 |
+
basicsr/data/meta_info/meta_info_REDS_GT.txt
|
214 |
+
|
215 |
+
Each line contains:
|
216 |
+
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
|
217 |
+
a white space.
|
218 |
+
Examples:
|
219 |
+
000 100 (720,1280,3)
|
220 |
+
001 100 (720,1280,3)
|
221 |
+
...
|
222 |
+
|
223 |
+
Key examples: "000/00000000"
|
224 |
+
GT (gt): Ground-Truth;
|
225 |
+
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
229 |
+
dataroot_gt (str): Data root path for gt.
|
230 |
+
dataroot_lq (str): Data root path for lq.
|
231 |
+
dataroot_flow (str, optional): Data root path for flow.
|
232 |
+
meta_info_file (str): Path for meta information file.
|
233 |
+
val_partition (str): Validation partition types. 'REDS4' or 'official'.
|
234 |
+
io_backend (dict): IO backend type and other kwarg.
|
235 |
+
num_frame (int): Window size for input frames.
|
236 |
+
gt_size (int): Cropped patched size for gt patches.
|
237 |
+
interval_list (list): Interval list for temporal augmentation.
|
238 |
+
random_reverse (bool): Random reverse input frames.
|
239 |
+
use_hflip (bool): Use horizontal flips.
|
240 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
241 |
+
scale (bool): Scale, which will be added automatically.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(self, opt):
|
245 |
+
super(REDSRecurrentDataset, self).__init__()
|
246 |
+
self.opt = opt
|
247 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
248 |
+
self.num_frame = opt['num_frame']
|
249 |
+
|
250 |
+
self.keys = []
|
251 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
252 |
+
for line in fin:
|
253 |
+
folder, frame_num, _ = line.split(' ')
|
254 |
+
self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
|
255 |
+
|
256 |
+
# remove the video clips used in validation
|
257 |
+
if opt['val_partition'] == 'REDS4':
|
258 |
+
val_partition = ['000', '011', '015', '020']
|
259 |
+
elif opt['val_partition'] == 'official':
|
260 |
+
val_partition = [f'{v:03d}' for v in range(240, 270)]
|
261 |
+
else:
|
262 |
+
raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
|
263 |
+
f"Supported ones are ['official', 'REDS4'].")
|
264 |
+
if opt['test_mode']:
|
265 |
+
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
|
266 |
+
else:
|
267 |
+
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
|
268 |
+
|
269 |
+
# file client (io backend)
|
270 |
+
self.file_client = None
|
271 |
+
self.io_backend_opt = opt['io_backend']
|
272 |
+
self.is_lmdb = False
|
273 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
274 |
+
self.is_lmdb = True
|
275 |
+
if hasattr(self, 'flow_root') and self.flow_root is not None:
|
276 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
|
277 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
|
278 |
+
else:
|
279 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
280 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
281 |
+
|
282 |
+
# temporal augmentation configs
|
283 |
+
self.interval_list = opt.get('interval_list', [1])
|
284 |
+
self.random_reverse = opt.get('random_reverse', False)
|
285 |
+
interval_str = ','.join(str(x) for x in self.interval_list)
|
286 |
+
logger = get_root_logger()
|
287 |
+
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
|
288 |
+
f'random reverse is {self.random_reverse}.')
|
289 |
+
|
290 |
+
def __getitem__(self, index):
|
291 |
+
if self.file_client is None:
|
292 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
293 |
+
|
294 |
+
scale = self.opt['scale']
|
295 |
+
gt_size = self.opt['gt_size']
|
296 |
+
key = self.keys[index]
|
297 |
+
clip_name, frame_name = key.split('/') # key example: 000/00000000
|
298 |
+
|
299 |
+
# determine the neighboring frames
|
300 |
+
interval = random.choice(self.interval_list)
|
301 |
+
|
302 |
+
# ensure not exceeding the borders
|
303 |
+
start_frame_idx = int(frame_name)
|
304 |
+
if start_frame_idx > 100 - self.num_frame * interval:
|
305 |
+
start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
|
306 |
+
end_frame_idx = start_frame_idx + self.num_frame * interval
|
307 |
+
|
308 |
+
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
|
309 |
+
|
310 |
+
# random reverse
|
311 |
+
if self.random_reverse and random.random() < 0.5:
|
312 |
+
neighbor_list.reverse()
|
313 |
+
|
314 |
+
# get the neighboring LQ and GT frames
|
315 |
+
img_lqs = []
|
316 |
+
img_gts = []
|
317 |
+
for neighbor in neighbor_list:
|
318 |
+
if self.is_lmdb:
|
319 |
+
img_lq_path = f'{clip_name}/{neighbor:08d}'
|
320 |
+
img_gt_path = f'{clip_name}/{neighbor:08d}'
|
321 |
+
else:
|
322 |
+
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
|
323 |
+
img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
|
324 |
+
|
325 |
+
# get LQ
|
326 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
327 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
328 |
+
img_lqs.append(img_lq)
|
329 |
+
|
330 |
+
# get GT
|
331 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
332 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
333 |
+
img_gts.append(img_gt)
|
334 |
+
|
335 |
+
# randomly crop
|
336 |
+
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
337 |
+
|
338 |
+
# augmentation - flip, rotate
|
339 |
+
img_lqs.extend(img_gts)
|
340 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
341 |
+
|
342 |
+
img_results = img2tensor(img_results)
|
343 |
+
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
|
344 |
+
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
|
345 |
+
|
346 |
+
# img_lqs: (t, c, h, w)
|
347 |
+
# img_gts: (t, c, h, w)
|
348 |
+
# key: str
|
349 |
+
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
350 |
+
|
351 |
+
def __len__(self):
|
352 |
+
return len(self.keys)
|
StableSR/basicsr/data/single_image_dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path as osp
|
2 |
+
from torch.utils import data as data
|
3 |
+
from torchvision.transforms.functional import normalize
|
4 |
+
|
5 |
+
from basicsr.data.data_util import paths_from_lmdb
|
6 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
|
7 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
8 |
+
|
9 |
+
from pathlib import Path
|
10 |
+
import random
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
@DATASET_REGISTRY.register()
|
16 |
+
class SingleImageDataset(data.Dataset):
|
17 |
+
"""Read only lq images in the test phase.
|
18 |
+
|
19 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
|
20 |
+
|
21 |
+
There are two modes:
|
22 |
+
1. 'meta_info_file': Use meta information file to generate paths.
|
23 |
+
2. 'folder': Scan folders to generate paths.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
27 |
+
dataroot_lq (str): Data root path for lq.
|
28 |
+
meta_info_file (str): Path for meta information file.
|
29 |
+
io_backend (dict): IO backend type and other kwarg.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, opt):
|
33 |
+
super(SingleImageDataset, self).__init__()
|
34 |
+
self.opt = opt
|
35 |
+
# file client (io backend)
|
36 |
+
self.file_client = None
|
37 |
+
self.io_backend_opt = opt['io_backend']
|
38 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
39 |
+
self.std = opt['std'] if 'std' in opt else None
|
40 |
+
self.lq_folder = opt['dataroot_lq']
|
41 |
+
|
42 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
43 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder]
|
44 |
+
self.io_backend_opt['client_keys'] = ['lq']
|
45 |
+
self.paths = paths_from_lmdb(self.lq_folder)
|
46 |
+
elif 'meta_info_file' in self.opt:
|
47 |
+
with open(self.opt['meta_info_file'], 'r') as fin:
|
48 |
+
self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
|
49 |
+
else:
|
50 |
+
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
if self.file_client is None:
|
54 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
55 |
+
|
56 |
+
# load lq image
|
57 |
+
lq_path = self.paths[index]
|
58 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
59 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
60 |
+
|
61 |
+
# color space transform
|
62 |
+
if 'color' in self.opt and self.opt['color'] == 'y':
|
63 |
+
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
|
64 |
+
|
65 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
66 |
+
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
|
67 |
+
# normalize
|
68 |
+
if self.mean is not None or self.std is not None:
|
69 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
70 |
+
return {'lq': img_lq, 'lq_path': lq_path}
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self.paths)
|
74 |
+
|
75 |
+
@DATASET_REGISTRY.register()
|
76 |
+
class SingleImageNPDataset(data.Dataset):
|
77 |
+
"""Read only lq images in the test phase.
|
78 |
+
|
79 |
+
Read diffusion generated data for training CFW.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
83 |
+
gt_path: Data root path for training data. The path needs to contain the following folders:
|
84 |
+
gts: Ground-truth images.
|
85 |
+
inputs: Input LQ images.
|
86 |
+
latents: The corresponding HQ latent code generated by diffusion model given the input LQ image.
|
87 |
+
samples: The corresponding HQ image given the HQ latent code, just for verification.
|
88 |
+
io_backend (dict): IO backend type and other kwarg.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, opt):
|
92 |
+
super(SingleImageNPDataset, self).__init__()
|
93 |
+
self.opt = opt
|
94 |
+
# file client (io backend)
|
95 |
+
self.file_client = None
|
96 |
+
self.io_backend_opt = opt['io_backend']
|
97 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
98 |
+
self.std = opt['std'] if 'std' in opt else None
|
99 |
+
if 'image_type' not in opt:
|
100 |
+
opt['image_type'] = 'png'
|
101 |
+
|
102 |
+
if isinstance(opt['gt_path'], str):
|
103 |
+
self.gt_paths = sorted([str(x) for x in Path(opt['gt_path']+'/gts').glob('*.'+opt['image_type'])])
|
104 |
+
self.lq_paths = sorted([str(x) for x in Path(opt['gt_path']+'/inputs').glob('*.'+opt['image_type'])])
|
105 |
+
self.np_paths = sorted([str(x) for x in Path(opt['gt_path']+'/latents').glob('*.npy')])
|
106 |
+
self.sample_paths = sorted([str(x) for x in Path(opt['gt_path']+'/samples').glob('*.'+opt['image_type'])])
|
107 |
+
else:
|
108 |
+
self.gt_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/gts').glob('*.'+opt['image_type'])])
|
109 |
+
self.lq_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/inputs').glob('*.'+opt['image_type'])])
|
110 |
+
self.np_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/latents').glob('*.npy')])
|
111 |
+
self.sample_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/samples').glob('*.'+opt['image_type'])])
|
112 |
+
if len(opt['gt_path']) > 1:
|
113 |
+
for i in range(len(opt['gt_path'])-1):
|
114 |
+
self.gt_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/gts').glob('*.'+opt['image_type'])]))
|
115 |
+
self.lq_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/inputs').glob('*.'+opt['image_type'])]))
|
116 |
+
self.np_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/latents').glob('*.npy')]))
|
117 |
+
self.sample_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/samples').glob('*.'+opt['image_type'])]))
|
118 |
+
|
119 |
+
assert len(self.gt_paths) == len(self.lq_paths)
|
120 |
+
assert len(self.gt_paths) == len(self.np_paths)
|
121 |
+
assert len(self.gt_paths) == len(self.sample_paths)
|
122 |
+
|
123 |
+
def __getitem__(self, index):
|
124 |
+
if self.file_client is None:
|
125 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
126 |
+
|
127 |
+
# load lq image
|
128 |
+
lq_path = self.lq_paths[index]
|
129 |
+
gt_path = self.gt_paths[index]
|
130 |
+
sample_path = self.sample_paths[index]
|
131 |
+
np_path = self.np_paths[index]
|
132 |
+
|
133 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
134 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
135 |
+
|
136 |
+
img_bytes_gt = self.file_client.get(gt_path, 'gt')
|
137 |
+
img_gt = imfrombytes(img_bytes_gt, float32=True)
|
138 |
+
|
139 |
+
img_bytes_sample = self.file_client.get(sample_path, 'sample')
|
140 |
+
img_sample = imfrombytes(img_bytes_sample, float32=True)
|
141 |
+
|
142 |
+
latent_np = np.load(np_path)
|
143 |
+
|
144 |
+
# color space transform
|
145 |
+
if 'color' in self.opt and self.opt['color'] == 'y':
|
146 |
+
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
|
147 |
+
img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None]
|
148 |
+
img_sample = rgb2ycbcr(img_sample, y_only=True)[..., None]
|
149 |
+
|
150 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
151 |
+
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
|
152 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
153 |
+
img_sample = img2tensor(img_sample, bgr2rgb=True, float32=True)
|
154 |
+
latent_np = torch.from_numpy(latent_np).float()
|
155 |
+
latent_np = latent_np.to(img_gt.device)
|
156 |
+
# normalize
|
157 |
+
if self.mean is not None or self.std is not None:
|
158 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
159 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
160 |
+
normalize(img_sample, self.mean, self.std, inplace=True)
|
161 |
+
return {'lq': img_lq, 'lq_path': lq_path, 'gt': img_gt, 'gt_path': gt_path, 'latent': latent_np[0], 'latent_path': np_path, 'sample': img_sample, 'sample_path': sample_path}
|
162 |
+
|
163 |
+
def __len__(self):
|
164 |
+
return len(self.gt_paths)
|
StableSR/basicsr/data/transforms.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def mod_crop(img, scale):
|
7 |
+
"""Mod crop images, used during testing.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
img (ndarray): Input image.
|
11 |
+
scale (int): Scale factor.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
ndarray: Result image.
|
15 |
+
"""
|
16 |
+
img = img.copy()
|
17 |
+
if img.ndim in (2, 3):
|
18 |
+
h, w = img.shape[0], img.shape[1]
|
19 |
+
h_remainder, w_remainder = h % scale, w % scale
|
20 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
21 |
+
else:
|
22 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
23 |
+
return img
|
24 |
+
|
25 |
+
|
26 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
27 |
+
"""Paired random crop. Support Numpy array and Tensor inputs.
|
28 |
+
|
29 |
+
It crops lists of lq and gt images with corresponding locations.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
33 |
+
should have the same shape. If the input is an ndarray, it will
|
34 |
+
be transformed to a list containing itself.
|
35 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
36 |
+
should have the same shape. If the input is an ndarray, it will
|
37 |
+
be transformed to a list containing itself.
|
38 |
+
gt_patch_size (int): GT patch size.
|
39 |
+
scale (int): Scale factor.
|
40 |
+
gt_path (str): Path to ground-truth. Default: None.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
44 |
+
only have one element, just return ndarray.
|
45 |
+
"""
|
46 |
+
|
47 |
+
if not isinstance(img_gts, list):
|
48 |
+
img_gts = [img_gts]
|
49 |
+
if not isinstance(img_lqs, list):
|
50 |
+
img_lqs = [img_lqs]
|
51 |
+
|
52 |
+
# determine input type: Numpy array or Tensor
|
53 |
+
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
54 |
+
|
55 |
+
if input_type == 'Tensor':
|
56 |
+
h_lq, w_lq = img_lqs[0].size()[-2:]
|
57 |
+
h_gt, w_gt = img_gts[0].size()[-2:]
|
58 |
+
else:
|
59 |
+
h_lq, w_lq = img_lqs[0].shape[0:2]
|
60 |
+
h_gt, w_gt = img_gts[0].shape[0:2]
|
61 |
+
lq_patch_size = gt_patch_size // scale
|
62 |
+
|
63 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
64 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
65 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
66 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
67 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
68 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
69 |
+
f'Please remove {gt_path}.')
|
70 |
+
|
71 |
+
# randomly choose top and left coordinates for lq patch
|
72 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
73 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
74 |
+
|
75 |
+
# crop lq patch
|
76 |
+
if input_type == 'Tensor':
|
77 |
+
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
78 |
+
else:
|
79 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
80 |
+
|
81 |
+
# crop corresponding gt patch
|
82 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
83 |
+
if input_type == 'Tensor':
|
84 |
+
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
85 |
+
else:
|
86 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
87 |
+
if len(img_gts) == 1:
|
88 |
+
img_gts = img_gts[0]
|
89 |
+
if len(img_lqs) == 1:
|
90 |
+
img_lqs = img_lqs[0]
|
91 |
+
return img_gts, img_lqs
|
92 |
+
|
93 |
+
def triplet_random_crop(img_gts, img_lqs, img_segs, gt_patch_size, scale, gt_path=None):
|
94 |
+
|
95 |
+
if not isinstance(img_gts, list):
|
96 |
+
img_gts = [img_gts]
|
97 |
+
if not isinstance(img_lqs, list):
|
98 |
+
img_lqs = [img_lqs]
|
99 |
+
if not isinstance(img_segs, list):
|
100 |
+
img_segs = [img_segs]
|
101 |
+
|
102 |
+
# determine input type: Numpy array or Tensor
|
103 |
+
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
104 |
+
|
105 |
+
if input_type == 'Tensor':
|
106 |
+
h_lq, w_lq = img_lqs[0].size()[-2:]
|
107 |
+
h_gt, w_gt = img_gts[0].size()[-2:]
|
108 |
+
h_seg, w_seg = img_segs[0].size()[-2:]
|
109 |
+
else:
|
110 |
+
h_lq, w_lq = img_lqs[0].shape[0:2]
|
111 |
+
h_gt, w_gt = img_gts[0].shape[0:2]
|
112 |
+
h_seg, w_seg = img_segs[0].shape[0:2]
|
113 |
+
lq_patch_size = gt_patch_size // scale
|
114 |
+
|
115 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
116 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
117 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
118 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
119 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
120 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
121 |
+
f'Please remove {gt_path}.')
|
122 |
+
|
123 |
+
# randomly choose top and left coordinates for lq patch
|
124 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
125 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
126 |
+
|
127 |
+
# crop lq patch
|
128 |
+
if input_type == 'Tensor':
|
129 |
+
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
130 |
+
else:
|
131 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
132 |
+
|
133 |
+
# crop corresponding gt patch
|
134 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
135 |
+
if input_type == 'Tensor':
|
136 |
+
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
137 |
+
else:
|
138 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
139 |
+
|
140 |
+
if input_type == 'Tensor':
|
141 |
+
img_segs = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_segs]
|
142 |
+
else:
|
143 |
+
img_segs = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_segs]
|
144 |
+
|
145 |
+
if len(img_gts) == 1:
|
146 |
+
img_gts = img_gts[0]
|
147 |
+
if len(img_lqs) == 1:
|
148 |
+
img_lqs = img_lqs[0]
|
149 |
+
if len(img_segs) == 1:
|
150 |
+
img_segs = img_segs[0]
|
151 |
+
|
152 |
+
return img_gts, img_lqs, img_segs
|
153 |
+
|
154 |
+
|
155 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
156 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
157 |
+
|
158 |
+
We use vertical flip and transpose for rotation implementation.
|
159 |
+
All the images in the list use the same augmentation.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
163 |
+
is an ndarray, it will be transformed to a list.
|
164 |
+
hflip (bool): Horizontal flip. Default: True.
|
165 |
+
rotation (bool): Ratotation. Default: True.
|
166 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
167 |
+
ndarray, it will be transformed to a list.
|
168 |
+
Dimension is (h, w, 2). Default: None.
|
169 |
+
return_status (bool): Return the status of flip and rotation.
|
170 |
+
Default: False.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
174 |
+
results only have one element, just return ndarray.
|
175 |
+
|
176 |
+
"""
|
177 |
+
hflip = hflip and random.random() < 0.5
|
178 |
+
vflip = rotation and random.random() < 0.5
|
179 |
+
rot90 = rotation and random.random() < 0.5
|
180 |
+
|
181 |
+
def _augment(img):
|
182 |
+
if hflip: # horizontal
|
183 |
+
cv2.flip(img, 1, img)
|
184 |
+
if vflip: # vertical
|
185 |
+
cv2.flip(img, 0, img)
|
186 |
+
if rot90:
|
187 |
+
img = img.transpose(1, 0, 2)
|
188 |
+
return img
|
189 |
+
|
190 |
+
def _augment_flow(flow):
|
191 |
+
if hflip: # horizontal
|
192 |
+
cv2.flip(flow, 1, flow)
|
193 |
+
flow[:, :, 0] *= -1
|
194 |
+
if vflip: # vertical
|
195 |
+
cv2.flip(flow, 0, flow)
|
196 |
+
flow[:, :, 1] *= -1
|
197 |
+
if rot90:
|
198 |
+
flow = flow.transpose(1, 0, 2)
|
199 |
+
flow = flow[:, :, [1, 0]]
|
200 |
+
return flow
|
201 |
+
|
202 |
+
if not isinstance(imgs, list):
|
203 |
+
imgs = [imgs]
|
204 |
+
imgs = [_augment(img) for img in imgs]
|
205 |
+
if len(imgs) == 1:
|
206 |
+
imgs = imgs[0]
|
207 |
+
|
208 |
+
if flows is not None:
|
209 |
+
if not isinstance(flows, list):
|
210 |
+
flows = [flows]
|
211 |
+
flows = [_augment_flow(flow) for flow in flows]
|
212 |
+
if len(flows) == 1:
|
213 |
+
flows = flows[0]
|
214 |
+
return imgs, flows
|
215 |
+
else:
|
216 |
+
if return_status:
|
217 |
+
return imgs, (hflip, vflip, rot90)
|
218 |
+
else:
|
219 |
+
return imgs
|
220 |
+
|
221 |
+
|
222 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
223 |
+
"""Rotate image.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
img (ndarray): Image to be rotated.
|
227 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
228 |
+
counter-clockwise rotation.
|
229 |
+
center (tuple[int]): Rotation center. If the center is None,
|
230 |
+
initialize it as the center of the image. Default: None.
|
231 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
232 |
+
"""
|
233 |
+
(h, w) = img.shape[:2]
|
234 |
+
|
235 |
+
if center is None:
|
236 |
+
center = (w // 2, h // 2)
|
237 |
+
|
238 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
239 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
240 |
+
return rotated_img
|
StableSR/basicsr/data/video_test_dataset.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import torch
|
3 |
+
from os import path as osp
|
4 |
+
from torch.utils import data as data
|
5 |
+
|
6 |
+
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
|
7 |
+
from basicsr.utils import get_root_logger, scandir
|
8 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
9 |
+
|
10 |
+
|
11 |
+
@DATASET_REGISTRY.register()
|
12 |
+
class VideoTestDataset(data.Dataset):
|
13 |
+
"""Video test dataset.
|
14 |
+
|
15 |
+
Supported datasets: Vid4, REDS4, REDSofficial.
|
16 |
+
More generally, it supports testing dataset with following structures:
|
17 |
+
|
18 |
+
::
|
19 |
+
|
20 |
+
dataroot
|
21 |
+
├── subfolder1
|
22 |
+
├── frame000
|
23 |
+
├── frame001
|
24 |
+
├── ...
|
25 |
+
├── subfolder2
|
26 |
+
├── frame000
|
27 |
+
├── frame001
|
28 |
+
├── ...
|
29 |
+
├── ...
|
30 |
+
|
31 |
+
For testing datasets, there is no need to prepare LMDB files.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
35 |
+
dataroot_gt (str): Data root path for gt.
|
36 |
+
dataroot_lq (str): Data root path for lq.
|
37 |
+
io_backend (dict): IO backend type and other kwarg.
|
38 |
+
cache_data (bool): Whether to cache testing datasets.
|
39 |
+
name (str): Dataset name.
|
40 |
+
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
41 |
+
in the dataroot will be used.
|
42 |
+
num_frame (int): Window size for input frames.
|
43 |
+
padding (str): Padding mode.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, opt):
|
47 |
+
super(VideoTestDataset, self).__init__()
|
48 |
+
self.opt = opt
|
49 |
+
self.cache_data = opt['cache_data']
|
50 |
+
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
51 |
+
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
52 |
+
# file client (io backend)
|
53 |
+
self.file_client = None
|
54 |
+
self.io_backend_opt = opt['io_backend']
|
55 |
+
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
56 |
+
|
57 |
+
logger = get_root_logger()
|
58 |
+
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
59 |
+
self.imgs_lq, self.imgs_gt = {}, {}
|
60 |
+
if 'meta_info_file' in opt:
|
61 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
62 |
+
subfolders = [line.split(' ')[0] for line in fin]
|
63 |
+
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
|
64 |
+
subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
|
65 |
+
else:
|
66 |
+
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
|
67 |
+
subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
|
68 |
+
|
69 |
+
if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
|
70 |
+
for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
|
71 |
+
# get frame list for lq and gt
|
72 |
+
subfolder_name = osp.basename(subfolder_lq)
|
73 |
+
img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
|
74 |
+
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
|
75 |
+
|
76 |
+
max_idx = len(img_paths_lq)
|
77 |
+
assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
|
78 |
+
f' and gt folders ({len(img_paths_gt)})')
|
79 |
+
|
80 |
+
self.data_info['lq_path'].extend(img_paths_lq)
|
81 |
+
self.data_info['gt_path'].extend(img_paths_gt)
|
82 |
+
self.data_info['folder'].extend([subfolder_name] * max_idx)
|
83 |
+
for i in range(max_idx):
|
84 |
+
self.data_info['idx'].append(f'{i}/{max_idx}')
|
85 |
+
border_l = [0] * max_idx
|
86 |
+
for i in range(self.opt['num_frame'] // 2):
|
87 |
+
border_l[i] = 1
|
88 |
+
border_l[max_idx - i - 1] = 1
|
89 |
+
self.data_info['border'].extend(border_l)
|
90 |
+
|
91 |
+
# cache data or save the frame list
|
92 |
+
if self.cache_data:
|
93 |
+
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
|
94 |
+
self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
|
95 |
+
self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
|
96 |
+
else:
|
97 |
+
self.imgs_lq[subfolder_name] = img_paths_lq
|
98 |
+
self.imgs_gt[subfolder_name] = img_paths_gt
|
99 |
+
else:
|
100 |
+
raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
|
101 |
+
|
102 |
+
def __getitem__(self, index):
|
103 |
+
folder = self.data_info['folder'][index]
|
104 |
+
idx, max_idx = self.data_info['idx'][index].split('/')
|
105 |
+
idx, max_idx = int(idx), int(max_idx)
|
106 |
+
border = self.data_info['border'][index]
|
107 |
+
lq_path = self.data_info['lq_path'][index]
|
108 |
+
|
109 |
+
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
110 |
+
|
111 |
+
if self.cache_data:
|
112 |
+
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
113 |
+
img_gt = self.imgs_gt[folder][idx]
|
114 |
+
else:
|
115 |
+
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
116 |
+
imgs_lq = read_img_seq(img_paths_lq)
|
117 |
+
img_gt = read_img_seq([self.imgs_gt[folder][idx]])
|
118 |
+
img_gt.squeeze_(0)
|
119 |
+
|
120 |
+
return {
|
121 |
+
'lq': imgs_lq, # (t, c, h, w)
|
122 |
+
'gt': img_gt, # (c, h, w)
|
123 |
+
'folder': folder, # folder name
|
124 |
+
'idx': self.data_info['idx'][index], # e.g., 0/99
|
125 |
+
'border': border, # 1 for border, 0 for non-border
|
126 |
+
'lq_path': lq_path # center frame
|
127 |
+
}
|
128 |
+
|
129 |
+
def __len__(self):
|
130 |
+
return len(self.data_info['gt_path'])
|
131 |
+
|
132 |
+
|
133 |
+
@DATASET_REGISTRY.register()
|
134 |
+
class VideoTestVimeo90KDataset(data.Dataset):
|
135 |
+
"""Video test dataset for Vimeo90k-Test dataset.
|
136 |
+
|
137 |
+
It only keeps the center frame for testing.
|
138 |
+
For testing datasets, there is no need to prepare LMDB files.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
142 |
+
dataroot_gt (str): Data root path for gt.
|
143 |
+
dataroot_lq (str): Data root path for lq.
|
144 |
+
io_backend (dict): IO backend type and other kwarg.
|
145 |
+
cache_data (bool): Whether to cache testing datasets.
|
146 |
+
name (str): Dataset name.
|
147 |
+
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
|
148 |
+
in the dataroot will be used.
|
149 |
+
num_frame (int): Window size for input frames.
|
150 |
+
padding (str): Padding mode.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self, opt):
|
154 |
+
super(VideoTestVimeo90KDataset, self).__init__()
|
155 |
+
self.opt = opt
|
156 |
+
self.cache_data = opt['cache_data']
|
157 |
+
if self.cache_data:
|
158 |
+
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
|
159 |
+
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
|
160 |
+
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
|
161 |
+
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
162 |
+
|
163 |
+
# file client (io backend)
|
164 |
+
self.file_client = None
|
165 |
+
self.io_backend_opt = opt['io_backend']
|
166 |
+
assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
|
167 |
+
|
168 |
+
logger = get_root_logger()
|
169 |
+
logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
|
170 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
171 |
+
subfolders = [line.split(' ')[0] for line in fin]
|
172 |
+
for idx, subfolder in enumerate(subfolders):
|
173 |
+
gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
|
174 |
+
self.data_info['gt_path'].append(gt_path)
|
175 |
+
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
|
176 |
+
self.data_info['lq_path'].append(lq_paths)
|
177 |
+
self.data_info['folder'].append('vimeo90k')
|
178 |
+
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
|
179 |
+
self.data_info['border'].append(0)
|
180 |
+
|
181 |
+
def __getitem__(self, index):
|
182 |
+
lq_path = self.data_info['lq_path'][index]
|
183 |
+
gt_path = self.data_info['gt_path'][index]
|
184 |
+
imgs_lq = read_img_seq(lq_path)
|
185 |
+
img_gt = read_img_seq([gt_path])
|
186 |
+
img_gt.squeeze_(0)
|
187 |
+
|
188 |
+
return {
|
189 |
+
'lq': imgs_lq, # (t, c, h, w)
|
190 |
+
'gt': img_gt, # (c, h, w)
|
191 |
+
'folder': self.data_info['folder'][index], # folder name
|
192 |
+
'idx': self.data_info['idx'][index], # e.g., 0/843
|
193 |
+
'border': self.data_info['border'][index], # 0 for non-border
|
194 |
+
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
|
195 |
+
}
|
196 |
+
|
197 |
+
def __len__(self):
|
198 |
+
return len(self.data_info['gt_path'])
|
199 |
+
|
200 |
+
|
201 |
+
@DATASET_REGISTRY.register()
|
202 |
+
class VideoTestDUFDataset(VideoTestDataset):
|
203 |
+
""" Video test dataset for DUF dataset.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
|
207 |
+
It has the following extra keys:
|
208 |
+
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
|
209 |
+
scale (bool): Scale, which will be added automatically.
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __getitem__(self, index):
|
213 |
+
folder = self.data_info['folder'][index]
|
214 |
+
idx, max_idx = self.data_info['idx'][index].split('/')
|
215 |
+
idx, max_idx = int(idx), int(max_idx)
|
216 |
+
border = self.data_info['border'][index]
|
217 |
+
lq_path = self.data_info['lq_path'][index]
|
218 |
+
|
219 |
+
select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
|
220 |
+
|
221 |
+
if self.cache_data:
|
222 |
+
if self.opt['use_duf_downsampling']:
|
223 |
+
# read imgs_gt to generate low-resolution frames
|
224 |
+
imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
|
225 |
+
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
226 |
+
else:
|
227 |
+
imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
|
228 |
+
img_gt = self.imgs_gt[folder][idx]
|
229 |
+
else:
|
230 |
+
if self.opt['use_duf_downsampling']:
|
231 |
+
img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
|
232 |
+
# read imgs_gt to generate low-resolution frames
|
233 |
+
imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
|
234 |
+
imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
|
235 |
+
else:
|
236 |
+
img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
|
237 |
+
imgs_lq = read_img_seq(img_paths_lq)
|
238 |
+
img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
|
239 |
+
img_gt.squeeze_(0)
|
240 |
+
|
241 |
+
return {
|
242 |
+
'lq': imgs_lq, # (t, c, h, w)
|
243 |
+
'gt': img_gt, # (c, h, w)
|
244 |
+
'folder': folder, # folder name
|
245 |
+
'idx': self.data_info['idx'][index], # e.g., 0/99
|
246 |
+
'border': border, # 1 for border, 0 for non-border
|
247 |
+
'lq_path': lq_path # center frame
|
248 |
+
}
|
249 |
+
|
250 |
+
|
251 |
+
@DATASET_REGISTRY.register()
|
252 |
+
class VideoRecurrentTestDataset(VideoTestDataset):
|
253 |
+
"""Video test dataset for recurrent architectures, which takes LR video
|
254 |
+
frames as input and output corresponding HR video frames.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
opt (dict): Same as VideoTestDataset. Unused opt:
|
258 |
+
padding (str): Padding mode.
|
259 |
+
|
260 |
+
"""
|
261 |
+
|
262 |
+
def __init__(self, opt):
|
263 |
+
super(VideoRecurrentTestDataset, self).__init__(opt)
|
264 |
+
# Find unique folder strings
|
265 |
+
self.folders = sorted(list(set(self.data_info['folder'])))
|
266 |
+
|
267 |
+
def __getitem__(self, index):
|
268 |
+
folder = self.folders[index]
|
269 |
+
|
270 |
+
if self.cache_data:
|
271 |
+
imgs_lq = self.imgs_lq[folder]
|
272 |
+
imgs_gt = self.imgs_gt[folder]
|
273 |
+
else:
|
274 |
+
raise NotImplementedError('Without cache_data is not implemented.')
|
275 |
+
|
276 |
+
return {
|
277 |
+
'lq': imgs_lq,
|
278 |
+
'gt': imgs_gt,
|
279 |
+
'folder': folder,
|
280 |
+
}
|
281 |
+
|
282 |
+
def __len__(self):
|
283 |
+
return len(self.folders)
|
StableSR/basicsr/data/vimeo90k_dataset.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
from torch.utils import data as data
|
5 |
+
|
6 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
7 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
8 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
9 |
+
|
10 |
+
|
11 |
+
@DATASET_REGISTRY.register()
|
12 |
+
class Vimeo90KDataset(data.Dataset):
|
13 |
+
"""Vimeo90K dataset for training.
|
14 |
+
|
15 |
+
The keys are generated from a meta info txt file.
|
16 |
+
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
|
17 |
+
|
18 |
+
Each line contains the following items, separated by a white space.
|
19 |
+
|
20 |
+
1. clip name;
|
21 |
+
2. frame number;
|
22 |
+
3. image shape
|
23 |
+
|
24 |
+
Examples:
|
25 |
+
|
26 |
+
::
|
27 |
+
|
28 |
+
00001/0001 7 (256,448,3)
|
29 |
+
00001/0002 7 (256,448,3)
|
30 |
+
|
31 |
+
- Key examples: "00001/0001"
|
32 |
+
- GT (gt): Ground-Truth;
|
33 |
+
- LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
|
34 |
+
|
35 |
+
The neighboring frame list for different num_frame:
|
36 |
+
|
37 |
+
::
|
38 |
+
|
39 |
+
num_frame | frame list
|
40 |
+
1 | 4
|
41 |
+
3 | 3,4,5
|
42 |
+
5 | 2,3,4,5,6
|
43 |
+
7 | 1,2,3,4,5,6,7
|
44 |
+
|
45 |
+
Args:
|
46 |
+
opt (dict): Config for train dataset. It contains the following keys:
|
47 |
+
dataroot_gt (str): Data root path for gt.
|
48 |
+
dataroot_lq (str): Data root path for lq.
|
49 |
+
meta_info_file (str): Path for meta information file.
|
50 |
+
io_backend (dict): IO backend type and other kwarg.
|
51 |
+
num_frame (int): Window size for input frames.
|
52 |
+
gt_size (int): Cropped patched size for gt patches.
|
53 |
+
random_reverse (bool): Random reverse input frames.
|
54 |
+
use_hflip (bool): Use horizontal flips.
|
55 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
56 |
+
scale (bool): Scale, which will be added automatically.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, opt):
|
60 |
+
super(Vimeo90KDataset, self).__init__()
|
61 |
+
self.opt = opt
|
62 |
+
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
|
63 |
+
|
64 |
+
with open(opt['meta_info_file'], 'r') as fin:
|
65 |
+
self.keys = [line.split(' ')[0] for line in fin]
|
66 |
+
|
67 |
+
# file client (io backend)
|
68 |
+
self.file_client = None
|
69 |
+
self.io_backend_opt = opt['io_backend']
|
70 |
+
self.is_lmdb = False
|
71 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
72 |
+
self.is_lmdb = True
|
73 |
+
self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
|
74 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
75 |
+
|
76 |
+
# indices of input images
|
77 |
+
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
|
78 |
+
|
79 |
+
# temporal augmentation configs
|
80 |
+
self.random_reverse = opt['random_reverse']
|
81 |
+
logger = get_root_logger()
|
82 |
+
logger.info(f'Random reverse is {self.random_reverse}.')
|
83 |
+
|
84 |
+
def __getitem__(self, index):
|
85 |
+
if self.file_client is None:
|
86 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
87 |
+
|
88 |
+
# random reverse
|
89 |
+
if self.random_reverse and random.random() < 0.5:
|
90 |
+
self.neighbor_list.reverse()
|
91 |
+
|
92 |
+
scale = self.opt['scale']
|
93 |
+
gt_size = self.opt['gt_size']
|
94 |
+
key = self.keys[index]
|
95 |
+
clip, seq = key.split('/') # key example: 00001/0001
|
96 |
+
|
97 |
+
# get the GT frame (im4.png)
|
98 |
+
if self.is_lmdb:
|
99 |
+
img_gt_path = f'{key}/im4'
|
100 |
+
else:
|
101 |
+
img_gt_path = self.gt_root / clip / seq / 'im4.png'
|
102 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
103 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
104 |
+
|
105 |
+
# get the neighboring LQ frames
|
106 |
+
img_lqs = []
|
107 |
+
for neighbor in self.neighbor_list:
|
108 |
+
if self.is_lmdb:
|
109 |
+
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
110 |
+
else:
|
111 |
+
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
112 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
113 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
114 |
+
img_lqs.append(img_lq)
|
115 |
+
|
116 |
+
# randomly crop
|
117 |
+
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
|
118 |
+
|
119 |
+
# augmentation - flip, rotate
|
120 |
+
img_lqs.append(img_gt)
|
121 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
122 |
+
|
123 |
+
img_results = img2tensor(img_results)
|
124 |
+
img_lqs = torch.stack(img_results[0:-1], dim=0)
|
125 |
+
img_gt = img_results[-1]
|
126 |
+
|
127 |
+
# img_lqs: (t, c, h, w)
|
128 |
+
# img_gt: (c, h, w)
|
129 |
+
# key: str
|
130 |
+
return {'lq': img_lqs, 'gt': img_gt, 'key': key}
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.keys)
|
134 |
+
|
135 |
+
|
136 |
+
@DATASET_REGISTRY.register()
|
137 |
+
class Vimeo90KRecurrentDataset(Vimeo90KDataset):
|
138 |
+
|
139 |
+
def __init__(self, opt):
|
140 |
+
super(Vimeo90KRecurrentDataset, self).__init__(opt)
|
141 |
+
|
142 |
+
self.flip_sequence = opt['flip_sequence']
|
143 |
+
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
|
144 |
+
|
145 |
+
def __getitem__(self, index):
|
146 |
+
if self.file_client is None:
|
147 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
148 |
+
|
149 |
+
# random reverse
|
150 |
+
if self.random_reverse and random.random() < 0.5:
|
151 |
+
self.neighbor_list.reverse()
|
152 |
+
|
153 |
+
scale = self.opt['scale']
|
154 |
+
gt_size = self.opt['gt_size']
|
155 |
+
key = self.keys[index]
|
156 |
+
clip, seq = key.split('/') # key example: 00001/0001
|
157 |
+
|
158 |
+
# get the neighboring LQ and GT frames
|
159 |
+
img_lqs = []
|
160 |
+
img_gts = []
|
161 |
+
for neighbor in self.neighbor_list:
|
162 |
+
if self.is_lmdb:
|
163 |
+
img_lq_path = f'{clip}/{seq}/im{neighbor}'
|
164 |
+
img_gt_path = f'{clip}/{seq}/im{neighbor}'
|
165 |
+
else:
|
166 |
+
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
|
167 |
+
img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
|
168 |
+
# LQ
|
169 |
+
img_bytes = self.file_client.get(img_lq_path, 'lq')
|
170 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
171 |
+
# GT
|
172 |
+
img_bytes = self.file_client.get(img_gt_path, 'gt')
|
173 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
174 |
+
|
175 |
+
img_lqs.append(img_lq)
|
176 |
+
img_gts.append(img_gt)
|
177 |
+
|
178 |
+
# randomly crop
|
179 |
+
img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
|
180 |
+
|
181 |
+
# augmentation - flip, rotate
|
182 |
+
img_lqs.extend(img_gts)
|
183 |
+
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
|
184 |
+
|
185 |
+
img_results = img2tensor(img_results)
|
186 |
+
img_lqs = torch.stack(img_results[:7], dim=0)
|
187 |
+
img_gts = torch.stack(img_results[7:], dim=0)
|
188 |
+
|
189 |
+
if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
|
190 |
+
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
|
191 |
+
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
|
192 |
+
|
193 |
+
# img_lqs: (t, c, h, w)
|
194 |
+
# img_gt: (c, h, w)
|
195 |
+
# key: str
|
196 |
+
return {'lq': img_lqs, 'gt': img_gts, 'key': key}
|
197 |
+
|
198 |
+
def __len__(self):
|
199 |
+
return len(self.keys)
|
StableSR/basicsr/losses/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from basicsr.utils import get_root_logger, scandir
|
6 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
7 |
+
from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
|
8 |
+
|
9 |
+
__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
|
10 |
+
|
11 |
+
# automatically scan and import loss modules for registry
|
12 |
+
# scan all the files under the 'losses' folder and collect files ending with '_loss.py'
|
13 |
+
loss_folder = osp.dirname(osp.abspath(__file__))
|
14 |
+
loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
|
15 |
+
# import all the loss modules
|
16 |
+
_model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames]
|
17 |
+
|
18 |
+
|
19 |
+
def build_loss(opt):
|
20 |
+
"""Build loss from options.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
opt (dict): Configuration. It must contain:
|
24 |
+
type (str): Model type.
|
25 |
+
"""
|
26 |
+
opt = deepcopy(opt)
|
27 |
+
loss_type = opt.pop('type')
|
28 |
+
loss = LOSS_REGISTRY.get(loss_type)(**opt)
|
29 |
+
logger = get_root_logger()
|
30 |
+
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
|
31 |
+
return loss
|
StableSR/basicsr/losses/basic_loss.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.archs.vgg_arch import VGGFeatureExtractor
|
6 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
7 |
+
from .loss_util import weighted_loss
|
8 |
+
|
9 |
+
_reduction_modes = ['none', 'mean', 'sum']
|
10 |
+
|
11 |
+
|
12 |
+
@weighted_loss
|
13 |
+
def l1_loss(pred, target):
|
14 |
+
return F.l1_loss(pred, target, reduction='none')
|
15 |
+
|
16 |
+
|
17 |
+
@weighted_loss
|
18 |
+
def mse_loss(pred, target):
|
19 |
+
return F.mse_loss(pred, target, reduction='none')
|
20 |
+
|
21 |
+
|
22 |
+
@weighted_loss
|
23 |
+
def charbonnier_loss(pred, target, eps=1e-12):
|
24 |
+
return torch.sqrt((pred - target)**2 + eps)
|
25 |
+
|
26 |
+
|
27 |
+
@LOSS_REGISTRY.register()
|
28 |
+
class L1Loss(nn.Module):
|
29 |
+
"""L1 (mean absolute error, MAE) loss.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
33 |
+
reduction (str): Specifies the reduction to apply to the output.
|
34 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
38 |
+
super(L1Loss, self).__init__()
|
39 |
+
if reduction not in ['none', 'mean', 'sum']:
|
40 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
|
41 |
+
|
42 |
+
self.loss_weight = loss_weight
|
43 |
+
self.reduction = reduction
|
44 |
+
|
45 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
49 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
50 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
|
51 |
+
"""
|
52 |
+
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
|
53 |
+
|
54 |
+
|
55 |
+
@LOSS_REGISTRY.register()
|
56 |
+
class MSELoss(nn.Module):
|
57 |
+
"""MSE (L2) loss.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
|
61 |
+
reduction (str): Specifies the reduction to apply to the output.
|
62 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
66 |
+
super(MSELoss, self).__init__()
|
67 |
+
if reduction not in ['none', 'mean', 'sum']:
|
68 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
|
69 |
+
|
70 |
+
self.loss_weight = loss_weight
|
71 |
+
self.reduction = reduction
|
72 |
+
|
73 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
77 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
78 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
|
79 |
+
"""
|
80 |
+
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
|
81 |
+
|
82 |
+
|
83 |
+
@LOSS_REGISTRY.register()
|
84 |
+
class CharbonnierLoss(nn.Module):
|
85 |
+
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
|
86 |
+
variant of L1Loss).
|
87 |
+
|
88 |
+
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
|
89 |
+
Super-Resolution".
|
90 |
+
|
91 |
+
Args:
|
92 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
93 |
+
reduction (str): Specifies the reduction to apply to the output.
|
94 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
95 |
+
eps (float): A value used to control the curvature near zero. Default: 1e-12.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
|
99 |
+
super(CharbonnierLoss, self).__init__()
|
100 |
+
if reduction not in ['none', 'mean', 'sum']:
|
101 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
|
102 |
+
|
103 |
+
self.loss_weight = loss_weight
|
104 |
+
self.reduction = reduction
|
105 |
+
self.eps = eps
|
106 |
+
|
107 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
108 |
+
"""
|
109 |
+
Args:
|
110 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
111 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
112 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
|
113 |
+
"""
|
114 |
+
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
|
115 |
+
|
116 |
+
|
117 |
+
@LOSS_REGISTRY.register()
|
118 |
+
class WeightedTVLoss(L1Loss):
|
119 |
+
"""Weighted TV loss.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
126 |
+
if reduction not in ['mean', 'sum']:
|
127 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
|
128 |
+
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
|
129 |
+
|
130 |
+
def forward(self, pred, weight=None):
|
131 |
+
if weight is None:
|
132 |
+
y_weight = None
|
133 |
+
x_weight = None
|
134 |
+
else:
|
135 |
+
y_weight = weight[:, :, :-1, :]
|
136 |
+
x_weight = weight[:, :, :, :-1]
|
137 |
+
|
138 |
+
y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
|
139 |
+
x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
|
140 |
+
|
141 |
+
loss = x_diff + y_diff
|
142 |
+
|
143 |
+
return loss
|
144 |
+
|
145 |
+
|
146 |
+
@LOSS_REGISTRY.register()
|
147 |
+
class PerceptualLoss(nn.Module):
|
148 |
+
"""Perceptual loss with commonly used style loss.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
152 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
153 |
+
feature layer (before relu5_4) will be extracted with weight
|
154 |
+
1.0 in calculating losses.
|
155 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
156 |
+
Default: 'vgg19'.
|
157 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
158 |
+
Default: True.
|
159 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
160 |
+
Default: False.
|
161 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
162 |
+
loss will be calculated and the loss will multiplied by the
|
163 |
+
weight. Default: 1.0.
|
164 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
165 |
+
calculated and the loss will multiplied by the weight.
|
166 |
+
Default: 0.
|
167 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self,
|
171 |
+
layer_weights,
|
172 |
+
vgg_type='vgg19',
|
173 |
+
use_input_norm=True,
|
174 |
+
range_norm=False,
|
175 |
+
perceptual_weight=1.0,
|
176 |
+
style_weight=0.,
|
177 |
+
criterion='l1'):
|
178 |
+
super(PerceptualLoss, self).__init__()
|
179 |
+
self.perceptual_weight = perceptual_weight
|
180 |
+
self.style_weight = style_weight
|
181 |
+
self.layer_weights = layer_weights
|
182 |
+
self.vgg = VGGFeatureExtractor(
|
183 |
+
layer_name_list=list(layer_weights.keys()),
|
184 |
+
vgg_type=vgg_type,
|
185 |
+
use_input_norm=use_input_norm,
|
186 |
+
range_norm=range_norm)
|
187 |
+
|
188 |
+
self.criterion_type = criterion
|
189 |
+
if self.criterion_type == 'l1':
|
190 |
+
self.criterion = torch.nn.L1Loss()
|
191 |
+
elif self.criterion_type == 'l2':
|
192 |
+
self.criterion = torch.nn.L2loss()
|
193 |
+
elif self.criterion_type == 'fro':
|
194 |
+
self.criterion = None
|
195 |
+
else:
|
196 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
197 |
+
|
198 |
+
def forward(self, x, gt):
|
199 |
+
"""Forward function.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
203 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Tensor: Forward results.
|
207 |
+
"""
|
208 |
+
# extract vgg features
|
209 |
+
x_features = self.vgg(x)
|
210 |
+
gt_features = self.vgg(gt.detach())
|
211 |
+
|
212 |
+
# calculate perceptual loss
|
213 |
+
if self.perceptual_weight > 0:
|
214 |
+
percep_loss = 0
|
215 |
+
for k in x_features.keys():
|
216 |
+
if self.criterion_type == 'fro':
|
217 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
218 |
+
else:
|
219 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
220 |
+
percep_loss *= self.perceptual_weight
|
221 |
+
else:
|
222 |
+
percep_loss = None
|
223 |
+
|
224 |
+
# calculate style loss
|
225 |
+
if self.style_weight > 0:
|
226 |
+
style_loss = 0
|
227 |
+
for k in x_features.keys():
|
228 |
+
if self.criterion_type == 'fro':
|
229 |
+
style_loss += torch.norm(
|
230 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
231 |
+
else:
|
232 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
233 |
+
gt_features[k])) * self.layer_weights[k]
|
234 |
+
style_loss *= self.style_weight
|
235 |
+
else:
|
236 |
+
style_loss = None
|
237 |
+
|
238 |
+
return percep_loss, style_loss
|
239 |
+
|
240 |
+
def _gram_mat(self, x):
|
241 |
+
"""Calculate Gram matrix.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
torch.Tensor: Gram matrix.
|
248 |
+
"""
|
249 |
+
n, c, h, w = x.size()
|
250 |
+
features = x.view(n, c, w * h)
|
251 |
+
features_t = features.transpose(1, 2)
|
252 |
+
gram = features.bmm(features_t) / (c * h * w)
|
253 |
+
return gram
|
StableSR/basicsr/losses/gan_loss.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import autograd as autograd
|
4 |
+
from torch import nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
8 |
+
|
9 |
+
|
10 |
+
@LOSS_REGISTRY.register()
|
11 |
+
class GANLoss(nn.Module):
|
12 |
+
"""Define GAN loss.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
16 |
+
real_label_val (float): The value for real label. Default: 1.0.
|
17 |
+
fake_label_val (float): The value for fake label. Default: 0.0.
|
18 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
19 |
+
Note that loss_weight is only for generators; and it is always 1.0
|
20 |
+
for discriminators.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
24 |
+
super(GANLoss, self).__init__()
|
25 |
+
self.gan_type = gan_type
|
26 |
+
self.loss_weight = loss_weight
|
27 |
+
self.real_label_val = real_label_val
|
28 |
+
self.fake_label_val = fake_label_val
|
29 |
+
|
30 |
+
if self.gan_type == 'vanilla':
|
31 |
+
self.loss = nn.BCEWithLogitsLoss()
|
32 |
+
elif self.gan_type == 'lsgan':
|
33 |
+
self.loss = nn.MSELoss()
|
34 |
+
elif self.gan_type == 'wgan':
|
35 |
+
self.loss = self._wgan_loss
|
36 |
+
elif self.gan_type == 'wgan_softplus':
|
37 |
+
self.loss = self._wgan_softplus_loss
|
38 |
+
elif self.gan_type == 'hinge':
|
39 |
+
self.loss = nn.ReLU()
|
40 |
+
else:
|
41 |
+
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
|
42 |
+
|
43 |
+
def _wgan_loss(self, input, target):
|
44 |
+
"""wgan loss.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
input (Tensor): Input tensor.
|
48 |
+
target (bool): Target label.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Tensor: wgan loss.
|
52 |
+
"""
|
53 |
+
return -input.mean() if target else input.mean()
|
54 |
+
|
55 |
+
def _wgan_softplus_loss(self, input, target):
|
56 |
+
"""wgan loss with soft plus. softplus is a smooth approximation to the
|
57 |
+
ReLU function.
|
58 |
+
|
59 |
+
In StyleGAN2, it is called:
|
60 |
+
Logistic loss for discriminator;
|
61 |
+
Non-saturating loss for generator.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
input (Tensor): Input tensor.
|
65 |
+
target (bool): Target label.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Tensor: wgan loss.
|
69 |
+
"""
|
70 |
+
return F.softplus(-input).mean() if target else F.softplus(input).mean()
|
71 |
+
|
72 |
+
def get_target_label(self, input, target_is_real):
|
73 |
+
"""Get target label.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
input (Tensor): Input tensor.
|
77 |
+
target_is_real (bool): Whether the target is real or fake.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
|
81 |
+
return Tensor.
|
82 |
+
"""
|
83 |
+
|
84 |
+
if self.gan_type in ['wgan', 'wgan_softplus']:
|
85 |
+
return target_is_real
|
86 |
+
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
|
87 |
+
return input.new_ones(input.size()) * target_val
|
88 |
+
|
89 |
+
def forward(self, input, target_is_real, is_disc=False):
|
90 |
+
"""
|
91 |
+
Args:
|
92 |
+
input (Tensor): The input for the loss module, i.e., the network
|
93 |
+
prediction.
|
94 |
+
target_is_real (bool): Whether the targe is real or fake.
|
95 |
+
is_disc (bool): Whether the loss for discriminators or not.
|
96 |
+
Default: False.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Tensor: GAN loss value.
|
100 |
+
"""
|
101 |
+
target_label = self.get_target_label(input, target_is_real)
|
102 |
+
if self.gan_type == 'hinge':
|
103 |
+
if is_disc: # for discriminators in hinge-gan
|
104 |
+
input = -input if target_is_real else input
|
105 |
+
loss = self.loss(1 + input).mean()
|
106 |
+
else: # for generators in hinge-gan
|
107 |
+
loss = -input.mean()
|
108 |
+
else: # other gan types
|
109 |
+
loss = self.loss(input, target_label)
|
110 |
+
|
111 |
+
# loss_weight is always 1.0 for discriminators
|
112 |
+
return loss if is_disc else loss * self.loss_weight
|
113 |
+
|
114 |
+
|
115 |
+
@LOSS_REGISTRY.register()
|
116 |
+
class MultiScaleGANLoss(GANLoss):
|
117 |
+
"""
|
118 |
+
MultiScaleGANLoss accepts a list of predictions
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
122 |
+
super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
|
123 |
+
|
124 |
+
def forward(self, input, target_is_real, is_disc=False):
|
125 |
+
"""
|
126 |
+
The input is a list of tensors, or a list of (a list of tensors)
|
127 |
+
"""
|
128 |
+
if isinstance(input, list):
|
129 |
+
loss = 0
|
130 |
+
for pred_i in input:
|
131 |
+
if isinstance(pred_i, list):
|
132 |
+
# Only compute GAN loss for the last layer
|
133 |
+
# in case of multiscale feature matching
|
134 |
+
pred_i = pred_i[-1]
|
135 |
+
# Safe operation: 0-dim tensor calling self.mean() does nothing
|
136 |
+
loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
|
137 |
+
loss += loss_tensor
|
138 |
+
return loss / len(input)
|
139 |
+
else:
|
140 |
+
return super().forward(input, target_is_real, is_disc)
|
141 |
+
|
142 |
+
|
143 |
+
def r1_penalty(real_pred, real_img):
|
144 |
+
"""R1 regularization for discriminator. The core idea is to
|
145 |
+
penalize the gradient on real data alone: when the
|
146 |
+
generator distribution produces the true data distribution
|
147 |
+
and the discriminator is equal to 0 on the data manifold, the
|
148 |
+
gradient penalty ensures that the discriminator cannot create
|
149 |
+
a non-zero gradient orthogonal to the data manifold without
|
150 |
+
suffering a loss in the GAN game.
|
151 |
+
|
152 |
+
Reference: Eq. 9 in Which training methods for GANs do actually converge.
|
153 |
+
"""
|
154 |
+
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
|
155 |
+
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
156 |
+
return grad_penalty
|
157 |
+
|
158 |
+
|
159 |
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
160 |
+
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
|
161 |
+
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
|
162 |
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
163 |
+
|
164 |
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
165 |
+
|
166 |
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
167 |
+
|
168 |
+
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
|
169 |
+
|
170 |
+
|
171 |
+
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
172 |
+
"""Calculate gradient penalty for wgan-gp.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
discriminator (nn.Module): Network for the discriminator.
|
176 |
+
real_data (Tensor): Real input data.
|
177 |
+
fake_data (Tensor): Fake input data.
|
178 |
+
weight (Tensor): Weight tensor. Default: None.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
Tensor: A tensor for gradient penalty.
|
182 |
+
"""
|
183 |
+
|
184 |
+
batch_size = real_data.size(0)
|
185 |
+
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
186 |
+
|
187 |
+
# interpolate between real_data and fake_data
|
188 |
+
interpolates = alpha * real_data + (1. - alpha) * fake_data
|
189 |
+
interpolates = autograd.Variable(interpolates, requires_grad=True)
|
190 |
+
|
191 |
+
disc_interpolates = discriminator(interpolates)
|
192 |
+
gradients = autograd.grad(
|
193 |
+
outputs=disc_interpolates,
|
194 |
+
inputs=interpolates,
|
195 |
+
grad_outputs=torch.ones_like(disc_interpolates),
|
196 |
+
create_graph=True,
|
197 |
+
retain_graph=True,
|
198 |
+
only_inputs=True)[0]
|
199 |
+
|
200 |
+
if weight is not None:
|
201 |
+
gradients = gradients * weight
|
202 |
+
|
203 |
+
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
|
204 |
+
if weight is not None:
|
205 |
+
gradients_penalty /= torch.mean(weight)
|
206 |
+
|
207 |
+
return gradients_penalty
|
StableSR/basicsr/losses/loss_util.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def reduce_loss(loss, reduction):
|
7 |
+
"""Reduce loss as specified.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
loss (Tensor): Elementwise loss tensor.
|
11 |
+
reduction (str): Options are 'none', 'mean' and 'sum'.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
Tensor: Reduced loss tensor.
|
15 |
+
"""
|
16 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
17 |
+
# none: 0, elementwise_mean:1, sum: 2
|
18 |
+
if reduction_enum == 0:
|
19 |
+
return loss
|
20 |
+
elif reduction_enum == 1:
|
21 |
+
return loss.mean()
|
22 |
+
else:
|
23 |
+
return loss.sum()
|
24 |
+
|
25 |
+
|
26 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
27 |
+
"""Apply element-wise weight and reduce loss.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
loss (Tensor): Element-wise loss.
|
31 |
+
weight (Tensor): Element-wise weights. Default: None.
|
32 |
+
reduction (str): Same as built-in losses of PyTorch. Options are
|
33 |
+
'none', 'mean' and 'sum'. Default: 'mean'.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Tensor: Loss values.
|
37 |
+
"""
|
38 |
+
# if weight is specified, apply element-wise weight
|
39 |
+
if weight is not None:
|
40 |
+
assert weight.dim() == loss.dim()
|
41 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
42 |
+
loss = loss * weight
|
43 |
+
|
44 |
+
# if weight is not specified or reduction is sum, just reduce the loss
|
45 |
+
if weight is None or reduction == 'sum':
|
46 |
+
loss = reduce_loss(loss, reduction)
|
47 |
+
# if reduction is mean, then compute mean over weight region
|
48 |
+
elif reduction == 'mean':
|
49 |
+
if weight.size(1) > 1:
|
50 |
+
weight = weight.sum()
|
51 |
+
else:
|
52 |
+
weight = weight.sum() * loss.size(1)
|
53 |
+
loss = loss.sum() / weight
|
54 |
+
|
55 |
+
return loss
|
56 |
+
|
57 |
+
|
58 |
+
def weighted_loss(loss_func):
|
59 |
+
"""Create a weighted version of a given loss function.
|
60 |
+
|
61 |
+
To use this decorator, the loss function must have the signature like
|
62 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
63 |
+
element-wise loss without any reduction. This decorator will add weight
|
64 |
+
and reduction arguments to the function. The decorated function will have
|
65 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
66 |
+
**kwargs)`.
|
67 |
+
|
68 |
+
:Example:
|
69 |
+
|
70 |
+
>>> import torch
|
71 |
+
>>> @weighted_loss
|
72 |
+
>>> def l1_loss(pred, target):
|
73 |
+
>>> return (pred - target).abs()
|
74 |
+
|
75 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
76 |
+
>>> target = torch.Tensor([1, 1, 1])
|
77 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
78 |
+
|
79 |
+
>>> l1_loss(pred, target)
|
80 |
+
tensor(1.3333)
|
81 |
+
>>> l1_loss(pred, target, weight)
|
82 |
+
tensor(1.5000)
|
83 |
+
>>> l1_loss(pred, target, reduction='none')
|
84 |
+
tensor([1., 1., 2.])
|
85 |
+
>>> l1_loss(pred, target, weight, reduction='sum')
|
86 |
+
tensor(3.)
|
87 |
+
"""
|
88 |
+
|
89 |
+
@functools.wraps(loss_func)
|
90 |
+
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
91 |
+
# get element-wise loss
|
92 |
+
loss = loss_func(pred, target, **kwargs)
|
93 |
+
loss = weight_reduce_loss(loss, weight, reduction)
|
94 |
+
return loss
|
95 |
+
|
96 |
+
return wrapper
|
97 |
+
|
98 |
+
|
99 |
+
def get_local_weights(residual, ksize):
|
100 |
+
"""Get local weights for generating the artifact map of LDL.
|
101 |
+
|
102 |
+
It is only called by the `get_refined_artifact_map` function.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
residual (Tensor): Residual between predicted and ground truth images.
|
106 |
+
ksize (Int): size of the local window.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Tensor: weight for each pixel to be discriminated as an artifact pixel
|
110 |
+
"""
|
111 |
+
|
112 |
+
pad = (ksize - 1) // 2
|
113 |
+
residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
|
114 |
+
|
115 |
+
unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
|
116 |
+
pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
|
117 |
+
|
118 |
+
return pixel_level_weight
|
119 |
+
|
120 |
+
|
121 |
+
def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
|
122 |
+
"""Calculate the artifact map of LDL
|
123 |
+
(Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
|
124 |
+
|
125 |
+
Args:
|
126 |
+
img_gt (Tensor): ground truth images.
|
127 |
+
img_output (Tensor): output images given by the optimizing model.
|
128 |
+
img_ema (Tensor): output images given by the ema model.
|
129 |
+
ksize (Int): size of the local window.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
overall_weight: weight for each pixel to be discriminated as an artifact pixel
|
133 |
+
(calculated based on both local and global observations).
|
134 |
+
"""
|
135 |
+
|
136 |
+
residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
|
137 |
+
residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
|
138 |
+
|
139 |
+
patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
|
140 |
+
pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
|
141 |
+
overall_weight = patch_level_weight * pixel_level_weight
|
142 |
+
|
143 |
+
overall_weight[residual_sr < residual_ema] = 0
|
144 |
+
|
145 |
+
return overall_weight
|
StableSR/basicsr/metrics/README.md
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Metrics
|
2 |
+
|
3 |
+
[English](README.md) **|** [简体中文](README_CN.md)
|
4 |
+
|
5 |
+
- [约定](#约定)
|
6 |
+
- [PSNR 和 SSIM](#psnr-和-ssim)
|
7 |
+
|
8 |
+
## 约定
|
9 |
+
|
10 |
+
因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
|
11 |
+
|
12 |
+
- Numpy 类型 (一般是 cv2 的结果)
|
13 |
+
- UINT8: BGR, [0, 255], (h, w, c)
|
14 |
+
- float: BGR, [0, 1], (h, w, c). 一般作为中间结果
|
15 |
+
- Tensor 类型
|
16 |
+
- float: RGB, [0, 1], (n, c, h, w)
|
17 |
+
|
18 |
+
其他约定:
|
19 |
+
|
20 |
+
- 以 `_pt` 结尾的是 PyTorch 结果
|
21 |
+
- PyTorch version 支持 batch 计算
|
22 |
+
- 颜色转换在 float32 上做;metric计算在 float64 上做
|
23 |
+
|
24 |
+
## PSNR 和 SSIM
|
25 |
+
|
26 |
+
PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
|
27 |
+
在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
|
28 |
+
|
29 |
+
下面列了各个实现的结果比对.
|
30 |
+
总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
|
31 |
+
|
32 |
+
- PSNR 比对
|
33 |
+
|
34 |
+
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|
35 |
+
|:---| :---: | :---: | :---: | :---: | :---: |
|
36 |
+
|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
|
37 |
+
|baboon| Y | - |22.441898 | 22.441899 | 22.444916|
|
38 |
+
|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
|
39 |
+
|comic | Y | - | 21.720398 | 21.720398 | 21.721663|
|
40 |
+
|
41 |
+
- SSIM 比对
|
42 |
+
|
43 |
+
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|
44 |
+
|:---| :---: | :---: | :---: | :---: | :---: |
|
45 |
+
|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
|
46 |
+
|baboon| Y | - |0.453097| 0.453097 | 0.453171|
|
47 |
+
|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
|
48 |
+
|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
|