tags:
- generated_from_trainer
- chemistry
- medical
- drug_drug_interaction
metrics:
- f2-score
- recall
- precision
- mcc
model-index:
- name: Bio_ClinicalBERT_DDI_finetuned
results:
- task:
name: Drug - Drug Interaction Classification
type: text-classification
dataset:
name: DrugBank
type: REST API
metrics:
- name: Recall
type: recall
value: 0.7849
widget:
- text: '[Ca++].[O-]C([O-])=O [SEP] OC[C@H](O)[C@@H](O)[C@H](O)[C@H](O)CO'
example_title: Drug1 [SEP] Drug2
pipeline_tag: text-classification
Bio_ClinicalBERT_DDI_finetuned
This model was initialized from Bio_ClinicalBERT by adding three hidden layers after the BERT pooler layer. The model was trained on the Drug-Drug Interaction dataset extracted from DrugBank database and National Library of Medicine API. It achieves the following results on the Test dataset:
- F2: 0.7872
- AUPRC: 0.869
- Recall: 0.7849
- Precision: 0.7967
- MCC: 0.3779
Model description
Predict Drug Drug Interaction (DDI) from Chemical Structure of two drugs. The Model returns the probability of the two drugs having interaction with each other.
Intended uses & limitations
To construct the input, using "[SEP]" token to seperate between the two drugs, example of a properly constructed input is as following
drug1 = "[Ca++].[O-]C([O-])=O" #Calcium Carbonate
drug2 = "OC[C@H](O)[C@@H](O)[C@H](O)[C@H](O)CO" #Sorbitol
correct_input = "[Ca++].[O-]C([O-])=O [SEP] OC[C@H](O)[C@@H](O)[C@H](O)[C@H](O)CO"
Training and evaluation data
To avoid data leakage and able to predict DDI for new drugs, the drug1 or drug2 in the validation and the test set were not included in the training set. Their SMILES chemical structures were never exposed to the training process.
Training procedure
Using AWS EC2 g5.4xlarge instance. 24GB GPU.
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.01
- train_batch_size: 32
- eval_batch_size: 32
- seed: 7
- optimizer: Adadelta with weight_decay=1e-04
- lr_scheduler_type: CosineAnnealingLR
- num_epochs: 4
Training results
Training Loss | Epoch | Validation Loss | F2 | Recall | Precision | Mcc |
---|---|---|---|---|---|---|
0.6068 | 1.0 | 0.7061 | 0.6508 | 0.6444 | 0.6778 | 0.2514 |
0.4529 | 2.0 | 0.8334 | 0.7555 | 0.7727 | 0.6939 | 0.3451 |
0.3375 | 3.0 | 0.9582 | 0.7636 | 0.7840 | 0.6915 | 0.3474 |
0.2624 | 4.0 | 1.2588 | 0.7770 | 0.8004 | 0.6954 | 0.3654 |
Framework versions
- Transformers 4.30.2
- Pytorch 2.0.1
- Datasets 2.13.1
- Tokenizers 0.13.3