Spaces:
Running
Running
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser, SimpleJsonOutputParser | |
from langchain_openai import ChatOpenAI | |
import re | |
import concurrent.futures | |
import copy | |
import os | |
class LangChainExecutor: | |
def __init__(self, model_name): | |
self.model_name = model_name | |
self.platform = 'gpt' if 'gpt' in model_name else 'gemini' | |
self.api_key = os.getenv("OPEN_AI_API_KEY") if self.platform == "gpt" else os.getenv("GEMINI_API_KEY") | |
if self.platform == "gpt": | |
self.default_config = { | |
"temperature": 1, | |
"max_tokens": None, | |
} | |
elif self.platform == "gemini": | |
self.default_config = { | |
"temperature": 1, | |
"top_p": 0.95, | |
"top_k": 64, | |
"max_output_tokens": 8192, | |
} | |
def create_model(self, model_name, cp_config): | |
# redefine by model_name | |
self.platform = 'gpt' if 'gpt' in model_name else 'gemini' | |
self.api_key = os.getenv("OPEN_AI_API_KEY") if self.platform == "gpt" else os.getenv("GEMINI_API_KEY") | |
if self.platform == "gpt": | |
self.default_config = { | |
"temperature": 1, | |
"max_tokens": None, | |
} | |
elif self.platform == "gemini": | |
self.default_config = { | |
"temperature": 1, | |
"top_p": 0.95, | |
"top_k": 64, | |
"max_output_tokens": None, | |
} | |
if self.platform == "gpt": | |
return ChatOpenAI( | |
model=model_name, | |
api_key=self.api_key, | |
temperature=cp_config["temperature"], | |
max_tokens=cp_config.get("max_tokens") | |
) | |
elif self.platform == "gemini": | |
return ChatGoogleGenerativeAI( | |
model=model_name, | |
google_api_key=self.api_key, | |
temperature=cp_config["temperature"], | |
top_p=cp_config.get("top_p"), | |
top_k=cp_config.get("top_k"), | |
max_output_tokens=cp_config.get("max_output_tokens") | |
) | |
def clean_response(self, response): | |
if response.startswith("```") and response.endswith("```"): | |
pattern = r'^(?:```json|```csv|```)\s*(.*?)\s*```$' | |
return re.sub(pattern, r'\1', response, flags=re.DOTALL).strip() | |
return response.strip() | |
def execute(self, model_input, user_input, model_name="", temperature=0, prefix=None, infix=None, suffix=None, json_output=False): | |
cp_config = copy.deepcopy(self.default_config) | |
cp_config["temperature"] = temperature | |
if model_name == "": | |
model_name = self.model_name | |
model = self.create_model(model_name, cp_config) | |
full_prompt_parts = [] | |
if prefix: | |
full_prompt_parts.append(prefix) | |
if infix: | |
full_prompt_parts.append(infix) | |
full_prompt_parts.append(model_input) | |
if suffix: | |
full_prompt_parts.append(suffix) | |
# Kết hợp các phần thành một chuỗi duy nhất | |
full_prompt = "\n".join(full_prompt_parts) | |
chat_template = ChatPromptTemplate.from_messages( | |
[ | |
("system", "{full_prompt}"), | |
("human", "{user_input}"), | |
] | |
) | |
if json_output: | |
parser = SimpleJsonOutputParser() | |
else: | |
parser = StrOutputParser() | |
run_chain = chat_template | model | parser | |
map_args = { | |
"full_prompt": full_prompt, | |
"user_input": user_input, | |
} | |
response = run_chain.invoke(map_args) | |
if json_output == False: | |
# print('Yess') | |
response = self.clean_response(response) | |
# print("Nooo") | |
return response | |
def execute_with_image(self, model_input, user_input, base64_image, model_name="", temperature=0, prefix=None, infix=None, suffix=None, json_output=False): | |
full_prompt_parts = [] | |
if prefix: | |
full_prompt_parts.append(prefix) | |
if infix: | |
full_prompt_parts.append(infix) | |
full_prompt_parts.append(model_input) | |
if suffix: | |
full_prompt_parts.append(suffix) | |
# Kết hợp các phần thành một chuỗi duy nhất | |
full_prompt = "\n".join(full_prompt_parts) | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", "{full_prompt}\n{user_input}"), | |
( | |
"user", | |
[ | |
{ | |
"type": "image_url", | |
"image_url": {"url": "data:image/jpeg;base64,{image_data}"}, | |
} | |
], | |
), | |
] | |
) | |
cp_config = copy.deepcopy(self.default_config) | |
cp_config["temperature"] = temperature | |
if model_name == "": | |
model_name = self.model_name | |
model = self.create_model(model_name, cp_config) | |
if json_output: | |
parser = SimpleJsonOutputParser() | |
else: | |
parser = StrOutputParser() | |
run_chain = prompt | model | parser | |
response = run_chain.invoke({ | |
"image_data": base64_image, | |
"full_prompt": full_prompt, | |
"user_input": user_input | |
}) | |
if json_output == False: | |
# print('Yess') | |
response = self.clean_response(response) | |
# print("Nooo") | |
return response | |
def batch_execute(self, requests): | |
""" | |
Execute multiple requests in parallel for both `execute` and `execute_with_image`. | |
Args: | |
requests (list of dict): List of requests, each containing `model_input`, `user_input`, | |
and optionally `model_name`, `temperature`, and `base64_image`. | |
Returns: | |
list of str: List of responses for each request, mapped correctly to their input. | |
""" | |
responses = [None] * len(requests) | |
def process_request(index, request): | |
model_input = request.get("model_input", "") | |
user_input = request.get("user_input", "") | |
prefix = request.get("prefix", None) | |
infix = request.get("infix", None) | |
suffix = request.get("suffix", None) | |
model_name = request.get("model_name", self.model_name) | |
temperature = request.get("temperature", 0) | |
base64_image = request.get("base64_image", None) | |
if base64_image: | |
result = self.execute_with_image(model_input, user_input, base64_image, model_name, temperature, prefix, infix, suffix) | |
else: | |
result = self.execute(model_input, user_input, model_name, temperature, prefix, infix, suffix) | |
responses[index] = result | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = {executor.submit(process_request, i, request): i for i, request in enumerate(requests)} | |
for future in concurrent.futures.as_completed(futures): | |
index = futures[future] | |
try: | |
future.result() | |
except Exception as exc: | |
responses[index] = f"Exception occurred: {exc}" | |
return responses |