File size: 1,752 Bytes
b3fc90b
 
82fdbb3
 
b579ccf
 
b3fc90b
1fac75f
b3fc90b
 
0cd2aee
b3fc90b
23e9e9c
b3fc90b
1fac75f
b3fc90b
 
 
 
 
 
 
0cd2aee
 
 
8f62329
b579ccf
 
 
 
5794520
b579ccf
 
8f62329
b579ccf
b3fc90b
 
 
 
 
0cd2aee
b3fc90b
3d397a8
b3fc90b
 
 
0cd2aee
c99c9ac
b3fc90b
 
 
 
b579ccf
97fb67e
a91dc62
 
6afbc36
 
b3fc90b
a91dc62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json

from services.qa_service.utils import format_prompt



class QAService:
    def __init__(self, conf, pinecone, model_pipeline, question, goals):
        self.conf = conf
        self.pc = pinecone['connection']
        self.pc_index = self.pc.Index(self.conf['embeddings']['index_name'])
        self.embedder = pinecone['embedder']
        self.model_pipeline = model_pipeline
        self.question = question
        self.goals = goals
    
    def __enter__(self):
        print("Start Q&A Service")
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting Q&A Service")

    def parse_results(self, result):

        parsed = []
        for i in result['matches']:
            collect = i['metadata']['_node_content']
            content = json.loads(collect)
            parsed.append({
                "speakers": content["metadata"]["speakers"],
                "text": content["text"]
            })
            
        return parsed
    
    def retrieve_context(self):
        """Pass embedded question into pinecone"""
        embedded_query = self.embedder.get_text_embedding(self.question)
        
        result = self.pc_index.query(
            vector=embedded_query,
            top_k=5,
            include_values=False,
            include_metadata=True
        )

        output = self.parse_results(result)
        return output
    
    def run(self):
        """Query pinecone outputs and infer results"""
        full_context = self.retrieve_context()
        
        context = '\n'.join([i["text"] for i in full_context])

        prompt = format_prompt(self.question, context)
        output = self.model_pipeline.infer(prompt)
        
        return output, context