team14 / client_server_interface.py
tagny's picture
image pairs matching - issue on config file in server.zip
0908a41
"Client-server interface custom implementation for matcher models."
from concrete import fhe
from matchers import Matcher
class FHEServer:
"""Server interface run a FHE circuit."""
def __init__(self, path_dir):
"""Initialize the FHE interface.
Args:
path_dir (Path): The path to the directory where the circuit is saved.
"""
self.path_dir = path_dir
# Load the FHE circuit
self.server = fhe.Server.load(self.path_dir / "server.zip")
def run(
self,
serialized_encrypted_query_image,
serialized_encrypted_reference_image,
serialized_evaluation_keys,
):
"""Run the matcher on the server over an encrypted image.
Args:
serialized_encrypted_query_image (bytes): The encrypted and serialized query image.
serialized_encrypted_reference_image (bytes): The encrypted and serialized referenceimage.
serialized_evaluation_keys (bytes): The serialized evaluation keys.
Returns:
bytes: The matcher's output.
"""
# Deserialize the encrypted input image and the evaluation keys
encrypted_query_image = fhe.Value.deserialize(serialized_encrypted_query_image)
encrypted_reference_image = fhe.Value.deserialize(
serialized_encrypted_reference_image
)
evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)
# Execute the matcher in FHE
encrypted_output = self.server.run(
encrypted_query_image,
encrypted_reference_image,
evaluation_keys=evaluation_keys,
)
# Serialize the encrypted output image
serialized_encrypted_output = encrypted_output.serialize()
return serialized_encrypted_output
class FHEDev:
"""Development interface to save and load the matcher."""
def __init__(self, matcher, path_dir):
"""Initialize the FHE interface.
Args:
matcher (Matcher): The matcher to use in the FHE interface.
path_dir (str): The path to the directory where the circuit is saved.
"""
self.matcher = matcher
self.path_dir = path_dir
self.path_dir.mkdir(parents=True, exist_ok=True)
def save(self):
"""Export all needed artifacts for the client and server interfaces."""
assert (
self.matcher.fhe_circuit is not None
), "The model must be compiled before saving it."
# Save the circuit for the server, using the via_mlir in order to handle cross-platform
# execution
path_circuit_server = self.path_dir / "server.zip"
self.matcher.fhe_circuit.server.save(path_circuit_server, via_mlir=True)
# Save the circuit for the client
path_circuit_client = self.path_dir / "client.zip"
self.matcher.fhe_circuit.client.save(path_circuit_client)
class FHEClient:
"""Client interface to encrypt and decrypt FHE data associated to a matcher."""
def __init__(self, path_dir, matcher_name, key_dir=None):
"""Initialize the FHE interface.
Args:
path_dir (Path): The path to the directory where the circuit is saved.
matcher_name (str): The matcher's name to consider.
key_dir (Path): The path to the directory where the keys are stored. Default to None.
"""
self.path_dir = path_dir
self.key_dir = key_dir
# If path_dir does not exist raise
assert (
path_dir.exists()
), f"{path_dir} does not exist. Please specify a valid path."
# Load the client
self.client = fhe.Client.load(self.path_dir / "client.zip", self.key_dir)
# Instantiate the matcher
self.matcher = Matcher(matcher_name)
def generate_private_and_evaluation_keys(self, force=False):
"""Generate the private and evaluation keys.
Args:
force (bool): If True, regenerate the keys even if they already exist.
"""
self.client.keygen(force)
def get_serialized_evaluation_keys(self):
"""Get the serialized evaluation keys.
Returns:
bytes: The evaluation keys.
"""
return self.client.evaluation_keys.serialize()
def encrypt_serialize(self, input_image):
"""Encrypt and serialize the input image in the clear.
Args:
input_image (numpy.ndarray): The image to encrypt and serialize.
Returns:
bytes: The pre-processed, encrypted and serialized image.
"""
# Encrypt the image
encrypted_image = self.client.encrypt(input_image)
# Serialize the encrypted image to be sent to the server
serialized_encrypted_image = encrypted_image.serialize()
return serialized_encrypted_image
def deserialize_decrypt_post_process(self, serialized_encrypted_output_image):
"""Deserialize, decrypt and post-process the output result in the clear.
Args:
serialized_encrypted_output_image (bytes): The serialized and encrypted output image.
Returns:
numpy.ndarray: The decrypted, deserialized and post-processed image.
"""
# Deserialize the encrypted image
encrypted_output = fhe.Value.deserialize(serialized_encrypted_output_image)
# Decrypt the result
output_result = self.client.decrypt(encrypted_output)
# Post-process the result
post_processed_output = self.matcher.post_processing(output_result)
return post_processed_output