|
from transformers import Qwen2Tokenizer |
|
import bm25s |
|
from bm25s.hf import BM25HF |
|
|
|
|
|
|
|
|
|
class RexQwen2Tokenizer(Qwen2Tokenizer): |
|
|
|
def __init__( |
|
self, |
|
rex_index_name = "Llama-3-Magpie-Pro-1M-v0.1", |
|
rex_size = 3, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.rex_index_name = rex_index_name |
|
hf_repo_name = f"yuchenlin/BM25S_index_{self.rex_index_name}" |
|
self.retriever = BM25HF.load_from_hub( |
|
hf_repo_name, revision="main", load_corpus=True |
|
) |
|
self.rex_size = rex_size |
|
self.user_prefix = "<|im_start|>user" |
|
|
|
def _rex_query(self, query): |
|
k = self.rex_size |
|
query_tokens = bm25s.tokenize(query, show_progress=False) |
|
results, scores = self.retriever.retrieve(query_tokens, k=k, show_progress=False) |
|
rex_chat_history = [] |
|
for i in range(results.shape[1]): |
|
doc, score = results[0, i], scores[0, i] |
|
rex_query = doc["query"] |
|
rex_response = doc["response"] |
|
rex_chat_history.append({"role": "user", "content": rex_query}) |
|
rex_chat_history.append({"role": "assistant", "content": rex_response}) |
|
rex_chat_history_tokens = self.apply_chat_template(rex_chat_history, tokenize=False, add_generation_prompt=False) |
|
start_user = rex_chat_history_tokens.index(self.user_prefix) |
|
rex_chat_history_tokens = rex_chat_history_tokens[start_user:] |
|
return rex_chat_history_tokens |
|
|
|
|
|
|
|
|
|
def tokenize(self, text, **kwargs): |
|
|
|
if self.user_prefix not in text or self.rex_size < 1: |
|
|
|
|
|
return super().tokenize(text, **kwargs) |
|
start_index = text.index(self.user_prefix) |
|
rex_chat_history_tokens = self._rex_query(text[start_index+len(self.user_prefix):]) |
|
rex_text = text[:start_index] + rex_chat_history_tokens + text[start_index:] |
|
|
|
tokens = super().tokenize(rex_text, **kwargs) |
|
return tokens |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from transformers import AutoTokenizer |
|
model_path = "/net/nfs/mosaic/yuchenl/Qwen2-1.5B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, rex_size=3) |
|
|
|
messages = [ |
|
{"role": "user", "content": "Who is Yuchen Lin?"}, |
|
{"role": "assistant", "content": "Yuchen Lin is a NLP researcher."}, |
|
{"role": "user", "content": "Can I ask him a question?"} |
|
] |
|
query = tokenizer.apply_chat_template(messages, tokenize=False) |
|
print(tokenizer.tokenize(query)) |
|
|