tykiww's picture
Update services/qa_service/qna.py
7929d48 verified
raw
history blame
2.46 kB
import json
from services.qa_service.utils import format_prompt
class QAService:
def __init__(self, conf, pinecone, model_pipeline, question, goals, session_key, keycheck):
self.conf = conf
if keycheck:
self.sess_key = session_key
self.pinecones = pinecone.run(namespace=self.sess_key)
else:
self.sess_key = self.conf["embeddings"]["demo_namespace"]
self.pinecones = pinecone.run(namespace=self.sess_key)
self.pc = self.pinecones['connection']
self.embedder = self.pinecones['embedder']
self.model_pipeline = model_pipeline
self.question = question
self.goals = goals
def __enter__(self):
print("Start Q&A Service")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print("Exiting Q&A Service")
def parse_results(self, result):
parsed = []
for i in result['matches']:
collect = i['metadata']['_node_content']
content = json.loads(collect)
parsed.append({
"speakers": content["metadata"]["speakers"],
"text": content["text"]
})
return parsed
def retrieve_context(self):
"""Pass embedded question into pinecone"""
embedded_query = self.embedder.get_text_embedding(self.question)
print("session key: "+self.sess_key)
print("index name: "+self.conf['embeddings']['index_name'])
index = self.pc.Index(self.conf['embeddings']['index_name'])
result = index.query(
vector=embedded_query,
namespace=self.sess_key, # I think namespace comes somewhere here during querying!!!
top_k=5,
include_values=False,
include_metadata=True
)
output = self.parse_results(result)
return output
def run(self):
"""Query pinecone outputs and infer results"""
full_context = self.retrieve_context()
transcript_count = len(full_context)
context = ""
for i in range(transcript_count):
context = "Transcript " + (i + 1)
context += i["text"]
if (i+1) < transcript_count:
context += "\n\n"
prompt = format_prompt(self.question, context)
output = self.model_pipeline.infer(prompt)
return output, context