Refactor imports and enhance query functionality with LangGraph integration; update requirements for new dependencies

This commit is contained in:
bolade
2025-08-30 13:56:19 +01:00
parent ba0ed169ce
commit 74931f235e
9 changed files with 179 additions and 11 deletions
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -2,7 +2,7 @@ import datetime
from sqlalchemy import Column, DateTime, Integer, String from sqlalchemy import Column, DateTime, Integer, String
from db.db import Base from app.db.db import Base
class InvestorTable(Base): class InvestorTable(Base):
+9 -8
View File
@@ -1,11 +1,12 @@
import io import io
import pandas as pd import pandas as pd
from db.db import db_dependency, init_database from app.db.db import db_dependency, init_database
from fastapi import FastAPI, File, UploadFile from fastapi import FastAPI, File, UploadFile
from services.openrouter import InvestorProcessor from app.services.openrouter import InvestorProcessor
from app.services.querying import QueryProcessor from app.pydantic_schemas import QueryRequest, QueryResponseList
from app.services.langgraph_agent import LangGraphQueryAgent
app = FastAPI() app = FastAPI()
@@ -31,11 +32,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
return {"results": [r.dict() for r in results]} return {"results": [r.dict() for r in results]}
@app.post("/query") @app.post("/query", response_model=QueryResponseList)
async def query_investors(db: db_dependency, question: str): async def query_investors(db: db_dependency, request: QueryRequest):
processor = QueryProcessor(sql_session=db) agent = LangGraphQueryAgent(sql_session=db)
results = processor.process_query(question) result = agent.run(request.question)
return {"results": [r.dict() for r in results]} return result
if __name__ == "__main__": if __name__ == "__main__":
+162
View File
@@ -0,0 +1,162 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
import chromadb
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field
from sqlalchemy import func
from app.db.tables import InvestorTable
from app.pydantic_schemas import QueryResponse, QueryResponseList
class AgentState(BaseModel):
question: str
sql_results: List[QueryResponse] = Field(default_factory=list)
vector_results: List[QueryResponse] = Field(default_factory=list)
class LangGraphQueryAgent:
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
) -> None:
self.sql_session = sql_session
# Setup Chroma collection
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
path="./chroma_db"
)
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus",
},
)
# Build graph
graph = StateGraph(AgentState)
graph.add_node("sql_search", self._sql_search)
graph.add_node("vector_search", self._vector_search)
graph.add_node("merge", self._merge)
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
graph.add_edge(START, "sql_search")
graph.add_edge(START, "vector_search")
graph.add_edge("sql_search", "merge")
graph.add_edge("vector_search", "merge")
graph.add_edge("merge", END)
self.app = graph.compile()
# Nodes
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
results: List[QueryResponse] = []
if self.sql_session is None:
return {"sql_results": results}
# Simple LIKE-based search across a few fields
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
q = (
self.sql_session.query(InvestorTable)
.filter(
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
| (
func.lower(InvestorTable.sector_focus).like(
f"%{state.question.lower()}%"
)
)
| (
func.lower(InvestorTable.stage_focus).like(
f"%{state.question.lower()}%"
)
)
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
)
.limit(10)
)
for row in q.all():
results.append(
QueryResponse(
name=row.name,
aum=row.aum,
check_size=row.check_size,
sector_focus=row.sector_focus,
stage_focus=row.stage_focus,
region=row.region,
investment_thesis="",
investor_description="",
reason="Matched SQL fields via LIKE",
)
)
return {"sql_results": results}
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
results: List[QueryResponse] = []
try:
q = self.collection.query(query_texts=[state.question], n_results=10)
# q has keys: ids, distances, documents, metadatas
docs = q.get("documents") or []
metas = q.get("metadatas") or []
if docs and metas:
for i, md in enumerate(metas[0]):
name = md.get("name", "Unknown")
results.append(
QueryResponse(
name=name,
aum=0,
check_size="",
sector_focus="",
stage_focus="",
region=md.get("headquarters", ""),
investment_thesis="",
investor_description=(docs[0][i] if docs[0] else ""),
reason="Vector similarity in Chroma",
)
)
except Exception:
# Best-effort; leave vector results empty on failure
pass
return {"vector_results": results}
def _merge(self, state: AgentState) -> Dict[str, Any]:
# Deduplicate by name, prefer SQL fields where available, keep first reason
merged: Dict[str, QueryResponse] = {}
for item in state.vector_results + state.sql_results:
if item.name not in merged:
merged[item.name] = item
else:
existing = merged[item.name]
merged[item.name] = QueryResponse(
name=existing.name,
aum=existing.aum or item.aum,
check_size=existing.check_size or item.check_size,
sector_focus=existing.sector_focus or item.sector_focus,
stage_focus=existing.stage_focus or item.stage_focus,
region=existing.region or item.region,
investment_thesis=existing.investment_thesis
or item.investment_thesis,
investor_description=existing.investor_description
or item.investor_description,
reason=existing.reason or item.reason,
)
# Store back into sql_results to pass through the END with full state
return {
"sql_results": list(merged.values()),
"vector_results": [],
}
# Public API
def run(self, question: str) -> QueryResponseList:
state = AgentState(question=question)
final_state: AgentState = self.app.invoke(state)
return QueryResponseList(responses=final_state.sql_results)
+2 -1
View File
@@ -3,12 +3,13 @@ from typing import List, Optional
import chromadb import chromadb
import pandas as pd import pandas as pd
from db.tables import InvestorTable
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from pydantic_schemas import Investor, InvestorList from pydantic_schemas import Investor, InvestorList
from settings import settings from settings import settings
from app.db.tables import InvestorTable
# Add these imports for your databases # Add these imports for your databases
# from sqlalchemy.ext.asyncio import AsyncSession # from sqlalchemy.ext.asyncio import AsyncSession
# from your_vector_db import VectorDBClient # from your_vector_db import VectorDBClient
+4
View File
@@ -8,9 +8,13 @@ chromadb>=0.4.0
# LLM integration # LLM integration
openai>=1.0.0 openai>=1.0.0
langchain>=0.2.0
langchain-openai>=0.1.0
langgraph>=0.2.0
# Environment management # Environment management
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic-settings>=2.0.0
# Additional dependencies for data processing # Additional dependencies for data processing
typing-extensions>=4.0.0 typing-extensions>=4.0.0