|
|
|
|
|
""" |
|
streamlit app demo |
|
how to run: |
|
streamlit run app.py --server.port 8501 |
|
|
|
@author: Tu Bui @surrey.ac.uk |
|
""" |
|
import os, sys, torch |
|
import argparse |
|
from pathlib import Path |
|
import numpy as np |
|
import pickle |
|
import pytorch_lightning as pl |
|
from torchvision import transforms |
|
import argparse |
|
from ldm.util import instantiate_from_config |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from tools.augment_imagenetc import RandomImagenetC |
|
from io import BytesIO |
|
from tools.helpers import welcome_message |
|
from tools.ecc import BCH, RSC |
|
|
|
import streamlit as st |
|
from streamlit.source_util import ( |
|
page_icon_and_name, |
|
calc_md5, |
|
get_pages, |
|
_on_pages_changed |
|
) |
|
|
|
model_names = ['UNet'] |
|
|
|
|
|
def delete_page(main_script_path_str, page_name): |
|
|
|
current_pages = get_pages(main_script_path_str) |
|
|
|
for key, value in current_pages.items(): |
|
print(value['page_name']) |
|
if value['page_name'] == page_name: |
|
del current_pages[key] |
|
break |
|
else: |
|
pass |
|
_on_pages_changed.send() |
|
|
|
|
|
def add_page(main_script_path_str, page_name): |
|
|
|
pages = get_pages(main_script_path_str) |
|
main_script_path = Path(main_script_path_str) |
|
pages_dir = main_script_path.parent / "pages" |
|
|
|
script_path = [f for f in list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")) if f.name.find(page_name) != -1][0] |
|
script_path_str = str(script_path.resolve()) |
|
pi, pn = page_icon_and_name(script_path) |
|
psh = calc_md5(script_path_str) |
|
pages[psh] = { |
|
"page_script_hash": psh, |
|
"page_name": pn, |
|
"icon": pi, |
|
"script_path": script_path_str, |
|
} |
|
_on_pages_changed.send() |
|
|
|
def unormalize(x): |
|
|
|
x = torch.clamp((x + 1) * 127.5, 0, 255).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) |
|
return x |
|
|
|
def to_bytes(x, mime): |
|
x = Image.fromarray(x) |
|
buf = BytesIO() |
|
f = "JPEG" if mime == 'image/jpeg' else "PNG" |
|
x.save(buf, format=f) |
|
byte_im = buf.getvalue() |
|
return byte_im |
|
|
|
|
|
def load_UNet(args): |
|
print('args: ', args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_file = args.config |
|
weight_file = args.weight |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
if weight_file.startswith('http'): |
|
weight_dir = Path('./weights') |
|
weight_dir.mkdir(exist_ok=True) |
|
weight_path = weight_dir / weight_file.split('/')[-1] |
|
config_path = weight_dir / config_file.split('/')[-1] |
|
if not weight_path.exists(): |
|
import wget |
|
print(f'Downloading {weight_file}...') |
|
with st.spinner("Downloading model... this may take awhile!"): |
|
wget.download(weight_file, str(weight_path)) |
|
wget.download(config_file, str(config_path)) |
|
weight_file = str(weight_path) |
|
config_file = str(config_path) |
|
|
|
config = OmegaConf.load(config_file).model |
|
secret_len = config.params.secret_len |
|
print(f'Secret length: {secret_len}') |
|
model = instantiate_from_config(config) |
|
state_dict = torch.load(weight_file, map_location=torch.device('cpu')) |
|
if 'global_step' in state_dict: |
|
print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}') |
|
|
|
if 'state_dict' in state_dict: |
|
state_dict = state_dict['state_dict'] |
|
misses, ignores = model.load_state_dict(state_dict, strict=False) |
|
print(f'Missed keys: {misses}\nIgnore keys: {ignores}') |
|
model = model.to(device) |
|
model.eval() |
|
return model, secret_len |
|
|
|
def embed_secret(model_name, model, cover, tform, secret): |
|
if model_name == 'UNet': |
|
w, h = cover.size |
|
with torch.no_grad(): |
|
im = tform(cover).unsqueeze(0).to(model.device) |
|
stego, _ = model(im, secret) |
|
res = (stego.clamp(-1,1) - im) |
|
res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear') |
|
res = res.permute(0,2,3,1).cpu().numpy() |
|
stego_uint8 = np.clip(res[0] + np.array(cover)/127.5-1., -1,1)*127.5+127.5 |
|
stego_uint8 = stego_uint8.astype(np.uint8) |
|
else: |
|
raise NotImplementedError |
|
return stego_uint8 |
|
|
|
def identity(x): |
|
return x |
|
|
|
def decode_secret(model_name, model, im, tform): |
|
if model_name in ['RoSteALS', 'UNet']: |
|
with torch.no_grad(): |
|
im = tform(im).unsqueeze(0).to(model.device) |
|
secret_pred = (model.decoder(im) > 0).cpu().numpy() |
|
else: |
|
raise NotImplementedError |
|
return secret_pred |
|
|
|
|
|
@st.cache_resource |
|
def load_model(model_name, _args): |
|
if model_name == 'UNet': |
|
tform_emb = transforms.Compose([ |
|
transforms.Resize((256,256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
tform_det = transforms.Compose([ |
|
transforms.Resize((224,224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
model, secret_len = load_UNet(_args) |
|
else: |
|
raise NotImplementedError |
|
return model, tform_emb, tform_det, secret_len |
|
|
|
|
|
@st.cache_resource |
|
def load_ecc(ecc_name, secret_len): |
|
if ecc_name == 'BCH': |
|
if secret_len == 160: |
|
ecc = BCH(285, 10, secret_len, verbose=True) |
|
elif secret_len == 100: |
|
ecc = BCH(137, 5, payload_len= secret_len, verbose=True) |
|
elif ecc_name == 'RSC': |
|
ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True) |
|
return ecc |
|
|
|
|
|
class Resize(object): |
|
def __init__(self, size=None) -> None: |
|
self.size = size |
|
def __call__(self, x, size=None): |
|
if isinstance(x, np.ndarray): |
|
x = Image.fromarray(x) |
|
new_size = size if size is not None else self.size |
|
if min(x.size) > min(new_size): |
|
x = x.resize(new_size, Image.LANCZOS) |
|
else: |
|
x = x.resize(new_size, Image.BILINEAR) |
|
x = np.array(x) |
|
return x |
|
|
|
|
|
def parse_st_args(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--weight', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt') |
|
parser.add_argument('--config', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml') |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def app(args): |
|
|
|
st.title('Watermarking Demo') |
|
|
|
model_name = st.selectbox("Choose the model", model_names) |
|
model, tform_emb, tform_det, secret_len = load_model(model_name, args) |
|
display_width = 300 |
|
|
|
ecc = load_ecc('BCH', secret_len) |
|
|
|
|
|
st.subheader("Input") |
|
image_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg"]) |
|
if image_file is not None: |
|
print('Image: ', image_file.name) |
|
ext = image_file.name.split('.')[-1] |
|
im = Image.open(image_file).convert('RGB') |
|
size0 = im.size |
|
st.image(im, width=display_width) |
|
secret_text = st.text_input(f'Input the secret (max {ecc.data_len} chars)', 'A secret') |
|
assert len(secret_text) <= ecc.data_len |
|
|
|
|
|
st.subheader("Embed results") |
|
status = st.empty() |
|
prep = transforms.Compose([ |
|
transforms.Resize((256,256)), |
|
transforms.CenterCrop((224,224)) |
|
]) |
|
if image_file is not None and secret_text is not None: |
|
secret = ecc.encode_text([secret_text]) |
|
secret = torch.from_numpy(secret).float().to(model.device) |
|
|
|
stego = embed_secret(model_name, model, im, tform_emb, secret) |
|
st.image(stego, width=display_width) |
|
|
|
|
|
mime='image/jpeg' if ext=='jpg' else f'image/{ext}' |
|
stego_bytes = to_bytes(stego, mime) |
|
st.download_button(label='Download image', data=stego_bytes, file_name=f'stego.{ext}', mime=mime) |
|
|
|
|
|
stego_processed = prep(Image.fromarray(stego)) |
|
secret_pred = decode_secret(model_name, model, stego_processed, tform_det) |
|
bit_acc = (secret_pred == secret.cpu().numpy()).mean() |
|
secret_pred = ecc.decode_text(secret_pred)[0] |
|
status.markdown('**Secret recovery check:** ' + secret_pred, unsafe_allow_html=True) |
|
status.markdown('**Bit accuracy:** ' + str(bit_acc), unsafe_allow_html=True) |
|
|
|
if __name__ == '__main__': |
|
args = parse_st_args() |
|
app(args) |
|
|
|
|
|
|