tykiww's picture
Update services/qa_service/qna.py
82fdbb3 verified
raw
history blame
1.33 kB
import json
from services.qa_service.utils import format_prompt
class QAService:
def __init__(self, conf, pinecone, model_pipeline, question, context):
self.conf = conf
self.pc = pinecone['connection']
self.embedder = pinecone['embedder']
self.model_pipeline = model_pipeline
self.question = question
self.context = context
def __enter__(self):
print("Start Q&A Service")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print("Exiting Q&A Service")
def retrieve_context(self):
"""Pass embedded question into pinecone"""
embedded_query = self.embedder.get_text_embedding(self.question)
pinecone_index = self.pc.Index(conf['embeddings']['index_name'])
result = pinecone_index.query(
vector=embedded_query,
top_k=1,
include_values=False,
include_metadata=True
)
output = json.loads(result['matches'][0]['metadata']['_node_content'])
return output
def run(self):
"""Query pinecone outputs and infer results"""
output = self.retrieve_context()
output = format_prompt(output)
output = self.model_pipeline.infer(output)
return output