import gradio as gr from model.nets import my_model import torch import cv2 import torch.utils.data as data import torchvision.transforms as transforms import PIL from PIL import Image from PIL import ImageFile import math import os import torch.nn.functional as F os.environ["CUDA_VISIBLE_DEVICES"] = "1" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model1 = my_model(en_feature_num=48, en_inter_num=32, de_feature_num=64, de_inter_num=32, sam_number=1, ).to(device) load_path1 = "./mix.pth" model_state_dict1 = torch.load(load_path1, map_location=device) model1.load_state_dict(model_state_dict1) def default_toTensor(img): t_list = [transforms.ToTensor()] composed_transform = transforms.Compose(t_list) return composed_transform(img) def predict1(img): in_img = transforms.ToTensor()(img).to(device).unsqueeze(0) b, c, h, w = in_img.size() # pad image such that the resolution is a multiple of 32 w_pad = (math.ceil(w / 32) * 32 - w) // 2 w_odd_pad = w_pad h_pad = (math.ceil(h / 32) * 32 - h) // 2 h_odd_pad = h_pad if w % 2 == 1: w_odd_pad += 1 if h % 2 == 1: h_odd_pad += 1 in_img = img_pad(in_img, w_pad=w_pad, h_pad=h_pad, w_odd_pad=w_odd_pad, h_odd_pad=h_odd_pad) with torch.no_grad(): out_1, out_2, out_3 = model1(in_img) if h_pad != 0: out_1 = out_1[:, :, h_pad:-h_odd_pad, :] if w_pad != 0: out_1 = out_1[:, :, :, w_pad:-w_odd_pad] out_1 = out_1.squeeze(0) out_1 = PIL.Image.fromarray(torch.clamp(out_1 * 255, min=0, max=255 ).byte().permute(1, 2, 0).cpu().numpy()) return out_1 def img_pad(x, w_pad, h_pad, w_odd_pad, h_odd_pad): ''' Here the padding values are determined by the average r,g,b values across the training set in FHDMi dataset. For the evaluation on the UHDM, you can also try the commented lines where the mean values are calculated from UHDM training set, yielding similar performance. ''' x1 = F.pad(x[:, 0:1, ...], (w_pad, w_odd_pad, h_pad, h_odd_pad), value=0.3827) x2 = F.pad(x[:, 1:2, ...], (w_pad, w_odd_pad, h_pad, h_odd_pad), value=0.4141) x3 = F.pad(x[:, 2:3, ...], (w_pad, w_odd_pad, h_pad, h_odd_pad), value=0.3912) y = torch.cat([x1, x2, x3], dim=1) return y title = "Clean Your Moire Images!" description = """ The model was trained to remove the moire patterns from your captured screen images! Specially, this model is capable of tackling images up to 4K resolution, which adapts to most of the modern mobile phones. (Note: It may cost 80s per 4K image (e.g., iPhone's resolution: 4032x3024) since this demo runs on the CPU. The model can run on a NVIDIA 3090 GPU 17ms per standard 4K image). The best way for a demo testing is using your mobile phone to capture a screen image, which may cause moire patterns. You can scan the QR code to play on your mobile phone.