Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import contextlib | |
import json | |
import logging | |
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type | |
import pandas as pd | |
import sqlalchemy | |
from langchain.docstore.document import Document | |
from langchain.schema.embeddings import Embeddings | |
from langchain.vectorstores.base import VectorStore | |
from models.article import Article | |
from models.distance import DistanceStrategy, distance_strategy_limit | |
from sqlalchemy import delete, text | |
from sqlalchemy.orm import Session | |
from utils import str_to_list | |
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE | |
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" | |
def _results_to_docs(docs_and_scores: Any) -> List[Document]: | |
"""Return docs from docs and scores.""" | |
return [doc for doc, _ in docs_and_scores] | |
class CustomVectorStore(VectorStore): | |
"""`Postgres`/`PGVector` vector store. | |
To use, you should have the ``pgvector`` python package installed. | |
Args: | |
connection: Postgres connection string. | |
embedding_function: Any embedding function implementing | |
`langchain.embeddings.base.Embeddings` interface. | |
table_name: The name of the collection to use. (default: langchain) | |
NOTE: This is not the name of the table, but the name of the collection. | |
The tables will be created when initializing the store (if not exists) | |
So, make sure the user has the right permissions to create tables. | |
distance_strategy: The distance strategy to use. (default: COSINE) | |
pre_delete_collection: If True, will delete the collection if it exists. | |
(default: False). Useful for testing. | |
Example: | |
.. code-block:: python | |
from langchain.vectorstores import PGVector | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
COLLECTION_NAME = "state_of_the_union_test" | |
embeddings = OpenAIEmbeddings() | |
vectorestore = PGVector.from_documents( | |
embedding=embeddings, | |
documents=docs, | |
table_name=COLLECTION_NAME, | |
connection=connection, | |
) | |
""" | |
def __init__( | |
self, | |
connection: sqlalchemy.engine.Connection, | |
embedding_function: Embeddings, | |
table_name: str, | |
column_name: str, | |
collection_metadata: Optional[dict] = None, | |
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, | |
pre_delete_collection: bool = False, | |
logger: Optional[logging.Logger] = None, | |
) -> None: | |
self._conn = connection | |
self.embedding_function = embedding_function | |
self.table_name = table_name | |
self.column_name = column_name | |
self.collection_metadata = collection_metadata | |
self._distance_strategy = distance_strategy | |
self.pre_delete_collection = pre_delete_collection | |
self.logger = logger or logging.getLogger(__name__) | |
self.__post_init__() | |
def __post_init__( | |
self, | |
) -> None: | |
""" | |
Initialize the store. | |
""" | |
# self._conn = self.connect() | |
self.EmbeddingStore = Article | |
def embeddings(self) -> Embeddings: | |
return self.embedding_function | |
def _make_session(self) -> Generator[Session, None, None]: | |
"""Create a context manager for the session, bind to _conn string.""" | |
yield Session(self._conn) | |
def add_embeddings( | |
self, | |
texts: Iterable[str], | |
embeddings: List[List[float]], | |
metadatas: Optional[List[dict]] = None, | |
ids: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
"""Add embeddings to the vectorstore. | |
Args: | |
texts: Iterable of strings to add to the vectorstore. | |
embeddings: List of list of embedding vectors. | |
metadatas: List of metadatas associated with the texts. | |
kwargs: vectorstore specific parameters | |
""" | |
if not metadatas: | |
metadatas = [{} for _ in texts] | |
with Session(self._conn) as session: | |
for txt, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): | |
embedding_store = self.EmbeddingStore( | |
embedding=embedding, | |
document=txt, | |
cmetadata=metadata, | |
custom_id=id, | |
) | |
session.add(embedding_store) | |
session.commit() | |
return ids | |
def add_texts( | |
self, | |
texts: Iterable[str], | |
metadatas: Optional[List[dict]] = None, | |
ids: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
"""Run more texts through the embeddings and add to the vectorstore. | |
Args: | |
texts: Iterable of strings to add to the vectorstore. | |
metadatas: Optional list of metadatas associated with the texts. | |
kwargs: vectorstore specific parameters | |
Returns: | |
List of ids from adding the texts into the vectorstore. | |
""" | |
embeddings = self.embedding_function.embed_documents(list(texts)) | |
return self.add_embeddings( | |
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs | |
) | |
def similarity_search( | |
self, | |
query: str, | |
k: int = 4, | |
filter: Optional[dict] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Run similarity search with PGVector with distance. | |
Args: | |
query (str): Query text to search for. | |
k (int): Number of results to return. Defaults to 4. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List of Documents most similar to the query. | |
""" | |
embedding = self.embedding_function.embed_query(text=query) | |
return self.similarity_search_by_vector( | |
embedding=embedding, | |
k=k, | |
) | |
def similarity_search_with_score( | |
self, | |
query: str, | |
k: int = 4, | |
filter: Optional[dict] = None, | |
) -> List[Tuple[Document, float]]: | |
"""Return docs most similar to query. | |
Args: | |
query: Text to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List of Documents most similar to the query and score for each. | |
""" | |
embedding = self.embedding_function.embed_query(query) | |
docs = self.similarity_search_with_score_by_vector(embedding=embedding, k=k) | |
return docs | |
def distance_strategy(self) -> str | None: | |
if self._distance_strategy == DistanceStrategy.EUCLIDEAN: | |
return "<->" | |
elif self._distance_strategy == DistanceStrategy.COSINE: | |
return "<=>" | |
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: | |
return "<#>" | |
else: | |
raise ValueError( | |
f"Got unexpected value for distance: {self._distance_strategy}. " | |
f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." | |
) | |
def similarity_search_with_score_by_vector( | |
self, | |
embedding: List[float], | |
k: int = 4, | |
) -> List[Tuple[Document, float]]: | |
results = self.__query_collection(embedding=embedding, k=k) | |
return self._results_to_docs_and_scores(results) | |
def _fetch_title(title: str, abstract: str): | |
if len(title) > 0: | |
return title | |
return abstract.split(".")[0] | |
def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: | |
"""Return docs and scores from results.""" | |
docs = [ | |
( | |
Document( | |
page_content=json.dumps( | |
{ | |
"title": self._fetch_title( | |
result["title"][0], result["abstract"][0] | |
), | |
"authors": result["authors"], | |
"keywords": result["keywords"], | |
} | |
), | |
metadata={ | |
"id": result["id"], | |
"doi": result["doi"], | |
"hal_id": result["hal_id"], | |
"distance": result["distance"], | |
"abstract": result["abstract"][0], | |
}, | |
), | |
result["distance"] if self.embedding_function is not None else None, | |
) | |
for result in results | |
] | |
return docs | |
def __query_collection( | |
self, | |
embedding: List[float], | |
k: int = 4, | |
) -> List[Any]: | |
"""Query the collection.""" | |
limit = distance_strategy_limit[self._distance_strategy] | |
with Session(self._conn) as session: | |
results = session.execute( | |
text( | |
f""" | |
select | |
a.id, | |
a.title_en, | |
a.doi, | |
a.hal_id, | |
a.abstract_en, | |
string_agg(distinct keyword."name", ', ') as keywords, | |
string_agg(distinct author."name", ', ') as authors, | |
abstract_embedding_en {self.distance_strategy} '{str(embedding)}' as distance | |
from article a | |
left join article_keyword ON article_keyword.article_id = a.id | |
left join keyword on article_keyword.keyword_id = keyword.id | |
left join article_author ON article_author.article_id = a.id | |
left join author on author.id = article_author.author_id | |
where | |
abstract_en != '' and | |
abstract_en != 'None' and | |
abstract_embedding_en {self.distance_strategy} '{str(embedding)}' < {limit} | |
GROUP BY a.id | |
ORDER BY distance | |
LIMIT 100; | |
""" | |
) | |
) | |
results = results.fetchall() | |
results = pd.DataFrame( | |
results, | |
columns=[ | |
"id", | |
"title", | |
"doi", | |
"hal_id", | |
"abstract", | |
"keywords", | |
"authors", | |
"distance", | |
], | |
) | |
results["abstract"] = results["abstract"].apply(str_to_list) | |
results["title"] = results["title"].apply(str_to_list) | |
results = results.to_dict(orient="records") | |
return results | |
def similarity_search_by_vector( | |
self, | |
embedding: List[float], | |
k: int = 4, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Return docs most similar to embedding vector. | |
Args: | |
embedding: Embedding to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. | |
Returns: | |
List of Documents most similar to the query vector. | |
""" | |
docs_and_scores = self.similarity_search_with_score_by_vector( | |
embedding=embedding, k=k | |
) | |
return _results_to_docs(docs_and_scores) | |
def from_texts( | |
cls: Type[CustomVectorStore], | |
texts: List[str], | |
embedding: Embeddings, | |
metadatas: Optional[List[dict]] = None, | |
table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, | |
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, | |
ids: Optional[List[str]] = None, | |
pre_delete_collection: bool = False, | |
**kwargs: Any, | |
) -> CustomVectorStore: | |
""" | |
Return VectorStore initialized from texts and embeddings. | |
Postgres connection string is required | |
"Either pass it as a parameter | |
or set the PGVECTOR_CONNECTION_STRING environment variable. | |
""" | |
embeddings = embedding.embed_documents(list(texts)) | |
return cls.__from( | |
texts, | |
embeddings, | |
embedding, | |
metadatas=metadatas, | |
ids=ids, | |
table_name=table_name, | |
distance_strategy=distance_strategy, | |
pre_delete_collection=pre_delete_collection, | |
**kwargs, | |
) | |