Spaces:
Running
on
Zero
Running
on
Zero
import yaml | |
from easydict import EasyDict as edict | |
from langchain.prompts import PromptTemplate | |
from langchain_community.chat_models import ChatOpenAI | |
from pathlib import Path | |
from langchain.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain_community.chat_models import AzureChatOpenAI | |
from langchain.chains import LLMChain | |
import logging | |
LLM_ENV = yaml.safe_load(open('config/llm_env.yml', 'r')) | |
class Color: | |
RED = '\033[91m' | |
GREEN = '\033[92m' | |
YELLOW = '\033[93m' | |
BLUE = '\033[94m' | |
END = '\033[0m' # Reset to default color | |
def get_llm(config: dict): | |
""" | |
Returns the LLM model | |
:param config: dictionary with the configuration | |
:return: The llm model | |
""" | |
if 'temperature' not in config: | |
temperature = 0 | |
else: | |
temperature = config['temperature'] | |
if 'model_kwargs' in config: | |
model_kwargs = config['model_kwargs'] | |
else: | |
model_kwargs = {} | |
if config['type'] == 'OpenAI': | |
if LLM_ENV['openai']['OPENAI_ORGANIZATION'] == '': | |
return ChatOpenAI(temperature=temperature, model_name=config['name'], | |
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']), | |
openai_api_base=config.get('openai_api_base', LLM_ENV['openai']['OPENAI_API_BASE']), | |
model_kwargs=model_kwargs) | |
else: | |
return ChatOpenAI(temperature=temperature, model_name=config['name'], | |
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']), | |
openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'), | |
openai_organization=config.get('openai_organization', LLM_ENV['openai']['OPENAI_ORGANIZATION']), | |
model_kwargs=model_kwargs) | |
elif config['type'] == 'Azure': | |
return AzureChatOpenAI(temperature=temperature, azure_deployment=config['name'], | |
openai_api_key=config.get('openai_api_key', LLM_ENV['azure']['AZURE_OPENAI_API_KEY']), | |
azure_endpoint=config.get('azure_endpoint', LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT']), | |
openai_api_version=config.get('openai_api_version', LLM_ENV['azure']['OPENAI_API_VERSION'])) | |
elif config['type'] == 'Google': | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
return ChatGoogleGenerativeAI(temperature=temperature, model=config['name'], | |
google_api_key=LLM_ENV['google']['GOOGLE_API_KEY'], | |
model_kwargs=model_kwargs) | |
elif config['type'] == 'HuggingFacePipeline': | |
device = config.get('gpu_device', -1) | |
device_map = config.get('device_map', None) | |
return HuggingFacePipeline.from_model_id( | |
model_id=config['name'], | |
task="text-generation", | |
pipeline_kwargs={"max_new_tokens": config['max_new_tokens']}, | |
device=device, | |
device_map=device_map | |
) | |
else: | |
raise NotImplementedError("LLM not implemented") | |
def load_yaml(yaml_path: str, as_edict: bool = True) -> edict: | |
""" | |
Reads the yaml file and enrich it with more vales. | |
:param yaml_path: The path to the yaml file | |
:param as_edict: If True, returns an EasyDict configuration | |
:return: An EasyDict configuration | |
""" | |
with open(yaml_path, 'r') as file: | |
yaml_data = yaml.safe_load(file) | |
if 'meta_prompts' in yaml_data.keys() and 'folder' in yaml_data['meta_prompts']: | |
yaml_data['meta_prompts']['folder'] = Path(yaml_data['meta_prompts']['folder']) | |
if as_edict: | |
yaml_data = edict(yaml_data) | |
return yaml_data | |
def load_prompt(prompt_path: str) -> PromptTemplate: | |
""" | |
Reads and returns the contents of a prompt file. | |
:param prompt_path: The path to the prompt file | |
""" | |
with open(prompt_path, 'r') as file: | |
prompt = file.read().rstrip() | |
return PromptTemplate.from_template(prompt) | |
def validate_generation_config(base_config, generation_config): | |
if "annotator" not in generation_config: | |
raise Exception("Generation config must contain an empty annotator.") | |
if "label_schema" not in generation_config.dataset or \ | |
base_config.dataset.label_schema != generation_config.dataset.label_schema: | |
raise Exception("Generation label schema must match the basic config.") | |
def modify_input_for_ranker(config, task_description, initial_prompt): | |
modifiers_config = yaml.safe_load(open('prompts/modifiers/modifiers.yml', 'r')) | |
task_desc_setup = load_prompt(modifiers_config['ranker']['task_desc_mod']) | |
init_prompt_setup = load_prompt(modifiers_config['ranker']['prompt_mod']) | |
llm = get_llm(config.llm) | |
task_llm_chain = LLMChain(llm=llm, prompt=task_desc_setup) | |
task_result = task_llm_chain( | |
{"task_description": task_description}) | |
mod_task_desc = task_result['text'] | |
logging.info(f"Task description modified for ranking to: \n{mod_task_desc}") | |
prompt_llm_chain = LLMChain(llm=llm, prompt=init_prompt_setup) | |
prompt_result = prompt_llm_chain({"prompt": initial_prompt, 'label_schema': config.dataset.label_schema}) | |
mod_prompt = prompt_result['text'] | |
logging.info(f"Initial prompt modified for ranking to: \n{mod_prompt}") | |
return mod_prompt, mod_task_desc | |
def override_config(override_config_file, config_file='config/config_default.yml'): | |
""" | |
Override the default configuration file with the override configuration file | |
:param config_file: The default configuration file | |
:param override_config_file: The override configuration file | |
""" | |
def override_dict(config_dict, override_config_dict): | |
for key, value in override_config_dict.items(): | |
if isinstance(value, dict): | |
if key not in config_dict: | |
config_dict[key] = value | |
else: | |
override_dict(config_dict[key], value) | |
else: | |
config_dict[key] = value | |
return config_dict | |
default_config_dict = load_yaml(config_file, as_edict=False) | |
override_config_dict = load_yaml(override_config_file, as_edict=False) | |
config_dict = override_dict(default_config_dict, override_config_dict) | |
return edict(config_dict) | |