Spaces:
Runtime error
Runtime error
File size: 6,583 Bytes
bfc0ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
"""Interface for storing vectors."""
import abc
import os
import pickle
from typing import Iterable, Optional, Type
import numpy as np
from ..schema import SpanVector, VectorKey
from ..utils import open_file
class VectorStore(abc.ABC):
"""Interface for storing and retrieving vectors."""
# The global name of the vector store.
name: str
@abc.abstractmethod
def save(self, base_path: str) -> None:
"""Save the store to disk."""
pass
@abc.abstractmethod
def load(self, base_path: str) -> None:
"""Load the store from disk."""
pass
@abc.abstractmethod
def size(self) -> int:
"""Return the number of vectors in the store."""
pass
@abc.abstractmethod
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
"""Add or edit the given keyed embeddings to the store.
If the keys already exist they will be overwritten, acting as an "upsert".
Args:
keys: The keys to add the embeddings for.
embeddings: The embeddings to add. This should be a 2D matrix with the same length as keys.
"""
pass
@abc.abstractmethod
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
"""Return the embeddings for given keys.
Args:
keys: The keys to return the embeddings for. If None, return all embeddings.
Returns
The embeddings for the given keys.
"""
pass
def topk(self,
query: np.ndarray,
k: int,
keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
"""Return the top k most similar vectors.
Args:
query: The query vector.
k: The number of results to return.
keys: Optional keys to restrict the search to.
Returns
A list of (key, score) tuples.
"""
raise NotImplementedError
PathKey = VectorKey
_SPANS_PICKLE_NAME = 'spans.pkl'
class VectorDBIndex:
"""Stores and retrives span vectors.
This wraps a regular vector store by adding a mapping from path keys, such as (rowid1, 0),
to span keys, such as (rowid1, 0, 0), which denotes the first span in the (rowid1, 0) document.
"""
def __init__(self, vector_store: str) -> None:
self._vector_store: VectorStore = get_vector_store_cls(vector_store)()
# Map a path key to spans for that path.
self._id_to_spans: dict[PathKey, list[tuple[int, int]]] = {}
def load(self, base_path: str) -> None:
"""Load the vector index from disk."""
assert not self._id_to_spans, 'Cannot load into a non-empty index.'
with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'rb') as f:
self._id_to_spans.update(pickle.load(f))
self._vector_store.load(os.path.join(base_path, self._vector_store.name))
def save(self, base_path: str) -> None:
"""Save the vector index to disk."""
assert self._id_to_spans, 'Cannot save an empty index.'
with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'wb') as f:
pickle.dump(list(self._id_to_spans.items()), f)
self._vector_store.save(os.path.join(base_path, self._vector_store.name))
def add(self, all_spans: list[tuple[PathKey, list[tuple[int, int]]]],
embeddings: np.ndarray) -> None:
"""Add the given spans and embeddings.
Args:
all_spans: The spans to initialize the index with.
embeddings: The embeddings to initialize the index with.
"""
assert not self._id_to_spans, 'Cannot add to a non-empty index.'
self._id_to_spans.update(all_spans)
vector_keys = [(*path_key, i) for path_key, spans in all_spans for i in range(len(spans))]
assert len(vector_keys) == len(embeddings), (
f'Number of spans ({len(vector_keys)}) and embeddings ({len(embeddings)}) must match.')
self._vector_store.add(vector_keys, embeddings)
def get_vector_store(self) -> VectorStore:
"""Return the underlying vector store."""
return self._vector_store
def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]:
"""Return the spans with vectors for each key in `keys`.
Args:
keys: The keys to return the vectors for.
Returns
The span vectors for the given keys.
"""
all_spans: list[list[tuple[int, int]]] = []
vector_keys: list[VectorKey] = []
for path_key in keys:
spans = self._id_to_spans[path_key]
all_spans.append(spans)
vector_keys.extend([(*path_key, i) for i in range(len(spans))])
all_vectors = self._vector_store.get(vector_keys)
offset = 0
for spans in all_spans:
vectors = all_vectors[offset:offset + len(spans)]
yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)]
offset += len(spans)
def topk(self,
query: np.ndarray,
k: int,
path_keys: Optional[Iterable[PathKey]] = None) -> list[tuple[PathKey, float]]:
"""Return the top k most similar vectors.
Args:
query: The query vector.
k: The number of results to return.
path_keys: Optional key prefixes to restrict the search to.
Returns
A list of (key, score) tuples.
"""
span_keys: Optional[list[VectorKey]] = None
if path_keys is not None:
span_keys = [
(*path_key, i) for path_key in path_keys for i in range(len(self._id_to_spans[path_key]))
]
span_k = k
path_key_scores: dict[PathKey, float] = {}
total_num_span_keys = self._vector_store.size()
while (len(path_key_scores) < k and span_k < total_num_span_keys and
(not span_keys or span_k < len(span_keys))):
span_k += k
vector_key_scores = self._vector_store.topk(query, span_k, span_keys)
for (*path_key_list, _), score in vector_key_scores:
path_key = tuple(path_key_list)
if path_key not in path_key_scores:
path_key_scores[path_key] = score
return list(path_key_scores.items())[:k]
VECTOR_STORE_REGISTRY: dict[str, Type[VectorStore]] = {}
def register_vector_store(vector_store_cls: Type[VectorStore]) -> None:
"""Register a vector store in the global registry."""
if vector_store_cls.name in VECTOR_STORE_REGISTRY:
raise ValueError(f'Vector store "{vector_store_cls.name}" has already been registered!')
VECTOR_STORE_REGISTRY[vector_store_cls.name] = vector_store_cls
def get_vector_store_cls(vector_store_name: str) -> Type[VectorStore]:
"""Return a registered vector store given the name in the registry."""
return VECTOR_STORE_REGISTRY[vector_store_name]
def clear_vector_store_registry() -> None:
"""Clear the vector store registry."""
VECTOR_STORE_REGISTRY.clear()
|