File size: 6,274 Bytes
6a950b5 55d3639 a2b01ca 55d3639 6a950b5 c032c99 55d3639 c032c99 55d3639 c032c99 55d3639 c032c99 55d3639 c032c99 55d3639 c032c99 55d3639 c032c99 55d3639 d9f3406 55d3639 d9f3406 55d3639 d9f3406 6a950b5 55d3639 6a950b5 215abfa 6a950b5 55d3639 6a950b5 55d3639 6a950b5 55d3639 6a950b5 55d3639 bc4d4bb 55d3639 6a950b5 55d3639 6a950b5 bc4d4bb 6a950b5 55d3639 6a950b5 55d3639 6a950b5 55d3639 6a950b5 55d3639 6a950b5 55d3639 6a950b5 55d3639 6a950b5 55d3639 b024045 55d3639 c032c99 55d3639 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
---
library_name: sentence-transformers
pipeline_tag: sentence-similarity
tags:
- sentence-transformers
- sentence-similarity
- feature-extraction
---
<div align="center">
<img src="https://raw.githubusercontent.com/Anditty/OASIS/refs/heads/main/Group.svg" width="60%" alt="Kwaipilot" />
</div>
<hr>
# Kwaipilot OASIS-1.3B
## Model Details
**Model Name**: OASIS (Optimized Augmentation Strategy for Improved code Search)
**Introduction**
OASIS is a state-of-the-art code embedding model developed by Kwaipilot. This model incorporates unique, proprietary methods including **repository-level program analysis**, the **OASIS-instruct data synthesis** algorithm, and a **specialized fusion loss function**, setting new benchmarks in code search efficiency and accuracy.
**Intended Use**
This model is ideal for developers and researchers engaged in enhancing **code retrieval systems**. OASIS excels in scenarios requiring semantic understanding and retrieval of code snippets within varied programming contexts.
**Training and Performance**
OASIS was trained on a synthetic dataset created through repository-level analysis, ensuring broad understanding across different coding styles and languages. It has demonstrated state-of-the-art performance on latest code search benchmarks.
## Future Directions
Kwaipilot upcoming initiatives include:
- Open sourcing improved models.
- Releasing technical reports.
- Releasing natural language processing models.
- ...
## Performance
| | Size | CoSQA | AdvTest | CSN-Py | CSN-Ja | CSN-JS | CSN-PHP | CSN-Go | CSN-Ruby | Avg|
|-----------------|:-----:|:------:|:---------:|:--------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|
|Openai-Embedding-Ada-002 | Unknown | 0.4423| 0.3808 | 0.6802 | 0.7149| 0.6750| 0.6062| 0.8563| **0.7472**|0.6378|
|jina-embeddings-v2-base-code | 161M |**0.6837** |0.385 | 0.6634 | 0.6803| 0.6304| 0.5701| 0.8595| 0.7095|0.6477|
| CodeSage-large | 1.3B | 0.4753| **0.5267** | 0.7077 | 0.7021| **0.695** | 0.6133| 0.8371| 0.7192|0.6595|
| CodeFuse-CGE-Small | 3.8B | 0.5619| 0.4639 | 0.6958 | 0.6863| 0.6564| 0.6133| 0.8637| 0.7341|0.6594|
| OASIS-1.3B | 1.3B | 0.5532| 0.4861 | **0.7110** | **0.7199**| 0.6727| **0.6217**| **0.8732**| 0.7333|**0.6713**|
## Usage
### Direct Usage
```bash
pip install -U torch
pip install -U transformers
```
Avoid using torch=2.5.0 when loading the model with torch_dtype=torch.bfloat16. For optimal performance and stability, please use PyTorch version 2.4.1 or earlier, or upgrade to 2.5.1 or later.
```python
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
# Add query prompt
def get_query_prompt(query: str):
query_description = 'Given a code search query, retrieve relevant code snippet that answer the query'
prompt = f'Instruct: {query_description}\nQuery: {query}'
return prompt
query = "How to do quicksort in python?"
code1 = """def bubble_sort(arr):
n = len(arr)
for i in range(n):
swapped = False
for j in range(1, n - i):
if arr[j - 1] > arr[j]:
arr[j - 1], arr[j] = arr[j], arr[j - 1]
swapped = True
if not swapped:
break
return arr"""
code2 = """def quick_sort(arr):
if len(arr) <= 1:
return arr
else:
pivot = arr[0]
less = [x for x in arr[1:] if x <= pivot]
greater = [x for x in arr[1:] if x > pivot]
return quick_sort(less) + [pivot] + quick_sort(greater)"""
model = AutoModel.from_pretrained("Kwaipilot/OASIS-code-1.3B", output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("Kwaipilot/OASIS-code-1.3B")
# Tokenize and inference
inputs = tokenizer([get_query_prompt(query), code1, code2], max_length=8192, padding=True, truncation=True, return_tensors='pt')
outputs = model(**inputs)
# Last token pooling
embeddings = last_token_pool(outputs.hidden_states[-1], inputs['attention_mask'])
print(embeddings.shape)
# torch.Size([3, 2048])
embeddings = F.normalize(embeddings, dim=1, p=2)
similarity = embeddings @ embeddings.T
print(similarity[0, 1:])
# tensor([0.6495, 0.8036])
```
### Sentence Transformers
First install the Sentence Transformers library:
```bash
pip install -U sentence-transformers
```
Then you can load this model and run inference.
```python
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub
model = SentenceTransformer("Kwaipilot/OASIS-code-1.3B")#, model_kwargs={"torch_dtype": torch.bfloat16})
query = "How to do quicksort in python?"
code1 = """def bubble_sort(arr):
n = len(arr)
for i in range(n):
swapped = False
for j in range(1, n - i):
if arr[j - 1] > arr[j]:
arr[j - 1], arr[j] = arr[j], arr[j - 1]
swapped = True
if not swapped:
break
return arr"""
code2 = """def quick_sort(arr):
if len(arr) <= 1:
return arr
else:
pivot = arr[0]
less = [x for x in arr[1:] if x <= pivot]
greater = [x for x in arr[1:] if x > pivot]
return quick_sort(less) + [pivot] + quick_sort(greater)"""
# Run inference
query_embedding = model.encode([query], prompt_name="query")
code_embeddings = model.encode([code1, code2])
print(code_embeddings.shape)
# (2, 2048)
# Get the similarity scores for the embeddings
print(model.similarity(query_embedding[0], code_embeddings[0]))
print(model.similarity(query_embedding[0], code_embeddings[1]))
# tensor([[0.6495]])
# tensor([[0.8036]])
```
### BibTeX
```bibtex
@misc{kwaipilotoasis,
title = {Optimized Augmentation Strategy for Improved code Search},
author = {Kwaipilot team},
year = {2024},
}
``` |