chat-llm-streaming / retrieval.py
star-nox
fixed pinecone
d87617b
raw
history blame
No virus
2.64 kB
import json
import os
import pathlib
import sys
import time
from typing import Any, Dict, List
import pinecone # cloud-hosted vector database for context retrieval
# for vector search
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Pinecone
from dotenv import load_dotenv
from PIL import Image
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, GPT2Tokenizer, OPTForCausalLM, T5ForConditionalGeneration)
class Retrieval:
def __init__(self,
device='cuda',
use_clip=True):
self.user_question = ''
self.max_text_length = None
self.pinecone_index_name = 'uiuc-chatbot' # uiuc-chatbot-v2
self.use_clip = use_clip
# init parameters
self.device = device
self.num_answers_generated = 3
self.vectorstore = None
def _load_pinecone_vectorstore(self,):
model_name = "intfloat/e5-large" # best text embedding model. 1024 dims.
embeddings = HuggingFaceEmbeddings(model_name=model_name)
#pinecone.init(api_key=os.environ['PINECONE_API_KEY'], environment="us-west1-gcp")
pinecone.init(api_key=PINECONE_API_KEY, environment="us-west1-gcp")
pincecone_index = pinecone.Index("uiuc-chatbot")
print("PINECONE: ")
print(pinecone.list_indexes())
print(pincecone_index.describe_index_stats())
self.vectorstore = Pinecone(index=pincecone_index, embedding_function=embeddings.embed_query, text_key="text")
def retrieve_contexts_from_pinecone(self, user_question: str, topk: int = None) -> List[Any]:
'''
Invoke Pinecone for vector search. These vector databases are created in the notebook `data_formatting_patel.ipynb` and `data_formatting_student_notes.ipynb`.
Returns a list of LangChain Documents. They have properties: `doc.page_content`: str, doc.metadata['page_number']: int, doc.metadata['textbook_name']: str.
'''
print("USER QUESTION: ", user_question)
print("TOPK: ", topk)
if topk is None:
topk = self.num_answers_generated
# similarity search
top_context_list = self.vectorstore.similarity_search(user_question, k=topk)
# add the source info to the bottom of the context.
top_context_metadata = [f"Source: page {doc.metadata['page_number']} in {doc.metadata['textbook_name']}" for doc in top_context_list]
relevant_context_list = [f"{text.page_content}. {meta}" for text, meta in zip(top_context_list, top_context_metadata)]
return relevant_context_list