Add CompanyTable model and refactor query handling; update requirements for new dependencies
This commit is contained in:
@@ -1,162 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chromadb
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db.tables import InvestorTable
|
||||
from app.pydantic_schemas import QueryResponse, QueryResponseList
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
question: str
|
||||
sql_results: List[QueryResponse] = Field(default_factory=list)
|
||||
vector_results: List[QueryResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class LangGraphQueryAgent:
|
||||
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
) -> None:
|
||||
self.sql_session = sql_session
|
||||
|
||||
# Setup Chroma collection
|
||||
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
|
||||
path="./chroma_db"
|
||||
)
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus",
|
||||
},
|
||||
)
|
||||
|
||||
# Build graph
|
||||
graph = StateGraph(AgentState)
|
||||
graph.add_node("sql_search", self._sql_search)
|
||||
graph.add_node("vector_search", self._vector_search)
|
||||
graph.add_node("merge", self._merge)
|
||||
|
||||
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
|
||||
graph.add_edge(START, "sql_search")
|
||||
graph.add_edge(START, "vector_search")
|
||||
graph.add_edge("sql_search", "merge")
|
||||
graph.add_edge("vector_search", "merge")
|
||||
graph.add_edge("merge", END)
|
||||
|
||||
self.app = graph.compile()
|
||||
|
||||
# Nodes
|
||||
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
|
||||
results: List[QueryResponse] = []
|
||||
if self.sql_session is None:
|
||||
return {"sql_results": results}
|
||||
|
||||
# Simple LIKE-based search across a few fields
|
||||
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
|
||||
q = (
|
||||
self.sql_session.query(InvestorTable)
|
||||
.filter(
|
||||
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
|
||||
| (
|
||||
func.lower(InvestorTable.sector_focus).like(
|
||||
f"%{state.question.lower()}%"
|
||||
)
|
||||
)
|
||||
| (
|
||||
func.lower(InvestorTable.stage_focus).like(
|
||||
f"%{state.question.lower()}%"
|
||||
)
|
||||
)
|
||||
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
for row in q.all():
|
||||
results.append(
|
||||
QueryResponse(
|
||||
name=row.name,
|
||||
aum=row.aum,
|
||||
check_size=row.check_size,
|
||||
sector_focus=row.sector_focus,
|
||||
stage_focus=row.stage_focus,
|
||||
region=row.region,
|
||||
investment_thesis="",
|
||||
investor_description="",
|
||||
reason="Matched SQL fields via LIKE",
|
||||
)
|
||||
)
|
||||
|
||||
return {"sql_results": results}
|
||||
|
||||
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
|
||||
results: List[QueryResponse] = []
|
||||
try:
|
||||
q = self.collection.query(query_texts=[state.question], n_results=10)
|
||||
# q has keys: ids, distances, documents, metadatas
|
||||
docs = q.get("documents") or []
|
||||
metas = q.get("metadatas") or []
|
||||
if docs and metas:
|
||||
for i, md in enumerate(metas[0]):
|
||||
name = md.get("name", "Unknown")
|
||||
results.append(
|
||||
QueryResponse(
|
||||
name=name,
|
||||
aum=0,
|
||||
check_size="",
|
||||
sector_focus="",
|
||||
stage_focus="",
|
||||
region=md.get("headquarters", ""),
|
||||
investment_thesis="",
|
||||
investor_description=(docs[0][i] if docs[0] else ""),
|
||||
reason="Vector similarity in Chroma",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort; leave vector results empty on failure
|
||||
pass
|
||||
|
||||
return {"vector_results": results}
|
||||
|
||||
def _merge(self, state: AgentState) -> Dict[str, Any]:
|
||||
# Deduplicate by name, prefer SQL fields where available, keep first reason
|
||||
merged: Dict[str, QueryResponse] = {}
|
||||
|
||||
for item in state.vector_results + state.sql_results:
|
||||
if item.name not in merged:
|
||||
merged[item.name] = item
|
||||
else:
|
||||
existing = merged[item.name]
|
||||
merged[item.name] = QueryResponse(
|
||||
name=existing.name,
|
||||
aum=existing.aum or item.aum,
|
||||
check_size=existing.check_size or item.check_size,
|
||||
sector_focus=existing.sector_focus or item.sector_focus,
|
||||
stage_focus=existing.stage_focus or item.stage_focus,
|
||||
region=existing.region or item.region,
|
||||
investment_thesis=existing.investment_thesis
|
||||
or item.investment_thesis,
|
||||
investor_description=existing.investor_description
|
||||
or item.investor_description,
|
||||
reason=existing.reason or item.reason,
|
||||
)
|
||||
|
||||
# Store back into sql_results to pass through the END with full state
|
||||
return {
|
||||
"sql_results": list(merged.values()),
|
||||
"vector_results": [],
|
||||
}
|
||||
|
||||
# Public API
|
||||
def run(self, question: str) -> QueryResponseList:
|
||||
state = AgentState(question=question)
|
||||
final_state: AgentState = self.app.invoke(state)
|
||||
return QueryResponseList(responses=final_state.sql_results)
|
||||
|
||||
Reference in New Issue
Block a user