import time from glob import glob from pathlib import Path from typing import List from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import JSONResponse, Response from tqdm import tqdm from utils import * from concrete.ml.deployment import FHEModelClient, FHEModelServer # Load the FHE server # Initialize an instance of FastAPI app = FastAPI() # Define the default route @app.get("/") def root(): """ Root endpoint of the health prediction API. Returns: dict: The welcome message. """ return {"message": "Welcome to your encrypted DNA testing use-case with FHE!"} @app.post("/send_input") def send_input( user_id: str = Form(...), root_dir: str = Form(...), files: List[UploadFile] = File(...) ): """Send the inputs to the server.""" print("------------ Step 3.2: Send the data to the server") print(f"{user_id=}, {root_dir=}, {len(files)=}") SERVER_DIR = Path(root_dir) / f"{user_id}/server" SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR # Save the files using the above paths with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("wb") as eval_key_1: eval_key_1.write(files[0].file.read()) with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("wb") as eval_key_2: eval_key_2.write(files[1].file.read()) print(f"{len(files)=}") for i in tqdm(range(2, len(files))): with (SERVER_ENCRYPTED_INPUT_DIR / f"encrypted_window_{i}").open("wb") as eval_key_2: eval_key_2.write(files[i].file.read()) @app.post("/run_fhe") def run_fhe( user_id: str = Form(), root_dir: str = Form(...), ): """Inference in FHE.""" print("------------ Step 4.2: Run in FHE on the Server Side") print(f"{user_id=}, {root_dir=}") SERVER_DIR = Path(root_dir) / f"{user_id}/server" SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: eval_key_base_module = eval_key_1.read() assert isinstance(eval_key_base_module, bytes) with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: eval_key_smoother_module = eval_key_2.read() assert isinstance(eval_key_smoother_module, bytes) shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) print(f"{len(shared_base_modules_path)=}") assert len(shared_base_modules_path) == META["NW"] client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) print(f"{len(client_encrypted_input_path)=}") assert len(shared_base_modules_path) == META["NW"] nb_total_iterations = META["NW"] * 2 start_time = time.time() y_proba = [] for i, (model_path, encrypted_window_path) in tqdm( enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) ): server = FHEModelServer(model_path) with open(encrypted_window_path, "rb") as f: encrypted_window = f.read() encrypted_output = server.run( encrypted_window, serialized_evaluation_keys=eval_key_base_module ) assert isinstance(encrypted_output, bytes) client = FHEModelClient(model_path, key_dir=model_path) decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: f.write(encrypted_window) y_proba.append(decrypted_output) with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f: f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})") client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) y_proba = y_proba.astype(numpy.int8) print(f"{y_proba.shape=}, {type(y_proba)}") X_slide, _ = slide_window(y_proba, META["SS"]) yhat_encrypted = [] for i in tqdm(range(len(X_slide))): input = X_slide[i].reshape(1, -1) encrypted_input = client.quantize_encrypt_serialize(input) encrypted_output = server.run( encrypted_input, serialized_evaluation_keys=eval_key_smoother_module ) # output = client.deserialize_decrypt_dequantize(encrypted_output) # y_pred = numpy.argmax(output, axis=-1)[0] yhat_encrypted.append(encrypted_output) with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f: f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})") write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) fhe_execution_time = round(time.time() - start_time, 2) return JSONResponse(content=fhe_execution_time) @app.post("/run_fhe_stage1") def run_fhe_stage1( user_id: str = Form(), root_dir: str = Form(...), ): """Inference in FHE.""" print("------------ Step 4.2: Run in FHE on the Server Side") print(f"{user_id=}, {root_dir=}") SERVER_DIR = Path(root_dir) / f"{user_id}/server" SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: eval_key_base_module = eval_key_1.read() assert isinstance(eval_key_base_module, bytes) with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: eval_key_smoother_module = eval_key_2.read() assert isinstance(eval_key_smoother_module, bytes) shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) print(f"{len(shared_base_modules_path)=}") assert len(shared_base_modules_path) == META["NW"] client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) print(f"{len(client_encrypted_input_path)=}") assert len(shared_base_modules_path) == META["NW"] start = time.time() y_proba = [] for i, (model_path, encrypted_window_path) in tqdm( enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) ): server = FHEModelServer(model_path) with open(encrypted_window_path, "rb") as f: encrypted_window = f.read() encrypted_output = server.run( encrypted_window, serialized_evaluation_keys=eval_key_base_module ) assert isinstance(encrypted_output, bytes) client = FHEModelClient(model_path, key_dir=model_path) decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: f.write(encrypted_window) y_proba.append(decrypted_output) client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) y_proba = y_proba.astype(numpy.int8) print(f"{y_proba.shape=}, {type(y_proba)}") X_slide, _ = slide_window(y_proba, META["SS"]) yhat_encrypted = [] for i in tqdm(range(len(X_slide))): input = X_slide[i].reshape(1, -1) encrypted_input = client.quantize_encrypt_serialize(input) encrypted_output = server.run( encrypted_input, serialized_evaluation_keys=eval_key_smoother_module ) # output = client.deserialize_decrypt_dequantize(encrypted_output) # y_pred = numpy.argmax(output, axis=-1)[0] yhat_encrypted.append(encrypted_output) write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) fhe_execution_time = round(time.time() - start, 2) return JSONResponse(content=fhe_execution_time) @app.post("/run_fhe_stage2") def run_fhe_stage2( user_id: str = Form(), root_dir: str = Form(...), ): """Inference in FHE.""" print("------------ Step 4.2: Run in FHE on the Server Side") print(f"{user_id=}, {root_dir=}") SERVER_DIR = Path(root_dir) / f"{user_id}/server" SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: eval_key_base_module = eval_key_1.read() assert isinstance(eval_key_base_module, bytes) with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: eval_key_smoother_module = eval_key_2.read() assert isinstance(eval_key_smoother_module, bytes) shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) print(f"{len(shared_base_modules_path)=}") assert len(shared_base_modules_path) == META["NW"] client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) print(f"{len(client_encrypted_input_path)=}") assert len(shared_base_modules_path) == META["NW"] start = time.time() y_proba = [] for i, (model_path, encrypted_window_path) in tqdm( enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) ): server = FHEModelServer(model_path) with open(encrypted_window_path, "rb") as f: encrypted_window = f.read() encrypted_output = server.run( encrypted_window, serialized_evaluation_keys=eval_key_base_module ) assert isinstance(encrypted_output, bytes) client = FHEModelClient(model_path, key_dir=model_path) decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: f.write(encrypted_window) y_proba.append(decrypted_output) client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) y_proba = y_proba.astype(numpy.int8) print(f"{y_proba.shape=}, {type(y_proba)}") X_slide, _ = slide_window(y_proba, META["SS"]) yhat_encrypted = [] for i in tqdm(range(len(X_slide))): input = X_slide[i].reshape(1, -1) encrypted_input = client.quantize_encrypt_serialize(input) encrypted_output = server.run( encrypted_input, serialized_evaluation_keys=eval_key_smoother_module ) # output = client.deserialize_decrypt_dequantize(encrypted_output) # y_pred = numpy.argmax(output, axis=-1)[0] yhat_encrypted.append(encrypted_output) write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) fhe_execution_time = round(time.time() - start, 2) return JSONResponse(content=fhe_execution_time) @app.post("/get_output") def get_output(user_id: str = Form(), root_dir: str = Form()): """Retrieve the encrypted output from the server.""" print("\nStep 5.2: Get the output from the server ............\n") SERVER_DIR = Path(root_dir) / f"{user_id}/server" SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR yhat_encrypted = load_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl") CLIENT_DIR = Path(root_dir) / f"{user_id}/client" CLIENT_ENCRYPTED_OUTPUT_DIR = CLIENT_DIR / ENCRYPTED_OUTPUT_DIR write_pickle(CLIENT_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) assert len(yhat_encrypted) == META["NW"] time.sleep(1) # Send the encrypted output return Response("OK")