|
--- |
|
language: |
|
- zh |
|
base_model: junnyu/roformer_chinese_base |
|
tags: |
|
- sentence-similarity |
|
- cmteb |
|
- sentence-transformers |
|
- transformers |
|
|
|
--- |
|
|
|
## <u>INF</u> <u>W</u>ord-level <u>S</u>parse <u>E</u>mbedding (INF-WSE) |
|
|
|
**INF-WSE** is a series of word-level sparse embedding models developed by [INF TECH](https://www.infly.cn/en). |
|
These models are optimized to generate sparse, high-dimensional text embeddings that excel in capturing the most |
|
relevant information for search and retrieval, particularly in Chinese text. |
|
|
|
### Key Features: |
|
|
|
- **Optimized for Retrieval**: INF-WSE is designed with retrieval tasks in mind. The sparse embeddings enable efficient |
|
matching between queries and documents, making it highly effective for semantic search, ranking, and information |
|
retrieval scenarios where speed and accuracy are critical. |
|
- **Word-level Sparse Embeddings**: The model generates sparse representations at the word level, capturing essential |
|
semantic details that help improve the relevance of search results. This is particularly useful for Chinese language |
|
retrieval tasks, where word segmentation can significantly impact performance. |
|
- **Sparse Representation for Efficiency**: Unlike dense embeddings that have a fixed number of dimensions, INF-WSE |
|
produces sparse embeddings where the dimensionality matches the vocabulary size. Most dimensions are set to zero, |
|
focusing only on the most significant terms. This sparsity reduces the computational load, enabling faster retrieval |
|
without compromising on precision. |
|
|
|
## Usage |
|
|
|
### Transformers |
|
|
|
#### Infer embeddings |
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
queries = ['电脑一体机由什么构成?', '什么是掌上电脑?'] |
|
documents = [ |
|
'电脑一体机,是由一台显示器、一个电脑键盘和一个鼠标组成的电脑。', |
|
'掌上电脑是一种运行在嵌入式操作系统和内嵌式应用软件之上的、小巧、轻便、易带、实用、价廉的手持式计算设备。', |
|
] |
|
input_texts = queries + documents |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True, use_fast=False) # Fast tokenizer has not been supported yet |
|
model = AutoModel.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True) |
|
model.eval() |
|
|
|
max_length = 512 |
|
|
|
input_batch = tokenizer(input_texts, padding=True, max_length=max_length, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
embeddings = model(input_batch['input_ids'], input_batch['attention_mask'], return_sparse=False) # if return_sparse=True, return sparse tensor, else return dense tensor |
|
|
|
scores = embeddings[:2] @ embeddings[2:].T |
|
print(scores.tolist()) |
|
# [[21.224790573120117, 4.520412921905518], [10.290857315063477, 19.359437942504883]] |
|
``` |
|
|
|
#### Convert embeddings to lexical weights |
|
```python |
|
from collections import OrderedDict |
|
def convert_embeddings_to_weights(embeddings, tokenizer): |
|
values, indices = torch.sort(embeddings, dim=-1, descending=True) |
|
|
|
token2weight = [] |
|
for i in range(embeddings.size(0)): |
|
token2weight.append(OrderedDict()) |
|
|
|
non_zero_mask = values[i] != 0 |
|
tokens = tokenizer.convert_ids_to_tokens(indices[i][non_zero_mask]) |
|
weights = values[i][non_zero_mask].tolist() |
|
|
|
for token, weight in zip(tokens, weights): |
|
token2weight[i][token] = weight |
|
|
|
return token2weight |
|
|
|
token2weight = convert_embeddings_to_weights(embeddings, tokenizer) |
|
print(token2weight[1]) |
|
# OrderedDict([('掌上', 3.4572525024414062), ('电脑', 2.6253132820129395), ('是', 2.0787220001220703), ('什么', 1.2899624109268188)]) |
|
``` |
|
|
|
## Evaluation |
|
|
|
### C-MTEB Retrieval task |
|
|
|
([Chinese Massive Text Embedding Benchmark](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB)) |
|
|
|
Metric: nDCG@10 |
|
|
|
| Model Name | Max Length | Average | Cmedqa | Covid | Du | Ecom | Medical | MMarco | T2 | Video | |
|
|:---------------------------------------------------:|:----------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:| |
|
| [BM25-zh](https://github.com/castorini/pyserini) | - | 50.37 | 13.70 | **86.58** | 57.13 | 44.04 | 32.08 | 48.31 | 60.48 | 60.64 | |
|
| [bge-m3-sparse](https://huggingface.co/BAAI/bge-m3) | 512 | 57.00 | **24.50** | 76.09 | 71.51 | 50.49 | 43.93 | 59.28 | 71.76 | 58.43 | |
|
| **inf-wse-v1-base-zh** | 512 | **61.16** | 20.51 | 76.41 | **79.84** | **56.78** | **46.24** | **66.40** | **76.50** | **68.57** | |
|
|
|
All results, except for BM25, are measured by building the sparse index via [Qdrant](https://github.com/qdrant/qdrant). |