File size: 6,904 Bytes
d4d76e3 c889936 d4d76e3 c889936 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import streamlit as st
import hashlib
import uuid
from streamlit_card import card
import streamlit.components.v1 as components
import time
import json
def generate_mock_hash():
return hashlib.sha256(str(time.time()).encode()).hexdigest()
from utils import (
CLIENT_DIR,
CURRENT_DIR,
DEPLOYMENT_DIR,
KEYS_DIR,
INPUT_BROWSER_LIMIT,
clean_directory,
SERVER_DIR,
)
from concrete.ml.deployment import FHEModelClient
st.set_page_config(layout="wide")
st.sidebar.title("Contact")
st.sidebar.info(
"""
- Reda Bellafqira
- Mehdi Ben Ghali
- Pierre-Elisée Flory
- Mohammed Lansari
- Thomas Winninger
"""
)
st.title("Secure Watermarking Service")
# st.image(
# "llm_watermarking.png",
# caption="A Watermark for Large Language Models (https://doi.org/10.48550/arXiv.2301.10226)",
# )
def todo():
st.warning("Not implemented yet", icon="⚠️")
def key_gen_fn(client_id):
"""
Generate keys for a given user. The keys are saved in KEYS_DIR
!!! needs a model in DEPLOYMENT_DIR as "client.zip" !!!
Args:
client_id (str): The client_id, retrieved from streamlit
"""
clean_directory()
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{client_id}")
client.load()
# Creates the private and evaluation keys on the client side
client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# Save the evaluation key
evaluation_key_path = KEYS_DIR / f"{client_id}/evaluation_key"
with evaluation_key_path.open("wb") as f:
f.write(serialized_evaluation_keys)
# show bit of key
serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[
:INPUT_BROWSER_LIMIT
]
# shpw len of key
# f"{len(serialized_evaluation_keys) / (10**6):.2f} MB"
with st.expander("Generated keys"):
st.write(f"{len(serialized_evaluation_keys) / (10**6):.2f} MB")
st.code(serialized_evaluation_keys_shorten_hex)
st.success("Keys have been generated!", icon="✅")
def gen_trigger_set(client_id, hf_id):
# input : random images seeded by client_id
# labels : binary array of the id
watermark_uuid = uuid.uuid1()
hash = hashlib.sha256()
hash.update(client_id + str(watermark_uuid))
client_seed = hash.digest()
hash = hashlib.sha256()
hash.update(hf_id + str(watermark_uuid))
hf_seed = hash.digest()
trigger_set_size = 128
trigger_set_client = [
{"input": 1, "label": digit} for digit in encode_id(client_id, trigger_set_size)
]
todo()
def encode_id(ascii_rep, size=128):
"""Encode a string id to a string of bits
Args:
ascii_rep (_type_): The id string
size (_type_): The size of the output bit string
Returns:
_type_: a string of bits
"""
return "".join([format(ord(x), "b").zfill(8) for x in client_id])[:size]
def decode_id(binary_rep):
"""Decode a string of bits to an ascii string
Args:
binary_rep (_type_): the binary string
Returns:
_type_: an ascii string
"""
# Initializing a binary string in the form of
# 0 and 1, with base of 2
binary_int = int(binary_rep, 2)
# Getting the byte number
byte_number = binary_int.bit_length() + 7 // 8
# Getting an array of bytes
binary_array = binary_int.to_bytes(byte_number, "big")
# Converting the array into ASCII text
ascii_text = binary_array.decode()
# Getting the ASCII value
return ascii_text
def compare_id(client_id, binary_triggert_set_result):
"""Compares the string id with the labels of the trigger set on the tested API
Args:
client_id (_type_): the ascii string
binary_triggert_set_result (_type_): the binary string
Returns:
_type_: _description_
"""
ground_truth = encode_id(client_id, 128)
correct_bit = 0
for true_bit, real_bit in zip(ground_truth, binary_triggert_set_result):
if true_bit != real_bit:
correct_bit += 1
return correct_bit / len(binary_triggert_set_result)
def watermark(model, trigger_set):
"""Watermarking function
Args:
model (_type_): The model to watermark
trigger_set (_type_): the trigger set
"""
todo()
model_file_path = SERVER_DIR / "watermarked_model"
trigger_set_file_path = SERVER_DIR / "trigger_set"
# TODO: remove once model correctly watermarked
model_file_path.touch()
trigger_set_file_path.touch()
# Once the model is watermarked and dumped to files (model + trigger set), the user can download them
with open(model_file_path, "rb") as model_file:
st.download_button(
label="Download the watermarked file",
data=model_file,
mime="application/octet-stream",
)
with open(trigger_set_file_path, "rb") as trigger_set_file:
st.download_button(
label="Download the triggert set",
data=trigger_set_file,
mime="application/octet-stream",
)
st.header("Client Configuration", divider=True)
client_id = st.text_input("Identification string", "team-8-uuid")
if st.button("Generate keys"):
key_gen_fn(client_id)
st.header("Model Watermarking", divider=True)
encrypted_model = st.file_uploader("Upload your encrypted model")
if st.button("Start Watermarking"):
watermark(None, None)
st.header("Watermarking Verification", divider=True)
st.header("Update Blockchain", divider=True)
# Initialize session state to store the block data
if 'block_data' not in st.session_state:
st.session_state.block_data = None
# Button to update the blockchain
if st.button("Update Blockchain"):
previous_hash = generate_mock_hash()
timestamp = int(time.time() * 1000) # Current timestamp in milliseconds
watermarked_model_hash = generate_mock_hash()
trigger_set_hash = generate_mock_hash()
# Create the block data structure
st.session_state.block_data = {
"blockNumber": 42,
"previousHash": previous_hash,
"timestamp": timestamp,
"transactions": [
{
"type": "Watermarked Model Hash",
"hash": watermarked_model_hash
},
{
"type": "Trigger Set Hash",
"hash": trigger_set_hash
}
]
}
st.success("Blockchain updated successfully!")
# Display the JSON if block_data exists
if st.session_state.block_data:
st.subheader("Latest Block Data (JSON)")
# Convert the data to a formatted JSON string
block_json = json.dumps(st.session_state.block_data, indent=2)
# Display the JSON
st.code(block_json, language='json')
|