aie4-final / graph.py
danicafisher's picture
Update graph.py
dc4e03f verified
from typing import Dict, List, TypedDict, Sequence
from langgraph.graph import StateGraph, END
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.tools.tavily_search import TavilySearchResults
import models
import prompts
import json
from operator import itemgetter
from langgraph.errors import GraphRecursionError
#######################################
### Research Team Components ###
#######################################
class ResearchState(TypedDict):
workflow: List[str]
topic: str
research_data: Dict[str, str]
next: str
message_to_manager: str
message_from_manager: str
#
# Reserach Chains and Tools
#
qdrant_research_chain = (
{"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")}
)
tavily_tool = TavilySearchResults(max_results=3)
query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() )
tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser())
tavily_chain = (
{"query": query_chain} | tavily_simple
)
research_supervisor_chain = (
prompts.research_supervisor_prompt | models.gpt4o | StrOutputParser()
)
#
# Reserach Node Defs
#
def query_qdrant(state: ResearchState) -> ResearchState:
topic = state["topic"]
result = qdrant_research_chain.invoke({"topic": topic})
print(result)
state["research_data"]["qdrant_results"] = result["response"]
state['workflow'].append("query_qdrant")
print(state['workflow'])
return state
def web_search(state: ResearchState) -> ResearchState:
topic = state["topic"]
qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.")
result = tavily_chain.invoke({"topic": topic,"qdrant_results": qdrant_results })
print(result)
state["research_data"]["web_search_results"] = result
state['workflow'].append("web_search")
print(state['workflow'])
return state
def research_supervisor(state):
message_from_manager = state["message_from_manager"]
collected_data = state["research_data"]
topic = state['topic']
supervisor_result = research_supervisor_chain.invoke({"message_from_manager": message_from_manager, "collected_data": collected_data, "topic": topic})
lines = supervisor_result.split('\n')
print(supervisor_result)
for line in lines:
if line.startswith('Next Action: '):
state['next'] = line[len('Next Action: '):].strip() # Extract the next action content
elif line.startswith('Message to project manager: '):
state['message_to_manager'] = line[len('Message to project manager: '):].strip()
state['workflow'].append("research_supervisor")
print(state['workflow'])
return state
def research_end(state):
state['workflow'].append("research_end")
print(state['workflow'])
return state
#######################################
### Writing Team Components ###
#######################################
class WritingState(TypedDict):
workflow: List[str]
topic: str
research_data: Dict[str, str]
draft_posts: Sequence[str]
final_post: str
next: str
message_to_manager: str
message_from_manager: str
review_comments: str
style_checked: bool
#
# Writing Chains
#
writing_supervisor_chain = (
prompts.writing_supervisor_prompt | models.gpt4o | StrOutputParser()
)
post_creation_chain = (
prompts.post_creation_prompt | models.gpt4o_mini | StrOutputParser()
)
post_editor_chain = (
prompts.post_editor_prompt | models.gpt4o | StrOutputParser()
)
post_review_chain = (
prompts.post_review_prompt | models.gpt4o | StrOutputParser()
)
#
# Writing Node Defs
#
def post_creation(state):
topic = state['topic']
drafts = state['draft_posts']
collected_data = state["research_data"]
review_comments = state['review_comments']
results = post_creation_chain.invoke({"topic": topic, "collected_data": collected_data, "drafts": drafts, "review_comments": review_comments})
print(results)
state['draft_posts'].append(results)
state['workflow'].append("post_creation")
print(state['workflow'])
return state
def post_editor(state):
current_draft = state['draft_posts'][-1]
styleguide = prompts.style_guide_text
review_comments = state['review_comments']
results = post_editor_chain.invoke({"current_draft": current_draft, "styleguide": styleguide, "review_comments": review_comments})
print(results)
state['draft_posts'].append(results)
state['workflow'].append("post_editor")
print(state['workflow'])
return state
def post_review(state):
print("post_review node")
current_draft = state['draft_posts'][-1]
styleguide = prompts.style_guide_text
results = post_review_chain.invoke({"current_draft": current_draft, "styleguide": styleguide})
print(results)
data = json.loads(results.strip())
state['review_comments'] = data["Comments on current draft"]
if data["Draft Acceptable"] == 'Yes':
state['final_post'] = state['draft_posts'][-1]
state['workflow'].append("post_review")
print(state['workflow'])
return state
def writing_end(state):
print("writing_end node")
state['workflow'].append("writing_end")
print(state['workflow'])
return state
def writing_supervisor(state):
print("writing_supervisor node")
message_from_manager = state['message_from_manager']
topic = state['topic']
drafts = state['draft_posts']
final_draft = state['final_post']
review_comments = state['review_comments']
supervisor_result = writing_supervisor_chain.invoke({"review_comments": review_comments, "message_from_manager": message_from_manager, "topic": topic, "drafts": drafts, "final_draft": final_draft})
print(supervisor_result)
lines = supervisor_result.split('\n')
for line in lines:
if line.startswith('Next Action: '):
state['next'] = line[len('Next Action: '):].strip() # Extract the next action content
elif line.startswith('Message to project manager: '):
state['message_to_manager'] = line[len('Message to project manager: '):].strip()
state['workflow'].append("writing_supervisor")
print(state['workflow'])
return state
#######################################
### Overarching Graph Components ###
#######################################
class State(TypedDict):
workflow: List[str]
topic: str
research_data: Dict[str, str]
draft_posts: Sequence[str]
final_post: str
next: str
user_input: str
message_to_manager: str
message_from_manager: str
last_active_team :str
next_team: str
review_comments: str
#
# Complete Graph Chains
#
overall_supervisor_chain = (
prompts.overall_supervisor_prompt | models.gpt4o | StrOutputParser()
)
#
# Complete Graph Node defs
#
def overall_supervisor(state):
init_user_query = state["user_input"]
message_to_manager = state['message_to_manager']
last_active_team = state['last_active_team']
final_post = state['final_post']
supervisor_result = overall_supervisor_chain.invoke({"query": init_user_query, "message_to_manager": message_to_manager, "last_active_team": last_active_team, "final_post": final_post})
print(supervisor_result)
lines = supervisor_result.split('\n')
for line in lines:
if line.startswith('Next Action: '):
state['next_team'] = line[len('Next Action: '):].strip() # Extract the next action content
elif line.startswith('Extracted Topic: '):
state['topic'] = line[len('Extracted Topic: '):].strip() # Extract the next action content
elif line.startswith('Message to supervisor: '):
state['message_from_manager'] = line[len('Message to supervisor: '):].strip() # Extract the next action content
state['workflow'].append("overall_supervisor")
print(state['workflow'])
return state
#######################################
### Graph structures ###
#######################################
#
# Reserach Graph Nodes
#
research_graph = StateGraph(ResearchState)
research_graph.add_node("query_qdrant", query_qdrant)
research_graph.add_node("web_search", web_search)
research_graph.add_node("research_supervisor", research_supervisor)
research_graph.add_node("research_end", research_end)
#
# Reserach Graph Edges
#
research_graph.set_entry_point("research_supervisor")
research_graph.add_edge("query_qdrant", "research_supervisor")
research_graph.add_edge("web_search", "research_supervisor")
research_graph.add_conditional_edges(
"research_supervisor",
lambda x: x["next"],
{"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": "research_end"},
)
research_graph_comp = research_graph.compile()
#
# Writing Graph Nodes
#
writing_graph = StateGraph(WritingState)
writing_graph.add_node("post_creation", post_creation)
writing_graph.add_node("post_editor", post_editor)
writing_graph.add_node("post_review", post_review)
writing_graph.add_node("writing_supervisor", writing_supervisor)
writing_graph.add_node("writing_end", writing_end)
#
# Writing Graph Edges
#
writing_graph.set_entry_point("writing_supervisor")
writing_graph.add_edge("post_creation", "post_editor")
writing_graph.add_edge("post_editor", "post_review")
writing_graph.add_edge("post_review", "writing_supervisor")
writing_graph.add_conditional_edges(
"writing_supervisor",
lambda x: x["next"],
{"NEW DRAFT": "post_creation",
"FINISH": "writing_end"},
)
writing_graph_comp = writing_graph.compile()
#
# Complete Graph Nodes
#
overall_graph = StateGraph(State)
overall_graph.add_node("overall_supervisor", overall_supervisor)
overall_graph.add_node("research_team_graph", research_graph_comp)
overall_graph.add_node("writing_team_graph", writing_graph_comp)
#
# Complete Graph Edges
#
overall_graph.set_entry_point("overall_supervisor")
overall_graph.add_edge("research_team_graph", "overall_supervisor")
overall_graph.add_edge("writing_team_graph", "overall_supervisor")
overall_graph.add_conditional_edges(
"overall_supervisor",
lambda x: x["next_team"],
{"research_team": "research_team_graph",
"writing_team": "writing_team_graph",
"FINISH": END},
)
app = overall_graph.compile()
#######################################
### Run method ###
#######################################
def getSocialMediaPost(userInput: str) -> str:
finalPost = ""
initial_state = State(
workflow = [],
topic= "",
research_data = {},
draft_posts = [],
final_post = [],
next = [],
next_team = [],
user_input=userInput,
message_to_manager="",
message_from_manager="",
last_active_team="",
review_comments=""
)
results = app.invoke(initial_state, {"recursion_limit": 40})
try:
results = app.invoke(initial_state, {"recursion_limit": 40})
except GraphRecursionError:
return "Recursion Error"
finalPost = results['final_post']
return finalPost