Commit
•
9c9e6d9
1
Parent(s):
c445d96
readme: add onnx mean pool function (#82)
Browse files- Add onnx mean pool function (d6fc33f53863b99071ffc22ead5273301daf3d59)
- compress normalization func (5316bd959084d8c6672afd0bfde314034ce84ed6)
- modify mean pooling annotations in onnx model (f96969c4e88351ae57c82a2a0c321429a277c783)
Co-authored-by: knysfh <knysfh@users.noreply.huggingface.co>
README.md
CHANGED
@@ -25206,6 +25206,15 @@ import onnxruntime
|
|
25206 |
import numpy as np
|
25207 |
from transformers import AutoTokenizer, PretrainedConfig
|
25208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25209 |
# Load tokenizer and model config
|
25210 |
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v3')
|
25211 |
config = PretrainedConfig.from_pretrained('jinaai/jina-embeddings-v3')
|
@@ -25229,7 +25238,9 @@ inputs = {
|
|
25229 |
# Run model
|
25230 |
outputs = session.run(None, inputs)[0]
|
25231 |
|
25232 |
-
# Apply mean pooling
|
|
|
|
|
25233 |
```
|
25234 |
|
25235 |
</p>
|
|
|
25206 |
import numpy as np
|
25207 |
from transformers import AutoTokenizer, PretrainedConfig
|
25208 |
|
25209 |
+
# Mean pool function
|
25210 |
+
def mean_pooling(model_output: np.ndarray, attention_mask: np.ndarray):
|
25211 |
+
token_embeddings = model_output
|
25212 |
+
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
25213 |
+
input_mask_expanded = np.broadcast_to(input_mask_expanded, token_embeddings.shape)
|
25214 |
+
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
|
25215 |
+
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
|
25216 |
+
return sum_embeddings / sum_mask
|
25217 |
+
|
25218 |
# Load tokenizer and model config
|
25219 |
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v3')
|
25220 |
config = PretrainedConfig.from_pretrained('jinaai/jina-embeddings-v3')
|
|
|
25238 |
# Run model
|
25239 |
outputs = session.run(None, inputs)[0]
|
25240 |
|
25241 |
+
# Apply mean pooling and normalization to the model outputs
|
25242 |
+
embeddings = mean_pooling(outputs, input_text["attention_mask"])
|
25243 |
+
embeddings = embeddings / np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
|
25244 |
```
|
25245 |
|
25246 |
</p>
|