Sunday01 commited on
Commit
9dce458
1 Parent(s): c7e2109
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +9 -0
  2. .gitattributes +2 -0
  3. .gitignore +41 -0
  4. CHANGELOG.md +111 -0
  5. CHANGELOG_CN.md +111 -0
  6. Dockerfile +65 -0
  7. LICENSE +674 -0
  8. Makefile +13 -0
  9. README_CN.md +413 -0
  10. devscripts/make_readme.py +98 -0
  11. devscripts/utils.py +42 -0
  12. docker_prepare.py +28 -0
  13. fonts/Arial-Unicode-Regular.ttf +3 -0
  14. fonts/anime_ace.ttf +3 -0
  15. fonts/anime_ace_3.ttf +3 -0
  16. fonts/comic shanns 2.ttf +3 -0
  17. fonts/msgothic.ttc +3 -0
  18. fonts/msyh.ttc +3 -0
  19. manga_translator/__init__.py +7 -0
  20. manga_translator/__main__.py +79 -0
  21. manga_translator/args.py +182 -0
  22. manga_translator/colorization/__init__.py +28 -0
  23. manga_translator/colorization/common.py +24 -0
  24. manga_translator/colorization/manga_colorization_v2.py +74 -0
  25. manga_translator/colorization/manga_colorization_v2_utils/denoising/denoiser.py +118 -0
  26. manga_translator/colorization/manga_colorization_v2_utils/denoising/functions.py +102 -0
  27. manga_translator/colorization/manga_colorization_v2_utils/denoising/models.py +100 -0
  28. manga_translator/colorization/manga_colorization_v2_utils/denoising/utils.py +66 -0
  29. manga_translator/colorization/manga_colorization_v2_utils/networks/extractor.py +127 -0
  30. manga_translator/colorization/manga_colorization_v2_utils/networks/models.py +319 -0
  31. manga_translator/colorization/manga_colorization_v2_utils/utils/utils.py +44 -0
  32. manga_translator/detection/__init__.py +37 -0
  33. manga_translator/detection/common.py +146 -0
  34. manga_translator/detection/craft.py +200 -0
  35. manga_translator/detection/craft_utils/refiner.py +65 -0
  36. manga_translator/detection/craft_utils/vgg16_bn.py +71 -0
  37. manga_translator/detection/ctd.py +186 -0
  38. manga_translator/detection/ctd_utils/__init__.py +5 -0
  39. manga_translator/detection/ctd_utils/basemodel.py +250 -0
  40. manga_translator/detection/ctd_utils/textmask.py +174 -0
  41. manga_translator/detection/ctd_utils/utils/db_utils.py +706 -0
  42. manga_translator/detection/ctd_utils/utils/imgproc_utils.py +180 -0
  43. manga_translator/detection/ctd_utils/utils/io_utils.py +54 -0
  44. manga_translator/detection/ctd_utils/utils/weight_init.py +103 -0
  45. manga_translator/detection/ctd_utils/utils/yolov5_utils.py +243 -0
  46. manga_translator/detection/ctd_utils/yolov5/common.py +289 -0
  47. manga_translator/detection/ctd_utils/yolov5/yolo.py +311 -0
  48. manga_translator/detection/dbnet_convnext.py +596 -0
  49. manga_translator/detection/default.py +103 -0
  50. manga_translator/detection/default_utils/CRAFT_resnet34.py +153 -0
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ result
2
+ *.ckpt
3
+ *.pt
4
+ .vscode
5
+ *.onnx
6
+ __pycache__
7
+ ocrs
8
+ models/*
9
+ test/testdata/bboxes
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ttc filter=lfs diff=lfs merge=lfs -text
37
+ *.ttf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ result
2
+ *.ckpt
3
+ *.pt
4
+ .vscode
5
+ *.onnx
6
+ __pycache__
7
+ ocrs
8
+ Manga
9
+ Manga-translated
10
+ /models
11
+ .env
12
+ *.local
13
+ *.local.*
14
+ test/testdata
15
+ .idea
16
+ pyvenv.cfg
17
+ Scripts
18
+ Lib
19
+ include
20
+ share
21
+
22
+ # Distribution / packaging
23
+ .Python
24
+ build/
25
+ develop-eggs/
26
+ dist/
27
+ downloads/
28
+ eggs/
29
+ .eggs/
30
+ lib/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ share/python-wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+ .history
CHANGELOG.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelogs
2
+
3
+ ### 2023-11-11
4
+
5
+ 1. Added new OCR model `48px`
6
+
7
+ ### 2023-05-08
8
+
9
+ 1. Added [4x-UltraSharp](https://mega.nz/folder/qZRBmaIY#nIG8KyWFcGNTuMX_XNbJ_g) upscaler
10
+
11
+ ### 2023-04-30
12
+
13
+ 1. Countless bug fixes and refactor
14
+ 2. Add [CRAFT](https://github.com/clovaai/CRAFT-pytorch) detector, enable by `--detector craft`
15
+
16
+ ### 2022-06-15
17
+
18
+ 1. Added New inpainting model LaMa MPE by [dmMaze](https://github.com/dmMaze) and set as default
19
+
20
+ ### 2022-04-23
21
+
22
+ Project version is now at beta-0.3
23
+
24
+ 1. Added English text renderer by [dmMaze](https://github.com/dmMaze)
25
+ 2. Added new CTC based OCR engine, significant speed improvement
26
+ 3. The new OCR model now support Korean
27
+
28
+ ### 2022-03-19
29
+
30
+ 1. Use new font rendering method by [pokedexter](https://github.com/pokedexter)
31
+ 2. Added manual translation UI by [rspreet92](https://github.com/rspreet92)
32
+
33
+ ### 2022-01-24
34
+
35
+ 1. Added text detection model by [dmMaze](https://github.com/dmMaze)
36
+
37
+ ### 2021-08-21
38
+
39
+ 1. New MST based text region merge algorithm, huge text region merge improvement
40
+ 2. Add baidu translator in demo mode
41
+ 3. Add google translator in demo mode
42
+ 4. Various bugfixes
43
+
44
+ ### 2021-07-29
45
+
46
+ 1. Web demo adds translator, detection resolution and target language option
47
+ 2. Slight text color extraction improvement
48
+
49
+ ### 2021-07-26
50
+
51
+ Major upgrades for all components, now we are on beta! \
52
+ Note in this version all English texts are detected as capital letters, \
53
+ You need Python >= 3.8 for `cached_property` to work
54
+
55
+ 1. Detection model upgrade
56
+ 2. OCR model upgrade, better at text color extraction
57
+ 3. Inpainting model upgrade
58
+ 4. Major text rendering improvement, faster rendering and higher quality text with shadow
59
+ 5. Slight mask generation improvement
60
+ 6. Various bugfixes
61
+ 7. Default detection resolution has been dialed back to 1536 from 2048
62
+
63
+ ### 2021-07-09
64
+
65
+ 1. Fix erroneous image rendering when inpainting is not used
66
+
67
+ ### 2021-06-18
68
+
69
+ 1. Support manual translation
70
+ 2. Support detection and rendering of angled texts
71
+
72
+ ### 2021-06-13
73
+
74
+ 1. Text mask completion is now based on CRF, mask quality is drastically improved
75
+
76
+ ### 2021-06-10
77
+
78
+ 1. Improve text rendering
79
+
80
+ ### 2021-06-09
81
+
82
+ 1. New text region based text direction detection method
83
+ 2. Support running demo as web service
84
+
85
+ ### 2021-05-20
86
+
87
+ 1. Text detection model is now based on DBNet with ResNet34 backbone
88
+ 2. OCR model is now trained with more English sentences
89
+ 3. Inpaint model is now based on [AOT](https://arxiv.org/abs/2104.01431) which requires far less memory
90
+ 4. Default inpainting resolution is now increased to 2048, thanks to the new inpainting model
91
+ 5. Support merging hyphenated English words
92
+
93
+ ### 2021-05-11
94
+
95
+ 1. Add youdao translate and set as default translator
96
+
97
+ ### 2021-05-06
98
+
99
+ 1. Text detection model is now based on DBNet with ResNet101 backbone
100
+ 2. OCR model is now deeper
101
+ 3. Default detection resolution has been increased to 2048 from 1536
102
+
103
+ Note this version is slightly better at handling English texts, other than that it is worse in every other ways
104
+
105
+ ### 2021-03-04
106
+
107
+ 1. Added inpainting model
108
+
109
+ ### 2021-02-17
110
+
111
+ 1. First version launched
CHANGELOG_CN.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 更新日志 (中文)
2
+
3
+ ### 2023-11-11
4
+
5
+ 1. 添加了新的OCR模型`48px`
6
+
7
+ ### 2023-05-08
8
+
9
+ 1. 添加了[4x-UltraSharp](https://mega.nz/folder/qZRBmaIY#nIG8KyWFcGNTuMX_XNbJ_g)超分辨率
10
+
11
+ ### 2023-04-30
12
+
13
+ 1. 无数bug修复和重构
14
+ 2. 添加了[CRAFT](https://github.com/clovaai/CRAFT-pytorch)文本检测器,使用`--detector craft`启用
15
+
16
+ ### 2022-06-15
17
+
18
+ 1. 增加了来自[dmMaze](https://github.com/dmMaze)的LaMa MPE图像修补模型
19
+
20
+ ### 2022-04-23
21
+
22
+ 版本更新为beta-0.3
23
+
24
+ 1. 增加了来自[dmMaze](https://github.com/dmMaze)的英语文本渲染器
25
+ 2. 增加了基于CTC的OCR模型,识别速度大幅提升
26
+ 3. 新OCR模型增加韩语识别支持
27
+
28
+ ### 2022-03-19
29
+
30
+ 1. 增加了来自[pokedexter](https://github.com/pokedexter)的新文本渲染器
31
+ 2. 增加了来自[rspreet92](https://github.com/rspreet92)的人工翻译页面
32
+
33
+ ### 2022-01-24
34
+
35
+ 1. 增加了来自[dmMaze](https://github.com/dmMaze)的文本检测模型
36
+
37
+ ### 2021-08-21
38
+
39
+ 1. 文本区域合并算法更新,先已经实现几乎完美文本行合并
40
+ 2. 增加演示模式百度翻译支持
41
+ 3. 增加演示模式谷歌翻译支持
42
+ 4. 各类 bug 修复
43
+
44
+ ### 2021-07-29
45
+
46
+ 1. 网页版增加翻译器、分辨率和目标语言选项
47
+ 2. 文本颜色提取小腹提升
48
+
49
+ ### 2021-07-26
50
+
51
+ 程序所有组件都大幅升级,本程序现已进入 beta 版本! \
52
+ 注意:该版本所有英文检测只会输出大写字母。\
53
+ 你需要 Python>=3.8 版本才能运行
54
+
55
+ 1. 检测模型升级
56
+ 2. OCR 模型升级,文本颜色抽取质量大幅提升
57
+ 3. 图像修补模型升级
58
+ 4. 文本渲染升级,渲染更快,并支持更高质量的文本和文本阴影渲染
59
+ 5. 文字掩膜补全算法小幅提升
60
+ 6. 各类 BUG 修复
61
+ 7. 默认检测分辨率为 1536
62
+
63
+ ### 2021-07-09
64
+
65
+ 1. 修复不使用 inpainting 时图片错误
66
+
67
+ ### 2021-06-18
68
+
69
+ 1. 增加手动翻译选项
70
+ 2. 支持倾斜文本的识别和渲染
71
+
72
+ ### 2021-06-13
73
+
74
+ 1. 文字掩膜补全算法更新为基于 CRF 算法,补全质量大幅提升
75
+
76
+ ### 2021-06-10
77
+
78
+ 1. 完善文本渲染
79
+
80
+ ### 2021-06-09
81
+
82
+ 1. 使用基于区域的文本方向检测,文本方向检测效果大幅提升
83
+ 2. 增加 web 服务功能
84
+
85
+ ### 2021-05-20
86
+
87
+ 1. 检测模型更新为基于 ResNet34 的 DBNet
88
+ 2. OCR 模型更新增加更多英语预料训练
89
+ 3. 图像修补模型升级到基于[AOT](https://arxiv.org/abs/2104.01431)的模型,占用更少显存
90
+ 4. 图像修补默认分辨率增加到 2048
91
+ 5. 支持多行英语单词合并
92
+
93
+ ### 2021-05-11
94
+
95
+ 1. 增加并默认使用有道翻译
96
+
97
+ ### 2021-05-06
98
+
99
+ 1. 检测模型更新为基于 ResNet101 的 DBNet
100
+ 2. OCR 模型更新更深
101
+ 3. 默认检测分辨率增加到 2048
102
+
103
+ 注意这个版本除了英文检测稍微好一些,其他方面都不如之前版本
104
+
105
+ ### 2021-03-04
106
+
107
+ 1. 添加图片修补模型
108
+
109
+ ### 2021-02-17
110
+
111
+ 1. 初步版本发布
Dockerfile ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:latest
2
+
3
+ RUN useradd -m -u 1000 user
4
+
5
+ WORKDIR /app
6
+
7
+ RUN apt-get update
8
+ RUN DEBIAN_FRONTEND=noninteractive TZ=asia/shanghai apt-get -y install tzdata
9
+ # 设置缓存环境变量
10
+ ENV TRANSFORMERS_CACHE=/app/cache
11
+ ENV DEEPL_AUTH_KEY="6e4907cd-8926-42e7-aa5d-7561363c82b1:fx"
12
+ ENV OPENAI_API_KEY="sk-yuBWvBk2lTQoJFYP24A03515D46041429f907dE81cC3F04e"
13
+ ENV OPENAI_HTTP_PROXY="https://www.ygxdapi.top"
14
+ RUN mkdir -p /app/cache
15
+ # Assume root to install required dependencies
16
+ RUN apt-get install -y git g++ ffmpeg libsm6 libxext6 libvulkan-dev
17
+
18
+
19
+ # Install pip dependencies
20
+
21
+ COPY --chown=user requirements.txt /app/requirements.txt
22
+
23
+ RUN pip install -r /app/requirements.txt
24
+ RUN pip install torchvision --force-reinstall
25
+ RUN pip install "numpy<2.0"
26
+ # RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
27
+
28
+ RUN apt-get remove -y g++ && \
29
+ apt-get autoremove -y
30
+
31
+ # Copy app
32
+ COPY --chown=user . /app
33
+
34
+ # Prepare models
35
+ RUN python -u docker_prepare.py
36
+
37
+ RUN rm -rf /tmp
38
+
39
+ # Add /app to Python module path
40
+ ENV PYTHONPATH="${PYTHONPATH}:/app"
41
+
42
+ WORKDIR /app
43
+ RUN mkdir -p /app/result && chmod 777 /app/result
44
+ RUN mkdir -p /app/models/translators && chmod 777 /app/models/translators
45
+ RUN mkdir -p /app/models/upscaling && chmod 777 /app/models/upscaling
46
+ RUN mkdir -p /app/cache/models && chmod 777 /app/cache/models
47
+ RUN mkdir -p /app/cache/.locks && chmod 777 /app/cache/.locks
48
+ RUN mkdir -p /app/cache/models--kha-white--manga-ocr-base && chmod 777 /app/cache/models--kha-white--manga-ocr-base
49
+ RUN mkdir -p /app && chmod 777 /app
50
+
51
+ ENTRYPOINT ["python", "-m", "manga_translator", "-v", "--mode", "web", "--host", "0.0.0.0", "--port", "7860", "--font-size", "28", "--font-size-offset", "5", "--unclip-ratio", "1.1", "--det-invert"]
52
+ # # ENTRYPOINT ["python", "-m", "manga_translator", "-v", "--mode", "web", "--host", "0.0.0.0", "--port", "7860", "--use-cuda", "--use-inpainting"]
53
+
54
+
55
+ # 使用指定的基础镜像
56
+ # FROM zyddnys/manga-image-translator:main
57
+
58
+ # 复制需要的文件到容器中
59
+ # COPY ./../../translate_demo.py /app/translate_demo.py
60
+
61
+ # # 暴露端口
62
+ # EXPOSE 7860
63
+
64
+ # # 运行命令
65
+ # CMD ["--verbose", "--log-web", "--mode", "web", "--use-inpainting", "--use-cuda", "--host=0.0.0.0", "--port=7860"]
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
Makefile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build-image:
2
+ docker rmi manga-image-translator || true
3
+ docker build . --tag=manga-image-translator
4
+
5
+ run-web-server:
6
+ docker run --gpus all -p 5003:5003 --ipc=host --rm zyddnys/manga-image-translator:main \
7
+ --target-lang=ENG \
8
+ --manga2eng \
9
+ --verbose \
10
+ --mode=web \
11
+ --use-gpu \
12
+ --host=0.0.0.0 \
13
+ --port=5003
README_CN.md ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 漫画图片翻译器 (中文说明)
2
+
3
+ > 一键翻译各类图片内文字\
4
+ > [English](README.md) | [更新日志](CHANGELOG_CN.md) \
5
+ > 欢迎加入我们的 Discord <https://discord.gg/Ak8APNy4vb>
6
+
7
+ 针对群内、各个图站上大量不太可能会有人去翻译的图片设计,让我这种日语小白能够勉强看懂图片\
8
+ 主要支持日语,汉语、英文和韩语\
9
+ 支持图片修补和嵌字\
10
+ 该项目是[求闻转译志](https://github.com/PatchyVideo/MMDOCR-HighPerformance)的 v2 版本
11
+
12
+ **只是初步版本,我们需要您的帮助完善**\
13
+ 这个项目目前只完成了简单的 demo,依旧存在大量不完善的地方,我们需要您的帮助完善这个项目!
14
+
15
+ ## 支持我们
16
+
17
+ 请支持我们使用 GPU 服务器,谢谢!
18
+
19
+ - Ko-fi: <https://ko-fi.com/voilelabs>
20
+ - Patreon: <https://www.patreon.com/voilelabs>
21
+ - 爱发电: <https://afdian.net/@voilelabs>
22
+
23
+ ## 在线版
24
+
25
+ 官方演示站 (由 zyddnys 维护): <https://cotrans.touhou.ai/>\
26
+ 镜像站 (由 Eidenz 维护): <https://manga.eidenz.com/>\
27
+ 浏览器脚本 (由 QiroNT 维护): <https://greasyfork.org/scripts/437569>
28
+
29
+ - 注意如果在线版无法访问说明 Google GCP 又在重启我的服务器,此时请等待我重新开启服务。
30
+ - 在线版使用的是目前 main 分支最新版本。
31
+
32
+ ## 使用说明
33
+
34
+ ```bash
35
+ # 首先,确信你的机器安装了 Python 3.8 及以上版本,和相应的编译工具
36
+ $ python --version
37
+ Python 3.8.13
38
+
39
+ # 拉取仓库
40
+ $ git clone https://github.com/zyddnys/manga-image-translator.git
41
+
42
+ # 安装依赖
43
+ $ pip install -r requirements.txt
44
+ ```
45
+
46
+ 注意:`pydensecrf` 和其他pip包可能需要操作系统的相应编译工具(如Debian的build-essential)。
47
+
48
+ [使用谷歌翻译时可选]\
49
+ 申请有道翻译或者 DeepL 的 API,把你的 `APP_KEY` 和 `APP_SECRET` 或 `AUTH_KEY` 写入 `translators/key.py` 中。
50
+
51
+ ### 翻译器列表
52
+
53
+ | 名称 | 是否需要 API Key | 是否离线可用 | 其他说明 |
54
+ | -------------- | ------- | ------- | ----------------------------------------------------- |
55
+ | google | | | |
56
+ | youdao | ✔️ | | 需要 `YOUDAO_APP_KEY` 和 `YOUDAO_SECRET_KEY` |
57
+ | baidu | ✔️ | | 需要 `BAIDU_APP_ID` 和 `BAIDU_SECRET_KEY` |
58
+ | deepl | ✔️ | | 需要 `DEEPL_AUTH_KEY` |
59
+ | caiyun | ✔️ | | 需要 `CAIYUN_TOKEN` |
60
+ | gpt3 | ✔️ | | Implements text-davinci-003. Requires `OPENAI_API_KEY`|
61
+ | gpt3.5 | ✔️ | | Implements gpt-3.5-turbo. Requires `OPENAI_API_KEY` |
62
+ | gpt4 | ✔️ | | Implements gpt-4. Requires `OPENAI_API_KEY` |
63
+ | papago | | | |
64
+ | sakura | | |需要`SAKURA_API_BASE`|
65
+ | offline | | ✔️ | 自动选择可用的离线模型,只是选择器 |
66
+ | sugoi | | ✔️ | 只能翻译英文 |
67
+ | m2m100 | | ✔️ | 可以翻译所有语言 |
68
+ | m2m100_big | | ✔️ | 带big的是完整尺寸,不带是精简版 |
69
+ | none | | ✔️ | 翻译成空白文本 |
70
+ | mbart50 | | ✔️ | |
71
+ | original | | ✔️ | 翻译成源文本 |
72
+
73
+ ### 语言代码列表
74
+
75
+ 可以填入 `--target-lang` 参数
76
+
77
+ ```yaml
78
+ CHS: Chinese (Simplified)
79
+ CHT: Chinese (Traditional)
80
+ CSY: Czech
81
+ NLD: Dutch
82
+ ENG: English
83
+ FRA: French
84
+ DEU: German
85
+ HUN: Hungarian
86
+ ITA: Italian
87
+ JPN: Japanese
88
+ KOR: Korean
89
+ PLK: Polish
90
+ PTB: Portuguese (Brazil)
91
+ ROM: Romanian
92
+ RUS: Russian
93
+ ESP: Spanish
94
+ TRK: Turkish
95
+ VIN: Vietnames
96
+ ARA: Arabic
97
+ SRP: Serbian
98
+ HRV: Croatian
99
+ THA: Thai
100
+ IND: Indonesian
101
+ FIL: Filipino (Tagalog)
102
+ ```
103
+
104
+ <!-- Auto generated start (See devscripts/make_readme.py) -->
105
+ ## 选项
106
+
107
+ ```text
108
+ -h, --help show this help message and exit
109
+ -m, --mode {demo,batch,web,web_client,ws,api}
110
+ Run demo in single image demo mode (demo), batch
111
+ translation mode (batch), web service mode (web)
112
+ -i, --input INPUT [INPUT ...] Path to an image file if using demo mode, or path to an
113
+ image folder if using batch mode
114
+ -o, --dest DEST Path to the destination folder for translated images in
115
+ batch mode
116
+ -l, --target-lang {CHS,CHT,CSY,NLD,ENG,FRA,DEU,HUN,ITA,JPN,KOR,PLK,PTB,ROM,RUS,ESP,TRK,UKR,VIN,ARA,CNR,SRP,HRV,THA,IND,FIL}
117
+ Destination language
118
+ -v, --verbose Print debug info and save intermediate images in result
119
+ folder
120
+ -f, --format {png,webp,jpg,xcf,psd,pdf} Output format of the translation.
121
+ --attempts ATTEMPTS Retry attempts on encountered error. -1 means infinite
122
+ times.
123
+ --ignore-errors Skip image on encountered error.
124
+ --overwrite Overwrite already translated images in batch mode.
125
+ --skip-no-text Skip image without text (Will not be saved).
126
+ --model-dir MODEL_DIR Model directory (by default ./models in project root)
127
+ --use-gpu Turn on/off gpu (automatic selection between mps or cuda)
128
+ --use-gpu-limited Turn on/off gpu (excluding offline translator)
129
+ --detector {default,ctd,craft,none} Text detector used for creating a text mask from an
130
+ image, DO NOT use craft for manga, it's not designed
131
+ for it
132
+ --ocr {32px,48px,48px_ctc,mocr} Optical character recognition (OCR) model to use
133
+ --use-mocr-merge Use bbox merge when Manga OCR inference.
134
+ --inpainter {default,lama_large,lama_mpe,sd,none,original}
135
+ Inpainting model to use
136
+ --upscaler {waifu2x,esrgan,4xultrasharp} Upscaler to use. --upscale-ratio has to be set for it
137
+ to take effect
138
+ --upscale-ratio UPSCALE_RATIO Image upscale ratio applied before detection. Can
139
+ improve text detection.
140
+ --colorizer {mc2} Colorization model to use.
141
+ --translator {google,youdao,baidu,deepl,papago,caiyun,gpt3,gpt3.5,gpt4,none,original,offline,nllb,nllb_big,sugoi,jparacrawl,jparacrawl_big,m2m100,sakura}
142
+ Language translator to use
143
+ --translator-chain TRANSLATOR_CHAIN Output of one translator goes in another. Example:
144
+ --translator-chain "google:JPN;sugoi:ENG".
145
+ --selective-translation SELECTIVE_TRANSLATION
146
+ Select a translator based on detected language in
147
+ image. Note the first translation service acts as
148
+ default if the language isn't defined. Example:
149
+ --translator-chain "google:JPN;sugoi:ENG".
150
+ --revert-upscaling Downscales the previously upscaled image after
151
+ translation back to original size (Use with --upscale-
152
+ ratio).
153
+ --detection-size DETECTION_SIZE Size of image used for detection
154
+ --det-rotate Rotate the image for detection. Might improve
155
+ detection.
156
+ --det-auto-rotate Rotate the image for detection to prefer vertical
157
+ textlines. Might improve detection.
158
+ --det-invert Invert the image colors for detection. Might improve
159
+ detection.
160
+ --det-gamma-correct Applies gamma correction for detection. Might improve
161
+ detection.
162
+ --unclip-ratio UNCLIP_RATIO How much to extend text skeleton to form bounding box
163
+ --box-threshold BOX_THRESHOLD Threshold for bbox generation
164
+ --text-threshold TEXT_THRESHOLD Threshold for text detection
165
+ --min-text-length MIN_TEXT_LENGTH Minimum text length of a text region
166
+ --no-text-lang-skip Dont skip text that is seemingly already in the target
167
+ language.
168
+ --inpainting-size INPAINTING_SIZE Size of image used for inpainting (too large will
169
+ result in OOM)
170
+ --inpainting-precision {fp32,fp16,bf16} Inpainting precision for lama, use bf16 while you can.
171
+ --colorization-size COLORIZATION_SIZE Size of image used for colorization. Set to -1 to use
172
+ full image size
173
+ --denoise-sigma DENOISE_SIGMA Used by colorizer and affects color strength, range
174
+ from 0 to 255 (default 30). -1 turns it off.
175
+ --mask-dilation-offset MASK_DILATION_OFFSET By how much to extend the text mask to remove left-over
176
+ text pixels of the original image.
177
+ --font-size FONT_SIZE Use fixed font size for rendering
178
+ --font-size-offset FONT_SIZE_OFFSET Offset font size by a given amount, positive number
179
+ increase font size and vice versa
180
+ --font-size-minimum FONT_SIZE_MINIMUM Minimum output font size. Default is
181
+ image_sides_sum/200
182
+ --font-color FONT_COLOR Overwrite the text fg/bg color detected by the OCR
183
+ model. Use hex string without the "#" such as FFFFFF
184
+ for a white foreground or FFFFFF:000000 to also have a
185
+ black background around the text.
186
+ --line-spacing LINE_SPACING Line spacing is font_size * this value. Default is 0.01
187
+ for horizontal text and 0.2 for vertical.
188
+ --force-horizontal Force text to be rendered horizontally
189
+ --force-vertical Force text to be rendered vertically
190
+ --align-left Align rendered text left
191
+ --align-center Align rendered text centered
192
+ --align-right Align rendered text right
193
+ --uppercase Change text to uppercase
194
+ --lowercase Change text to lowercase
195
+ --no-hyphenation If renderer should be splitting up words using a hyphen
196
+ character (-)
197
+ --manga2eng Render english text translated from manga with some
198
+ additional typesetting. Ignores some other argument
199
+ options
200
+ --gpt-config GPT_CONFIG Path to GPT config file, more info in README
201
+ --use-mtpe Turn on/off machine translation post editing (MTPE) on
202
+ the command line (works only on linux right now)
203
+ --save-text Save extracted text and translations into a text file.
204
+ --save-text-file SAVE_TEXT_FILE Like --save-text but with a specified file path.
205
+ --filter-text FILTER_TEXT Filter regions by their text with a regex. Example
206
+ usage: --text-filter ".*badtext.*"
207
+ --skip-lang Skip translation if source image is one of the provide languages,
208
+ use comma to separate multiple languages. Example: JPN,ENG
209
+ --prep-manual Prepare for manual typesetting by outputting blank,
210
+ inpainted images, plus copies of the original for
211
+ reference
212
+ --font-path FONT_PATH Path to font file
213
+ --gimp-font GIMP_FONT Font family to use for gimp rendering.
214
+ --host HOST Used by web module to decide which host to attach to
215
+ --port PORT Used by web module to decide which port to attach to
216
+ --nonce NONCE Used by web module as secret for securing internal web
217
+ server communication
218
+ --ws-url WS_URL Server URL for WebSocket mode
219
+ --save-quality SAVE_QUALITY Quality of saved JPEG image, range from 0 to 100 with
220
+ 100 being best
221
+ --ignore-bubble IGNORE_BUBBLE The threshold for ignoring text in non bubble areas,
222
+ with valid values ranging from 1 to 50, does not ignore
223
+ others. Recommendation 5 to 10. If it is too low,
224
+ normal bubble areas may be ignored, and if it is too
225
+ large, non bubble areas may be considered normal
226
+ bubbles
227
+ ```
228
+
229
+ <!-- Auto generated end -->
230
+
231
+ ### 使用命令行执行
232
+
233
+ ```bash
234
+ # 如果机器有支持 CUDA 的 NVIDIA GPU,可以添加 `--use-gpu` 参数
235
+ # 使用 `--use-gpu-limited` 将需要使用大量显存的翻译交由CPU执行,这样可以减少显存占用
236
+ # 使用 `--translator=<翻译器名称>` 来指定翻译器
237
+ # 使用 `--target-lang=<语言代码>` 来指定目标语言
238
+ # 将 <图片文件路径> 替换为图片的路径
239
+ # 如果你要翻译的图片比较小或者模糊,可以使用upscaler提升图像大小与质量,从而提升检测翻译效果
240
+ $ python -m manga_translator --verbose --use-gpu --translator=google --target-lang=CHS -i <path_to_image_file>
241
+ # 结果会存放到 result 文件夹里
242
+ ```
243
+
244
+ #### 使用命令行批量翻译
245
+
246
+ ```bash
247
+ # 其它参数如上
248
+ # 使用 `--mode batch` 开启批量翻译模式
249
+ # 将 <图片文件夹路径> 替换为图片文件夹的路径
250
+ $ python -m manga_translator --verbose --mode batch --use-gpu --translator=google --target-lang=CHS -i <图片文件夹路径>
251
+ # 结果会存放到 `<图片文件夹路径>-translated` 文件夹里
252
+ ```
253
+
254
+ ### 使用浏览器 (Web 服务器)
255
+
256
+ ```bash
257
+ # 其它参数如上
258
+ # 使用 `--mode web` 开启 Web 服务器模式
259
+ $ python -m manga_translator --verbose --mode web --use-gpu
260
+ # 程序服务会开启在 http://127.0.0.1:5003
261
+ ```
262
+
263
+ 程序提供两个请求模式:同步模式和异步模式。\
264
+ 同步模式下你的 HTTP POST 请求会一直等待直到翻译完成。\
265
+ 异步模式下你的 HTTP POST 会立刻返回一个 `task_id`,你可以使用这个 `task_id` 去定期轮询得到翻译的状态。
266
+
267
+ #### 同步模式
268
+
269
+ 1. POST 提交一个带图片,名字是 file 的 form 到 <http://127.0.0.1:5003/run>
270
+ 2. 等待返回
271
+ 3. 从得到的 `task_id` 去 result 文件夹里取结果,例如通过 Nginx 暴露 result 下的内容
272
+
273
+ #### 异步模式
274
+
275
+ 1. POST 提交一个带图片,名字是 file 的 form 到<http://127.0.0.1:5003/submit>
276
+ 2. 你会得到一个 `task_id`
277
+ 3. 通过这个 `task_id` 你可以定期发送 POST 轮询请求 JSON `{"taskid": <task_id>}` 到 <http://127.0.0.1:5003/task-state>
278
+ 4. 当返回的状态是 `finished`、`error` 或 `error-lang` 时代表翻译完成
279
+ 5. 去 result 文件夹里取结果,例如通过 Nginx 暴露 result 下的内容
280
+
281
+ #### 人工翻译
282
+
283
+ 人工翻译允许代替机翻手动填入翻译后文本
284
+
285
+ POST 提交一个带图片,名字是 file 的 form 到 <http://127.0.0.1:5003/manual-translate>,并等待返回
286
+
287
+ 你会得到一个 JSON 数组,例如:
288
+
289
+ ```json
290
+ {
291
+ "task_id": "12c779c9431f954971cae720eb104499",
292
+ "status": "pending",
293
+ "trans_result": [
294
+ {
295
+ "s": "☆上司来ちゃった……",
296
+ "t": ""
297
+ }
298
+ ]
299
+ }
300
+ ```
301
+
302
+ 将翻译后内容填入 t 字符串:
303
+
304
+ ```json
305
+ {
306
+ "task_id": "12c779c9431f954971cae720eb104499",
307
+ "status": "pending",
308
+ "trans_result": [
309
+ {
310
+ "s": "☆上司来ちゃった……",
311
+ "t": "☆上司来了..."
312
+ }
313
+ ]
314
+ }
315
+ ```
316
+
317
+ 将该 JSON 发送到 <http://127.0.0.1:5003/post-manual-result>,并等待返回\
318
+ 之后就可以从得到的 `task_id` 去 result 文件夹里取结果,例如通过 Nginx 暴露 result 下的内容
319
+
320
+ ## 下一步
321
+
322
+ 列一下以后完善这个项目需要做的事,欢迎贡献!
323
+
324
+ 1. 使用基于扩散模型的图像修补算法,不过这样图像修补会慢很多
325
+ 2. ~~【重要,请求帮助】目前的文字渲染引擎只能勉强看,和 Adobe 的渲染引擎差距明显,我们需要您的帮助完善文本渲染!~~
326
+ 3. ~~我尝试了在 OCR 模型里提取文字颜色,均以失败告终,现在只能用 DPGMM 凑活提取文字颜色,但是效果欠佳,我会尽量完善文字颜色提取,如果您有好的建议请尽管提 issue~~
327
+ 4. ~~文本检测目前不能很好处理英语和韩语,等图片修补模型训练好了我就会训练新版的文字检测模型。~~ ~~韩语支持在做了~~
328
+ 5. 文本渲染区域是根据检测到的文本,而不是汽包决定的,这样可以处理没有汽包的图片但是不能很好进行英语嵌字,目前没有想到好的解决方案。
329
+ 6. [Ryota et al.](https://arxiv.org/abs/2012.14271) 提出了获取配对漫画作为训练数据,训练可以结合图片内容进行翻译的模型,未来可以考虑把大量图片 VQVAE 化,输入 nmt 的 encoder 辅助翻译,而不是分框提取 tag 辅助翻译,这样可以处理范围更广的图片。这需要我们也获取大量配对翻译漫画/图片数据,以及训练 VQVAE 模型。
330
+ 7. 求闻转译志针对视频设计,未来这个项目要能优化到可以处理视频,提取文本颜色用于生成 ass 字幕,进一步辅助东方视频字幕组工作。甚至可以涂改视频内容,去掉视频内字幕。
331
+ 8. ~~结合传统算法的 mask 生成优化,目前在测试 CRF 相关算法。~~
332
+ 9. ~~尚不支持倾斜文本区域合并~~
333
+
334
+ ## 效果图
335
+
336
+ 以下样例可能并未经常更新,可能不能代表当前主分支版本的效果。
337
+
338
+ <table>
339
+ <thead>
340
+ <tr>
341
+ <th align="center" width="50%">原始图片</th>
342
+ <th align="center" width="50%">翻译后图片</th>
343
+ </tr>
344
+ </thead>
345
+ <tbody>
346
+ <tr>
347
+ <td align="center" width="50%">
348
+ <a href="https://user-images.githubusercontent.com/31543482/232265329-6a560438-e887-4f7f-b6a1-a61b8648f781.png">
349
+ <img alt="佐藤さんは知っていた - 猫麦" src="https://user-images.githubusercontent.com/31543482/232265329-6a560438-e887-4f7f-b6a1-a61b8648f781.png" />
350
+ </a>
351
+ <br />
352
+ <a href="https://twitter.com/09ra_19ra/status/1647079591109103617/photo/1">(Source @09ra_19ra)</a>
353
+ </td>
354
+ <td align="center" width="50%">
355
+ <a href="https://user-images.githubusercontent.com/31543482/232265339-514c843a-0541-4a24-b3bc-1efa6915f757.png">
356
+ <img alt="Output" src="https://user-images.githubusercontent.com/31543482/232265339-514c843a-0541-4a24-b3bc-1efa6915f757.png" />
357
+ </a>
358
+ <br />
359
+ <a href="https://user-images.githubusercontent.com/31543482/232265376-01a4557d-8120-4b6b-b062-f271df177770.png">(Mask)</a>
360
+ </td>
361
+ </tr>
362
+ <tr>
363
+ <td align="center" width="50%">
364
+ <a href="https://user-images.githubusercontent.com/31543482/232265479-a15c43b5-0f00-489c-9b04-5dfbcd48c432.png">
365
+ <img alt="Gris finds out she's of royal blood - VERTI" src="https://user-images.githubusercontent.com/31543482/232265479-a15c43b5-0f00-489c-9b04-5dfbcd48c432.png" />
366
+ </a>
367
+ <br />
368
+ <a href="https://twitter.com/VERTIGRIS_ART/status/1644365184142647300/photo/1">(Source @VERTIGRIS_ART)</a>
369
+ </td>
370
+ <td align="center" width="50%">
371
+ <a href="https://user-images.githubusercontent.com/31543482/232265480-f8ba7a28-846f-46e7-8041-3dcb1afe3f67.png">
372
+ <img alt="Output" src="https://user-images.githubusercontent.com/31543482/232265480-f8ba7a28-846f-46e7-8041-3dcb1afe3f67.png" />
373
+ </a>
374
+ <br />
375
+ <code>--detector ctd</code>
376
+ <a href="https://user-images.githubusercontent.com/31543482/232265483-99ad20af-dca8-4b78-90f9-a6599eb0e70b.png">(Mask)</a>
377
+ </td>
378
+ </tr>
379
+ <tr>
380
+ <td align="center" width="50%">
381
+ <a href="https://user-images.githubusercontent.com/31543482/232264684-5a7bcf8e-707b-4925-86b0-4212382f1680.png">
382
+ <img alt="陰キャお嬢様の新学期🏫📔🌸 (#3) - ひづき夜宵🎀💜" src="https://user-images.githubusercontent.com/31543482/232264684-5a7bcf8e-707b-4925-86b0-4212382f1680.png" />
383
+ </a>
384
+ <br />
385
+ <a href="https://twitter.com/hiduki_yayoi/status/1645186427712573440/photo/2">(Source @hiduki_yayoi)</a>
386
+ </td>
387
+ <td align="center" width="50%">
388
+ <a href="https://user-images.githubusercontent.com/31543482/232264644-39db36c8-a8d9-4009-823d-bf85ca0609bf.png">
389
+ <img alt="Output" src="https://user-images.githubusercontent.com/31543482/232264644-39db36c8-a8d9-4009-823d-bf85ca0609bf.png" />
390
+ </a>
391
+ <br />
392
+ <code>--translator none</code>
393
+ <a href="https://user-images.githubusercontent.com/31543482/232264671-bc8dd9d0-8675-4c6d-8f86-0d5b7a342233.png">(Mask)</a>
394
+ </td>
395
+ </tr>
396
+ <tr>
397
+ <td align="center" width="50%">
398
+ <a href="https://user-images.githubusercontent.com/31543482/232265794-5ea8a0cb-42fe-4438-80b7-3bf7eaf0ff2c.png">
399
+ <img alt="幼なじみの高校デビューの癖がすごい (#1) - 神吉李花☪️🐧" src="https://user-images.githubusercontent.com/31543482/232265794-5ea8a0cb-42fe-4438-80b7-3bf7eaf0ff2c.png" />
400
+ </a>
401
+ <br />
402
+ <a href="https://twitter.com/rikak/status/1642727617886556160/photo/1">(Source @rikak)</a>
403
+ </td>
404
+ <td align="center" width="50%">
405
+ <a href="https://user-images.githubusercontent.com/31543482/232265795-4bc47589-fd97-4073-8cf4-82ae216a88bc.png">
406
+ <img alt="Output" src="https://user-images.githubusercontent.com/31543482/232265795-4bc47589-fd97-4073-8cf4-82ae216a88bc.png" />
407
+ </a>
408
+ <br />
409
+ <a href="https://user-images.githubusercontent.com/31543482/232265800-6bdc7973-41fe-4d7e-a554-98ea7ca7a137.png">(Mask)</a>
410
+ </td>
411
+ </tr>
412
+ </tbody>
413
+ </table>
devscripts/make_readme.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Adapted from https://github.com/yt-dlp/yt-dlp/tree/master/devscripts
4
+
5
+ import os
6
+ import sys
7
+
8
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+
10
+ import functools
11
+ import re
12
+
13
+ from devscripts.utils import read_file, write_file
14
+ from manga_translator.args import HelpFormatter, parser
15
+
16
+ READMES = (
17
+ [
18
+ 'README.md',
19
+ '## Options',
20
+ '<!-- Auto generated end -->',
21
+ ],
22
+ [
23
+ 'README_CN.md',
24
+ '## 选项',
25
+ '<!-- Auto generated end -->',
26
+ ],
27
+ )
28
+
29
+ ALLOWED_OVERSHOOT = 2
30
+ DISABLE_PATCH = object()
31
+
32
+ HelpFormatter.INDENT_INCREMENT = 0
33
+ HelpFormatter.MAX_HELP_POSITION = 45
34
+ HelpFormatter.WIDTH = 100
35
+
36
+ def take_section(text, start=None, end=None, *, shift=0):
37
+ return text[
38
+ text.index(start) + shift if start else None:
39
+ text.index(end) + shift if end else None
40
+ ]
41
+
42
+
43
+ def apply_patch(text, patch):
44
+ return text if patch[0] is DISABLE_PATCH else re.sub(*patch, text)
45
+
46
+
47
+ options = take_section(parser.format_help(), '\noptions:', shift=len('\noptions:'))
48
+
49
+ max_width = max(map(len, options.split('\n')))
50
+ switch_col_width = len(re.search(r'(?m)^\s{5,}', options).group())
51
+ delim = f'\n{" " * switch_col_width}'
52
+
53
+ PATCHES = (
54
+ # ( # Headings
55
+ # r'(?m)^ (\w.+\n)( (?=\w))?',
56
+ # r'## \1'
57
+ # ),
58
+ ( # Fixup `--date` formatting
59
+ rf'(?m)( --date DATE.+({delim}[^\[]+)*)\[.+({delim}.+)*$',
60
+ (rf'\1[now|today|yesterday][-N[day|week|month|year]].{delim}'
61
+ f'E.g. "--date today-2weeks" downloads only{delim}'
62
+ 'videos uploaded on the same day two weeks ago'),
63
+ ),
64
+ ( # Do not split URLs
65
+ rf'({delim[:-1]})? (?P<label>\[\S+\] )?(?P<url>https?({delim})?:({delim})?/({delim})?/(({delim})?\S+)+)\s',
66
+ lambda mobj: ''.join((delim, mobj.group('label') or '', re.sub(r'\s+', '', mobj.group('url')), '\n'))
67
+ ),
68
+ ( # Do not split "words"
69
+ rf'(?m)({delim}\S+)+$',
70
+ lambda mobj: ''.join((delim, mobj.group(0).replace(delim, '')))
71
+ ),
72
+ # ( # Allow overshooting last line
73
+ # rf'(?m)^(?P<prev>.+)${delim}(?P<current>.+)$(?!{delim})',
74
+ # lambda mobj: (mobj.group().replace(delim, ' ')
75
+ # if len(mobj.group()) - len(delim) + 1 <= max_width + ALLOWED_OVERSHOOT
76
+ # else mobj.group())
77
+ # ),
78
+ # ( # Avoid newline when a space is available b/w switch and description
79
+ # DISABLE_PATCH, # This creates issues with prepare_manpage
80
+ # r'(?m)^(\s{4}-.{%d})(%s)' % (switch_col_width - 6, delim),
81
+ # r'\1 '
82
+ # ),
83
+ # ( # Replace brackets with a Markdown link
84
+ # r'SponsorBlock API \((http.+)\)',
85
+ # r'[SponsorBlock API](\1)'
86
+ # ),
87
+ )
88
+
89
+ for file, options_start, options_end in READMES:
90
+ readme = read_file(file)
91
+
92
+ write_file(file, ''.join((
93
+ take_section(readme, end=options_start, shift=len(options_start)),
94
+ '\n\n```text',
95
+ functools.reduce(apply_patch, PATCHES, options),
96
+ '```\n\n',
97
+ take_section(readme, options_end),
98
+ )))
devscripts/utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/yt-dlp/yt-dlp/tree/master/devscripts
2
+
3
+ import argparse
4
+ import functools
5
+ import subprocess
6
+
7
+
8
+ def read_file(fname):
9
+ with open(fname, encoding='utf-8') as f:
10
+ return f.read()
11
+
12
+
13
+ def write_file(fname, content, mode='w'):
14
+ with open(fname, mode, encoding='utf-8') as f:
15
+ return f.write(content)
16
+
17
+
18
+ def get_filename_args(has_infile=False, default_outfile=None):
19
+ parser = argparse.ArgumentParser()
20
+ if has_infile:
21
+ parser.add_argument('infile', help='Input file')
22
+ kwargs = {'nargs': '?', 'default': default_outfile} if default_outfile else {}
23
+ parser.add_argument('outfile', **kwargs, help='Output file')
24
+
25
+ opts = parser.parse_args()
26
+ if has_infile:
27
+ return opts.infile, opts.outfile
28
+ return opts.outfile
29
+
30
+
31
+ def compose_functions(*functions):
32
+ return lambda x: functools.reduce(lambda y, f: f(y), functions, x)
33
+
34
+
35
+ def run_process(*args, **kwargs):
36
+ kwargs.setdefault('text', True)
37
+ kwargs.setdefault('check', True)
38
+ kwargs.setdefault('capture_output', True)
39
+ if kwargs['text']:
40
+ kwargs.setdefault('encoding', 'utf-8')
41
+ kwargs.setdefault('errors', 'replace')
42
+ return subprocess.run(args, **kwargs)
docker_prepare.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from manga_translator.utils import ModelWrapper
4
+ from manga_translator.detection import DETECTORS
5
+ from manga_translator.ocr import OCRS
6
+ from manga_translator.inpainting import INPAINTERS
7
+
8
+ async def download(dict):
9
+ for key, value in dict.items():
10
+ if issubclass(value, ModelWrapper):
11
+ print(' -- Downloading', key)
12
+ try:
13
+ inst = value()
14
+ await inst.download()
15
+ except Exception as e:
16
+ print('Failed to download', key, value)
17
+ print(e)
18
+
19
+ async def main():
20
+ await download(DETECTORS)
21
+ await download(OCRS)
22
+ await download({
23
+ k: v for k, v in INPAINTERS.items()
24
+ if k not in ['sd']
25
+ })
26
+
27
+ if __name__ == '__main__':
28
+ asyncio.run(main())
fonts/Arial-Unicode-Regular.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14f28249244f00c13348cb211c8a83c3e6e44dcf1874ebcb083efbfc0b9d5387
3
+ size 23892708
fonts/anime_ace.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3e311d48c305e79757cc0051aca591b735eb57002f78035969cbfc5ca4a5125
3
+ size 108036
fonts/anime_ace_3.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9b7c40b5389c511a950234fe0add8a11da9563b468e0e8a88219ccbf2257f83
3
+ size 58236
fonts/comic shanns 2.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64590b794cab741937889d379b205ae126ca4f3ed5cbe4f19839d2bfac246da6
3
+ size 73988
fonts/msgothic.ttc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef9044f54896c6d045a425e62e38b3232d49facc5549a12837d077ff0bf74298
3
+ size 9176636
fonts/msyh.ttc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b3b9d058750fb80899c24f68e35beda606ca92694eff0e9f7f91eec7a846aa
3
+ size 19647736
manga_translator/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import colorama
2
+ from dotenv import load_dotenv
3
+
4
+ colorama.init(autoreset=True)
5
+ load_dotenv()
6
+
7
+ from .manga_translator import *
manga_translator/__main__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import logging
4
+ from argparse import Namespace
5
+
6
+ from .manga_translator import (
7
+ MangaTranslator,
8
+ MangaTranslatorWeb,
9
+ MangaTranslatorWS,
10
+ MangaTranslatorAPI,
11
+ set_main_logger,
12
+ )
13
+ from .args import parser
14
+ from .utils import (
15
+ BASE_PATH,
16
+ init_logging,
17
+ get_logger,
18
+ set_log_level,
19
+ natural_sort,
20
+ )
21
+
22
+ # TODO: Dynamic imports to reduce ram usage in web(-server) mode. Will require dealing with args.py imports.
23
+
24
+ async def dispatch(args: Namespace):
25
+ args_dict = vars(args)
26
+
27
+ logger.info(f'Running in {args.mode} mode')
28
+
29
+ if args.mode in ('demo', 'batch'):
30
+ if not args.input:
31
+ raise Exception('No input image was supplied. Use -i <image_path>')
32
+ translator = MangaTranslator(args_dict)
33
+ if args.mode == 'demo':
34
+ if len(args.input) != 1 or not os.path.isfile(args.input[0]):
35
+ raise FileNotFoundError(f'Invalid single image file path for demo mode: "{" ".join(args.input)}". Use `-m batch`.')
36
+ dest = os.path.join(BASE_PATH, 'result/final.png')
37
+ args.overwrite = True # Do overwrite result/final.png file
38
+ await translator.translate_path(args.input[0], dest, args_dict)
39
+ else: # batch
40
+ dest = args.dest
41
+ for path in natural_sort(args.input):
42
+ await translator.translate_path(path, dest, args_dict)
43
+
44
+ elif args.mode == 'web':
45
+ from .server.web_main import dispatch
46
+ await dispatch(args.host, args.port, translation_params=args_dict)
47
+
48
+ elif args.mode == 'web_client':
49
+ translator = MangaTranslatorWeb(args_dict)
50
+ await translator.listen(args_dict)
51
+
52
+ elif args.mode == 'ws':
53
+ translator = MangaTranslatorWS(args_dict)
54
+ await translator.listen(args_dict)
55
+
56
+ elif args.mode == 'api':
57
+ translator = MangaTranslatorAPI(args_dict)
58
+ await translator.listen(args_dict)
59
+
60
+ if __name__ == '__main__':
61
+ args = None
62
+ init_logging()
63
+ try:
64
+ args = parser.parse_args()
65
+ set_log_level(level=logging.DEBUG if args.verbose else logging.INFO)
66
+ logger = get_logger(args.mode)
67
+ set_main_logger(logger)
68
+ if args.mode != 'web':
69
+ logger.debug(args)
70
+
71
+ loop = asyncio.new_event_loop()
72
+ asyncio.set_event_loop(loop)
73
+ loop.run_until_complete(dispatch(args))
74
+ except KeyboardInterrupt:
75
+ if not args or args.mode != 'web':
76
+ print()
77
+ except Exception as e:
78
+ logger.error(f'{e.__class__.__name__}: {e}',
79
+ exc_info=e if args and args.verbose else None)
manga_translator/args.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from urllib.parse import unquote
4
+
5
+ from .detection import DETECTORS
6
+ from .ocr import OCRS
7
+ from .inpainting import INPAINTERS
8
+ from .translators import VALID_LANGUAGES, TRANSLATORS, TranslatorChain
9
+ from .upscaling import UPSCALERS
10
+ from .colorization import COLORIZERS
11
+ from .save import OUTPUT_FORMATS
12
+
13
+ def url_decode(s):
14
+ s = unquote(s)
15
+ if s.startswith('file:///'):
16
+ s = s[len('file://'):]
17
+ return s
18
+
19
+ # Additional argparse types
20
+ def path(string):
21
+ if not string:
22
+ return ''
23
+ s = url_decode(os.path.expanduser(string))
24
+ if not os.path.exists(s):
25
+ raise argparse.ArgumentTypeError(f'No such file or directory: "{string}"')
26
+ return s
27
+
28
+ def file_path(string):
29
+ if not string:
30
+ return ''
31
+ s = url_decode(os.path.expanduser(string))
32
+ if not os.path.exists(s):
33
+ raise argparse.ArgumentTypeError(f'No such file: "{string}"')
34
+ return s
35
+
36
+ def dir_path(string):
37
+ if not string:
38
+ return ''
39
+ s = url_decode(os.path.expanduser(string))
40
+ if not os.path.exists(s):
41
+ raise argparse.ArgumentTypeError(f'No such directory: "{string}"')
42
+ return s
43
+
44
+ # def choice_chain(choices):
45
+ # """Argument type for string chains from choices separated by ':'. Example: 'choice1:choice2:choice3'"""
46
+ # def _func(string):
47
+ # if choices is not None:
48
+ # for s in string.split(':') or ['']:
49
+ # if s not in choices:
50
+ # raise argparse.ArgumentTypeError(f'Invalid choice: %s (choose from %s)' % (s, ', '.join(map(repr, choices))))
51
+ # return string
52
+ # return _func
53
+
54
+ def translator_chain(string):
55
+ try:
56
+ return TranslatorChain(string)
57
+ except ValueError as e:
58
+ raise argparse.ArgumentTypeError(e)
59
+ except Exception:
60
+ raise argparse.ArgumentTypeError(f'Invalid translator_chain value: "{string}". Example usage: --translator "google:sugoi" -l "JPN:ENG"')
61
+
62
+
63
+ class HelpFormatter(argparse.HelpFormatter):
64
+ INDENT_INCREMENT = 2
65
+ MAX_HELP_POSITION = 24
66
+ WIDTH = None
67
+
68
+ def __init__(self, prog: str, indent_increment: int = 2, max_help_position: int = 24, width: int = None):
69
+ super().__init__(prog, self.INDENT_INCREMENT, self.MAX_HELP_POSITION, self.WIDTH)
70
+
71
+ def _format_action_invocation(self, action: argparse.Action) -> str:
72
+ if action.option_strings:
73
+
74
+ # if the Optional doesn't take a value, format is:
75
+ # -s, --long
76
+ if action.nargs == 0:
77
+ return ', '.join(action.option_strings)
78
+
79
+ # if the Optional takes a value, format is:
80
+ # -s, --long ARGS
81
+ else:
82
+ default = self._get_default_metavar_for_optional(action)
83
+ args_string = self._format_args(action, default)
84
+ return ', '.join(action.option_strings) + ' ' + args_string
85
+ else:
86
+ return super()._format_action_invocation(action)
87
+
88
+
89
+ parser = argparse.ArgumentParser(prog='manga_translator', description='Seamlessly translate mangas into a chosen language', formatter_class=HelpFormatter)
90
+ parser.add_argument('-m', '--mode', default='batch', type=str, choices=['demo', 'batch', 'web', 'web_client', 'ws', 'api'], help='Run demo in single image demo mode (demo), batch translation mode (batch), web service mode (web)')
91
+ parser.add_argument('-i', '--input', default=None, type=path, nargs='+', help='Path to an image file if using demo mode, or path to an image folder if using batch mode')
92
+ parser.add_argument('-o', '--dest', default='', type=str, help='Path to the destination folder for translated images in batch mode')
93
+ parser.add_argument('-l', '--target-lang', default='CHS', type=str, choices=VALID_LANGUAGES, help='Destination language')
94
+ parser.add_argument('-v', '--verbose', action='store_true', help='Print debug info and save intermediate images in result folder')
95
+ parser.add_argument('-f', '--format', default=None, choices=OUTPUT_FORMATS, help='Output format of the translation.')
96
+ parser.add_argument('--attempts', default=0, type=int, help='Retry attempts on encountered error. -1 means infinite times.')
97
+ parser.add_argument('--ignore-errors', action='store_true', help='Skip image on encountered error.')
98
+ parser.add_argument('--overwrite', action='store_true', help='Overwrite already translated images in batch mode.')
99
+ parser.add_argument('--skip-no-text', action='store_true', help='Skip image without text (Will not be saved).')
100
+ parser.add_argument('--model-dir', default=None, type=dir_path, help='Model directory (by default ./models in project root)')
101
+ parser.add_argument('--skip-lang', default=None, type=str, help='Skip translation if source image is one of the provide languages, use comma to separate multiple languages. Example: JPN,ENG')
102
+
103
+ g = parser.add_mutually_exclusive_group()
104
+ g.add_argument('--use-gpu', action='store_true', help='Turn on/off gpu (auto switch between mps and cuda)')
105
+ g.add_argument('--use-gpu-limited', action='store_true', help='Turn on/off gpu (excluding offline translator)')
106
+
107
+ parser.add_argument('--detector', default='default', type=str, choices=DETECTORS, help='Text detector used for creating a text mask from an image, DO NOT use craft for manga, it\'s not designed for it')
108
+ parser.add_argument('--ocr', default='48px', type=str, choices=OCRS, help='Optical character recognition (OCR) model to use')
109
+ parser.add_argument('--use-mocr-merge', action='store_true', help='Use bbox merge when Manga OCR inference.')
110
+ parser.add_argument('--inpainter', default='lama_large', type=str, choices=INPAINTERS, help='Inpainting model to use')
111
+ parser.add_argument('--upscaler', default='esrgan', type=str, choices=UPSCALERS, help='Upscaler to use. --upscale-ratio has to be set for it to take effect')
112
+ parser.add_argument('--upscale-ratio', default=None, type=float, help='Image upscale ratio applied before detection. Can improve text detection.')
113
+ parser.add_argument('--colorizer', default=None, type=str, choices=COLORIZERS, help='Colorization model to use.')
114
+
115
+ g = parser.add_mutually_exclusive_group()
116
+ g.add_argument('--translator', default='google', type=str, choices=TRANSLATORS, help='Language translator to use')
117
+ g.add_argument('--translator-chain', default=None, type=translator_chain, help='Output of one translator goes in another. Example: --translator-chain "google:JPN;sugoi:ENG".')
118
+ g.add_argument('--selective-translation', default=None, type=translator_chain, help='Select a translator based on detected language in image. Note the first translation service acts as default if the language isn\'t defined. Example: --translator-chain "google:JPN;sugoi:ENG".')
119
+
120
+ parser.add_argument('--revert-upscaling', action='store_true', help='Downscales the previously upscaled image after translation back to original size (Use with --upscale-ratio).')
121
+ parser.add_argument('--detection-size', default=1536, type=int, help='Size of image used for detection')
122
+ parser.add_argument('--det-rotate', action='store_true', help='Rotate the image for detection. Might improve detection.')
123
+ parser.add_argument('--det-auto-rotate', action='store_true', help='Rotate the image for detection to prefer vertical textlines. Might improve detection.')
124
+ parser.add_argument('--det-invert', action='store_true', help='Invert the image colors for detection. Might improve detection.')
125
+ parser.add_argument('--det-gamma-correct', action='store_true', help='Applies gamma correction for detection. Might improve detection.')
126
+ parser.add_argument('--unclip-ratio', default=2.3, type=float, help='How much to extend text skeleton to form bounding box')
127
+ parser.add_argument('--box-threshold', default=0.7, type=float, help='Threshold for bbox generation')
128
+ parser.add_argument('--text-threshold', default=0.5, type=float, help='Threshold for text detection')
129
+ parser.add_argument('--min-text-length', default=0, type=int, help='Minimum text length of a text region')
130
+ parser.add_argument('--no-text-lang-skip', action='store_true', help='Dont skip text that is seemingly already in the target language.')
131
+ parser.add_argument('--inpainting-size', default=2048, type=int, help='Size of image used for inpainting (too large will result in OOM)')
132
+ parser.add_argument('--inpainting-precision', default='fp32', type=str, help='Inpainting precision for lama, use bf16 while you can.', choices=['fp32', 'fp16', 'bf16'])
133
+ parser.add_argument('--colorization-size', default=576, type=int, help='Size of image used for colorization. Set to -1 to use full image size')
134
+ parser.add_argument('--denoise-sigma', default=30, type=int, help='Used by colorizer and affects color strength, range from 0 to 255 (default 30). -1 turns it off.')
135
+ parser.add_argument('--mask-dilation-offset', default=0, type=int, help='By how much to extend the text mask to remove left-over text pixels of the original image.')
136
+
137
+ parser.add_argument('--disable-font-border', action='store_true', help='Disable font border')
138
+ parser.add_argument('--font-size', default=None, type=int, help='Use fixed font size for rendering')
139
+ parser.add_argument('--font-size-offset', default=0, type=int, help='Offset font size by a given amount, positive number increase font size and vice versa')
140
+ parser.add_argument('--font-size-minimum', default=-1, type=int, help='Minimum output font size. Default is image_sides_sum/200')
141
+ parser.add_argument('--font-color', default=None, type=str, help='Overwrite the text fg/bg color detected by the OCR model. Use hex string without the "#" such as FFFFFF for a white foreground or FFFFFF:000000 to also have a black background around the text.')
142
+ parser.add_argument('--line-spacing', default=None, type=float, help='Line spacing is font_size * this value. Default is 0.01 for horizontal text and 0.2 for vertical.')
143
+
144
+ g = parser.add_mutually_exclusive_group()
145
+ g.add_argument('--force-horizontal', action='store_true', help='Force text to be rendered horizontally')
146
+ g.add_argument('--force-vertical', action='store_true', help='Force text to be rendered vertically')
147
+
148
+ g = parser.add_mutually_exclusive_group()
149
+ g.add_argument('--align-left', action='store_true', help='Align rendered text left')
150
+ g.add_argument('--align-center', action='store_true', help='Align rendered text centered')
151
+ g.add_argument('--align-right', action='store_true', help='Align rendered text right')
152
+
153
+ g = parser.add_mutually_exclusive_group()
154
+ g.add_argument('--uppercase', action='store_true', help='Change text to uppercase')
155
+ g.add_argument('--lowercase', action='store_true', help='Change text to lowercase')
156
+
157
+ parser.add_argument('--no-hyphenation', action='store_true', help='If renderer should be splitting up words using a hyphen character (-)')
158
+ parser.add_argument('--manga2eng', action='store_true', help='Render english text translated from manga with some additional typesetting. Ignores some other argument options')
159
+ parser.add_argument('--gpt-config', type=file_path, help='Path to GPT config file, more info in README')
160
+ parser.add_argument('--use-mtpe', action='store_true', help='Turn on/off machine translation post editing (MTPE) on the command line (works only on linux right now)')
161
+
162
+ g = parser.add_mutually_exclusive_group()
163
+ g.add_argument('--save-text', action='store_true', help='Save extracted text and translations into a text file.')
164
+ g.add_argument('--save-text-file', default='', type=str, help='Like --save-text but with a specified file path.')
165
+
166
+ parser.add_argument('--filter-text', default=None, type=str, help='Filter regions by their text with a regex. Example usage: --text-filter ".*badtext.*"')
167
+ parser.add_argument('--prep-manual', action='store_true', help='Prepare for manual typesetting by outputting blank, inpainted images, plus copies of the original for reference')
168
+ parser.add_argument('--font-path', default='', type=file_path, help='Path to font file')
169
+ parser.add_argument('--gimp-font', default='Sans-serif', type=str, help='Font family to use for gimp rendering.')
170
+ parser.add_argument('--host', default='127.0.0.1', type=str, help='Used by web module to decide which host to attach to')
171
+ parser.add_argument('--port', default=5003, type=int, help='Used by web module to decide which port to attach to')
172
+ parser.add_argument('--nonce', default=os.getenv('MT_WEB_NONCE', ''), type=str, help='Used by web module as secret for securing internal web server communication')
173
+ # parser.add_argument('--log-web', action='store_true', help='Used by web module to decide if web logs should be surfaced')
174
+ parser.add_argument('--ws-url', default='ws://localhost:5000', type=str, help='Server URL for WebSocket mode')
175
+ parser.add_argument('--save-quality', default=100, type=int, help='Quality of saved JPEG image, range from 0 to 100 with 100 being best')
176
+ parser.add_argument('--ignore-bubble', default=0, type=int, help='The threshold for ignoring text in non bubble areas, with valid values ranging from 1 to 50, does not ignore others. Recommendation 5 to 10. If it is too low, normal bubble areas may be ignored, and if it is too large, non bubble areas may be considered normal bubbles')
177
+
178
+ parser.add_argument('--kernel-size', default=3, type=int, help='Set the convolution kernel size of the text erasure area to completely clean up text residues')
179
+
180
+
181
+ # Generares dict with a default value for each argument
182
+ DEFAULT_ARGS = vars(parser.parse_args([]))
manga_translator/colorization/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ from .common import CommonColorizer, OfflineColorizer
4
+ from .manga_colorization_v2 import MangaColorizationV2
5
+
6
+ COLORIZERS = {
7
+ 'mc2': MangaColorizationV2,
8
+ }
9
+ colorizer_cache = {}
10
+
11
+ def get_colorizer(key: str, *args, **kwargs) -> CommonColorizer:
12
+ if key not in COLORIZERS:
13
+ raise ValueError(f'Could not find colorizer for: "{key}". Choose from the following: %s' % ','.join(COLORIZERS))
14
+ if not colorizer_cache.get(key):
15
+ upscaler = COLORIZERS[key]
16
+ colorizer_cache[key] = upscaler(*args, **kwargs)
17
+ return colorizer_cache[key]
18
+
19
+ async def prepare(key: str):
20
+ upscaler = get_colorizer(key)
21
+ if isinstance(upscaler, OfflineColorizer):
22
+ await upscaler.download()
23
+
24
+ async def dispatch(key: str, device: str = 'cpu', **kwargs) -> Image.Image:
25
+ colorizer = get_colorizer(key)
26
+ if isinstance(colorizer, OfflineColorizer):
27
+ await colorizer.load(device)
28
+ return await colorizer.colorize(**kwargs)
manga_translator/colorization/common.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from abc import abstractmethod
3
+
4
+ from ..utils import InfererModule, ModelWrapper
5
+
6
+ class CommonColorizer(InfererModule):
7
+ _VALID_UPSCALE_RATIOS = None
8
+
9
+ async def colorize(self, image: Image.Image, colorization_size: int, **kwargs) -> Image.Image:
10
+ return await self._colorize(image, colorization_size, **kwargs)
11
+
12
+ @abstractmethod
13
+ async def _colorize(self, image: Image.Image, colorization_size: int, **kwargs) -> Image.Image:
14
+ pass
15
+
16
+ class OfflineColorizer(CommonColorizer, ModelWrapper):
17
+ _MODEL_SUB_DIR = 'colorization'
18
+
19
+ async def _colorize(self, *args, **kwargs):
20
+ return await self.infer(*args, **kwargs)
21
+
22
+ @abstractmethod
23
+ async def _infer(self, image: Image.Image, colorization_size: int, **kwargs) -> Image.Image:
24
+ pass
manga_translator/colorization/manga_colorization_v2.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision.transforms import ToTensor
6
+
7
+ from .common import OfflineColorizer
8
+ from .manga_colorization_v2_utils.networks.models import Colorizer
9
+ from .manga_colorization_v2_utils.denoising.denoiser import FFDNetDenoiser
10
+ from .manga_colorization_v2_utils.utils.utils import resize_pad
11
+
12
+
13
+ # https://github.com/qweasdd/manga-colorization-v2
14
+ class MangaColorizationV2(OfflineColorizer):
15
+ _MODEL_SUB_DIR = os.path.join(OfflineColorizer._MODEL_SUB_DIR, 'manga-colorization-v2')
16
+ _MODEL_MAPPING = {
17
+ # Models were in google drive so had to upload to github
18
+ 'generator': {
19
+ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/manga-colorization-v2-generator.zip',
20
+ 'file': 'generator.zip',
21
+ 'hash': '087e6a0bc02770e732a52f33878b71a272a6123c9ac649e9b5bfb75e39e5c1d5',
22
+ },
23
+ 'denoiser': {
24
+ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/manga-colorization-v2-net_rgb.pth',
25
+ 'file': 'net_rgb.pth',
26
+ 'hash': '0fe98bfd2ac870b15f360661b1c4789eecefc6dc2e4462842a0dd15e149a0433',
27
+ }
28
+ }
29
+
30
+ async def _load(self, device: str):
31
+ self.device = device
32
+ self.colorizer = Colorizer().to(device)
33
+ self.colorizer.generator.load_state_dict(
34
+ torch.load(self._get_file_path('generator.zip'), map_location=self.device))
35
+ self.colorizer = self.colorizer.eval()
36
+ self.denoiser = FFDNetDenoiser(device, _weights_dir=self.model_dir)
37
+
38
+ async def _unload(self):
39
+ del self.colorizer
40
+ del self.denoiser
41
+
42
+ async def _infer(self, image: Image.Image, colorization_size: int, denoise_sigma=25, **kwargs) -> Image.Image:
43
+ # Size has to be multiple of 32
44
+ img = np.array(image.convert('RGBA'))
45
+ max_size = min(*img.shape[:2])
46
+ max_size -= max_size % 32
47
+ if colorization_size > 0:
48
+ size = min(max_size, colorization_size - (colorization_size % 32))
49
+ else:
50
+ # size<=576 gives best results
51
+ size = min(max_size, 576)
52
+
53
+ if 0 <= denoise_sigma and denoise_sigma <= 255:
54
+ img = self.denoiser.get_denoised_image(img, sigma=denoise_sigma)
55
+
56
+ img, current_pad = resize_pad(img, size)
57
+
58
+ transform = ToTensor()
59
+ current_image = transform(img).unsqueeze(0).to(self.device)
60
+ current_hint = torch.zeros(1, 4, current_image.shape[2], current_image.shape[3]).float().to(self.device)
61
+
62
+ with torch.no_grad():
63
+ fake_color, _ = self.colorizer(torch.cat([current_image, current_hint], 1))
64
+ fake_color = fake_color.detach()
65
+
66
+ result = fake_color[0].detach().cpu().permute(1, 2, 0) * 0.5 + 0.5
67
+
68
+ if current_pad[0] != 0:
69
+ result = result[:-current_pad[0]]
70
+ if current_pad[1] != 0:
71
+ result = result[:, :-current_pad[1]]
72
+
73
+ colored_image = result.numpy() * 255
74
+ return Image.fromarray(colored_image.astype(np.uint8))
manga_translator/colorization/manga_colorization_v2_utils/denoising/denoiser.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Denoise an image with the FFDNet denoising method
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import os
14
+ import argparse
15
+ import time
16
+
17
+
18
+ import numpy as np
19
+ import cv2
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.autograd import Variable
23
+ from .models import FFDNet
24
+ from .utils import normalize, variable_to_cv2_image, remove_dataparallel_wrapper, is_rgb
25
+
26
+ class FFDNetDenoiser:
27
+ def __init__(self, _device, _sigma = 25, _weights_dir = 'denoising/models/', _in_ch = 3):
28
+ self.sigma = _sigma / 255
29
+ self.weights_dir = _weights_dir
30
+ self.channels = _in_ch
31
+ self.device = _device
32
+ self.model = FFDNet(num_input_channels = _in_ch)
33
+ self.load_weights()
34
+ self.model.eval()
35
+
36
+
37
+ def load_weights(self):
38
+ weights_name = 'net_rgb.pth' if self.channels == 3 else 'net_gray.pth'
39
+ weights_path = os.path.join(self.weights_dir, weights_name)
40
+ if self.device == 'cuda':
41
+ # data paralles only for cuda , no need for mps devices
42
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
43
+ self.model = nn.DataParallel(self.model,device_ids = [0]).to(self.device)
44
+ else:
45
+ # MPS devices don't support DataParallel
46
+ state_dict = torch.load(weights_path, map_location=self.device)
47
+ # CPU mode: remove the DataParallel wrapper
48
+ state_dict = remove_dataparallel_wrapper(state_dict)
49
+ self.model.load_state_dict(state_dict)
50
+
51
+ def get_denoised_image(self, imorig, sigma = None):
52
+
53
+ if sigma is not None:
54
+ cur_sigma = sigma / 255
55
+ else:
56
+ cur_sigma = self.sigma
57
+
58
+ if len(imorig.shape) < 3 or imorig.shape[2] == 1:
59
+ imorig = np.repeat(np.expand_dims(imorig, 2), 3, 2)
60
+
61
+ imorig = imorig[..., :3]
62
+
63
+ if (max(imorig.shape[0], imorig.shape[1]) > 1200):
64
+ ratio = max(imorig.shape[0], imorig.shape[1]) / 1200
65
+ imorig = cv2.resize(imorig, (int(imorig.shape[1] / ratio), int(imorig.shape[0] / ratio)), interpolation = cv2.INTER_AREA)
66
+
67
+ imorig = imorig.transpose(2, 0, 1)
68
+
69
+ if (imorig.max() > 1.2):
70
+ imorig = normalize(imorig)
71
+ imorig = np.expand_dims(imorig, 0)
72
+
73
+ # Handle odd sizes
74
+ expanded_h = False
75
+ expanded_w = False
76
+ sh_im = imorig.shape
77
+ if sh_im[2]%2 == 1:
78
+ expanded_h = True
79
+ imorig = np.concatenate((imorig, imorig[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
80
+
81
+ if sh_im[3]%2 == 1:
82
+ expanded_w = True
83
+ imorig = np.concatenate((imorig, imorig[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
84
+
85
+
86
+ imorig = torch.Tensor(imorig)
87
+
88
+
89
+ # Sets data type according to CPU or GPU modes
90
+ if self.device == 'cuda':
91
+ dtype = torch.cuda.FloatTensor
92
+ else:
93
+ # for mps devices is still floatTensor
94
+ dtype = torch.FloatTensor
95
+
96
+ imnoisy = imorig#.clone()
97
+
98
+
99
+ with torch.no_grad():
100
+ imorig, imnoisy = imorig.type(dtype), imnoisy.type(dtype)
101
+ nsigma = torch.FloatTensor([cur_sigma]).type(dtype)
102
+
103
+
104
+ # Estimate noise and subtract it from the input image
105
+ im_noise_estim = self.model(imnoisy, nsigma)
106
+ outim = torch.clamp(imnoisy - im_noise_estim, 0., 1.)
107
+
108
+ if expanded_h:
109
+ # imorig = imorig[:, :, :-1, :]
110
+ outim = outim[:, :, :-1, :]
111
+ # imnoisy = imnoisy[:, :, :-1, :]
112
+
113
+ if expanded_w:
114
+ # imorig = imorig[:, :, :, :-1]
115
+ outim = outim[:, :, :, :-1]
116
+ # imnoisy = imnoisy[:, :, :, :-1]
117
+
118
+ return variable_to_cv2_image(outim)
manga_translator/colorization/manga_colorization_v2_utils/denoising/functions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions implementing custom NN layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch
14
+ from torch.autograd import Function, Variable
15
+
16
+ def concatenate_input_noise_map(input, noise_sigma):
17
+ r"""Implements the first layer of FFDNet. This function returns a
18
+ torch.autograd.Variable composed of the concatenation of the downsampled
19
+ input image and the noise map. Each image of the batch of size CxHxW gets
20
+ converted to an array of size 4*CxH/2xW/2. Each of the pixels of the
21
+ non-overlapped 2x2 patches of the input image are placed in the new array
22
+ along the first dimension.
23
+
24
+ Args:
25
+ input: batch containing CxHxW images
26
+ noise_sigma: the value of the pixels of the CxH/2xW/2 noise map
27
+ """
28
+ # noise_sigma is a list of length batch_size
29
+ N, C, H, W = input.size()
30
+ dtype = input.type()
31
+ sca = 2
32
+ sca2 = sca*sca
33
+ Cout = sca2*C
34
+ Hout = H//sca
35
+ Wout = W//sca
36
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
37
+
38
+ # Fill the downsampled image with zeros
39
+ if 'cuda' in dtype:
40
+ downsampledfeatures = torch.cuda.FloatTensor(N, Cout, Hout, Wout).fill_(0)
41
+ else:
42
+ # cpu and mps are the same
43
+ downsampledfeatures = torch.FloatTensor(N, Cout, Hout, Wout).fill_(0)
44
+
45
+ # Build the CxH/2xW/2 noise map
46
+ noise_map = noise_sigma.view(N, 1, 1, 1).repeat(1, C, Hout, Wout)
47
+
48
+ # Populate output
49
+ for idx in range(sca2):
50
+ downsampledfeatures[:, idx:Cout:sca2, :, :] = \
51
+ input[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
52
+
53
+ # concatenate de-interleaved mosaic with noise map
54
+ return torch.cat((noise_map, downsampledfeatures), 1)
55
+
56
+ class UpSampleFeaturesFunction(Function):
57
+ r"""Extends PyTorch's modules by implementing a torch.autograd.Function.
58
+ This class implements the forward and backward methods of the last layer
59
+ of FFDNet. It basically performs the inverse of
60
+ concatenate_input_noise_map(): it converts each of the images of a
61
+ batch of size CxH/2xW/2 to images of size C/4xHxW
62
+ """
63
+ @staticmethod
64
+ def forward(ctx, input):
65
+ N, Cin, Hin, Win = input.size()
66
+ dtype = input.type()
67
+ sca = 2
68
+ sca2 = sca*sca
69
+ Cout = Cin//sca2
70
+ Hout = Hin*sca
71
+ Wout = Win*sca
72
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
73
+
74
+ assert (Cin%sca2 == 0), 'Invalid input dimensions: number of channels should be divisible by 4'
75
+
76
+ result = torch.zeros((N, Cout, Hout, Wout)).type(dtype)
77
+ for idx in range(sca2):
78
+ result[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca] = input[:, idx:Cin:sca2, :, :]
79
+
80
+ return result
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_output):
84
+ N, Cg_out, Hg_out, Wg_out = grad_output.size()
85
+ dtype = grad_output.data.type()
86
+ sca = 2
87
+ sca2 = sca*sca
88
+ Cg_in = sca2*Cg_out
89
+ Hg_in = Hg_out//sca
90
+ Wg_in = Wg_out//sca
91
+ idxL = [[0, 0], [0, 1], [1, 0], [1, 1]]
92
+
93
+ # Build output
94
+ grad_input = torch.zeros((N, Cg_in, Hg_in, Wg_in)).type(dtype)
95
+ # Populate output
96
+ for idx in range(sca2):
97
+ grad_input[:, idx:Cg_in:sca2, :, :] = grad_output.data[:, :, idxL[idx][0]::sca, idxL[idx][1]::sca]
98
+
99
+ return Variable(grad_input)
100
+
101
+ # Alias functions
102
+ upsamplefeatures = UpSampleFeaturesFunction.apply
manga_translator/colorization/manga_colorization_v2_utils/denoising/models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definition of the FFDNet model and its custom layers
3
+
4
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
5
+
6
+ This program is free software: you can use, modify and/or
7
+ redistribute it under the terms of the GNU General Public
8
+ License as published by the Free Software Foundation, either
9
+ version 3 of the License, or (at your option) any later
10
+ version. You should have received a copy of this license along
11
+ this program. If not, see <http://www.gnu.org/licenses/>.
12
+ """
13
+ import torch.nn as nn
14
+ from torch.autograd import Variable
15
+ from . import functions
16
+
17
+ class UpSampleFeatures(nn.Module):
18
+ r"""Implements the last layer of FFDNet
19
+ """
20
+ def __init__(self):
21
+ super(UpSampleFeatures, self).__init__()
22
+ def forward(self, x):
23
+ return functions.upsamplefeatures(x)
24
+
25
+ class IntermediateDnCNN(nn.Module):
26
+ r"""Implements the middel part of the FFDNet architecture, which
27
+ is basically a DnCNN net
28
+ """
29
+ def __init__(self, input_features, middle_features, num_conv_layers):
30
+ super(IntermediateDnCNN, self).__init__()
31
+ self.kernel_size = 3
32
+ self.padding = 1
33
+ self.input_features = input_features
34
+ self.num_conv_layers = num_conv_layers
35
+ self.middle_features = middle_features
36
+ if self.input_features == 5:
37
+ self.output_features = 4 #Grayscale image
38
+ elif self.input_features == 15:
39
+ self.output_features = 12 #RGB image
40
+ else:
41
+ raise Exception('Invalid number of input features')
42
+
43
+ layers = []
44
+ layers.append(nn.Conv2d(in_channels=self.input_features,\
45
+ out_channels=self.middle_features,\
46
+ kernel_size=self.kernel_size,\
47
+ padding=self.padding,\
48
+ bias=False))
49
+ layers.append(nn.ReLU(inplace=True))
50
+ for _ in range(self.num_conv_layers-2):
51
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
52
+ out_channels=self.middle_features,\
53
+ kernel_size=self.kernel_size,\
54
+ padding=self.padding,\
55
+ bias=False))
56
+ layers.append(nn.BatchNorm2d(self.middle_features))
57
+ layers.append(nn.ReLU(inplace=True))
58
+ layers.append(nn.Conv2d(in_channels=self.middle_features,\
59
+ out_channels=self.output_features,\
60
+ kernel_size=self.kernel_size,\
61
+ padding=self.padding,\
62
+ bias=False))
63
+ self.itermediate_dncnn = nn.Sequential(*layers)
64
+ def forward(self, x):
65
+ out = self.itermediate_dncnn(x)
66
+ return out
67
+
68
+ class FFDNet(nn.Module):
69
+ r"""Implements the FFDNet architecture
70
+ """
71
+ def __init__(self, num_input_channels):
72
+ super(FFDNet, self).__init__()
73
+ self.num_input_channels = num_input_channels
74
+ if self.num_input_channels == 1:
75
+ # Grayscale image
76
+ self.num_feature_maps = 64
77
+ self.num_conv_layers = 15
78
+ self.downsampled_channels = 5
79
+ self.output_features = 4
80
+ elif self.num_input_channels == 3:
81
+ # RGB image
82
+ self.num_feature_maps = 96
83
+ self.num_conv_layers = 12
84
+ self.downsampled_channels = 15
85
+ self.output_features = 12
86
+ else:
87
+ raise Exception('Invalid number of input features')
88
+
89
+ self.intermediate_dncnn = IntermediateDnCNN(\
90
+ input_features=self.downsampled_channels,\
91
+ middle_features=self.num_feature_maps,\
92
+ num_conv_layers=self.num_conv_layers)
93
+ self.upsamplefeatures = UpSampleFeatures()
94
+
95
+ def forward(self, x, noise_sigma):
96
+ concat_noise_x = functions.concatenate_input_noise_map(x.data, noise_sigma.data)
97
+ concat_noise_x = Variable(concat_noise_x)
98
+ h_dncnn = self.intermediate_dncnn(concat_noise_x)
99
+ pred_noise = self.upsamplefeatures(h_dncnn)
100
+ return pred_noise
manga_translator/colorization/manga_colorization_v2_utils/denoising/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Different utilities such as orthogonalization of weights, initialization of
3
+ loggers, etc
4
+
5
+ Copyright (C) 2018, Matias Tassano <matias.tassano@parisdescartes.fr>
6
+
7
+ This program is free software: you can use, modify and/or
8
+ redistribute it under the terms of the GNU General Public
9
+ License as published by the Free Software Foundation, either
10
+ version 3 of the License, or (at your option) any later
11
+ version. You should have received a copy of this license along
12
+ this program. If not, see <http://www.gnu.org/licenses/>.
13
+ """
14
+ import numpy as np
15
+ import cv2
16
+
17
+
18
+ def variable_to_cv2_image(varim):
19
+ r"""Converts a torch.autograd.Variable to an OpenCV image
20
+
21
+ Args:
22
+ varim: a torch.autograd.Variable
23
+ """
24
+ nchannels = varim.size()[1]
25
+ if nchannels == 1:
26
+ res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8)
27
+ elif nchannels == 3:
28
+ res = varim.data.cpu().numpy()[0]
29
+ res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
30
+ res = (res*255.).clip(0, 255).astype(np.uint8)
31
+ else:
32
+ raise Exception('Number of color channels not supported')
33
+ return res
34
+
35
+
36
+ def normalize(data):
37
+ return np.float32(data/255.)
38
+
39
+ def remove_dataparallel_wrapper(state_dict):
40
+ r"""Converts a DataParallel model to a normal one by removing the "module."
41
+ wrapper in the module dictionary
42
+
43
+ Args:
44
+ state_dict: a torch.nn.DataParallel state dictionary
45
+ """
46
+ from collections import OrderedDict
47
+
48
+ new_state_dict = OrderedDict()
49
+ for k, vl in state_dict.items():
50
+ name = k[7:] # remove 'module.' of DataParallel
51
+ new_state_dict[name] = vl
52
+
53
+ return new_state_dict
54
+
55
+ def is_rgb(im_path):
56
+ r""" Returns True if the image in im_path is an RGB image
57
+ """
58
+ from skimage.io import imread
59
+ rgb = False
60
+ im = imread(im_path)
61
+ if (len(im.shape) == 3):
62
+ if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])):
63
+ rgb = True
64
+ print("rgb: {}".format(rgb))
65
+ print("im shape: {}".format(im.shape))
66
+ return rgb
manga_translator/colorization/manga_colorization_v2_utils/networks/extractor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ '''https://github.com/blandocs/Tag2Pix/blob/master/model/pretrained.py'''
6
+
7
+ # Pretrained version
8
+ class Selayer(nn.Module):
9
+ def __init__(self, inplanes):
10
+ super(Selayer, self).__init__()
11
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
12
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
13
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.sigmoid = nn.Sigmoid()
16
+
17
+ def forward(self, x):
18
+ out = self.global_avgpool(x)
19
+ out = self.conv1(out)
20
+ out = self.relu(out)
21
+ out = self.conv2(out)
22
+ out = self.sigmoid(out)
23
+
24
+ return x * out
25
+
26
+
27
+ class BottleneckX_Origin(nn.Module):
28
+ expansion = 4
29
+
30
+ def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
31
+ super(BottleneckX_Origin, self).__init__()
32
+ self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(planes * 2)
34
+
35
+ self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
36
+ padding=1, groups=cardinality, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(planes * 2)
38
+
39
+ self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
40
+ self.bn3 = nn.BatchNorm2d(planes * 4)
41
+
42
+ self.selayer = Selayer(planes * 4)
43
+
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.downsample = downsample
46
+ self.stride = stride
47
+
48
+ def forward(self, x):
49
+ residual = x
50
+
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv3(out)
60
+ out = self.bn3(out)
61
+
62
+ out = self.selayer(out)
63
+
64
+ if self.downsample is not None:
65
+ residual = self.downsample(x)
66
+
67
+ out += residual
68
+ out = self.relu(out)
69
+
70
+ return out
71
+
72
+ class SEResNeXt_Origin(nn.Module):
73
+ def __init__(self, block, layers, input_channels=3, cardinality=32, num_classes=1000):
74
+ super(SEResNeXt_Origin, self).__init__()
75
+ self.cardinality = cardinality
76
+ self.inplanes = 64
77
+ self.input_channels = input_channels
78
+
79
+ self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
80
+ bias=False)
81
+ self.bn1 = nn.BatchNorm2d(64)
82
+ self.relu = nn.ReLU(inplace=True)
83
+
84
+ self.layer1 = self._make_layer(block, 64, layers[0])
85
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
86
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
87
+
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d):
90
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
91
+ m.weight.data.normal_(0, math.sqrt(2. / n))
92
+ if m.bias is not None:
93
+ m.bias.data.zero_()
94
+ elif isinstance(m, nn.BatchNorm2d):
95
+ m.weight.data.fill_(1)
96
+ m.bias.data.zero_()
97
+
98
+ def _make_layer(self, block, planes, blocks, stride=1):
99
+ downsample = None
100
+ if stride != 1 or self.inplanes != planes * block.expansion:
101
+ downsample = nn.Sequential(
102
+ nn.Conv2d(self.inplanes, planes * block.expansion,
103
+ kernel_size=1, stride=stride, bias=False),
104
+ nn.BatchNorm2d(planes * block.expansion),
105
+ )
106
+
107
+ layers = []
108
+ layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
109
+ self.inplanes = planes * block.expansion
110
+ for i in range(1, blocks):
111
+ layers.append(block(self.inplanes, planes, self.cardinality))
112
+
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+
117
+ x = self.conv1(x)
118
+ x = self.bn1(x)
119
+ x1 = self.relu(x)
120
+
121
+ x2 = self.layer1(x1)
122
+
123
+ x3 = self.layer2(x2)
124
+
125
+ x4 = self.layer3(x3)
126
+
127
+ return x1, x2, x3, x4
manga_translator/colorization/manga_colorization_v2_utils/networks/models.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as M
5
+ import math
6
+ from torch import Tensor
7
+ from torch.nn import Parameter
8
+
9
+ from .extractor import SEResNeXt_Origin, BottleneckX_Origin
10
+
11
+ '''https://github.com/orashi/AlacGAN/blob/master/models/standard.py'''
12
+
13
+ def l2normalize(v, eps=1e-12):
14
+ return v / (v.norm() + eps)
15
+
16
+
17
+ class SpectralNorm(nn.Module):
18
+ def __init__(self, module, name='weight', power_iterations=1):
19
+ super(SpectralNorm, self).__init__()
20
+ self.module = module
21
+ self.name = name
22
+ self.power_iterations = power_iterations
23
+ if not self._made_params():
24
+ self._make_params()
25
+
26
+ def _update_u_v(self):
27
+ u = getattr(self.module, self.name + "_u")
28
+ v = getattr(self.module, self.name + "_v")
29
+ w = getattr(self.module, self.name + "_bar")
30
+
31
+ height = w.data.shape[0]
32
+ for _ in range(self.power_iterations):
33
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
34
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
35
+
36
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
37
+ sigma = u.dot(w.view(height, -1).mv(v))
38
+ setattr(self.module, self.name, w / sigma.expand_as(w))
39
+
40
+ def _made_params(self):
41
+ try:
42
+ u = getattr(self.module, self.name + "_u")
43
+ v = getattr(self.module, self.name + "_v")
44
+ w = getattr(self.module, self.name + "_bar")
45
+ return True
46
+ except AttributeError:
47
+ return False
48
+
49
+
50
+ def _make_params(self):
51
+ w = getattr(self.module, self.name)
52
+ height = w.data.shape[0]
53
+ width = w.view(height, -1).data.shape[1]
54
+
55
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
56
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
57
+ u.data = l2normalize(u.data)
58
+ v.data = l2normalize(v.data)
59
+ w_bar = Parameter(w.data)
60
+
61
+ del self.module._parameters[self.name]
62
+
63
+ self.module.register_parameter(self.name + "_u", u)
64
+ self.module.register_parameter(self.name + "_v", v)
65
+ self.module.register_parameter(self.name + "_bar", w_bar)
66
+
67
+
68
+ def forward(self, *args):
69
+ self._update_u_v()
70
+ return self.module.forward(*args)
71
+
72
+ class Selayer(nn.Module):
73
+ def __init__(self, inplanes):
74
+ super(Selayer, self).__init__()
75
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
76
+ self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
77
+ self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
78
+ self.relu = nn.ReLU(inplace=True)
79
+ self.sigmoid = nn.Sigmoid()
80
+
81
+ def forward(self, x):
82
+ out = self.global_avgpool(x)
83
+ out = self.conv1(out)
84
+ out = self.relu(out)
85
+ out = self.conv2(out)
86
+ out = self.sigmoid(out)
87
+
88
+ return x * out
89
+
90
+ class SelayerSpectr(nn.Module):
91
+ def __init__(self, inplanes):
92
+ super(SelayerSpectr, self).__init__()
93
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
94
+ self.conv1 = SpectralNorm(nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1))
95
+ self.conv2 = SpectralNorm(nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1))
96
+ self.relu = nn.ReLU(inplace=True)
97
+ self.sigmoid = nn.Sigmoid()
98
+
99
+ def forward(self, x):
100
+ out = self.global_avgpool(x)
101
+ out = self.conv1(out)
102
+ out = self.relu(out)
103
+ out = self.conv2(out)
104
+ out = self.sigmoid(out)
105
+
106
+ return x * out
107
+
108
+ class ResNeXtBottleneck(nn.Module):
109
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
110
+ super(ResNeXtBottleneck, self).__init__()
111
+ D = out_channels // 2
112
+ self.out_channels = out_channels
113
+ self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
114
+ self.conv_conv = nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
115
+ groups=cardinality,
116
+ bias=False)
117
+ self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
118
+ self.shortcut = nn.Sequential()
119
+ if stride != 1:
120
+ self.shortcut.add_module('shortcut',
121
+ nn.AvgPool2d(2, stride=2))
122
+
123
+ self.selayer = Selayer(out_channels)
124
+
125
+ def forward(self, x):
126
+ bottleneck = self.conv_reduce.forward(x)
127
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
128
+ bottleneck = self.conv_conv.forward(bottleneck)
129
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
130
+ bottleneck = self.conv_expand.forward(bottleneck)
131
+ bottleneck = self.selayer(bottleneck)
132
+
133
+ x = self.shortcut.forward(x)
134
+ return x + bottleneck
135
+
136
+ class SpectrResNeXtBottleneck(nn.Module):
137
+ def __init__(self, in_channels=256, out_channels=256, stride=1, cardinality=32, dilate=1):
138
+ super(SpectrResNeXtBottleneck, self).__init__()
139
+ D = out_channels // 2
140
+ self.out_channels = out_channels
141
+ self.conv_reduce = SpectralNorm(nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False))
142
+ self.conv_conv = SpectralNorm(nn.Conv2d(D, D, kernel_size=2 + stride, stride=stride, padding=dilate, dilation=dilate,
143
+ groups=cardinality,
144
+ bias=False))
145
+ self.conv_expand = SpectralNorm(nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False))
146
+ self.shortcut = nn.Sequential()
147
+ if stride != 1:
148
+ self.shortcut.add_module('shortcut',
149
+ nn.AvgPool2d(2, stride=2))
150
+
151
+ self.selayer = SelayerSpectr(out_channels)
152
+
153
+ def forward(self, x):
154
+ bottleneck = self.conv_reduce.forward(x)
155
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
156
+ bottleneck = self.conv_conv.forward(bottleneck)
157
+ bottleneck = F.leaky_relu(bottleneck, 0.2, True)
158
+ bottleneck = self.conv_expand.forward(bottleneck)
159
+ bottleneck = self.selayer(bottleneck)
160
+
161
+ x = self.shortcut.forward(x)
162
+ return x + bottleneck
163
+
164
+ class FeatureConv(nn.Module):
165
+ def __init__(self, input_dim=512, output_dim=512):
166
+ super(FeatureConv, self).__init__()
167
+
168
+ no_bn = True
169
+
170
+ seq = []
171
+ seq.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
172
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
173
+ seq.append(nn.ReLU(inplace=True))
174
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
175
+ if not no_bn: seq.append(nn.BatchNorm2d(output_dim))
176
+ seq.append(nn.ReLU(inplace=True))
177
+ seq.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1, bias=False))
178
+ seq.append(nn.ReLU(inplace=True))
179
+
180
+ self.network = nn.Sequential(*seq)
181
+
182
+ def forward(self, x):
183
+ return self.network(x)
184
+
185
+ class Generator(nn.Module):
186
+ def __init__(self, ngf=64):
187
+ super(Generator, self).__init__()
188
+
189
+ self.encoder = SEResNeXt_Origin(BottleneckX_Origin, [3, 4, 6, 3], num_classes= 370, input_channels=1)
190
+
191
+ self.to0 = self._make_encoder_block_first(5, 32)
192
+ self.to1 = self._make_encoder_block(32, 64)
193
+ self.to2 = self._make_encoder_block(64, 92)
194
+ self.to3 = self._make_encoder_block(92, 128)
195
+ self.to4 = self._make_encoder_block(128, 256)
196
+
197
+ self.deconv_for_decoder = nn.Sequential(
198
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # output is 64 * 64
199
+ nn.LeakyReLU(0.2),
200
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # output is 128 * 128
201
+ nn.LeakyReLU(0.2),
202
+ nn.ConvTranspose2d(64, 32, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
203
+ nn.LeakyReLU(0.2),
204
+ nn.ConvTranspose2d(32, 3, 3, stride=1, padding=1, output_padding=0), # output is 256 * 256
205
+ nn.Tanh(),
206
+ )
207
+
208
+ tunnel4 = nn.Sequential(*[ResNeXtBottleneck(512, 512, cardinality=32, dilate=1) for _ in range(20)])
209
+
210
+
211
+ self.tunnel4 = nn.Sequential(nn.Conv2d(1024 + 128, 512, kernel_size=3, stride=1, padding=1),
212
+ nn.LeakyReLU(0.2, True),
213
+ tunnel4,
214
+ nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
215
+ nn.PixelShuffle(2),
216
+ nn.LeakyReLU(0.2, True)
217
+ ) # 64
218
+
219
+ depth = 2
220
+ tunnel = [ResNeXtBottleneck(256, 256, cardinality=32, dilate=1) for _ in range(depth)]
221
+ tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2) for _ in range(depth)]
222
+ tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=4) for _ in range(depth)]
223
+ tunnel += [ResNeXtBottleneck(256, 256, cardinality=32, dilate=2),
224
+ ResNeXtBottleneck(256, 256, cardinality=32, dilate=1)]
225
+ tunnel3 = nn.Sequential(*tunnel)
226
+
227
+ self.tunnel3 = nn.Sequential(nn.Conv2d(512 + 256, 256, kernel_size=3, stride=1, padding=1),
228
+ nn.LeakyReLU(0.2, True),
229
+ tunnel3,
230
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
231
+ nn.PixelShuffle(2),
232
+ nn.LeakyReLU(0.2, True)
233
+ ) # 128
234
+
235
+ tunnel = [ResNeXtBottleneck(128, 128, cardinality=32, dilate=1) for _ in range(depth)]
236
+ tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2) for _ in range(depth)]
237
+ tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=4) for _ in range(depth)]
238
+ tunnel += [ResNeXtBottleneck(128, 128, cardinality=32, dilate=2),
239
+ ResNeXtBottleneck(128, 128, cardinality=32, dilate=1)]
240
+ tunnel2 = nn.Sequential(*tunnel)
241
+
242
+ self.tunnel2 = nn.Sequential(nn.Conv2d(128 + 256 + 64, 128, kernel_size=3, stride=1, padding=1),
243
+ nn.LeakyReLU(0.2, True),
244
+ tunnel2,
245
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
246
+ nn.PixelShuffle(2),
247
+ nn.LeakyReLU(0.2, True)
248
+ )
249
+
250
+ tunnel = [ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
251
+ tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2)]
252
+ tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=4)]
253
+ tunnel += [ResNeXtBottleneck(64, 64, cardinality=16, dilate=2),
254
+ ResNeXtBottleneck(64, 64, cardinality=16, dilate=1)]
255
+ tunnel1 = nn.Sequential(*tunnel)
256
+
257
+ self.tunnel1 = nn.Sequential(nn.Conv2d(64 + 32, 64, kernel_size=3, stride=1, padding=1),
258
+ nn.LeakyReLU(0.2, True),
259
+ tunnel1,
260
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
261
+ nn.PixelShuffle(2),
262
+ nn.LeakyReLU(0.2, True)
263
+ )
264
+
265
+ self.exit = nn.Sequential(nn.Conv2d(64 + 32, 32, kernel_size=3, stride=1, padding=1),
266
+ nn.LeakyReLU(0.2, True),
267
+ nn.Conv2d(32, 3, kernel_size= 1, stride = 1, padding = 0))
268
+
269
+
270
+ def _make_encoder_block(self, inplanes, planes):
271
+ return nn.Sequential(
272
+ nn.Conv2d(inplanes, planes, 3, 2, 1),
273
+ nn.LeakyReLU(0.2),
274
+ nn.Conv2d(planes, planes, 3, 1, 1),
275
+ nn.LeakyReLU(0.2),
276
+ )
277
+
278
+ def _make_encoder_block_first(self, inplanes, planes):
279
+ return nn.Sequential(
280
+ nn.Conv2d(inplanes, planes, 3, 1, 1),
281
+ nn.LeakyReLU(0.2),
282
+ nn.Conv2d(planes, planes, 3, 1, 1),
283
+ nn.LeakyReLU(0.2),
284
+ )
285
+
286
+ def forward(self, sketch):
287
+
288
+ x0 = self.to0(sketch)
289
+ aux_out = self.to1(x0)
290
+ aux_out = self.to2(aux_out)
291
+ aux_out = self.to3(aux_out)
292
+
293
+ x1, x2, x3, x4 = self.encoder(sketch[:, 0:1])
294
+
295
+ out = self.tunnel4(torch.cat([x4, aux_out], 1))
296
+
297
+
298
+
299
+ x = self.tunnel3(torch.cat([out, x3], 1))
300
+
301
+ x = self.tunnel2(torch.cat([x, x2, x1], 1))
302
+
303
+
304
+ x = torch.tanh(self.exit(torch.cat([x, x0], 1)))
305
+
306
+ decoder_output = self.deconv_for_decoder(out)
307
+
308
+ return x, decoder_output
309
+
310
+
311
+ class Colorizer(nn.Module):
312
+ def __init__(self):
313
+ super(Colorizer, self).__init__()
314
+
315
+ self.generator = Generator()
316
+
317
+ def forward(self, x, extractor_grad = False):
318
+ fake, guide = self.generator(x)
319
+ return fake, guide
manga_translator/colorization/manga_colorization_v2_utils/utils/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ def resize_pad(img, size = 256):
5
+
6
+ if len(img.shape) == 2:
7
+ img = np.expand_dims(img, 2)
8
+
9
+ if img.shape[2] == 1:
10
+ img = np.repeat(img, 3, 2)
11
+
12
+ if img.shape[2] == 4:
13
+ img = img[:, :, :3]
14
+
15
+ pad = None
16
+
17
+ if (img.shape[0] < img.shape[1]):
18
+ height = img.shape[0]
19
+ ratio = height / (size * 1.5)
20
+ width = int(np.ceil(img.shape[1] / ratio))
21
+ img = cv2.resize(img, (width, int(size * 1.5)), interpolation = cv2.INTER_AREA)
22
+
23
+
24
+ new_width = width + (32 - width % 32)
25
+
26
+ pad = (0, new_width - width)
27
+
28
+ img = np.pad(img, ((0, 0), (0, pad[1]), (0, 0)), 'maximum')
29
+ else:
30
+ width = img.shape[1]
31
+ ratio = width / size
32
+ height = int(np.ceil(img.shape[0] / ratio))
33
+ img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
34
+
35
+ new_height = height + (32 - height % 32)
36
+
37
+ pad = (new_height - height, 0)
38
+
39
+ img = np.pad(img, ((0, pad[0]), (0, 0), (0, 0)), 'maximum')
40
+
41
+ if (img.dtype == 'float32'):
42
+ np.clip(img, 0, 1, out = img)
43
+
44
+ return img[:, :, :1], pad
manga_translator/detection/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from .default import DefaultDetector
4
+ from .dbnet_convnext import DBConvNextDetector
5
+ from .ctd import ComicTextDetector
6
+ from .craft import CRAFTDetector
7
+ from .none import NoneDetector
8
+ from .common import CommonDetector, OfflineDetector
9
+
10
+ DETECTORS = {
11
+ 'default': DefaultDetector,
12
+ 'dbconvnext': DBConvNextDetector,
13
+ 'ctd': ComicTextDetector,
14
+ 'craft': CRAFTDetector,
15
+ 'none': NoneDetector,
16
+ }
17
+ detector_cache = {}
18
+
19
+ def get_detector(key: str, *args, **kwargs) -> CommonDetector:
20
+ if key not in DETECTORS:
21
+ raise ValueError(f'Could not find detector for: "{key}". Choose from the following: %s' % ','.join(DETECTORS))
22
+ if not detector_cache.get(key):
23
+ detector = DETECTORS[key]
24
+ detector_cache[key] = detector(*args, **kwargs)
25
+ return detector_cache[key]
26
+
27
+ async def prepare(detector_key: str):
28
+ detector = get_detector(detector_key)
29
+ if isinstance(detector, OfflineDetector):
30
+ await detector.download()
31
+
32
+ async def dispatch(detector_key: str, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float, unclip_ratio: float,
33
+ invert: bool, gamma_correct: bool, rotate: bool, auto_rotate: bool = False, device: str = 'cpu', verbose: bool = False):
34
+ detector = get_detector(detector_key)
35
+ if isinstance(detector, OfflineDetector):
36
+ await detector.load(device)
37
+ return await detector.detect(image, detect_size, text_threshold, box_threshold, unclip_ratio, invert, gamma_correct, rotate, auto_rotate, verbose)
manga_translator/detection/common.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import List, Tuple
3
+ from collections import Counter
4
+ import numpy as np
5
+ import cv2
6
+
7
+ from ..utils import InfererModule, ModelWrapper, Quadrilateral
8
+
9
+
10
+ class CommonDetector(InfererModule):
11
+
12
+ async def detect(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float, unclip_ratio: float,
13
+ invert: bool, gamma_correct: bool, rotate: bool, auto_rotate: bool = False, verbose: bool = False):
14
+ '''
15
+ Returns textblock list and text mask.
16
+ '''
17
+
18
+ # Apply filters
19
+ img_h, img_w = image.shape[:2]
20
+ orig_image = image.copy()
21
+ minimum_image_size = 400
22
+ # Automatically add border if image too small (instead of simply resizing due to them more likely containing large fonts)
23
+ add_border = min(img_w, img_h) < minimum_image_size
24
+ if rotate:
25
+ self.logger.debug('Adding rotation')
26
+ image = self._add_rotation(image)
27
+ if add_border:
28
+ self.logger.debug('Adding border')
29
+ image = self._add_border(image, minimum_image_size)
30
+ if invert:
31
+ self.logger.debug('Adding inversion')
32
+ image = self._add_inversion(image)
33
+ if gamma_correct:
34
+ self.logger.debug('Adding gamma correction')
35
+ image = self._add_gamma_correction(image)
36
+ # if True:
37
+ # self.logger.debug('Adding histogram equalization')
38
+ # image = self._add_histogram_equalization(image)
39
+
40
+ # cv2.imwrite('histogram.png', image)
41
+ # cv2.waitKey(0)
42
+
43
+ # Run detection
44
+ textlines, raw_mask, mask = await self._detect(image, detect_size, text_threshold, box_threshold, unclip_ratio, verbose)
45
+ textlines = list(filter(lambda x: x.area > 1, textlines))
46
+
47
+ # Remove filters
48
+ if add_border:
49
+ textlines, raw_mask, mask = self._remove_border(image, img_w, img_h, textlines, raw_mask, mask)
50
+ if auto_rotate:
51
+ # Rotate if horizontal aspect ratios are prevalent to potentially improve detection
52
+ if len(textlines) > 0:
53
+ orientations = ['h' if txtln.aspect_ratio > 1 else 'v' for txtln in textlines]
54
+ majority_orientation = Counter(orientations).most_common(1)[0][0]
55
+ else:
56
+ majority_orientation = 'h'
57
+ if majority_orientation == 'h':
58
+ self.logger.info('Rerunning detection with 90° rotation')
59
+ return await self.detect(orig_image, detect_size, text_threshold, box_threshold, unclip_ratio, invert, gamma_correct,
60
+ rotate=(not rotate), auto_rotate=False, verbose=verbose)
61
+ if rotate:
62
+ textlines, raw_mask, mask = self._remove_rotation(textlines, raw_mask, mask, img_w, img_h)
63
+
64
+ return textlines, raw_mask, mask
65
+
66
+ @abstractmethod
67
+ async def _detect(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
68
+ unclip_ratio: float, verbose: bool = False) -> Tuple[List[Quadrilateral], np.ndarray, np.ndarray]:
69
+ pass
70
+
71
+ def _add_border(self, image: np.ndarray, target_side_length: int):
72
+ old_h, old_w = image.shape[:2]
73
+ new_w = new_h = max(old_w, old_h, target_side_length)
74
+ new_image = np.zeros([new_h, new_w, 3]).astype(np.uint8)
75
+ # new_image[:] = np.array([255, 255, 255], np.uint8)
76
+ x, y = 0, 0
77
+ # x, y = (new_h - old_h) // 2, (new_w - old_w) // 2
78
+ new_image[y:y+old_h, x:x+old_w] = image
79
+ return new_image
80
+
81
+ def _remove_border(self, image: np.ndarray, old_w: int, old_h: int, textlines: List[Quadrilateral], raw_mask, mask):
82
+ new_h, new_w = image.shape[:2]
83
+ raw_mask = cv2.resize(raw_mask, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
84
+ raw_mask = raw_mask[:old_h, :old_w]
85
+ if mask is not None:
86
+ mask = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
87
+ mask = mask[:old_h, :old_w]
88
+
89
+ # Filter out regions within the border and clamp the points of the remaining regions
90
+ new_textlines = []
91
+ for txtln in textlines:
92
+ if txtln.xyxy[0] >= old_w and txtln.xyxy[1] >= old_h:
93
+ continue
94
+ points = txtln.pts
95
+ points[:,0] = np.clip(points[:,0], 0, old_w)
96
+ points[:,1] = np.clip(points[:,1], 0, old_h)
97
+ new_txtln = Quadrilateral(points, txtln.text, txtln.prob)
98
+ new_textlines.append(new_txtln)
99
+ return new_textlines, raw_mask, mask
100
+
101
+ def _add_rotation(self, image: np.ndarray):
102
+ return np.rot90(image, k=-1)
103
+
104
+ def _remove_rotation(self, textlines, raw_mask, mask, img_w, img_h):
105
+ raw_mask = np.ascontiguousarray(np.rot90(raw_mask))
106
+ if mask is not None:
107
+ mask = np.ascontiguousarray(np.rot90(mask).astype(np.uint8))
108
+
109
+ for i, txtln in enumerate(textlines):
110
+ rotated_pts = txtln.pts[:,[1,0]]
111
+ rotated_pts[:,1] = -rotated_pts[:,1] + img_h
112
+ textlines[i] = Quadrilateral(rotated_pts, txtln.text, txtln.prob)
113
+ return textlines, raw_mask, mask
114
+
115
+ def _add_inversion(self, image: np.ndarray):
116
+ return cv2.bitwise_not(image)
117
+
118
+ def _add_gamma_correction(self, image: np.ndarray):
119
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
120
+ mid = 0.5
121
+ mean = np.mean(gray)
122
+ gamma = np.log(mid * 255) / np.log(mean)
123
+ img_gamma = np.power(image, gamma).clip(0,255).astype(np.uint8)
124
+ return img_gamma
125
+
126
+ def _add_histogram_equalization(self, image: np.ndarray):
127
+ img_yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
128
+
129
+ # equalize the histogram of the Y channel
130
+ img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0])
131
+
132
+ # convert the YUV image back to RGB format
133
+ img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)
134
+ return img_output
135
+
136
+
137
+ class OfflineDetector(CommonDetector, ModelWrapper):
138
+ _MODEL_SUB_DIR = 'detection'
139
+
140
+ async def _detect(self, *args, **kwargs):
141
+ return await self.infer(*args, **kwargs)
142
+
143
+ @abstractmethod
144
+ async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
145
+ unclip_ratio: float, verbose: bool = False):
146
+ pass
manga_translator/detection/craft.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+ MIT License
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import os
12
+ import shutil
13
+ import numpy as np
14
+ import torch
15
+ import cv2
16
+ import einops
17
+ from typing import List, Tuple
18
+
19
+ from .default_utils.DBNet_resnet34 import TextDetection as TextDetectionDefault
20
+ from .default_utils import imgproc, dbnet_utils, craft_utils
21
+ from .common import OfflineDetector
22
+ from ..utils import TextBlock, Quadrilateral, det_rearrange_forward
23
+ from shapely.geometry import Polygon, MultiPoint
24
+ from shapely import affinity
25
+
26
+ from .craft_utils.vgg16_bn import vgg16_bn, init_weights
27
+ from .craft_utils.refiner import RefineNet
28
+
29
+ class double_conv(nn.Module):
30
+ def __init__(self, in_ch, mid_ch, out_ch):
31
+ super(double_conv, self).__init__()
32
+ self.conv = nn.Sequential(
33
+ nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
34
+ nn.BatchNorm2d(mid_ch),
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
37
+ nn.BatchNorm2d(out_ch),
38
+ nn.ReLU(inplace=True)
39
+ )
40
+
41
+ def forward(self, x):
42
+ x = self.conv(x)
43
+ return x
44
+
45
+
46
+ class CRAFT(nn.Module):
47
+ def __init__(self, pretrained=False, freeze=False):
48
+ super(CRAFT, self).__init__()
49
+
50
+ """ Base network """
51
+ self.basenet = vgg16_bn(pretrained, freeze)
52
+
53
+ """ U network """
54
+ self.upconv1 = double_conv(1024, 512, 256)
55
+ self.upconv2 = double_conv(512, 256, 128)
56
+ self.upconv3 = double_conv(256, 128, 64)
57
+ self.upconv4 = double_conv(128, 64, 32)
58
+
59
+ num_class = 2
60
+ self.conv_cls = nn.Sequential(
61
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
62
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
63
+ nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
64
+ nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
65
+ nn.Conv2d(16, num_class, kernel_size=1),
66
+ )
67
+
68
+ init_weights(self.upconv1.modules())
69
+ init_weights(self.upconv2.modules())
70
+ init_weights(self.upconv3.modules())
71
+ init_weights(self.upconv4.modules())
72
+ init_weights(self.conv_cls.modules())
73
+
74
+ def forward(self, x):
75
+ """ Base network """
76
+ sources = self.basenet(x)
77
+
78
+ """ U network """
79
+ y = torch.cat([sources[0], sources[1]], dim=1)
80
+ y = self.upconv1(y)
81
+
82
+ y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
83
+ y = torch.cat([y, sources[2]], dim=1)
84
+ y = self.upconv2(y)
85
+
86
+ y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
87
+ y = torch.cat([y, sources[3]], dim=1)
88
+ y = self.upconv3(y)
89
+
90
+ y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
91
+ y = torch.cat([y, sources[4]], dim=1)
92
+ feature = self.upconv4(y)
93
+
94
+ y = self.conv_cls(feature)
95
+
96
+ return y.permute(0,2,3,1), feature
97
+
98
+
99
+ from collections import OrderedDict
100
+ def copyStateDict(state_dict):
101
+ if list(state_dict.keys())[0].startswith("module"):
102
+ start_idx = 1
103
+ else:
104
+ start_idx = 0
105
+ new_state_dict = OrderedDict()
106
+ for k, v in state_dict.items():
107
+ name = ".".join(k.split(".")[start_idx:])
108
+ new_state_dict[name] = v
109
+ return new_state_dict
110
+
111
+ class CRAFTDetector(OfflineDetector):
112
+ _MODEL_MAPPING = {
113
+ 'refiner': {
114
+ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/craft_refiner_CTW1500.pth',
115
+ 'hash': 'f7000cd3e9c76f2231b62b32182212203f73c08dfaa12bb16ffb529948a01399',
116
+ 'file': 'craft_refiner_CTW1500.pth',
117
+ },
118
+ 'craft': {
119
+ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/craft_mlt_25k.pth',
120
+ 'hash': '4a5efbfb48b4081100544e75e1e2b57f8de3d84f213004b14b85fd4b3748db17',
121
+ 'file': 'craft_mlt_25k.pth',
122
+ }
123
+ }
124
+
125
+ def __init__(self, *args, **kwargs):
126
+ os.makedirs(self.model_dir, exist_ok=True)
127
+ if os.path.exists('craft_mlt_25k.pth'):
128
+ shutil.move('craft_mlt_25k.pth', self._get_file_path('craft_mlt_25k.pth'))
129
+ if os.path.exists('craft_refiner_CTW1500.pth'):
130
+ shutil.move('craft_refiner_CTW1500.pth', self._get_file_path('craft_refiner_CTW1500.pth'))
131
+ super().__init__(*args, **kwargs)
132
+
133
+ async def _load(self, device: str):
134
+ self.model = CRAFT()
135
+ self.model.load_state_dict(copyStateDict(torch.load(self._get_file_path('craft_mlt_25k.pth'), map_location='cpu')))
136
+ self.model.eval()
137
+ self.model_refiner = RefineNet()
138
+ self.model_refiner.load_state_dict(copyStateDict(torch.load(self._get_file_path('craft_refiner_CTW1500.pth'), map_location='cpu')))
139
+ self.model_refiner.eval()
140
+ self.device = device
141
+ if device == 'cuda' or device == 'mps':
142
+ self.model = self.model.to(self.device)
143
+ self.model_refiner = self.model_refiner.to(self.device)
144
+ global MODEL
145
+ MODEL = self.model
146
+
147
+ async def _unload(self):
148
+ del self.model
149
+
150
+ async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
151
+ unclip_ratio: float, verbose: bool = False):
152
+
153
+ img_resized, target_ratio, size_heatmap, pad_w, pad_h = imgproc.resize_aspect_ratio(image, detect_size, interpolation = cv2.INTER_CUBIC, mag_ratio = 1)
154
+ ratio_h = ratio_w = 1 / target_ratio
155
+
156
+ # preprocessing
157
+ x = imgproc.normalizeMeanVariance(img_resized)
158
+ x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
159
+ x = x.unsqueeze(0).to(self.device) # [c, h, w] to [b, c, h, w]
160
+
161
+ with torch.no_grad() :
162
+ y, feature = self.model(x)
163
+
164
+ # make score and link map
165
+ score_text = y[0,:,:,0].cpu().data.numpy()
166
+ score_link = y[0,:,:,1].cpu().data.numpy()
167
+
168
+ # refine link
169
+ y_refiner = self.model_refiner(y, feature)
170
+ score_link = y_refiner[0,:,:,0].cpu().data.numpy()
171
+
172
+ # Post-processing
173
+ boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, box_threshold, box_threshold, True)
174
+
175
+ # coordinate adjustment
176
+ boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
177
+ polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
178
+ for k in range(len(polys)):
179
+ if polys[k] is None: polys[k] = boxes[k]
180
+
181
+ mask = np.zeros(shape = (image.shape[0], image.shape[1]), dtype = np.uint8)
182
+
183
+ for poly in polys :
184
+ mask = cv2.fillPoly(mask, [poly.reshape((-1, 1, 2)).astype(np.int32)], color = 255)
185
+
186
+ polys_ret = []
187
+ for i in range(len(polys)) :
188
+ poly = MultiPoint(polys[i])
189
+ if poly.area > 10 :
190
+ rect = poly.minimum_rotated_rectangle
191
+ rect = affinity.scale(rect, xfact = 1.2, yfact = 1.2)
192
+ polys_ret.append(np.roll(np.asarray(list(rect.exterior.coords)[:4]), 2))
193
+
194
+ kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
195
+ mask = cv2.dilate(mask, kern)
196
+
197
+ textlines = [Quadrilateral(pts.astype(int), '', 1) for pts in polys_ret]
198
+ textlines = list(filter(lambda q: q.area > 16, textlines))
199
+
200
+ return textlines, mask, None
manga_translator/detection/craft_utils/refiner.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+ MIT License
4
+ """
5
+
6
+ # -*- coding: utf-8 -*-
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.autograd import Variable
11
+ from .vgg16_bn import init_weights
12
+
13
+
14
+ class RefineNet(nn.Module):
15
+ def __init__(self):
16
+ super(RefineNet, self).__init__()
17
+
18
+ self.last_conv = nn.Sequential(
19
+ nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
20
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
21
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
22
+ )
23
+
24
+ self.aspp1 = nn.Sequential(
25
+ nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
26
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
27
+ nn.Conv2d(128, 1, kernel_size=1)
28
+ )
29
+
30
+ self.aspp2 = nn.Sequential(
31
+ nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
32
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
33
+ nn.Conv2d(128, 1, kernel_size=1)
34
+ )
35
+
36
+ self.aspp3 = nn.Sequential(
37
+ nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
38
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
39
+ nn.Conv2d(128, 1, kernel_size=1)
40
+ )
41
+
42
+ self.aspp4 = nn.Sequential(
43
+ nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
44
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
45
+ nn.Conv2d(128, 1, kernel_size=1)
46
+ )
47
+
48
+ init_weights(self.last_conv.modules())
49
+ init_weights(self.aspp1.modules())
50
+ init_weights(self.aspp2.modules())
51
+ init_weights(self.aspp3.modules())
52
+ init_weights(self.aspp4.modules())
53
+
54
+ def forward(self, y, upconv4):
55
+ refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1)
56
+ refine = self.last_conv(refine)
57
+
58
+ aspp1 = self.aspp1(refine)
59
+ aspp2 = self.aspp2(refine)
60
+ aspp3 = self.aspp3(refine)
61
+ aspp4 = self.aspp4(refine)
62
+
63
+ #out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
64
+ out = aspp1 + aspp2 + aspp3 + aspp4
65
+ return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1)
manga_translator/detection/craft_utils/vgg16_bn.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.init as init
6
+ from torchvision import models
7
+
8
+ def init_weights(modules):
9
+ for m in modules:
10
+ if isinstance(m, nn.Conv2d):
11
+ init.xavier_uniform_(m.weight.data)
12
+ if m.bias is not None:
13
+ m.bias.data.zero_()
14
+ elif isinstance(m, nn.BatchNorm2d):
15
+ m.weight.data.fill_(1)
16
+ m.bias.data.zero_()
17
+ elif isinstance(m, nn.Linear):
18
+ m.weight.data.normal_(0, 0.01)
19
+ m.bias.data.zero_()
20
+
21
+ class vgg16_bn(torch.nn.Module):
22
+ def __init__(self, pretrained=True, freeze=True):
23
+ super(vgg16_bn, self).__init__()
24
+ vgg_pretrained_features = models.vgg16_bn().features
25
+ self.slice1 = torch.nn.Sequential()
26
+ self.slice2 = torch.nn.Sequential()
27
+ self.slice3 = torch.nn.Sequential()
28
+ self.slice4 = torch.nn.Sequential()
29
+ self.slice5 = torch.nn.Sequential()
30
+ for x in range(12): # conv2_2
31
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
32
+ for x in range(12, 19): # conv3_3
33
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
34
+ for x in range(19, 29): # conv4_3
35
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
36
+ for x in range(29, 39): # conv5_3
37
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
38
+
39
+ # fc6, fc7 without atrous conv
40
+ self.slice5 = torch.nn.Sequential(
41
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
42
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
43
+ nn.Conv2d(1024, 1024, kernel_size=1)
44
+ )
45
+
46
+ if not pretrained:
47
+ init_weights(self.slice1.modules())
48
+ init_weights(self.slice2.modules())
49
+ init_weights(self.slice3.modules())
50
+ init_weights(self.slice4.modules())
51
+
52
+ init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
53
+
54
+ if freeze:
55
+ for param in self.slice1.parameters(): # only first conv
56
+ param.requires_grad= False
57
+
58
+ def forward(self, X):
59
+ h = self.slice1(X)
60
+ h_relu2_2 = h
61
+ h = self.slice2(h)
62
+ h_relu3_2 = h
63
+ h = self.slice3(h)
64
+ h_relu4_3 = h
65
+ h = self.slice4(h)
66
+ h_relu5_3 = h
67
+ h = self.slice5(h)
68
+ h_fc7 = h
69
+ vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
70
+ out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
71
+ return out
manga_translator/detection/ctd.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import numpy as np
4
+ import einops
5
+ from typing import Union, Tuple
6
+ import cv2
7
+ import torch
8
+
9
+ from .ctd_utils.basemodel import TextDetBase, TextDetBaseDNN
10
+ from .ctd_utils.utils.yolov5_utils import non_max_suppression
11
+ from .ctd_utils.utils.db_utils import SegDetectorRepresenter
12
+ from .ctd_utils.utils.imgproc_utils import letterbox
13
+ from .ctd_utils.textmask import REFINEMASK_INPAINT, refine_mask
14
+ from .common import OfflineDetector
15
+ from ..utils import Quadrilateral, det_rearrange_forward
16
+
17
+ def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
18
+ if bgr2rgb:
19
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
20
+ img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64)
21
+ if to_tensor:
22
+ img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
23
+ img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255
24
+ if to_tensor:
25
+ img_in = torch.from_numpy(img_in).to(device)
26
+ if half:
27
+ img_in = img_in.half()
28
+ return img_in, ratio, int(dw), int(dh)
29
+
30
+ def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None):
31
+ # img = img.permute(1, 2, 0)
32
+ if isinstance(img, torch.Tensor):
33
+ img = img.squeeze_()
34
+ if img.device != 'cpu':
35
+ img = img.detach().cpu()
36
+ img = img.numpy()
37
+ else:
38
+ img = img.squeeze()
39
+ if thresh is not None:
40
+ img = img > thresh
41
+ img = img * 255
42
+ # if isinstance(img, torch.Tensor):
43
+
44
+ return img.astype(np.uint8)
45
+
46
+ def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None):
47
+ det = non_max_suppression(det, conf_thresh, nms_thresh)[0]
48
+ # bbox = det[..., 0:4]
49
+ if det.device != 'cpu':
50
+ det = det.detach_().cpu().numpy()
51
+ det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0]
52
+ det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1]
53
+ if sort_func is not None:
54
+ det = sort_func(det)
55
+
56
+ blines = det[..., 0:4].astype(np.int32)
57
+ confs = np.round(det[..., 4], 3)
58
+ cls = det[..., 5].astype(np.int32)
59
+ return blines, cls, confs
60
+
61
+
62
+ class ComicTextDetector(OfflineDetector):
63
+ _MODEL_MAPPING = {
64
+ 'model-cuda': {
65
+ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt',
66
+ 'hash': '1f90fa60aeeb1eb82e2ac1167a66bf139a8a61b8780acd351ead55268540cccb',
67
+ 'file': '.',
68
+ },
69
+ 'model-cpu': {
70
+ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/comictextdetector.pt.onnx',
71
+ 'hash': '1a86ace74961413cbd650002e7bb4dcec4980ffa21b2f19b86933372071d718f',
72
+ 'file': '.',
73
+ },
74
+ }
75
+
76
+ def __init__(self, *args, **kwargs):
77
+ os.makedirs(self.model_dir, exist_ok=True)
78
+ if os.path.exists('comictextdetector.pt'):
79
+ shutil.move('comictextdetector.pt', self._get_file_path('comictextdetector.pt'))
80
+ if os.path.exists('comictextdetector.pt.onnx'):
81
+ shutil.move('comictextdetector.pt.onnx', self._get_file_path('comictextdetector.pt.onnx'))
82
+ super().__init__(*args, **kwargs)
83
+
84
+ async def _load(self, device: str, input_size=1024, half=False, nms_thresh=0.35, conf_thresh=0.4):
85
+ self.device = device
86
+ if self.device == 'cuda' or self.device == 'mps':
87
+ self.model = TextDetBase(self._get_file_path('comictextdetector.pt'), device=self.device, act='leaky')
88
+ self.model.to(self.device)
89
+ self.backend = 'torch'
90
+ else:
91
+ model_path = self._get_file_path('comictextdetector.pt.onnx')
92
+ self.model = cv2.dnn.readNetFromONNX(model_path)
93
+ self.model = TextDetBaseDNN(input_size, model_path)
94
+ self.backend = 'opencv'
95
+
96
+ if isinstance(input_size, int):
97
+ input_size = (input_size, input_size)
98
+ self.input_size = input_size
99
+ self.half = half
100
+ self.conf_thresh = conf_thresh
101
+ self.nms_thresh = nms_thresh
102
+ self.seg_rep = SegDetectorRepresenter(thresh=0.3)
103
+
104
+ async def _unload(self):
105
+ del self.model
106
+
107
+ def det_batch_forward_ctd(self, batch: np.ndarray, device: str) -> Tuple[np.ndarray, np.ndarray]:
108
+ if isinstance(self.model, TextDetBase):
109
+ batch = einops.rearrange(batch.astype(np.float32) / 255., 'n h w c -> n c h w')
110
+ batch = torch.from_numpy(batch).to(device)
111
+ _, mask, lines = self.model(batch)
112
+ mask = mask.detach().cpu().numpy()
113
+ lines = lines.detach().cpu().numpy()
114
+ elif isinstance(self.model, TextDetBaseDNN):
115
+ mask_lst, line_lst = [], []
116
+ for b in batch:
117
+ _, mask, lines = self.model(b)
118
+ if mask.shape[1] == 2: # some version of opencv spit out reversed result
119
+ tmp = mask
120
+ mask = lines
121
+ lines = tmp
122
+ mask_lst.append(mask)
123
+ line_lst.append(lines)
124
+ lines, mask = np.concatenate(line_lst, 0), np.concatenate(mask_lst, 0)
125
+ else:
126
+ raise NotImplementedError
127
+ return lines, mask
128
+
129
+ @torch.no_grad()
130
+ async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
131
+ unclip_ratio: float, verbose: bool = False):
132
+
133
+ # keep_undetected_mask = False
134
+ # refine_mode = REFINEMASK_INPAINT
135
+
136
+ im_h, im_w = image.shape[:2]
137
+ lines_map, mask = det_rearrange_forward(image, self.det_batch_forward_ctd, self.input_size[0], 4, self.device, verbose)
138
+ # blks = []
139
+ # resize_ratio = [1, 1]
140
+ if lines_map is None:
141
+ img_in, ratio, dw, dh = preprocess_img(image, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
142
+ blks, mask, lines_map = self.model(img_in)
143
+
144
+ if self.backend == 'opencv':
145
+ if mask.shape[1] == 2: # some version of opencv spit out reversed result
146
+ tmp = mask
147
+ mask = lines_map
148
+ lines_map = tmp
149
+ mask = mask.squeeze()
150
+ # resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh))
151
+ # blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
152
+ mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw]
153
+ lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw]
154
+
155
+ mask = postprocess_mask(mask)
156
+ lines, scores = self.seg_rep(None, lines_map, height=im_h, width=im_w)
157
+ box_thresh = 0.6
158
+ idx = np.where(scores[0] > box_thresh)
159
+ lines, scores = lines[0][idx], scores[0][idx]
160
+
161
+ # map output to input img
162
+ mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
163
+
164
+ # if lines.size == 0:
165
+ # lines = []
166
+ # else:
167
+ # lines = lines.astype(np.int32)
168
+
169
+ # YOLO was used for finding bboxes which to order the lines into. This is now solved
170
+ # through the textline merger, which seems to work more reliably.
171
+ # The YOLO language detection seems unnecessary as it could never be as good as
172
+ # using the OCR extracted string directly.
173
+ # Doing it for increasing the textline merge accuracy doesn't really work either,
174
+ # as the merge could be postponed until after the OCR finishes.
175
+
176
+ textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(lines, scores)]
177
+ mask_refined = refine_mask(image, mask, textlines, refine_mode=None)
178
+
179
+ return textlines, mask_refined, None
180
+
181
+ # blk_list = group_output(blks, lines, im_w, im_h, mask)
182
+ # mask_refined = refine_mask(image, mask, blk_list, refine_mode=refine_mode)
183
+ # if keep_undetected_mask:
184
+ # mask_refined = refine_undetected_mask(image, mask, mask_refined, blk_list, refine_mode=refine_mode)
185
+
186
+ # return blk_list, mask, mask_refined
manga_translator/detection/ctd_utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .basemodel import TextDetBase, TextDetBaseDNN
2
+ from .utils.yolov5_utils import non_max_suppression
3
+ from .utils.db_utils import SegDetectorRepresenter
4
+ from .utils.imgproc_utils import letterbox
5
+ from .textmask import refine_mask, refine_undetected_mask, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION
manga_translator/detection/ctd_utils/basemodel.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import copy
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .utils.yolov5_utils import fuse_conv_and_bn
7
+ from .utils.weight_init import init_weights
8
+ from .yolov5.yolo import load_yolov5_ckpt
9
+ from .yolov5.common import C3, Conv
10
+
11
+ TEXTDET_MASK = 0
12
+ TEXTDET_DET = 1
13
+ TEXTDET_INFERENCE = 2
14
+
15
+ class double_conv_up_c3(nn.Module):
16
+ def __init__(self, in_ch, mid_ch, out_ch, act=True):
17
+ super(double_conv_up_c3, self).__init__()
18
+ self.conv = nn.Sequential(
19
+ C3(in_ch+mid_ch, mid_ch, act=act),
20
+ nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False),
21
+ nn.BatchNorm2d(out_ch),
22
+ nn.ReLU(inplace=True),
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.conv(x)
27
+
28
+ class double_conv_c3(nn.Module):
29
+ def __init__(self, in_ch, out_ch, stride=1, act=True):
30
+ super(double_conv_c3, self).__init__()
31
+ if stride > 1:
32
+ self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None
33
+ self.conv = C3(in_ch, out_ch, act=act)
34
+
35
+ def forward(self, x):
36
+ if self.down is not None:
37
+ x = self.down(x)
38
+ x = self.conv(x)
39
+ return x
40
+
41
+ class UnetHead(nn.Module):
42
+ def __init__(self, act=True) -> None:
43
+
44
+ super(UnetHead, self).__init__()
45
+ self.down_conv1 = double_conv_c3(512, 512, 2, act=act)
46
+ self.upconv0 = double_conv_up_c3(0, 512, 256, act=act)
47
+ self.upconv2 = double_conv_up_c3(256, 512, 256, act=act)
48
+ self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
49
+ self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
50
+ self.upconv5 = double_conv_up_c3(64, 128, 64, act=act)
51
+ self.upconv6 = nn.Sequential(
52
+ nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False),
53
+ nn.Sigmoid()
54
+ )
55
+
56
+ def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK):
57
+ # input: 640@3
58
+ d10 = self.down_conv1(f3) # 512@10
59
+ u20 = self.upconv0(d10) # 256@10
60
+ u40 = self.upconv2(torch.cat([f20, u20], dim = 1)) # 256@40
61
+
62
+ if forward_mode == TEXTDET_DET:
63
+ return f80, f40, u40
64
+ else:
65
+ u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
66
+ u160 = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
67
+ u320 = self.upconv5(torch.cat([f160, u160], dim = 1)) # 64@320
68
+ mask = self.upconv6(u320)
69
+ if forward_mode == TEXTDET_MASK:
70
+ return mask
71
+ else:
72
+ return mask, [f80, f40, u40]
73
+
74
+ def init_weight(self, init_func):
75
+ self.apply(init_func)
76
+
77
+ class DBHead(nn.Module):
78
+ def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True):
79
+ super().__init__()
80
+ self.k = k
81
+ self.shrink_with_sigmoid = shrink_with_sigmoid
82
+ self.upconv3 = double_conv_up_c3(0, 512, 256, act=act)
83
+ self.upconv4 = double_conv_up_c3(128, 256, 128, act=act)
84
+ self.conv = nn.Sequential(
85
+ nn.Conv2d(128, in_channels, 1),
86
+ nn.BatchNorm2d(in_channels),
87
+ nn.ReLU(inplace=True)
88
+ )
89
+ self.binarize = nn.Sequential(
90
+ nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
91
+ nn.BatchNorm2d(in_channels // 4),
92
+ nn.ReLU(inplace=True),
93
+ nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
94
+ nn.BatchNorm2d(in_channels // 4),
95
+ nn.ReLU(inplace=True),
96
+ nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)
97
+ )
98
+ self.thresh = self._init_thresh(in_channels)
99
+
100
+ def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False):
101
+ shrink_with_sigmoid = self.shrink_with_sigmoid
102
+ u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80
103
+ x = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160
104
+ x = self.conv(x)
105
+ threshold_maps = self.thresh(x)
106
+ x = self.binarize(x)
107
+ shrink_maps = torch.sigmoid(x)
108
+
109
+ if self.training:
110
+ binary_maps = self.step_function(shrink_maps, threshold_maps)
111
+ if shrink_with_sigmoid:
112
+ return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
113
+ else:
114
+ return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1)
115
+ else:
116
+ if step_eval:
117
+ return self.step_function(shrink_maps, threshold_maps)
118
+ else:
119
+ return torch.cat((shrink_maps, threshold_maps), dim=1)
120
+
121
+ def init_weight(self, init_func):
122
+ self.apply(init_func)
123
+
124
+ def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
125
+ in_channels = inner_channels
126
+ if serial:
127
+ in_channels += 1
128
+ self.thresh = nn.Sequential(
129
+ nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
130
+ nn.BatchNorm2d(inner_channels // 4),
131
+ nn.ReLU(inplace=True),
132
+ self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
133
+ nn.BatchNorm2d(inner_channels // 4),
134
+ nn.ReLU(inplace=True),
135
+ self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
136
+ nn.Sigmoid())
137
+ return self.thresh
138
+
139
+ def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
140
+ if smooth:
141
+ inter_out_channels = out_channels
142
+ if out_channels == 1:
143
+ inter_out_channels = in_channels
144
+ module_list = [
145
+ nn.Upsample(scale_factor=2, mode='nearest'),
146
+ nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
147
+ if out_channels == 1:
148
+ module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
149
+ return nn.Sequential(module_list)
150
+ else:
151
+ return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
152
+
153
+ def step_function(self, x, y):
154
+ return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
155
+
156
+ class TextDetector(nn.Module):
157
+ def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True):
158
+ super(TextDetector, self).__init__()
159
+
160
+ yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location)
161
+ yolov5s_backbone.eval()
162
+ out_indices = [1, 3, 5, 7, 9]
163
+ yolov5s_backbone.out_indices = out_indices
164
+ yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1]
165
+ self.act = act
166
+ self.seg_net = UnetHead(act=act)
167
+ self.backbone = yolov5s_backbone
168
+ self.dbnet = None
169
+ self.forward_mode = forward_mode
170
+
171
+ def train_mask(self):
172
+ self.forward_mode = TEXTDET_MASK
173
+ self.backbone.eval()
174
+ self.seg_net.train()
175
+
176
+ def initialize_db(self, unet_weights):
177
+ self.dbnet = DBHead(64, act=self.act)
178
+ self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights'])
179
+ self.dbnet.init_weight(init_weights)
180
+ self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3)
181
+ self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4)
182
+ del self.seg_net.upconv3
183
+ del self.seg_net.upconv4
184
+ del self.seg_net.upconv5
185
+ del self.seg_net.upconv6
186
+ # del self.seg_net.conv_mask
187
+
188
+ def train_db(self):
189
+ self.forward_mode = TEXTDET_DET
190
+ self.backbone.eval()
191
+ self.seg_net.eval()
192
+ self.dbnet.train()
193
+
194
+ def forward(self, x):
195
+ forward_mode = self.forward_mode
196
+ with torch.no_grad():
197
+ outs = self.backbone(x)
198
+ if forward_mode == TEXTDET_MASK:
199
+ return self.seg_net(*outs, forward_mode=forward_mode)
200
+ elif forward_mode == TEXTDET_DET:
201
+ with torch.no_grad():
202
+ outs = self.seg_net(*outs, forward_mode=forward_mode)
203
+ return self.dbnet(*outs)
204
+
205
+ def get_base_det_models(model_path, device='cpu', half=False, act='leaky'):
206
+ textdetector_dict = torch.load(model_path, map_location=device)
207
+ blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location=device)
208
+ text_seg = UnetHead(act=act)
209
+ text_seg.load_state_dict(textdetector_dict['text_seg'])
210
+ text_det = DBHead(64, act=act)
211
+ text_det.load_state_dict(textdetector_dict['text_det'])
212
+ if half:
213
+ return blk_det.eval().half(), text_seg.eval().half(), text_det.eval().half()
214
+ return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device)
215
+
216
+ class TextDetBase(nn.Module):
217
+ def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'):
218
+ super(TextDetBase, self).__init__()
219
+ self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act)
220
+ if fuse:
221
+ self.fuse()
222
+
223
+ def fuse(self):
224
+ def _fuse(model):
225
+ for m in model.modules():
226
+ if isinstance(m, (Conv)) and hasattr(m, 'bn'):
227
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
228
+ delattr(m, 'bn') # remove batchnorm
229
+ m.forward = m.forward_fuse # update forward
230
+ return model
231
+ self.text_seg = _fuse(self.text_seg)
232
+ self.text_det = _fuse(self.text_det)
233
+
234
+ def forward(self, features):
235
+ blks, features = self.blk_det(features, detect=True)
236
+ mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE)
237
+ lines = self.text_det(*features, step_eval=False)
238
+ return blks[0], mask, lines
239
+
240
+ class TextDetBaseDNN:
241
+ def __init__(self, input_size, model_path):
242
+ self.input_size = input_size
243
+ self.model = cv2.dnn.readNetFromONNX(model_path)
244
+ self.uoln = self.model.getUnconnectedOutLayersNames()
245
+
246
+ def __call__(self, im_in):
247
+ blob = cv2.dnn.blobFromImage(im_in, scalefactor=1 / 255.0, size=(self.input_size, self.input_size))
248
+ self.model.setInput(blob)
249
+ blks, mask, lines_map = self.model.forward(self.uoln)
250
+ return blks, mask, lines_map
manga_translator/detection/ctd_utils/textmask.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from .utils.imgproc_utils import union_area, enlarge_window
6
+ from ...utils import TextBlock, Quadrilateral
7
+
8
+ WHITE = (255, 255, 255)
9
+ BLACK = (0, 0, 0)
10
+ LANG_ENG = 0
11
+ LANG_JPN = 1
12
+
13
+ REFINEMASK_INPAINT = 0
14
+ REFINEMASK_ANNOTATION = 1
15
+
16
+ def get_topk_color(color_list, bins, k=3, color_var=10, bin_tol=0.001):
17
+ idx = np.argsort(bins * -1)
18
+ color_list, bins = color_list[idx], bins[idx]
19
+ top_colors = [color_list[0]]
20
+ bin_tol = np.sum(bins) * bin_tol
21
+ if len(color_list) > 1:
22
+ for color, bin in zip(color_list[1:], bins[1:]):
23
+ if np.abs(np.array(top_colors) - color).min() > color_var:
24
+ top_colors.append(color)
25
+ if len(top_colors) >= k or bin < bin_tol:
26
+ break
27
+ return top_colors
28
+
29
+ def minxor_thresh(threshed, mask, dilate=False):
30
+ neg_threshed = 255 - threshed
31
+ e_size = 1
32
+ if dilate:
33
+ element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
34
+ neg_threshed = cv2.dilate(neg_threshed, element, iterations=1)
35
+ threshed = cv2.dilate(threshed, element, iterations=1)
36
+ neg_xor_sum = cv2.bitwise_xor(neg_threshed, mask).sum()
37
+ xor_sum = cv2.bitwise_xor(threshed, mask).sum()
38
+ if neg_xor_sum < xor_sum:
39
+ return neg_threshed, neg_xor_sum
40
+ else:
41
+ return threshed, xor_sum
42
+
43
+ def get_otsuthresh_masklist(img, pred_mask, per_channel=False) -> List[np.ndarray]:
44
+ channels = [img[..., 0], img[..., 1], img[..., 2]]
45
+ mask_list = []
46
+ for c in channels:
47
+ _, threshed = cv2.threshold(c, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
48
+ threshed, xor_sum = minxor_thresh(threshed, pred_mask, dilate=False)
49
+ mask_list.append([threshed, xor_sum])
50
+ mask_list.sort(key=lambda x: x[1])
51
+ if per_channel:
52
+ return mask_list
53
+ else:
54
+ return [mask_list[0]]
55
+
56
+ def get_topk_masklist(im_grey, pred_mask):
57
+ if len(im_grey.shape) == 3 and im_grey.shape[-1] == 3:
58
+ im_grey = cv2.cvtColor(im_grey, cv2.COLOR_BGR2GRAY)
59
+ msk = np.ascontiguousarray(pred_mask)
60
+ candidate_grey_px = im_grey[np.where(cv2.erode(msk, np.ones((3,3), np.uint8), iterations=1) > 127)]
61
+ bin, his = np.histogram(candidate_grey_px, bins=255)
62
+ topk_color = get_topk_color(his, bin, color_var=10, k=3)
63
+ color_range = 30
64
+ mask_list = list()
65
+ for ii, color in enumerate(topk_color):
66
+ c_top = min(color+color_range, 255)
67
+ c_bottom = c_top - 2 * color_range
68
+ threshed = cv2.inRange(im_grey, c_bottom, c_top)
69
+ threshed, xor_sum = minxor_thresh(threshed, msk)
70
+ mask_list.append([threshed, xor_sum])
71
+ return mask_list
72
+
73
+ def merge_mask_list(mask_list, pred_mask, blk: Quadrilateral = None, pred_thresh=30, text_window=None, filter_with_lines=False, refine_mode=REFINEMASK_INPAINT):
74
+ mask_list.sort(key=lambda x: x[1])
75
+ linemask = None
76
+ if blk is not None and filter_with_lines:
77
+ linemask = np.zeros_like(pred_mask)
78
+ lines = blk.pts.astype(np.int64)
79
+ for line in lines:
80
+ line[..., 0] -= text_window[0]
81
+ line[..., 1] -= text_window[1]
82
+ cv2.fillPoly(linemask, [line], 255)
83
+ linemask = cv2.dilate(linemask, np.ones((3, 3), np.uint8), iterations=3)
84
+
85
+ if pred_thresh > 0:
86
+ e_size = 1
87
+ element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
88
+ pred_mask = cv2.erode(pred_mask, element, iterations=1)
89
+ _, pred_mask = cv2.threshold(pred_mask, 60, 255, cv2.THRESH_BINARY)
90
+ connectivity = 8
91
+ mask_merged = np.zeros_like(pred_mask)
92
+ for ii, (candidate_mask, xor_sum) in enumerate(mask_list):
93
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(candidate_mask, connectivity, cv2.CV_16U)
94
+ for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
95
+ if label_index != 0: # skip background label
96
+ x, y, w, h, area = stat
97
+ if w * h < 3:
98
+ continue
99
+ x1, y1, x2, y2 = x, y, x+w, y+h
100
+ label_local = labels[y1: y2, x1: x2]
101
+ label_coordinates = np.where(label_local==label_index)
102
+ tmp_merged = np.zeros_like(label_local, np.uint8)
103
+ tmp_merged[label_coordinates] = 255
104
+ tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
105
+ xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
106
+ xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
107
+ if xor_merged < xor_origin:
108
+ mask_merged[y1: y2, x1: x2] = tmp_merged
109
+
110
+ if refine_mode == REFINEMASK_INPAINT:
111
+ mask_merged = cv2.dilate(mask_merged, np.ones((5, 5), np.uint8), iterations=1)
112
+ # fill holes
113
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(255-mask_merged, connectivity, cv2.CV_16U)
114
+ sorted_area = np.sort(stats[:, -1])
115
+ if len(sorted_area) > 1:
116
+ area_thresh = sorted_area[-2]
117
+ else:
118
+ area_thresh = sorted_area[-1]
119
+ for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
120
+ x, y, w, h, area = stat
121
+ if area < area_thresh:
122
+ x1, y1, x2, y2 = x, y, x+w, y+h
123
+ label_local = labels[y1: y2, x1: x2]
124
+ label_coordinates = np.where(label_local==label_index)
125
+ tmp_merged = np.zeros_like(label_local, np.uint8)
126
+ tmp_merged[label_coordinates] = 255
127
+ tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
128
+ xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
129
+ xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
130
+ if xor_merged < xor_origin:
131
+ mask_merged[y1: y2, x1: x2] = tmp_merged
132
+ return mask_merged
133
+
134
+
135
+ def refine_undetected_mask(img: np.ndarray, mask_pred: np.ndarray, mask_refined: np.ndarray, blk_list: List[TextBlock], refine_mode=REFINEMASK_INPAINT):
136
+ mask_pred[np.where(mask_refined > 30)] = 0
137
+ _, pred_mask_t = cv2.threshold(mask_pred, 30, 255, cv2.THRESH_BINARY)
138
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_mask_t, 4, cv2.CV_16U)
139
+ valid_labels = np.where(stats[:, -1] > 50)[0]
140
+ seg_blk_list = []
141
+ if len(valid_labels) > 0:
142
+ for lab_index in valid_labels[1:]:
143
+ x, y, w, h, area = stats[lab_index]
144
+ bx1, by1 = x, y
145
+ bx2, by2 = x+w, y+h
146
+ bbox = [bx1, by1, bx2, by2]
147
+ bbox_score = -1
148
+ for blk in blk_list:
149
+ bbox_s = union_area(blk.xyxy, bbox)
150
+ if bbox_s > bbox_score:
151
+ bbox_score = bbox_s
152
+ if bbox_score / w / h < 0.5:
153
+ seg_blk_list.append(TextBlock(bbox))
154
+ if len(seg_blk_list) > 0:
155
+ mask_refined = cv2.bitwise_or(mask_refined, refine_mask(img, mask_pred, seg_blk_list, refine_mode=refine_mode))
156
+ return mask_refined
157
+
158
+ def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[Quadrilateral], refine_mode: int = REFINEMASK_INPAINT) -> np.ndarray:
159
+ mask_refined = np.zeros_like(pred_mask)
160
+ for blk in blk_list:
161
+ bx1, by1, bx2, by2 = enlarge_window(blk.xyxy, img.shape[1], img.shape[0])
162
+ im = np.ascontiguousarray(img[by1: by2, bx1: bx2])
163
+ msk = np.ascontiguousarray(pred_mask[by1: by2, bx1: bx2])
164
+
165
+ mask_list = get_topk_masklist(im, msk)
166
+ mask_list += get_otsuthresh_masklist(im, msk, per_channel=False)
167
+ mask_merged = merge_mask_list(mask_list, msk, blk=blk, text_window=[bx1, by1, bx2, by2], refine_mode=refine_mode)
168
+ mask_refined[by1: by2, bx1: bx2] = cv2.bitwise_or(mask_refined[by1: by2, bx1: bx2], mask_merged)
169
+ # cv2.imshow('im', im)
170
+ # cv2.imshow('msk', msk)
171
+ # cv2.imshow('mask_refined', mask_refined[by1: by2, bx1: bx2])
172
+ # cv2.waitKey(0)
173
+
174
+ return mask_refined
manga_translator/detection/ctd_utils/utils/db_utils.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import pyclipper
4
+ from shapely.geometry import Polygon
5
+ from collections import namedtuple
6
+ import warnings
7
+ import torch
8
+ warnings.filterwarnings('ignore')
9
+
10
+
11
+ def iou_rotate(box_a, box_b, method='union'):
12
+ rect_a = cv2.minAreaRect(box_a)
13
+ rect_b = cv2.minAreaRect(box_b)
14
+ r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b)
15
+ if r1[0] == 0:
16
+ return 0
17
+ else:
18
+ inter_area = cv2.contourArea(r1[1])
19
+ area_a = cv2.contourArea(box_a)
20
+ area_b = cv2.contourArea(box_b)
21
+ union_area = area_a + area_b - inter_area
22
+ if union_area == 0 or inter_area == 0:
23
+ return 0
24
+ if method == 'union':
25
+ iou = inter_area / union_area
26
+ elif method == 'intersection':
27
+ iou = inter_area / min(area_a, area_b)
28
+ else:
29
+ raise NotImplementedError
30
+ return iou
31
+
32
+ class SegDetectorRepresenter():
33
+ def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5):
34
+ self.min_size = 3
35
+ self.thresh = thresh
36
+ self.box_thresh = box_thresh
37
+ self.max_candidates = max_candidates
38
+ self.unclip_ratio = unclip_ratio
39
+
40
+ def __call__(self, batch, pred, is_output_polygon=False, height=None, width=None):
41
+ '''
42
+ batch: (image, polygons, ignore_tags
43
+ batch: a dict produced by dataloaders.
44
+ image: tensor of shape (N, C, H, W).
45
+ polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
46
+ ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
47
+ shape: the original shape of images.
48
+ filename: the original filenames of images.
49
+ pred:
50
+ binary: text region segmentation map, with shape (N, H, W)
51
+ thresh: [if exists] thresh hold prediction with shape (N, H, W)
52
+ thresh_binary: [if exists] binarized with threshold, (N, H, W)
53
+ '''
54
+ pred = pred[:, 0, :, :]
55
+ segmentation = self.binarize(pred)
56
+ boxes_batch = []
57
+ scores_batch = []
58
+ # print(pred.size())
59
+ batch_size = pred.size(0) if isinstance(pred, torch.Tensor) else pred.shape[0]
60
+
61
+ if height is None:
62
+ height = pred.shape[1]
63
+ if width is None:
64
+ width = pred.shape[2]
65
+
66
+ for batch_index in range(batch_size):
67
+ if is_output_polygon:
68
+ boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
69
+ else:
70
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
71
+ boxes_batch.append(boxes)
72
+ scores_batch.append(scores)
73
+ return boxes_batch, scores_batch
74
+
75
+ def binarize(self, pred) -> np.ndarray:
76
+ return pred > self.thresh
77
+
78
+ def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
79
+ '''
80
+ _bitmap: single map with shape (H, W),
81
+ whose values are binarized as {0, 1}
82
+ '''
83
+
84
+ assert len(_bitmap.shape) == 2
85
+ bitmap = _bitmap.cpu().numpy() # The first channel
86
+ pred = pred.cpu().detach().numpy()
87
+ height, width = bitmap.shape
88
+ boxes = []
89
+ scores = []
90
+
91
+ contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
92
+
93
+ for contour in contours[:self.max_candidates]:
94
+ epsilon = 0.005 * cv2.arcLength(contour, True)
95
+ approx = cv2.approxPolyDP(contour, epsilon, True)
96
+ points = approx.reshape((-1, 2))
97
+ if points.shape[0] < 4:
98
+ continue
99
+ # _, sside = self.get_mini_boxes(contour)
100
+ # if sside < self.min_size:
101
+ # continue
102
+ score = self.box_score_fast(pred, contour.squeeze(1))
103
+ if self.box_thresh > score:
104
+ continue
105
+
106
+ if points.shape[0] > 2:
107
+ box = self.unclip(points, unclip_ratio=self.unclip_ratio)
108
+ if len(box) > 1:
109
+ continue
110
+ else:
111
+ continue
112
+ box = box.reshape(-1, 2)
113
+ _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
114
+ if sside < self.min_size + 2:
115
+ continue
116
+
117
+ if not isinstance(dest_width, int):
118
+ dest_width = dest_width.item()
119
+ dest_height = dest_height.item()
120
+
121
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
122
+ box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
123
+ boxes.append(box)
124
+ scores.append(score)
125
+ return boxes, scores
126
+
127
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
128
+ '''
129
+ _bitmap: single map with shape (H, W),
130
+ whose values are binarized as {0, 1}
131
+ '''
132
+
133
+ assert len(_bitmap.shape) == 2
134
+ if isinstance(pred, torch.Tensor):
135
+ bitmap = _bitmap.cpu().numpy() # The first channel
136
+ pred = pred.cpu().detach().numpy()
137
+ else:
138
+ bitmap = _bitmap
139
+ # cv2.imwrite('tmp.png', (bitmap*255).astype(np.uint8))
140
+ height, width = bitmap.shape
141
+ contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
142
+ num_contours = min(len(contours), self.max_candidates)
143
+ boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
144
+ scores = np.zeros((num_contours,), dtype=np.float32)
145
+
146
+ for index in range(num_contours):
147
+ contour = contours[index].squeeze(1)
148
+ points, sside = self.get_mini_boxes(contour)
149
+ # if sside < self.min_size:
150
+ # continue
151
+ if sside < 2:
152
+ continue
153
+ points = np.array(points)
154
+ score = self.box_score_fast(pred, contour)
155
+ # if self.box_thresh > score:
156
+ # continue
157
+
158
+ box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
159
+ box, sside = self.get_mini_boxes(box)
160
+ # if sside < 5:
161
+ # continue
162
+ box = np.array(box)
163
+ if not isinstance(dest_width, int):
164
+ dest_width = dest_width.item()
165
+ dest_height = dest_height.item()
166
+
167
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
168
+ box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
169
+ boxes[index, :, :] = box.astype(np.int16)
170
+ scores[index] = score
171
+ return boxes, scores
172
+
173
+ def unclip(self, box, unclip_ratio=1.5):
174
+ poly = Polygon(box)
175
+ distance = poly.area * unclip_ratio / poly.length
176
+ offset = pyclipper.PyclipperOffset()
177
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
178
+ expanded = np.array(offset.Execute(distance))
179
+ return expanded
180
+
181
+ def get_mini_boxes(self, contour):
182
+ bounding_box = cv2.minAreaRect(contour)
183
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
184
+
185
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
186
+ if points[1][1] > points[0][1]:
187
+ index_1 = 0
188
+ index_4 = 1
189
+ else:
190
+ index_1 = 1
191
+ index_4 = 0
192
+ if points[3][1] > points[2][1]:
193
+ index_2 = 2
194
+ index_3 = 3
195
+ else:
196
+ index_2 = 3
197
+ index_3 = 2
198
+
199
+ box = [points[index_1], points[index_2], points[index_3], points[index_4]]
200
+ return box, min(bounding_box[1])
201
+
202
+ def box_score_fast(self, bitmap, _box):
203
+ h, w = bitmap.shape[:2]
204
+ box = _box.copy()
205
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
206
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
207
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
208
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
209
+
210
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
211
+ box[:, 0] = box[:, 0] - xmin
212
+ box[:, 1] = box[:, 1] - ymin
213
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
214
+ if bitmap.dtype == np.float16:
215
+ bitmap = bitmap.astype(np.float32)
216
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
217
+
218
+ class AverageMeter(object):
219
+ """Computes and stores the average and current value"""
220
+
221
+ def __init__(self):
222
+ self.reset()
223
+
224
+ def reset(self):
225
+ self.val = 0
226
+ self.avg = 0
227
+ self.sum = 0
228
+ self.count = 0
229
+
230
+ def update(self, val, n=1):
231
+ self.val = val
232
+ self.sum += val * n
233
+ self.count += n
234
+ self.avg = self.sum / self.count
235
+ return self
236
+
237
+
238
+ class DetectionIoUEvaluator(object):
239
+ def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5):
240
+ self.is_output_polygon = is_output_polygon
241
+ self.iou_constraint = iou_constraint
242
+ self.area_precision_constraint = area_precision_constraint
243
+
244
+ def evaluate_image(self, gt, pred):
245
+
246
+ def get_union(pD, pG):
247
+ return Polygon(pD).union(Polygon(pG)).area
248
+
249
+ def get_intersection_over_union(pD, pG):
250
+ return get_intersection(pD, pG) / get_union(pD, pG)
251
+
252
+ def get_intersection(pD, pG):
253
+ return Polygon(pD).intersection(Polygon(pG)).area
254
+
255
+ def compute_ap(confList, matchList, numGtCare):
256
+ correct = 0
257
+ AP = 0
258
+ if len(confList) > 0:
259
+ confList = np.array(confList)
260
+ matchList = np.array(matchList)
261
+ sorted_ind = np.argsort(-confList)
262
+ confList = confList[sorted_ind]
263
+ matchList = matchList[sorted_ind]
264
+ for n in range(len(confList)):
265
+ match = matchList[n]
266
+ if match:
267
+ correct += 1
268
+ AP += float(correct) / (n + 1)
269
+
270
+ if numGtCare > 0:
271
+ AP /= numGtCare
272
+
273
+ return AP
274
+
275
+ perSampleMetrics = {}
276
+
277
+ matchedSum = 0
278
+
279
+ Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
280
+
281
+ numGlobalCareGt = 0
282
+ numGlobalCareDet = 0
283
+
284
+ arrGlobalConfidences = []
285
+ arrGlobalMatches = []
286
+
287
+ recall = 0
288
+ precision = 0
289
+ hmean = 0
290
+
291
+ detMatched = 0
292
+
293
+ iouMat = np.empty([1, 1])
294
+
295
+ gtPols = []
296
+ detPols = []
297
+
298
+ gtPolPoints = []
299
+ detPolPoints = []
300
+
301
+ # Array of Ground Truth Polygons' keys marked as don't Care
302
+ gtDontCarePolsNum = []
303
+ # Array of Detected Polygons' matched with a don't Care GT
304
+ detDontCarePolsNum = []
305
+
306
+ pairs = []
307
+ detMatchedNums = []
308
+
309
+ arrSampleConfidences = []
310
+ arrSampleMatch = []
311
+
312
+ evaluationLog = ""
313
+
314
+ for n in range(len(gt)):
315
+ points = gt[n]['points']
316
+ # transcription = gt[n]['text']
317
+ dontCare = gt[n]['ignore']
318
+
319
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
320
+ continue
321
+
322
+ gtPol = points
323
+ gtPols.append(gtPol)
324
+ gtPolPoints.append(points)
325
+ if dontCare:
326
+ gtDontCarePolsNum.append(len(gtPols) - 1)
327
+
328
+ evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(
329
+ gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
330
+
331
+ for n in range(len(pred)):
332
+ points = pred[n]['points']
333
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
334
+ continue
335
+
336
+ detPol = points
337
+ detPols.append(detPol)
338
+ detPolPoints.append(points)
339
+ if len(gtDontCarePolsNum) > 0:
340
+ for dontCarePol in gtDontCarePolsNum:
341
+ dontCarePol = gtPols[dontCarePol]
342
+ intersected_area = get_intersection(dontCarePol, detPol)
343
+ pdDimensions = Polygon(detPol).area
344
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
345
+ if (precision > self.area_precision_constraint):
346
+ detDontCarePolsNum.append(len(detPols) - 1)
347
+ break
348
+
349
+ evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(
350
+ detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
351
+
352
+ if len(gtPols) > 0 and len(detPols) > 0:
353
+ # Calculate IoU and precision matrixs
354
+ outputShape = [len(gtPols), len(detPols)]
355
+ iouMat = np.empty(outputShape)
356
+ gtRectMat = np.zeros(len(gtPols), np.int8)
357
+ detRectMat = np.zeros(len(detPols), np.int8)
358
+ if self.is_output_polygon:
359
+ for gtNum in range(len(gtPols)):
360
+ for detNum in range(len(detPols)):
361
+ pG = gtPols[gtNum]
362
+ pD = detPols[detNum]
363
+ iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
364
+ else:
365
+ # gtPols = np.float32(gtPols)
366
+ # detPols = np.float32(detPols)
367
+ for gtNum in range(len(gtPols)):
368
+ for detNum in range(len(detPols)):
369
+ pG = np.float32(gtPols[gtNum])
370
+ pD = np.float32(detPols[detNum])
371
+ iouMat[gtNum, detNum] = iou_rotate(pD, pG)
372
+ for gtNum in range(len(gtPols)):
373
+ for detNum in range(len(detPols)):
374
+ if gtRectMat[gtNum] == 0 and detRectMat[
375
+ detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
376
+ if iouMat[gtNum, detNum] > self.iou_constraint:
377
+ gtRectMat[gtNum] = 1
378
+ detRectMat[detNum] = 1
379
+ detMatched += 1
380
+ pairs.append({'gt': gtNum, 'det': detNum})
381
+ detMatchedNums.append(detNum)
382
+ evaluationLog += "Match GT #" + \
383
+ str(gtNum) + " with Det #" + str(detNum) + "\n"
384
+
385
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
386
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
387
+ if numGtCare == 0:
388
+ recall = float(1)
389
+ precision = float(0) if numDetCare > 0 else float(1)
390
+ else:
391
+ recall = float(detMatched) / numGtCare
392
+ precision = 0 if numDetCare == 0 else float(
393
+ detMatched) / numDetCare
394
+
395
+ hmean = 0 if (precision + recall) == 0 else 2.0 * \
396
+ precision * recall / (precision + recall)
397
+
398
+ matchedSum += detMatched
399
+ numGlobalCareGt += numGtCare
400
+ numGlobalCareDet += numDetCare
401
+
402
+ perSampleMetrics = {
403
+ 'precision': precision,
404
+ 'recall': recall,
405
+ 'hmean': hmean,
406
+ 'pairs': pairs,
407
+ 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
408
+ 'gtPolPoints': gtPolPoints,
409
+ 'detPolPoints': detPolPoints,
410
+ 'gtCare': numGtCare,
411
+ 'detCare': numDetCare,
412
+ 'gtDontCare': gtDontCarePolsNum,
413
+ 'detDontCare': detDontCarePolsNum,
414
+ 'detMatched': detMatched,
415
+ 'evaluationLog': evaluationLog
416
+ }
417
+
418
+ return perSampleMetrics
419
+
420
+ def combine_results(self, results):
421
+ numGlobalCareGt = 0
422
+ numGlobalCareDet = 0
423
+ matchedSum = 0
424
+ for result in results:
425
+ numGlobalCareGt += result['gtCare']
426
+ numGlobalCareDet += result['detCare']
427
+ matchedSum += result['detMatched']
428
+
429
+ methodRecall = 0 if numGlobalCareGt == 0 else float(
430
+ matchedSum) / numGlobalCareGt
431
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(
432
+ matchedSum) / numGlobalCareDet
433
+ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
434
+ methodRecall * methodPrecision / (
435
+ methodRecall + methodPrecision)
436
+
437
+ methodMetrics = {'precision': methodPrecision,
438
+ 'recall': methodRecall, 'hmean': methodHmean}
439
+
440
+ return methodMetrics
441
+
442
+ class QuadMetric():
443
+ def __init__(self, is_output_polygon=False):
444
+ self.is_output_polygon = is_output_polygon
445
+ self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon)
446
+
447
+ def measure(self, batch, output, box_thresh=0.6):
448
+ '''
449
+ batch: (image, polygons, ignore_tags
450
+ batch: a dict produced by dataloaders.
451
+ image: tensor of shape (N, C, H, W).
452
+ polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
453
+ ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
454
+ shape: the original shape of images.
455
+ filename: the original filenames of images.
456
+ output: (polygons, ...)
457
+ '''
458
+ results = []
459
+ gt_polyons_batch = batch['text_polys']
460
+ ignore_tags_batch = batch['ignore_tags']
461
+ pred_polygons_batch = np.array(output[0])
462
+ pred_scores_batch = np.array(output[1])
463
+ for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch):
464
+ gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))]
465
+ if self.is_output_polygon:
466
+ pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))]
467
+ else:
468
+ pred = []
469
+ # print(pred_polygons.shape)
470
+ for i in range(pred_polygons.shape[0]):
471
+ if pred_scores[i] >= box_thresh:
472
+ # print(pred_polygons[i,:,:].tolist())
473
+ pred.append(dict(points=pred_polygons[i, :, :].astype(np.int32)))
474
+ # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])]
475
+ results.append(self.evaluator.evaluate_image(gt, pred))
476
+ return results
477
+
478
+ def validate_measure(self, batch, output, box_thresh=0.6):
479
+ return self.measure(batch, output, box_thresh)
480
+
481
+ def evaluate_measure(self, batch, output):
482
+ return self.measure(batch, output), np.linspace(0, batch['image'].shape[0]).tolist()
483
+
484
+ def gather_measure(self, raw_metrics):
485
+ raw_metrics = [image_metrics
486
+ for batch_metrics in raw_metrics
487
+ for image_metrics in batch_metrics]
488
+
489
+ result = self.evaluator.combine_results(raw_metrics)
490
+
491
+ precision = AverageMeter()
492
+ recall = AverageMeter()
493
+ fmeasure = AverageMeter()
494
+
495
+ precision.update(result['precision'], n=len(raw_metrics))
496
+ recall.update(result['recall'], n=len(raw_metrics))
497
+ fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8)
498
+ fmeasure.update(fmeasure_score)
499
+
500
+ return {
501
+ 'precision': precision,
502
+ 'recall': recall,
503
+ 'fmeasure': fmeasure
504
+ }
505
+
506
+ def shrink_polygon_py(polygon, shrink_ratio):
507
+ """
508
+ 对框进行缩放,返回去的比例为1/shrink_ratio 即可
509
+ """
510
+ cx = polygon[:, 0].mean()
511
+ cy = polygon[:, 1].mean()
512
+ polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
513
+ polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
514
+ return polygon
515
+
516
+
517
+ def shrink_polygon_pyclipper(polygon, shrink_ratio):
518
+ from shapely.geometry import Polygon
519
+ import pyclipper
520
+ polygon_shape = Polygon(polygon)
521
+ distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
522
+ subject = [tuple(l) for l in polygon]
523
+ padding = pyclipper.PyclipperOffset()
524
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
525
+ shrunk = padding.Execute(-distance)
526
+ if shrunk == []:
527
+ shrunk = np.array(shrunk)
528
+ else:
529
+ shrunk = np.array(shrunk[0]).reshape(-1, 2)
530
+ return shrunk
531
+
532
+ class MakeShrinkMap():
533
+ r'''
534
+ Making binary mask from detection data with ICDAR format.
535
+ Typically following the process of class `MakeICDARData`.
536
+ '''
537
+
538
+ def __init__(self, min_text_size=4, shrink_ratio=0.4, shrink_type='pyclipper'):
539
+ shrink_func_dict = {'py': shrink_polygon_py, 'pyclipper': shrink_polygon_pyclipper}
540
+ self.shrink_func = shrink_func_dict[shrink_type]
541
+ self.min_text_size = min_text_size
542
+ self.shrink_ratio = shrink_ratio
543
+
544
+ def __call__(self, data: dict) -> dict:
545
+ """
546
+ 从scales中随机选择一个尺度,对图片和文本框进行缩放
547
+ :param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':}
548
+ :return:
549
+ """
550
+ image = data['imgs']
551
+ text_polys = data['text_polys']
552
+ ignore_tags = data['ignore_tags']
553
+
554
+ h, w = image.shape[:2]
555
+ text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
556
+ gt = np.zeros((h, w), dtype=np.float32)
557
+ mask = np.ones((h, w), dtype=np.float32)
558
+ for i in range(len(text_polys)):
559
+ polygon = text_polys[i]
560
+ height = max(polygon[:, 1]) - min(polygon[:, 1])
561
+ width = max(polygon[:, 0]) - min(polygon[:, 0])
562
+ if ignore_tags[i] or min(height, width) < self.min_text_size:
563
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
564
+ ignore_tags[i] = True
565
+ else:
566
+ shrunk = self.shrink_func(polygon, self.shrink_ratio)
567
+ if shrunk.size == 0:
568
+ cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
569
+ ignore_tags[i] = True
570
+ continue
571
+ cv2.fillPoly(gt, [shrunk.astype(np.int32)], 1)
572
+
573
+ data['shrink_map'] = gt
574
+ data['shrink_mask'] = mask
575
+ return data
576
+
577
+ def validate_polygons(self, polygons, ignore_tags, h, w):
578
+ '''
579
+ polygons (numpy.array, required): of shape (num_instances, num_points, 2)
580
+ '''
581
+ if len(polygons) == 0:
582
+ return polygons, ignore_tags
583
+ assert len(polygons) == len(ignore_tags)
584
+ for polygon in polygons:
585
+ polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
586
+ polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
587
+
588
+ for i in range(len(polygons)):
589
+ area = self.polygon_area(polygons[i])
590
+ if abs(area) < 1:
591
+ ignore_tags[i] = True
592
+ if area > 0:
593
+ polygons[i] = polygons[i][::-1, :]
594
+ return polygons, ignore_tags
595
+
596
+ def polygon_area(self, polygon):
597
+ return cv2.contourArea(polygon)
598
+
599
+
600
+ class MakeBorderMap():
601
+ def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7):
602
+ self.shrink_ratio = shrink_ratio
603
+ self.thresh_min = thresh_min
604
+ self.thresh_max = thresh_max
605
+
606
+ def __call__(self, data: dict) -> dict:
607
+ """
608
+ 从scales中随机选择一个尺度,对图片和文本框进行缩放
609
+ :param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':}
610
+ :return:
611
+ """
612
+ im = data['imgs']
613
+ text_polys = data['text_polys']
614
+ ignore_tags = data['ignore_tags']
615
+
616
+ canvas = np.zeros(im.shape[:2], dtype=np.float32)
617
+ mask = np.zeros(im.shape[:2], dtype=np.float32)
618
+
619
+ for i in range(len(text_polys)):
620
+ if ignore_tags[i]:
621
+ continue
622
+ self.draw_border_map(text_polys[i], canvas, mask=mask)
623
+ canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
624
+
625
+ data['threshold_map'] = canvas
626
+ data['threshold_mask'] = mask
627
+ return data
628
+
629
+ def draw_border_map(self, polygon, canvas, mask):
630
+ polygon = np.array(polygon)
631
+ assert polygon.ndim == 2
632
+ assert polygon.shape[1] == 2
633
+
634
+ polygon_shape = Polygon(polygon)
635
+ if polygon_shape.area <= 0:
636
+ return
637
+ distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
638
+ subject = [tuple(l) for l in polygon]
639
+ padding = pyclipper.PyclipperOffset()
640
+ padding.AddPath(subject, pyclipper.JT_ROUND,
641
+ pyclipper.ET_CLOSEDPOLYGON)
642
+
643
+ padded_polygon = np.array(padding.Execute(distance)[0])
644
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
645
+
646
+ xmin = padded_polygon[:, 0].min()
647
+ xmax = padded_polygon[:, 0].max()
648
+ ymin = padded_polygon[:, 1].min()
649
+ ymax = padded_polygon[:, 1].max()
650
+ width = xmax - xmin + 1
651
+ height = ymax - ymin + 1
652
+
653
+ polygon[:, 0] = polygon[:, 0] - xmin
654
+ polygon[:, 1] = polygon[:, 1] - ymin
655
+
656
+ xs = np.broadcast_to(
657
+ np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
658
+ ys = np.broadcast_to(
659
+ np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
660
+
661
+ distance_map = np.zeros(
662
+ (polygon.shape[0], height, width), dtype=np.float32)
663
+ for i in range(polygon.shape[0]):
664
+ j = (i + 1) % polygon.shape[0]
665
+ absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
666
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
667
+ distance_map = distance_map.min(axis=0)
668
+
669
+ xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
670
+ xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
671
+ ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
672
+ ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
673
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
674
+ 1 - distance_map[
675
+ ymin_valid - ymin:ymax_valid - ymax + height,
676
+ xmin_valid - xmin:xmax_valid - xmax + width],
677
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
678
+
679
+ def distance(self, xs, ys, point_1, point_2):
680
+ '''
681
+ compute the distance from point to a line
682
+ ys: coordinates in the first axis
683
+ xs: coordinates in the second axis
684
+ point_1, point_2: (x, y), the end of the line
685
+ '''
686
+ height, width = xs.shape[:2]
687
+ square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
688
+ square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
689
+ square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
690
+
691
+ cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2))
692
+ square_sin = 1 - np.square(cosin)
693
+ square_sin = np.nan_to_num(square_sin)
694
+
695
+ result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)
696
+ result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
697
+ return result
698
+
699
+ def extend_line(self, point_1, point_2, result):
700
+ ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))),
701
+ int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))))
702
+ cv2.line(result, tuple(ex_point_1), tuple(point_1), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
703
+ ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))),
704
+ int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))))
705
+ cv2.line(result, tuple(ex_point_2), tuple(point_2), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
706
+ return ex_point_1, ex_point_2
manga_translator/detection/ctd_utils/utils/imgproc_utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import random
4
+ from typing import List
5
+
6
+ def hex2bgr(hex):
7
+ gmask = 254 << 8
8
+ rmask = 254
9
+ b = hex >> 16
10
+ g = (hex & gmask) >> 8
11
+ r = hex & rmask
12
+ return np.stack([b, g, r]).transpose()
13
+
14
+ def union_area(bboxa, bboxb):
15
+ x1 = max(bboxa[0], bboxb[0])
16
+ y1 = max(bboxa[1], bboxb[1])
17
+ x2 = min(bboxa[2], bboxb[2])
18
+ y2 = min(bboxa[3], bboxb[3])
19
+ if y2 < y1 or x2 < x1:
20
+ return -1
21
+ return (y2 - y1) * (x2 - x1)
22
+
23
+ def get_yololabel_strings(clslist, labellist):
24
+ content = ''
25
+ for cls, xywh in zip(clslist, labellist):
26
+ content += str(int(cls)) + ' ' + ' '.join([str(e) for e in xywh]) + '\n'
27
+ if len(content) != 0:
28
+ content = content[:-1]
29
+ return content
30
+
31
+ # 4 points bbox to 8 points polygon
32
+ def xywh2xyxypoly(xywh, to_int=True):
33
+ xyxypoly = np.tile(xywh[:, [0, 1]], 4)
34
+ xyxypoly[:, [2, 4]] += xywh[:, [2]]
35
+ xyxypoly[:, [5, 7]] += xywh[:, [3]]
36
+ if to_int:
37
+ xyxypoly = xyxypoly.astype(np.int64)
38
+ return xyxypoly
39
+
40
+ def xyxy2yolo(xyxy, w: int, h: int):
41
+ if xyxy == [] or xyxy == np.array([]) or len(xyxy) == 0:
42
+ return None
43
+ if isinstance(xyxy, list):
44
+ xyxy = np.array(xyxy)
45
+ if len(xyxy.shape) == 1:
46
+ xyxy = np.array([xyxy])
47
+ yolo = np.copy(xyxy).astype(np.float64)
48
+ yolo[:, [0, 2]] = yolo[:, [0, 2]] / w
49
+ yolo[:, [1, 3]] = yolo[:, [1, 3]] / h
50
+ yolo[:, [2, 3]] -= yolo[:, [0, 1]]
51
+ yolo[:, [0, 1]] += yolo[:, [2, 3]] / 2
52
+ return yolo
53
+
54
+ def yolo_xywh2xyxy(xywh: np.array, w: int, h: int, to_int=True):
55
+ if xywh is None:
56
+ return None
57
+ if len(xywh) == 0:
58
+ return None
59
+ if len(xywh.shape) == 1:
60
+ xywh = np.array([xywh])
61
+ xywh[:, [0, 2]] *= w
62
+ xywh[:, [1, 3]] *= h
63
+ xywh[:, [0, 1]] -= xywh[:, [2, 3]] / 2
64
+ xywh[:, [2, 3]] += xywh[:, [0, 1]]
65
+ if to_int:
66
+ xywh = xywh.astype(np.int64)
67
+ return xywh
68
+
69
+ def letterbox(im, new_shape=(640, 640), color=(0, 0, 0), auto=False, scaleFill=False, scaleup=True, stride=128):
70
+ # Resize and pad image while meeting stride-multiple constraints
71
+ shape = im.shape[:2] # current shape [height, width]
72
+ if not isinstance(new_shape, tuple):
73
+ new_shape = (new_shape, new_shape)
74
+
75
+ # Scale ratio (new / old)
76
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
77
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
78
+ r = min(r, 1.0)
79
+
80
+ # Compute padding
81
+ ratio = r, r # width, height ratios
82
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
83
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
84
+ if auto: # minimum rectangle
85
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
86
+ elif scaleFill: # stretch
87
+ dw, dh = 0.0, 0.0
88
+ new_unpad = (new_shape[1], new_shape[0])
89
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
90
+
91
+ # dw /= 2 # divide padding into 2 sides
92
+ # dh /= 2
93
+ dh, dw = int(dh), int(dw)
94
+
95
+ if shape[::-1] != new_unpad: # resize
96
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
97
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
98
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
99
+ im = cv2.copyMakeBorder(im, 0, dh, 0, dw, cv2.BORDER_CONSTANT, value=color) # add border
100
+ return im, ratio, (dw, dh)
101
+
102
+ def resize_keepasp(im, new_shape=640, scaleup=True, interpolation=cv2.INTER_LINEAR, stride=None):
103
+ shape = im.shape[:2] # current shape [height, width]
104
+
105
+ if new_shape is not None:
106
+ if not isinstance(new_shape, tuple):
107
+ new_shape = (new_shape, new_shape)
108
+ else:
109
+ new_shape = shape
110
+
111
+ # Scale ratio (new / old)
112
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
113
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
114
+ r = min(r, 1.0)
115
+
116
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
117
+
118
+ if stride is not None:
119
+ h, w = new_unpad
120
+ if new_shape[0] % stride != 0:
121
+ new_h = (stride - (new_shape[0] % stride)) + h
122
+ else:
123
+ new_h = h
124
+ if w % stride != 0:
125
+ new_w = (stride - (w % stride)) + w
126
+ else:
127
+ new_w = w
128
+ new_unpad = (new_h, new_w)
129
+
130
+ if shape[::-1] != new_unpad: # resize
131
+ im = cv2.resize(im, new_unpad, interpolation=interpolation)
132
+ return im
133
+
134
+ def enlarge_window(rect, im_w, im_h, ratio=2.5, aspect_ratio=1.0) -> List:
135
+ assert ratio > 1.0
136
+
137
+ x1, y1, x2, y2 = rect
138
+ w = x2 - x1
139
+ h = y2 - y1
140
+
141
+ # https://numpy.org/doc/stable/reference/generated/numpy.roots.html
142
+ coeff = [aspect_ratio, w+h*aspect_ratio, (1-ratio)*w*h]
143
+ roots = np.roots(coeff)
144
+ roots.sort()
145
+ delta = int(round(roots[-1] / 2 ))
146
+ delta_w = int(delta * aspect_ratio)
147
+ delta_w = min(x1, im_w - x2, delta_w)
148
+ delta = min(y1, im_h - y2, delta)
149
+ rect = np.array([x1-delta_w, y1-delta, x2+delta_w, y2+delta], dtype=np.int64)
150
+ return rect.tolist()
151
+
152
+ def draw_connected_labels(num_labels, labels, stats, centroids, names="draw_connected_labels", skip_background=True):
153
+ labdraw = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)
154
+ max_ind = 0
155
+ if isinstance(num_labels, int):
156
+ num_labels = range(num_labels)
157
+
158
+ # for ind, lab in enumerate((range(num_labels))):
159
+ for lab in num_labels:
160
+ if skip_background and lab == 0:
161
+ continue
162
+ randcolor = (random.randint(0,255), random.randint(0,255), random.randint(0,255))
163
+ labdraw[np.where(labels==lab)] = randcolor
164
+ maxr, minr = 0.5, 0.001
165
+ maxw, maxh = stats[max_ind][2] * maxr, stats[max_ind][3] * maxr
166
+ minarea = labdraw.shape[0] * labdraw.shape[1] * minr
167
+
168
+ stat = stats[lab]
169
+ bboxarea = stat[2] * stat[3]
170
+ if stat[2] < maxw and stat[3] < maxh and bboxarea > minarea:
171
+ pix = np.zeros((labels.shape[0], labels.shape[1]), dtype=np.uint8)
172
+ pix[np.where(labels==lab)] = 255
173
+
174
+ rect = cv2.minAreaRect(cv2.findNonZero(pix))
175
+ box = np.int0(cv2.boxPoints(rect))
176
+ labdraw = cv2.drawContours(labdraw, [box], 0, randcolor, 2)
177
+ labdraw = cv2.circle(labdraw, (int(centroids[lab][0]),int(centroids[lab][1])), radius=5, color=(random.randint(0,255), random.randint(0,255), random.randint(0,255)), thickness=-1)
178
+
179
+ cv2.imshow(names, labdraw)
180
+ return labdraw
manga_translator/detection/ctd_utils/utils/io_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import glob
4
+ from pathlib import Path
5
+ import cv2
6
+ import numpy as np
7
+ import json
8
+
9
+ IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg']
10
+
11
+ NP_BOOL_TYPES = (np.bool_, np.bool8)
12
+ NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
13
+ NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
14
+
15
+ # https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
16
+ class NumpyEncoder(json.JSONEncoder):
17
+ def default(self, obj):
18
+ if isinstance(obj, np.ndarray):
19
+ return obj.tolist()
20
+ elif isinstance(obj, np.ScalarType):
21
+ if isinstance(obj, NP_BOOL_TYPES):
22
+ return bool(obj)
23
+ elif isinstance(obj, NP_FLOAT_TYPES):
24
+ return float(obj)
25
+ elif isinstance(obj, NP_INT_TYPES):
26
+ return int(obj)
27
+ return json.JSONEncoder.default(self, obj)
28
+
29
+ def find_all_imgs(img_dir, abs_path=False):
30
+ imglist = list()
31
+ for filep in glob.glob(osp.join(img_dir, "*")):
32
+ filename = osp.basename(filep)
33
+ file_suffix = Path(filename).suffix
34
+ if file_suffix.lower() not in IMG_EXT:
35
+ continue
36
+ if abs_path:
37
+ imglist.append(filep)
38
+ else:
39
+ imglist.append(filename)
40
+ return imglist
41
+
42
+ def imread(imgpath, read_type=cv2.IMREAD_COLOR):
43
+ # img = cv2.imread(imgpath, read_type)
44
+ # if img is None:
45
+ img = cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), read_type)
46
+ return img
47
+
48
+ def imwrite(img_path, img, ext='.png'):
49
+ suffix = Path(img_path).suffix
50
+ if suffix != '':
51
+ img_path = img_path.replace(suffix, ext)
52
+ else:
53
+ img_path += ext
54
+ cv2.imencode(ext, img)[1].tofile(img_path)
manga_translator/detection/ctd_utils/utils/weight_init.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ def constant_init(module, val, bias=0):
5
+ nn.init.constant_(module.weight, val)
6
+ if hasattr(module, 'bias') and module.bias is not None:
7
+ nn.init.constant_(module.bias, bias)
8
+
9
+ def xavier_init(module, gain=1, bias=0, distribution='normal'):
10
+ assert distribution in ['uniform', 'normal']
11
+ if distribution == 'uniform':
12
+ nn.init.xavier_uniform_(module.weight, gain=gain)
13
+ else:
14
+ nn.init.xavier_normal_(module.weight, gain=gain)
15
+ if hasattr(module, 'bias') and module.bias is not None:
16
+ nn.init.constant_(module.bias, bias)
17
+
18
+
19
+ def normal_init(module, mean=0, std=1, bias=0):
20
+ nn.init.normal_(module.weight, mean, std)
21
+ if hasattr(module, 'bias') and module.bias is not None:
22
+ nn.init.constant_(module.bias, bias)
23
+
24
+
25
+ def uniform_init(module, a=0, b=1, bias=0):
26
+ nn.init.uniform_(module.weight, a, b)
27
+ if hasattr(module, 'bias') and module.bias is not None:
28
+ nn.init.constant_(module.bias, bias)
29
+
30
+
31
+ def kaiming_init(module,
32
+ a=0,
33
+ is_rnn=False,
34
+ mode='fan_in',
35
+ nonlinearity='leaky_relu',
36
+ bias=0,
37
+ distribution='normal'):
38
+ assert distribution in ['uniform', 'normal']
39
+ if distribution == 'uniform':
40
+ if is_rnn:
41
+ for name, param in module.named_parameters():
42
+ if 'bias' in name:
43
+ nn.init.constant_(param, bias)
44
+ elif 'weight' in name:
45
+ nn.init.kaiming_uniform_(param,
46
+ a=a,
47
+ mode=mode,
48
+ nonlinearity=nonlinearity)
49
+ else:
50
+ nn.init.kaiming_uniform_(module.weight,
51
+ a=a,
52
+ mode=mode,
53
+ nonlinearity=nonlinearity)
54
+
55
+ else:
56
+ if is_rnn:
57
+ for name, param in module.named_parameters():
58
+ if 'bias' in name:
59
+ nn.init.constant_(param, bias)
60
+ elif 'weight' in name:
61
+ nn.init.kaiming_normal_(param,
62
+ a=a,
63
+ mode=mode,
64
+ nonlinearity=nonlinearity)
65
+ else:
66
+ nn.init.kaiming_normal_(module.weight,
67
+ a=a,
68
+ mode=mode,
69
+ nonlinearity=nonlinearity)
70
+
71
+ if not is_rnn and hasattr(module, 'bias') and module.bias is not None:
72
+ nn.init.constant_(module.bias, bias)
73
+
74
+
75
+ def bilinear_kernel(in_channels, out_channels, kernel_size):
76
+ factor = (kernel_size + 1) // 2
77
+ if kernel_size % 2 == 1:
78
+ center = factor - 1
79
+ else:
80
+ center = factor - 0.5
81
+ og = (torch.arange(kernel_size).reshape(-1, 1),
82
+ torch.arange(kernel_size).reshape(1, -1))
83
+ filt = (1 - torch.abs(og[0] - center) / factor) * \
84
+ (1 - torch.abs(og[1] - center) / factor)
85
+ weight = torch.zeros((in_channels, out_channels,
86
+ kernel_size, kernel_size))
87
+ weight[range(in_channels), range(out_channels), :, :] = filt
88
+ return weight
89
+
90
+
91
+ def init_weights(m):
92
+ # for m in modules:
93
+
94
+ if isinstance(m, nn.Conv2d):
95
+ kaiming_init(m)
96
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
97
+ constant_init(m, 1)
98
+ elif isinstance(m, nn.Linear):
99
+ xavier_init(m)
100
+ elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
101
+ kaiming_init(m, is_rnn=True)
102
+ # elif isinstance(m, nn.ConvTranspose2d):
103
+ # m.weight.data.copy_(bilinear_kernel(m.in_channels, m.out_channels, 4));
manga_translator/detection/ctd_utils/utils/yolov5_utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import cv2
6
+ import numpy as np
7
+ import time
8
+ import torchvision
9
+
10
+ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
11
+ # scales img(bs,3,y,x) by ratio constrained to gs-multiple
12
+ if ratio == 1.0:
13
+ return img
14
+ else:
15
+ h, w = img.shape[2:]
16
+ s = (int(h * ratio), int(w * ratio)) # new size
17
+ img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
18
+ if not same_shape: # pad/crop img
19
+ h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
20
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
21
+
22
+ def fuse_conv_and_bn(conv, bn):
23
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
24
+ fusedconv = nn.Conv2d(conv.in_channels,
25
+ conv.out_channels,
26
+ kernel_size=conv.kernel_size,
27
+ stride=conv.stride,
28
+ padding=conv.padding,
29
+ groups=conv.groups,
30
+ bias=True).requires_grad_(False).to(conv.weight.device)
31
+
32
+ # prepare filters
33
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
34
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
35
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
36
+
37
+ # prepare spatial bias
38
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
39
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
40
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
41
+
42
+ return fusedconv
43
+
44
+ def check_anchor_order(m):
45
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
46
+ a = m.anchors.prod(-1).view(-1) # anchor area
47
+ da = a[-1] - a[0] # delta a
48
+ ds = m.stride[-1] - m.stride[0] # delta s
49
+ if da.sign() != ds.sign(): # same order
50
+ m.anchors[:] = m.anchors.flip(0)
51
+
52
+ def initialize_weights(model):
53
+ for m in model.modules():
54
+ t = type(m)
55
+ if t is nn.Conv2d:
56
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
57
+ elif t is nn.BatchNorm2d:
58
+ m.eps = 1e-3
59
+ m.momentum = 0.03
60
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
61
+ m.inplace = True
62
+
63
+ def make_divisible(x, divisor):
64
+ # Returns nearest x divisible by divisor
65
+ if isinstance(divisor, torch.Tensor):
66
+ divisor = int(divisor.max()) # to int
67
+ return math.ceil(x / divisor) * divisor
68
+
69
+ def intersect_dicts(da, db, exclude=()):
70
+ # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
71
+ return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
72
+
73
+ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
74
+ # Check version vs. required version
75
+ from packaging import version
76
+ current, minimum = (version.parse(x) for x in (current, minimum))
77
+ result = (current == minimum) if pinned else (current >= minimum) # bool
78
+ if hard: # assert min requirements met
79
+ assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
80
+ else:
81
+ return result
82
+
83
+ class Colors:
84
+ # Ultralytics color palette https://ultralytics.com/
85
+ def __init__(self):
86
+ # hex = matplotlib.colors.TABLEAU_COLORS.values()
87
+ hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
88
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
89
+ self.palette = [self.hex2rgb('#' + c) for c in hex]
90
+ self.n = len(self.palette)
91
+
92
+ def __call__(self, i, bgr=False):
93
+ c = self.palette[int(i) % self.n]
94
+ return (c[2], c[1], c[0]) if bgr else c
95
+
96
+ @staticmethod
97
+ def hex2rgb(h): # rgb order (PIL)
98
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
99
+
100
+ def box_iou(box1, box2):
101
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
102
+ """
103
+ Return intersection-over-union (Jaccard index) of boxes.
104
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
105
+ Arguments:
106
+ box1 (Tensor[N, 4])
107
+ box2 (Tensor[M, 4])
108
+ Returns:
109
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
110
+ IoU values for every element in boxes1 and boxes2
111
+ """
112
+
113
+ def box_area(box):
114
+ # box = 4xn
115
+ return (box[2] - box[0]) * (box[3] - box[1])
116
+
117
+ area1 = box_area(box1.T)
118
+ area2 = box_area(box2.T)
119
+
120
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
121
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
122
+ return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
123
+
124
+ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
125
+ labels=(), max_det=300):
126
+ """Runs Non-Maximum Suppression (NMS) on inference results
127
+
128
+ Returns:
129
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
130
+ """
131
+
132
+ if isinstance(prediction, np.ndarray):
133
+ prediction = torch.from_numpy(prediction)
134
+
135
+ nc = prediction.shape[2] - 5 # number of classes
136
+ xc = prediction[..., 4] > conf_thres # candidates
137
+
138
+ # Checks
139
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
140
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
141
+
142
+ # Settings
143
+ min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
144
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
145
+ time_limit = 10.0 # seconds to quit after
146
+ redundant = True # require redundant detections
147
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
148
+ merge = False # use merge-NMS
149
+
150
+ t = time.time()
151
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
152
+ for xi, x in enumerate(prediction): # image index, image inference
153
+ # Apply constraints
154
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
155
+ x = x[xc[xi]] # confidence
156
+
157
+ # Cat apriori labels if autolabelling
158
+ if labels and len(labels[xi]):
159
+ l = labels[xi]
160
+ v = torch.zeros((len(l), nc + 5), device=x.device)
161
+ v[:, :4] = l[:, 1:5] # box
162
+ v[:, 4] = 1.0 # conf
163
+ v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
164
+ x = torch.cat((x, v), 0)
165
+
166
+ # If none remain process next image
167
+ if not x.shape[0]:
168
+ continue
169
+
170
+ # Compute conf
171
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
172
+
173
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
174
+ box = xywh2xyxy(x[:, :4])
175
+
176
+ # Detections matrix nx6 (xyxy, conf, cls)
177
+ if multi_label:
178
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
179
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
180
+ else: # best class only
181
+ conf, j = x[:, 5:].max(1, keepdim=True)
182
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
183
+
184
+ # Filter by class
185
+ if classes is not None:
186
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
187
+
188
+ # Apply finite constraint
189
+ # if not torch.isfinite(x).all():
190
+ # x = x[torch.isfinite(x).all(1)]
191
+
192
+ # Check shape
193
+ n = x.shape[0] # number of boxes
194
+ if not n: # no boxes
195
+ continue
196
+ elif n > max_nms: # excess boxes
197
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
198
+
199
+ # Batched NMS
200
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
201
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
202
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
203
+ if i.shape[0] > max_det: # limit detections
204
+ i = i[:max_det]
205
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
206
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
207
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
208
+ weights = iou * scores[None] # box weights
209
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
210
+ if redundant:
211
+ i = i[iou.sum(1) > 1] # require redundancy
212
+
213
+ output[xi] = x[i]
214
+ if (time.time() - t) > time_limit:
215
+ print(f'WARNING: NMS time limit {time_limit}s exceeded')
216
+ break # time limit exceeded
217
+
218
+ return output
219
+
220
+ def xywh2xyxy(x):
221
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
222
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
223
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
224
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
225
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
226
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
227
+ return y
228
+
229
+ DEFAULT_LANG_LIST = ['eng', 'ja']
230
+ def draw_bbox(pred, img, lang_list=None):
231
+ if lang_list is None:
232
+ lang_list = DEFAULT_LANG_LIST
233
+ lw = max(round(sum(img.shape) / 2 * 0.003), 2) # line width
234
+ pred = pred.astype(np.int32)
235
+ colors = Colors()
236
+ img = np.copy(img)
237
+ for ii, obj in enumerate(pred):
238
+ p1, p2 = (obj[0], obj[1]), (obj[2], obj[3])
239
+ label = lang_list[obj[-1]] + str(ii+1)
240
+ cv2.rectangle(img, p1, p2, colors(obj[-1], bgr=True), lw, lineType=cv2.LINE_AA)
241
+ t_w, t_h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=lw)[0]
242
+ cv2.putText(img, label, (p1[0], p1[1] + t_h + 2), 0, lw / 3, colors(obj[-1], bgr=True), max(lw-1, 1), cv2.LINE_AA)
243
+ return img
manga_translator/detection/ctd_utils/yolov5/common.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
+ """
3
+ Common modules
4
+ """
5
+
6
+ import json
7
+ import math
8
+ import platform
9
+ import warnings
10
+ from collections import OrderedDict, namedtuple
11
+ from copy import copy
12
+ from pathlib import Path
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import requests
17
+ import torch
18
+ import torch.nn as nn
19
+ from PIL import Image
20
+ from torch.cuda import amp
21
+
22
+ from ..utils.yolov5_utils import make_divisible, initialize_weights, check_anchor_order, check_version, fuse_conv_and_bn
23
+
24
+ def autopad(k, p=None): # kernel, padding
25
+ # Pad to 'same'
26
+ if p is None:
27
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
28
+ return p
29
+
30
+ class Conv(nn.Module):
31
+ # Standard convolution
32
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
33
+ super().__init__()
34
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
35
+ self.bn = nn.BatchNorm2d(c2)
36
+ if isinstance(act, bool):
37
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
38
+ elif isinstance(act, str):
39
+ if act == 'leaky':
40
+ self.act = nn.LeakyReLU(0.1, inplace=True)
41
+ elif act == 'relu':
42
+ self.act = nn.ReLU(inplace=True)
43
+ else:
44
+ self.act = None
45
+ def forward(self, x):
46
+ return self.act(self.bn(self.conv(x)))
47
+
48
+ def forward_fuse(self, x):
49
+ return self.act(self.conv(x))
50
+
51
+
52
+ class DWConv(Conv):
53
+ # Depth-wise convolution class
54
+ def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
55
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
56
+
57
+
58
+ class TransformerLayer(nn.Module):
59
+ # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
60
+ def __init__(self, c, num_heads):
61
+ super().__init__()
62
+ self.q = nn.Linear(c, c, bias=False)
63
+ self.k = nn.Linear(c, c, bias=False)
64
+ self.v = nn.Linear(c, c, bias=False)
65
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
66
+ self.fc1 = nn.Linear(c, c, bias=False)
67
+ self.fc2 = nn.Linear(c, c, bias=False)
68
+
69
+ def forward(self, x):
70
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
71
+ x = self.fc2(self.fc1(x)) + x
72
+ return x
73
+
74
+
75
+ class TransformerBlock(nn.Module):
76
+ # Vision Transformer https://arxiv.org/abs/2010.11929
77
+ def __init__(self, c1, c2, num_heads, num_layers):
78
+ super().__init__()
79
+ self.conv = None
80
+ if c1 != c2:
81
+ self.conv = Conv(c1, c2)
82
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
83
+ self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
84
+ self.c2 = c2
85
+
86
+ def forward(self, x):
87
+ if self.conv is not None:
88
+ x = self.conv(x)
89
+ b, _, w, h = x.shape
90
+ p = x.flatten(2).permute(2, 0, 1)
91
+ return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
92
+
93
+
94
+ class Bottleneck(nn.Module):
95
+ # Standard bottleneck
96
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, shortcut, groups, expansion
97
+ super().__init__()
98
+ c_ = int(c2 * e) # hidden channels
99
+ self.cv1 = Conv(c1, c_, 1, 1, act=act)
100
+ self.cv2 = Conv(c_, c2, 3, 1, g=g, act=act)
101
+ self.add = shortcut and c1 == c2
102
+
103
+ def forward(self, x):
104
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
105
+
106
+
107
+ class BottleneckCSP(nn.Module):
108
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
109
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
110
+ super().__init__()
111
+ c_ = int(c2 * e) # hidden channels
112
+ self.cv1 = Conv(c1, c_, 1, 1)
113
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
114
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
115
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
116
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
117
+ self.act = nn.SiLU()
118
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
119
+
120
+ def forward(self, x):
121
+ y1 = self.cv3(self.m(self.cv1(x)))
122
+ y2 = self.cv2(x)
123
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
124
+
125
+
126
+ class C3(nn.Module):
127
+ # CSP Bottleneck with 3 convolutions
128
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, number, shortcut, groups, expansion
129
+ super().__init__()
130
+ c_ = int(c2 * e) # hidden channels
131
+ self.cv1 = Conv(c1, c_, 1, 1, act=act)
132
+ self.cv2 = Conv(c1, c_, 1, 1, act=act)
133
+ self.cv3 = Conv(2 * c_, c2, 1, act=act) # act=FReLU(c2)
134
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0, act=act) for _ in range(n)))
135
+ # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
136
+
137
+ def forward(self, x):
138
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
139
+
140
+
141
+ class C3TR(C3):
142
+ # C3 module with TransformerBlock()
143
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
144
+ super().__init__(c1, c2, n, shortcut, g, e)
145
+ c_ = int(c2 * e)
146
+ self.m = TransformerBlock(c_, c_, 4, n)
147
+
148
+
149
+ class C3SPP(C3):
150
+ # C3 module with SPP()
151
+ def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
152
+ super().__init__(c1, c2, n, shortcut, g, e)
153
+ c_ = int(c2 * e)
154
+ self.m = SPP(c_, c_, k)
155
+
156
+
157
+ class C3Ghost(C3):
158
+ # C3 module with GhostBottleneck()
159
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
160
+ super().__init__(c1, c2, n, shortcut, g, e)
161
+ c_ = int(c2 * e) # hidden channels
162
+ self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
163
+
164
+
165
+ class SPP(nn.Module):
166
+ # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
167
+ def __init__(self, c1, c2, k=(5, 9, 13)):
168
+ super().__init__()
169
+ c_ = c1 // 2 # hidden channels
170
+ self.cv1 = Conv(c1, c_, 1, 1)
171
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
172
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
173
+
174
+ def forward(self, x):
175
+ x = self.cv1(x)
176
+ with warnings.catch_warnings():
177
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
178
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
179
+
180
+
181
+ class SPPF(nn.Module):
182
+ # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
183
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
184
+ super().__init__()
185
+ c_ = c1 // 2 # hidden channels
186
+ self.cv1 = Conv(c1, c_, 1, 1)
187
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
188
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
189
+
190
+ def forward(self, x):
191
+ x = self.cv1(x)
192
+ with warnings.catch_warnings():
193
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
194
+ y1 = self.m(x)
195
+ y2 = self.m(y1)
196
+ return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
197
+
198
+
199
+ class Focus(nn.Module):
200
+ # Focus wh information into c-space
201
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
202
+ super().__init__()
203
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
204
+ # self.contract = Contract(gain=2)
205
+
206
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
207
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
208
+ # return self.conv(self.contract(x))
209
+
210
+
211
+ class GhostConv(nn.Module):
212
+ # Ghost Convolution https://github.com/huawei-noah/ghostnet
213
+ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
214
+ super().__init__()
215
+ c_ = c2 // 2 # hidden channels
216
+ self.cv1 = Conv(c1, c_, k, s, None, g, act)
217
+ self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
218
+
219
+ def forward(self, x):
220
+ y = self.cv1(x)
221
+ return torch.cat([y, self.cv2(y)], 1)
222
+
223
+
224
+ class GhostBottleneck(nn.Module):
225
+ # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
226
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
227
+ super().__init__()
228
+ c_ = c2 // 2
229
+ self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
230
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
231
+ GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
232
+ self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
233
+ Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
234
+
235
+ def forward(self, x):
236
+ return self.conv(x) + self.shortcut(x)
237
+
238
+
239
+ class Contract(nn.Module):
240
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
241
+ def __init__(self, gain=2):
242
+ super().__init__()
243
+ self.gain = gain
244
+
245
+ def forward(self, x):
246
+ b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
247
+ s = self.gain
248
+ x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
249
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
250
+ return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
251
+
252
+
253
+ class Expand(nn.Module):
254
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
255
+ def __init__(self, gain=2):
256
+ super().__init__()
257
+ self.gain = gain
258
+
259
+ def forward(self, x):
260
+ b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
261
+ s = self.gain
262
+ x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
263
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
264
+ return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
265
+
266
+
267
+ class Concat(nn.Module):
268
+ # Concatenate a list of tensors along dimension
269
+ def __init__(self, dimension=1):
270
+ super().__init__()
271
+ self.d = dimension
272
+
273
+ def forward(self, x):
274
+ return torch.cat(x, self.d)
275
+
276
+
277
+ class Classify(nn.Module):
278
+ # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
279
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
280
+ super().__init__()
281
+ self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
282
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
283
+ self.flat = nn.Flatten()
284
+
285
+ def forward(self, x):
286
+ z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
287
+ return self.flat(self.conv(z)) # flatten to x(b,c2)
288
+
289
+
manga_translator/detection/ctd_utils/yolov5/yolo.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import mod
2
+ from cv2 import imshow
3
+ # from utils.yolov5_utils import scale_img
4
+ from copy import deepcopy
5
+ from .common import *
6
+
7
+ class Detect(nn.Module):
8
+ stride = None # strides computed during build
9
+ onnx_dynamic = False # ONNX export parameter
10
+
11
+ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
12
+ super().__init__()
13
+ self.nc = nc # number of classes
14
+ self.no = nc + 5 # number of outputs per anchor
15
+ self.nl = len(anchors) # number of detection layers
16
+ self.na = len(anchors[0]) // 2 # number of anchors
17
+ self.grid = [torch.zeros(1)] * self.nl # init grid
18
+ self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
19
+ self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
20
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
21
+ self.inplace = inplace # use in-place ops (e.g. slice assignment)
22
+
23
+ def forward(self, x):
24
+ z = [] # inference output
25
+ for i in range(self.nl):
26
+ x[i] = self.m[i](x[i]) # conv
27
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
28
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
29
+
30
+ if not self.training: # inference
31
+ if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
32
+ self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
33
+
34
+ y = x[i].sigmoid()
35
+ if self.inplace:
36
+ y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
37
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
38
+ else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
39
+ xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
40
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
41
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
42
+ z.append(y.view(bs, -1, self.no))
43
+
44
+ return x if self.training else (torch.cat(z, 1), x)
45
+
46
+ def _make_grid(self, nx=20, ny=20, i=0):
47
+ d = self.anchors[i].device
48
+ if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
49
+ yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)], indexing='ij')
50
+ else:
51
+ yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)])
52
+ grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
53
+ anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
54
+ .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
55
+ return grid, anchor_grid
56
+
57
+ class Model(nn.Module):
58
+ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
59
+ super().__init__()
60
+ self.out_indices = None
61
+ if isinstance(cfg, dict):
62
+ self.yaml = cfg # model dict
63
+ else: # is *.yaml
64
+ import yaml # for torch hub
65
+ self.yaml_file = Path(cfg).name
66
+ with open(cfg, encoding='ascii', errors='ignore') as f:
67
+ self.yaml = yaml.safe_load(f) # model dict
68
+
69
+ # Define model
70
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
71
+ if nc and nc != self.yaml['nc']:
72
+ # LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
73
+ self.yaml['nc'] = nc # override yaml value
74
+ if anchors:
75
+ # LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
76
+ self.yaml['anchors'] = round(anchors) # override yaml value
77
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
78
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
79
+ self.inplace = self.yaml.get('inplace', True)
80
+
81
+ # Build strides, anchors
82
+ m = self.model[-1] # Detect()
83
+ # with torch.no_grad():
84
+ if isinstance(m, Detect):
85
+ s = 256 # 2x min stride
86
+ m.inplace = self.inplace
87
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
88
+ m.anchors /= m.stride.view(-1, 1, 1)
89
+ check_anchor_order(m)
90
+ self.stride = m.stride
91
+ self._initialize_biases() # only run once
92
+
93
+ # Init weights, biases
94
+ initialize_weights(self)
95
+
96
+ def forward(self, x, augment=False, profile=False, visualize=False, detect=False):
97
+ # if augment:
98
+ # return self._forward_augment(x) # augmented inference, None
99
+ return self._forward_once(x, profile, visualize, detect=detect) # single-scale inference, train
100
+
101
+ # def _forward_augment(self, x):
102
+ # img_size = x.shape[-2:] # height, width
103
+ # s = [1, 0.83, 0.67] # scales
104
+ # f = [None, 3, None] # flips (2-ud, 3-lr)
105
+ # y = [] # outputs
106
+ # for si, fi in zip(s, f):
107
+ # xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
108
+ # yi = self._forward_once(xi)[0] # forward
109
+ # # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
110
+ # yi = self._descale_pred(yi, fi, si, img_size)
111
+ # y.append(yi)
112
+ # y = self._clip_augmented(y) # clip augmented tails
113
+ # return torch.cat(y, 1), None # augmented inference, train
114
+
115
+ def _forward_once(self, x, profile=False, visualize=False, detect=False):
116
+ y, dt = [], [] # outputs
117
+ z = []
118
+ for ii, m in enumerate(self.model):
119
+ if m.f != -1: # if not from previous layer
120
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
121
+ if profile:
122
+ self._profile_one_layer(m, x, dt)
123
+ x = m(x) # run
124
+ y.append(x if m.i in self.save else None) # save output
125
+ if self.out_indices is not None:
126
+ if m.i in self.out_indices:
127
+ z.append(x)
128
+ if self.out_indices is not None:
129
+ if detect:
130
+ return x, z
131
+ else:
132
+ return z
133
+ else:
134
+ return x
135
+
136
+ def _descale_pred(self, p, flips, scale, img_size):
137
+ # de-scale predictions following augmented inference (inverse operation)
138
+ if self.inplace:
139
+ p[..., :4] /= scale # de-scale
140
+ if flips == 2:
141
+ p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
142
+ elif flips == 3:
143
+ p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
144
+ else:
145
+ x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
146
+ if flips == 2:
147
+ y = img_size[0] - y # de-flip ud
148
+ elif flips == 3:
149
+ x = img_size[1] - x # de-flip lr
150
+ p = torch.cat((x, y, wh, p[..., 4:]), -1)
151
+ return p
152
+
153
+ def _clip_augmented(self, y):
154
+ # Clip YOLOv5 augmented inference tails
155
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
156
+ g = sum(4 ** x for x in range(nl)) # grid points
157
+ e = 1 # exclude layer count
158
+ i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
159
+ y[0] = y[0][:, :-i] # large
160
+ i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
161
+ y[-1] = y[-1][:, i:] # small
162
+ return y
163
+
164
+ def _profile_one_layer(self, m, x, dt):
165
+ c = isinstance(m, Detect) # is final layer, copy input as inplace fix
166
+ for _ in range(10):
167
+ m(x.copy() if c else x)
168
+
169
+
170
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
171
+ # https://arxiv.org/abs/1708.02002 section 3.3
172
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
173
+ m = self.model[-1] # Detect() module
174
+ for mi, s in zip(m.m, m.stride): # from
175
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
176
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
177
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
178
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
179
+
180
+ def _print_biases(self):
181
+ m = self.model[-1] # Detect() module
182
+ for mi in m.m: # from
183
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
184
+
185
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
186
+ for m in self.model.modules():
187
+ if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
188
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
189
+ delattr(m, 'bn') # remove batchnorm
190
+ m.forward = m.forward_fuse # update forward
191
+ # self.info()
192
+ return self
193
+
194
+ # def info(self, verbose=False, img_size=640): # print model information
195
+ # model_info(self, verbose, img_size)
196
+
197
+ def _apply(self, fn):
198
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
199
+ self = super()._apply(fn)
200
+ m = self.model[-1] # Detect()
201
+ if isinstance(m, Detect):
202
+ m.stride = fn(m.stride)
203
+ m.grid = list(map(fn, m.grid))
204
+ if isinstance(m.anchor_grid, list):
205
+ m.anchor_grid = list(map(fn, m.anchor_grid))
206
+ return self
207
+
208
+ def parse_model(d, ch): # model_dict, input_channels(3)
209
+ # LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
210
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
211
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
212
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
213
+
214
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
215
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
216
+ m = eval(m) if isinstance(m, str) else m # eval strings
217
+ for j, a in enumerate(args):
218
+ try:
219
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
220
+ except NameError:
221
+ pass
222
+
223
+ n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
224
+ if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
225
+ BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
226
+ c1, c2 = ch[f], args[0]
227
+ if c2 != no: # if not output
228
+ c2 = make_divisible(c2 * gw, 8)
229
+
230
+ args = [c1, c2, *args[1:]]
231
+ if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
232
+ args.insert(2, n) # number of repeats
233
+ n = 1
234
+ elif m is nn.BatchNorm2d:
235
+ args = [ch[f]]
236
+ elif m is Concat:
237
+ c2 = sum(ch[x] for x in f)
238
+ elif m is Detect:
239
+ args.append([ch[x] for x in f])
240
+ if isinstance(args[1], int): # number of anchors
241
+ args[1] = [list(range(args[1] * 2))] * len(f)
242
+ elif m is Contract:
243
+ c2 = ch[f] * args[0] ** 2
244
+ elif m is Expand:
245
+ c2 = ch[f] // args[0] ** 2
246
+ else:
247
+ c2 = ch[f]
248
+
249
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
250
+ t = str(m)[8:-2].replace('__main__.', '') # module type
251
+ np = sum(x.numel() for x in m_.parameters()) # number params
252
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
253
+ # LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
254
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
255
+ layers.append(m_)
256
+ if i == 0:
257
+ ch = []
258
+ ch.append(c2)
259
+ return nn.Sequential(*layers), sorted(save)
260
+
261
+ def load_yolov5(weights, map_location='cuda', fuse=True, inplace=True, out_indices=[1, 3, 5, 7, 9]):
262
+ if isinstance(weights, str):
263
+ ckpt = torch.load(weights, map_location=map_location) # load
264
+ else:
265
+ ckpt = weights
266
+
267
+ if fuse:
268
+ model = ckpt['model'].float().fuse().eval() # FP32 model
269
+ else:
270
+ model = ckpt['model'].float().eval() # without layer fuse
271
+
272
+ # Compatibility updates
273
+ for m in model.modules():
274
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
275
+ m.inplace = inplace # pytorch 1.7.0 compatibility
276
+ if type(m) is Detect:
277
+ if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
278
+ delattr(m, 'anchor_grid')
279
+ setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
280
+ elif type(m) is Conv:
281
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
282
+ model.out_indices = out_indices
283
+ return model
284
+
285
+ @torch.no_grad()
286
+ def load_yolov5_ckpt(weights, map_location='cpu', fuse=True, inplace=True, out_indices=[1, 3, 5, 7, 9]):
287
+ if isinstance(weights, str):
288
+ ckpt = torch.load(weights, map_location=map_location) # load
289
+ else:
290
+ ckpt = weights
291
+
292
+ model = Model(ckpt['cfg'])
293
+ model.load_state_dict(ckpt['weights'], strict=True)
294
+
295
+ if fuse:
296
+ model = model.float().fuse().eval() # FP32 model
297
+ else:
298
+ model = model.float().eval() # without layer fuse
299
+
300
+ # Compatibility updates
301
+ for m in model.modules():
302
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
303
+ m.inplace = inplace # pytorch 1.7.0 compatibility
304
+ if type(m) is Detect:
305
+ if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
306
+ delattr(m, 'anchor_grid')
307
+ setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
308
+ elif type(m) is Conv:
309
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
310
+ model.out_indices = out_indices
311
+ return model
manga_translator/detection/dbnet_convnext.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from functools import partial
3
+ import shutil
4
+ from typing import Callable, Optional, Tuple, Union
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.nn.init as init
11
+
12
+ from torchvision.models import resnet34
13
+
14
+ import einops
15
+ import math
16
+
17
+ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
18
+ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
19
+
20
+ class Downsample(nn.Module):
21
+
22
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
23
+ super().__init__()
24
+ avg_stride = stride if dilation == 1 else 1
25
+ if stride > 1 or dilation > 1:
26
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
27
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
28
+ else:
29
+ self.pool = nn.Identity()
30
+
31
+ if in_chs != out_chs:
32
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
33
+ else:
34
+ self.conv = nn.Identity()
35
+
36
+ def forward(self, x):
37
+ x = self.pool(x)
38
+ x = self.conv(x)
39
+ return x
40
+
41
+
42
+ class ConvNeXtBlock(nn.Module):
43
+ """ ConvNeXt Block
44
+ There are two equivalent implementations:
45
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
46
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
47
+
48
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
49
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
50
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ in_chs: int,
56
+ out_chs: Optional[int] = None,
57
+ kernel_size: int = 7,
58
+ stride: int = 1,
59
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
60
+ mlp_ratio: float = 4,
61
+ conv_mlp: bool = False,
62
+ conv_bias: bool = True,
63
+ use_grn: bool = False,
64
+ ls_init_value: Optional[float] = 1e-6,
65
+ act_layer: Union[str, Callable] = 'gelu',
66
+ norm_layer: Optional[Callable] = None,
67
+ drop_path: float = 0.,
68
+ ):
69
+ """
70
+
71
+ Args:
72
+ in_chs: Block input channels.
73
+ out_chs: Block output channels (same as in_chs if None).
74
+ kernel_size: Depthwise convolution kernel size.
75
+ stride: Stride of depthwise convolution.
76
+ dilation: Tuple specifying input and output dilation of block.
77
+ mlp_ratio: MLP expansion ratio.
78
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
79
+ conv_bias: Apply bias for all convolution (linear) layers.
80
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
81
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
82
+ act_layer: Activation layer.
83
+ norm_layer: Normalization layer (defaults to LN if not specified).
84
+ drop_path: Stochastic depth probability.
85
+ """
86
+ super().__init__()
87
+ out_chs = out_chs or in_chs
88
+ dilation = to_ntuple(2)(dilation)
89
+ act_layer = get_act_layer(act_layer)
90
+ if not norm_layer:
91
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
92
+ mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
93
+ self.use_conv_mlp = conv_mlp
94
+ self.conv_dw = create_conv2d(
95
+ in_chs,
96
+ out_chs,
97
+ kernel_size=kernel_size,
98
+ stride=stride,
99
+ dilation=dilation[0],
100
+ depthwise=True if out_chs >= in_chs else False,
101
+ bias=conv_bias,
102
+ )
103
+ self.norm = norm_layer(out_chs)
104
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
105
+ self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
106
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
107
+ self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
108
+ else:
109
+ self.shortcut = nn.Identity()
110
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
111
+
112
+ def forward(self, x):
113
+ shortcut = x
114
+ x = self.conv_dw(x)
115
+ if self.use_conv_mlp:
116
+ x = self.norm(x)
117
+ x = self.mlp(x)
118
+ else:
119
+ x = x.permute(0, 2, 3, 1)
120
+ x = self.norm(x)
121
+ x = self.mlp(x)
122
+ x = x.permute(0, 3, 1, 2)
123
+ if self.gamma is not None:
124
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
125
+
126
+ x = self.drop_path(x) + self.shortcut(shortcut)
127
+ return x
128
+
129
+
130
+ class ConvNeXtStage(nn.Module):
131
+
132
+ def __init__(
133
+ self,
134
+ in_chs,
135
+ out_chs,
136
+ kernel_size=7,
137
+ stride=2,
138
+ depth=2,
139
+ dilation=(1, 1),
140
+ drop_path_rates=None,
141
+ ls_init_value=1.0,
142
+ conv_mlp=False,
143
+ conv_bias=True,
144
+ use_grn=False,
145
+ act_layer='gelu',
146
+ norm_layer=None,
147
+ norm_layer_cl=None
148
+ ):
149
+ super().__init__()
150
+ self.grad_checkpointing = False
151
+
152
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
153
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
154
+ pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
155
+ self.downsample = nn.Sequential(
156
+ norm_layer(in_chs),
157
+ create_conv2d(
158
+ in_chs,
159
+ out_chs,
160
+ kernel_size=ds_ks,
161
+ stride=stride,
162
+ dilation=dilation[0],
163
+ padding=pad,
164
+ bias=conv_bias,
165
+ ),
166
+ )
167
+ in_chs = out_chs
168
+ else:
169
+ self.downsample = nn.Identity()
170
+
171
+ drop_path_rates = drop_path_rates or [0.] * depth
172
+ stage_blocks = []
173
+ for i in range(depth):
174
+ stage_blocks.append(ConvNeXtBlock(
175
+ in_chs=in_chs,
176
+ out_chs=out_chs,
177
+ kernel_size=kernel_size,
178
+ dilation=dilation[1],
179
+ drop_path=drop_path_rates[i],
180
+ ls_init_value=ls_init_value,
181
+ conv_mlp=conv_mlp,
182
+ conv_bias=conv_bias,
183
+ use_grn=use_grn,
184
+ act_layer=act_layer,
185
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
186
+ ))
187
+ in_chs = out_chs
188
+ self.blocks = nn.Sequential(*stage_blocks)
189
+
190
+ def forward(self, x):
191
+ x = self.downsample(x)
192
+ x = self.blocks(x)
193
+ return x
194
+
195
+
196
+ class ConvNeXt(nn.Module):
197
+ r""" ConvNeXt
198
+ A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ in_chans: int = 3,
204
+ num_classes: int = 1000,
205
+ global_pool: str = 'avg',
206
+ output_stride: int = 32,
207
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
208
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
209
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
210
+ ls_init_value: Optional[float] = 1e-6,
211
+ stem_type: str = 'patch',
212
+ patch_size: int = 4,
213
+ head_init_scale: float = 1.,
214
+ head_norm_first: bool = False,
215
+ head_hidden_size: Optional[int] = None,
216
+ conv_mlp: bool = False,
217
+ conv_bias: bool = True,
218
+ use_grn: bool = False,
219
+ act_layer: Union[str, Callable] = 'gelu',
220
+ norm_layer: Optional[Union[str, Callable]] = None,
221
+ norm_eps: Optional[float] = None,
222
+ drop_rate: float = 0.,
223
+ drop_path_rate: float = 0.,
224
+ ):
225
+ """
226
+ Args:
227
+ in_chans: Number of input image channels.
228
+ num_classes: Number of classes for classification head.
229
+ global_pool: Global pooling type.
230
+ output_stride: Output stride of network, one of (8, 16, 32).
231
+ depths: Number of blocks at each stage.
232
+ dims: Feature dimension at each stage.
233
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
234
+ ls_init_value: Init value for Layer Scale, disabled if None.
235
+ stem_type: Type of stem.
236
+ patch_size: Stem patch size for patch stem.
237
+ head_init_scale: Init scaling value for classifier weights and biases.
238
+ head_norm_first: Apply normalization before global pool + head.
239
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
240
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
241
+ conv_bias: Use bias layers w/ all convolutions.
242
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
243
+ act_layer: Activation layer type.
244
+ norm_layer: Normalization layer type.
245
+ drop_rate: Head pre-classifier dropout rate.
246
+ drop_path_rate: Stochastic depth drop rate.
247
+ """
248
+ super().__init__()
249
+ assert output_stride in (8, 16, 32)
250
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
251
+ if norm_layer is None:
252
+ norm_layer = LayerNorm2d
253
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
254
+ if norm_eps is not None:
255
+ norm_layer = partial(norm_layer, eps=norm_eps)
256
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
257
+ else:
258
+ assert conv_mlp,\
259
+ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
260
+ norm_layer_cl = norm_layer
261
+ if norm_eps is not None:
262
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
263
+
264
+ self.num_classes = num_classes
265
+ self.drop_rate = drop_rate
266
+ self.feature_info = []
267
+
268
+ assert stem_type in ('patch', 'overlap', 'overlap_tiered')
269
+ if stem_type == 'patch':
270
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
271
+ self.stem = nn.Sequential(
272
+ nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
273
+ norm_layer(dims[0]),
274
+ )
275
+ stem_stride = patch_size
276
+ else:
277
+ mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
278
+ self.stem = nn.Sequential(
279
+ nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
280
+ nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
281
+ norm_layer(dims[0]),
282
+ )
283
+ stem_stride = 4
284
+
285
+ self.stages = nn.Sequential()
286
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
287
+ stages = []
288
+ prev_chs = dims[0]
289
+ curr_stride = stem_stride
290
+ dilation = 1
291
+ # 4 feature resolution stages, each consisting of multiple residual blocks
292
+ for i in range(4):
293
+ stride = 2 if curr_stride == 2 or i > 0 else 1
294
+ if curr_stride >= output_stride and stride > 1:
295
+ dilation *= stride
296
+ stride = 1
297
+ curr_stride *= stride
298
+ first_dilation = 1 if dilation in (1, 2) else 2
299
+ out_chs = dims[i]
300
+ stages.append(ConvNeXtStage(
301
+ prev_chs,
302
+ out_chs,
303
+ kernel_size=kernel_sizes[i],
304
+ stride=stride,
305
+ dilation=(first_dilation, dilation),
306
+ depth=depths[i],
307
+ drop_path_rates=dp_rates[i],
308
+ ls_init_value=ls_init_value,
309
+ conv_mlp=conv_mlp,
310
+ conv_bias=conv_bias,
311
+ use_grn=use_grn,
312
+ act_layer=act_layer,
313
+ norm_layer=norm_layer,
314
+ norm_layer_cl=norm_layer_cl,
315
+ ))
316
+ prev_chs = out_chs
317
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
318
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
319
+ self.stages = nn.Sequential(*stages)
320
+ self.num_features = prev_chs
321
+
322
+ @torch.jit.ignore
323
+ def group_matcher(self, coarse=False):
324
+ return dict(
325
+ stem=r'^stem',
326
+ blocks=r'^stages\.(\d+)' if coarse else [
327
+ (r'^stages\.(\d+)\.downsample', (0,)), # blocks
328
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
329
+ (r'^norm_pre', (99999,))
330
+ ]
331
+ )
332
+
333
+ @torch.jit.ignore
334
+ def set_grad_checkpointing(self, enable=True):
335
+ for s in self.stages:
336
+ s.grad_checkpointing = enable
337
+
338
+ @torch.jit.ignore
339
+ def get_classifier(self):
340
+ return self.head.fc
341
+
342
+ def forward_features(self, x):
343
+ x = self.stem(x)
344
+ x = self.stages(x)
345
+ return x
346
+
347
+ def _init_weights(module, name=None, head_init_scale=1.0):
348
+ if isinstance(module, nn.Conv2d):
349
+ trunc_normal_(module.weight, std=.02)
350
+ if module.bias is not None:
351
+ nn.init.zeros_(module.bias)
352
+ elif isinstance(module, nn.Linear):
353
+ trunc_normal_(module.weight, std=.02)
354
+ nn.init.zeros_(module.bias)
355
+ if name and 'head.' in name:
356
+ module.weight.data.mul_(head_init_scale)
357
+ module.bias.data.mul_(head_init_scale)
358
+
359
+ class UpconvSkip(nn.Module) :
360
+ def __init__(self, ch1, ch2, out_ch) -> None:
361
+ super().__init__()
362
+ self.conv = ConvNeXtBlock(
363
+ in_chs=ch1 + ch2,
364
+ out_chs=out_ch,
365
+ kernel_size=7,
366
+ dilation=1,
367
+ drop_path=0,
368
+ ls_init_value=1.0,
369
+ conv_mlp=False,
370
+ conv_bias=True,
371
+ use_grn=False,
372
+ act_layer='gelu',
373
+ norm_layer=LayerNorm,
374
+ )
375
+ self.upconv = nn.ConvTranspose2d(out_ch, out_ch, 2, 2, 0, 0)
376
+
377
+ def forward(self, x) :
378
+ x = self.conv(x)
379
+ x = self.upconv(x)
380
+ return x
381
+
382
+ class DBHead(nn.Module):
383
+ def __init__(self, in_channels, k = 50):
384
+ super().__init__()
385
+ self.k = k
386
+ self.binarize = nn.Sequential(
387
+ nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),
388
+ #nn.BatchNorm2d(in_channels // 4),
389
+ nn.SiLU(inplace=True),
390
+ nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 4, 2, 1),
391
+ #nn.BatchNorm2d(in_channels // 4),
392
+ nn.SiLU(inplace=True),
393
+ nn.ConvTranspose2d(in_channels // 4, 1, 4, 2, 1),
394
+ )
395
+ self.binarize.apply(self.weights_init)
396
+
397
+ self.thresh = self._init_thresh(in_channels)
398
+ self.thresh.apply(self.weights_init)
399
+
400
+ def forward(self, x):
401
+ shrink_maps = self.binarize(x)
402
+ threshold_maps = self.thresh(x)
403
+ if self.training:
404
+ binary_maps = self.step_function(shrink_maps.sigmoid(), threshold_maps)
405
+ y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
406
+ else:
407
+ y = torch.cat((shrink_maps, threshold_maps), dim=1)
408
+ return y
409
+
410
+ def weights_init(self, m):
411
+ classname = m.__class__.__name__
412
+ if classname.find('Conv') != -1:
413
+ nn.init.kaiming_normal_(m.weight.data)
414
+ elif classname.find('BatchNorm') != -1:
415
+ m.weight.data.fill_(1.)
416
+ m.bias.data.fill_(1e-4)
417
+
418
+ def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
419
+ in_channels = inner_channels
420
+ if serial:
421
+ in_channels += 1
422
+ self.thresh = nn.Sequential(
423
+ nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
424
+ #nn.GroupNorm(inner_channels // 4),
425
+ nn.SiLU(inplace=True),
426
+ self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),
427
+ #nn.GroupNorm(inner_channels // 4),
428
+ nn.SiLU(inplace=True),
429
+ self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
430
+ nn.Sigmoid())
431
+ return self.thresh
432
+
433
+ def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
434
+ if smooth:
435
+ inter_out_channels = out_channels
436
+ if out_channels == 1:
437
+ inter_out_channels = in_channels
438
+ module_list = [
439
+ nn.Upsample(scale_factor=2, mode='bilinear'),
440
+ nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]
441
+ if out_channels == 1:
442
+ module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))
443
+ return nn.Sequential(module_list)
444
+ else:
445
+ return nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)
446
+
447
+ def step_function(self, x, y):
448
+ return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
449
+
450
+ class DBNetConvNext(nn.Module) :
451
+ def __init__(self) :
452
+ super(DBNetConvNext, self).__init__()
453
+ self.backbone = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
454
+
455
+ self.conv_mask = nn.Sequential(
456
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.SiLU(inplace=True),
457
+ nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.SiLU(inplace=True),
458
+ nn.Conv2d(32, 1, kernel_size=1),
459
+ nn.Sigmoid()
460
+ )
461
+
462
+ self.down_conv1 = ConvNeXtStage(1024, 1024, depth = 2, norm_layer = LayerNorm2d)
463
+ self.down_conv2 = ConvNeXtStage(1024, 1024, depth = 2, norm_layer = LayerNorm2d)
464
+
465
+ self.upconv1 = UpconvSkip(0, 1024, 128)
466
+ self.upconv2 = UpconvSkip(128, 1024, 128)
467
+ self.upconv3 = UpconvSkip(128, 1024, 128)
468
+ self.upconv4 = UpconvSkip(128, 512, 128)
469
+ self.upconv5 = UpconvSkip(128, 256, 128)
470
+ self.upconv6 = UpconvSkip(128, 128, 64)
471
+
472
+ self.conv_db = DBHead(128)
473
+
474
+ def forward(self, x) :
475
+ # in 3@1536
476
+ x = self.backbone.stem(x) # 128@384
477
+ h4 = self.backbone.stages[0](x) # 128@384
478
+ h8 = self.backbone.stages[1](h4) # 256@192
479
+ h16 = self.backbone.stages[2](h8) # 512@96
480
+ h32 = self.backbone.stages[3](h16) # 1024@48
481
+ h64 = self.down_conv1(h32) # 1024@24
482
+ h128 = self.down_conv2(h64) # 1024@12
483
+
484
+ up128 = self.upconv1(h128)
485
+ up64 = self.upconv2(torch.cat([up128, h64], dim = 1))
486
+ up32 = self.upconv3(torch.cat([up64, h32], dim = 1))
487
+ up16 = self.upconv4(torch.cat([up32, h16], dim = 1))
488
+ up8 = self.upconv5(torch.cat([up16, h8], dim = 1))
489
+ up4 = self.upconv6(torch.cat([up8, h4], dim = 1))
490
+
491
+ return self.conv_db(up8), self.conv_mask(up4)
492
+
493
+ import os
494
+ from .default_utils import imgproc, dbnet_utils, craft_utils
495
+ from .common import OfflineDetector
496
+ from ..utils import TextBlock, Quadrilateral, det_rearrange_forward
497
+
498
+ MODEL = None
499
+ def det_batch_forward_default(batch: np.ndarray, device: str):
500
+ global MODEL
501
+ if isinstance(batch, list):
502
+ batch = np.array(batch)
503
+ batch = einops.rearrange(batch.astype(np.float32) / 127.5 - 1.0, 'n h w c -> n c h w')
504
+ batch = torch.from_numpy(batch).to(device)
505
+ with torch.no_grad():
506
+ db, mask = MODEL(batch)
507
+ db = db.sigmoid().cpu().numpy()
508
+ mask = mask.cpu().numpy()
509
+ return db, mask
510
+
511
+
512
+ class DBConvNextDetector(OfflineDetector):
513
+ _MODEL_MAPPING = {
514
+ 'model': {
515
+ 'url': '',
516
+ 'hash': '',
517
+ 'file': '.',
518
+ }
519
+ }
520
+
521
+ def __init__(self, *args, **kwargs):
522
+ os.makedirs(self.model_dir, exist_ok=True)
523
+ if os.path.exists('dbnet_convnext.ckpt'):
524
+ shutil.move('dbnet_convnext.ckpt', self._get_file_path('dbnet_convnext.ckpt'))
525
+ super().__init__(*args, **kwargs)
526
+
527
+ async def _load(self, device: str):
528
+ self.model = DBNetConvNext()
529
+ sd = torch.load(self._get_file_path('dbnet_convnext.ckpt'), map_location='cpu')
530
+ self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
531
+ self.model.eval()
532
+ self.device = device
533
+ if device == 'cuda' or device == 'mps':
534
+ self.model = self.model.to(self.device)
535
+ global MODEL
536
+ MODEL = self.model
537
+
538
+ async def _unload(self):
539
+ del self.model
540
+
541
+ async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
542
+ unclip_ratio: float, verbose: bool = False):
543
+
544
+ # TODO: Move det_rearrange_forward to common.py and refactor
545
+ db, mask = det_rearrange_forward(image, det_batch_forward_default, detect_size, 4, device=self.device, verbose=verbose)
546
+
547
+ if db is None:
548
+ # rearrangement is not required, fallback to default forward
549
+ img_resized, target_ratio, _, pad_w, pad_h = imgproc.resize_aspect_ratio(cv2.bilateralFilter(image, 17, 80, 80), detect_size, cv2.INTER_LINEAR, mag_ratio = 1)
550
+ img_resized_h, img_resized_w = img_resized.shape[:2]
551
+ ratio_h = ratio_w = 1 / target_ratio
552
+ db, mask = det_batch_forward_default([img_resized], self.device)
553
+ else:
554
+ img_resized_h, img_resized_w = image.shape[:2]
555
+ ratio_w = ratio_h = 1
556
+ pad_h = pad_w = 0
557
+ self.logger.info(f'Detection resolution: {img_resized_w}x{img_resized_h}')
558
+
559
+ mask = mask[0, 0, :, :]
560
+ det = dbnet_utils.SegDetectorRepresenter(text_threshold, box_threshold, unclip_ratio=unclip_ratio)
561
+ # boxes, scores = det({'shape': [(img_resized.shape[0], img_resized.shape[1])]}, db)
562
+ boxes, scores = det({'shape':[(img_resized_h, img_resized_w)]}, db)
563
+ boxes, scores = boxes[0], scores[0]
564
+ if boxes.size == 0:
565
+ polys = []
566
+ else:
567
+ idx = boxes.reshape(boxes.shape[0], -1).sum(axis=1) > 0
568
+ polys, _ = boxes[idx], scores[idx]
569
+ polys = polys.astype(np.float64)
570
+ polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=1)
571
+ polys = polys.astype(np.int16)
572
+
573
+ textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(polys, scores)]
574
+ textlines = list(filter(lambda q: q.area > 16, textlines))
575
+ mask_resized = cv2.resize(mask, (mask.shape[1] * 2, mask.shape[0] * 2), interpolation=cv2.INTER_LINEAR)
576
+ if pad_h > 0:
577
+ mask_resized = mask_resized[:-pad_h, :]
578
+ elif pad_w > 0:
579
+ mask_resized = mask_resized[:, :-pad_w]
580
+ raw_mask = np.clip(mask_resized * 255, 0, 255).astype(np.uint8)
581
+
582
+ # if verbose:
583
+ # img_bbox_raw = np.copy(image)
584
+ # for txtln in textlines:
585
+ # cv2.polylines(img_bbox_raw, [txtln.pts], True, color=(255, 0, 0), thickness=2)
586
+ # cv2.imwrite(f'result/bboxes_unfiltered.png', cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR))
587
+
588
+ return textlines, raw_mask, None
589
+
590
+
591
+ if __name__ == '__main__' :
592
+ net = DBNetConvNext().cuda()
593
+ img = torch.randn(2, 3, 1536, 1536).cuda()
594
+ ret1, ret2 = net.forward(img)
595
+ print(ret1.shape)
596
+ print(ret2.shape)
manga_translator/detection/default.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import numpy as np
4
+ import torch
5
+ import cv2
6
+ import einops
7
+ from typing import List, Tuple
8
+
9
+ from .default_utils.DBNet_resnet34 import TextDetection as TextDetectionDefault
10
+ from .default_utils import imgproc, dbnet_utils, craft_utils
11
+ from .common import OfflineDetector
12
+ from ..utils import TextBlock, Quadrilateral, det_rearrange_forward
13
+
14
+ MODEL = None
15
+ def det_batch_forward_default(batch: np.ndarray, device: str):
16
+ global MODEL
17
+ if isinstance(batch, list):
18
+ batch = np.array(batch)
19
+ batch = einops.rearrange(batch.astype(np.float32) / 127.5 - 1.0, 'n h w c -> n c h w')
20
+ batch = torch.from_numpy(batch).to(device)
21
+ with torch.no_grad():
22
+ db, mask = MODEL(batch)
23
+ db = db.sigmoid().cpu().numpy()
24
+ mask = mask.cpu().numpy()
25
+ return db, mask
26
+
27
+ class DefaultDetector(OfflineDetector):
28
+ _MODEL_MAPPING = {
29
+ 'model': {
30
+ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/detect.ckpt',
31
+ 'hash': '69080aea78de0803092bc6b751ae283ca463011de5f07e1d20e6491b05571a30',
32
+ 'file': '.',
33
+ }
34
+ }
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ os.makedirs(self.model_dir, exist_ok=True)
38
+ if os.path.exists('detect.ckpt'):
39
+ shutil.move('detect.ckpt', self._get_file_path('detect.ckpt'))
40
+ super().__init__(*args, **kwargs)
41
+
42
+ async def _load(self, device: str):
43
+ self.model = TextDetectionDefault()
44
+ sd = torch.load(self._get_file_path('detect.ckpt'), map_location='cpu')
45
+ self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
46
+ self.model.eval()
47
+ self.device = device
48
+ if device == 'cuda' or device == 'mps':
49
+ self.model = self.model.to(self.device)
50
+ global MODEL
51
+ MODEL = self.model
52
+
53
+ async def _unload(self):
54
+ del self.model
55
+
56
+ async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
57
+ unclip_ratio: float, verbose: bool = False):
58
+
59
+ # TODO: Move det_rearrange_forward to common.py and refactor
60
+ db, mask = det_rearrange_forward(image, det_batch_forward_default, detect_size, 4, device=self.device, verbose=verbose)
61
+
62
+ if db is None:
63
+ # rearrangement is not required, fallback to default forward
64
+ img_resized, target_ratio, _, pad_w, pad_h = imgproc.resize_aspect_ratio(cv2.bilateralFilter(image, 17, 80, 80), detect_size, cv2.INTER_LINEAR, mag_ratio = 1)
65
+ img_resized_h, img_resized_w = img_resized.shape[:2]
66
+ ratio_h = ratio_w = 1 / target_ratio
67
+ db, mask = det_batch_forward_default([img_resized], self.device)
68
+ else:
69
+ img_resized_h, img_resized_w = image.shape[:2]
70
+ ratio_w = ratio_h = 1
71
+ pad_h = pad_w = 0
72
+ self.logger.info(f'Detection resolution: {img_resized_w}x{img_resized_h}')
73
+
74
+ mask = mask[0, 0, :, :]
75
+ det = dbnet_utils.SegDetectorRepresenter(text_threshold, box_threshold, unclip_ratio=unclip_ratio)
76
+ # boxes, scores = det({'shape': [(img_resized.shape[0], img_resized.shape[1])]}, db)
77
+ boxes, scores = det({'shape':[(img_resized_h, img_resized_w)]}, db)
78
+ boxes, scores = boxes[0], scores[0]
79
+ if boxes.size == 0:
80
+ polys = []
81
+ else:
82
+ idx = boxes.reshape(boxes.shape[0], -1).sum(axis=1) > 0
83
+ polys, _ = boxes[idx], scores[idx]
84
+ polys = polys.astype(np.float64)
85
+ polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=1)
86
+ polys = polys.astype(np.int16)
87
+
88
+ textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(polys, scores)]
89
+ textlines = list(filter(lambda q: q.area > 16, textlines))
90
+ mask_resized = cv2.resize(mask, (mask.shape[1] * 2, mask.shape[0] * 2), interpolation=cv2.INTER_LINEAR)
91
+ if pad_h > 0:
92
+ mask_resized = mask_resized[:-pad_h, :]
93
+ elif pad_w > 0:
94
+ mask_resized = mask_resized[:, :-pad_w]
95
+ raw_mask = np.clip(mask_resized * 255, 0, 255).astype(np.uint8)
96
+
97
+ # if verbose:
98
+ # img_bbox_raw = np.copy(image)
99
+ # for txtln in textlines:
100
+ # cv2.polylines(img_bbox_raw, [txtln.pts], True, color=(255, 0, 0), thickness=2)
101
+ # cv2.imwrite(f'result/bboxes_unfiltered.png', cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR))
102
+
103
+ return textlines, raw_mask, None
manga_translator/detection/default_utils/CRAFT_resnet34.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+
7
+ from torchvision.models import resnet34
8
+
9
+ import einops
10
+ import math
11
+
12
+ class ImageMultiheadSelfAttention(nn.Module):
13
+ def __init__(self, planes):
14
+ super(ImageMultiheadSelfAttention, self).__init__()
15
+ self.attn = nn.MultiheadAttention(planes, 4)
16
+ def forward(self, x):
17
+ res = x
18
+ n, c, h, w = x.shape
19
+ x = einops.rearrange(x, 'n c h w -> (h w) n c')
20
+ x = self.attn(x, x, x)[0]
21
+ x = einops.rearrange(x, '(h w) n c -> n c h w', n = n, c = c, h = h, w = w)
22
+ return res + x
23
+
24
+ class double_conv(nn.Module):
25
+ def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256):
26
+ super(double_conv, self).__init__()
27
+ self.planes = planes
28
+ # down = None
29
+ # if stride > 1:
30
+ # down = nn.Sequential(
31
+ # nn.AvgPool2d(2, 2),
32
+ # nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion)
33
+ # )
34
+ self.down = None
35
+ if stride > 1:
36
+ self.down = nn.AvgPool2d(2,stride=2)
37
+ self.conv = nn.Sequential(
38
+ nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False),
39
+ nn.BatchNorm2d(mid_ch),
40
+ nn.ReLU(inplace=True),
41
+ #Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d),
42
+ nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride = 1, padding=1, bias=False),
43
+ nn.BatchNorm2d(out_ch),
44
+ nn.ReLU(inplace=True),
45
+ )
46
+
47
+ def forward(self, x):
48
+ if self.down is not None:
49
+ x = self.down(x)
50
+ x = self.conv(x)
51
+ return x
52
+
53
+ class CRAFT_net(nn.Module):
54
+ def __init__(self):
55
+ super(CRAFT_net, self).__init__()
56
+ self.backbone = resnet34()
57
+
58
+ self.conv_rs = nn.Sequential(
59
+ nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
60
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
61
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
62
+ nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
63
+ nn.Conv2d(32, 1, kernel_size=1),
64
+ nn.Sigmoid()
65
+ )
66
+
67
+ self.conv_as = nn.Sequential(
68
+ nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
69
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
70
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
71
+ nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
72
+ nn.Conv2d(32, 1, kernel_size=1),
73
+ nn.Sigmoid()
74
+ )
75
+
76
+ self.conv_mask = nn.Sequential(
77
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
78
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
79
+ nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
80
+ nn.Conv2d(32, 1, kernel_size=1),
81
+ nn.Sigmoid()
82
+ )
83
+
84
+ self.down_conv1 = double_conv(0, 512, 512, 2)
85
+ self.down_conv2 = double_conv(0, 512, 512, 2)
86
+ self.down_conv3 = double_conv(0, 512, 512, 2)
87
+
88
+ self.upconv1 = double_conv(0, 512, 256)
89
+ self.upconv2 = double_conv(256, 512, 256)
90
+ self.upconv3 = double_conv(256, 512, 256)
91
+ self.upconv4 = double_conv(256, 512, 256, planes = 128)
92
+ self.upconv5 = double_conv(256, 256, 128, planes = 64)
93
+ self.upconv6 = double_conv(128, 128, 64, planes = 32)
94
+ self.upconv7 = double_conv(64, 64, 64, planes = 16)
95
+
96
+ def forward_train(self, x):
97
+ x = self.backbone.conv1(x)
98
+ x = self.backbone.bn1(x)
99
+ x = self.backbone.relu(x)
100
+ x = self.backbone.maxpool(x) # 64@384
101
+
102
+ h4 = self.backbone.layer1(x) # 64@384
103
+ h8 = self.backbone.layer2(h4) # 128@192
104
+ h16 = self.backbone.layer3(h8) # 256@96
105
+ h32 = self.backbone.layer4(h16) # 512@48
106
+ h64 = self.down_conv1(h32) # 512@24
107
+ h128 = self.down_conv2(h64) # 512@12
108
+ h256 = self.down_conv3(h128) # 512@6
109
+
110
+ up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
111
+ up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
112
+ up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
113
+ up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
114
+ up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
115
+ up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
116
+ up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768
117
+
118
+ ascore = self.conv_as(up4)
119
+ rscore = self.conv_rs(up4)
120
+
121
+ return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)
122
+
123
+ def forward(self, x):
124
+ x = self.backbone.conv1(x)
125
+ x = self.backbone.bn1(x)
126
+ x = self.backbone.relu(x)
127
+ x = self.backbone.maxpool(x) # 64@384
128
+
129
+ h4 = self.backbone.layer1(x) # 64@384
130
+ h8 = self.backbone.layer2(h4) # 128@192
131
+ h16 = self.backbone.layer3(h8) # 256@96
132
+ h32 = self.backbone.layer4(h16) # 512@48
133
+ h64 = self.down_conv1(h32) # 512@24
134
+ h128 = self.down_conv2(h64) # 512@12
135
+ h256 = self.down_conv3(h128) # 512@6
136
+
137
+ up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
138
+ up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
139
+ up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
140
+ up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
141
+ up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
142
+ up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
143
+ up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768
144
+
145
+ ascore = self.conv_as(up4)
146
+ rscore = self.conv_rs(up4)
147
+
148
+ return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)
149
+
150
+ if __name__ == '__main__':
151
+ net = CRAFT_net().cuda()
152
+ img = torch.randn(2, 3, 1536, 1536).cuda()
153
+ print(net.forward_train(img)[0].shape)