Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,376 Bytes
e1aa577 e91d22b e1aa577 e91d22b e1aa577 e91d22b e1aa577 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
from langchain.chains.openai_functions import (
create_structured_output_runnable)
from utils.config import get_llm, load_prompt
from langchain_community.callbacks import get_openai_callback
import asyncio
from langchain.chains import LLMChain
import importlib
from pathlib import Path
from tqdm import trange, tqdm
import concurrent.futures
import logging
class DummyCallback:
"""
A dummy callback for the LLM.
This is a trick to handle an empty callback.
"""
def __enter__(self):
self.total_cost = 0
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def get_dummy_callback():
return DummyCallback()
class ChainWrapper:
"""
A wrapper for a LLM chain
"""
def __init__(self, llm_config, prompt_path: str, json_schema: dict = None, parser_func=None):
"""
Initialize a new instance of the ChainWrapper class.
:param llm_config: The config for the LLM
:param prompt_path: A path to the prompt file (text file)
:param json_schema: A dict for the json schema, to get a structured output for the LLM
:param parser_func: A function to parse the output of the LLM
"""
self.llm_config = llm_config
self.llm = get_llm(llm_config)
self.json_schema = json_schema
self.parser_func = parser_func
self.prompt = load_prompt(prompt_path)
self.build_chain()
self.accumulate_usage = 0
if self.llm_config.type == 'OpenAI':
self.callback = get_openai_callback
else:
self.callback = get_dummy_callback
def invoke(self, chain_input: dict) -> dict:
"""
Invoke the chain on a single input
:param chain_input: The input for the chain
:return: A dict with the defined json schema
"""
with self.callback() as cb:
try:
result = self.chain.invoke(chain_input)
if self.parser_func is not None:
result = self.parser_func(result)
except Exception as e:
#raise e
#if e.http_status == 401:
# raise e
#else:
#logging.error('Error in chain invoke: {}'.format(e.user_message))
result = None
self.accumulate_usage += cb.total_cost
return result
async def retry_operation(self, tasks):
"""
Retry an async operation
:param tasks:
:return:
"""
delay = self.llm_config.async_params.retry_interval
timeout = delay * self.llm_config.async_params.max_retries
start_time = asyncio.get_event_loop().time()
end_time = start_time + timeout
results = []
while True:
remaining_time = end_time - asyncio.get_event_loop().time()
if remaining_time <= 0:
print("Timeout reached. Operation incomplete.")
break
done, pending = await asyncio.wait(tasks, timeout=delay)
results += list(done)
if len(done) == len(tasks):
print("All tasks completed successfully.")
break
if not pending:
print("No pending tasks. Operation incomplete.")
break
tasks = list(pending) # Retry with the pending tasks
return results
async def async_batch_invoke(self, inputs: list[dict]) -> list[dict]:
"""
Invoke the chain on a batch of inputs in async mode
:param inputs: A batch of inputs
:return: A list of dicts with the defined json schema
"""
with self.callback() as cb:
tasks = [self.chain.ainvoke(chain_input) for chain_input in inputs]
all_res = await self.retry_operation(tasks)
self.accumulate_usage += cb.total_cost
if self.parser_func is not None:
return [self.parser_func(t.result()) for t in list(all_res)]
return [t.result() for t in list(all_res)]
def batch_invoke(self, inputs: list[dict], num_workers: int):
"""
Invoke the chain on a batch of inputs either async or not
:param inputs: The list of all inputs
:param num_workers: The number of workers
:return: A list of results
"""
def sample_generator():
for sample in inputs:
yield sample
def process_sample_with_progress(sample):
result = self.invoke(sample)
pbar.update(1) # Update the progress bar
return result
if not ('async_params' in self.llm_config.keys()): # non async mode, use regular workers
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
with tqdm(total=len(inputs), desc="Processing samples") as pbar:
all_results = list(executor.map(process_sample_with_progress, sample_generator()))
else:
all_results = []
for i in trange(0, len(inputs), num_workers, desc='Predicting'):
results = asyncio.run(self.async_batch_invoke(inputs[i:i + num_workers]))
all_results += results
all_results = [res for res in all_results if res is not None]
return all_results
def build_chain(self):
"""
Build the chain according to the LLM type
"""
if (self.llm_config.type == 'OpenAI' or self.llm_config.type == 'Azure') and self.json_schema is not None:
#self.chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.chain = create_structured_output_runnable(self.json_schema, self.llm, self.prompt)
else:
self.chain = LLMChain(llm=self.llm, prompt=self.prompt)
def get_chain_metadata(prompt_fn: Path, retrieve_module: bool = False) -> dict:
"""
Get the metadata of the chain
:param prompt_fn: The path to the prompt file
:param retrieve_module: If True, retrieve the module
:return: A dict with the metadata
"""
prompt_directory = str(prompt_fn.parent)
prompt_name = str(prompt_fn.stem)
try:
spec = importlib.util.spec_from_file_location('output_schemes', prompt_directory + '/output_schemes.py')
schema_parser = importlib.util.module_from_spec(spec)
spec.loader.exec_module(schema_parser)
except ImportError as e:
print(f"Error loading module {prompt_directory + '/output_schemes'}: {e}")
if hasattr(schema_parser, '{}_schema'.format(prompt_name)):
json_schema = getattr(schema_parser, '{}_schema'.format(prompt_name))
else:
json_schema = None
if hasattr(schema_parser, '{}_parser'.format(prompt_name)):
parser_func = getattr(schema_parser, '{}_parser'.format(prompt_name))
else:
parser_func = None
result = {'json_schema': json_schema, 'parser_func': parser_func}
if retrieve_module:
result['module'] = schema_parser
return result
class MetaChain:
"""
A wrapper for the meta-prompts chain
"""
def __init__(self, config):
"""
Initialize a new instance of the MetaChain class. Loading all the meta-prompts
:param config: An EasyDict configuration
"""
self.config = config
self.initial_chain = self.load_chain('initial')
self.step_prompt_chain = self.load_chain('step_prompt')
self.step_samples = self.load_chain('step_samples')
self.error_analysis = self.load_chain('error_analysis')
def load_chain(self, chain_name: str) -> ChainWrapper:
"""
Load a chain according to the chain name
:param chain_name: The name of the chain
"""
metadata = get_chain_metadata(self.config.meta_prompts.folder / '{}.prompt'.format(chain_name))
return ChainWrapper(self.config.llm, self.config.meta_prompts.folder / '{}.prompt'.format(chain_name),
metadata['json_schema'], metadata['parser_func'])
def calc_usage(self) -> float:
"""
Calculate the usage of all the meta-prompts
:return: The total usage value
"""
return self.initial_chain.accumulate_usage + self.step_prompt_chain.accumulate_usage \
+ self.step_samples.accumulate_usage
|