Spaces:
Runtime error
Runtime error
"""Embedding registry.""" | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import Callable, Generator, Iterable, Iterator, Optional, Union, cast | |
import numpy as np | |
from pydantic import StrictStr | |
from sklearn.preprocessing import normalize | |
from ..schema import ( | |
EMBEDDING_KEY, | |
TEXT_SPAN_END_FEATURE, | |
TEXT_SPAN_START_FEATURE, | |
VALUE_KEY, | |
Item, | |
RichData, | |
SpanVector, | |
lilac_embedding, | |
) | |
from ..signal import TextEmbeddingSignal, get_signal_by_type | |
from ..splitters.chunk_splitter import TextChunk | |
from ..utils import chunks | |
EmbeddingId = Union[StrictStr, TextEmbeddingSignal] | |
EmbedFn = Callable[[Iterable[RichData]], Iterator[list[SpanVector]]] | |
def get_embed_fn(embedding_name: str, split: bool) -> EmbedFn: | |
"""Return a function that returns the embedding matrix for the given embedding signal.""" | |
embedding_cls = get_signal_by_type(embedding_name, TextEmbeddingSignal) | |
embedding = embedding_cls(split=split) | |
embedding.setup() | |
def _embed_fn(data: Iterable[RichData]) -> Iterator[list[SpanVector]]: | |
items = embedding.compute(data) | |
for item in items: | |
if not item: | |
raise ValueError('Embedding signal returned None.') | |
yield [{ | |
'vector': item_val[EMBEDDING_KEY].reshape(-1), | |
'span': | |
(item_val[VALUE_KEY][TEXT_SPAN_START_FEATURE], item_val[VALUE_KEY][TEXT_SPAN_END_FEATURE]) | |
} for item_val in item] | |
return _embed_fn | |
def compute_split_embeddings(docs: Iterable[str], | |
batch_size: int, | |
embed_fn: Callable[[list[str]], list[np.ndarray]], | |
split_fn: Optional[Callable[[str], list[TextChunk]]] = None, | |
num_parallel_requests: int = 1) -> Generator[Item, None, None]: | |
"""Compute text embeddings in batches of chunks, using the provided splitter and embedding fn.""" | |
pool = ThreadPoolExecutor() | |
def _splitter(doc: str) -> list[TextChunk]: | |
if not doc: | |
return [] | |
if split_fn: | |
return split_fn(doc) | |
else: | |
# Return a single chunk that spans the entire document. | |
return [(doc, (0, len(doc)))] | |
num_docs = 0 | |
def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]: | |
"""Split a batch of documents into chunks and yield them.""" | |
nonlocal num_docs | |
for i, doc in enumerate(docs): | |
num_docs += 1 | |
chunks = _splitter(doc) | |
for chunk in chunks: | |
yield (i, chunk) | |
doc_chunks = _flat_split_batch_docs(docs) | |
items_to_yield: Optional[list[Item]] = None | |
current_index = 0 | |
mega_batch_size = batch_size * num_parallel_requests | |
for batch in chunks(doc_chunks, mega_batch_size): | |
texts = [text for _, (text, _) in batch] | |
embeddings: list[np.ndarray] = [] | |
for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))): | |
embeddings.extend(x) | |
matrix = cast(np.ndarray, normalize(np.array(embeddings, dtype=np.float32))) | |
# np.split returns a shallow copy of each embedding so we don't increase the mem footprint. | |
embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0])) | |
for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch): | |
embedding = embedding.reshape(-1) | |
if index == current_index: | |
if items_to_yield is None: | |
items_to_yield = [] | |
items_to_yield.append(lilac_embedding(start, end, embedding)) | |
else: | |
yield items_to_yield | |
current_index += 1 | |
while current_index < index: | |
yield None | |
current_index += 1 | |
items_to_yield = [lilac_embedding(start, end, embedding)] | |
while current_index < num_docs: | |
yield items_to_yield | |
items_to_yield = None | |
current_index += 1 | |