198 lines
6.3 KiB
Python
198 lines
6.3 KiB
Python
import json
|
|
from typing import List, Dict, Optional, TypedDict, Sequence, Annotated
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
from langgraph.graph import START, MessagesState, StateGraph
|
|
from utils.utils import format_questions_text
|
|
from src.prompts import chat_prompt
|
|
from langchain_openai import ChatOpenAI
|
|
@dataclass
|
|
class Message:
|
|
role: str # 'human' or 'ai'
|
|
content: str
|
|
timestamp: str
|
|
|
|
QUESTIONS_PATH = "./data/config_files/questions.json"
|
|
with open(QUESTIONS_PATH, "r") as f:
|
|
questions = json.load(f)
|
|
|
|
prompt_template = None
|
|
MODEL = "gpt-4o-mini"
|
|
def initialize_workflow(model) -> StateGraph:
|
|
"""Initialize LangGraph workflow"""
|
|
workflow = StateGraph(state_schema=MessagesState)
|
|
memory = MemorySaver()
|
|
|
|
def call_model(state: MessagesState):
|
|
prompt = prompt_template.invoke({"messages": state["messages"], "language": state["language"]})
|
|
response = model.invoke(prompt)
|
|
return {"messages": [response]}
|
|
|
|
workflow.add_edge(START, "model")
|
|
workflow.add_node("model", call_model)
|
|
return workflow.compile(checkpointer=memory)
|
|
|
|
|
|
def setup_prompt_template(theme: int, resume: str) -> ChatPromptTemplate:
|
|
"""Set up the prompt template"""
|
|
return ChatPromptTemplate.from_messages([
|
|
("system", chat_prompt(theme, resume)),
|
|
MessagesPlaceholder(variable_name="messages")
|
|
])
|
|
|
|
def parse_ai_response(content: str) -> Dict:
|
|
"""Parse AI response content into expected format"""
|
|
try:
|
|
response = json.loads(content)
|
|
return {
|
|
"message": response.get("message", ""),
|
|
"end": response.get("end", "no") == "yes"
|
|
}
|
|
except json.JSONDecodeError:
|
|
return {
|
|
"message": content,
|
|
"end": False
|
|
}
|
|
|
|
def add_message(storage_path: Path, conversation_id: str, role: str, content: str) -> None:
|
|
"""Add a message to the conversation history"""
|
|
message_data = {
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
conversations = load_conversations(storage_path)
|
|
if conversation_id not in conversations:
|
|
conversations[conversation_id] = {"messages": []}
|
|
conversations[conversation_id]["messages"].append(message_data)
|
|
save_conversations(storage_path, conversations)
|
|
|
|
|
|
def get_conversation_history(conversation_id: str, storage_path: Path) -> List[Message]:
|
|
"""Get the conversation history"""
|
|
conversations = load_conversations(storage_path)
|
|
if conversation_id not in conversations:
|
|
return None
|
|
|
|
return [
|
|
Message(
|
|
role=msg["role"],
|
|
content=msg["content"],
|
|
timestamp=msg["timestamp"]
|
|
)
|
|
for msg in conversations[conversation_id]["messages"]
|
|
]
|
|
|
|
def load_conversations(storage_path: Path) -> Dict:
|
|
"""Load conversations from storage file"""
|
|
try:
|
|
with open(storage_path, 'r') as f:
|
|
return json.load(f)
|
|
except FileNotFoundError:
|
|
return {}
|
|
|
|
def save_conversations(storage_path: Path, conversations: Dict) -> None:
|
|
"""Save conversations to storage file"""
|
|
with open(storage_path, 'w') as f:
|
|
json.dump(conversations, f, indent=2)
|
|
|
|
def convert_to_langchain_messages(messages: List[Message]) -> List[HumanMessage | AIMessage]:
|
|
"""Convert our Message objects to LangChain message objects"""
|
|
converted_messages = []
|
|
for msg in messages:
|
|
if msg.role == "human":
|
|
converted_messages.append(HumanMessage(content=msg.content))
|
|
else:
|
|
converted_messages.append(AIMessage(content=msg.content))
|
|
return converted_messages
|
|
|
|
|
|
def ai_chat(query: str, conversation_id: str, theme_id: int, resume: str) -> str:
|
|
"""Main chat function that processes queries and manages conversation"""
|
|
storage_path = Path("conversations.json")
|
|
|
|
class State(TypedDict):
|
|
messages: Annotated[Sequence[BaseMessage], "The messages in the conversation"]
|
|
language: str
|
|
# Initialize model and workflow
|
|
model = ChatOpenAI(model=MODEL)
|
|
workflow = StateGraph(state_schema=State)
|
|
|
|
|
|
|
|
def call_model(state: State):
|
|
prompt_template = setup_prompt_template(theme_id, resume)
|
|
prompt = prompt_template.invoke({
|
|
"messages": state["messages"],
|
|
"language": state["language"]
|
|
})
|
|
response = model.invoke(prompt)
|
|
return {"messages": [response]}
|
|
|
|
workflow.add_edge(START, "model")
|
|
workflow.add_node("model", call_model)
|
|
|
|
memory = MemorySaver()
|
|
app = workflow.compile(checkpointer=memory)
|
|
|
|
# Get conversation history
|
|
history = get_conversation_history(conversation_id, storage_path)
|
|
|
|
config = {"configurable": {"thread_id": conversation_id}}
|
|
language = "English"
|
|
|
|
if not history:
|
|
# New conversation
|
|
input_messages = [HumanMessage(content=query)] if query else [HumanMessage(content="Let's get started")]
|
|
output = app.invoke(
|
|
{"messages": input_messages, "language": language},
|
|
config
|
|
)
|
|
else:
|
|
# Existing conversation
|
|
history = convert_to_langchain_messages(history)
|
|
input_messages = history + [HumanMessage(content=query)] if query else history
|
|
output = app.invoke(
|
|
{"messages": input_messages, "language": language},
|
|
config
|
|
)
|
|
|
|
# Store messages
|
|
if query:
|
|
add_message(storage_path, conversation_id, "human", query)
|
|
add_message(storage_path, conversation_id, "ai", output["messages"][-1].content)
|
|
|
|
return output["messages"][-1].content
|
|
|
|
|
|
# Example usage:
|
|
if __name__ == "__main__":
|
|
# Sample resume
|
|
sample_resume = """
|
|
John Doe
|
|
EMT-B Certified
|
|
5 years experience as volunteer firefighter
|
|
Bachelor's in Fire Science
|
|
"""
|
|
|
|
# Sample conversation
|
|
conversation_id = "12345"
|
|
theme_id = 1 # Customer Service theme
|
|
|
|
# Start conversation
|
|
|
|
|
|
# Continue conversation
|
|
follow_up = ai_chat(
|
|
query="What was my last questions?",
|
|
conversation_id=conversation_id,
|
|
theme_id=theme_id,
|
|
resume=sample_resume
|
|
)
|
|
print("AI:", follow_up)
|