Spaces:
Paused
Paused
from typing import Dict, List, TypedDict, Sequence | |
from langgraph.graph import StateGraph, END | |
from langchain.schema import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
import models | |
import prompts | |
import json | |
from operator import itemgetter | |
from langgraph.errors import GraphRecursionError | |
####################################### | |
### Research Team Components ### | |
####################################### | |
class ResearchState(TypedDict): | |
workflow: List[str] | |
topic: str | |
research_data: Dict[str, str] | |
next: str | |
message_to_manager: str | |
message_from_manager: str | |
# | |
# Reserach Chains and Tools | |
# | |
qdrant_research_chain = ( | |
{"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")} | |
) | |
tavily_tool = TavilySearchResults(max_results=3) | |
query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() ) | |
tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser()) | |
tavily_chain = ( | |
{"query": query_chain} | tavily_simple | |
) | |
research_supervisor_chain = ( | |
prompts.research_supervisor_prompt | models.gpt4o | StrOutputParser() | |
) | |
# | |
# Reserach Node Defs | |
# | |
def query_qdrant(state: ResearchState) -> ResearchState: | |
topic = state["topic"] | |
result = qdrant_research_chain.invoke({"topic": topic}) | |
print(result) | |
state["research_data"]["qdrant_results"] = result["response"] | |
state['workflow'].append("query_qdrant") | |
print(state['workflow']) | |
return state | |
def web_search(state: ResearchState) -> ResearchState: | |
topic = state["topic"] | |
qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.") | |
result = tavily_chain.invoke({"topic": topic,"qdrant_results": qdrant_results }) | |
print(result) | |
state["research_data"]["web_search_results"] = result | |
state['workflow'].append("web_search") | |
print(state['workflow']) | |
return state | |
def research_supervisor(state): | |
message_from_manager = state["message_from_manager"] | |
collected_data = state["research_data"] | |
topic = state['topic'] | |
supervisor_result = research_supervisor_chain.invoke({"message_from_manager": message_from_manager, "collected_data": collected_data, "topic": topic}) | |
lines = supervisor_result.split('\n') | |
print(supervisor_result) | |
for line in lines: | |
if line.startswith('Next Action: '): | |
state['next'] = line[len('Next Action: '):].strip() # Extract the next action content | |
elif line.startswith('Message to project manager: '): | |
state['message_to_manager'] = line[len('Message to project manager: '):].strip() | |
state['workflow'].append("research_supervisor") | |
print(state['workflow']) | |
return state | |
def research_end(state): | |
state['workflow'].append("research_end") | |
print(state['workflow']) | |
return state | |
####################################### | |
### Writing Team Components ### | |
####################################### | |
class WritingState(TypedDict): | |
workflow: List[str] | |
topic: str | |
research_data: Dict[str, str] | |
draft_posts: Sequence[str] | |
final_post: str | |
next: str | |
message_to_manager: str | |
message_from_manager: str | |
review_comments: str | |
style_checked: bool | |
# | |
# Writing Chains | |
# | |
writing_supervisor_chain = ( | |
prompts.writing_supervisor_prompt | models.gpt4o | StrOutputParser() | |
) | |
post_creation_chain = ( | |
prompts.post_creation_prompt | models.gpt4o_mini | StrOutputParser() | |
) | |
post_editor_chain = ( | |
prompts.post_editor_prompt | models.gpt4o | StrOutputParser() | |
) | |
post_review_chain = ( | |
prompts.post_review_prompt | models.gpt4o | StrOutputParser() | |
) | |
# | |
# Writing Node Defs | |
# | |
def post_creation(state): | |
topic = state['topic'] | |
drafts = state['draft_posts'] | |
collected_data = state["research_data"] | |
review_comments = state['review_comments'] | |
results = post_creation_chain.invoke({"topic": topic, "collected_data": collected_data, "drafts": drafts, "review_comments": review_comments}) | |
print(results) | |
state['draft_posts'].append(results) | |
state['workflow'].append("post_creation") | |
print(state['workflow']) | |
return state | |
def post_editor(state): | |
current_draft = state['draft_posts'][-1] | |
styleguide = prompts.style_guide_text | |
review_comments = state['review_comments'] | |
results = post_editor_chain.invoke({"current_draft": current_draft, "styleguide": styleguide, "review_comments": review_comments}) | |
print(results) | |
state['draft_posts'].append(results) | |
state['workflow'].append("post_editor") | |
print(state['workflow']) | |
return state | |
def post_review(state): | |
print("post_review node") | |
current_draft = state['draft_posts'][-1] | |
styleguide = prompts.style_guide_text | |
results = post_review_chain.invoke({"current_draft": current_draft, "styleguide": styleguide}) | |
print(results) | |
data = json.loads(results.strip()) | |
state['review_comments'] = data["Comments on current draft"] | |
if data["Draft Acceptable"] == 'Yes': | |
state['final_post'] = state['draft_posts'][-1] | |
state['workflow'].append("post_review") | |
print(state['workflow']) | |
return state | |
def writing_end(state): | |
print("writing_end node") | |
state['workflow'].append("writing_end") | |
print(state['workflow']) | |
return state | |
def writing_supervisor(state): | |
print("writing_supervisor node") | |
message_from_manager = state['message_from_manager'] | |
topic = state['topic'] | |
drafts = state['draft_posts'] | |
final_draft = state['final_post'] | |
review_comments = state['review_comments'] | |
supervisor_result = writing_supervisor_chain.invoke({"review_comments": review_comments, "message_from_manager": message_from_manager, "topic": topic, "drafts": drafts, "final_draft": final_draft}) | |
print(supervisor_result) | |
lines = supervisor_result.split('\n') | |
for line in lines: | |
if line.startswith('Next Action: '): | |
state['next'] = line[len('Next Action: '):].strip() # Extract the next action content | |
elif line.startswith('Message to project manager: '): | |
state['message_to_manager'] = line[len('Message to project manager: '):].strip() | |
state['workflow'].append("writing_supervisor") | |
print(state['workflow']) | |
return state | |
####################################### | |
### Overarching Graph Components ### | |
####################################### | |
class State(TypedDict): | |
workflow: List[str] | |
topic: str | |
research_data: Dict[str, str] | |
draft_posts: Sequence[str] | |
final_post: str | |
next: str | |
user_input: str | |
message_to_manager: str | |
message_from_manager: str | |
last_active_team :str | |
next_team: str | |
review_comments: str | |
# | |
# Complete Graph Chains | |
# | |
overall_supervisor_chain = ( | |
prompts.overall_supervisor_prompt | models.gpt4o | StrOutputParser() | |
) | |
# | |
# Complete Graph Node defs | |
# | |
def overall_supervisor(state): | |
init_user_query = state["user_input"] | |
message_to_manager = state['message_to_manager'] | |
last_active_team = state['last_active_team'] | |
final_post = state['final_post'] | |
supervisor_result = overall_supervisor_chain.invoke({"query": init_user_query, "message_to_manager": message_to_manager, "last_active_team": last_active_team, "final_post": final_post}) | |
print(supervisor_result) | |
lines = supervisor_result.split('\n') | |
for line in lines: | |
if line.startswith('Next Action: '): | |
state['next_team'] = line[len('Next Action: '):].strip() # Extract the next action content | |
elif line.startswith('Extracted Topic: '): | |
state['topic'] = line[len('Extracted Topic: '):].strip() # Extract the next action content | |
elif line.startswith('Message to supervisor: '): | |
state['message_from_manager'] = line[len('Message to supervisor: '):].strip() # Extract the next action content | |
state['workflow'].append("overall_supervisor") | |
print(state['workflow']) | |
return state | |
####################################### | |
### Graph structures ### | |
####################################### | |
# | |
# Reserach Graph Nodes | |
# | |
research_graph = StateGraph(ResearchState) | |
research_graph.add_node("query_qdrant", query_qdrant) | |
research_graph.add_node("web_search", web_search) | |
research_graph.add_node("research_supervisor", research_supervisor) | |
research_graph.add_node("research_end", research_end) | |
# | |
# Reserach Graph Edges | |
# | |
research_graph.set_entry_point("research_supervisor") | |
research_graph.add_edge("query_qdrant", "research_supervisor") | |
research_graph.add_edge("web_search", "research_supervisor") | |
research_graph.add_conditional_edges( | |
"research_supervisor", | |
lambda x: x["next"], | |
{"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": "research_end"}, | |
) | |
research_graph_comp = research_graph.compile() | |
# | |
# Writing Graph Nodes | |
# | |
writing_graph = StateGraph(WritingState) | |
writing_graph.add_node("post_creation", post_creation) | |
writing_graph.add_node("post_editor", post_editor) | |
writing_graph.add_node("post_review", post_review) | |
writing_graph.add_node("writing_supervisor", writing_supervisor) | |
writing_graph.add_node("writing_end", writing_end) | |
# | |
# Writing Graph Edges | |
# | |
writing_graph.set_entry_point("writing_supervisor") | |
writing_graph.add_edge("post_creation", "post_editor") | |
writing_graph.add_edge("post_editor", "post_review") | |
writing_graph.add_edge("post_review", "writing_supervisor") | |
writing_graph.add_conditional_edges( | |
"writing_supervisor", | |
lambda x: x["next"], | |
{"NEW DRAFT": "post_creation", | |
"FINISH": "writing_end"}, | |
) | |
writing_graph_comp = writing_graph.compile() | |
# | |
# Complete Graph Nodes | |
# | |
overall_graph = StateGraph(State) | |
overall_graph.add_node("overall_supervisor", overall_supervisor) | |
overall_graph.add_node("research_team_graph", research_graph_comp) | |
overall_graph.add_node("writing_team_graph", writing_graph_comp) | |
# | |
# Complete Graph Edges | |
# | |
overall_graph.set_entry_point("overall_supervisor") | |
overall_graph.add_edge("research_team_graph", "overall_supervisor") | |
overall_graph.add_edge("writing_team_graph", "overall_supervisor") | |
overall_graph.add_conditional_edges( | |
"overall_supervisor", | |
lambda x: x["next_team"], | |
{"research_team": "research_team_graph", | |
"writing_team": "writing_team_graph", | |
"FINISH": END}, | |
) | |
app = overall_graph.compile() | |
####################################### | |
### Run method ### | |
####################################### | |
def getSocialMediaPost(userInput: str) -> str: | |
finalPost = "" | |
initial_state = State( | |
workflow = [], | |
topic= "", | |
research_data = {}, | |
draft_posts = [], | |
final_post = [], | |
next = [], | |
next_team = [], | |
user_input=userInput, | |
message_to_manager="", | |
message_from_manager="", | |
last_active_team="", | |
review_comments="" | |
) | |
results = app.invoke(initial_state, {"recursion_limit": 40}) | |
try: | |
results = app.invoke(initial_state, {"recursion_limit": 40}) | |
except GraphRecursionError: | |
return "Recursion Error" | |
finalPost = results['final_post'] | |
return finalPost |