from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser, SimpleJsonOutputParser from langchain_openai import ChatOpenAI import re import concurrent.futures import copy import os class LangChainExecutor: def __init__(self, model_name): self.model_name = model_name self.platform = 'gpt' if 'gpt' in model_name else 'gemini' self.api_key = os.getenv("OPEN_AI_API_KEY") if self.platform == "gpt" else os.getenv("GEMINI_API_KEY") if self.platform == "gpt": self.default_config = { "temperature": 1, "max_tokens": None, } elif self.platform == "gemini": self.default_config = { "temperature": 1, "top_p": 0.95, "top_k": 64, "max_output_tokens": 8192, } def create_model(self, model_name, cp_config): # redefine by model_name self.platform = 'gpt' if 'gpt' in model_name else 'gemini' self.api_key = os.getenv("OPEN_AI_API_KEY") if self.platform == "gpt" else os.getenv("GEMINI_API_KEY") if self.platform == "gpt": self.default_config = { "temperature": 1, "max_tokens": None, } elif self.platform == "gemini": self.default_config = { "temperature": 1, "top_p": 0.95, "top_k": 64, "max_output_tokens": None, } if self.platform == "gpt": return ChatOpenAI( model=model_name, api_key=self.api_key, temperature=cp_config["temperature"], max_tokens=cp_config.get("max_tokens") ) elif self.platform == "gemini": return ChatGoogleGenerativeAI( model=model_name, google_api_key=self.api_key, temperature=cp_config["temperature"], top_p=cp_config.get("top_p"), top_k=cp_config.get("top_k"), max_output_tokens=cp_config.get("max_output_tokens") ) def clean_response(self, response): if response.startswith("```") and response.endswith("```"): pattern = r'^(?:```json|```csv|```)\s*(.*?)\s*```$' return re.sub(pattern, r'\1', response, flags=re.DOTALL).strip() return response.strip() def execute(self, model_input, user_input, model_name="", temperature=0, prefix=None, infix=None, suffix=None, json_output=False): cp_config = copy.deepcopy(self.default_config) cp_config["temperature"] = temperature if model_name == "": model_name = self.model_name model = self.create_model(model_name, cp_config) full_prompt_parts = [] if prefix: full_prompt_parts.append(prefix) if infix: full_prompt_parts.append(infix) full_prompt_parts.append(model_input) if suffix: full_prompt_parts.append(suffix) # Kết hợp các phần thành một chuỗi duy nhất full_prompt = "\n".join(full_prompt_parts) chat_template = ChatPromptTemplate.from_messages( [ ("system", "{full_prompt}"), ("human", "{user_input}"), ] ) if json_output: parser = SimpleJsonOutputParser() else: parser = StrOutputParser() run_chain = chat_template | model | parser map_args = { "full_prompt": full_prompt, "user_input": user_input, } response = run_chain.invoke(map_args) if json_output == False: # print('Yess') response = self.clean_response(response) # print("Nooo") return response def execute_with_image(self, model_input, user_input, base64_image, model_name="", temperature=0, prefix=None, infix=None, suffix=None, json_output=False): full_prompt_parts = [] if prefix: full_prompt_parts.append(prefix) if infix: full_prompt_parts.append(infix) full_prompt_parts.append(model_input) if suffix: full_prompt_parts.append(suffix) # Kết hợp các phần thành một chuỗi duy nhất full_prompt = "\n".join(full_prompt_parts) prompt = ChatPromptTemplate.from_messages( [ ("system", "{full_prompt}\n{user_input}"), ( "user", [ { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{image_data}"}, } ], ), ] ) cp_config = copy.deepcopy(self.default_config) cp_config["temperature"] = temperature if model_name == "": model_name = self.model_name model = self.create_model(model_name, cp_config) if json_output: parser = SimpleJsonOutputParser() else: parser = StrOutputParser() run_chain = prompt | model | parser response = run_chain.invoke({ "image_data": base64_image, "full_prompt": full_prompt, "user_input": user_input }) if json_output == False: # print('Yess') response = self.clean_response(response) # print("Nooo") return response def batch_execute(self, requests): """ Execute multiple requests in parallel for both `execute` and `execute_with_image`. Args: requests (list of dict): List of requests, each containing `model_input`, `user_input`, and optionally `model_name`, `temperature`, and `base64_image`. Returns: list of str: List of responses for each request, mapped correctly to their input. """ responses = [None] * len(requests) def process_request(index, request): model_input = request.get("model_input", "") user_input = request.get("user_input", "") prefix = request.get("prefix", None) infix = request.get("infix", None) suffix = request.get("suffix", None) model_name = request.get("model_name", self.model_name) temperature = request.get("temperature", 0) base64_image = request.get("base64_image", None) if base64_image: result = self.execute_with_image(model_input, user_input, base64_image, model_name, temperature, prefix, infix, suffix) else: result = self.execute(model_input, user_input, model_name, temperature, prefix, infix, suffix) responses[index] = result with concurrent.futures.ThreadPoolExecutor() as executor: futures = {executor.submit(process_request, i, request): i for i, request in enumerate(requests)} for future in concurrent.futures.as_completed(futures): index = futures[future] try: future.result() except Exception as exc: responses[index] = f"Exception occurred: {exc}" return responses