smu-lib-chat / app_modules /llm_summarize_chain.py
inflaton's picture
Duplicate from inflaton/smu-ai
d5edf96
raw
history blame
No virus
2.42 kB
import os
from typing import List, Optional
from langchain import PromptTemplate
from langchain.chains.base import Chain
from langchain.chains.summarize import load_summarize_chain
from app_modules.llm_inference import LLMInference
def get_llama_2_prompt_template(instruction):
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the text to get context"
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
return prompt_template
class SummarizeChain(LLMInference):
def __init__(self, llm_loader):
super().__init__(llm_loader)
def create_chain(self) -> Chain:
use_llama_2_prompt_template = (
os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
)
prompt_template = """Write a concise summary of the following:
{text}
CONCISE SUMMARY:"""
if use_llama_2_prompt_template:
prompt_template = get_llama_2_prompt_template(prompt_template)
prompt = PromptTemplate.from_template(prompt_template)
refine_template = (
"Your job is to produce a final summary\n"
"We have provided an existing summary up to a certain point: {existing_answer}\n"
"We have the opportunity to refine the existing summary"
"(only if needed) with some more context below.\n"
"------------\n"
"{text}\n"
"------------\n"
"Given the new context, refine the original summary."
"If the context isn't useful, return the original summary."
)
if use_llama_2_prompt_template:
refine_template = get_llama_2_prompt_template(refine_template)
refine_prompt = PromptTemplate.from_template(refine_template)
chain = load_summarize_chain(
llm=self.llm_loader.llm,
chain_type="refine",
question_prompt=prompt,
refine_prompt=refine_prompt,
return_intermediate_steps=True,
input_key="input_documents",
output_key="output_text",
)
return chain
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
result = chain(inputs, return_only_outputs=True)
return result