from langchain.chains.openai_functions import ( create_structured_output_runnable) from utils.config import get_llm, load_prompt from langchain_community.callbacks import get_openai_callback import asyncio from langchain.chains import LLMChain import importlib from pathlib import Path from tqdm import trange, tqdm import concurrent.futures import logging class DummyCallback: """ A dummy callback for the LLM. This is a trick to handle an empty callback. """ def __enter__(self): self.total_cost = 0 return self def __exit__(self, exc_type, exc_value, traceback): pass def get_dummy_callback(): return DummyCallback() class ChainWrapper: """ A wrapper for a LLM chain """ def __init__(self, llm_config, prompt_path: str, json_schema: dict = None, parser_func=None): """ Initialize a new instance of the ChainWrapper class. :param llm_config: The config for the LLM :param prompt_path: A path to the prompt file (text file) :param json_schema: A dict for the json schema, to get a structured output for the LLM :param parser_func: A function to parse the output of the LLM """ self.llm_config = llm_config self.llm = get_llm(llm_config) self.json_schema = json_schema self.parser_func = parser_func self.prompt = load_prompt(prompt_path) self.build_chain() self.accumulate_usage = 0 if self.llm_config.type == 'OpenAI': self.callback = get_openai_callback else: self.callback = get_dummy_callback def invoke(self, chain_input: dict) -> dict: """ Invoke the chain on a single input :param chain_input: The input for the chain :return: A dict with the defined json schema """ with self.callback() as cb: try: result = self.chain.invoke(chain_input) if self.parser_func is not None: result = self.parser_func(result) except Exception as e: #raise e #if e.http_status == 401: # raise e #else: #logging.error('Error in chain invoke: {}'.format(e.user_message)) result = None self.accumulate_usage += cb.total_cost return result async def retry_operation(self, tasks): """ Retry an async operation :param tasks: :return: """ delay = self.llm_config.async_params.retry_interval timeout = delay * self.llm_config.async_params.max_retries start_time = asyncio.get_event_loop().time() end_time = start_time + timeout results = [] while True: remaining_time = end_time - asyncio.get_event_loop().time() if remaining_time <= 0: print("Timeout reached. Operation incomplete.") break done, pending = await asyncio.wait(tasks, timeout=delay) results += list(done) if len(done) == len(tasks): print("All tasks completed successfully.") break if not pending: print("No pending tasks. Operation incomplete.") break tasks = list(pending) # Retry with the pending tasks return results async def async_batch_invoke(self, inputs: list[dict]) -> list[dict]: """ Invoke the chain on a batch of inputs in async mode :param inputs: A batch of inputs :return: A list of dicts with the defined json schema """ with self.callback() as cb: tasks = [self.chain.ainvoke(chain_input) for chain_input in inputs] all_res = await self.retry_operation(tasks) self.accumulate_usage += cb.total_cost if self.parser_func is not None: return [self.parser_func(t.result()) for t in list(all_res)] return [t.result() for t in list(all_res)] def batch_invoke(self, inputs: list[dict], num_workers: int): """ Invoke the chain on a batch of inputs either async or not :param inputs: The list of all inputs :param num_workers: The number of workers :return: A list of results """ def sample_generator(): for sample in inputs: yield sample def process_sample_with_progress(sample): result = self.invoke(sample) pbar.update(1) # Update the progress bar return result if not ('async_params' in self.llm_config.keys()): # non async mode, use regular workers with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: with tqdm(total=len(inputs), desc="Processing samples") as pbar: all_results = list(executor.map(process_sample_with_progress, sample_generator())) else: all_results = [] for i in trange(0, len(inputs), num_workers, desc='Predicting'): results = asyncio.run(self.async_batch_invoke(inputs[i:i + num_workers])) all_results += results all_results = [res for res in all_results if res is not None] return all_results def build_chain(self): """ Build the chain according to the LLM type """ if (self.llm_config.type == 'OpenAI' or self.llm_config.type == 'Azure') and self.json_schema is not None: #self.chain = LLMChain(llm=self.llm, prompt=self.prompt) self.chain = create_structured_output_runnable(self.json_schema, self.llm, self.prompt) else: self.chain = LLMChain(llm=self.llm, prompt=self.prompt) def get_chain_metadata(prompt_fn: Path, retrieve_module: bool = False) -> dict: """ Get the metadata of the chain :param prompt_fn: The path to the prompt file :param retrieve_module: If True, retrieve the module :return: A dict with the metadata """ prompt_directory = str(prompt_fn.parent) prompt_name = str(prompt_fn.stem) try: spec = importlib.util.spec_from_file_location('output_schemes', prompt_directory + '/output_schemes.py') schema_parser = importlib.util.module_from_spec(spec) spec.loader.exec_module(schema_parser) except ImportError as e: print(f"Error loading module {prompt_directory + '/output_schemes'}: {e}") if hasattr(schema_parser, '{}_schema'.format(prompt_name)): json_schema = getattr(schema_parser, '{}_schema'.format(prompt_name)) else: json_schema = None if hasattr(schema_parser, '{}_parser'.format(prompt_name)): parser_func = getattr(schema_parser, '{}_parser'.format(prompt_name)) else: parser_func = None result = {'json_schema': json_schema, 'parser_func': parser_func} if retrieve_module: result['module'] = schema_parser return result class MetaChain: """ A wrapper for the meta-prompts chain """ def __init__(self, config): """ Initialize a new instance of the MetaChain class. Loading all the meta-prompts :param config: An EasyDict configuration """ self.config = config self.initial_chain = self.load_chain('initial') self.step_prompt_chain = self.load_chain('step_prompt') self.step_samples = self.load_chain('step_samples') self.error_analysis = self.load_chain('error_analysis') def load_chain(self, chain_name: str) -> ChainWrapper: """ Load a chain according to the chain name :param chain_name: The name of the chain """ metadata = get_chain_metadata(self.config.meta_prompts.folder / '{}.prompt'.format(chain_name)) return ChainWrapper(self.config.llm, self.config.meta_prompts.folder / '{}.prompt'.format(chain_name), metadata['json_schema'], metadata['parser_func']) def calc_usage(self) -> float: """ Calculate the usage of all the meta-prompts :return: The total usage value """ return self.initial_chain.accumulate_usage + self.step_prompt_chain.accumulate_usage \ + self.step_samples.accumulate_usage