root commited on
Commit
c839b4c
1 Parent(s): 4fd61d3

deepnote update

Browse files
Files changed (1) hide show
  1. app.py +55 -56
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  import chainlit as cl
4
  from chainlit import user_session
5
  from chainlit.types import LLMSettings
 
6
  from langchain import LLMChain
7
  from langchain.prompts import PromptTemplate
8
  from langchain.llms import AzureOpenAI
@@ -13,7 +14,8 @@ from langchain.vectorstores import Chroma
13
  from langchain.vectorstores.base import VectorStoreRetriever
14
 
15
 
16
- current_agent = os.environ["AGENT"]
 
17
 
18
 
19
  def load_dialogues():
@@ -28,10 +30,8 @@ def load_persona():
28
  return df.astype(str)
29
 
30
 
31
- def load_prompt_engineering():
32
- df = pd.read_excel(
33
- os.environ["PROMPT_ENGINEERING_SHEET"], header=0, keep_default_na=False
34
- )
35
  df = df[df["Agent"] == current_agent]
36
  return df.astype(str)
37
 
@@ -50,20 +50,25 @@ def init_embedding_function():
50
 
51
 
52
  def load_vectordb(init: bool = False):
53
- vectordb = None
54
  VECTORDB_FOLDER = ".vectordb"
55
- if not init:
56
  vectordb = Chroma(
57
  embedding_function=init_embedding_function(),
58
  persist_directory=VECTORDB_FOLDER,
59
  )
60
- if init or not vectordb.get()["ids"]:
 
 
 
 
61
  vectordb = Chroma.from_documents(
62
  documents=load_documents(load_dialogues(), page_content_column="Utterance"),
63
  embedding=init_embedding_function(),
64
  persist_directory=VECTORDB_FOLDER,
65
  )
66
  vectordb.persist()
 
67
  return vectordb
68
 
69
 
@@ -80,17 +85,15 @@ def get_retriever(context_state: str, vectordb):
80
  )
81
 
82
 
83
- vectordb = load_vectordb()
84
-
85
-
86
  @cl.langchain_factory(use_async=True)
87
  def factory():
88
- df_prompt_engineering = load_prompt_engineering()
 
89
  user_session.set("context_state", "")
90
 
91
  llm_settings = LLMSettings(
92
  model_name="text-davinci-003",
93
- temperature=df_prompt_engineering["Temperature"].values[0],
94
  )
95
  user_session.set("llm_settings", llm_settings)
96
 
@@ -101,14 +104,12 @@ def factory():
101
  streaming=True,
102
  )
103
 
104
- utterance_prompt = PromptTemplate.from_template(
105
- df_prompt_engineering["Utterance-Prompt"].values[0]
106
- )
107
 
108
  chat_memory = ConversationBufferWindowMemory(
109
  memory_key="History",
110
  input_key="Utterance",
111
- k=df_prompt_engineering["History"].values[0],
112
  )
113
 
114
  utterance_chain = LLMChain(
@@ -118,9 +119,7 @@ def factory():
118
  memory=chat_memory,
119
  )
120
 
121
- continuation_prompt = PromptTemplate.from_template(
122
- df_prompt_engineering["Continuation-Prompt"].values[0]
123
- )
124
 
125
  continuation_chain = LLMChain(
126
  prompt=continuation_prompt,
@@ -139,52 +138,52 @@ async def run(agent, input_str):
139
  global vectordb
140
  if input_str == "/reload":
141
  vectordb = load_vectordb(True)
142
- await cl.Message(content="Data loaded").send()
143
- else:
144
- df_persona = load_persona()
145
 
146
- retriever = get_retriever(user_session.get("context_state"), vectordb)
147
 
148
- document = retriever.get_relevant_documents(query=input_str)
149
 
150
- response = await agent.acall(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  {
152
  "Persona": df_persona.loc[
153
- df_persona["AI"] == document[0].metadata["AI"]
154
  ]["Persona"].values[0],
155
- "Utterance": input_str,
156
- "Response": document[0].metadata["Response"],
157
  },
158
  callbacks=[cl.AsyncLangchainCallbackHandler()],
159
  )
160
  await cl.Message(
161
  content=response["text"],
162
- author=document[0].metadata["AI"],
163
  llm_settings=user_session.get("llm_settings"),
164
  ).send()
165
- user_session.set("context_state", document[0].metadata["Contextualisation"])
166
- continuation = document[0].metadata["Continuation"]
167
-
168
- while continuation != "":
169
- document_continuation = vectordb.get(where={"Intent": continuation})
170
- continuation_chain = user_session.get("continuation_chain")
171
- response = await continuation_chain.acall(
172
- {
173
- "Persona": df_persona.loc[
174
- df_persona["AI"] == document_continuation["metadatas"][0]["AI"]
175
- ]["Persona"].values[0],
176
- "Utterance": "",
177
- "Response": document_continuation["metadatas"][0]["Response"],
178
- },
179
- callbacks=[cl.AsyncLangchainCallbackHandler()],
180
- )
181
- await cl.Message(
182
- content=response["text"],
183
- author=document_continuation["metadatas"][0]["AI"],
184
- llm_settings=user_session.get("llm_settings"),
185
- ).send()
186
- user_session.set(
187
- "context_state",
188
- document_continuation["metadatas"][0]["Contextualisation"],
189
- )
190
- continuation = document_continuation["metadatas"][0]["Continuation"]
 
3
  import chainlit as cl
4
  from chainlit import user_session
5
  from chainlit.types import LLMSettings
6
+ from chainlit.logger import logger
7
  from langchain import LLMChain
8
  from langchain.prompts import PromptTemplate
9
  from langchain.llms import AzureOpenAI
 
14
  from langchain.vectorstores.base import VectorStoreRetriever
15
 
16
 
17
+ current_agent = os.environ["AGENT_SHEET"]
18
+ vectordb = None
19
 
20
 
21
  def load_dialogues():
 
30
  return df.astype(str)
31
 
32
 
33
+ def load_prompts():
34
+ df = pd.read_excel(os.environ["PROMPT_SHEET"], header=0, keep_default_na=False)
 
 
35
  df = df[df["Agent"] == current_agent]
36
  return df.astype(str)
37
 
 
50
 
51
 
52
  def load_vectordb(init: bool = False):
53
+ global vectordb
54
  VECTORDB_FOLDER = ".vectordb"
55
+ if not init and vectordb is None:
56
  vectordb = Chroma(
57
  embedding_function=init_embedding_function(),
58
  persist_directory=VECTORDB_FOLDER,
59
  )
60
+ if not vectordb.get()["ids"]:
61
+ init = True
62
+ else:
63
+ logger.info(f"Vector DB loaded")
64
+ if init:
65
  vectordb = Chroma.from_documents(
66
  documents=load_documents(load_dialogues(), page_content_column="Utterance"),
67
  embedding=init_embedding_function(),
68
  persist_directory=VECTORDB_FOLDER,
69
  )
70
  vectordb.persist()
71
+ logger.info(f"Vector DB initialised")
72
  return vectordb
73
 
74
 
 
85
  )
86
 
87
 
 
 
 
88
  @cl.langchain_factory(use_async=True)
89
  def factory():
90
+ load_vectordb()
91
+ df_prompts = load_prompts()
92
  user_session.set("context_state", "")
93
 
94
  llm_settings = LLMSettings(
95
  model_name="text-davinci-003",
96
+ temperature=df_prompts["Temperature"].values[0],
97
  )
98
  user_session.set("llm_settings", llm_settings)
99
 
 
104
  streaming=True,
105
  )
106
 
107
+ utterance_prompt = PromptTemplate.from_template(df_prompts["Template"].values[0])
 
 
108
 
109
  chat_memory = ConversationBufferWindowMemory(
110
  memory_key="History",
111
  input_key="Utterance",
112
+ k=df_prompts["History"].values[0],
113
  )
114
 
115
  utterance_chain = LLMChain(
 
119
  memory=chat_memory,
120
  )
121
 
122
+ continuation_prompt = PromptTemplate.from_template(df_prompts["Template"].values[1])
 
 
123
 
124
  continuation_chain = LLMChain(
125
  prompt=continuation_prompt,
 
138
  global vectordb
139
  if input_str == "/reload":
140
  vectordb = load_vectordb(True)
141
+ return await cl.Message(content="Data loaded").send()
142
+
143
+ df_persona = load_persona()
144
 
145
+ retriever = get_retriever(user_session.get("context_state"), vectordb)
146
 
147
+ document = retriever.get_relevant_documents(query=input_str)
148
 
149
+ response = await agent.acall(
150
+ {
151
+ "Persona": df_persona.loc[df_persona["AI"] == document[0].metadata["AI"]][
152
+ "Persona"
153
+ ].values[0],
154
+ "Utterance": input_str,
155
+ "Response": document[0].metadata["Response"],
156
+ },
157
+ callbacks=[cl.AsyncLangchainCallbackHandler()],
158
+ )
159
+ await cl.Message(
160
+ content=response["text"],
161
+ author=document[0].metadata["AI"],
162
+ llm_settings=user_session.get("llm_settings"),
163
+ ).send()
164
+ user_session.set("context_state", document[0].metadata["Contextualisation"])
165
+ continuation = document[0].metadata["Continuation"]
166
+
167
+ while continuation != "":
168
+ document_continuation = vectordb.get(where={"Intent": continuation})
169
+ continuation_chain = user_session.get("continuation_chain")
170
+ response = await continuation_chain.acall(
171
  {
172
  "Persona": df_persona.loc[
173
+ df_persona["AI"] == document_continuation["metadatas"][0]["AI"]
174
  ]["Persona"].values[0],
175
+ "Utterance": "",
176
+ "Response": document_continuation["metadatas"][0]["Response"],
177
  },
178
  callbacks=[cl.AsyncLangchainCallbackHandler()],
179
  )
180
  await cl.Message(
181
  content=response["text"],
182
+ author=document_continuation["metadatas"][0]["AI"],
183
  llm_settings=user_session.get("llm_settings"),
184
  ).send()
185
+ user_session.set(
186
+ "context_state",
187
+ document_continuation["metadatas"][0]["Contextualisation"],
188
+ )
189
+ continuation = document_continuation["metadatas"][0]["Continuation"]