File size: 6,428 Bytes
e1aa577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import yaml
from easydict import EasyDict as edict
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from pathlib import Path
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.chat_models import AzureChatOpenAI
from langchain.chains import LLMChain
import logging

LLM_ENV = yaml.safe_load(open('config/llm_env.yml', 'r'))


class Color:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'
    END = '\033[0m'  # Reset to default color


def get_llm(config: dict):
    """
    Returns the LLM model
    :param config: dictionary with the configuration
    :return: The llm model
    """
    if 'temperature' not in config:
        temperature = 0
    else:
        temperature = config['temperature']
    if 'model_kwargs' in config:
        model_kwargs = config['model_kwargs']
    else:
        model_kwargs = {}
    if config['type'] == 'OpenAI':
        if LLM_ENV['openai']['OPENAI_ORGANIZATION'] == '':
            return ChatOpenAI(temperature=temperature, model_name=config['name'],
                              openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
                              openai_api_base=config.get('openai_api_base', LLM_ENV['openai']['OPENAI_API_BASE']),
                              model_kwargs=model_kwargs)
        else:
            return ChatOpenAI(temperature=temperature, model_name=config['name'],
                              openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
                              openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'),
                              openai_organization=config.get('openai_organization', LLM_ENV['openai']['OPENAI_ORGANIZATION']),
                              model_kwargs=model_kwargs)
    elif config['type'] == 'Azure':
        return AzureChatOpenAI(temperature=temperature, azure_deployment=config['name'],
                        openai_api_key=config.get('openai_api_key', LLM_ENV['azure']['AZURE_OPENAI_API_KEY']),
                        azure_endpoint=config.get('azure_endpoint', LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT']),
                        openai_api_version=config.get('openai_api_version', LLM_ENV['azure']['OPENAI_API_VERSION']))

    elif config['type'] == 'Google':
        from langchain_google_genai import ChatGoogleGenerativeAI
        return ChatGoogleGenerativeAI(temperature=temperature, model=config['name'],
                              google_api_key=LLM_ENV['google']['GOOGLE_API_KEY'],
                              model_kwargs=model_kwargs)


    elif config['type'] == 'HuggingFacePipeline':
        device = config.get('gpu_device', -1)
        device_map = config.get('device_map', None)

        return HuggingFacePipeline.from_model_id(
            model_id=config['name'],
            task="text-generation",
            pipeline_kwargs={"max_new_tokens": config['max_new_tokens']},
            device=device,
            device_map=device_map
        )
    else:
        raise NotImplementedError("LLM not implemented")


def load_yaml(yaml_path: str, as_edict: bool = True) -> edict:
    """
    Reads the yaml file and enrich it with more vales.
    :param yaml_path: The path to the yaml file
    :param as_edict: If True, returns an EasyDict configuration
    :return: An EasyDict configuration
    """
    with open(yaml_path, 'r') as file:
        yaml_data = yaml.safe_load(file)
        if 'meta_prompts' in yaml_data.keys() and 'folder' in yaml_data['meta_prompts']:
            yaml_data['meta_prompts']['folder'] = Path(yaml_data['meta_prompts']['folder'])
    if as_edict:
        yaml_data = edict(yaml_data)
    return yaml_data


def load_prompt(prompt_path: str) -> PromptTemplate:
    """
    Reads and returns the contents of a prompt file.
    :param prompt_path: The path to the prompt file
    """
    with open(prompt_path, 'r') as file:
        prompt = file.read().rstrip()
    return PromptTemplate.from_template(prompt)


def validate_generation_config(base_config, generation_config):
    if "annotator" not in generation_config:
        raise Exception("Generation config must contain an empty annotator.")
    if "label_schema" not in generation_config.dataset or \
            base_config.dataset.label_schema != generation_config.dataset.label_schema:
        raise Exception("Generation label schema must match the basic config.")


def modify_input_for_ranker(config, task_description, initial_prompt):
    modifiers_config = yaml.safe_load(open('prompts/modifiers/modifiers.yml', 'r'))
    task_desc_setup = load_prompt(modifiers_config['ranker']['task_desc_mod'])
    init_prompt_setup = load_prompt(modifiers_config['ranker']['prompt_mod'])

    llm = get_llm(config.llm)
    task_llm_chain = LLMChain(llm=llm, prompt=task_desc_setup)
    task_result = task_llm_chain(
        {"task_description": task_description})
    mod_task_desc = task_result['text']
    logging.info(f"Task description modified for ranking to: \n{mod_task_desc}")

    prompt_llm_chain = LLMChain(llm=llm, prompt=init_prompt_setup)
    prompt_result = prompt_llm_chain({"prompt": initial_prompt, 'label_schema': config.dataset.label_schema})
    mod_prompt = prompt_result['text']
    logging.info(f"Initial prompt modified for ranking to: \n{mod_prompt}")

    return mod_prompt, mod_task_desc


def override_config(override_config_file, config_file='config/config_default.yml'):
    """
    Override the default configuration file with the override configuration file
    :param config_file: The default configuration file
    :param override_config_file: The override configuration file
    """

    def override_dict(config_dict, override_config_dict):
        for key, value in override_config_dict.items():
            if isinstance(value, dict):
                if key not in config_dict:
                    config_dict[key] = value
                else:
                    override_dict(config_dict[key], value)
            else:
                config_dict[key] = value
        return config_dict

    default_config_dict = load_yaml(config_file, as_edict=False)
    override_config_dict = load_yaml(override_config_file, as_edict=False)
    config_dict = override_dict(default_config_dict, override_config_dict)
    return edict(config_dict)