File size: 4,830 Bytes
e9c77ac
 
f8e2abd
e9c77ac
 
f8e2abd
 
 
e9c77ac
f8e2abd
e9c77ac
 
ea48bc6
e9c77ac
3b1de57
e9c77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea48bc6
e9c77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea48bc6
e9c77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
ea48bc6
 
e9c77ac
 
 
 
 
ea48bc6
e9c77ac
 
 
 
 
 
0e34514
 
 
ea48bc6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
---
language:
- zh
base_model: junnyu/roformer_chinese_base
tags:
  - sentence-similarity
  - cmteb
  - sentence-transformers
  - transformers

---

## <u>INF</u> <u>W</u>ord-level <u>S</u>parse <u>E</u>mbedding (INF-WSE)

**INF-WSE** is a series of word-level sparse embedding models developed by [INF TECH](https://www.infly.cn/en).
These models are optimized to generate sparse, high-dimensional text embeddings that excel in capturing the most
relevant information for search and retrieval, particularly in Chinese text.

### Key Features:

- **Optimized for Retrieval**: INF-WSE is designed with retrieval tasks in mind. The sparse embeddings enable efficient
  matching between queries and documents, making it highly effective for semantic search, ranking, and information
  retrieval scenarios where speed and accuracy are critical.
- **Word-level Sparse Embeddings**: The model generates sparse representations at the word level, capturing essential
  semantic details that help improve the relevance of search results. This is particularly useful for Chinese language
  retrieval tasks, where word segmentation can significantly impact performance.
- **Sparse Representation for Efficiency**: Unlike dense embeddings that have a fixed number of dimensions, INF-WSE
  produces sparse embeddings where the dimensionality matches the vocabulary size. Most dimensions are set to zero,
  focusing only on the most significant terms. This sparsity reduces the computational load, enabling faster retrieval
  without compromising on precision.

## Usage

### Transformers

#### Infer embeddings
```python
import torch
from transformers import AutoTokenizer, AutoModel

queries = ['电脑一体机由什么构成?', '什么是掌上电脑?']
documents = [
    '电脑一体机,是由一台显示器、一个电脑键盘和一个鼠标组成的电脑。',
    '掌上电脑是一种运行在嵌入式操作系统和内嵌式应用软件之上的、小巧、轻便、易带、实用、价廉的手持式计算设备。',
]
input_texts = queries + documents

tokenizer = AutoTokenizer.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True, use_fast=False)  # Fast tokenizer has not been supported yet
model = AutoModel.from_pretrained("infly/inf-wse-v1-base-zh", trust_remote_code=True)
model.eval()

max_length = 512

input_batch = tokenizer(input_texts, padding=True, max_length=max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
    embeddings = model(input_batch['input_ids'], input_batch['attention_mask'], return_sparse=False)  # if return_sparse=True, return sparse tensor, else return dense tensor

scores = embeddings[:2] @ embeddings[2:].T
print(scores.tolist())
# [[21.224790573120117, 4.520412921905518], [10.290857315063477, 19.359437942504883]]
```

#### Convert embeddings to lexical weights
```python
from collections import OrderedDict
def convert_embeddings_to_weights(embeddings, tokenizer):
    values, indices = torch.sort(embeddings, dim=-1, descending=True)
    
    token2weight = []
    for i in range(embeddings.size(0)):
        token2weight.append(OrderedDict())

        non_zero_mask = values[i] != 0
        tokens = tokenizer.convert_ids_to_tokens(indices[i][non_zero_mask])
        weights = values[i][non_zero_mask].tolist()

        for token, weight in zip(tokens, weights):
            token2weight[i][token] = weight

    return token2weight

token2weight = convert_embeddings_to_weights(embeddings, tokenizer)
print(token2weight[1])
# OrderedDict([('掌上', 3.4572525024414062), ('电脑', 2.6253132820129395), ('是', 2.0787220001220703), ('什么', 1.2899624109268188)])
```

## Evaluation

### C-MTEB Retrieval task

([Chinese Massive Text Embedding Benchmark](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB))

Metric: nDCG@10

|                     Model Name                      | Max Length |  Average  |  Cmedqa   |   Covid   |    Du     |   Ecom    |  Medical  |  MMarco   |    T2     |   Video   | 
|:---------------------------------------------------:|:----------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
|  [BM25-zh](https://github.com/castorini/pyserini)   |     -      |   50.37   |   13.70   | **86.58** |   57.13   |   44.04   |   32.08   |   48.31   |   60.48   |   60.64   |
| [bge-m3-sparse](https://huggingface.co/BAAI/bge-m3) |    512     |   57.00   | **24.50** |   76.09   |   71.51   |   50.49   |   43.93   |   59.28   |   71.76   |   58.43   |
|               **inf-wse-v1-base-zh**                |    512     | **61.16** |   20.51   |   76.41   | **79.84** | **56.78** | **46.24** | **66.40** | **76.50** | **68.57** |

All results, except for BM25, are measured by building the sparse index via [Qdrant](https://github.com/qdrant/qdrant).