Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 6,225 Bytes
9a997e4 c119738 9a997e4 c119738 9a997e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
"""Modified model class to handles multi-inputs circuit."""
import numpy
import time
from typing import Optional, Sequence, Union
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
from concrete.ml.common.check_inputs import check_array_and_assert
from concrete.ml.common.utils import (
generate_proxy_function,
manage_parameters_for_pbs_errors,
check_there_is_no_p_error_options_in_configuration
)
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
class MultiInputXGBClassifier(ConcreteXGBClassifier):
def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray:
self.check_model_is_fitted()
assert sum(input.shape[1] for input in X) == len(self.input_quantizers)
base_j = 0
q_inputs = []
for i, input in enumerate(X):
q_input = numpy.zeros_like(input, dtype=numpy.int64)
for j in range(input.shape[1]):
quantizer_index = base_j + j
q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j])
assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values"
q_inputs.append(q_input)
base_j += input.shape[1]
return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0]
def compile(
self,
*inputs,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> Circuit:
# Check that the model is correctly fitted
self.check_model_is_fitted()
# Cast pandas, list or torch to numpy
inputs_as_array = []
for input in inputs:
input_as_array = check_array_and_assert(input)
inputs_as_array.append(input_as_array)
inputs_as_array = tuple(inputs_as_array)
# p_error or global_p_error should not be set in both the configuration and direct arguments
check_there_is_no_p_error_options_in_configuration(configuration)
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)
# Quantize the inputs
quantized_inputs = self.quantize_input(*inputs_as_array)
# Generate the compilation input-set with proper dimensions
inputset = _get_inputset_generator(quantized_inputs)
# Reset for double compile
self._is_compiled = False
# Retrieve the compiler instance
module_to_compile = self._get_module_to_compile(inputs_encryption_status)
# Compiling using a QuantizedModule requires different steps and should not be done here
assert isinstance(module_to_compile, Compiler), (
"Wrong module to compile. Expected to be of type `Compiler` but got "
f"{type(module_to_compile)}."
)
# Jit compiler is now deprecated and will soon be removed, it is thus forced to False
# by default
self.fhe_circuit_ = module_to_compile.compile(
inputset,
configuration=configuration,
artifacts=artifacts,
show_mlir=show_mlir,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
jit=False,
)
self._is_compiled = True
# For mypy
assert isinstance(self.fhe_circuit, Circuit)
return self.fhe_circuit
def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]:
assert self._tree_inference is not None, self._is_not_fitted_error_message()
if not self._is_compiled:
xgb_inference = self._tree_inference
self._tree_inference = lambda *args: xgb_inference(numpy.concatenate(args, axis=1))
input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))]
# Generate the proxy function to compile
_tree_inference_proxy, function_arg_names = generate_proxy_function(
self._tree_inference, input_names
)
inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)}
# Create the compiler instance
compiler = Compiler(
_tree_inference_proxy,
inputs_encryption_statuses,
)
return compiler
def predict_multi_inputs(self, *multi_inputs, simulate=True):
"""Run the inference with multiple inputs, with simulation or in FHE."""
assert all(isinstance(inputs, numpy.ndarray) for inputs in multi_inputs)
if not simulate:
self.fhe_circuit.keygen()
y_preds = []
execution_times = []
for inputs in zip(*multi_inputs):
inputs = tuple(numpy.expand_dims(input, axis=0) for input in inputs)
q_inputs = self.quantize_input(*inputs)
if simulate:
q_y_proba = self.fhe_circuit.simulate(*q_inputs)
else:
q_inputs_enc = self.fhe_circuit.encrypt(*q_inputs)
start = time.time()
q_y_proba_enc = self.fhe_circuit.run(*q_inputs_enc)
end = time.time() - start
execution_times.append(end)
q_y_proba = self.fhe_circuit.decrypt(q_y_proba_enc)
y_proba = self.dequantize_output(q_y_proba)
y_proba = self.post_processing(y_proba)
y_pred = numpy.argmax(y_proba, axis=1)
y_preds.append(y_pred)
if not simulate:
print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s")
return numpy.array(y_preds)
|