File size: 3,652 Bytes
39946a9
 
 
 
 
 
 
 
 
 
 
71aa43a
24bdcbc
39946a9
 
 
 
 
 
 
 
 
5ceb3c0
 
39946a9
 
 
 
 
 
 
 
 
 
 
ad90eb1
c9f3f27
 
5f46008
d9d20bd
5f46008
2024b6d
1566524
 
 
c9f3f27
 
09920f1
c9f3f27
9d7d381
 
 
c9f3f27
4d2ff54
c9f3f27
4b16c0f
c9f3f27
 
39946a9
c9f3f27
 
 
 
 
 
 
09920f1
c9f3f27
 
 
71aa43a
c9f3f27
 
 
 
71aa43a
c9f3f27
39946a9
c9f3f27
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import os, requests
import numpy as np
from inference import setup_model, colorize_grayscale, predict_anchors

## local |  remote
RUN_MODE = "remote"
if RUN_MODE != "local":
    os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/disco-beta.pth.rar")
    os.rename("disco-beta.pth.rar", "./checkpoints/disco-beta.pth.rar")
    ## examples
    
    os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/04.jpg")

## step 1: set up model
device = "cpu"
checkpt_path = "checkpoints/disco-beta.pth.rar"
colorizer, colorLabeler = setup_model(checkpt_path, device=device)

def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
    if hint_img is None:
        hint_img = rgb_img
    output = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device)
    return output

def click_predanchors(rgb_img, n_anchors, is_high_res, is_editable):
    output = predict_anchors(colorizer, colorLabeler, rgb_img, n_anchors, is_high_res, is_editable, device)
    return output

## step 2: configure interface
def switch_states(is_checked):
    if is_checked:
        return gr.Image.update(visible=True), gr.Button.update(visible=True)
    else:
        return gr.Image.update(visible=False), gr.Button.update(visible=False)

demo = gr.Blocks(title="DISCO")
with demo:
    gr.HTML(value="""
                        <div style="text-align:center; font-size: 32px;">Раскрашивание черно-белой картинки</div>
                    """)
    
    with gr.Row():
        with gr.Column():
            with gr.Row():
                Image_input = gr.Image(type="numpy", label="Input", interactive=True)
                Image_anchor = gr.Image(type="numpy", label="Anchor", tool="color-sketch", interactive=True, visible=False)
                
            with gr.Row():
                Num_anchor = gr.Number(type="int", value=8, label="Количество опорных точек (3~14)")
                Radio_resolution = gr.Radio(type="index", choices=["Low (256x256)", "Medium (512x512)", "High (1024x1024)"], \
                                                label="Область для раскрашивания кистью", value="Low (256x256)")
            with gr.Row():
                Ckeckbox_editable = gr.Checkbox(default=False, label='Загрузить редактор')
                Button_show_anchor = gr.Button(value="Predict anchors", visible=False)
            Button_run = gr.Button(value="Исполнить")
        with gr.Column():
            Image_output = gr.Image(type="numpy", label="Output").style(height=480)

    Ckeckbox_editable.change(fn=switch_states, inputs=Ckeckbox_editable, outputs=[Image_anchor, Button_show_anchor])
    Button_show_anchor.click(fn=click_predanchors, inputs=[Image_input, Num_anchor, Radio_resolution, Ckeckbox_editable], outputs=Image_anchor)
    Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
                    outputs=Image_output)
    
    ## guiline
    gr.Markdown(value="""    
                    
                    """)
    if RUN_MODE != "local":
        gr.Examples(examples=[
                    
                    ['04.jpg', 8, "Low (256x256)"],
                    ], 
                    inputs=[Image_input,Num_anchor,Radio_resolution], outputs=[Image_output], label="Examples")
    gr.HTML(value="""
                
                    """)

if RUN_MODE == "local":
    demo.launch(server_name='9.134.253.83',server_port=7788)
else:
    demo.launch()