Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
"""A local gradio app that filters images using FHE.""" | |
import os | |
import shutil | |
import subprocess | |
import time | |
import gradio as gr | |
import numpy | |
import requests | |
from itertools import chain | |
from settings import ( | |
REPO_DIR, | |
SERVER_URL, | |
FHE_KEYS, | |
CLIENT_FILES, | |
SERVER_FILES, | |
DEPLOYMENT_PATH, | |
INITIAL_INPUT_SHAPE, | |
INPUT_INDEXES, | |
START_POSITIONS, | |
) | |
from development.client_server_interface import MultiInputsFHEModelClient | |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) | |
time.sleep(3) | |
def shorten_bytes_object(bytes_object, limit=500): | |
"""Shorten the input bytes object to a given length. | |
Encrypted data is too large for displaying it in the browser using Gradio. This function | |
provides a shorten representation of it. | |
Args: | |
bytes_object (bytes): The input to shorten | |
limit (int): The length to consider. Default to 500. | |
Returns: | |
str: Hexadecimal string shorten representation of the input byte object. | |
""" | |
# Define a shift for better display | |
shift = 100 | |
return bytes_object[shift : limit + shift].hex() | |
def get_client(client_id, client_type): | |
"""Get the client API. | |
Args: | |
client_id (int): The client ID to consider. | |
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party'). | |
Returns: | |
FHEModelClient: The client API. | |
""" | |
key_dir = FHE_KEYS / f"{client_type}_{client_id}" | |
return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir) | |
def get_client_file_path(name, client_id, client_type): | |
"""Get the correct temporary file path for the client. | |
Args: | |
name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs'). | |
client_id (int): The client ID to consider. | |
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party'). | |
Returns: | |
pathlib.Path: The file path. | |
""" | |
return CLIENT_FILES / f"{name}_{client_type}_{client_id}" | |
def clean_temporary_files(n_keys=20): | |
"""Clean keys and encrypted images. | |
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this | |
limit is reached, the oldest files are deleted. | |
Args: | |
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. | |
""" | |
# Get the oldest key files in the key directory | |
key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime) | |
# If more than n_keys keys are found, remove the oldest | |
user_ids = [] | |
if len(key_dirs) > n_keys: | |
n_keys_to_delete = len(key_dirs) - n_keys | |
for key_dir in key_dirs[:n_keys_to_delete]: | |
user_ids.append(key_dir.name) | |
shutil.rmtree(key_dir) | |
# Get all the encrypted objects in the temporary folder | |
client_files = CLIENT_FILES.iterdir() | |
server_files = SERVER_FILES.iterdir() | |
# Delete all files related to the ids whose keys were deleted | |
for file in chain(client_files, server_files): | |
for user_id in user_ids: | |
if user_id in file.name: | |
file.unlink() | |
def keygen(client_id, client_type): | |
"""Generate the private key associated to a filter. | |
Args: | |
client_id (int): The client ID to consider. | |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). | |
""" | |
# Clean temporary files | |
clean_temporary_files() | |
# Retrieve the client instance | |
client = get_client(client_id, client_type) | |
# Generate a private key | |
client.generate_private_and_evaluation_keys(force=True) | |
# Retrieve the serialized evaluation key. In this case, as circuits are fully leveled, this | |
# evaluation key is empty. However, for software reasons, it is still needed for proper FHE | |
# execution | |
evaluation_key = client.get_serialized_evaluation_keys() | |
# Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio | |
# buttons (see https://github.com/gradio-app/gradio/issues/1877) | |
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type) | |
with evaluation_key_path.open("wb") as evaluation_key_file: | |
evaluation_key_file.write(evaluation_key) | |
def send_input(client_id, client_type): | |
"""Send the encrypted input image as well as the evaluation key to the server. | |
Args: | |
client_id (int): The client ID to consider. | |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). | |
""" | |
# Get the paths to the evaluation key and encrypted inputs | |
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type) | |
encrypted_input_path = get_client_file_path("encrypted_inputs", client_id, client_type) | |
# Define the data and files to post | |
data = { | |
"client_id": client_id, | |
"client_type": client_type, | |
} | |
files = [ | |
("files", open(encrypted_input_path, "rb")), | |
("files", open(evaluation_key_path, "rb")), | |
] | |
# Send the encrypted input image and evaluation key to the server | |
url = SERVER_URL + "send_input" | |
with requests.post( | |
url=url, | |
data=data, | |
files=files, | |
) as response: | |
return response.ok | |
def keygen_encrypt_send(inputs, client_type): | |
"""Encrypt the given inputs for a specific client. | |
Args: | |
inputs (numpy.ndarray): The inputs to encrypt. | |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). | |
Returns: | |
""" | |
# Create an ID for the current client to consider | |
client_id = numpy.random.randint(0, 2**32) | |
keygen(client_id, client_type) | |
# Retrieve the client instance | |
client = get_client(client_id, client_type) | |
# TODO : pre-process the data first | |
# Quantize, encrypt and serialize the inputs | |
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs( | |
inputs, | |
input_index=INPUT_INDEXES[client_type], | |
initial_input_shape=INITIAL_INPUT_SHAPE, | |
start_position=START_POSITIONS[client_type], | |
) | |
# Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio | |
# buttons, https://github.com/gradio-app/gradio/issues/1877 | |
encrypted_inputs_path = get_client_file_path("encrypted_inputs", client_id, client_type) | |
with encrypted_inputs_path.open("wb") as encrypted_inputs_file: | |
encrypted_inputs_file.write(encrypted_inputs) | |
# Create a truncated version of the encrypted image for display | |
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs) | |
send_input(client_id, client_type) | |
# TODO: also return private key representation if possible | |
return encrypted_inputs_short | |
def run_fhe(client_id): | |
"""Run the model on the encrypted inputs previously sent using FHE. | |
Args: | |
client_id (int): The client ID to consider. | |
""" | |
# TODO : add a warning for users to send all client types' inputs | |
data = { | |
"client_id": client_id, | |
} | |
# Trigger the FHE execution on the encrypted inputs previously sent | |
url = SERVER_URL + "run_fhe" | |
with requests.post( | |
url=url, | |
data=data, | |
) as response: | |
if response.ok: | |
return response.json() | |
else: | |
raise gr.Error("Please wait for the inputs to be sent to the server.") | |
def get_output(client_id): | |
"""Retrieve the encrypted output. | |
Args: | |
client_id (int): The client ID to consider. | |
Returns: | |
output_encrypted_representation (numpy.ndarray): A representation of the encrypted output. | |
""" | |
data = { | |
"client_id": client_id, | |
} | |
# Retrieve the encrypted output image | |
url = SERVER_URL + "get_output" | |
with requests.post( | |
url=url, | |
data=data, | |
) as response: | |
if response.ok: | |
encrypted_output = response.content | |
# Save the encrypted output to bytes in a file as it is too large to pass through regular | |
# Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) | |
# TODO : check if output to user is relevant | |
encrypted_output_path = get_client_file_path("encrypted_output", client_id, "user") | |
with encrypted_output_path.open("wb") as encrypted_output_file: | |
encrypted_output_file.write(encrypted_output) | |
# TODO | |
# Decrypt the output using a different (wrong) key for display | |
# output_encrypted_representation = decrypt_output_with_wrong_key(encrypted_output, client_type) | |
# return output_encrypted_representation | |
return None | |
else: | |
raise gr.Error("Please wait for the FHE execution to be completed.") | |
def decrypt_output(client_id, client_type): | |
"""Decrypt the result. | |
Args: | |
client_id (int): The client ID to consider. | |
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party'). | |
Returns: | |
output(numpy.ndarray): The decrypted output | |
""" | |
# Get the encrypted output path | |
encrypted_output_path = get_client_file_path("encrypted_output", client_id, client_type) | |
if not encrypted_output_path.is_file(): | |
raise gr.Error("Please run the FHE execution first.") | |
# Load the encrypted output as bytes | |
with encrypted_output_path.open("rb") as encrypted_output_file: | |
encrypted_output_proba = encrypted_output_file.read() | |
# Retrieve the client API | |
client = get_client(client_id, client_type) | |
# Deserialize, decrypt and post-process the encrypted output | |
output_proba = client.deserialize_decrypt_post_process(encrypted_output_proba) | |
# Determine the predicted class | |
output = numpy.argmax(output_proba, axis=1) | |
return output | |
demo = gr.Blocks() | |
print("Starting the demo...") | |
with demo: | |
gr.Markdown( | |
""" | |
<h1 align="center">Credit Card Approval Prediction Using Fully Homomorphic Encryption</h1> | |
""" | |
) | |
gr.Markdown("## Client side") | |
gr.Markdown("### Step 1: Infos. ") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### User") | |
# TODO : change infos | |
choice_1 = gr.Dropdown(choices=["Yes, No"], label="Choose", interactive=True) | |
slide_1 = gr.Slider(2, 20, value=4, label="Count", info="Choose between 2 and 20") | |
with gr.Column(): | |
gr.Markdown("### Bank ") | |
# TODO : change infos | |
checkbox_1 = gr.CheckboxGroup(["USA", "Japan", "Pakistan"], label="Countries", info="Where are they from?") | |
with gr.Column(): | |
gr.Markdown("### Third Party ") | |
# TODO : change infos | |
radio_1 = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?") | |
gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### User") | |
encrypt_button_user = gr.Button("Encrypt the inputs and send to server.") | |
keys_user = gr.Textbox( | |
label="Keys representation:", max_lines=2, interactive=False | |
) | |
encrypted_input_user = gr.Textbox( | |
label="Encrypted input representation:", max_lines=2, interactive=False | |
) | |
user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) | |
with gr.Column(): | |
gr.Markdown("### Bank ") | |
encrypt_button_bank = gr.Button("Encrypt the inputs and send to server.") | |
keys_bank = gr.Textbox( | |
label="Keys representation:", max_lines=2, interactive=False | |
) | |
encrypted_input_bank = gr.Textbox( | |
label="Encrypted input representation:", max_lines=2, interactive=False | |
) | |
bank_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) | |
with gr.Column(): | |
gr.Markdown("### Third Party ") | |
encrypt_button_third_party = gr.Button("Encrypt the inputs and send to server.") | |
keys_3 = gr.Textbox( | |
label="Keys representation:", max_lines=2, interactive=False | |
) | |
encrypted_input__third_party = gr.Textbox( | |
label="Encrypted input representation:", max_lines=2, interactive=False | |
) | |
third_party_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) | |
gr.Markdown("## Server side") | |
gr.Markdown( | |
"The encrypted values are received by the server. The server can then compute the prediction " | |
"directly over them. Once the computation is finished, the server returns " | |
"the encrypted result to the client." | |
) | |
gr.Markdown("### Step 6: Run FHE execution.") | |
execute_fhe_button = gr.Button("Run FHE execution.") | |
fhe_execution_time = gr.Textbox( | |
label="Total FHE execution time (in seconds):", max_lines=1, interactive=False | |
) | |
gr.Markdown("## Client side") | |
gr.Markdown( | |
"The encrypted output is sent back to the client, who can finally decrypt it with the " | |
"private key." | |
) | |
gr.Markdown("### Step 7: Receive the encrypted output from the server.") | |
gr.Markdown( | |
"The output displayed here is the encrypted result sent by the server, which has been " | |
"decrypted using a different private key. This is only used to visually represent an " | |
"encrypted output." | |
) | |
get_output_button = gr.Button("Receive the encrypted output from the server.") | |
encrypted_output_representation = gr.Textbox( | |
label="Encrypted output representation: ", max_lines=1, interactive=False | |
) | |
gr.Markdown("### Step 8: Decrypt the output.") | |
decrypt_button = gr.Button("Decrypt the output") | |
prediction_output = gr.Textbox( | |
label="Credit card approval decision: ", max_lines=1, interactive=False | |
) | |
# Button to encrypt inputs on the client side | |
# encrypt_button_user.click( | |
# encrypt, | |
# inputs=[user_id, input_image, filter_name], | |
# outputs=[original_image, encrypted_input], | |
# ) | |
# # Button to encrypt inputs on the client side | |
# encrypt_button_bank.click( | |
# encrypt, | |
# inputs=[user_id, input_image, filter_name], | |
# outputs=[original_image, encrypted_input], | |
# ) | |
# # Button to encrypt inputs on the client side | |
# encrypt_button_third_party.click( | |
# encrypt, | |
# inputs=[user_id, input_image, filter_name], | |
# outputs=[original_image, encrypted_input], | |
# ) | |
# # Button to send the encodings to the server using post method | |
# send_input_button.click( | |
# send_input, inputs=[user_id, filter_name], outputs=[send_input_checkbox] | |
# ) | |
# # Button to send the encodings to the server using post method | |
# execute_fhe_button.click(run_fhe, inputs=[user_id, filter_name], outputs=[fhe_execution_time]) | |
# # Button to send the encodings to the server using post method | |
# get_output_button.click( | |
# get_output, | |
# inputs=[user_id, filter_name], | |
# outputs=[encrypted_output_representation] | |
# ) | |
# # Button to decrypt the output on the client side | |
# decrypt_button.click( | |
# decrypt_output, | |
# inputs=[user_id, filter_name], | |
# outputs=[output_image, keygen_checkbox, send_input_checkbox], | |
# ) | |
gr.Markdown( | |
"The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a " | |
"Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). " | |
"Try it yourself and don't forget to star on Github ⭐." | |
) | |
demo.launch(share=False) | |