import os # Uncomment if run locally # import sys # sys.path.append(os.path.abspath("../../../molvault")) # sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from requests import head from concrete.ml.deployment import FHEModelClient import numpy import os from pathlib import Path import requests import json import base64 import subprocess import shutil import time from chemdata import get_ECFP_AND_FEATURES import streamlit as st import cairosvg from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem.Draw import rdMolDraw2D import pandas as pd from st_keyup import st_keyup import pickle import numpy as np st.set_page_config(layout="wide", page_title="VaultChem") def local_css(file_name): with open(file_name) as f: st.markdown(f"", unsafe_allow_html=True) local_css("style.css") def img_to_bytes(img_path): img_bytes = Path(img_path).read_bytes() encoded = base64.b64encode(img_bytes).decode() return encoded def img_to_html(img_path, width=None): img_bytes = img_to_bytes(img_path) if width: img_html = "".format( img_bytes, width ) else: img_html = "".format( img_bytes ) return img_html # Start timing formatted_text = ( "

" "Pharmacokinetics" " of " "🀫confidential" " molecules" "

" ) st.markdown(formatted_text, unsafe_allow_html=True) # make a small hint that the app needs a few seconds to start st.markdown( "

" + "The app needs a second to start...not optimized for mobile yet. πŸš€" + "

", unsafe_allow_html=True, ) interesting_text = """ Machine learning (**ML**) has become a cornerstone of modern drug discovery. However, the data used to evaluate the ML models is often **confidential**. This is especially true for the pharmaceutical industry where new drug candidates are considered as the most valuable asset. Therefore chemical companies are reluctant to share their data with third parties, for instance, to use ML services provided by other companies. πŸ”’**We implemented a workflow that allows predicting properties of a molecule with third-party models without sharing them**πŸ”’. That means an organization "A" can use any server - even an untrusted environment - outside of their infrastructure to perform the prediction. This way organization "A" can benefit from ML services provided by organization "B" without sharing their confidential data. πŸͺ„ **The magic?** πŸͺ„ The server on which the prediction is computed will never see the molecule in clear text, but will still compute an encrypted prediction. Why is this **magic**? Because this is equivalent to computing the prediction on the molecule in clear text, but without sharing the molecule with the server. Even if organization "B" - or in fact any other party - would try to steal the data, they would only see the encrypted molecular data. **Only the party that has the private key (organization "A") can decrypt the prediction**. This is possible using a method called "Fully Homomorphic Encryption" (FHE). This special encryption scheme allows to perform computations on encrypted data, to learn more about FHE, click [here](https://fhe.org/resources/). We use the open-source library Concrete-ML to develop safe and robust encryption technology. The code used for the FHE prediction is available in the open-source library \n **What are the steps involved?** \n Find out below! πŸ‘‡ You can try it for yourself! πŸŽ‰ """ st.markdown( f"{interesting_text}", unsafe_allow_html=True, ) st.divider() st.markdown( "

" + img_to_html("scheme2.png", width="65%") + "

", unsafe_allow_html=True, ) # Define your data st.divider() # This repository's directory REPO_DIR = Path(__file__).parent subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) # if not exists, create a directory for the FHE keys called .fhe_keys if not os.path.exists(".fhe_keys"): os.mkdir(".fhe_keys") # if not exists, create a directory for the tmp files called tmp if not os.path.exists("tmp"): os.mkdir("tmp") # Wait 4 sec for the server to start time.sleep(4) # Encrypted data limit for the browser to display # (encrypted data is too large to display in the browser) ENCRYPTED_DATA_BROWSER_LIMIT = 500 N_USER_KEY_STORED = 20 def clean_tmp_directory(): # Allow 20 user keys to be stored. # Once that limitation is reached, deleted the oldest. path_sub_directories = sorted( [f for f in Path(".fhe_keys/").iterdir() if f.is_dir()], key=os.path.getmtime ) user_ids = [] if len(path_sub_directories) > N_USER_KEY_STORED: n_files_to_delete = len(path_sub_directories) - N_USER_KEY_STORED for p in path_sub_directories[:n_files_to_delete]: user_ids.append(p.name) shutil.rmtree(p) list_files_tmp = Path("tmp/").iterdir() # Delete all files related to user_id for file in list_files_tmp: for user_id in user_ids: if file.name.endswith(f"{user_id}.npy"): file.unlink() def keygen(): # Clean tmp directory if needed clean_tmp_directory() print("Initializing FHEModelClient...") task = st.session_state["task"] # Let's create a user_id user_id = numpy.random.randint(0, 2**32) fhe_api = FHEModelClient(f"deployment/deployment_{task}", f".fhe_keys/{user_id}") fhe_api.load() # Generate a fresh key fhe_api.generate_private_and_evaluation_keys(force=True) evaluation_key = fhe_api.get_serialized_evaluation_keys() numpy.save(f"tmp/tmp_evaluation_key_{user_id}.npy", evaluation_key) return [list(evaluation_key)[:ENCRYPTED_DATA_BROWSER_LIMIT], user_id] def encode_quantize_encrypt(text, user_id): task = st.session_state["task"] fhe_api = FHEModelClient(f"deployment/deployment_{task}", f".fhe_keys/{user_id}") fhe_api.load() encodings = get_ECFP_AND_FEATURES(text, radius=2, bits=1024).reshape(1, -1) quantized_encodings = fhe_api.model.quantize_input(encodings).astype(numpy.uint8) encrypted_quantized_encoding = fhe_api.quantize_encrypt_serialize(encodings) # Save encrypted_quantized_encoding in a file, since too large to pass through regular Gradio # buttons, https://github.com/gradio-app/gradio/issues/1877 numpy.save( f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy", encrypted_quantized_encoding, ) # Compute size encrypted_quantized_encoding_shorten = list(encrypted_quantized_encoding)[ :ENCRYPTED_DATA_BROWSER_LIMIT ] encrypted_quantized_encoding_shorten_hex = "".join( f"{i:02x}" for i in encrypted_quantized_encoding_shorten ) return ( encodings[0], quantized_encodings[0], encrypted_quantized_encoding_shorten_hex, ) def run_fhe(user_id): encoded_data_path = Path(f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy") # if not user_id: # print("You need to generate FHE keys first.") # if not encoded_data_path.is_file(): # print("No encrypted data was found. Encrypt the data before trying to predict.") # Read encrypted_quantized_encoding from the file task = st.session_state["task"] if st.session_state["fhe_prediction"] == "": encrypted_quantized_encoding = numpy.load(encoded_data_path) # Read evaluation_key from the file evaluation_key = numpy.load(f"tmp/tmp_evaluation_key_{user_id}.npy") # Use base64 to encode the encodings and evaluation key encrypted_quantized_encoding = base64.b64encode( encrypted_quantized_encoding ).decode() encoded_evaluation_key = base64.b64encode(evaluation_key).decode() query = {} query["evaluation_key"] = encoded_evaluation_key query["encrypted_encoding"] = encrypted_quantized_encoding headers = {"Content-type": "application/json"} if task == "0": response = requests.post( "http://localhost:8000/predict_HLM", data=json.dumps(query), headers=headers, ) elif task == "1": response = requests.post( "http://localhost:8000/predict_MDR1MDCK", data=json.dumps(query), headers=headers, ) elif task == "2": response = requests.post( "http://localhost:8000/predict_SOLUBILITY", data=json.dumps(query), headers=headers, ) elif task == "3": response = requests.post( "http://localhost:8000/predict_PROTEIN_BINDING_HUMAN", data=json.dumps(query), headers=headers, ) elif task == "4": response = requests.post( "http://localhost:8000/predict_PROTEIN_BINDING_RAT", data=json.dumps(query), headers=headers, ) elif task == "5": response = requests.post( "http://localhost:8000/predict_RLM_CLint", data=json.dumps(query), headers=headers, ) else: print("Invalid task number") encrypted_prediction = base64.b64decode(response.json()["encrypted_prediction"]) numpy.save(f"tmp/tmp_encrypted_prediction_{user_id}.npy", encrypted_prediction) encrypted_prediction_shorten = list(encrypted_prediction)[ :ENCRYPTED_DATA_BROWSER_LIMIT ] encrypted_prediction_shorten_hex = "".join( f"{i:02x}" for i in encrypted_prediction_shorten ) st.session_state["fhe_prediction"] = encrypted_prediction_shorten_hex st.session_state["fhe_done"] = True def decrypt_prediction(user_id): encoded_data_path = Path(f"tmp/tmp_encrypted_prediction_{user_id}.npy") # Read encrypted_prediction from the file task = st.session_state["task"] if st.session_state["decryption_done"] == False: encrypted_prediction = numpy.load(encoded_data_path).tobytes() fhe_api = FHEModelClient( f"deployment/deployment_{task}", f".fhe_keys/{user_id}" ) fhe_api.load() # We need to retrieve the private key that matches the client specs (see issue #18) fhe_api.generate_private_and_evaluation_keys(force=False) predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_prediction) st.session_state["decryption_done"] = True st.session_state["decrypted_prediction"] = predictions def init_session_state(): if "molecule_submitted" not in st.session_state: st.session_state["molecule_submitted"] = False if "input_molecule" not in st.session_state: st.session_state["input_molecule"] = "" if "key_generated" not in st.session_state: st.session_state["key_generated"] = False if "evaluation_key" not in st.session_state: st.session_state["evaluation_key"] = [] if "user_id" not in st.session_state: st.session_state["user_id"] = -100 if "encrypt" not in st.session_state: st.session_state["encrypt"] = False if "molecule_info_list" not in st.session_state: st.session_state["molecule_info_list"] = [] if "encrypt_tuple" not in st.session_state: st.session_state["encrypt_tuple"] = () if "fhe_prediction" not in st.session_state: st.session_state["fhe_prediction"] = "" if "fhe_done" not in st.session_state: st.session_state["fhe_done"] = False if "decryption_done" not in st.session_state: st.session_state["decryption_done"] = False if "decrypted_prediction" not in st.session_state: st.session_state[ "decrypted_prediction" ] = "" # actually a list of list. But python takes care as it is dynamically typed. def molecule_submitted(text: str = st.session_state.get("molecule_to_test", "")): msg_to_user = "" if len(text) == 0: msg_to_user = "Enter a non-empty molecule formula." molecule_present = False elif Chem.MolFromSmiles(text) == None: msg_to_user = "Invalid Molecule. Please enter a valid molecule. How about trying Aspirin or Ibuprofen?" molecule_present = False else: st.session_state["molecule_submitted"] = True st.session_state["input_molecule"] = text molecule_present = True msg_to_user = "Molecule Submitted for Prediction" st.session_state["molecule_info_list"].clear() st.session_state["molecule_info_list"].append(molecule_present) st.session_state["molecule_info_list"].append(msg_to_user) def keygen_util(): if st.session_state["molecule_submitted"] == False: pass else: if st.session_state["user_id"] == -100: (st.session_state["evaluation_key"], st.session_state["user_id"]) = keygen() st.session_state["key_generated"] = True def encrpyt_data_util(): if st.session_state["key_generated"] == False: pass else: if len(st.session_state["encrypt_tuple"]) == 0: st.session_state["encrypt_tuple"] = encode_quantize_encrypt( st.session_state["input_molecule"], st.session_state["user_id"] ) st.session_state["encrypt"] = True def mol_to_img(mol): mol = Chem.MolFromSmiles(mol) mol = AllChem.RemoveHs(mol) AllChem.Compute2DCoords(mol) drawer = rdMolDraw2D.MolDraw2DSVG(300, 300) drawer.DrawMolecule(mol) drawer.FinishDrawing() svg = drawer.GetDrawingText() return cairosvg.svg2png(bytestring=svg.encode("utf-8")) def FHE_util(): run_fhe(st.session_state["user_id"]) def decrypt_util(): decrypt_prediction(st.session_state["user_id"]) def clear_session_state(): st.session_state.clear() # Define global variables outside main function scope. task_options = ["0", "1", "2", "3", "4", "5"] task_mapping = { "0": "HLM", "1": "MDR-1-MDCK-ER", "2": "Solubility", "3": "Protein bind. human", "4": "Protein bind. rat", "5": "RLM", } task_mapping_2 = { "0": "LOG HLM_CLint (mL/min/kg)", "1": "LOG MDR1-MDCK ER (B-A/A-B)", "2": "LOG SOLUBILITY PH 6.8 (ug/mL)", "3": "LOG PLASMA PROTEIN BINDING (HUMAN) (% unbound)", "4": "LOG PLASMA PROTEIN BINDING (RAT) (% unbound)", "5": "LOG RLM_CLint (mL/min/kg)" } unit_mapping = { "0": "(mL/min/kg)", "1": " ", "2": "(ug/mL)", "3": " (%)", "4": " (%)", "5": "(mL/min/kg)", } task_options = list(task_mapping.values()) # Create the dropdown menu data_dict = { "HLM": "Human Liver Microsomes: drug is metabolized by the liver", "MDR-1-MDCK-ER": "MDR-1-MDCK-ER: drug is transported by the P-glycoprotein", "Solubility": "How soluble a drug is in water", "Protein bind. human": "Drug binding to human plasma proteins", "Protein bind. rat": "Drug binding to rat plasma proteins", "RLM": "Rat Liver Microsomes: Drug metabolism by a rat liver", } # Convert the dictionary to a DataFrame data = pd.DataFrame(list(data_dict.items()), columns=["Property", "Explanation"]) user_id = 0 css_styling = """""" if __name__ == "__main__": # Set up the Streamlit interface init_session_state() with st.container(): st.header("Start") st.markdown( "Run all the steps in order to predict a property for a molecule of your choice. Why not all steps at once? Because we want to show you the steps involved in the process (see figure above)." ) st.subheader(":red[Step 0: Which property do you want to predict?]") st.markdown( "This app can predict the following properties of confidential molecules:" ) # Check if 'task' is not already in session_state if "task" not in st.session_state: # Initialize it with the first value of your options st.session_state["task"] = "0" # Custom HTML and CSS styling html = data.to_html(index=False, classes="table table-striped table-hover") # Custom styling st.markdown(css_styling, unsafe_allow_html=True) # Display the HTML table st.write(html, unsafe_allow_html=True) st.markdown("Which one do you want to predict?") selected_label = st.selectbox( "Choose a property", task_options, index=task_options.index(task_mapping[st.session_state["task"]]), ) st.session_state["task"] = list(task_mapping.keys())[ task_options.index(selected_label) ] st.subheader(":red[Step 1: Submit a molecule]") x, y, z = st.columns(3) with x: st.text("") with y: submit_molecule = st.button( "Try Aspirin", on_click=molecule_submitted, args=("CC(=O)OC1=CC=CC=C1C(=O)O",), ) with z: submit_molecule = st.button( "Try Ibuprofen", on_click=molecule_submitted, args=("CC(Cc1ccc(cc1)C(C(=O)O)C)C",), ) # Use the custom keyup component for text input molecule_to_test = st_keyup( label="Enter a molecular SMILES string or choose one of the two options", value=st.session_state.get("molecule_to_test", ""), ) submit_molecule = st.button( "Submit", on_click=molecule_submitted, args=(molecule_to_test,), ) if len(st.session_state["molecule_info_list"]) != 0: if st.session_state["molecule_info_list"][0] == True: st.success(st.session_state["molecule_info_list"][1]) mol_image = mol_to_img(st.session_state["input_molecule"]) # center the image col1, col2, col3 = st.columns([1, 2, 1]) with col2: st.image(mol_image) st.caption(f"Input molecule {st.session_state['input_molecule']}") else: st.warning(st.session_state["molecule_info_list"][1], icon="⚠️") with st.container(): st.subheader( f":red[Step 2 : Generate encryption key (private to you) and an evaluation key (public).]" ) bullet_points = """ - Evaluation key is public and accessible by server. - Private Keys are solely accessible by client for encrypting the information before sending to the server. The same key is used for decryption after FHE inference. """ st.markdown(bullet_points, unsafe_allow_html=True) button_gen_key = st.button( "Generate Keys for this session", on_click=keygen_util ) if st.session_state["key_generated"] == True: st.success("Keys generated successfully", icon="πŸ™Œ") st.code(f'The user id for this session is {st.session_state["user_id"]} ') else: task = st.session_state["task"] task_label = task_mapping[task] st.warning( f"Please submit the molecule first to test its {task_label} value", icon="⚠️", ) with st.container(): st.subheader( f":red[Step 3 : Encrypt molecule using private key and send it to server.]" ) encrypt_button = st.button("Encrypt molecule", on_click=encrpyt_data_util) if st.session_state["encrypt"] == True: st.success("Successfully Encrypted Data", icon="πŸ™Œ") st.text("The server can only see the encrypted data:") st.code( f"The encrypted quantized encoding is \n {st.session_state['encrypt_tuple'][2]}..." ) else: st.warning( "Keys Not Yet Generated. Encryption can be done only after you generate keys." ) with st.container(): st.subheader(f":blue[Step 4 : Run encrypted prediction on server side.]") fhe_button = st.button("Predict in FHE domain", on_click=FHE_util) if st.session_state["fhe_done"]: st.success("Prediction Done Successfuly in FHE domain", icon="πŸ™Œ") st.code( f"The encrypted prediction is {st.session_state['fhe_prediction']}..." ) else: st.warning("Check if you have generated keys correctly.") with st.container(): st.subheader(f":red[Step 5 : Decrypt the predictions with your private key.]") decrypt_button = st.button( "Perform Decryption on FHE inferred prediction", on_click=decrypt_util ) if st.session_state["decryption_done"]: st.success("Decryption Done successfully!", icon="πŸ™Œ") value = st.session_state["decrypted_prediction"][0][0] # 2 digit precision value = round(value, 2) unit = unit_mapping[st.session_state["task"]] task_label = task_mapping[st.session_state["task"]] st.code( f"The Molecule {st.session_state['input_molecule']} has a {task_label} value of {value} {unit}" ) st.toast("Session successfully completed!!!") st.markdown("Is this a large, average or small value for this property? πŸ€” Find out by comparing with the property distribution of the training dataset") # now load the data from the pkl with open("all_data.pkl", "rb") as f: all_data = pickle.load(f) import plotly.graph_objects as go # Assuming task_mapping_2, all_data, and 'value' are defined elsewhere in your code. task_label_2 = task_mapping_2[st.session_state["task"]] data = all_data[task_label_2] # Create a histogram fig = go.Figure( go.Histogram( x=data, nbinsx=20, marker_color="blue", opacity=0.5, name="ADME dataset", ) ) # If you don't have specific y-values for the vertical line, you can set them to ensure the line spans the plot. # Here, we're assuming a static range. You might want to adjust these based on your dataset's characteristics. max_y_value = np.max(np.histogram(data, bins=20)[0]) # Calculate the max height of the histogram bars fig.add_trace(go.Scatter(x=[value, value], y=[0, max_y_value * 1.1], mode="lines", name="Prediction", line=dict(color="red", dash="dash"))) # Update layout if necessary fig.update_layout( title="Comparison of the molecule's value with the distribution of the ADME dataset", xaxis_title=task_label_2, yaxis_title="Count", bargap=0.2, ) # Display the figure in the Streamlit app st.plotly_chart(fig) else: st.warning("Check if FHE computation has been done.") with st.container(): st.subheader(f"Step 6 : Reset to predict a new molecule") reset_button = st.button("Reset app", on_click=clear_session_state) x, y, z = st.columns(3) with x: st.write("") with y: st.markdown( "

" + img_to_html("VaultChem.png", width="50%") + "

", unsafe_allow_html=True, ) st.markdown( "
Visit our website : VaultChem
", unsafe_allow_html=True, ) st.markdown( "
Visit our Github Repo : Github
", unsafe_allow_html=True, ) st.markdown( "
Built with Streamlit🎈
", unsafe_allow_html=True, ) with z: st.write("") st.markdown( """
The app was built with Concrete-ML, an open-source library by Zama.
""", unsafe_allow_html=True, ) st.write( ":red[Please Note]: The content of your app is purely for educational and illustrative purposes and is not intended for the management of sensitive information. We disclaim any liability for potential financial or other damages. This platform is not a substitute for professional health advice, diagnosis, or treatment. Health-related inquiries should be directed to qualified medical professionals. Use of this app implies acknowledgment of these terms and understanding of its intended educational use." ) st.write( ":red[Copyright notice]: Server_rack icon by DBCLS https://togotv.dbcls.jp/en/pics.html is licensed CC-BY 4.0 Unported, rat-adult icon by Servier https://smart.servier.com/ is licensed under CC-BY 3.0 Unported https://creativecommons.org/licenses/by/3.0/" )