File size: 3,878 Bytes
b1c6042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f353300
b1c6042
f353300
 
 
 
 
 
 
 
b1c6042
 
 
f353300
b1c6042
f353300
b1c6042
 
 
 
 
 
f353300
b1c6042
 
 
 
 
f353300
 
 
b1c6042
 
 
 
 
 
39188ed
494bbe5
 
3184bda
494bbe5
 
 
 
2ea1b4e
39188ed
 
d3ab28e
 
39188ed
 
b1c6042
39188ed
 
 
0cf68e3
39188ed
f353300
39188ed
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from model.nets import my_model
import torch
import cv2
import torch.utils.data as data
import torchvision.transforms as transforms
import PIL
from PIL import Image
from PIL import ImageFile
import math
import os
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model1 = my_model(en_feature_num=48,
                     en_inter_num=32,
                     de_feature_num=64,
                     de_inter_num=32,
                     sam_number=1,
                     ).to(device)

load_path1 = "./mix.pth"
model_state_dict1 = torch.load(load_path1, map_location=device)
model1.load_state_dict(model_state_dict1)


def default_toTensor(img):
    t_list = [transforms.ToTensor()]
    composed_transform = transforms.Compose(t_list)
    return composed_transform(img)

def predict1(img):
    in_img = transforms.ToTensor()(img).to(device).unsqueeze(0)
    b, c, h, w = in_img.size()
    # pad image such that the resolution is a multiple of 32
    w_pad = (math.ceil(w / 32) * 32 - w) // 2
    w_odd_pad = w_pad
    h_pad = (math.ceil(h / 32) * 32 - h) // 2
    h_odd_pad = h_pad

    if w % 2 == 1:
        w_odd_pad += 1
    if h % 2 == 1:
        h_odd_pad += 1

    in_img = img_pad(in_img, w_pad=w_pad, h_pad=h_pad, w_odd_pad=w_odd_pad, h_odd_pad=h_odd_pad)
    with torch.no_grad():
        out_1, out_2, out_3 = model1(in_img)
        if h_pad != 0:
            out_1 = out_1[:, :, h_pad:-h_odd_pad, :]
        if w_pad != 0:
            out_1 = out_1[:, :, :, w_pad:-w_odd_pad]
    out_1 = out_1.squeeze(0)
    out_1 = PIL.Image.fromarray(torch.clamp(out_1 * 255, min=0, max=255
    ).byte().permute(1, 2, 0).cpu().numpy())

    return out_1

def img_pad(x,  w_pad, h_pad, w_odd_pad, h_odd_pad):
    '''
    Here the padding values are determined by the average r,g,b values across the training set
    in FHDMi dataset. For the evaluation on the UHDM, you can also try the commented lines where
    the mean values are calculated from UHDM training set, yielding similar performance.
    '''
    x1 = F.pad(x[:, 0:1, ...], (w_pad, w_odd_pad, h_pad, h_odd_pad), value=0.3827)
    x2 = F.pad(x[:, 1:2, ...], (w_pad, w_odd_pad, h_pad, h_odd_pad), value=0.4141)
    x3 = F.pad(x[:, 2:3, ...], (w_pad, w_odd_pad, h_pad, h_odd_pad), value=0.3912)

    y = torch.cat([x1, x2, x3], dim=1)

    return y


title = "Clean Your Moire Images!"
description = " The model was trained to remove the moire patterns from your captured screen images! Specially, this model is capable of tackling \
images up to 4K resolution, which adapts to most of the modern mobile phones. \
<br /> \
(Note: It may cost 80s per 4K image (e.g., iPhone's resolution: 4032x3024) since this demo runs on the CPU. The model can run \
on a NVIDIA 3090 GPU 17ms per standard 4K image). \
<br /> \
The best way for a demo testing is using your mobile phone to capture a screen image, which may cause moire patterns. \
You can scan the [QR code](https://github.com/CVMI-Lab/UHDM/blob/main/figures/QR.jpg) to play on your mobile phone. "

article = "Check out the [ECCV 2022 paper](https://arxiv.org/abs/2207.09935) and the \
            [official training code](https://github.com/CVMI-Lab/UHDM) which the demo is based on.\
            <center><img src='https://visitor-badge.glitch.me/badge?page_id=Andyx_screen_image_demoire' alt='visitor badge'></center>"


iface1 = gr.Interface(fn=predict1,
                     inputs=gr.inputs.Image(type="pil"),
                     outputs=gr.inputs.Image(type="pil"),
                     examples=['001.jpg',
                            '002.jpg',
                            '005.jpg'],
                    title = title,
                    description = description,
                    article = article
                    )


iface1.launch()