163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import chromadb
|
|
from langgraph.graph import END, START, StateGraph
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import func
|
|
|
|
from app.db.tables import InvestorTable
|
|
from app.pydantic_schemas import QueryResponse, QueryResponseList
|
|
|
|
|
|
class AgentState(BaseModel):
|
|
question: str
|
|
sql_results: List[QueryResponse] = Field(default_factory=list)
|
|
vector_results: List[QueryResponse] = Field(default_factory=list)
|
|
|
|
|
|
class LangGraphQueryAgent:
|
|
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
|
|
|
|
def __init__(
|
|
self,
|
|
sql_session: Optional[object] = None,
|
|
vector_db_client: Optional[object] = None,
|
|
) -> None:
|
|
self.sql_session = sql_session
|
|
|
|
# Setup Chroma collection
|
|
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
|
|
path="./chroma_db"
|
|
)
|
|
self.collection = self.vector_db_client.get_or_create_collection(
|
|
name="investor_descriptions",
|
|
metadata={
|
|
"description": "Investor descriptions and investment thesis focus",
|
|
},
|
|
)
|
|
|
|
# Build graph
|
|
graph = StateGraph(AgentState)
|
|
graph.add_node("sql_search", self._sql_search)
|
|
graph.add_node("vector_search", self._vector_search)
|
|
graph.add_node("merge", self._merge)
|
|
|
|
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
|
|
graph.add_edge(START, "sql_search")
|
|
graph.add_edge(START, "vector_search")
|
|
graph.add_edge("sql_search", "merge")
|
|
graph.add_edge("vector_search", "merge")
|
|
graph.add_edge("merge", END)
|
|
|
|
self.app = graph.compile()
|
|
|
|
# Nodes
|
|
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
|
|
results: List[QueryResponse] = []
|
|
if self.sql_session is None:
|
|
return {"sql_results": results}
|
|
|
|
# Simple LIKE-based search across a few fields
|
|
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
|
|
q = (
|
|
self.sql_session.query(InvestorTable)
|
|
.filter(
|
|
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
|
|
| (
|
|
func.lower(InvestorTable.sector_focus).like(
|
|
f"%{state.question.lower()}%"
|
|
)
|
|
)
|
|
| (
|
|
func.lower(InvestorTable.stage_focus).like(
|
|
f"%{state.question.lower()}%"
|
|
)
|
|
)
|
|
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
|
|
)
|
|
.limit(10)
|
|
)
|
|
|
|
for row in q.all():
|
|
results.append(
|
|
QueryResponse(
|
|
name=row.name,
|
|
aum=row.aum,
|
|
check_size=row.check_size,
|
|
sector_focus=row.sector_focus,
|
|
stage_focus=row.stage_focus,
|
|
region=row.region,
|
|
investment_thesis="",
|
|
investor_description="",
|
|
reason="Matched SQL fields via LIKE",
|
|
)
|
|
)
|
|
|
|
return {"sql_results": results}
|
|
|
|
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
|
|
results: List[QueryResponse] = []
|
|
try:
|
|
q = self.collection.query(query_texts=[state.question], n_results=10)
|
|
# q has keys: ids, distances, documents, metadatas
|
|
docs = q.get("documents") or []
|
|
metas = q.get("metadatas") or []
|
|
if docs and metas:
|
|
for i, md in enumerate(metas[0]):
|
|
name = md.get("name", "Unknown")
|
|
results.append(
|
|
QueryResponse(
|
|
name=name,
|
|
aum=0,
|
|
check_size="",
|
|
sector_focus="",
|
|
stage_focus="",
|
|
region=md.get("headquarters", ""),
|
|
investment_thesis="",
|
|
investor_description=(docs[0][i] if docs[0] else ""),
|
|
reason="Vector similarity in Chroma",
|
|
)
|
|
)
|
|
except Exception:
|
|
# Best-effort; leave vector results empty on failure
|
|
pass
|
|
|
|
return {"vector_results": results}
|
|
|
|
def _merge(self, state: AgentState) -> Dict[str, Any]:
|
|
# Deduplicate by name, prefer SQL fields where available, keep first reason
|
|
merged: Dict[str, QueryResponse] = {}
|
|
|
|
for item in state.vector_results + state.sql_results:
|
|
if item.name not in merged:
|
|
merged[item.name] = item
|
|
else:
|
|
existing = merged[item.name]
|
|
merged[item.name] = QueryResponse(
|
|
name=existing.name,
|
|
aum=existing.aum or item.aum,
|
|
check_size=existing.check_size or item.check_size,
|
|
sector_focus=existing.sector_focus or item.sector_focus,
|
|
stage_focus=existing.stage_focus or item.stage_focus,
|
|
region=existing.region or item.region,
|
|
investment_thesis=existing.investment_thesis
|
|
or item.investment_thesis,
|
|
investor_description=existing.investor_description
|
|
or item.investor_description,
|
|
reason=existing.reason or item.reason,
|
|
)
|
|
|
|
# Store back into sql_results to pass through the END with full state
|
|
return {
|
|
"sql_results": list(merged.values()),
|
|
"vector_results": [],
|
|
}
|
|
|
|
# Public API
|
|
def run(self, question: str) -> QueryResponseList:
|
|
state = AgentState(question=question)
|
|
final_state: AgentState = self.app.invoke(state)
|
|
return QueryResponseList(responses=final_state.sql_results)
|