kcelia commited on
Commit
174cd37
1 Parent(s): bbc133a

chore: update the space with layout

Browse files
README.md CHANGED
@@ -15,7 +15,7 @@ tags:
15
  - data anonymization
16
  - homomorphic encryption
17
  - security
18
- python_version: 3.10
19
  ---
20
 
21
  # Data Anonymization using FHE
 
15
  - data anonymization
16
  - homomorphic encryption
17
  - security
18
+ python_version: 3.10.12
19
  ---
20
 
21
  # Data Anonymization using FHE
anonymize_file_clear.py CHANGED
@@ -1,25 +1,28 @@
1
  import argparse
2
- import json
3
  import re
4
  import uuid
5
- from pathlib import Path
6
- import gensim
7
  from concrete.ml.common.serialization.loaders import load
8
- from transformers import AutoTokenizer, AutoModel
9
- from utils_demo import get_batch_text_representation
10
 
11
  def load_models():
12
- base_dir = Path(__file__).parent / "models"
13
 
14
- # Load tokenizer and model
15
- tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
16
- embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
 
 
 
 
 
 
 
17
 
18
- with open(base_dir / "cml_logreg.model", "r") as model_file:
19
- fhe_ner_detection = load(file=model_file)
20
- return embeddings_model, tokenizer, fhe_ner_detection
21
 
22
- def anonymize_text(text, embeddings_model, tokenizer, fhe_ner_detection):
 
23
  token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
24
  tokens = re.findall(token_pattern, text)
25
  uuid_map = {}
@@ -28,9 +31,9 @@ def anonymize_text(text, embeddings_model, tokenizer, fhe_ner_detection):
28
  for token in tokens:
29
  if token.strip() and re.match(r"\w+", token): # If the token is a word
30
  x = get_batch_text_representation([token], embeddings_model, tokenizer)
31
- prediction_proba = fhe_ner_detection.predict_proba(x)
32
  probability = prediction_proba[0][1]
33
- prediction = probability >= 0.5
34
  if prediction:
35
  if token not in uuid_map:
36
  uuid_map[token] = str(uuid.uuid4())[:8]
@@ -40,41 +43,69 @@ def anonymize_text(text, embeddings_model, tokenizer, fhe_ner_detection):
40
  else:
41
  processed_tokens.append(token) # Preserve punctuation and spaces as is
42
 
43
- anonymized_text = ''.join(processed_tokens)
44
  return anonymized_text, uuid_map
45
 
46
- def main():
47
- parser = argparse.ArgumentParser(description="Anonymize named entities in a text file and save the mapping to a JSON file.")
48
- parser.add_argument("file_path", type=str, help="The path to the file to be processed.")
49
- args = parser.parse_args()
50
 
51
- embeddings_model, tokenizer, fhe_ner_detection = load_models()
52
 
53
- # Read the input file
54
- with open(args.file_path, 'r', encoding='utf-8') as file:
55
- text = file.read()
 
 
 
 
56
 
57
  # Save the original text to its specified file
58
- original_file_path = Path(__file__).parent / "files" / "original_document.txt"
59
- with open(original_file_path, 'w', encoding='utf-8') as original_file:
60
- original_file.write(text)
61
-
62
  # Anonymize the text
63
- anonymized_text, uuid_map = anonymize_text(text, embeddings_model, tokenizer, fhe_ner_detection)
64
 
65
  # Save the anonymized text to its specified file
66
- anonymized_file_path = Path(__file__).parent / "files" / "anonymized_document.txt"
67
- with open(anonymized_file_path, 'w', encoding='utf-8') as anonymized_file:
68
- anonymized_file.write(anonymized_text)
 
 
 
 
69
 
70
  # Save the UUID mapping to a JSON file
71
- mapping_path = Path(args.file_path).stem + "_uuid_mapping.json"
72
- with open(mapping_path, 'w', encoding='utf-8') as file:
73
- json.dump(uuid_map, file, indent=4, sort_keys=True)
 
 
 
 
 
 
 
74
 
75
- print(f"Original text saved to {original_file_path}")
76
- print(f"Anonymized text saved to {anonymized_file_path}")
77
- print(f"UUID mapping saved to {mapping_path}")
78
 
79
  if __name__ == "__main__":
80
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
 
2
  import re
3
  import uuid
4
+
5
+ from transformers import AutoModel, AutoTokenizer
6
  from concrete.ml.common.serialization.loaders import load
7
+ from utils_demo import *
 
8
 
9
  def load_models():
 
10
 
11
+ # Load the tokenizer and the embedding model
12
+ try:
13
+ tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
14
+ embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
15
+ except:
16
+ print("Error while loading Roberta")
17
+
18
+ # Load the CML trained model
19
+ with open(LOGREG_MODEL_PATH, "r") as model_file:
20
+ cml_ner_model = load(file=model_file)
21
 
22
+ return embeddings_model, tokenizer, cml_ner_model
 
 
23
 
24
+
25
+ def anonymize_with_cml(text, embeddings_model, tokenizer, cml_ner_model):
26
  token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
27
  tokens = re.findall(token_pattern, text)
28
  uuid_map = {}
 
31
  for token in tokens:
32
  if token.strip() and re.match(r"\w+", token): # If the token is a word
33
  x = get_batch_text_representation([token], embeddings_model, tokenizer)
34
+ prediction_proba = cml_ner_model.predict_proba(x, fhe="disable")
35
  probability = prediction_proba[0][1]
36
+ prediction = probability >= 0.77
37
  if prediction:
38
  if token not in uuid_map:
39
  uuid_map[token] = str(uuid.uuid4())[:8]
 
43
  else:
44
  processed_tokens.append(token) # Preserve punctuation and spaces as is
45
 
46
+ anonymized_text = "".join(processed_tokens)
47
  return anonymized_text, uuid_map
48
 
 
 
 
 
49
 
50
+ def anonymize_text(text, verbose=False, save=False):
51
 
52
+ # Load models
53
+ if verbose:
54
+ print("Loading models..")
55
+ embeddings_model, tokenizer, cml_ner_model = load_models()
56
+
57
+ if verbose:
58
+ print(f"\nText to process:--------------------\n{text}\n--------------------\n")
59
 
60
  # Save the original text to its specified file
61
+ if save:
62
+ write_txt(ORIGINAL_FILE_PATH, text)
63
+
 
64
  # Anonymize the text
65
+ anonymized_text, uuid_map = anonymize_with_cml(text, embeddings_model, tokenizer, cml_ner_model)
66
 
67
  # Save the anonymized text to its specified file
68
+ if save:
69
+ mapping = {o: (i, a) for i, (o, a) in enumerate(zip(text.split("\n\n"), anonymized_text.split("\n\n")))}
70
+ write_txt(ANONYMIZED_FILE_PATH, anonymized_text)
71
+ write_pickle(MAPPING_SENTENCES_PATH, mapping)
72
+
73
+ if verbose:
74
+ print(f"\nAnonymized text:--------------------\n{anonymized_text}\n--------------------\n")
75
 
76
  # Save the UUID mapping to a JSON file
77
+ if save:
78
+ write_json(MAPPING_UUID_PATH, uuid_map)
79
+
80
+ if verbose and save:
81
+ print(f"Original text saved to :{ORIGINAL_FILE_PATH}")
82
+ print(f"Anonymized text saved to :{ANONYMIZED_FILE_PATH}")
83
+ print(f"UUID mapping saved to :{MAPPING_UUID_PATH}")
84
+ print(f"Sentence mapping saved to :{MAPPING_SENTENCES_PATH}")
85
+
86
+ return anonymized_text
87
 
 
 
 
88
 
89
  if __name__ == "__main__":
90
+ parser = argparse.ArgumentParser(
91
+ description="Anonymize named entities in a text file and save the mapping to a JSON file."
92
+ )
93
+ parser.add_argument(
94
+ "--file_path",
95
+ type=str,
96
+ default="files/original_document.txt",
97
+ help="The path to the file to be processed.",
98
+ )
99
+ parser.add_argument(
100
+ "--verbose",
101
+ type=bool,
102
+ default=True,
103
+ help="This provides additional details about the program's execution.",
104
+ )
105
+ parser.add_argument("--save", type=bool, default=True, help="Save the files.")
106
+
107
+ args = parser.parse_args()
108
+
109
+ text = read_txt(args.file_path)
110
+
111
+ anonymize_text(text, verbose=args.verbose, save=args.save)
app.py CHANGED
@@ -1,35 +1,102 @@
1
  """A Gradio app for anonymizing text data using FHE."""
2
 
 
 
 
 
3
  import gradio as gr
4
- from fhe_anonymizer import FHEAnonymizer
5
  import pandas as pd
 
6
  from openai import OpenAI
7
- import os
8
- import json
9
- import re
10
  from utils_demo import *
11
- from typing import List, Dict, Tuple
 
 
 
 
 
12
 
13
  anonymizer = FHEAnonymizer()
14
 
15
- client = OpenAI(
16
- api_key=os.environ.get("openaikey"),
17
- )
18
 
19
 
20
- def check_user_query_fn(user_query: str) -> Dict:
21
- if is_user_query_valid(user_query):
22
- # TODO: check if the query is related to our context
23
- error_msg = ("Unable to process ❌: The request exceeds the length limit or falls "
24
- "outside the scope of this document. Please refine your query.")
25
- print(error_msg)
26
- return {input_text: gr.update(value=error_msg)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  else:
28
- # Collapsing Multiple Spaces
29
- return {input_text: gr.update(value=re.sub(" +", " ", user_query))}
30
-
31
- def deidentify_text(input_text):
32
- anonymized_text, identified_words_with_prob = anonymizer(input_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Convert the list of identified words and probabilities into a DataFrame
35
  if identified_words_with_prob:
@@ -41,18 +108,35 @@ def deidentify_text(input_text):
41
  return anonymized_text, identified_df
42
 
43
 
44
- def query_chatgpt(anonymized_query):
 
 
 
 
 
45
 
46
- with open("files/anonymized_document.txt", "r") as file:
47
- anonymized_document = file.read()
48
- with open("files/chatgpt_prompt.txt", "r") as file:
49
- prompt = file.read()
 
 
 
 
 
 
 
50
 
51
  # Prepare prompt
52
- full_prompt = (
53
- prompt + "\n"
 
 
 
 
 
 
54
  )
55
- query = "Document content:\n```\n" + anonymized_document + "\n\n```" + "Query:\n```\n" + anonymized_query + "\n```"
56
  print(full_prompt)
57
 
58
  completion = client.chat.completions.create(
@@ -63,16 +147,16 @@ def query_chatgpt(anonymized_query):
63
  ],
64
  )
65
  anonymized_response = completion.choices[0].message.content
66
- with open("original_document_uuid_mapping.json", "r") as file:
67
- uuid_map = json.load(file)
68
- inverse_uuid_map = {v: k for k, v in uuid_map.items()} # TODO load the inverse mapping from disk for efficiency
 
 
69
 
70
  # Pattern to identify words and non-words (including punctuation, spaces, etc.)
71
- token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
72
- tokens = re.findall(token_pattern, anonymized_response)
73
  processed_tokens = []
74
 
75
-
76
  for token in tokens:
77
  # Directly append non-word tokens or whitespace to processed_tokens
78
  if not token.strip() or not re.match(r"\w+", token):
@@ -87,12 +171,6 @@ def query_chatgpt(anonymized_query):
87
  return anonymized_response, deanonymized_response
88
 
89
 
90
- with open("files/original_document.txt", "r") as file:
91
- original_document = file.read()
92
-
93
- with open("files/anonymized_document.txt", "r") as file:
94
- anonymized_document = file.read()
95
-
96
  demo = gr.Blocks(css=".markdown-body { font-size: 18px; }")
97
 
98
  with demo:
@@ -108,80 +186,204 @@ with demo:
108
 
109
  <a href="https://docs.zama.ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="file/images/logos/documentation.png">Documentation</a>
110
 
111
- <a href="https://zama.ai/community"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="file/images/logos/community.png">Community</a>
112
 
113
  <a href="https://twitter.com/zama_fhe"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="file/images/logos/x.png">@zama_fhe</a>
114
  </p>
115
  """
116
  )
117
 
118
- gr.Markdown(
 
 
 
 
 
 
 
 
 
 
119
  """
120
- <p align="center">
121
- <img width="30%" height="25%" src="./encrypted_anonymization_diagram.jpg">
122
- </p>
 
 
 
 
 
 
 
 
 
 
123
  """
 
 
 
 
 
 
 
 
 
124
  )
125
 
126
- with gr.Accordion("What is Encrypted Anonymization?", open=False):
127
- gr.Markdown(
128
- """
129
- Encrypted Anonymization leverages Fully Homomorphic Encryption (FHE) to
130
- protect sensitive information during data processing. This approach allows for the
131
- anonymization of text data, such as personal identifiers, while ensuring that the data
132
- remains encrypted throughout the entire process.
133
- """
134
- )
135
 
136
  ########################## Main document Part ##########################
137
 
 
 
138
  with gr.Row():
139
  with gr.Column():
140
- original_doc_box = gr.Textbox(label="Original Document:", value=original_document, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  with gr.Column():
142
- anonymized_doc_box = gr.Textbox(label="Anonymized Document:", value=anonymized_document, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  ########################## User Query Part ##########################
145
-
 
 
 
 
 
 
 
 
 
 
 
 
146
  with gr.Row():
147
- input_text = gr.Textbox(value="Who lives in Maine?", label="User query", interactive=True)
148
-
149
- default_query_box = gr.Radio(choices=list(DEFAULT_QUERIES.keys()), label="Example Queries")
150
-
151
- default_query_box.change(
152
- fn=lambda default_query_box: DEFAULT_QUERIES[default_query_box],
153
- inputs=[default_query_box],
154
- outputs=[input_text]
155
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- input_text.change(
158
- check_user_query_fn,
159
- inputs=[input_text],
160
- outputs=[input_text],
161
- )
 
 
 
 
 
162
 
163
- anonymized_text_output = gr.Textbox(label="Anonymized Text with FHE", lines=1, interactive=True)
164
 
165
- identified_words_output = gr.Dataframe(label="Identified Words", visible=False)
 
 
166
 
167
- submit_button = gr.Button("Anonymize with FHE")
168
 
169
- submit_button.click(
170
- deidentify_text,
171
- inputs=[input_text],
172
  outputs=[anonymized_text_output, identified_words_output],
173
  )
174
 
175
- with gr.Row():
176
- chatgpt_response_anonymized = gr.Textbox(label="ChatGPT Anonymized Response", lines=13)
177
- chatgpt_response_deanonymized = gr.Textbox(label="ChatGPT Deanonymized Response", lines=13)
 
 
 
 
178
 
179
  chatgpt_button = gr.Button("Query ChatGPT")
 
 
 
 
 
 
 
180
  chatgpt_button.click(
181
- query_chatgpt,
182
- inputs=[anonymized_text_output],
183
  outputs=[chatgpt_response_anonymized, chatgpt_response_deanonymized],
184
  )
185
 
 
 
 
 
 
 
 
 
 
186
  # Launch the app
187
  demo.launch(share=False)
 
1
  """A Gradio app for anonymizing text data using FHE."""
2
 
3
+ import os
4
+ import re
5
+ from typing import Dict, List
6
+
7
  import gradio as gr
 
8
  import pandas as pd
9
+ from fhe_anonymizer import FHEAnonymizer
10
  from openai import OpenAI
 
 
 
11
  from utils_demo import *
12
+
13
+ ORIGINAL_DOCUMENT = read_txt(ORIGINAL_FILE_PATH).split("\n\n")
14
+ ANONYMIZED_DOCUMENT = read_txt(ANONYMIZED_FILE_PATH)
15
+ MAPPING_SENTENCES = read_pickle(MAPPING_SENTENCES_PATH)
16
+
17
+ clean_directory()
18
 
19
  anonymizer = FHEAnonymizer()
20
 
21
+ client = OpenAI(api_key=os.environ.get("openaikey"))
 
 
22
 
23
 
24
+ def select_static_sentences_fn(selected_sentences: List):
25
+
26
+ selected_sentences = [MAPPING_SENTENCES[sentence] for sentence in selected_sentences]
27
+
28
+ anonymized_selected_sentence = sorted(selected_sentences, key=lambda x: x[0])
29
+
30
+ anonymized_selected_sentence = [sentence for _, sentence in anonymized_selected_sentence]
31
+
32
+ return {anonymized_doc_box: gr.update(value="\n\n".join(anonymized_selected_sentence))}
33
+
34
+
35
+ def key_gen_fn() -> Dict:
36
+ """Generate keys for a given user.
37
+
38
+ Returns:
39
+ dict: A dictionary containing the generated keys and related information.
40
+ """
41
+ print("Key Gen..")
42
+
43
+ anonymizer.generate_key()
44
+
45
+ evaluation_key_path = KEYS_DIR / "evaluation_key"
46
+
47
+ if not evaluation_key_path.is_file():
48
+ error_message = (
49
+ f"Error Encountered While generating the evaluation {evaluation_key_path.is_file()=}"
50
+ )
51
+ print(error_message)
52
+ return {gen_key_btn: gr.update(value=error_message)}
53
  else:
54
+ return {gen_key_btn: gr.update(value="Keys have been generated ✅")}
55
+
56
+
57
+ def encrypt_query_fn(query):
58
+ print(f"Query: {query}")
59
+
60
+ evaluation_key_path = KEYS_DIR / "evaluation_key"
61
+
62
+ if not evaluation_key_path.is_file():
63
+ error_message = "Error ❌: Please generate the key first!"
64
+ return {output_encrypted_box: gr.update(value=error_message)}
65
+
66
+ if is_user_query_valid(query):
67
+ # TODO: check if the query is related to our context
68
+ error_msg = (
69
+ "Unable to process ❌: The request exceeds the length limit or falls "
70
+ "outside the scope of this document. Please refine your query."
71
+ )
72
+ print(error_msg)
73
+ return {query_box: gr.update(value=error_msg)}
74
+
75
+ anonymizer.encrypt_query(query)
76
+
77
+ encrypted_tokens = read_pickle(KEYS_DIR / "encrypted_quantized_query")
78
+
79
+ encrypted_quant_tokens_hex = [token.hex()[500:510] for token in encrypted_tokens]
80
+
81
+ return {output_encrypted_box: gr.update(value=" ".join(encrypted_quant_tokens_hex))}
82
+
83
+
84
+ def run_fhe_fn(query_box):
85
+
86
+ evaluation_key_path = KEYS_DIR / "evaluation_key"
87
+ if not evaluation_key_path.is_file():
88
+ error_message = "Error ❌: Please generate the key first!"
89
+ return {anonymized_text_output: gr.update(value=error_message)}
90
+
91
+ encryted_query_path = KEYS_DIR / "encrypted_quantized_query"
92
+ if not encryted_query_path.is_file():
93
+ error_message = "Error ❌: Please encrypt your query first!"
94
+ return {anonymized_text_output: gr.update(value=error_message)}
95
+
96
+ anonymizer.run_server_and_decrypt_output(query_box)
97
+
98
+ anonymized_text = read_pickle(KEYS_DIR / "reconstructed_sentence")
99
+ identified_words_with_prob = read_pickle(KEYS_DIR / "identified_words_with_prob")
100
 
101
  # Convert the list of identified words and probabilities into a DataFrame
102
  if identified_words_with_prob:
 
108
  return anonymized_text, identified_df
109
 
110
 
111
+ def query_chatgpt_fn(anonymized_query, anonymized_document):
112
+
113
+ evaluation_key_path = KEYS_DIR / "evaluation_key"
114
+ if not evaluation_key_path.is_file():
115
+ error_message = "Error ❌: Please generate the key first!"
116
+ return {anonymized_text_output: gr.update(value=error_message)}
117
 
118
+ encryted_query_path = KEYS_DIR / "encrypted_quantized_query"
119
+ if not encryted_query_path.is_file():
120
+ error_message = "Error ❌: Please encrypt your query first!"
121
+ return {anonymized_text_output: gr.update(value=error_message)}
122
+
123
+ decrypted_query_path = KEYS_DIR / "reconstructed_sentence"
124
+ if not decrypted_query_path.is_file():
125
+ error_message = "Error ❌: Please run the FHE computation first!"
126
+ return {anonymized_text_output: gr.update(value=error_message)}
127
+
128
+ prompt = read_txt(PROMPT_PATH)
129
 
130
  # Prepare prompt
131
+ full_prompt = prompt + "\n"
132
+ query = (
133
+ "Document content:\n```\n"
134
+ + anonymized_document
135
+ + "\n\n```"
136
+ + "Query:\n```\n"
137
+ + anonymized_query
138
+ + "\n```"
139
  )
 
140
  print(full_prompt)
141
 
142
  completion = client.chat.completions.create(
 
147
  ],
148
  )
149
  anonymized_response = completion.choices[0].message.content
150
+ uuid_map = read_json(MAPPING_UUID_PATH)
151
+
152
+ inverse_uuid_map = {
153
+ v: k for k, v in uuid_map.items()
154
+ } # TODO load the inverse mapping from disk for efficiency
155
 
156
  # Pattern to identify words and non-words (including punctuation, spaces, etc.)
157
+ tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", anonymized_response)
 
158
  processed_tokens = []
159
 
 
160
  for token in tokens:
161
  # Directly append non-word tokens or whitespace to processed_tokens
162
  if not token.strip() or not re.match(r"\w+", token):
 
171
  return anonymized_response, deanonymized_response
172
 
173
 
 
 
 
 
 
 
174
  demo = gr.Blocks(css=".markdown-body { font-size: 18px; }")
175
 
176
  with demo:
 
186
 
187
  <a href="https://docs.zama.ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="file/images/logos/documentation.png">Documentation</a>
188
 
189
+ <a href=" https://community.zama.ai/c/concrete-ml/8"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="file/images/logos/community.png">Community</a>
190
 
191
  <a href="https://twitter.com/zama_fhe"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="file/images/logos/x.png">@zama_fhe</a>
192
  </p>
193
  """
194
  )
195
 
196
+ # gr.Markdown(
197
+ # """
198
+ # <p align="center">
199
+ # <img width="15%" height="15%" src="./encrypted_anonymization_diagram.jpg">
200
+ # </p>
201
+ # """
202
+ # )
203
+
204
+ with gr.Accordion("What is encrypted anonymization?", open=False):
205
+ gr.Markdown(
206
+ <<<<<<< HEAD
207
  """
208
+ Anonymization is the process of removing personally identifiable information (PII)
209
+ =======
210
+ """Anonymization is the process of removing personally identifiable information (PII)
211
+ >>>>>>> 053bec9 (chore: update with marketing remarks)
212
+ from data to protect individual privacy.
213
+
214
+ To resolve trust issues when deploying anonymization as a cloud service, Fully Homomorphic
215
+ Encryption (FHE) can be used to preserve the privacy of the original data using
216
+ encryption.
217
+
218
+ The data remains encrypted throughout the anonymization process, eliminating the need for
219
+ third-party access to the raw data. Once the data is anonymized, it can safely be sent
220
+ to GenAI services such as ChatGPT.
221
  """
222
+ )
223
+
224
+ ########################## Key Gen Part ##########################
225
+
226
+ gr.Markdown(
227
+ "### Key generation\n\n"
228
+ """In FHE schemes, two sets of keys are generated. First, secret keys are used for
229
+ encrypting and decrypting data owned by the client. Second, evaluation keys allow a server
230
+ to blindly process the encrypted data. """
231
  )
232
 
233
+ gen_key_btn = gr.Button("Generate the private and evaluation keys")
234
+
235
+ gen_key_btn.click(
236
+ key_gen_fn,
237
+ inputs=[],
238
+ outputs=[gen_key_btn],
239
+ )
 
 
240
 
241
  ########################## Main document Part ##########################
242
 
243
+ gr.Markdown("## Private document")
244
+
245
  with gr.Row():
246
  with gr.Column():
247
+ gr.Markdown(
248
+ """This document was retrieved from the [Microsoft Presidio](https://huggingface.co/spaces/presidio/presidio_demo) demo.\n\n
249
+ You can select and deselect sentences to customize the document that will be used
250
+ as the initial prompt for ChatGPT in this space's final stage.\n\n
251
+ """
252
+ )
253
+ with gr.Column():
254
+ gr.Markdown(
255
+ """You can see the anonymized document that is sent to ChatGPT here.
256
+ ChatGPT will answer any queries that you have about the document below.
257
+ The anonymized information is replaced with hexadecimal strings.
258
+ """
259
+ )
260
+
261
+ with gr.Row():
262
  with gr.Column():
263
+ original_sentences_box = gr.CheckboxGroup(
264
+ ORIGINAL_DOCUMENT, value=ORIGINAL_DOCUMENT, label="Original document:"
265
+ )
266
+
267
+ with gr.Column():
268
+ anonymized_doc_box = gr.Textbox(
269
+ label="Anonymized document:", value=ANONYMIZED_DOCUMENT, interactive=False, lines=11
270
+ )
271
+
272
+ original_sentences_box.change(
273
+ fn=select_static_sentences_fn,
274
+ inputs=[original_sentences_box],
275
+ outputs=[anonymized_doc_box],
276
+ )
277
 
278
  ########################## User Query Part ##########################
279
+
280
+ gr.Markdown("<hr />")
281
+ gr.Markdown("## Private query")
282
+
283
+ gr.Markdown(
284
+ """Now, formulate a query regarding the selected document.\n\n
285
+
286
+ Choose from predefined options in 'Example Queries' or craft a custom query
287
+ in the 'User Query' box. Keep your question concise and relevant to the text's
288
+ context. Any off-topic question will not be processed.
289
+ """
290
+ )
291
+
292
  with gr.Row():
293
+ with gr.Column(scale=5):
294
+
295
+ with gr.Column(scale=5):
296
+ default_query_box = gr.Dropdown(
297
+ list(DEFAULT_QUERIES.values()), label="Example queries"
298
+ )
299
+
300
+ query_box = gr.Textbox(
301
+ value="Who lives in Maine?", label="User query", interactive=True
302
+ )
303
+
304
+ default_query_box.change(
305
+ fn=lambda default_query_box: default_query_box,
306
+ inputs=[default_query_box],
307
+ outputs=[query_box],
308
+ )
309
+
310
+ with gr.Column(scale=1, min_width=6):
311
+ gr.HTML("<div style='height: 25px;'></div>")
312
+
313
+ gr.Markdown(
314
+ """
315
+ <p align="center">
316
+ Encrypt data locally with FHE 💻 ⚙️
317
+ </p>
318
+ """
319
+ )
320
+ encrypt_btn = gr.Button("Encrypt data")
321
+ gr.HTML("<div style='height: 25px;'></div>")
322
+
323
+ with gr.Column(scale=5):
324
+ output_encrypted_box = gr.Textbox(
325
+ label="Encrypted anonymized query that is sent to the anonymization server", lines=6
326
+ )
327
+
328
+ encrypt_btn.click(
329
+ fn=encrypt_query_fn, inputs=[query_box], outputs=[query_box, output_encrypted_box]
330
+ )
331
 
332
+ gr.Markdown("<hr />")
333
+ gr.Markdown("## Secure anonymization with FHE")
334
+ gr.Markdown(
335
+ """
336
+ Once the client encrypts the private query locally,
337
+ the client transmits it to a remote server to perform the
338
+ anonymization on encrypted data. When the computation is finished, the server returns
339
+ the result to the client for decryption.
340
+ """
341
+ )
342
 
343
+ run_fhe_btn = gr.Button("Anonymize with FHE")
344
 
345
+ anonymized_text_output = gr.Textbox(
346
+ label="Decrypted anonymized query that will be sent to ChatGPT", lines=1, interactive=True
347
+ )
348
 
349
+ identified_words_output = gr.Dataframe(label="Identified words", visible=False)
350
 
351
+ run_fhe_btn.click(
352
+ run_fhe_fn,
353
+ inputs=[query_box],
354
  outputs=[anonymized_text_output, identified_words_output],
355
  )
356
 
357
+ gr.Markdown("<hr />")
358
+
359
+ gr.Markdown("## Secure your communication on ChatGPT with anonymized queries")
360
+ gr.Markdown(
361
+ """After securely anonymizing the query with FHE,
362
+ you can forward it to ChatGPT without any concern for information leakage."""
363
+ )
364
 
365
  chatgpt_button = gr.Button("Query ChatGPT")
366
+
367
+ with gr.Row():
368
+ chatgpt_response_anonymized = gr.Textbox(label="ChatGPT anonymized response", lines=13)
369
+ chatgpt_response_deanonymized = gr.Textbox(
370
+ label="ChatGPT non-anonymized response", lines=13
371
+ )
372
+
373
  chatgpt_button.click(
374
+ query_chatgpt_fn,
375
+ inputs=[anonymized_text_output, anonymized_doc_box],
376
  outputs=[chatgpt_response_anonymized, chatgpt_response_deanonymized],
377
  )
378
 
379
+ gr.Markdown(
380
+ """**Please Note**: As this space is intended solely for demonstration purposes, some
381
+ private information may be missed the the anonymization algorithm. Please validate the
382
+ following query before sending it to ChatGPT."""
383
+ )
384
+ <<<<<<< HEAD
385
+ =======
386
+
387
+ >>>>>>> 053bec9 (chore: update with marketing remarks)
388
  # Launch the app
389
  demo.launch(share=False)
demo_text.txt DELETED
@@ -1 +0,0 @@
1
- who lives in Maine?
 
 
encrypted_anonymization_diagram.jpg DELETED
Binary file (94.7 kB)
 
fhe_anonymizer.py CHANGED
@@ -1,73 +1,123 @@
1
- import gensim
2
  import re
3
- from concrete.ml.deployment import FHEModelClient, FHEModelServer
 
4
  from pathlib import Path
 
 
 
 
5
  from concrete.ml.common.serialization.loaders import load
6
- import uuid
7
- import json
8
- from transformers import AutoTokenizer, AutoModel
9
- from utils_demo import get_batch_text_representation
10
 
11
- base_dir = Path(__file__).parent
 
 
 
 
 
12
 
13
 
14
  class FHEAnonymizer:
15
- def __init__(self, punctuation_list=".,!?:;"):
16
 
17
  # Load tokenizer and model
18
  self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
19
  self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
20
 
21
- self.punctuation_list = punctuation_list
 
 
 
 
 
 
22
 
23
- with open(base_dir / "original_document_uuid_mapping.json", 'r') as file:
24
- self.uuid_map = json.load(file)
25
 
26
- path_to_model = (base_dir / "deployment").resolve()
27
- self.client = FHEModelClient(path_to_model)
28
- self.server = FHEModelServer(path_to_model)
29
  self.client.generate_private_and_evaluation_keys()
 
 
30
  self.evaluation_key = self.client.get_serialized_evaluation_keys()
 
 
 
31
 
32
- def fhe_inference(self, x):
33
- enc_x = self.client.quantize_encrypt_serialize(x)
34
- enc_y = self.server.run(enc_x, self.evaluation_key)
35
- y = self.client.deserialize_decrypt_dequantize(enc_y)
36
- return y
37
 
38
- def __call__(self, text: str):
39
  # Pattern to identify words and non-words (including punctuation, spaces, etc.)
40
- token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
41
- tokens = re.findall(token_pattern, text)
42
- identified_words_with_prob = []
43
- processed_tokens = []
44
 
45
  for token in tokens:
46
- # Directly append non-word tokens or whitespace to processed_tokens
47
- if not token.strip() or not re.match(r"\w+", token):
48
- processed_tokens.append(token)
49
  continue
 
50
 
51
  # Prediction for each word
52
- x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
 
 
53
 
54
- prediction_proba = self.fhe_inference(x)
55
- probability = prediction_proba[0][1]
56
 
57
- if probability >= 0.5:
58
- identified_words_with_prob.append((token, probability))
59
 
60
- # Use the existing UUID if available, otherwise generate a new one
61
- tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
62
- processed_tokens.append(tmp_uuid)
63
- self.uuid_map[token] = tmp_uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  else:
65
- processed_tokens.append(token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Update the UUID map with query.
68
- with open(base_dir / "original_document_uuid_mapping.json", 'w') as file:
69
- json.dump(self.uuid_map, file)
70
 
71
- # Reconstruct the sentence
72
- reconstructed_sentence = ''.join(processed_tokens)
73
- return reconstructed_sentence, identified_words_with_prob
 
1
+ import json
2
  import re
3
+ import time
4
+ import uuid
5
  from pathlib import Path
6
+
7
+ from transformers import AutoModel, AutoTokenizer
8
+ from utils_demo import *
9
+
10
  from concrete.ml.common.serialization.loaders import load
11
+ from concrete.ml.deployment import FHEModelClient, FHEModelServer
 
 
 
12
 
13
+ TOLERANCE_PROBA = 0.77
14
+
15
+ CURRENT_DIR = Path(__file__).parent
16
+
17
+ DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
18
+ KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
19
 
20
 
21
  class FHEAnonymizer:
22
+ def __init__(self):
23
 
24
  # Load tokenizer and model
25
  self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
26
  self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
27
 
28
+ self.punctuation_list = PUNCTUATION_LIST
29
+ self.uuid_map = read_json(MAPPING_UUID_PATH)
30
+
31
+ self.client = FHEModelClient(DEPLOYMENT_DIR, key_dir=KEYS_DIR)
32
+ self.server = FHEModelServer(DEPLOYMENT_DIR)
33
+
34
+ def generate_key(self):
35
 
36
+ clean_directory()
 
37
 
38
+ # Creates the private and evaluation keys on the client side
 
 
39
  self.client.generate_private_and_evaluation_keys()
40
+
41
+ # Get the serialized evaluation keys
42
  self.evaluation_key = self.client.get_serialized_evaluation_keys()
43
+ assert isinstance(self.evaluation_key, bytes)
44
+
45
+ evaluation_key_path = KEYS_DIR / "evaluation_key"
46
 
47
+ with evaluation_key_path.open("wb") as f:
48
+ f.write(self.evaluation_key)
 
 
 
49
 
50
+ def encrypt_query(self, text: str):
51
  # Pattern to identify words and non-words (including punctuation, spaces, etc.)
52
+ tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
53
+ encrypted_tokens = []
 
 
54
 
55
  for token in tokens:
56
+ if bool(re.match(r"^\s+$", token)):
 
 
57
  continue
58
+ # Directly append non-word tokens or whitespace to processed_tokens
59
 
60
  # Prediction for each word
61
+ emb_x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
62
+ encrypted_x = self.client.quantize_encrypt_serialize(emb_x)
63
+ assert isinstance(encrypted_x, bytes)
64
 
65
+ encrypted_tokens.append(encrypted_x)
 
66
 
67
+ write_pickle(KEYS_DIR / f"encrypted_quantized_query", encrypted_tokens)
 
68
 
69
+ def run_server(self):
70
+
71
+ encrypted_tokens = read_pickle(KEYS_DIR / f"encrypted_quantized_query")
72
+
73
+ encrypted_output, timing = [], []
74
+ for enc_x in encrypted_tokens:
75
+ start_time = time.time()
76
+ enc_y = self.server.run(enc_x, self.evaluation_key)
77
+ timing.append((time.time() - start_time) / 60.0)
78
+ encrypted_output.append(enc_y)
79
+
80
+ write_pickle(KEYS_DIR / f"encrypted_output", encrypted_output)
81
+ write_pickle(KEYS_DIR / f"encrypted_timing", timing)
82
+
83
+ return encrypted_output, timing
84
+
85
+ def decrypt_output(self, text):
86
+
87
+ encrypted_output = read_pickle(KEYS_DIR / f"encrypted_output")
88
+
89
+ tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
90
+ decrypted_output, identified_words_with_prob = [], []
91
+
92
+ i = 0
93
+ for token in tokens:
94
+ # Directly append non-word tokens or whitespace to processed_tokens
95
+ if bool(re.match(r"^\s+$", token)):
96
+ continue
97
  else:
98
+ encrypted_token = encrypted_output[i]
99
+ prediction_proba = self.client.deserialize_decrypt_dequantize(encrypted_token)
100
+ probability = prediction_proba[0][1]
101
+ i += 1
102
+
103
+ if probability >= TOLERANCE_PROBA:
104
+ identified_words_with_prob.append((token, probability))
105
+
106
+ # Use the existing UUID if available, otherwise generate a new one
107
+ tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
108
+ decrypted_output.append(tmp_uuid)
109
+ self.uuid_map[token] = tmp_uuid
110
+ else:
111
+ decrypted_output.append(token)
112
+
113
+ # Update the UUID map with query.
114
+ with open(MAPPING_UUID_PATH, "w") as file:
115
+ json.dump(self.uuid_map, file)
116
+
117
+ write_pickle(KEYS_DIR / f"reconstructed_sentence", " ".join(decrypted_output))
118
+ write_pickle(KEYS_DIR / f"identified_words_with_prob", identified_words_with_prob)
119
 
 
 
 
120
 
121
+ def run_server_and_decrypt_output(self, text):
122
+ self.run_server()
123
+ self.decrypt_output(text)
files/anonymized_document.txt CHANGED
@@ -1,10 +1,10 @@
1
- 84381322, my name is 8b9ec610 8c6d3442 and I live in 269b9686.
2
- My credit card number is c075beec and my crypto wallet id is 54344fd4.
3
 
4
- On 9d6193ab 57c4ba7a I visited ea9cc7db and sent an email to d2934e4f, from the IP 1a26727d.
5
 
6
- My 694a9044: 8d6f2b87 and my phone number: 6491a9cd 2a61cfbc.
7
 
8
- This is a valid a1cc4c7e 46e4a44b Account Number: de6fd087 . Can you please check the status on bank account 9277229c?
9
 
10
- 4571d08d's social security number is 095fa9c8. 290451c3 driver license? it is 778679d7.
 
1
+ Hello, my name is ebe99761 53a9291d and I live in 6337f12f.
2
+ My credit card number is e5b499b0 and my crypto wallet id is ac41d58b.
3
 
4
+ On September 18 I visited 0d574451 and sent an email to 1f78e797, from the IP 116fe81e.
5
 
6
+ My passport: 59a83e41 and my phone number: 144a2acc d9e5704e.
7
 
8
+ This is a valid 71d0f51c Bank Account Number: 5ca977a4. Can you please check the status on bank account 9eb07461?
9
 
10
+ b474d794's social security number is d8da62f1. Her driver license? it is 5e63c327.
files/chatgpt_prompt.txt CHANGED
@@ -5,5 +5,6 @@ Details:
5
  - Sensitive information includes: names, locations, credit card numbers, email addresses, IP addresses, passport details, phone numbers, bank accounts, social security numbers, and driver's licenses.
6
  - Each piece of information is represented by a unique identifier, maintaining privacy while discussing document content.
7
  - Your role is to interpret the document's anonymized content and accurately respond to queries using the identifiers.
 
8
  - Consistency in identifiers is crucial for connecting the text with the queries correctly.
9
  - You must not discuss the anonymized nature of the text and use the identifiers as if they were real words for a smooth chat with users.
 
5
  - Sensitive information includes: names, locations, credit card numbers, email addresses, IP addresses, passport details, phone numbers, bank accounts, social security numbers, and driver's licenses.
6
  - Each piece of information is represented by a unique identifier, maintaining privacy while discussing document content.
7
  - Your role is to interpret the document's anonymized content and accurately respond to queries using the identifiers.
8
+ - Any question outside the content of the document are forbidden, reply that it is out of the scope, do not answer that question, and warn the user to try another question.
9
  - Consistency in identifiers is crucial for connecting the text with the queries correctly.
10
  - You must not discuss the anonymized nature of the text and use the identifiers as if they were real words for a smooth chat with users.
models/embedded_model.model → files/mapping_clear_to_anonymized.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:28fcf483356bf2bef29b8220b84803acf9518f19fbc9342e76cac06b30803f28
3
- size 73056
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:944e5c32bd04e955194c513d35b91467615c08973c767745a1756d015b3e6ebb
3
+ size 1085
files/original_document.txt CHANGED
@@ -5,6 +5,6 @@ On September 18 I visited microsoft.com and sent an email to test@presidio.site,
5
 
6
  My passport: 191280342 and my phone number: (212) 555-1234.
7
 
8
- This is a valid International Bank Account Number: IL150120690000003111111 . Can you please check the status on bank account 954567876544?
9
 
10
  Kate's social security number is 078-05-1126. Her driver license? it is 1234567A.
 
5
 
6
  My passport: 191280342 and my phone number: (212) 555-1234.
7
 
8
+ This is a valid International Bank Account Number: IL150120690000003111111. Can you please check the status on bank account 954567876544?
9
 
10
  Kate's social security number is 078-05-1126. Her driver license? it is 1234567A.
files/original_document_uuid_mapping.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"078-05-1126": "d8da62f1", "1234567A": "5e63c327", "16Yeky6GMjeNkAiNcBY7ZhrLoMSgg1BoyZ": "ac41d58b", "191280342": "59a83e41", "192.168.0.1": "116fe81e", "212": "144a2acc", "4095-2609-9393-4932": "e5b499b0", "555-1234": "d9e5704e", "954567876544": "9eb07461", "David": "ebe99761", "IL150120690000003111111": "5ca977a4", "International": "71d0f51c", "Johnson": "53a9291d", "Kate": "b474d794", "Maine": "6337f12f", "microsoft.com": "0d574451", "test@presidio.site": "1f78e797"}
models/embedded_model.model.wv.vectors_ngrams.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:faf08ed9c3bc29cf71c16f5d2b311f3bfb730a92f12c2e52d742bc6b59bf9e5f
3
- size 800000128
 
 
 
 
models/without_pronoun_cml_xgboost.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:933d1d5c5f83c30211dd9a497482c517a822df809c0498fed164de72bd7bf910
3
- size 1085795
 
 
 
 
models/without_pronoun_embedded_model.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:762240ca4040c68e44c403f16abce5683a0c4a005ec10f3dd0135a0e429a66c1
3
- size 1189196
 
 
 
 
models/without_pronoun_embedded_model.model.wv.vectors_ngrams.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cf06fe78185b373c97ee0616f599ce6b1aceb6445b8f666fac6cd4cd307fe46
3
- size 400000128
 
 
 
 
original_document_uuid_mapping.json DELETED
@@ -1 +0,0 @@
1
- {"078-05-1126": "095fa9c8", "1234567A": "778679d7", "16Yeky6GMjeNkAiNcBY7ZhrLoMSgg1BoyZ": "54344fd4", "18": "57c4ba7a", "191280342": "8d6f2b87", "192.168.0.1": "1a26727d", "212": "6491a9cd", "4095-2609-9393-4932": "c075beec", "555-1234": "2a61cfbc", "954567876544": "9277229c", "Bank": "46e4a44b", "David": "8b9ec610", "Hello": "84381322", "Her": "290451c3", "IL150120690000003111111": "de6fd087", "International": "a1cc4c7e", "Johnson": "8c6d3442", "Kate": "4571d08d", "Maine": "269b9686", "September": "9d6193ab", "microsoft.com": "ea9cc7db", "passport": "694a9044", "test@presidio.site": "d2934e4f"}
 
 
utils_demo.py CHANGED
@@ -1,28 +1,68 @@
1
- import torch
2
- import numpy as np
 
 
 
 
 
 
3
 
 
 
4
 
5
- MAX_USER_QUERY_LEN = 35
6
 
7
  # List of example queries for easy access
8
  DEFAULT_QUERIES = {
9
  "Example Query 1": "Who visited microsoft.com on September 18?",
10
- "Example Query 2": "Does Kate has drive ?",
11
- "Example Query 3": "What phone number can be used to contact David Johnson?",
12
  }
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
15
- """
16
- Get mean-pooled representations of given texts in batches.
17
- """
18
  mean_pooled_batch = []
19
  for i in range(0, len(texts), batch_size):
20
- batch_texts = texts[i:i+batch_size]
21
  inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
22
  with torch.no_grad():
23
  outputs = model(**inputs, output_hidden_states=False)
24
  last_hidden_states = outputs.last_hidden_state
25
- input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden_states.size()).float()
 
 
26
  sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
27
  sum_mask = input_mask_expanded.sum(1)
28
  mean_pooled = sum_embeddings / sum_mask
@@ -39,11 +79,82 @@ def is_user_query_valid(user_query: str) -> bool:
39
  bool: True if the `user_query` is None or empty, False otherwise.
40
  """
41
  # If the query is not part of the default queries
42
- is_default_query = user_query in DEFAULT_QUERIES.values()
43
-
44
  # Check if the query exceeds the length limit
45
  is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN
46
-
47
  return not is_default_query and not is_exceeded_max_length
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle as pkl
4
+ import re
5
+ import shutil
6
+ import string
7
+ from collections import Counter
8
+ from pathlib import Path
9
 
10
+ import numpy as np
11
+ import torch
12
 
13
+ MAX_USER_QUERY_LEN = 80
14
 
15
  # List of example queries for easy access
16
  DEFAULT_QUERIES = {
17
  "Example Query 1": "Who visited microsoft.com on September 18?",
18
+ "Example Query 2": "Does Kate have a driving licence?",
19
+ "Example Query 3": "What's David Johnson's phone number?",
20
  }
21
 
22
+
23
+ CURRENT_DIR = Path(__file__).parent
24
+
25
+ DATA_PATH = CURRENT_DIR / "files"
26
+ LOGREG_MODEL_PATH = CURRENT_DIR / "models" / "cml_logreg.model"
27
+ DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
28
+ KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
29
+
30
+ ORIGINAL_FILE_PATH = DATA_PATH / "original_document.txt"
31
+ ANONYMIZED_FILE_PATH = DATA_PATH / "anonymized_document.txt"
32
+ MAPPING_UUID_PATH = DATA_PATH / "original_document_uuid_mapping.json"
33
+ MAPPING_SENTENCES_PATH = DATA_PATH / "mapping_clear_to_anonymized.pkl"
34
+ PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt"
35
+
36
+ ALL_DIRS = [KEYS_DIR]
37
+
38
+ PUNCTUATION_LIST = list(string.punctuation)
39
+ PUNCTUATION_LIST.remove("%")
40
+ PUNCTUATION_LIST.remove("$")
41
+ PUNCTUATION_LIST = "".join(PUNCTUATION_LIST)
42
+
43
+
44
+ def clean_directory() -> None:
45
+ """Clear direcgtories"""
46
+
47
+ print("Cleaning...\n")
48
+ for target_dir in ALL_DIRS:
49
+ if os.path.exists(target_dir) and os.path.isdir(target_dir):
50
+ shutil.rmtree(target_dir)
51
+ target_dir.mkdir(exist_ok=True, parents=True)
52
+
53
+
54
  def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
55
+ """Get mean-pooled representations of given texts in batches."""
 
 
56
  mean_pooled_batch = []
57
  for i in range(0, len(texts), batch_size):
58
+ batch_texts = texts[i : i + batch_size]
59
  inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
60
  with torch.no_grad():
61
  outputs = model(**inputs, output_hidden_states=False)
62
  last_hidden_states = outputs.last_hidden_state
63
+ input_mask_expanded = (
64
+ inputs["attention_mask"].unsqueeze(-1).expand(last_hidden_states.size()).float()
65
+ )
66
  sum_embeddings = torch.sum(last_hidden_states * input_mask_expanded, 1)
67
  sum_mask = input_mask_expanded.sum(1)
68
  mean_pooled = sum_embeddings / sum_mask
 
79
  bool: True if the `user_query` is None or empty, False otherwise.
80
  """
81
  # If the query is not part of the default queries
82
+ is_default_query = user_query in DEFAULT_QUERIES.values()
83
+
84
  # Check if the query exceeds the length limit
85
  is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN
86
+
87
  return not is_default_query and not is_exceeded_max_length
88
 
89
 
90
+ def compare_texts_ignoring_extra_spaces(original_text, modified_text):
91
+ """Check if the modified_text is identical to the original_text except for additional spaces.
92
+
93
+ Args:
94
+ original_text (str): The original text for comparison.
95
+ modified_text (str): The modified text to compare against the original.
96
+
97
+ Returns:
98
+ (bool): True if the modified_text is the same as the original_text except for
99
+ additional spaces; False otherwise.
100
+ """
101
+ normalized_original = " ".join(original_text.split())
102
+ normalized_modified = " ".join(modified_text.split())
103
+
104
+ return normalized_original == normalized_modified
105
+
106
+
107
+ def is_strict_deletion_only(original_text, modified_text):
108
+
109
+ # Define a regex pattern that matches a word character next to a punctuation
110
+ # or a punctuation next to a word character, without a space between them.
111
+ pattern = r"(?<=[\w])(?=[^\w\s])|(?<=[^\w\s])(?=[\w])"
112
+
113
+ # Replace instances found by the pattern with a space
114
+ original_text = re.sub(pattern, " ", original_text)
115
+ modified_text = re.sub(pattern, " ", modified_text)
116
+
117
+ # Tokenize the texts into words, considering also punctuation
118
+ original_words = Counter(original_text.lower().split())
119
+ modified_words = Counter(modified_text.lower().split())
120
+
121
+ base_words = all(item in original_words.keys() for item in modified_words.keys())
122
+ base_count = all(original_words[k] >= v for k, v in modified_words.items())
123
+
124
+ return base_words and base_count
125
+
126
+
127
+ def read_txt(file_path):
128
+ """Read text from a file."""
129
+ with open(file_path, "r", encoding="utf-8") as file:
130
+ return file.read()
131
+
132
+
133
+ def write_txt(file_path, data):
134
+ """Write text to a file."""
135
+ with open(file_path, "w", encoding="utf-8") as file:
136
+ file.write(data)
137
+
138
+
139
+ def write_pickle(file_path, data):
140
+ """Save data to a pickle file."""
141
+ with open(file_path, "wb") as f:
142
+ pkl.dump(data, f)
143
+
144
+
145
+ def read_pickle(file_name):
146
+ """Load data from a pickle file."""
147
+ with open(file_name, "rb") as file:
148
+ return pkl.load(file)
149
+
150
+
151
+ def read_json(file_name):
152
+ """Load data from a json file."""
153
+ with open(file_name, "r") as file:
154
+ return json.load(file)
155
+
156
+
157
+ def write_json(file_name, data):
158
+ """Save data to a json file."""
159
+ with open(file_name, "w", encoding="utf-8") as file:
160
+ json.dump(data, file, indent=4, sort_keys=True)