from typing import Dict, List, TypedDict, Sequence from langgraph.graph import StateGraph, END from langchain.schema import StrOutputParser from langchain.schema.runnable import RunnablePassthrough from langchain_community.tools.tavily_search import TavilySearchResults import models import prompts import json from operator import itemgetter from langgraph.errors import GraphRecursionError ####################################### ### Research Team Components ### ####################################### class ResearchState(TypedDict): workflow: List[str] topic: str research_data: Dict[str, str] next: str message_to_manager: str message_from_manager: str # # Reserach Chains and Tools # qdrant_research_chain = ( {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")} | RunnablePassthrough.assign(context=itemgetter("context")) | {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")} ) tavily_tool = TavilySearchResults(max_results=3) query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() ) tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser()) tavily_chain = ( {"query": query_chain} | tavily_simple ) research_supervisor_chain = ( prompts.research_supervisor_prompt | models.gpt4o | StrOutputParser() ) # # Reserach Node Defs # def query_qdrant(state: ResearchState) -> ResearchState: topic = state["topic"] result = qdrant_research_chain.invoke({"topic": topic}) print(result) state["research_data"]["qdrant_results"] = result["response"] state['workflow'].append("query_qdrant") print(state['workflow']) return state def web_search(state: ResearchState) -> ResearchState: topic = state["topic"] qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.") result = tavily_chain.invoke({"topic": topic,"qdrant_results": qdrant_results }) print(result) state["research_data"]["web_search_results"] = result state['workflow'].append("web_search") print(state['workflow']) return state def research_supervisor(state): message_from_manager = state["message_from_manager"] collected_data = state["research_data"] topic = state['topic'] supervisor_result = research_supervisor_chain.invoke({"message_from_manager": message_from_manager, "collected_data": collected_data, "topic": topic}) lines = supervisor_result.split('\n') print(supervisor_result) for line in lines: if line.startswith('Next Action: '): state['next'] = line[len('Next Action: '):].strip() # Extract the next action content elif line.startswith('Message to project manager: '): state['message_to_manager'] = line[len('Message to project manager: '):].strip() state['workflow'].append("research_supervisor") print(state['workflow']) return state def research_end(state): state['workflow'].append("research_end") print(state['workflow']) return state ####################################### ### Writing Team Components ### ####################################### class WritingState(TypedDict): workflow: List[str] topic: str research_data: Dict[str, str] draft_posts: Sequence[str] final_post: str next: str message_to_manager: str message_from_manager: str review_comments: str style_checked: bool # # Writing Chains # writing_supervisor_chain = ( prompts.writing_supervisor_prompt | models.gpt4o | StrOutputParser() ) post_creation_chain = ( prompts.post_creation_prompt | models.gpt4o_mini | StrOutputParser() ) post_editor_chain = ( prompts.post_editor_prompt | models.gpt4o | StrOutputParser() ) post_review_chain = ( prompts.post_review_prompt | models.gpt4o | StrOutputParser() ) # # Writing Node Defs # def post_creation(state): topic = state['topic'] drafts = state['draft_posts'] collected_data = state["research_data"] review_comments = state['review_comments'] results = post_creation_chain.invoke({"topic": topic, "collected_data": collected_data, "drafts": drafts, "review_comments": review_comments}) print(results) state['draft_posts'].append(results) state['workflow'].append("post_creation") print(state['workflow']) return state def post_editor(state): current_draft = state['draft_posts'][-1] styleguide = prompts.style_guide_text review_comments = state['review_comments'] results = post_editor_chain.invoke({"current_draft": current_draft, "styleguide": styleguide, "review_comments": review_comments}) print(results) state['draft_posts'].append(results) state['workflow'].append("post_editor") print(state['workflow']) return state def post_review(state): print("post_review node") current_draft = state['draft_posts'][-1] styleguide = prompts.style_guide_text results = post_review_chain.invoke({"current_draft": current_draft, "styleguide": styleguide}) print(results) data = json.loads(results.strip()) state['review_comments'] = data["Comments on current draft"] if data["Draft Acceptable"] == 'Yes': state['final_post'] = state['draft_posts'][-1] state['workflow'].append("post_review") print(state['workflow']) return state def writing_end(state): print("writing_end node") state['workflow'].append("writing_end") print(state['workflow']) return state def writing_supervisor(state): print("writing_supervisor node") message_from_manager = state['message_from_manager'] topic = state['topic'] drafts = state['draft_posts'] final_draft = state['final_post'] review_comments = state['review_comments'] supervisor_result = writing_supervisor_chain.invoke({"review_comments": review_comments, "message_from_manager": message_from_manager, "topic": topic, "drafts": drafts, "final_draft": final_draft}) print(supervisor_result) lines = supervisor_result.split('\n') for line in lines: if line.startswith('Next Action: '): state['next'] = line[len('Next Action: '):].strip() # Extract the next action content elif line.startswith('Message to project manager: '): state['message_to_manager'] = line[len('Message to project manager: '):].strip() state['workflow'].append("writing_supervisor") print(state['workflow']) return state ####################################### ### Overarching Graph Components ### ####################################### class State(TypedDict): workflow: List[str] topic: str research_data: Dict[str, str] draft_posts: Sequence[str] final_post: str next: str user_input: str message_to_manager: str message_from_manager: str last_active_team :str next_team: str review_comments: str # # Complete Graph Chains # overall_supervisor_chain = ( prompts.overall_supervisor_prompt | models.gpt4o | StrOutputParser() ) # # Complete Graph Node defs # def overall_supervisor(state): init_user_query = state["user_input"] message_to_manager = state['message_to_manager'] last_active_team = state['last_active_team'] final_post = state['final_post'] supervisor_result = overall_supervisor_chain.invoke({"query": init_user_query, "message_to_manager": message_to_manager, "last_active_team": last_active_team, "final_post": final_post}) print(supervisor_result) lines = supervisor_result.split('\n') for line in lines: if line.startswith('Next Action: '): state['next_team'] = line[len('Next Action: '):].strip() # Extract the next action content elif line.startswith('Extracted Topic: '): state['topic'] = line[len('Extracted Topic: '):].strip() # Extract the next action content elif line.startswith('Message to supervisor: '): state['message_from_manager'] = line[len('Message to supervisor: '):].strip() # Extract the next action content state['workflow'].append("overall_supervisor") print(state['workflow']) return state ####################################### ### Graph structures ### ####################################### # # Reserach Graph Nodes # research_graph = StateGraph(ResearchState) research_graph.add_node("query_qdrant", query_qdrant) research_graph.add_node("web_search", web_search) research_graph.add_node("research_supervisor", research_supervisor) research_graph.add_node("research_end", research_end) # # Reserach Graph Edges # research_graph.set_entry_point("research_supervisor") research_graph.add_edge("query_qdrant", "research_supervisor") research_graph.add_edge("web_search", "research_supervisor") research_graph.add_conditional_edges( "research_supervisor", lambda x: x["next"], {"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": "research_end"}, ) research_graph_comp = research_graph.compile() # # Writing Graph Nodes # writing_graph = StateGraph(WritingState) writing_graph.add_node("post_creation", post_creation) writing_graph.add_node("post_editor", post_editor) writing_graph.add_node("post_review", post_review) writing_graph.add_node("writing_supervisor", writing_supervisor) writing_graph.add_node("writing_end", writing_end) # # Writing Graph Edges # writing_graph.set_entry_point("writing_supervisor") writing_graph.add_edge("post_creation", "post_editor") writing_graph.add_edge("post_editor", "post_review") writing_graph.add_edge("post_review", "writing_supervisor") writing_graph.add_conditional_edges( "writing_supervisor", lambda x: x["next"], {"NEW DRAFT": "post_creation", "FINISH": "writing_end"}, ) writing_graph_comp = writing_graph.compile() # # Complete Graph Nodes # overall_graph = StateGraph(State) overall_graph.add_node("overall_supervisor", overall_supervisor) overall_graph.add_node("research_team_graph", research_graph_comp) overall_graph.add_node("writing_team_graph", writing_graph_comp) # # Complete Graph Edges # overall_graph.set_entry_point("overall_supervisor") overall_graph.add_edge("research_team_graph", "overall_supervisor") overall_graph.add_edge("writing_team_graph", "overall_supervisor") overall_graph.add_conditional_edges( "overall_supervisor", lambda x: x["next_team"], {"research_team": "research_team_graph", "writing_team": "writing_team_graph", "FINISH": END}, ) app = overall_graph.compile() ####################################### ### Run method ### ####################################### def getSocialMediaPost(userInput: str) -> str: finalPost = "" initial_state = State( workflow = [], topic= "", research_data = {}, draft_posts = [], final_post = [], next = [], next_team = [], user_input=userInput, message_to_manager="", message_from_manager="", last_active_team="", review_comments="" ) results = app.invoke(initial_state, {"recursion_limit": 40}) try: results = app.invoke(initial_state, {"recursion_limit": 40}) except GraphRecursionError: return "Recursion Error" finalPost = results['final_post'] return finalPost