File size: 2,478 Bytes
9a997e4
 
c119738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a997e4
c119738
 
 
9a997e4
c119738
 
 
9a997e4
c119738
9a997e4
 
c119738
 
9a997e4
c119738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Modified classes for use for Client-Server interface with multi-inputs circuits."""

import numpy
import copy

from concrete.fhe import Value, EvaluationKeys

from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier


class MultiInputsFHEModelDev(FHEModelDev):

    def __init__(self, *arg, **kwargs):

        super().__init__(*arg, **kwargs)
        
        model = copy.copy(self.model)
        model.__class__ = ConcreteXGBClassifier
        self.model = model


class MultiInputsFHEModelClient(FHEModelClient):

    def __init__(self, *args, nb_inputs=1, **kwargs):
        self.nb_inputs = nb_inputs

        super().__init__(*args, **kwargs)
    
    def quantize_encrypt_serialize_multi_inputs(self, x: numpy.ndarray, input_index, initial_input_shape, input_slice) -> bytes:

        x_padded = numpy.zeros(initial_input_shape)

        x_padded[:, input_slice] = x

        q_x_padded = self.model.quantize_input(x_padded)

        q_x = q_x_padded[:, input_slice]
        
        q_x_inputs = [None for _ in range(self.nb_inputs)]
        q_x_inputs[input_index] = q_x

        # Encrypt the values
        q_x_enc = self.client.encrypt(*q_x_inputs)

        # Serialize the encrypted values to be sent to the server
        q_x_enc_ser = q_x_enc[input_index].serialize()
        return q_x_enc_ser
    

class MultiInputsFHEModelServer(FHEModelServer):

    def run(
        self,
        *serialized_encrypted_quantized_data: bytes,
        serialized_evaluation_keys: bytes,
    ) -> bytes:
        """Run the model on the server over encrypted data.

        Args:
            serialized_encrypted_quantized_data (bytes): the encrypted, quantized
                and serialized data
            serialized_evaluation_keys (bytes): the serialized evaluation keys

        Returns:
            bytes: the result of the model
        """
        assert self.server is not None, "Model has not been loaded."

        deserialized_encrypted_quantized_data = tuple(Value.deserialize(data) for data in serialized_encrypted_quantized_data)

        deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)

        result = self.server.run(
            *deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
        )
        serialized_result = result.serialize()
        return serialized_result