DmitrMakeev
commited on
Commit
•
39946a9
1
Parent(s):
a3e6b3a
Upload 9 files
Browse files- .gitattributes +33 -0
- LICENSE +21 -0
- README.md +13 -0
- app.py +91 -0
- cog.yaml +41 -0
- environment.yml +121 -0
- inference.py +105 -0
- predict.py +104 -0
- requirements.txt +17 -0
.gitattributes
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
25 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Menghan Xia
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: colorizator
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.9
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: openrail
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os, requests
|
3 |
+
import numpy as np
|
4 |
+
from inference import setup_model, colorize_grayscale, predict_anchors
|
5 |
+
|
6 |
+
## local | remote
|
7 |
+
RUN_MODE = "remote"
|
8 |
+
if RUN_MODE != "local":
|
9 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/disco-beta.pth.rar")
|
10 |
+
os.rename("disco-beta.pth.rar", "./checkpoints/disco-beta.pth.rar")
|
11 |
+
## examples
|
12 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/01.jpg")
|
13 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/02.jpg")
|
14 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/03.jpg")
|
15 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/04.jpg")
|
16 |
+
|
17 |
+
## step 1: set up model
|
18 |
+
device = "cpu"
|
19 |
+
checkpt_path = "checkpoints/disco-beta.pth.rar"
|
20 |
+
colorizer, colorLabeler = setup_model(checkpt_path, device=device)
|
21 |
+
|
22 |
+
def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
|
23 |
+
if hint_img is None:
|
24 |
+
hint_img = rgb_img
|
25 |
+
output = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, True, is_editable, device)
|
26 |
+
output1 = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, False, is_editable, device)
|
27 |
+
return output, output1
|
28 |
+
|
29 |
+
def click_predanchors(rgb_img, n_anchors, is_high_res, is_editable):
|
30 |
+
output = predict_anchors(colorizer, colorLabeler, rgb_img, n_anchors, is_high_res, is_editable, device)
|
31 |
+
return output
|
32 |
+
|
33 |
+
## step 2: configure interface
|
34 |
+
def switch_states(is_checked):
|
35 |
+
if is_checked:
|
36 |
+
return gr.Image.update(visible=True), gr.Button.update(visible=True)
|
37 |
+
else:
|
38 |
+
return gr.Image.update(visible=False), gr.Button.update(visible=False)
|
39 |
+
|
40 |
+
demo = gr.Blocks(title="DISCO")
|
41 |
+
with demo:
|
42 |
+
gr.Markdown(value="""
|
43 |
+
**Gradio demo for DISCO: Disentangled Image Colorization via Global Anchors**. Check our [project page](https://menghanxia.github.io/projects/disco.html) 😛.
|
44 |
+
""")
|
45 |
+
with gr.Row():
|
46 |
+
with gr.Column():
|
47 |
+
with gr.Row():
|
48 |
+
Image_input = gr.Image(type="numpy", label="Input", interactive=True)
|
49 |
+
Image_anchor = gr.Image(type="numpy", label="Anchor", tool="color-sketch", interactive=True, visible=False)
|
50 |
+
with gr.Row():
|
51 |
+
Num_anchor = gr.Number(type="int", value=8, label="Num. of anchors (3~14)")
|
52 |
+
Radio_resolution = gr.Radio(type="index", choices=["Low (256x256)", "High (512x512)"], \
|
53 |
+
label="Colorization resolution (Low is more stable)", value="Low (256x256)")
|
54 |
+
with gr.Row():
|
55 |
+
Ckeckbox_editable = gr.Checkbox(default=False, label='Show editable anchors')
|
56 |
+
Button_show_anchor = gr.Button(value="Predict anchors", visible=False)
|
57 |
+
Button_run = gr.Button(value="Colorize")
|
58 |
+
with gr.Column():
|
59 |
+
Image_output = [gr.Image(type="numpy", label="Output").style(height=480), gr.Image(type="numpy", label="Output").style(height=480)]
|
60 |
+
|
61 |
+
Ckeckbox_editable.change(fn=switch_states, inputs=Ckeckbox_editable, outputs=[Image_anchor, Button_show_anchor])
|
62 |
+
Button_show_anchor.click(fn=click_predanchors, inputs=[Image_input, Num_anchor, Radio_resolution, Ckeckbox_editable], outputs=Image_anchor)
|
63 |
+
Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
|
64 |
+
outputs=Image_output)
|
65 |
+
|
66 |
+
## guiline
|
67 |
+
gr.Markdown(value="""
|
68 |
+
🔔**Guideline**
|
69 |
+
1. Upload your image or select one from the examples.
|
70 |
+
2. Set up the arguments: "Num. of anchors" and "Colorization resolution".
|
71 |
+
3. Run the colorization (two modes supported):
|
72 |
+
- 📀Automatic mode: **Click** "Colorize" to get the automatically colorized output.
|
73 |
+
- ✏️Editable mode: **Check** ""Show editable anchors"; **Click** "Predict anchors"; **Redraw** the anchor colors (only anchor region will be used); **Click** "Colorize" to get the result.
|
74 |
+
""")
|
75 |
+
if RUN_MODE != "local":
|
76 |
+
gr.Examples(examples=[
|
77 |
+
['01.jpg', 8, "Low (256x256)"],
|
78 |
+
['02.jpg', 8, "Low (256x256)"],
|
79 |
+
['03.jpg', 8, "Low (256x256)"],
|
80 |
+
['04.jpg', 8, "Low (256x256)"],
|
81 |
+
],
|
82 |
+
inputs=[Image_input,Num_anchor,Radio_resolution], outputs=[Image_output], label="Examples", cache_examples=False)
|
83 |
+
gr.HTML(value="""
|
84 |
+
<p style="text-align:center; color:orange"><a href='https://menghanxia.github.io/projects/disco.html' target='_blank'>DISCO Project Page</a> | <a href='https://github.com/MenghanXia/DisentangledColorization' target='_blank'>Github Repo</a></p>
|
85 |
+
""")
|
86 |
+
|
87 |
+
if RUN_MODE == "local":
|
88 |
+
demo.launch(server_name='9.134.253.83',server_port=7788)
|
89 |
+
else:
|
90 |
+
demo.queue(default_enabled=True, status_update_rate=5)
|
91 |
+
demo.launch()
|
cog.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
build:
|
5 |
+
# set to true if your model requires a GPU
|
6 |
+
cuda: "10.2"
|
7 |
+
gpu: true
|
8 |
+
|
9 |
+
# a list of ubuntu apt packages to install
|
10 |
+
system_packages:
|
11 |
+
# - "libgl1-mesa-glx"
|
12 |
+
# - "libglib2.0-0"
|
13 |
+
- "libgl1-mesa-dev"
|
14 |
+
|
15 |
+
# python version in the form '3.8' or '3.8.12'
|
16 |
+
python_version: "3.8"
|
17 |
+
|
18 |
+
# a list of packages in the format <package-name>==<version>
|
19 |
+
python_packages:
|
20 |
+
# - "numpy==1.19.4"
|
21 |
+
# - "torch==1.8.0"
|
22 |
+
# - "torchvision==0.9.0"
|
23 |
+
- "numpy==1.23.1"
|
24 |
+
- "torch==1.8.0"
|
25 |
+
- "torchvision==0.9.0"
|
26 |
+
- "opencv-python==4.6.0.66"
|
27 |
+
- "pandas==1.4.3"
|
28 |
+
- "pillow==9.2.0"
|
29 |
+
- "tqdm==4.64.0"
|
30 |
+
- "scikit-image==0.19.3"
|
31 |
+
- "scikit-learn==1.1.2"
|
32 |
+
- "scipy==1.9.1"
|
33 |
+
|
34 |
+
# commands run after the environment is setup
|
35 |
+
# run:
|
36 |
+
# - "echo env is ready!"
|
37 |
+
# - "echo another command if needed"
|
38 |
+
|
39 |
+
# predict.py defines how predictions are run on your model
|
40 |
+
predict: "predict.py:Predictor"
|
41 |
+
#image: "r8.im/menghanxia/disco"
|
environment.yml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: DISCO
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
- conda-forge
|
6 |
+
dependencies:
|
7 |
+
- blas=1.0=mkl
|
8 |
+
- bzip2=1.0.8=h7b6447c_0
|
9 |
+
- ca-certificates=2022.07.19=h06a4308_0
|
10 |
+
- certifi=2022.6.15=py38h06a4308_0
|
11 |
+
- cudatoolkit=10.2.89=hfd86e86_1
|
12 |
+
- freetype=2.11.0=h70c0345_0
|
13 |
+
- giflib=5.2.1=h7b6447c_0
|
14 |
+
- gmp=6.2.1=h295c915_3
|
15 |
+
- gnutls=3.6.15=he1e5248_0
|
16 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
17 |
+
- jpeg=9b=h024ee3a_2
|
18 |
+
- lame=3.100=h7b6447c_0
|
19 |
+
- lcms2=2.12=h3be6417_0
|
20 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
21 |
+
- libffi=3.3=he6710b0_2
|
22 |
+
- libgcc-ng=11.2.0=h1234567_1
|
23 |
+
- libiconv=1.16=h7f8727e_2
|
24 |
+
- libidn2=2.3.2=h7f8727e_0
|
25 |
+
- libpng=1.6.37=hbc83047_0
|
26 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
27 |
+
- libtasn1=4.16.0=h27cfd23_0
|
28 |
+
- libtiff=4.1.0=h2733197_1
|
29 |
+
- libunistring=0.9.10=h27cfd23_0
|
30 |
+
- libuv=1.40.0=h7b6447c_0
|
31 |
+
- libwebp=1.2.0=h89dd481_0
|
32 |
+
- lz4-c=1.9.3=h295c915_1
|
33 |
+
- mkl=2021.4.0=h06a4308_640
|
34 |
+
- mkl-service=2.4.0=py38h7f8727e_0
|
35 |
+
- mkl_fft=1.3.1=py38hd3c417c_0
|
36 |
+
- mkl_random=1.2.2=py38h51133e4_0
|
37 |
+
- ncurses=6.3=h5eee18b_3
|
38 |
+
- nettle=3.7.3=hbbd107a_1
|
39 |
+
- ninja=1.10.2=h06a4308_5
|
40 |
+
- ninja-base=1.10.2=hd09550d_5
|
41 |
+
- numpy=1.23.1=py38h6c91a56_0
|
42 |
+
- numpy-base=1.23.1=py38ha15fc14_0
|
43 |
+
- openh264=2.1.1=h4ff587b_0
|
44 |
+
- openssl=1.1.1q=h7f8727e_0
|
45 |
+
- pillow=9.2.0=py38hace64e9_1
|
46 |
+
- pip=22.1.2=py38h06a4308_0
|
47 |
+
- python=3.8.13=h12debd9_0
|
48 |
+
- readline=8.1.2=h7f8727e_1
|
49 |
+
- setuptools=63.4.1=py38h06a4308_0
|
50 |
+
- six=1.16.0=pyhd3eb1b0_1
|
51 |
+
- sqlite=3.39.2=h5082296_0
|
52 |
+
- tk=8.6.12=h1ccaba5_0
|
53 |
+
- typing_extensions=4.3.0=py38h06a4308_0
|
54 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
55 |
+
- xz=5.2.5=h7f8727e_1
|
56 |
+
- zlib=1.2.12=h7f8727e_2
|
57 |
+
- zstd=1.4.9=haebb681_0
|
58 |
+
- ffmpeg=4.3=hf484d3e_0
|
59 |
+
- pytorch=1.8.0=py3.8_cuda10.2_cudnn7.6.5_0
|
60 |
+
- torchaudio=0.8.0=py38
|
61 |
+
- torchvision=0.9.0=py38_cu102
|
62 |
+
- pip:
|
63 |
+
- addict==2.4.0
|
64 |
+
- astunparse==1.6.3
|
65 |
+
- cachetools==4.2.4
|
66 |
+
- charset-normalizer==2.0.7
|
67 |
+
- clang==5.0
|
68 |
+
- cycler==0.11.0
|
69 |
+
- flatbuffers==1.12
|
70 |
+
- fonttools==4.37.1
|
71 |
+
- future==0.18.2
|
72 |
+
- gast==0.4.0
|
73 |
+
- google-auth==2.3.2
|
74 |
+
- google-auth-oauthlib==0.4.6
|
75 |
+
- google-pasta==0.2.0
|
76 |
+
- grpcio==1.41.1
|
77 |
+
- h5py==3.1.0
|
78 |
+
- idna==3.3
|
79 |
+
- imageio==2.21.1
|
80 |
+
- joblib==1.1.0
|
81 |
+
- keras==2.6.0
|
82 |
+
- keras-preprocessing==1.1.2
|
83 |
+
- kiwisolver==1.4.4
|
84 |
+
- lpips==0.1.4
|
85 |
+
- markdown==3.3.4
|
86 |
+
- matplotlib==3.5.3
|
87 |
+
- networkx==2.8.6
|
88 |
+
- oauthlib==3.1.1
|
89 |
+
- opencv-python==4.6.0.66
|
90 |
+
- opt-einsum==3.3.0
|
91 |
+
- packaging==21.3
|
92 |
+
- pandas==1.4.3
|
93 |
+
- protobuf==3.19.0
|
94 |
+
- pyasn1==0.4.8
|
95 |
+
- pyasn1-modules==0.2.8
|
96 |
+
- pyparsing==3.0.9
|
97 |
+
- python-dateutil==2.8.2
|
98 |
+
- pytz==2022.2.1
|
99 |
+
- pywavelets==1.3.0
|
100 |
+
- pyyaml==6.0
|
101 |
+
- requests==2.26.0
|
102 |
+
- requests-oauthlib==1.3.0
|
103 |
+
- rsa==4.7.2
|
104 |
+
- scikit-image==0.19.3
|
105 |
+
- scikit-learn==1.1.2
|
106 |
+
- scipy==1.9.1
|
107 |
+
- tensorboard-data-server==0.6.1
|
108 |
+
- tensorboard-plugin-wit==1.8.0
|
109 |
+
- tensorflow-estimator==2.6.0
|
110 |
+
- tensorflow-gpu==2.6.0
|
111 |
+
- termcolor==1.1.0
|
112 |
+
- threadpoolctl==3.1.0
|
113 |
+
- tifffile==2022.8.12
|
114 |
+
- torch==1.8.0
|
115 |
+
- tqdm==4.64.0
|
116 |
+
- urllib3==1.26.7
|
117 |
+
- werkzeug==2.0.2
|
118 |
+
- wrapt==1.12.1
|
119 |
+
- yapf==0.32.0
|
120 |
+
prefix: /root/data/programs/anaconda3/envs/DISCO
|
121 |
+
|
inference.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob, sys, logging
|
2 |
+
import argparse, datetime, time
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from models import model, basic
|
10 |
+
from utils import util
|
11 |
+
|
12 |
+
|
13 |
+
def setup_model(checkpt_path, device="cuda"):
|
14 |
+
#print('--------------', torch.cuda.is_available())
|
15 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
16 |
+
colorLabeler = basic.ColorLabel(device=device)
|
17 |
+
colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
|
18 |
+
colorizer = colorizer.to(device)
|
19 |
+
#checkpt_path = "./checkpoints/disco-beta.pth.rar"
|
20 |
+
assert os.path.exists(checkpt_path), "No checkpoint found!"
|
21 |
+
data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
|
22 |
+
colorizer.load_state_dict(data_dict['state_dict'])
|
23 |
+
colorizer.eval()
|
24 |
+
return colorizer, colorLabeler
|
25 |
+
|
26 |
+
|
27 |
+
def resize_ab2l(gray_img, lab_imgs, vis=False):
|
28 |
+
H, W = gray_img.shape[:2]
|
29 |
+
reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
|
30 |
+
if vis:
|
31 |
+
gray_img = cv2.resize(lab_imgs[:,:,:1], (W,H), interpolation=cv2.INTER_LINEAR)
|
32 |
+
return np.concatenate((gray_img[:,:,np.newaxis], reszied_ab), axis=2)
|
33 |
+
else:
|
34 |
+
return np.concatenate((gray_img, reszied_ab), axis=2)
|
35 |
+
|
36 |
+
def prepare_data(rgb_img, target_res):
|
37 |
+
rgb_img = np.array(rgb_img / 255., np.float32)
|
38 |
+
lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
|
39 |
+
org_grays = (lab_img[:,:,[0]]-50.) / 50.
|
40 |
+
lab_img = cv2.resize(lab_img, target_res, interpolation=cv2.INTER_LINEAR)
|
41 |
+
|
42 |
+
lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
|
43 |
+
gray_img = (lab_img[0:1,:,:]-50.) / 50.
|
44 |
+
ab_chans = lab_img[1:3,:,:] / 110.
|
45 |
+
input_grays = gray_img.unsqueeze(0)
|
46 |
+
input_colors = ab_chans.unsqueeze(0)
|
47 |
+
return input_grays, input_colors, org_grays
|
48 |
+
|
49 |
+
|
50 |
+
def colorize_grayscale(colorizer, color_class, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device="cuda"):
|
51 |
+
n_anchors = int(n_anchors)
|
52 |
+
n_anchors = max(n_anchors, 3)
|
53 |
+
n_anchors = min(n_anchors, 14)
|
54 |
+
target_res = (512,512) if is_high_res else (256,256)
|
55 |
+
input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
|
56 |
+
input_grays = input_grays.to(device)
|
57 |
+
input_colors = input_colors.to(device)
|
58 |
+
|
59 |
+
if is_editable:
|
60 |
+
print('>>>:editable mode')
|
61 |
+
sampled_T = -1
|
62 |
+
_, input_colors, _ = prepare_data(hint_img, target_res)
|
63 |
+
input_colors = input_colors.to(device)
|
64 |
+
pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
|
65 |
+
input_colors, n_anchors, sampled_T)
|
66 |
+
else:
|
67 |
+
print('>>>:automatic mode')
|
68 |
+
sampled_T = 0
|
69 |
+
pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
|
70 |
+
input_colors, n_anchors, sampled_T)
|
71 |
+
|
72 |
+
pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
|
73 |
+
lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
|
74 |
+
lab_imgs = resize_ab2l(org_grays, lab_imgs)
|
75 |
+
|
76 |
+
lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
|
77 |
+
lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
|
78 |
+
rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
|
79 |
+
return (rgb_output*255.0).astype(np.uint8)
|
80 |
+
|
81 |
+
|
82 |
+
def predict_anchors(colorizer, color_class, rgb_img, n_anchors, is_high_res, is_editable, device="cuda"):
|
83 |
+
n_anchors = int(n_anchors)
|
84 |
+
n_anchors = max(n_anchors, 3)
|
85 |
+
n_anchors = min(n_anchors, 14)
|
86 |
+
target_res = (512,512) if is_high_res else (256,256)
|
87 |
+
input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
|
88 |
+
input_grays = input_grays.to(device)
|
89 |
+
input_colors = input_colors.to(device)
|
90 |
+
|
91 |
+
sampled_T, sp_size = 0, 16
|
92 |
+
pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
|
93 |
+
input_colors, n_anchors, sampled_T)
|
94 |
+
pred_probs = pal_logit
|
95 |
+
guided_colors = color_class.decode_ind2ab(ref_logit, T=0)
|
96 |
+
guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
|
97 |
+
anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
|
98 |
+
marked_labs = basic.mark_color_hints(input_grays, guided_colors, anchor_masks, base_ABs=None)
|
99 |
+
lab_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
|
100 |
+
lab_imgs = resize_ab2l(org_grays, lab_imgs, vis=True)
|
101 |
+
|
102 |
+
lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
|
103 |
+
lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
|
104 |
+
rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
|
105 |
+
return (rgb_output*255.0).astype(np.uint8)
|
predict.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prediction interface for Cog ⚙️
|
2 |
+
# https://github.com/replicate/cog/blob/main/docs/python.md
|
3 |
+
|
4 |
+
from cog import BasePredictor, Input, Path
|
5 |
+
import tempfile
|
6 |
+
import os, glob
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
from PIL import Image
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from models import model, basic
|
14 |
+
from utils import util
|
15 |
+
|
16 |
+
class Predictor(BasePredictor):
|
17 |
+
def setup(self):
|
18 |
+
seed = 130
|
19 |
+
np.random.seed(seed)
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
torch.cuda.manual_seed(seed)
|
22 |
+
#print('--------------', torch.cuda.is_available())
|
23 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
24 |
+
self.colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True)
|
25 |
+
self.colorizer = self.colorizer.cuda()
|
26 |
+
checkpt_path = "./checkpoints/disco-beta.pth.rar"
|
27 |
+
assert os.path.exists(checkpt_path)
|
28 |
+
data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
|
29 |
+
self.colorizer.load_state_dict(data_dict['state_dict'])
|
30 |
+
self.colorizer.eval()
|
31 |
+
self.color_class = basic.ColorLabel(lambda_=0.5, device='cuda')
|
32 |
+
|
33 |
+
def resize_ab2l(self, gray_img, lab_imgs):
|
34 |
+
H, W = gray_img.shape[:2]
|
35 |
+
reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
|
36 |
+
return np.concatenate((gray_img, reszied_ab), axis=2)
|
37 |
+
|
38 |
+
def predict(
|
39 |
+
self,
|
40 |
+
image: Path = Input(description="input image. Output will be one or multiple colorized images."),
|
41 |
+
n_anchors: int = Input(
|
42 |
+
description="number of color anchors", ge=3, le=14, default=8
|
43 |
+
),
|
44 |
+
multi_result: bool = Input(
|
45 |
+
description="to generate diverse results", default=False
|
46 |
+
),
|
47 |
+
vis_anchors: bool = Input(
|
48 |
+
description="to visualize the anchor locations", default=False
|
49 |
+
)
|
50 |
+
) -> Path:
|
51 |
+
"""Run a single prediction on the model"""
|
52 |
+
bgr_img = cv2.imread(str(image), cv2.IMREAD_COLOR)
|
53 |
+
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
54 |
+
rgb_img = np.array(rgb_img / 255., np.float32)
|
55 |
+
lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
|
56 |
+
org_grays = (lab_img[:,:,[0]]-50.) / 50.
|
57 |
+
lab_img = cv2.resize(lab_img, (256,256), interpolation=cv2.INTER_LINEAR)
|
58 |
+
|
59 |
+
lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
|
60 |
+
gray_img = (lab_img[0:1,:,:]-50.) / 50.
|
61 |
+
ab_chans = lab_img[1:3,:,:] / 110.
|
62 |
+
input_grays = gray_img.unsqueeze(0)
|
63 |
+
input_colors = ab_chans.unsqueeze(0)
|
64 |
+
input_grays = input_grays.cuda(non_blocking=True)
|
65 |
+
input_colors = input_colors.cuda(non_blocking=True)
|
66 |
+
|
67 |
+
sampled_T = 2 if multi_result else 0
|
68 |
+
pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = self.colorizer(input_grays, \
|
69 |
+
input_colors, n_anchors, True, sampled_T)
|
70 |
+
pred_probs = pal_logit
|
71 |
+
guided_colors = self.color_class.decode_ind2ab(ref_logit, T=0)
|
72 |
+
sp_size = 16
|
73 |
+
guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
|
74 |
+
res_list = []
|
75 |
+
if multi_result:
|
76 |
+
for no in range(3):
|
77 |
+
pred_labs = torch.cat((input_grays,enhanced_ab[no:no+1,:,:,:]), dim=1)
|
78 |
+
lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
|
79 |
+
lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
|
80 |
+
#util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1, suffix='c%d'%no)
|
81 |
+
res_list.append(lab_imgs)
|
82 |
+
else:
|
83 |
+
pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
|
84 |
+
lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
|
85 |
+
lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
|
86 |
+
#util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1)#, suffix='enhanced')
|
87 |
+
res_list.append(lab_imgs)
|
88 |
+
|
89 |
+
if vis_anchors:
|
90 |
+
## visualize anchor locations
|
91 |
+
anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
|
92 |
+
marked_labs = basic.mark_color_hints(input_grays, enhanced_ab, anchor_masks, base_ABs=enhanced_ab)
|
93 |
+
hint_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
|
94 |
+
hint_imgs = self.resize_ab2l(org_grays, hint_imgs)
|
95 |
+
#util.save_normLabs_from_batch(hint_imgs, save_dir, [file_name], -1, suffix='anchors')
|
96 |
+
res_list.append(hint_imgs)
|
97 |
+
|
98 |
+
output = cv2.vconcat(res_list)
|
99 |
+
output[:,:,0] = output[:,:,0] * 50.0 + 50.0
|
100 |
+
output[:,:,1:3] = output[:,:,1:3] * 110.0
|
101 |
+
rgb_output = cv2.cvtColor(output[:,:,:], cv2.COLOR_LAB2BGR)
|
102 |
+
out_path = Path(tempfile.mkdtemp()) / "out.png"
|
103 |
+
cv2.imwrite(str(out_path), (rgb_output*255.0).astype(np.uint8))
|
104 |
+
return out_path
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addict
|
2 |
+
future
|
3 |
+
numpy
|
4 |
+
opencv-python
|
5 |
+
pandas
|
6 |
+
Pillow
|
7 |
+
pyyaml
|
8 |
+
requests
|
9 |
+
scikit-image
|
10 |
+
scikit-learn
|
11 |
+
scipy
|
12 |
+
torch>=1.8.0
|
13 |
+
torchvision
|
14 |
+
tensorboardx>=2.4
|
15 |
+
tqdm
|
16 |
+
yapf
|
17 |
+
lpips
|