Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
romanbredehoft-zama
commited on
Commit
β’
9a997e4
1
Parent(s):
4337a72
First working demo with multi-inputs XGB
Browse files- app.py +96 -365
- backend.py +411 -0
- data/clean_data.csv +0 -0
- deployment_files/client.zip +2 -2
- deployment_files/pre_processor_third_party.pkl +3 -0
- deployment_files/pre_processor_user.pkl +3 -0
- deployment_files/server.zip +2 -2
- development.py +97 -0
- development/development.py +0 -67
- development/pre_processing.py +0 -122
- server.py +29 -20
- settings.py +41 -5
- {development β utils}/client_server_interface.py +8 -7
- {development β utils}/model.py +43 -0
- utils/pre_processing.py +85 -0
app.py
CHANGED
@@ -1,318 +1,39 @@
|
|
1 |
-
"""A
|
2 |
|
3 |
-
import os
|
4 |
-
import shutil
|
5 |
import subprocess
|
6 |
import time
|
7 |
import gradio as gr
|
8 |
-
import numpy
|
9 |
-
import requests
|
10 |
-
from itertools import chain
|
11 |
|
12 |
from settings import (
|
13 |
REPO_DIR,
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
)
|
23 |
-
|
24 |
-
from development.client_server_interface import MultiInputsFHEModelClient
|
25 |
|
26 |
|
27 |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
|
28 |
time.sleep(3)
|
29 |
|
30 |
|
31 |
-
def shorten_bytes_object(bytes_object, limit=500):
|
32 |
-
"""Shorten the input bytes object to a given length.
|
33 |
-
|
34 |
-
Encrypted data is too large for displaying it in the browser using Gradio. This function
|
35 |
-
provides a shorten representation of it.
|
36 |
-
|
37 |
-
Args:
|
38 |
-
bytes_object (bytes): The input to shorten
|
39 |
-
limit (int): The length to consider. Default to 500.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
str: Hexadecimal string shorten representation of the input byte object.
|
43 |
-
|
44 |
-
"""
|
45 |
-
# Define a shift for better display
|
46 |
-
shift = 100
|
47 |
-
return bytes_object[shift : limit + shift].hex()
|
48 |
-
|
49 |
-
|
50 |
-
def get_client(client_id, client_type):
|
51 |
-
"""Get the client API.
|
52 |
-
|
53 |
-
Args:
|
54 |
-
client_id (int): The client ID to consider.
|
55 |
-
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
|
56 |
-
|
57 |
-
Returns:
|
58 |
-
FHEModelClient: The client API.
|
59 |
-
"""
|
60 |
-
key_dir = FHE_KEYS / f"{client_type}_{client_id}"
|
61 |
-
|
62 |
-
return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir)
|
63 |
-
|
64 |
-
|
65 |
-
def get_client_file_path(name, client_id, client_type):
|
66 |
-
"""Get the correct temporary file path for the client.
|
67 |
-
|
68 |
-
Args:
|
69 |
-
name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs').
|
70 |
-
client_id (int): The client ID to consider.
|
71 |
-
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
|
72 |
-
|
73 |
-
Returns:
|
74 |
-
pathlib.Path: The file path.
|
75 |
-
"""
|
76 |
-
return CLIENT_FILES / f"{name}_{client_type}_{client_id}"
|
77 |
-
|
78 |
-
|
79 |
-
def clean_temporary_files(n_keys=20):
|
80 |
-
"""Clean keys and encrypted images.
|
81 |
-
|
82 |
-
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
|
83 |
-
limit is reached, the oldest files are deleted.
|
84 |
-
|
85 |
-
Args:
|
86 |
-
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20.
|
87 |
-
|
88 |
-
"""
|
89 |
-
# Get the oldest key files in the key directory
|
90 |
-
key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime)
|
91 |
-
|
92 |
-
# If more than n_keys keys are found, remove the oldest
|
93 |
-
user_ids = []
|
94 |
-
if len(key_dirs) > n_keys:
|
95 |
-
n_keys_to_delete = len(key_dirs) - n_keys
|
96 |
-
for key_dir in key_dirs[:n_keys_to_delete]:
|
97 |
-
user_ids.append(key_dir.name)
|
98 |
-
shutil.rmtree(key_dir)
|
99 |
-
|
100 |
-
# Get all the encrypted objects in the temporary folder
|
101 |
-
client_files = CLIENT_FILES.iterdir()
|
102 |
-
server_files = SERVER_FILES.iterdir()
|
103 |
-
|
104 |
-
# Delete all files related to the ids whose keys were deleted
|
105 |
-
for file in chain(client_files, server_files):
|
106 |
-
for user_id in user_ids:
|
107 |
-
if user_id in file.name:
|
108 |
-
file.unlink()
|
109 |
-
|
110 |
-
|
111 |
-
def keygen(client_id, client_type):
|
112 |
-
"""Generate the private key associated to a filter.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
client_id (int): The client ID to consider.
|
116 |
-
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
117 |
-
"""
|
118 |
-
# Clean temporary files
|
119 |
-
clean_temporary_files()
|
120 |
-
|
121 |
-
# Retrieve the client instance
|
122 |
-
client = get_client(client_id, client_type)
|
123 |
-
|
124 |
-
# Generate a private key
|
125 |
-
client.generate_private_and_evaluation_keys(force=True)
|
126 |
-
|
127 |
-
# Retrieve the serialized evaluation key. In this case, as circuits are fully leveled, this
|
128 |
-
# evaluation key is empty. However, for software reasons, it is still needed for proper FHE
|
129 |
-
# execution
|
130 |
-
evaluation_key = client.get_serialized_evaluation_keys()
|
131 |
-
|
132 |
-
# Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
|
133 |
-
# buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
134 |
-
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type)
|
135 |
-
|
136 |
-
with evaluation_key_path.open("wb") as evaluation_key_file:
|
137 |
-
evaluation_key_file.write(evaluation_key)
|
138 |
-
|
139 |
-
|
140 |
-
def send_input(client_id, client_type):
|
141 |
-
"""Send the encrypted input image as well as the evaluation key to the server.
|
142 |
-
|
143 |
-
Args:
|
144 |
-
client_id (int): The client ID to consider.
|
145 |
-
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
146 |
-
"""
|
147 |
-
# Get the paths to the evaluation key and encrypted inputs
|
148 |
-
evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type)
|
149 |
-
encrypted_input_path = get_client_file_path("encrypted_inputs", client_id, client_type)
|
150 |
-
|
151 |
-
# Define the data and files to post
|
152 |
-
data = {
|
153 |
-
"client_id": client_id,
|
154 |
-
"client_type": client_type,
|
155 |
-
}
|
156 |
-
|
157 |
-
files = [
|
158 |
-
("files", open(encrypted_input_path, "rb")),
|
159 |
-
("files", open(evaluation_key_path, "rb")),
|
160 |
-
]
|
161 |
-
|
162 |
-
# Send the encrypted input image and evaluation key to the server
|
163 |
-
url = SERVER_URL + "send_input"
|
164 |
-
with requests.post(
|
165 |
-
url=url,
|
166 |
-
data=data,
|
167 |
-
files=files,
|
168 |
-
) as response:
|
169 |
-
return response.ok
|
170 |
-
|
171 |
-
|
172 |
-
def keygen_encrypt_send(inputs, client_type):
|
173 |
-
"""Encrypt the given inputs for a specific client.
|
174 |
-
|
175 |
-
Args:
|
176 |
-
inputs (numpy.ndarray): The inputs to encrypt.
|
177 |
-
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
178 |
-
|
179 |
-
Returns:
|
180 |
-
|
181 |
-
"""
|
182 |
-
# Create an ID for the current client to consider
|
183 |
-
client_id = numpy.random.randint(0, 2**32)
|
184 |
-
|
185 |
-
keygen(client_id, client_type)
|
186 |
-
|
187 |
-
# Retrieve the client instance
|
188 |
-
client = get_client(client_id, client_type)
|
189 |
-
|
190 |
-
# TODO : pre-process the data first
|
191 |
-
|
192 |
-
# Quantize, encrypt and serialize the inputs
|
193 |
-
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
|
194 |
-
inputs,
|
195 |
-
input_index=INPUT_INDEXES[client_type],
|
196 |
-
initial_input_shape=INITIAL_INPUT_SHAPE,
|
197 |
-
start_position=START_POSITIONS[client_type],
|
198 |
-
)
|
199 |
-
|
200 |
-
# Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio
|
201 |
-
# buttons, https://github.com/gradio-app/gradio/issues/1877
|
202 |
-
encrypted_inputs_path = get_client_file_path("encrypted_inputs", client_id, client_type)
|
203 |
-
|
204 |
-
with encrypted_inputs_path.open("wb") as encrypted_inputs_file:
|
205 |
-
encrypted_inputs_file.write(encrypted_inputs)
|
206 |
-
|
207 |
-
# Create a truncated version of the encrypted image for display
|
208 |
-
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs)
|
209 |
-
|
210 |
-
send_input(client_id, client_type)
|
211 |
-
|
212 |
-
# TODO: also return private key representation if possible
|
213 |
-
return encrypted_inputs_short
|
214 |
-
|
215 |
-
|
216 |
-
def run_fhe(client_id):
|
217 |
-
"""Run the model on the encrypted inputs previously sent using FHE.
|
218 |
-
|
219 |
-
Args:
|
220 |
-
client_id (int): The client ID to consider.
|
221 |
-
"""
|
222 |
-
|
223 |
-
# TODO : add a warning for users to send all client types' inputs
|
224 |
-
|
225 |
-
data = {
|
226 |
-
"client_id": client_id,
|
227 |
-
}
|
228 |
-
|
229 |
-
# Trigger the FHE execution on the encrypted inputs previously sent
|
230 |
-
url = SERVER_URL + "run_fhe"
|
231 |
-
with requests.post(
|
232 |
-
url=url,
|
233 |
-
data=data,
|
234 |
-
) as response:
|
235 |
-
if response.ok:
|
236 |
-
return response.json()
|
237 |
-
else:
|
238 |
-
raise gr.Error("Please wait for the inputs to be sent to the server.")
|
239 |
-
|
240 |
-
|
241 |
-
def get_output(client_id):
|
242 |
-
"""Retrieve the encrypted output.
|
243 |
-
|
244 |
-
Args:
|
245 |
-
client_id (int): The client ID to consider.
|
246 |
-
|
247 |
-
Returns:
|
248 |
-
output_encrypted_representation (numpy.ndarray): A representation of the encrypted output.
|
249 |
-
|
250 |
-
"""
|
251 |
-
data = {
|
252 |
-
"client_id": client_id,
|
253 |
-
}
|
254 |
-
|
255 |
-
# Retrieve the encrypted output image
|
256 |
-
url = SERVER_URL + "get_output"
|
257 |
-
with requests.post(
|
258 |
-
url=url,
|
259 |
-
data=data,
|
260 |
-
) as response:
|
261 |
-
if response.ok:
|
262 |
-
encrypted_output = response.content
|
263 |
-
|
264 |
-
# Save the encrypted output to bytes in a file as it is too large to pass through regular
|
265 |
-
# Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
266 |
-
# TODO : check if output to user is relevant
|
267 |
-
encrypted_output_path = get_client_file_path("encrypted_output", client_id, "user")
|
268 |
-
|
269 |
-
with encrypted_output_path.open("wb") as encrypted_output_file:
|
270 |
-
encrypted_output_file.write(encrypted_output)
|
271 |
-
|
272 |
-
# TODO
|
273 |
-
# Decrypt the output using a different (wrong) key for display
|
274 |
-
# output_encrypted_representation = decrypt_output_with_wrong_key(encrypted_output, client_type)
|
275 |
-
|
276 |
-
# return output_encrypted_representation
|
277 |
-
|
278 |
-
return None
|
279 |
-
else:
|
280 |
-
raise gr.Error("Please wait for the FHE execution to be completed.")
|
281 |
-
|
282 |
-
|
283 |
-
def decrypt_output(client_id, client_type):
|
284 |
-
"""Decrypt the result.
|
285 |
-
|
286 |
-
Args:
|
287 |
-
client_id (int): The client ID to consider.
|
288 |
-
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
289 |
-
|
290 |
-
Returns:
|
291 |
-
output(numpy.ndarray): The decrypted output
|
292 |
-
|
293 |
-
"""
|
294 |
-
# Get the encrypted output path
|
295 |
-
encrypted_output_path = get_client_file_path("encrypted_output", client_id, client_type)
|
296 |
-
|
297 |
-
if not encrypted_output_path.is_file():
|
298 |
-
raise gr.Error("Please run the FHE execution first.")
|
299 |
-
|
300 |
-
# Load the encrypted output as bytes
|
301 |
-
with encrypted_output_path.open("rb") as encrypted_output_file:
|
302 |
-
encrypted_output_proba = encrypted_output_file.read()
|
303 |
-
|
304 |
-
# Retrieve the client API
|
305 |
-
client = get_client(client_id, client_type)
|
306 |
-
|
307 |
-
# Deserialize, decrypt and post-process the encrypted output
|
308 |
-
output_proba = client.deserialize_decrypt_post_process(encrypted_output_proba)
|
309 |
-
|
310 |
-
# Determine the predicted class
|
311 |
-
output = numpy.argmax(output_proba, axis=1)
|
312 |
-
|
313 |
-
return output
|
314 |
-
|
315 |
-
|
316 |
demo = gr.Blocks()
|
317 |
|
318 |
|
@@ -330,60 +51,68 @@ with demo:
|
|
330 |
with gr.Row():
|
331 |
with gr.Column():
|
332 |
gr.Markdown("### User")
|
333 |
-
|
334 |
-
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
with gr.Column():
|
338 |
gr.Markdown("### Bank ")
|
339 |
-
|
340 |
-
checkbox_1 = gr.CheckboxGroup(["USA", "Japan", "Pakistan"], label="Countries", info="Where are they from?")
|
341 |
|
342 |
with gr.Column():
|
343 |
-
gr.Markdown("### Third
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
|
348 |
gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.")
|
349 |
with gr.Row():
|
350 |
with gr.Column():
|
351 |
gr.Markdown("### User")
|
352 |
encrypt_button_user = gr.Button("Encrypt the inputs and send to server.")
|
353 |
-
|
354 |
-
|
355 |
-
)
|
356 |
encrypted_input_user = gr.Textbox(
|
357 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
358 |
)
|
|
|
|
|
|
|
359 |
|
360 |
-
user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
361 |
|
362 |
|
363 |
with gr.Column():
|
364 |
gr.Markdown("### Bank ")
|
365 |
encrypt_button_bank = gr.Button("Encrypt the inputs and send to server.")
|
366 |
-
|
367 |
-
|
368 |
-
)
|
369 |
encrypted_input_bank = gr.Textbox(
|
370 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
371 |
)
|
372 |
-
|
373 |
-
|
|
|
374 |
|
375 |
|
376 |
with gr.Column():
|
377 |
gr.Markdown("### Third Party ")
|
378 |
encrypt_button_third_party = gr.Button("Encrypt the inputs and send to server.")
|
379 |
-
keys_3 = gr.Textbox(
|
380 |
-
label="Keys representation:", max_lines=2, interactive=False
|
381 |
-
)
|
382 |
-
encrypted_input__third_party = gr.Textbox(
|
383 |
-
label="Encrypted input representation:", max_lines=2, interactive=False
|
384 |
-
)
|
385 |
|
386 |
third_party_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
gr.Markdown("## Server side")
|
389 |
gr.Markdown(
|
@@ -412,9 +141,9 @@ with demo:
|
|
412 |
)
|
413 |
get_output_button = gr.Button("Receive the encrypted output from the server.")
|
414 |
|
415 |
-
encrypted_output_representation = gr.Textbox(
|
416 |
-
|
417 |
-
)
|
418 |
|
419 |
gr.Markdown("### Step 8: Decrypt the output.")
|
420 |
decrypt_button = gr.Button("Decrypt the output")
|
@@ -423,48 +152,50 @@ with demo:
|
|
423 |
label="Credit card approval decision: ", max_lines=1, interactive=False
|
424 |
)
|
425 |
|
426 |
-
# Button to encrypt inputs
|
427 |
-
#
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
# encrypt_button_bank.click(
|
435 |
-
# encrypt,
|
436 |
-
# inputs=[user_id, input_image, filter_name],
|
437 |
-
# outputs=[original_image, encrypted_input],
|
438 |
-
# )
|
439 |
|
440 |
-
#
|
441 |
-
#
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
|
|
446 |
|
447 |
-
#
|
448 |
-
#
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
451 |
|
452 |
-
#
|
453 |
-
#
|
|
|
454 |
|
455 |
-
#
|
456 |
-
#
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
|
|
461 |
|
462 |
-
#
|
463 |
-
#
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
|
|
468 |
|
469 |
gr.Markdown(
|
470 |
"The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a "
|
|
|
1 |
+
"""A gradio app for credit card approval prediction using FHE."""
|
2 |
|
|
|
|
|
3 |
import subprocess
|
4 |
import time
|
5 |
import gradio as gr
|
|
|
|
|
|
|
6 |
|
7 |
from settings import (
|
8 |
REPO_DIR,
|
9 |
+
ACCOUNT_MIN_MAX,
|
10 |
+
CHILDREN_MIN_MAX,
|
11 |
+
INCOME_MIN_MAX,
|
12 |
+
AGE_MIN_MAX,
|
13 |
+
EMPLOYED_MIN_MAX,
|
14 |
+
FAMILY_MIN_MAX,
|
15 |
+
INCOME_TYPES,
|
16 |
+
OCCUPATION_TYPES,
|
17 |
+
HOUSING_TYPES,
|
18 |
+
EDUCATION_TYPES,
|
19 |
+
FAMILY_STATUS,
|
20 |
+
)
|
21 |
+
from backend import (
|
22 |
+
shorten_bytes_object,
|
23 |
+
clean_temporary_files,
|
24 |
+
pre_process_keygen_encrypt_send_user,
|
25 |
+
pre_process_keygen_encrypt_send_bank,
|
26 |
+
pre_process_keygen_encrypt_send_third_party,
|
27 |
+
run_fhe,
|
28 |
+
get_output,
|
29 |
+
decrypt_output,
|
30 |
)
|
|
|
|
|
31 |
|
32 |
|
33 |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
|
34 |
time.sleep(3)
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
demo = gr.Blocks()
|
38 |
|
39 |
|
|
|
51 |
with gr.Row():
|
52 |
with gr.Column():
|
53 |
gr.Markdown("### User")
|
54 |
+
gender = gr.Radio(["Female", "Male"], label="Gender")
|
55 |
+
bool_inputs = gr.CheckboxGroup(["Car", "Property", "Work phone", "Phone", "Email"], label="What do you own ?")
|
56 |
+
num_children = gr.Slider(**CHILDREN_MIN_MAX, step=1, label="Number of children", info="How many children do you have (0 to 19) ?")
|
57 |
+
num_family = gr.Slider(**FAMILY_MIN_MAX, step=1, label="Family", info="How many members does your family have? (1 to 20) ?")
|
58 |
+
total_income = gr.Slider(**INCOME_MIN_MAX, label="Income", info="What's you total yearly income (in euros, 3780 to 220500) ?")
|
59 |
+
age = gr.Slider(**AGE_MIN_MAX, step=1, label="Age", info="How old are you (20 to 68) ?")
|
60 |
+
income_type = gr.Dropdown(choices=INCOME_TYPES, label="Income type", info="What is your main type of income ?")
|
61 |
+
education_type = gr.Dropdown(choices=EDUCATION_TYPES, label="Education", info="What is your education background ?")
|
62 |
+
family_status = gr.Dropdown(choices=FAMILY_STATUS, label="Family", info="What is your family status ?")
|
63 |
+
occupation_type = gr.Dropdown(choices=OCCUPATION_TYPES, label="Occupation", info="What is your main occupation ?")
|
64 |
+
housing_type = gr.Dropdown(choices=HOUSING_TYPES, label="Housing", info="In what type of housing do you live ?")
|
65 |
|
66 |
with gr.Column():
|
67 |
gr.Markdown("### Bank ")
|
68 |
+
account_length = gr.Slider(**ACCOUNT_MIN_MAX, step=1, label="Account length", info="How long have this person had this account (in months, 0 to 60) ?")
|
|
|
69 |
|
70 |
with gr.Column():
|
71 |
+
gr.Markdown("### Third party ")
|
72 |
+
employed = gr.Radio(["Yes", "No"], label="Is the person employed ?")
|
73 |
+
years_employed = gr.Slider(**EMPLOYED_MIN_MAX, step=1, label="Years of employment", info="How long have this person been employed (in years, 0 to 43) ?")
|
74 |
+
|
75 |
|
76 |
gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.")
|
77 |
with gr.Row():
|
78 |
with gr.Column():
|
79 |
gr.Markdown("### User")
|
80 |
encrypt_button_user = gr.Button("Encrypt the inputs and send to server.")
|
81 |
+
|
82 |
+
user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
|
|
83 |
encrypted_input_user = gr.Textbox(
|
84 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
85 |
)
|
86 |
+
# keys_user = gr.Textbox(
|
87 |
+
# label="Keys representation:", max_lines=2, interactive=False
|
88 |
+
# )
|
89 |
|
|
|
90 |
|
91 |
|
92 |
with gr.Column():
|
93 |
gr.Markdown("### Bank ")
|
94 |
encrypt_button_bank = gr.Button("Encrypt the inputs and send to server.")
|
95 |
+
|
96 |
+
bank_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
|
|
97 |
encrypted_input_bank = gr.Textbox(
|
98 |
label="Encrypted input representation:", max_lines=2, interactive=False
|
99 |
)
|
100 |
+
# keys_bank = gr.Textbox(
|
101 |
+
# label="Keys representation:", max_lines=2, interactive=False
|
102 |
+
# )
|
103 |
|
104 |
|
105 |
with gr.Column():
|
106 |
gr.Markdown("### Third Party ")
|
107 |
encrypt_button_third_party = gr.Button("Encrypt the inputs and send to server.")
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
third_party_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
|
110 |
+
encrypted_input_third_party = gr.Textbox(
|
111 |
+
label="Encrypted input representation:", max_lines=2, interactive=False
|
112 |
+
)
|
113 |
+
# keys_3 = gr.Textbox(
|
114 |
+
# label="Keys representation:", max_lines=2, interactive=False
|
115 |
+
# )
|
116 |
|
117 |
gr.Markdown("## Server side")
|
118 |
gr.Markdown(
|
|
|
141 |
)
|
142 |
get_output_button = gr.Button("Receive the encrypted output from the server.")
|
143 |
|
144 |
+
# encrypted_output_representation = gr.Textbox(
|
145 |
+
# label="Encrypted output representation: ", max_lines=1, interactive=False
|
146 |
+
# )
|
147 |
|
148 |
gr.Markdown("### Step 8: Decrypt the output.")
|
149 |
decrypt_button = gr.Button("Decrypt the output")
|
|
|
152 |
label="Credit card approval decision: ", max_lines=1, interactive=False
|
153 |
)
|
154 |
|
155 |
+
# Button to pre-process, generate the key, encrypt and send the user inputs from the client
|
156 |
+
# side to the server
|
157 |
+
encrypt_button_user.click(
|
158 |
+
pre_process_keygen_encrypt_send_user,
|
159 |
+
inputs=[gender, bool_inputs, num_children, num_family, total_income, age, income_type, \
|
160 |
+
education_type, family_status, occupation_type, housing_type],
|
161 |
+
outputs=[user_id, encrypted_input_user],
|
162 |
+
)
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
# Button to pre-process, generate the key, encrypt and send the bank inputs from the client
|
165 |
+
# side to the server
|
166 |
+
encrypt_button_bank.click(
|
167 |
+
pre_process_keygen_encrypt_send_bank,
|
168 |
+
inputs=[account_length],
|
169 |
+
outputs=[bank_id, encrypted_input_bank],
|
170 |
+
)
|
171 |
|
172 |
+
# Button to pre-process, generate the key, encrypt and send the third party inputs from the
|
173 |
+
# client side to the server
|
174 |
+
encrypt_button_third_party.click(
|
175 |
+
pre_process_keygen_encrypt_send_third_party,
|
176 |
+
inputs=[employed, years_employed],
|
177 |
+
outputs=[third_party_id, encrypted_input_third_party],
|
178 |
+
)
|
179 |
|
180 |
+
# TODO : ID should be unique
|
181 |
+
# Button to send the encodings to the server using post method
|
182 |
+
execute_fhe_button.click(run_fhe, inputs=[user_id, bank_id, third_party_id], outputs=[fhe_execution_time])
|
183 |
|
184 |
+
# TODO : ID should be unique
|
185 |
+
# Button to send the encodings to the server using post method
|
186 |
+
get_output_button.click(
|
187 |
+
get_output,
|
188 |
+
inputs=[user_id, bank_id, third_party_id],
|
189 |
+
# outputs=[encrypted_output_representation]
|
190 |
+
)
|
191 |
|
192 |
+
# TODO : ID should be unique
|
193 |
+
# Button to decrypt the output as the user
|
194 |
+
decrypt_button.click(
|
195 |
+
decrypt_output,
|
196 |
+
inputs=[user_id, bank_id, third_party_id],
|
197 |
+
outputs=[prediction_output],
|
198 |
+
)
|
199 |
|
200 |
gr.Markdown(
|
201 |
"The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a "
|
backend.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Backend functions used in the app."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import gradio as gr
|
6 |
+
import numpy
|
7 |
+
import requests
|
8 |
+
import pickle
|
9 |
+
import pandas
|
10 |
+
from itertools import chain
|
11 |
+
|
12 |
+
from settings import (
|
13 |
+
SERVER_URL,
|
14 |
+
FHE_KEYS,
|
15 |
+
CLIENT_FILES,
|
16 |
+
SERVER_FILES,
|
17 |
+
DEPLOYMENT_PATH,
|
18 |
+
INITIAL_INPUT_SHAPE,
|
19 |
+
INPUT_INDEXES,
|
20 |
+
INPUT_SLICES,
|
21 |
+
PRE_PROCESSOR_USER_PATH,
|
22 |
+
PRE_PROCESSOR_THIRD_PARTY_PATH,
|
23 |
+
CLIENT_TYPES,
|
24 |
+
)
|
25 |
+
|
26 |
+
from utils.client_server_interface import MultiInputsFHEModelClient
|
27 |
+
|
28 |
+
# Load pre-processor instances
|
29 |
+
with PRE_PROCESSOR_USER_PATH.open('rb') as file:
|
30 |
+
PRE_PROCESSOR_USER = pickle.load(file)
|
31 |
+
|
32 |
+
with PRE_PROCESSOR_THIRD_PARTY_PATH.open('rb') as file:
|
33 |
+
PRE_PROCESSOR_THIRD_PARTY = pickle.load(file)
|
34 |
+
|
35 |
+
|
36 |
+
def shorten_bytes_object(bytes_object, limit=500):
|
37 |
+
"""Shorten the input bytes object to a given length.
|
38 |
+
|
39 |
+
Encrypted data is too large for displaying it in the browser using Gradio. This function
|
40 |
+
provides a shorten representation of it.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
bytes_object (bytes): The input to shorten
|
44 |
+
limit (int): The length to consider. Default to 500.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
str: Hexadecimal string shorten representation of the input byte object.
|
48 |
+
|
49 |
+
"""
|
50 |
+
# Define a shift for better display
|
51 |
+
shift = 100
|
52 |
+
return bytes_object[shift : limit + shift].hex()
|
53 |
+
|
54 |
+
|
55 |
+
def clean_temporary_files(n_keys=20):
|
56 |
+
"""Clean keys and encrypted images.
|
57 |
+
|
58 |
+
A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
|
59 |
+
limit is reached, the oldest files are deleted.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
n_keys (int): The maximum number of keys and associated files to be stored. Default to 20.
|
63 |
+
|
64 |
+
"""
|
65 |
+
# Get the oldest key files in the key directory
|
66 |
+
key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime)
|
67 |
+
|
68 |
+
# If more than n_keys keys are found, remove the oldest
|
69 |
+
user_ids = []
|
70 |
+
if len(key_dirs) > n_keys:
|
71 |
+
n_keys_to_delete = len(key_dirs) - n_keys
|
72 |
+
for key_dir in key_dirs[:n_keys_to_delete]:
|
73 |
+
user_ids.append(key_dir.name)
|
74 |
+
shutil.rmtree(key_dir)
|
75 |
+
|
76 |
+
# Get all the encrypted objects in the temporary folder
|
77 |
+
client_files = CLIENT_FILES.iterdir()
|
78 |
+
server_files = SERVER_FILES.iterdir()
|
79 |
+
|
80 |
+
# Delete all files related to the ids whose keys were deleted
|
81 |
+
for file in chain(client_files, server_files):
|
82 |
+
for user_id in user_ids:
|
83 |
+
if user_id in file.name:
|
84 |
+
file.unlink()
|
85 |
+
|
86 |
+
|
87 |
+
def _get_client(client_id, client_type):
|
88 |
+
"""Get the client API.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
client_id (int): The client ID to consider.
|
92 |
+
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
FHEModelClient: The client API.
|
96 |
+
"""
|
97 |
+
key_dir = FHE_KEYS / f"{client_type}_{client_id}"
|
98 |
+
|
99 |
+
return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES))
|
100 |
+
|
101 |
+
|
102 |
+
def _keygen(client_id, client_type):
|
103 |
+
"""Generate the private key associated to a filter.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
client_id (int): The client ID to consider.
|
107 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
108 |
+
"""
|
109 |
+
# Clean temporary files
|
110 |
+
clean_temporary_files()
|
111 |
+
|
112 |
+
# Retrieve the client instance
|
113 |
+
client = _get_client(client_id, client_type)
|
114 |
+
|
115 |
+
# Generate a private key
|
116 |
+
client.generate_private_and_evaluation_keys(force=True)
|
117 |
+
|
118 |
+
# Retrieve the serialized evaluation key. In this case, as circuits are fully leveled, this
|
119 |
+
# evaluation key is empty. However, for software reasons, it is still needed for proper FHE
|
120 |
+
# execution
|
121 |
+
evaluation_key = client.get_serialized_evaluation_keys()
|
122 |
+
|
123 |
+
# Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
|
124 |
+
# buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
125 |
+
evaluation_key_path = _get_client_file_path("evaluation_key", client_id, client_type)
|
126 |
+
|
127 |
+
with evaluation_key_path.open("wb") as evaluation_key_file:
|
128 |
+
evaluation_key_file.write(evaluation_key)
|
129 |
+
|
130 |
+
|
131 |
+
def _send_input(client_id, client_type):
|
132 |
+
"""Send the encrypted input image as well as the evaluation key to the server.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
client_id (int): The client ID to consider.
|
136 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
137 |
+
"""
|
138 |
+
# Get the paths to the evaluation key and encrypted inputs
|
139 |
+
evaluation_key_path = _get_client_file_path("evaluation_key", client_id, client_type)
|
140 |
+
encrypted_input_path = _get_client_file_path("encrypted_inputs", client_id, client_type)
|
141 |
+
|
142 |
+
# Define the data and files to post
|
143 |
+
data = {
|
144 |
+
"client_id": client_id,
|
145 |
+
"client_type": client_type,
|
146 |
+
}
|
147 |
+
|
148 |
+
files = [
|
149 |
+
("files", open(encrypted_input_path, "rb")),
|
150 |
+
("files", open(evaluation_key_path, "rb")),
|
151 |
+
]
|
152 |
+
|
153 |
+
# Send the encrypted input image and evaluation key to the server
|
154 |
+
url = SERVER_URL + "send_input"
|
155 |
+
with requests.post(
|
156 |
+
url=url,
|
157 |
+
data=data,
|
158 |
+
files=files,
|
159 |
+
) as response:
|
160 |
+
return response.ok
|
161 |
+
|
162 |
+
|
163 |
+
def _get_client_file_path(name, client_id, client_type):
|
164 |
+
"""Get the correct temporary file path for the client.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs').
|
168 |
+
client_id (int): The client ID to consider.
|
169 |
+
client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
pathlib.Path: The file path.
|
173 |
+
"""
|
174 |
+
return CLIENT_FILES / f"{name}_{client_type}_{client_id}"
|
175 |
+
|
176 |
+
|
177 |
+
def _keygen_encrypt_send(inputs, client_type):
|
178 |
+
"""Encrypt the given inputs for a specific client.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
inputs (numpy.ndarray): The inputs to encrypt.
|
182 |
+
client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
client_id, encrypted_inputs_short (int, bytes): Integer ID representing the current client
|
186 |
+
and a byte short representation of the encrypted input to send.
|
187 |
+
"""
|
188 |
+
# Create an ID for the current client to consider
|
189 |
+
client_id = numpy.random.randint(0, 2**32)
|
190 |
+
|
191 |
+
_keygen(client_id, client_type)
|
192 |
+
|
193 |
+
# Retrieve the client instance
|
194 |
+
client = _get_client(client_id, client_type)
|
195 |
+
|
196 |
+
# TODO : pre-process the data first
|
197 |
+
|
198 |
+
# Quantize, encrypt and serialize the inputs
|
199 |
+
encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
|
200 |
+
inputs,
|
201 |
+
input_index=INPUT_INDEXES[client_type],
|
202 |
+
initial_input_shape=INITIAL_INPUT_SHAPE,
|
203 |
+
input_slice=INPUT_SLICES[client_type],
|
204 |
+
)
|
205 |
+
|
206 |
+
# Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio
|
207 |
+
# buttons, https://github.com/gradio-app/gradio/issues/1877
|
208 |
+
encrypted_inputs_path = _get_client_file_path("encrypted_inputs", client_id, client_type)
|
209 |
+
|
210 |
+
with encrypted_inputs_path.open("wb") as encrypted_inputs_file:
|
211 |
+
encrypted_inputs_file.write(encrypted_inputs)
|
212 |
+
|
213 |
+
# Create a truncated version of the encrypted image for display
|
214 |
+
encrypted_inputs_short = shorten_bytes_object(encrypted_inputs)
|
215 |
+
|
216 |
+
_send_input(client_id, client_type)
|
217 |
+
|
218 |
+
# TODO: also return private key representation if possible
|
219 |
+
return client_id, encrypted_inputs_short
|
220 |
+
|
221 |
+
|
222 |
+
def pre_process_keygen_encrypt_send_user(*inputs):
|
223 |
+
"""Pre-process the given inputs for a specific client.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
(int, bytes): Integer ID representing the current client and a byte short representation of
|
230 |
+
the encrypted input to send.
|
231 |
+
"""
|
232 |
+
gender, bool_inputs, num_children, num_family, total_income, age, income_type, education_type, \
|
233 |
+
family_status, occupation_type, housing_type = inputs
|
234 |
+
|
235 |
+
# Encoding given in https://www.kaggle.com/code/samuelcortinhas/credit-cards-data-cleaning
|
236 |
+
# for "Gender" is M ('Male') -> 1 and F ('Female') -> 0
|
237 |
+
gender = gender == "Male"
|
238 |
+
|
239 |
+
# Retrieve boolean values
|
240 |
+
own_car = "Car" in bool_inputs
|
241 |
+
own_property = "Property" in bool_inputs
|
242 |
+
work_phone = "Work phone" in bool_inputs
|
243 |
+
phone = "Phone" in bool_inputs
|
244 |
+
email = "Email" in bool_inputs
|
245 |
+
|
246 |
+
user_inputs = pandas.DataFrame({
|
247 |
+
"Gender": [gender],
|
248 |
+
"Own_car": [own_car],
|
249 |
+
"Own_property": [own_property],
|
250 |
+
"Work_phone": [work_phone],
|
251 |
+
"Phone": [phone],
|
252 |
+
"Email": [email],
|
253 |
+
"Num_children": num_children,
|
254 |
+
"Num_family": num_family,
|
255 |
+
"Total_income": total_income,
|
256 |
+
"Age": age,
|
257 |
+
"Income_type": income_type,
|
258 |
+
"Education_type": education_type,
|
259 |
+
"Family_status": family_status,
|
260 |
+
"Occupation_type": occupation_type,
|
261 |
+
"Housing_type": housing_type,
|
262 |
+
})
|
263 |
+
|
264 |
+
preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
|
265 |
+
|
266 |
+
return _keygen_encrypt_send(preprocessed_user_inputs, "user")
|
267 |
+
|
268 |
+
|
269 |
+
def pre_process_keygen_encrypt_send_bank(*inputs):
|
270 |
+
"""Pre-process the given inputs for a specific client.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
(int, bytes): Integer ID representing the current client and a byte short representation of
|
277 |
+
the encrypted input to send.
|
278 |
+
"""
|
279 |
+
account_length = inputs[0]
|
280 |
+
|
281 |
+
return _keygen_encrypt_send(account_length, "bank")
|
282 |
+
|
283 |
+
|
284 |
+
def pre_process_keygen_encrypt_send_third_party(*inputs):
|
285 |
+
"""Pre-process the given inputs for a specific client.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
*inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
(int, bytes): Integer ID representing the current client and a byte short representation of
|
292 |
+
the encrypted input to send.
|
293 |
+
"""
|
294 |
+
employed, years_employed = inputs
|
295 |
+
|
296 |
+
# Original dataset contains an "unemployed" feature instead of "employed"
|
297 |
+
unemployed = employed == "No"
|
298 |
+
|
299 |
+
third_party_inputs = pandas.DataFrame({
|
300 |
+
"Unemployed": [unemployed],
|
301 |
+
"Years_employed": [years_employed],
|
302 |
+
})
|
303 |
+
|
304 |
+
preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
|
305 |
+
|
306 |
+
return _keygen_encrypt_send(preprocessed_third_party_inputs, "third_party")
|
307 |
+
|
308 |
+
|
309 |
+
def run_fhe(user_id, bank_id, third_party_id):
|
310 |
+
"""Run the model on the encrypted inputs previously sent using FHE.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
user_id (int): The user ID to consider.
|
314 |
+
bank_id (int): The bank ID to consider.
|
315 |
+
third_party_id (int): The third party ID to consider.
|
316 |
+
"""
|
317 |
+
|
318 |
+
# TODO : add a warning for users to send all client types' inputs
|
319 |
+
|
320 |
+
data = {
|
321 |
+
"user_id": user_id,
|
322 |
+
"bank_id": bank_id,
|
323 |
+
"third_party_id": third_party_id,
|
324 |
+
}
|
325 |
+
|
326 |
+
# Trigger the FHE execution on the encrypted inputs previously sent
|
327 |
+
url = SERVER_URL + "run_fhe"
|
328 |
+
with requests.post(
|
329 |
+
url=url,
|
330 |
+
data=data,
|
331 |
+
) as response:
|
332 |
+
if response.ok:
|
333 |
+
return response.json()
|
334 |
+
else:
|
335 |
+
raise gr.Error("Please wait for the inputs to be sent to the server.")
|
336 |
+
|
337 |
+
|
338 |
+
def get_output(user_id, bank_id, third_party_id):
|
339 |
+
"""Retrieve the encrypted output.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
user_id (int): The user ID to consider.
|
343 |
+
bank_id (int): The bank ID to consider.
|
344 |
+
third_party_id (int): The third party ID to consider.
|
345 |
+
"""
|
346 |
+
data = {
|
347 |
+
"user_id": user_id,
|
348 |
+
"bank_id": bank_id,
|
349 |
+
"third_party_id": third_party_id,
|
350 |
+
}
|
351 |
+
|
352 |
+
# Retrieve the encrypted output image
|
353 |
+
url = SERVER_URL + "get_output"
|
354 |
+
with requests.post(
|
355 |
+
url=url,
|
356 |
+
data=data,
|
357 |
+
) as response:
|
358 |
+
if response.ok:
|
359 |
+
encrypted_output = response.content
|
360 |
+
|
361 |
+
# Save the encrypted output to bytes in a file as it is too large to pass through regular
|
362 |
+
# Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
|
363 |
+
# TODO : check if output to user is relevant
|
364 |
+
encrypted_output_path = _get_client_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
|
365 |
+
|
366 |
+
with encrypted_output_path.open("wb") as encrypted_output_file:
|
367 |
+
encrypted_output_file.write(encrypted_output)
|
368 |
+
|
369 |
+
# TODO
|
370 |
+
# Decrypt the output using a different (wrong) key for display
|
371 |
+
# output_encrypted_representation = decrypt_output_with_wrong_key(encrypted_output, client_type)
|
372 |
+
|
373 |
+
# return output_encrypted_representation
|
374 |
+
|
375 |
+
return None
|
376 |
+
else:
|
377 |
+
raise gr.Error("Please wait for the FHE execution to be completed.")
|
378 |
+
|
379 |
+
|
380 |
+
def decrypt_output(user_id, bank_id, third_party_id):
|
381 |
+
"""Decrypt the result.
|
382 |
+
|
383 |
+
Args:
|
384 |
+
user_id (int): The user ID to consider.
|
385 |
+
bank_id (int): The bank ID to consider.
|
386 |
+
third_party_id (int): The third party ID to consider.
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
output(numpy.ndarray): The decrypted output
|
390 |
+
|
391 |
+
"""
|
392 |
+
# Get the encrypted output path
|
393 |
+
encrypted_output_path = _get_client_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
|
394 |
+
|
395 |
+
if not encrypted_output_path.is_file():
|
396 |
+
raise gr.Error("Please run the FHE execution first.")
|
397 |
+
|
398 |
+
# Load the encrypted output as bytes
|
399 |
+
with encrypted_output_path.open("rb") as encrypted_output_file:
|
400 |
+
encrypted_output_proba = encrypted_output_file.read()
|
401 |
+
|
402 |
+
# Retrieve the client API
|
403 |
+
client = _get_client(user_id, "user")
|
404 |
+
|
405 |
+
# Deserialize, decrypt and post-process the encrypted output
|
406 |
+
output_proba = client.deserialize_decrypt_dequantize(encrypted_output_proba)
|
407 |
+
|
408 |
+
# Determine the predicted class
|
409 |
+
output = numpy.argmax(output_proba, axis=1)
|
410 |
+
|
411 |
+
return output
|
data/clean_data.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
deployment_files/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:06c7bd8264089eb169342aa5c3f638b11d894c54d054511a91523bfdfab69487
|
3 |
+
size 76130
|
deployment_files/pre_processor_third_party.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee39c00c8ca119a4e61f6905687c9bb540352b5ce4005aaba125290679722587
|
3 |
+
size 1590
|
deployment_files/pre_processor_user.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af3db3f40e0e38febb8efb858e07df1f432458cc66f2edb38bedbd4d35520802
|
3 |
+
size 6207
|
deployment_files/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04c3f1de7261abe6ad075f6cc13885677ddf4ca0b03d6a31f26a60f94d5aa2ae
|
3 |
+
size 10975
|
development.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Train and compile the model."""
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
import numpy
|
5 |
+
import pandas
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
from sklearn.metrics import accuracy_score
|
10 |
+
from imblearn.over_sampling import SMOTE
|
11 |
+
|
12 |
+
from settings import DEPLOYMENT_PATH, RANDOM_STATE, DATA_PATH, INPUT_SLICES, PRE_PROCESSOR_USER_PATH, PRE_PROCESSOR_THIRD_PARTY_PATH
|
13 |
+
from utils.client_server_interface import MultiInputsFHEModelDev
|
14 |
+
from utils.model import MultiInputXGBClassifier
|
15 |
+
from utils.pre_processing import get_pre_processors, select_and_pop_features
|
16 |
+
|
17 |
+
|
18 |
+
def get_processed_multi_inputs(data):
|
19 |
+
return (
|
20 |
+
data[:, INPUT_SLICES["user"]],
|
21 |
+
data[:, INPUT_SLICES["bank"]],
|
22 |
+
data[:, INPUT_SLICES["third_party"]]
|
23 |
+
)
|
24 |
+
|
25 |
+
print("Load and pre-process the data")
|
26 |
+
|
27 |
+
data = pandas.read_csv(DATA_PATH, encoding="utf-8")
|
28 |
+
|
29 |
+
# Define input and target data
|
30 |
+
data_y = data.pop("Target").copy()
|
31 |
+
data_x = data.copy()
|
32 |
+
|
33 |
+
# Get data from all parties
|
34 |
+
data_third_party = select_and_pop_features(data_x, ["Years_employed", "Unemployed"])
|
35 |
+
data_bank = select_and_pop_features(data_x, ["Account_length"])
|
36 |
+
data_user = data_x.copy()
|
37 |
+
|
38 |
+
# Feature engineer the data
|
39 |
+
pre_processor_user, pre_processor_third_party = get_pre_processors()
|
40 |
+
|
41 |
+
preprocessed_data_user = pre_processor_user.fit_transform(data_user)
|
42 |
+
preprocessed_data_bank = data_bank.to_numpy()
|
43 |
+
preprocessed_data_third_party = pre_processor_third_party.fit_transform(data_third_party)
|
44 |
+
|
45 |
+
preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1)
|
46 |
+
|
47 |
+
# The initial data-set is very imbalanced: use SMOTE to get better results
|
48 |
+
x, y = SMOTE().fit_resample(preprocessed_data_x, data_y)
|
49 |
+
|
50 |
+
# Retrieve the training and testing data
|
51 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
52 |
+
x, y, stratify=y, test_size=0.3, random_state=RANDOM_STATE
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
print("\nTrain and compile the model")
|
57 |
+
|
58 |
+
model = MultiInputXGBClassifier(max_depth=3, n_estimators=40)
|
59 |
+
|
60 |
+
model, sklearn_model = model.fit_benchmark(X_train, y_train)
|
61 |
+
|
62 |
+
multi_inputs_train = get_processed_multi_inputs(X_train)
|
63 |
+
|
64 |
+
model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
|
65 |
+
|
66 |
+
# Delete the deployment folder and its content if it already exists
|
67 |
+
if DEPLOYMENT_PATH.is_dir():
|
68 |
+
shutil.rmtree(DEPLOYMENT_PATH)
|
69 |
+
|
70 |
+
|
71 |
+
print("\nEvaluate the models")
|
72 |
+
|
73 |
+
y_pred_sklearn = sklearn_model.predict(X_test)
|
74 |
+
|
75 |
+
print(f"Sklearn accuracy score : {accuracy_score(y_test, y_pred_sklearn )*100:.2f}%")
|
76 |
+
|
77 |
+
multi_inputs_test = get_processed_multi_inputs(X_test)
|
78 |
+
|
79 |
+
y_pred_simulated = model.predict_multi_inputs(*multi_inputs_test, simulate=True)
|
80 |
+
|
81 |
+
print(f"Concrete ML accuracy score (simulated) : {accuracy_score(y_test, y_pred_simulated)*100:.2f}%")
|
82 |
+
|
83 |
+
|
84 |
+
print("\nSave deployment files")
|
85 |
+
|
86 |
+
# Save files needed for deployment
|
87 |
+
fhe_dev = MultiInputsFHEModelDev(DEPLOYMENT_PATH, model)
|
88 |
+
fhe_dev.save()
|
89 |
+
|
90 |
+
# Save pre-processors
|
91 |
+
with PRE_PROCESSOR_USER_PATH.open('wb') as file:
|
92 |
+
pickle.dump(pre_processor_user, file)
|
93 |
+
|
94 |
+
with PRE_PROCESSOR_THIRD_PARTY_PATH.open('wb') as file:
|
95 |
+
pickle.dump(pre_processor_third_party, file)
|
96 |
+
|
97 |
+
print("\nDone !")
|
development/development.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
"A script to generate all development files necessary for the project."
|
2 |
-
|
3 |
-
import shutil
|
4 |
-
import numpy
|
5 |
-
import pandas
|
6 |
-
|
7 |
-
from sklearn.model_selection import train_test_split
|
8 |
-
from imblearn.over_sampling import SMOTE
|
9 |
-
|
10 |
-
from ..settings import DEPLOYMENT_PATH, RANDOM_STATE
|
11 |
-
from client_server_interface import MultiInputsFHEModelDev
|
12 |
-
from model import MultiInputXGBClassifier
|
13 |
-
from development.pre_processing import pre_process_data
|
14 |
-
|
15 |
-
|
16 |
-
print("Load and pre-process the data")
|
17 |
-
|
18 |
-
data = pandas.read_csv("data/clean_data.csv", encoding="utf-8")
|
19 |
-
|
20 |
-
# Make median annual salary similar to France (2023): from 157500 to 22050
|
21 |
-
data["Total_income"] = data["Total_income"] * 0.14
|
22 |
-
|
23 |
-
# Remove ID feature
|
24 |
-
data.drop("ID", axis=1, inplace=True)
|
25 |
-
|
26 |
-
# Feature engineer the data
|
27 |
-
pre_processed_data, training_bins = pre_process_data(data)
|
28 |
-
|
29 |
-
# Define input and target data
|
30 |
-
y = pre_processed_data.pop("Target")
|
31 |
-
x = pre_processed_data
|
32 |
-
|
33 |
-
# The initial data-set is very imbalanced: use SMOTE to get better results
|
34 |
-
x, y = SMOTE().fit_resample(x, y)
|
35 |
-
|
36 |
-
# Retrieve the training data
|
37 |
-
X_train, _, y_train, _ = train_test_split(
|
38 |
-
x, y, stratify=y, test_size=0.3, random_state=RANDOM_STATE
|
39 |
-
)
|
40 |
-
|
41 |
-
# Convert the Pandas data frames into Numpy arrays
|
42 |
-
X_train_np = X_train.to_numpy()
|
43 |
-
y_train_np = y_train.to_numpy()
|
44 |
-
|
45 |
-
|
46 |
-
print("Train and compile the model")
|
47 |
-
|
48 |
-
model = MultiInputXGBClassifier(max_depth=3, n_estimators=40)
|
49 |
-
|
50 |
-
model.fit(X_train_np, y_train_np)
|
51 |
-
|
52 |
-
multi_inputs_train = numpy.array_split(X_train_np, 3, axis=1)
|
53 |
-
|
54 |
-
model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
|
55 |
-
|
56 |
-
# Delete the deployment folder and its content if it already exists
|
57 |
-
if DEPLOYMENT_PATH.is_dir():
|
58 |
-
shutil.rmtree(DEPLOYMENT_PATH)
|
59 |
-
|
60 |
-
|
61 |
-
print("Save deployment files")
|
62 |
-
|
63 |
-
# Save the files needed for deployment
|
64 |
-
fhe_dev = MultiInputsFHEModelDev(model, DEPLOYMENT_PATH)
|
65 |
-
fhe_dev.save()
|
66 |
-
|
67 |
-
print("Done !")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
development/pre_processing.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
import pandas
|
2 |
-
from copy import deepcopy
|
3 |
-
|
4 |
-
|
5 |
-
def convert_dummy(df, feature):
|
6 |
-
pos = pandas.get_dummies(df[feature], prefix=feature)
|
7 |
-
|
8 |
-
df.drop([feature], axis=1, inplace=True)
|
9 |
-
df = df.join(pos)
|
10 |
-
return df
|
11 |
-
|
12 |
-
|
13 |
-
def get_category(df, col, labels, qcut=False, binsnum=None, bins=None, retbins=False):
|
14 |
-
assert binsnum is not None or bins is not None
|
15 |
-
|
16 |
-
if qcut and binsnum is not None:
|
17 |
-
localdf, bin_edges = pandas.qcut(df[col], q=binsnum, labels=labels, retbins=True) # quantile cut
|
18 |
-
else:
|
19 |
-
input_bins = bins if bins is not None else binsnum
|
20 |
-
localdf, bin_edges = pandas.cut(df[col], bins=input_bins, labels=labels, retbins=True) # equal-length cut
|
21 |
-
|
22 |
-
df.drop(col, axis=1, inplace=True)
|
23 |
-
|
24 |
-
localdf = pandas.DataFrame(localdf)
|
25 |
-
df = df.join(localdf[col])
|
26 |
-
|
27 |
-
if retbins:
|
28 |
-
return df, bin_edges
|
29 |
-
|
30 |
-
return df
|
31 |
-
|
32 |
-
|
33 |
-
def pre_process_data(input_data, bins=None, columns=None):
|
34 |
-
assert bins is None or ("bin_edges_income" in bins and "bin_edges_age" in bins and "bin_edges_years_employed" in bins and columns is not None)
|
35 |
-
|
36 |
-
training_bins = {}
|
37 |
-
|
38 |
-
input_data = deepcopy(input_data)
|
39 |
-
bins = deepcopy(bins) if bins is not None else None
|
40 |
-
|
41 |
-
input_data.loc[input_data["Num_children"] >= 2, "Num_children"] = "2_or_more"
|
42 |
-
|
43 |
-
input_data = convert_dummy(input_data, "Num_children")
|
44 |
-
|
45 |
-
if bins is None:
|
46 |
-
input_data, bin_edges_income = get_category(input_data, "Total_income", ["low", "medium", "high"], qcut=True, binsnum=3, retbins=True)
|
47 |
-
training_bins["bin_edges_income"] = bin_edges_income
|
48 |
-
else:
|
49 |
-
input_data = get_category(input_data, "Total_income", ["low", "medium", "high"], bins=bins["bin_edges_income"])
|
50 |
-
|
51 |
-
input_data = convert_dummy(input_data, "Total_income")
|
52 |
-
|
53 |
-
if bins is None:
|
54 |
-
input_data, bin_edges_age = get_category(input_data, "Age", ["lowest", "low", "medium", "high", "highest"], binsnum=5, retbins=True)
|
55 |
-
training_bins["bin_edges_age"] = bin_edges_age
|
56 |
-
else:
|
57 |
-
input_data = get_category(input_data, "Age", ["lowest", "low", "medium", "high", "highest"], bins=bins["bin_edges_age"])
|
58 |
-
|
59 |
-
input_data = convert_dummy(input_data, "Age")
|
60 |
-
|
61 |
-
if bins is None:
|
62 |
-
input_data, bin_edges_years_employed = get_category(input_data, "Years_employed", ["lowest", "low", "medium", "high", "highest"], binsnum=5, retbins=True)
|
63 |
-
training_bins["bin_edges_years_employed"] = bin_edges_years_employed
|
64 |
-
else:
|
65 |
-
input_data = get_category(input_data, "Years_employed", ["lowest", "low", "medium", "high", "highest"], bins=bins["bin_edges_years_employed"])
|
66 |
-
|
67 |
-
input_data = convert_dummy(input_data, "Years_employed")
|
68 |
-
|
69 |
-
input_data.loc[input_data["Num_family"] >= 3, "Num_family"] = "3_or_more"
|
70 |
-
|
71 |
-
input_data = convert_dummy(input_data, "Num_family")
|
72 |
-
|
73 |
-
input_data.loc[input_data["Income_type"] == "Pensioner", "Income_type"] = "State servant"
|
74 |
-
input_data.loc[input_data["Income_type"] == "Student", "Income_type"] = "State servant"
|
75 |
-
|
76 |
-
input_data = convert_dummy(input_data, "Income_type")
|
77 |
-
|
78 |
-
input_data.loc[
|
79 |
-
(input_data["Occupation_type"] == "Cleaning staff")
|
80 |
-
| (input_data["Occupation_type"] == "Cooking staff")
|
81 |
-
| (input_data["Occupation_type"] == "Drivers")
|
82 |
-
| (input_data["Occupation_type"] == "Laborers")
|
83 |
-
| (input_data["Occupation_type"] == "Low-skill Laborers")
|
84 |
-
| (input_data["Occupation_type"] == "Security staff")
|
85 |
-
| (input_data["Occupation_type"] == "Waiters/barmen staff"),
|
86 |
-
"Occupation_type",
|
87 |
-
] = "Labor_work"
|
88 |
-
input_data.loc[
|
89 |
-
(input_data["Occupation_type"] == "Accountants")
|
90 |
-
| (input_data["Occupation_type"] == "Core staff")
|
91 |
-
| (input_data["Occupation_type"] == "HR staff")
|
92 |
-
| (input_data["Occupation_type"] == "Medicine staff")
|
93 |
-
| (input_data["Occupation_type"] == "Private service staff")
|
94 |
-
| (input_data["Occupation_type"] == "Realty agents")
|
95 |
-
| (input_data["Occupation_type"] == "Sales staff")
|
96 |
-
| (input_data["Occupation_type"] == "Secretaries"),
|
97 |
-
"Occupation_type",
|
98 |
-
] = "Office_work"
|
99 |
-
input_data.loc[
|
100 |
-
(input_data["Occupation_type"] == "Managers")
|
101 |
-
| (input_data["Occupation_type"] == "High skill tech staff")
|
102 |
-
| (input_data["Occupation_type"] == "IT staff"),
|
103 |
-
"Occupation_type",
|
104 |
-
] = "High_tech_work"
|
105 |
-
|
106 |
-
input_data = convert_dummy(input_data, "Occupation_type")
|
107 |
-
|
108 |
-
input_data = convert_dummy(input_data, "Housing_type")
|
109 |
-
|
110 |
-
input_data.loc[input_data["Education_type"] == "Academic degree", "Education_type"] = "Higher education"
|
111 |
-
input_data = convert_dummy(input_data, "Education_type")
|
112 |
-
|
113 |
-
input_data = convert_dummy(input_data, "Family_status")
|
114 |
-
|
115 |
-
input_data = input_data.astype("int")
|
116 |
-
|
117 |
-
if training_bins:
|
118 |
-
return input_data, training_bins
|
119 |
-
|
120 |
-
input_data = input_data.reindex(columns=columns, fill_value=0)
|
121 |
-
|
122 |
-
return input_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server.py
CHANGED
@@ -6,12 +6,13 @@ from fastapi import FastAPI, File, Form, UploadFile
|
|
6 |
from fastapi.responses import JSONResponse, Response
|
7 |
|
8 |
from settings import DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
|
9 |
-
from
|
10 |
|
11 |
# Load the server objects related to all currently available filters once and for all
|
12 |
FHE_SERVER = MultiInputsFHEModelServer(DEPLOYMENT_PATH)
|
13 |
|
14 |
-
|
|
|
15 |
"""Get the correct temporary file path for the server.
|
16 |
|
17 |
Args:
|
@@ -42,8 +43,8 @@ def send_input(
|
|
42 |
):
|
43 |
"""Send the inputs to the server."""
|
44 |
# Retrieve the encrypted inputs and the evaluation key paths
|
45 |
-
encrypted_inputs_path =
|
46 |
-
evaluation_key_path =
|
47 |
|
48 |
# Write the files using the above paths
|
49 |
with encrypted_inputs_path.open("wb") as encrypted_inputs, evaluation_key_path.open(
|
@@ -55,23 +56,30 @@ def send_input(
|
|
55 |
|
56 |
@app.post("/run_fhe")
|
57 |
def run_fhe(
|
58 |
-
|
|
|
|
|
59 |
):
|
60 |
"""Execute the model on the encrypted inputs using FHE."""
|
61 |
-
# Retrieve the evaluation key
|
62 |
-
evaluation_key_path =
|
63 |
|
64 |
-
# Get the
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
evaluation_key = evaluation_key_file.read()
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
encrypted_inputs = []
|
70 |
-
for client_type in CLIENT_TYPES:
|
71 |
-
encrypted_inputs_path = get_server_file_path("encrypted_inputs", client_id, client_type)
|
72 |
-
with encrypted_inputs_path.open("rb") as encrypted_inputs_file:
|
73 |
-
encrypted_input = encrypted_inputs_file.read()
|
74 |
-
encrypted_inputs.append(encrypted_input)
|
75 |
|
76 |
# Run the FHE execution
|
77 |
start = time.time()
|
@@ -79,7 +87,7 @@ def run_fhe(
|
|
79 |
fhe_execution_time = round(time.time() - start, 2)
|
80 |
|
81 |
# Retrieve the encrypted output path
|
82 |
-
encrypted_output_path =
|
83 |
|
84 |
# Write the file using the above path
|
85 |
with encrypted_output_path.open("wb") as output_file:
|
@@ -90,12 +98,13 @@ def run_fhe(
|
|
90 |
|
91 |
@app.post("/get_output")
|
92 |
def get_output(
|
93 |
-
|
94 |
-
|
|
|
95 |
):
|
96 |
"""Retrieve the encrypted output."""
|
97 |
# Retrieve the encrypted output path
|
98 |
-
encrypted_output_path =
|
99 |
|
100 |
# Read the file using the above path
|
101 |
with encrypted_output_path.open("rb") as encrypted_output_file:
|
|
|
6 |
from fastapi.responses import JSONResponse, Response
|
7 |
|
8 |
from settings import DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
|
9 |
+
from utils.client_server_interface import MultiInputsFHEModelServer
|
10 |
|
11 |
# Load the server objects related to all currently available filters once and for all
|
12 |
FHE_SERVER = MultiInputsFHEModelServer(DEPLOYMENT_PATH)
|
13 |
|
14 |
+
|
15 |
+
def _get_server_file_path(name, client_id, client_type):
|
16 |
"""Get the correct temporary file path for the server.
|
17 |
|
18 |
Args:
|
|
|
43 |
):
|
44 |
"""Send the inputs to the server."""
|
45 |
# Retrieve the encrypted inputs and the evaluation key paths
|
46 |
+
encrypted_inputs_path = _get_server_file_path("encrypted_inputs", client_id, client_type)
|
47 |
+
evaluation_key_path = _get_server_file_path("evaluation_key", client_id, client_type)
|
48 |
|
49 |
# Write the files using the above paths
|
50 |
with encrypted_inputs_path.open("wb") as encrypted_inputs, evaluation_key_path.open(
|
|
|
56 |
|
57 |
@app.post("/run_fhe")
|
58 |
def run_fhe(
|
59 |
+
user_id: str = Form(),
|
60 |
+
bank_id: str = Form(),
|
61 |
+
third_party_id: str = Form(),
|
62 |
):
|
63 |
"""Execute the model on the encrypted inputs using FHE."""
|
64 |
+
# Retrieve the evaluation key (from the user, as all evaluation keys should be the same)
|
65 |
+
evaluation_key_path = _get_server_file_path("evaluation_key", user_id, "user")
|
66 |
|
67 |
+
# Get the encrypted inputs
|
68 |
+
encrypted_user_inputs_path = _get_server_file_path("encrypted_inputs", user_id, "user")
|
69 |
+
encrypted_bank_inputs_path = _get_server_file_path("encrypted_inputs", bank_id, "bank")
|
70 |
+
encrypted_third_party_inputs_path = _get_server_file_path("encrypted_inputs", third_party_id, "third_party")
|
71 |
+
with (
|
72 |
+
evaluation_key_path.open("rb") as evaluation_key_file,
|
73 |
+
encrypted_user_inputs_path.open("rb") as encrypted_user_inputs_file,
|
74 |
+
encrypted_bank_inputs_path.open("rb") as encrypted_bank_inputs_file,
|
75 |
+
encrypted_third_party_inputs_path.open("rb") as encrypted_third_party_inputs_file,
|
76 |
+
):
|
77 |
evaluation_key = evaluation_key_file.read()
|
78 |
+
encrypted_user_input = encrypted_user_inputs_file.read()
|
79 |
+
encrypted_bank_input = encrypted_bank_inputs_file.read()
|
80 |
+
encrypted_third_party_input = encrypted_third_party_inputs_file.read()
|
81 |
|
82 |
+
encrypted_inputs = (encrypted_user_input, encrypted_bank_input, encrypted_third_party_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
# Run the FHE execution
|
85 |
start = time.time()
|
|
|
87 |
fhe_execution_time = round(time.time() - start, 2)
|
88 |
|
89 |
# Retrieve the encrypted output path
|
90 |
+
encrypted_output_path = _get_server_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
|
91 |
|
92 |
# Write the file using the above path
|
93 |
with encrypted_output_path.open("wb") as output_file:
|
|
|
98 |
|
99 |
@app.post("/get_output")
|
100 |
def get_output(
|
101 |
+
user_id: str = Form(),
|
102 |
+
bank_id: str = Form(),
|
103 |
+
third_party_id: str = Form(),
|
104 |
):
|
105 |
"""Retrieve the encrypted output."""
|
106 |
# Retrieve the encrypted output path
|
107 |
+
encrypted_output_path = _get_server_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
|
108 |
|
109 |
# Read the file using the above path
|
110 |
with encrypted_output_path.open("rb") as encrypted_output_file:
|
settings.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"All constants used in the project."
|
2 |
|
3 |
from pathlib import Path
|
|
|
4 |
|
5 |
# The directory of this project
|
6 |
REPO_DIR = Path(__file__).parent
|
@@ -11,6 +12,10 @@ FHE_KEYS = REPO_DIR / ".fhe_keys"
|
|
11 |
CLIENT_FILES = REPO_DIR / "client_files"
|
12 |
SERVER_FILES = REPO_DIR / "server_files"
|
13 |
|
|
|
|
|
|
|
|
|
14 |
# Create the necessary directories
|
15 |
FHE_KEYS.mkdir(exist_ok=True)
|
16 |
CLIENT_FILES.mkdir(exist_ok=True)
|
@@ -19,8 +24,14 @@ SERVER_FILES.mkdir(exist_ok=True)
|
|
19 |
# Store the server's URL
|
20 |
SERVER_URL = "http://localhost:8000/"
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
|
|
|
|
24 |
INITIAL_INPUT_SHAPE = (1, 49)
|
25 |
|
26 |
CLIENT_TYPES = ["user", "bank", "third_party"]
|
@@ -29,8 +40,33 @@ INPUT_INDEXES = {
|
|
29 |
"bank": 1,
|
30 |
"third_party": 2,
|
31 |
}
|
32 |
-
|
33 |
-
"user": 0, # First position: start from 0
|
34 |
-
"bank":
|
35 |
-
"third_party":
|
36 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"All constants used in the project."
|
2 |
|
3 |
from pathlib import Path
|
4 |
+
import pandas
|
5 |
|
6 |
# The directory of this project
|
7 |
REPO_DIR = Path(__file__).parent
|
|
|
12 |
CLIENT_FILES = REPO_DIR / "client_files"
|
13 |
SERVER_FILES = REPO_DIR / "server_files"
|
14 |
|
15 |
+
# Path targeting pre-processor saved files
|
16 |
+
PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
|
17 |
+
PRE_PROCESSOR_THIRD_PARTY_PATH = DEPLOYMENT_PATH / 'pre_processor_third_party.pkl'
|
18 |
+
|
19 |
# Create the necessary directories
|
20 |
FHE_KEYS.mkdir(exist_ok=True)
|
21 |
CLIENT_FILES.mkdir(exist_ok=True)
|
|
|
24 |
# Store the server's URL
|
25 |
SERVER_URL = "http://localhost:8000/"
|
26 |
|
27 |
+
# Path to data file
|
28 |
+
# The data was previously cleaned using this notebook : https://www.kaggle.com/code/samuelcortinhas/credit-cards-data-cleaning
|
29 |
+
# Additionally, the "ID" columns has been removed and the "Total_income" has been adjusted so that
|
30 |
+
# its median value corresponds to France's 2023 median annual salary (22050 euros)
|
31 |
+
DATA_PATH = "data/clean_data.csv"
|
32 |
|
33 |
+
# Developement settings
|
34 |
+
RANDOM_STATE = 0
|
35 |
INITIAL_INPUT_SHAPE = (1, 49)
|
36 |
|
37 |
CLIENT_TYPES = ["user", "bank", "third_party"]
|
|
|
40 |
"bank": 1,
|
41 |
"third_party": 2,
|
42 |
}
|
43 |
+
INPUT_SLICES = {
|
44 |
+
"user": slice(0, 42), # First position: start from 0
|
45 |
+
"bank": slice(42, 43), # Second position: start from n_feature_user
|
46 |
+
"third_party": slice(43, 49), # Third position: start from n_feature_user + n_feature_bank
|
47 |
}
|
48 |
+
|
49 |
+
_data = pandas.read_csv(DATA_PATH, encoding="utf-8")
|
50 |
+
|
51 |
+
def get_min_max(data, column):
|
52 |
+
"""Get min/max values of a column in order to input them in Gradio's API as key arguments."""
|
53 |
+
return {
|
54 |
+
"minimum": int(data[column].min()),
|
55 |
+
"maximum": int(data[column].max()),
|
56 |
+
}
|
57 |
+
|
58 |
+
# App data min and max values
|
59 |
+
ACCOUNT_MIN_MAX = get_min_max(_data, "Account_length")
|
60 |
+
CHILDREN_MIN_MAX = get_min_max(_data, "Num_children")
|
61 |
+
INCOME_MIN_MAX = get_min_max(_data, "Total_income")
|
62 |
+
AGE_MIN_MAX = get_min_max(_data, "Age")
|
63 |
+
EMPLOYED_MIN_MAX = get_min_max(_data, "Years_employed")
|
64 |
+
FAMILY_MIN_MAX = get_min_max(_data, "Num_family")
|
65 |
+
|
66 |
+
# App data choices
|
67 |
+
INCOME_TYPES = list(_data["Income_type"].unique())
|
68 |
+
OCCUPATION_TYPES = list(_data["Occupation_type"].unique())
|
69 |
+
HOUSING_TYPES = list(_data["Housing_type"].unique())
|
70 |
+
EDUCATION_TYPES = list(_data["Education_type"].unique())
|
71 |
+
FAMILY_STATUS = list(_data["Family_status"].unique())
|
72 |
+
|
{development β utils}/client_server_interface.py
RENAMED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import numpy
|
2 |
import copy
|
3 |
|
@@ -25,22 +27,21 @@ class MultiInputsFHEModelClient(FHEModelClient):
|
|
25 |
|
26 |
super().__init__(*args, **kwargs)
|
27 |
|
28 |
-
def quantize_encrypt_serialize_multi_inputs(self, x: numpy.ndarray, input_index, initial_input_shape,
|
29 |
|
30 |
x_padded = numpy.zeros(initial_input_shape)
|
31 |
|
32 |
-
|
33 |
-
x_padded[:, start_position:end] = x
|
34 |
|
35 |
q_x_padded = self.model.quantize_input(x_padded)
|
36 |
|
37 |
-
q_x = q_x_padded[:,
|
38 |
|
39 |
-
|
40 |
-
|
41 |
|
42 |
# Encrypt the values
|
43 |
-
q_x_enc = self.client.encrypt(*
|
44 |
|
45 |
# Serialize the encrypted values to be sent to the server
|
46 |
q_x_enc_ser = q_x_enc[input_index].serialize()
|
|
|
1 |
+
"""Modified classes for use for Client-Server interface with multi-inputs circuits."""
|
2 |
+
|
3 |
import numpy
|
4 |
import copy
|
5 |
|
|
|
27 |
|
28 |
super().__init__(*args, **kwargs)
|
29 |
|
30 |
+
def quantize_encrypt_serialize_multi_inputs(self, x: numpy.ndarray, input_index, initial_input_shape, input_slice) -> bytes:
|
31 |
|
32 |
x_padded = numpy.zeros(initial_input_shape)
|
33 |
|
34 |
+
x_padded[:, input_slice] = x
|
|
|
35 |
|
36 |
q_x_padded = self.model.quantize_input(x_padded)
|
37 |
|
38 |
+
q_x = q_x_padded[:, input_slice]
|
39 |
|
40 |
+
q_x_inputs = [None for _ in range(self.nb_inputs)]
|
41 |
+
q_x_inputs[input_index] = q_x
|
42 |
|
43 |
# Encrypt the values
|
44 |
+
q_x_enc = self.client.encrypt(*q_x_inputs)
|
45 |
|
46 |
# Serialize the encrypted values to be sent to the server
|
47 |
q_x_enc_ser = q_x_enc[input_index].serialize()
|
{development β utils}/model.py
RENAMED
@@ -1,4 +1,7 @@
|
|
|
|
|
|
1 |
import numpy
|
|
|
2 |
from typing import Optional, Sequence, Union
|
3 |
|
4 |
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
|
@@ -128,3 +131,43 @@ class MultiInputXGBClassifier(ConcreteXGBClassifier):
|
|
128 |
)
|
129 |
|
130 |
return compiler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified model class to handles multi-inputs circuit."""
|
2 |
+
|
3 |
import numpy
|
4 |
+
import time
|
5 |
from typing import Optional, Sequence, Union
|
6 |
|
7 |
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
|
|
|
131 |
)
|
132 |
|
133 |
return compiler
|
134 |
+
|
135 |
+
def predict_multi_inputs(self, *multi_inputs, simulate=True):
|
136 |
+
"""Run the inference with multiple inputs, with simulation or in FHE."""
|
137 |
+
assert all(isinstance(inputs, numpy.ndarray) for inputs in multi_inputs)
|
138 |
+
|
139 |
+
if not simulate:
|
140 |
+
self.fhe_circuit.keygen()
|
141 |
+
|
142 |
+
y_preds = []
|
143 |
+
execution_times = []
|
144 |
+
for inputs in zip(*multi_inputs):
|
145 |
+
inputs = tuple(numpy.expand_dims(input, axis=0) for input in inputs)
|
146 |
+
|
147 |
+
q_inputs = self.quantize_input(*inputs)
|
148 |
+
|
149 |
+
if simulate:
|
150 |
+
q_y_proba = self.fhe_circuit.simulate(*q_inputs)
|
151 |
+
else:
|
152 |
+
q_inputs_enc = self.fhe_circuit.encrypt(*q_inputs)
|
153 |
+
|
154 |
+
start = time.time()
|
155 |
+
q_y_proba_enc = self.fhe_circuit.run(*q_inputs_enc)
|
156 |
+
end = time.time() - start
|
157 |
+
|
158 |
+
execution_times.append(end)
|
159 |
+
|
160 |
+
q_y_proba = self.fhe_circuit.decrypt(q_y_proba_enc)
|
161 |
+
|
162 |
+
y_proba = self.dequantize_output(q_y_proba)
|
163 |
+
|
164 |
+
y_proba = self.post_processing(y_proba)
|
165 |
+
|
166 |
+
y_pred = numpy.argmax(y_proba, axis=1)
|
167 |
+
|
168 |
+
y_preds.append(y_pred)
|
169 |
+
|
170 |
+
if not simulate:
|
171 |
+
print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s")
|
172 |
+
|
173 |
+
return numpy.array(y_preds)
|
utils/pre_processing.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Data pre-processing functions."""
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
from sklearn.compose import ColumnTransformer
|
5 |
+
from sklearn.pipeline import Pipeline
|
6 |
+
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, KBinsDiscretizer
|
7 |
+
|
8 |
+
|
9 |
+
def _get_pipeline_replace_one_hot(func, value):
|
10 |
+
return Pipeline([
|
11 |
+
("replace", FunctionTransformer(
|
12 |
+
func,
|
13 |
+
kw_args={"value": value},
|
14 |
+
feature_names_out='one-to-one',
|
15 |
+
)),
|
16 |
+
("one_hot", OneHotEncoder(),),
|
17 |
+
])
|
18 |
+
|
19 |
+
|
20 |
+
def _replace_values_geq(column, value):
|
21 |
+
return numpy.where(column >= value, f"{value}_or_more", column)
|
22 |
+
|
23 |
+
def _replace_values_eq(column, value):
|
24 |
+
for desired_value, values_to_replace in value.items():
|
25 |
+
column = numpy.where(numpy.isin(column, values_to_replace), desired_value, column)
|
26 |
+
return column
|
27 |
+
|
28 |
+
def get_pre_processors():
|
29 |
+
pre_processor_user = ColumnTransformer(
|
30 |
+
transformers=[
|
31 |
+
(
|
32 |
+
"replace_num_children",
|
33 |
+
_get_pipeline_replace_one_hot(_replace_values_geq, 2),
|
34 |
+
['Num_children']
|
35 |
+
),
|
36 |
+
(
|
37 |
+
"replace_num_family",
|
38 |
+
_get_pipeline_replace_one_hot(_replace_values_geq, 3),
|
39 |
+
['Num_family']
|
40 |
+
),
|
41 |
+
(
|
42 |
+
"replace_income_type",
|
43 |
+
_get_pipeline_replace_one_hot(_replace_values_eq, {"State servant": ["Pensioner", "Student"]}),
|
44 |
+
['Income_type']
|
45 |
+
),
|
46 |
+
(
|
47 |
+
"replace_education_type",
|
48 |
+
_get_pipeline_replace_one_hot(_replace_values_eq, {"Higher education": ["Academic degree"]}),
|
49 |
+
['Education_type']
|
50 |
+
),
|
51 |
+
(
|
52 |
+
"replace_occupation_type_labor",
|
53 |
+
_get_pipeline_replace_one_hot(
|
54 |
+
_replace_values_eq,
|
55 |
+
{
|
56 |
+
"Labor_work": ["Cleaning staff", "Cooking staff", "Drivers", "Laborers", "Low-skill Laborers", "Security staff", "Waiters/barmen staff"],
|
57 |
+
"Office_work": ["Accountants", "Core staff", "HR staff", "Medicine staff", "Private service staff", "Realty agents", "Sales staff", "Secretaries"],
|
58 |
+
"High_tech_work": ["Managers", "High skill tech staff", "IT staff"],
|
59 |
+
},
|
60 |
+
),
|
61 |
+
['Occupation_type']
|
62 |
+
),
|
63 |
+
('one_hot_housing_fam_status', OneHotEncoder(), ['Housing_type', 'Family_status']),
|
64 |
+
('qbin_total_income', KBinsDiscretizer(n_bins=3, strategy='quantile', encode="onehot"), ['Total_income']),
|
65 |
+
('bin_age', KBinsDiscretizer(n_bins=5, strategy='uniform', encode="onehot"), ['Age']),
|
66 |
+
],
|
67 |
+
remainder='passthrough',
|
68 |
+
verbose_feature_names_out=False,
|
69 |
+
)
|
70 |
+
|
71 |
+
pre_processor_third_party = ColumnTransformer(
|
72 |
+
transformers=[
|
73 |
+
('bin_years_employed', KBinsDiscretizer(n_bins=5, strategy='uniform', encode="onehot"), ['Years_employed'])
|
74 |
+
],
|
75 |
+
remainder='passthrough',
|
76 |
+
verbose_feature_names_out=False,
|
77 |
+
)
|
78 |
+
|
79 |
+
return pre_processor_user, pre_processor_third_party
|
80 |
+
|
81 |
+
|
82 |
+
def select_and_pop_features(data, columns):
|
83 |
+
new_data = data[columns].copy()
|
84 |
+
data.drop(columns, axis=1, inplace=True)
|
85 |
+
return new_data
|