angry-meow commited on
Commit
b210243
1 Parent(s): 5aaa7d9

done for the night

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. graph.py +159 -0
  3. helper_functions.py +5 -2
  4. prompts.py +42 -1
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .env
2
  /__pycache__
 
 
 
1
  .env
2
  /__pycache__
3
+ testing.py
4
+ /models
graph.py CHANGED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, TypedDict, Annotated, Sequence
2
+ from langgraph.graph import Graph, StateGraph, END
3
+ from langgraph.prebuilt import ToolExecutor
4
+ from langchain.schema import StrOutputParser
5
+ from langchain.schema.runnable import RunnablePassthrough
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ import models
8
+ import prompts
9
+ from helper_functions import format_docs
10
+ from operator import itemgetter
11
+
12
+ # Define the state structure
13
+ class State(TypedDict):
14
+ messages: Sequence[str]
15
+ research_data: Dict[str, str]
16
+ draft_post: str
17
+ final_post: str
18
+
19
+
20
+ # Research Agent Pieces
21
+ qdrant_research_chain = (
22
+ {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")}
23
+ | RunnablePassthrough.assign(context=itemgetter("context"))
24
+ | {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")}
25
+ )
26
+
27
+ # Web Search Agent Pieces
28
+ tavily_tool = TavilySearchResults(max_results=5)
29
+ web_search_chain = (
30
+ {
31
+ "topic": itemgetter("topic"),
32
+ "qdrant_results": itemgetter("qdrant_results"),
33
+ }
34
+ | prompts.search_query_prompt
35
+ | models.gpt4o_mini
36
+ | StrOutputParser()
37
+ | tavily_tool
38
+ | {
39
+ "topic": itemgetter("topic"),
40
+ "qdrant_results": itemgetter("qdrant_results"),
41
+ "search_results": RunnablePassthrough()
42
+ }
43
+ | prompts.summarize_prompt
44
+ | models.gpt4o_mini
45
+ | StrOutputParser()
46
+ )
47
+
48
+ def query_qdrant(state: State) -> State:
49
+ # Extract the last message as the input
50
+ input_text = state["messages"][-1]
51
+
52
+ # Run the chain
53
+ result = qdrant_research_chain.invoke({"topic": input_text})
54
+
55
+ # Update the state with the research results
56
+ state["research_data"]["qdrant_results"] = result
57
+
58
+ return state
59
+
60
+ def web_search(state: State) -> State:
61
+ # Extract the last message as the topic
62
+ topic = state["messages"][-1]
63
+
64
+ # Get the Qdrant results from the state
65
+ qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.")
66
+
67
+ # Run the web search chain
68
+ result = web_search_chain.invoke({
69
+ "topic": topic,
70
+ "qdrant_results": qdrant_results
71
+ })
72
+
73
+ # Update the state with the web search results
74
+ state["research_data"]["web_search_results"] = result
75
+
76
+ return state
77
+
78
+ def research_supervisor(state):
79
+ # Implement research supervision logic
80
+ return state
81
+
82
+ def post_creation(state):
83
+ # Implement post creation logic
84
+ return state
85
+
86
+ def copy_editing(state):
87
+ # Implement copy editing logic
88
+ return state
89
+
90
+ def voice_editing(state):
91
+ # Implement voice editing logic
92
+ return state
93
+
94
+ def post_review(state):
95
+ # Implement post review logic
96
+ return state
97
+
98
+ def writing_supervisor(state):
99
+ # Implement writing supervision logic
100
+ return state
101
+
102
+ def overall_supervisor(state):
103
+ # Implement overall supervision logic
104
+ return state
105
+
106
+ # Create the research team graph
107
+ research_graph = StateGraph(State)
108
+
109
+ research_graph.add_node("query_qdrant", query_qdrant)
110
+ research_graph.add_node("web_search", web_search)
111
+ research_graph.add_node("research_supervisor", research_supervisor)
112
+
113
+ research_graph.add_edge("query_qdrant", "research_supervisor")
114
+ research_graph.add_edge("web_search", "research_supervisor")
115
+ research_graph.add_edge("research_supervisor", "query_qdrant")
116
+ research_graph.add_edge("research_supervisor", "web_search")
117
+ research_graph.add_edge("research_supervisor", END)
118
+
119
+ research_graph.set_entry_point("research_supervisor")
120
+
121
+ # Create the writing team graph
122
+ writing_graph = StateGraph(State)
123
+
124
+ writing_graph.add_node("post_creation", post_creation)
125
+ writing_graph.add_node("copy_editing", copy_editing)
126
+ writing_graph.add_node("voice_editing", voice_editing)
127
+ writing_graph.add_node("post_review", post_review)
128
+ writing_graph.add_node("writing_supervisor", writing_supervisor)
129
+
130
+ writing_graph.add_edge("writing_supervisor", "post_creation")
131
+ writing_graph.add_edge("post_creation", "copy_editing")
132
+ writing_graph.add_edge("copy_editing", "voice_editing")
133
+ writing_graph.add_edge("voice_editing", "post_review")
134
+ writing_graph.add_edge("post_review", "writing_supervisor")
135
+ writing_graph.add_edge("writing_supervisor", END)
136
+
137
+ writing_graph.set_entry_point("writing_supervisor")
138
+
139
+ # Create the overall graph
140
+ overall_graph = StateGraph(State)
141
+
142
+ # Add the research and writing team graphs as nodes
143
+ overall_graph.add_node("research_team", research_graph)
144
+ overall_graph.add_node("writing_team", writing_graph)
145
+
146
+ # Add the overall supervisor node
147
+ overall_graph.add_node("overall_supervisor", overall_supervisor)
148
+
149
+ overall_graph.set_entry_point("overall_supervisor")
150
+
151
+ # Connect the nodes
152
+ overall_graph.add_edge("overall_supervisor", "research_team")
153
+ overall_graph.add_edge("research_team", "overall_supervisor")
154
+ overall_graph.add_edge("overall_supervisor", "writing_team")
155
+ overall_graph.add_edge("writing_team", "overall_supervisor")
156
+ overall_graph.add_edge("overall_supervisor", END)
157
+
158
+ # Compile the graph
159
+ app = overall_graph.compile()
helper_functions.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List
2
  from langchain.agents import AgentExecutor, create_openai_functions_agent
3
  from langchain_community.document_loaders import PyMuPDFLoader, TextLoader, UnstructuredURLLoader, WebBaseLoader
4
  from langchain_community.vectorstores import Qdrant
@@ -141,4 +141,7 @@ def create_agent(
141
  )
142
  agent = create_openai_functions_agent(llm, tools, prompt)
143
  executor = AgentExecutor(agent=agent, tools=tools)
144
- return executor
 
 
 
 
1
+ from typing import Dict, List
2
  from langchain.agents import AgentExecutor, create_openai_functions_agent
3
  from langchain_community.document_loaders import PyMuPDFLoader, TextLoader, UnstructuredURLLoader, WebBaseLoader
4
  from langchain_community.vectorstores import Qdrant
 
141
  )
142
  agent = create_openai_functions_agent(llm, tools, prompt)
143
  executor = AgentExecutor(agent=agent, tools=tools)
144
+ return executor
145
+
146
+ def format_docs(docs: List[Dict]) -> str:
147
+ return "\n\n".join(f"Content: {doc.page_content}\nSource: {doc.metadata.get('source', 'Unknown')}" for doc in docs)
prompts.py CHANGED
@@ -21,4 +21,45 @@ chat_prompt = ChatPromptTemplate.from_messages([("system", rag_system_prompt_tem
21
 
22
  style_guide_path = "./public/CoExperiences Writing Style Guide V1 (2024).pdf"
23
  style_guide_docs = PyMuPDFLoader(style_guide_path).load()
24
- style_guide_text = "\n".join([doc.page_content for doc in style_guide_docs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  style_guide_path = "./public/CoExperiences Writing Style Guide V1 (2024).pdf"
23
  style_guide_docs = PyMuPDFLoader(style_guide_path).load()
24
+ style_guide_text = "\n".join([doc.page_content for doc in style_guide_docs])
25
+
26
+ research_query_prompt = ChatPromptTemplate.from_template("""
27
+ Given a provided context and a topic, compile facts, statistics, quotes, or other related pieces of information that relate to the topic. Make sure to include the source of any such pieces of information in your response.
28
+
29
+ Context:
30
+ {context}
31
+
32
+ Topic:
33
+ {topic}
34
+
35
+ Answer:
36
+ """
37
+ )
38
+
39
+ search_query_prompt = ChatPromptTemplate.from_template(
40
+ """Given the following topic and information from our database, create a search query to find supplementary information:
41
+
42
+ Topic: {topic}
43
+
44
+ Information from our database:
45
+ {qdrant_results}
46
+
47
+ Generate a search query to find additional, up-to-date information that complements what we already know:
48
+ """
49
+ )
50
+
51
+ # Create a prompt for summarizing the search results
52
+ summarize_prompt = ChatPromptTemplate.from_template(
53
+ """Summarize the following search results, focusing on information that is complementary to what we already know from our database. Include sources for each piece of information:
54
+
55
+ Topic: {topic}
56
+
57
+ Information from our database:
58
+ {qdrant_results}
59
+
60
+ Search results:
61
+ {search_results}
62
+
63
+ Complementary summary with sources:
64
+ """
65
+ )