etrop's picture
Update README.md
8437243
|
raw
history blame
4.89 kB
---
license: cc-by-nc-sa-4.0
widget:
- text: ACCTGA<mask>TTCTGAGTC
datasets:
- InstaDeepAI/plant-genomic-benchmark
tags:
- biology
- genomics
- language model
- plants
---
## Model Overview
AgroNT is a DNA language model trained on primarily edible plant genomes. More specifically, AgroNT uses the transformer architecture with self-attention and a masked language modeling
objective to leverage highly available genotype data from 48 different plant speices to learn general representations of nucleotide sequences. AgroNT contains 1 billion parameters and has a context window of 1024 tokens.
AgroNt uses a non-overlapping 6-mer tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1024 tokens correspond to approximately 6144 base pairs.
## How to use
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
model_name = 'agro-nucleotide-transformer-1b'
# fetch model and tokenizer from InstaDeep's hf repo
agro_nt_model = AutoModelForMaskedLM.from_pretrained(f'InstaDeepAI/{model_name}')
agro_nt_tokenizer = AutoTokenizer.from_pretrained(f'InstaDeepAI/{model_name}')
print(f"Loaded the {model_name} model with {agro_nt_model.num_parameters()} parameters and corresponding tokenizer.")
# example sequence and tokenization
sequences = ['ATATACGGCCGNC','GGGTATCGCTTCCGAC']
batch_tokens = agro_nt_tokenizer(sequences,padding="longest")['input_ids']
print(f"Tokenzied sequence: {agro_nt_tokenizer.batch_decode(batch_tokens)}")
torch_batch_tokens = torch.tensor(batch_tokens)
attention_mask = torch_batch_tokens != agro_nt_tokenizer.pad_token_id
# inference
outs = agro_nt_model(
torch_batch_tokens,
attention_mask=attention_mask,
encoder_attention_mask=attention_mask,
output_hidden_states=True
)
# get the final layer embeddings and language model head logits
embeddings = outs['hidden_states'][-1].detach().numpy()
logits = outs['logits'].detach().numpy()
```
## Pre-training
#### Data
Our pre-training dataset was built from (mostly) edible plants reference genomes contained in the Ensembl Plants database.
The dataset consists of approximately 10.5 million genomic sequences across 48 different species.
#### Processing
All reference genomes for each specie were assembled into a single fasta file. In this fasta file, all nucleotides other than A, T, C, G were replaced by N. A tokenizer was used to convert strings of letters into sequences of tokens.
The tokenizer's alphabet consisted of the 4<sup>6</sup> = 4096 possible 6-mer combinations obtained by combining A, T, C, G, as well as five additional tokens
representing standalone A, T, C, G, and N. It also included three special tokens: the pad [PAD], mask [MASK], and class [CLS] tokens. This resulted in a vocabulary of 4104 tokens. To tokenize an input sequence, the tokenizer started with a class token and
then converted the sequence from left to right, matching 6-mer tokens when possible, or using the standalone tokens when necessary (for instance, when the letter
N was present or if the sequence length was not a multiple of 6).
**Tokenization example**
nucleotide sequence: ```ATCCCGGNNTCGACACN```\
tokens: ```<CLS> <ATCCCG> <G> <N> <N> <TCGACA> <C> <N>```
#### Training
The MLM objective was used to pre-train AgroNT in a self-supervised manner. In a self-supervised learning setting annotations (supervision) for each sequence
are not needed as we can mask some proportion of the sequence and use the information contained in the unmasked portion of the sequence to predict the masked locations.
This allows us to leverage the vast amount of unlabeled genomic sequencing data available. Specifically, 15\% of the tokens in the input sequence are selected to be
augmented with 80\% being replaced with a mask token, 10\% randomly replaced by another token from the vocabulary, and the final 10\% maintaining the same token.
The tokenized sequence is passed through the model and a cross entropy loss is computed for the masked tokens. Pre-training was carried out with a sequence length of 1024 tokens
and an effective batch size of 1.5M tokens for 315k update steps, resulting in the model training on a total of 472.5B tokens.
#### Hardware
Model pre-training was carried out using Google TPU-V4 accelerators, specifically a TPU v4-1024 containing 512 devices. We trained for a total of approx. four days.
### BibTeX entry and citation info
```bibtex
@article{mendoza2023foundational,
title={A Foundational Large Language Model for Edible Plant Genomes},
author={Mendoza-Revilla, Javier and Trop, Evan and Gonzalez, Liam and Roller, Masa and Dalla-Torre, Hugo and de Almeida, Bernardo P and Richard, Guillaume and Caton, Jonathan and Lopez Carranza, Nicolas and Skwark, Marcin and others},
journal={bioRxiv},
pages={2023--10},
year={2023},
publisher={Cold Spring Harbor Laboratory}
}
```