Pierre Fernandez
added encoding and decoding
9e6cbab
raw
history blame
3.64 kB
import gradio as gr
import gradio.inputs as grinputs
import gradio.outputs as groutputs
import numpy as np
import json
import torch
from torchvision import transforms
import utils
import utils_img
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
np.random.seed(0)
print('Building backbone and normalization layer...')
backbone = utils.build_backbone(path='dino_r50.pth')
normlayer = utils.load_normalization_layer(path='out2048.pth')
model = utils.NormLayerWrapper(backbone, normlayer)
print('Building the hypercone...')
FPR = 1e-6
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
rho = 1 + np.tan(angle)**2
# angle = utils.pvalue_angle(2048, 1, proba=FPR)
carrier = torch.randn(1, 2048)
carrier /= torch.norm(carrier, dim=1, keepdim=True)
default_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
img = img_orig.clone().to(device, non_blocking=True)
img.requires_grad = True
optimizer = torch.optim.Adam([img], lr=1e-2)
for iteration in range(epochs):
x = utils_img.ssim_attenuation(img, img_orig)
x = utils_img.psnr_clip(x, img_orig, psnr)
ft = model(x) # BxCxWxH -> BxD
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
cosines = torch.abs(dot_product/norm)
log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
loss = lambda_w*loss_R + lambda_i*loss_l2_img
optimizer.zero_grad()
loss.backward()
optimizer.step()
logs = {
"keyword": "img_optim",
"iteration": iteration,
"loss": loss.item(),
"loss_R": loss_R.item(),
"loss_l2_img": loss_l2_img.item(),
"log10_pvalue": log10_pvalue.item(),
}
print("__log__:%s" % json.dumps(logs))
img = utils_img.ssim_attenuation(img, img_orig)
img = utils_img.psnr_clip(img, img_orig, psnr)
img = utils_img.round_pixel(img)
img = img.squeeze(0).detach().cpu()
img = transforms.ToPILImage()(utils_img.unnormalize_img(img).squeeze(0))
return img
def decode(image):
img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
ft = model(img) # BxCxWxH -> BxD
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
cosines = torch.abs(dot_product/norm)
log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
text_marked = "marked" if loss_R < 0 else "unmarked"
return 'Image is {s}, with p-value={p}'.format(s=text_marked, p=10**log10_pvalue)
def on_submit(image, mode):
print('{} mode'.format(mode))
if mode=='Encode':
return encode(image), 'Successfully encoded'
else:
return image, decode(image)
iface = gr.Interface(
fn=on_submit,
inputs=[
grinputs.Image(),
grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")],
outputs=[
groutputs.Image(label='Watermarked image'),
groutputs.Textbox(label='Information')],
allow_screenshot=False,
allow_flagging="auto",
)
iface.launch()