The textual ONNX models seem to have issues with more than 16 tokens on input

#22
by ondrejnespor - opened

It is my understanding that the onnx texual model expects token IDs on input:

import onnxruntime as ort

ort_sess = ort.InferenceSession('text_model.onnx')
for i in ort_sess.get_inputs():
    print(i)

says

NodeArg(name='input_ids', type='tensor(int64)', shape=['batch_size', 'sequence_length'])

so the expected use would be something like:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-clip-v1')

input = tokenizer(['hello world'])
output = ort_sess.run(None, {'input_ids': input['input_ids']})

The issue is that the model stops working with inputs of 17 tokens or longer.

For 16 tokens, it returns an embedding:

input = tokenizer(['hello world hello world hello world hello world hello world hello world hello world'])
print(len(input['input_ids'][0])) # 16
output = ort_sess.run(None, {'input_ids': input['input_ids']})

with 17 or more, it throws an error:

input = tokenizer(['hello world hello world hello world hello world hello world hello world hello world hello'])
print(len(input['input_ids'][0])) # 17
output = ort_sess.run(None, {'input_ids': input['input_ids']})

outputs

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Sub node. Name:'/transformer/encoder/layers.0/mixer/inner_attn/Sub' Status Message: D:\a\_work\1\s\onnxruntime\core/providers/cpu/math/element_wise_ops.h:540 onnxruntime::BroadcastIterator::Init axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 16 by 17

It does it for all text_model, text_model_quantized and text_model_int8.

When I try to export an onnx model myself:

import torch
from transformers import AutoModel

model = AutoModel.from_pretrained('jinaai/jina-clip-v1', trust_remote_code=True)
tm = model.text_model
tm.eval()

torch.onnx.export(
    tm,
    torch.randint(1, 1024, (1, 1024)),
    "./manual.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names = ['input_ids'],
    output_names = ['text_embeds'],
    dynamic_axes={
        'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
        'text_embeds' : {0 : 'batch_size'}
    }
)

I should end up with a model very similar to text_model.onnx from this repo. And indeed it seems to return the same embeddings but does support 17 tokens and more on input.

Am I missing something or do the textual ONNX exports have an issue?

Yes, I’ve encountered the same issue. I found that the cause is related to the need to update variables at runtime in the SelfAttention code, specifically in the mha file of the jina-bert-flash-implementation

if self.alibi_slopes is not None:
            if seqlen > self.linear_biases.shape[-1]:
                self.linear_biases = self._build_linear_biases(seqlen)
            cropped_biases = self.linear_biases[..., :seqlen, :seqlen]
            # print(self.linear_biases, self.linear_biases.shape)
            scores = scores - cropped_biases
Jina AI org

thanks, i'll take a look today!

Jina AI org

@Riddler2024 @ondrejnespor the issue with the text_model.onnx should be fixed now!

bwang0911 changed discussion status to closed

Sign up or log in to comment