File size: 1,906 Bytes
697eefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pathlib import Path

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from transformers import pipeline
from preprocessing import stride_sentences
from fetch_transcript import zip_transcript


class Engine:
    def __init__(self, transcript:list) -> None:

        self.base_path = Path('./models')

        self.qa_model_name = 'QA_Model'
        self.qa_model_path = self.base_path / self.qa_model_name
        self.qa_model = pipeline('question-answering',model=str(self.qa_model_path))

        self.sim_model_name = 'Similarity_Model'
        self.sim_model_path = self.base_path / self.sim_model_name
        self.sim_model = SentenceTransformer(self.sim_model_path)
        
        self.timestamps, self.texts = zip_transcript(transcript).values()

        self.stride = 10
        self.text_groups = stride_sentences(self.texts,self.stride)
        
        self.embeddings = self._encode_transcript()


    def _encode_transcript(self):
        return self.sim_model.encode(self.text_groups)


    def ask(self, question_text:str):

        result = self.qa_model(
            question=question_text,
            context=' '.join(self.text_groups).strip(),
            doc_stride=256,
            max_answer_len=512,
            max_question_len=128,
        )
        return result['answer']


    def find_similar(self, txt:str, top_k=1):
        txt = self.sim_model.encode(txt)
        similarities:torch.Tensor = cos_sim(txt,self.embeddings)
        similarities = similarities.reshape(-1)
        indices = list(torch.argsort(similarities))
        indices = [idx.item() for idx in indices[::-1]][:top_k]
        groups = [self.text_groups[i] for i in indices]
        timestamps = [self.timestamps[self.stride*i] for i in indices]
        return groups, timestamps


if __name__ == '__main__':
    model = Engine()