Files
Anton_wireframe/app/services/langgraph_agent.py
T

163 lines
6.0 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional
import chromadb
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field
from sqlalchemy import func
from app.db.tables import InvestorTable
from app.pydantic_schemas import QueryResponse, QueryResponseList
class AgentState(BaseModel):
question: str
sql_results: List[QueryResponse] = Field(default_factory=list)
vector_results: List[QueryResponse] = Field(default_factory=list)
class LangGraphQueryAgent:
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
) -> None:
self.sql_session = sql_session
# Setup Chroma collection
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
path="./chroma_db"
)
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus",
},
)
# Build graph
graph = StateGraph(AgentState)
graph.add_node("sql_search", self._sql_search)
graph.add_node("vector_search", self._vector_search)
graph.add_node("merge", self._merge)
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
graph.add_edge(START, "sql_search")
graph.add_edge(START, "vector_search")
graph.add_edge("sql_search", "merge")
graph.add_edge("vector_search", "merge")
graph.add_edge("merge", END)
self.app = graph.compile()
# Nodes
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
results: List[QueryResponse] = []
if self.sql_session is None:
return {"sql_results": results}
# Simple LIKE-based search across a few fields
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
q = (
self.sql_session.query(InvestorTable)
.filter(
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
| (
func.lower(InvestorTable.sector_focus).like(
f"%{state.question.lower()}%"
)
)
| (
func.lower(InvestorTable.stage_focus).like(
f"%{state.question.lower()}%"
)
)
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
)
.limit(10)
)
for row in q.all():
results.append(
QueryResponse(
name=row.name,
aum=row.aum,
check_size=row.check_size,
sector_focus=row.sector_focus,
stage_focus=row.stage_focus,
region=row.region,
investment_thesis="",
investor_description="",
reason="Matched SQL fields via LIKE",
)
)
return {"sql_results": results}
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
results: List[QueryResponse] = []
try:
q = self.collection.query(query_texts=[state.question], n_results=10)
# q has keys: ids, distances, documents, metadatas
docs = q.get("documents") or []
metas = q.get("metadatas") or []
if docs and metas:
for i, md in enumerate(metas[0]):
name = md.get("name", "Unknown")
results.append(
QueryResponse(
name=name,
aum=0,
check_size="",
sector_focus="",
stage_focus="",
region=md.get("headquarters", ""),
investment_thesis="",
investor_description=(docs[0][i] if docs[0] else ""),
reason="Vector similarity in Chroma",
)
)
except Exception:
# Best-effort; leave vector results empty on failure
pass
return {"vector_results": results}
def _merge(self, state: AgentState) -> Dict[str, Any]:
# Deduplicate by name, prefer SQL fields where available, keep first reason
merged: Dict[str, QueryResponse] = {}
for item in state.vector_results + state.sql_results:
if item.name not in merged:
merged[item.name] = item
else:
existing = merged[item.name]
merged[item.name] = QueryResponse(
name=existing.name,
aum=existing.aum or item.aum,
check_size=existing.check_size or item.check_size,
sector_focus=existing.sector_focus or item.sector_focus,
stage_focus=existing.stage_focus or item.stage_focus,
region=existing.region or item.region,
investment_thesis=existing.investment_thesis
or item.investment_thesis,
investor_description=existing.investor_description
or item.investor_description,
reason=existing.reason or item.reason,
)
# Store back into sql_results to pass through the END with full state
return {
"sql_results": list(merged.values()),
"vector_results": [],
}
# Public API
def run(self, question: str) -> QueryResponseList:
state = AgentState(question=question)
final_state: AgentState = self.app.invoke(state)
return QueryResponseList(responses=final_state.sql_results)