Modified version of xlm-roberta-flash-implementation for the onnx conversion
Brief Summary of Challenges and Modifications:
Dynamic Matrix Calculation in RoPE
The original RoPE implementation did not compute the entire rotation matrix at the start. Instead, it calculated the matrix only for the required sequence length, cached it, and recalculated if a longer sequence came as input. This approach isn't compatible with ONNX, which requires a fixed graph during inference. To solve this, I now calculate the entire rotation matrix in advance.
Custom Backward Functions for RoPE
We have custom forward and backward functions for RoPE. ONNX does not support custom backward functions, but since we only need forward passes for inference with ONNX, I removed the backward function completely.
ONNX Model Size Limitation
ONNX stores the model in a protobuf format, which has a maximum size limit of 2GB. Our model was too large to fit this limit, so I had to store the model's parameters as external data files.
Lack of Support for the unique()
Function
We used the unique()
function to identify unique task types in a batch, which is important when there are multiple task types. However, ONNX does not support the unique() function. For inference, having multiple task types in a batch is not important. Therefore, I modified the code to use the task_id
argument—an integer that works for every text in a batch—instead of the adapter_mask
, which was a tensor specifying an independent task ID for each text in the batch.
Code
import torch
from transformers import AutoModel, AutoTokenizer
import torch.onnx
model = AutoModel.from_pretrained('/home/admin/saba/jina-embeddings-v3', trust_remote_code=True, use_flash_attn=False)
model.eval()
onnx_path = "/home/admin/saba/jina-embeddings-v3/onnx/model.onnx"
tokenizer = AutoTokenizer.from_pretrained('/home/admin/saba/jina-embeddings-v3')
inputs = tokenizer(["jina", 'ai'], return_tensors="pt", padding='longest')
inps = inputs['input_ids']
mask = inputs['attention_mask']
task_id = 2
torch.onnx.export(
model,
(inps, mask, task_id),
onnx_path,
export_params=True,
do_constant_folding=True,
input_names = ['input_ids', 'attention_mask', 'task_id'],
output_names = ['text_embeds'],
opset_version=16,
dynamic_axes={
'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
'attention_mask' : {0 : 'batch_size', 1: 'sequence_length'},
'text_embeds' : {0 : 'batch_size'}
},
)