File size: 6,583 Bytes
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Interface for storing vectors."""

import abc
import os
import pickle
from typing import Iterable, Optional, Type

import numpy as np

from ..schema import SpanVector, VectorKey
from ..utils import open_file


class VectorStore(abc.ABC):
  """Interface for storing and retrieving vectors."""

  # The global name of the vector store.
  name: str

  @abc.abstractmethod
  def save(self, base_path: str) -> None:
    """Save the store to disk."""
    pass

  @abc.abstractmethod
  def load(self, base_path: str) -> None:
    """Load the store from disk."""
    pass

  @abc.abstractmethod
  def size(self) -> int:
    """Return the number of vectors in the store."""
    pass

  @abc.abstractmethod
  def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
    """Add or edit the given keyed embeddings to the store.

    If the keys already exist they will be overwritten, acting as an "upsert".

    Args:
      keys: The keys to add the embeddings for.
      embeddings: The embeddings to add. This should be a 2D matrix with the same length as keys.
    """
    pass

  @abc.abstractmethod
  def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
    """Return the embeddings for given keys.

    Args:
      keys: The keys to return the embeddings for. If None, return all embeddings.

    Returns
      The embeddings for the given keys.
    """
    pass

  def topk(self,
           query: np.ndarray,
           k: int,
           keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
    """Return the top k most similar vectors.

    Args:
      query: The query vector.
      k: The number of results to return.
      keys: Optional keys to restrict the search to.

    Returns
      A list of (key, score) tuples.
    """
    raise NotImplementedError


PathKey = VectorKey

_SPANS_PICKLE_NAME = 'spans.pkl'


class VectorDBIndex:
  """Stores and retrives span vectors.

  This wraps a regular vector store by adding a mapping from path keys, such as (rowid1, 0),
  to span keys, such as (rowid1, 0, 0), which denotes the first span in the (rowid1, 0) document.
  """

  def __init__(self, vector_store: str) -> None:
    self._vector_store: VectorStore = get_vector_store_cls(vector_store)()
    # Map a path key to spans for that path.
    self._id_to_spans: dict[PathKey, list[tuple[int, int]]] = {}

  def load(self, base_path: str) -> None:
    """Load the vector index from disk."""
    assert not self._id_to_spans, 'Cannot load into a non-empty index.'
    with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'rb') as f:
      self._id_to_spans.update(pickle.load(f))
    self._vector_store.load(os.path.join(base_path, self._vector_store.name))

  def save(self, base_path: str) -> None:
    """Save the vector index to disk."""
    assert self._id_to_spans, 'Cannot save an empty index.'
    with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'wb') as f:
      pickle.dump(list(self._id_to_spans.items()), f)
    self._vector_store.save(os.path.join(base_path, self._vector_store.name))

  def add(self, all_spans: list[tuple[PathKey, list[tuple[int, int]]]],
          embeddings: np.ndarray) -> None:
    """Add the given spans and embeddings.

    Args:
      all_spans: The spans to initialize the index with.
      embeddings: The embeddings to initialize the index with.
    """
    assert not self._id_to_spans, 'Cannot add to a non-empty index.'
    self._id_to_spans.update(all_spans)
    vector_keys = [(*path_key, i) for path_key, spans in all_spans for i in range(len(spans))]
    assert len(vector_keys) == len(embeddings), (
      f'Number of spans ({len(vector_keys)}) and embeddings ({len(embeddings)}) must match.')
    self._vector_store.add(vector_keys, embeddings)

  def get_vector_store(self) -> VectorStore:
    """Return the underlying vector store."""
    return self._vector_store

  def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]:
    """Return the spans with vectors for each key in `keys`.

    Args:
      keys: The keys to return the vectors for.

    Returns
      The span vectors for the given keys.
    """
    all_spans: list[list[tuple[int, int]]] = []
    vector_keys: list[VectorKey] = []
    for path_key in keys:
      spans = self._id_to_spans[path_key]
      all_spans.append(spans)
      vector_keys.extend([(*path_key, i) for i in range(len(spans))])

    all_vectors = self._vector_store.get(vector_keys)
    offset = 0
    for spans in all_spans:
      vectors = all_vectors[offset:offset + len(spans)]
      yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)]
      offset += len(spans)

  def topk(self,
           query: np.ndarray,
           k: int,
           path_keys: Optional[Iterable[PathKey]] = None) -> list[tuple[PathKey, float]]:
    """Return the top k most similar vectors.

    Args:
      query: The query vector.
      k: The number of results to return.
      path_keys: Optional key prefixes to restrict the search to.

    Returns
      A list of (key, score) tuples.
    """
    span_keys: Optional[list[VectorKey]] = None
    if path_keys is not None:
      span_keys = [
        (*path_key, i) for path_key in path_keys for i in range(len(self._id_to_spans[path_key]))
      ]
    span_k = k
    path_key_scores: dict[PathKey, float] = {}
    total_num_span_keys = self._vector_store.size()
    while (len(path_key_scores) < k and span_k < total_num_span_keys and
           (not span_keys or span_k < len(span_keys))):
      span_k += k
      vector_key_scores = self._vector_store.topk(query, span_k, span_keys)
      for (*path_key_list, _), score in vector_key_scores:
        path_key = tuple(path_key_list)
        if path_key not in path_key_scores:
          path_key_scores[path_key] = score

    return list(path_key_scores.items())[:k]


VECTOR_STORE_REGISTRY: dict[str, Type[VectorStore]] = {}


def register_vector_store(vector_store_cls: Type[VectorStore]) -> None:
  """Register a vector store in the global registry."""
  if vector_store_cls.name in VECTOR_STORE_REGISTRY:
    raise ValueError(f'Vector store "{vector_store_cls.name}" has already been registered!')

  VECTOR_STORE_REGISTRY[vector_store_cls.name] = vector_store_cls


def get_vector_store_cls(vector_store_name: str) -> Type[VectorStore]:
  """Return a registered vector store given the name in the registry."""
  return VECTOR_STORE_REGISTRY[vector_store_name]


def clear_vector_store_registry() -> None:
  """Clear the vector store registry."""
  VECTOR_STORE_REGISTRY.clear()