File size: 1,953 Bytes
74b7d77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy


class JsonPromptSegment():
    def __init__(self, segment_index: int, prompt: str, format_prompt: bool = False):
        self.prompt = prompt
        self.segment_index = segment_index
        self.format_prompt = format_prompt

class JsonPromptStrategy(AbstractPromptStrategy):
    def __init__(self, initial_json_prompt: str):
        """
        Parameters
        ----------
            initial_json_prompt: str
                The initial prompts for each segment in JSON form.

                Format:
                [
                    {"segment_index": 0, "prompt": "Hello, how are you?"},
                    {"segment_index": 1, "prompt": "I'm doing well, how are you?"},
                    {"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
                ]
                    
        """
        parsed_json = json.loads(initial_json_prompt)
        self.segment_lookup = dict[str, JsonPromptSegment]()
        
        for prompt_entry in parsed_json:
            segment_index = prompt_entry["segment_index"]
            prompt = prompt_entry["prompt"]
            format_prompt = prompt_entry.get("format_prompt", False)
            self.segment_lookup[str(segment_index)] = JsonPromptSegment(segment_index, prompt, format_prompt)

    def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
        # Lookup prompt
        prompt = self.segment_lookup.get(str(segment_index), None)

        if (prompt is None):
            # No prompt found, return whisper prompt
            print(f"Could not find prompt for segment {segment_index}, returning whisper prompt")
            return whisper_prompt

        if (prompt.format_prompt):
            return prompt.prompt.format(whisper_prompt)
        else:
            return self._concat_prompt(prompt.prompt, whisper_prompt)