from __future__ import annotations from typing import Any, Dict, List, Optional import chromadb from langgraph.graph import END, START, StateGraph from pydantic import BaseModel, Field from sqlalchemy import func from app.db.tables import InvestorTable from app.pydantic_schemas import QueryResponse, QueryResponseList class AgentState(BaseModel): question: str sql_results: List[QueryResponse] = Field(default_factory=list) vector_results: List[QueryResponse] = Field(default_factory=list) class LangGraphQueryAgent: """Simple LangGraph agent that queries both SQL and Chroma and merges results.""" def __init__( self, sql_session: Optional[object] = None, vector_db_client: Optional[object] = None, ) -> None: self.sql_session = sql_session # Setup Chroma collection self.vector_db_client = vector_db_client or chromadb.PersistentClient( path="./chroma_db" ) self.collection = self.vector_db_client.get_or_create_collection( name="investor_descriptions", metadata={ "description": "Investor descriptions and investment thesis focus", }, ) # Build graph graph = StateGraph(AgentState) graph.add_node("sql_search", self._sql_search) graph.add_node("vector_search", self._vector_search) graph.add_node("merge", self._merge) # Parallel fan-out: START -> sql_search & vector_search -> merge -> END graph.add_edge(START, "sql_search") graph.add_edge(START, "vector_search") graph.add_edge("sql_search", "merge") graph.add_edge("vector_search", "merge") graph.add_edge("merge", END) self.app = graph.compile() # Nodes def _sql_search(self, state: AgentState) -> Dict[str, Any]: results: List[QueryResponse] = [] if self.sql_session is None: return {"sql_results": results} # Simple LIKE-based search across a few fields # Note: SQLite uses case-insensitive LIKE by default for ASCII. q = ( self.sql_session.query(InvestorTable) .filter( (func.lower(InvestorTable.name).like(f"%{state.question.lower()}%")) | ( func.lower(InvestorTable.sector_focus).like( f"%{state.question.lower()}%" ) ) | ( func.lower(InvestorTable.stage_focus).like( f"%{state.question.lower()}%" ) ) | (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%")) ) .limit(10) ) for row in q.all(): results.append( QueryResponse( name=row.name, aum=row.aum, check_size=row.check_size, sector_focus=row.sector_focus, stage_focus=row.stage_focus, region=row.region, investment_thesis="", investor_description="", reason="Matched SQL fields via LIKE", ) ) return {"sql_results": results} def _vector_search(self, state: AgentState) -> Dict[str, Any]: results: List[QueryResponse] = [] try: q = self.collection.query(query_texts=[state.question], n_results=10) # q has keys: ids, distances, documents, metadatas docs = q.get("documents") or [] metas = q.get("metadatas") or [] if docs and metas: for i, md in enumerate(metas[0]): name = md.get("name", "Unknown") results.append( QueryResponse( name=name, aum=0, check_size="", sector_focus="", stage_focus="", region=md.get("headquarters", ""), investment_thesis="", investor_description=(docs[0][i] if docs[0] else ""), reason="Vector similarity in Chroma", ) ) except Exception: # Best-effort; leave vector results empty on failure pass return {"vector_results": results} def _merge(self, state: AgentState) -> Dict[str, Any]: # Deduplicate by name, prefer SQL fields where available, keep first reason merged: Dict[str, QueryResponse] = {} for item in state.vector_results + state.sql_results: if item.name not in merged: merged[item.name] = item else: existing = merged[item.name] merged[item.name] = QueryResponse( name=existing.name, aum=existing.aum or item.aum, check_size=existing.check_size or item.check_size, sector_focus=existing.sector_focus or item.sector_focus, stage_focus=existing.stage_focus or item.stage_focus, region=existing.region or item.region, investment_thesis=existing.investment_thesis or item.investment_thesis, investor_description=existing.investor_description or item.investor_description, reason=existing.reason or item.reason, ) # Store back into sql_results to pass through the END with full state return { "sql_results": list(merged.values()), "vector_results": [], } # Public API def run(self, question: str) -> QueryResponseList: state = AgentState(question=question) final_state: AgentState = self.app.invoke(state) return QueryResponseList(responses=final_state.sql_results)