Refactor imports and enhance query functionality with LangGraph integration; update requirements for new dependencies
This commit is contained in:
Binary file not shown.
Binary file not shown.
+1
-1
@@ -2,7 +2,7 @@ import datetime
|
|||||||
|
|
||||||
from sqlalchemy import Column, DateTime, Integer, String
|
from sqlalchemy import Column, DateTime, Integer, String
|
||||||
|
|
||||||
from db.db import Base
|
from app.db.db import Base
|
||||||
|
|
||||||
|
|
||||||
class InvestorTable(Base):
|
class InvestorTable(Base):
|
||||||
|
|||||||
+9
-8
@@ -1,11 +1,12 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from db.db import db_dependency, init_database
|
from app.db.db import db_dependency, init_database
|
||||||
from fastapi import FastAPI, File, UploadFile
|
from fastapi import FastAPI, File, UploadFile
|
||||||
from services.openrouter import InvestorProcessor
|
from app.services.openrouter import InvestorProcessor
|
||||||
|
|
||||||
from app.services.querying import QueryProcessor
|
from app.pydantic_schemas import QueryRequest, QueryResponseList
|
||||||
|
from app.services.langgraph_agent import LangGraphQueryAgent
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@@ -31,11 +32,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
|
|||||||
return {"results": [r.dict() for r in results]}
|
return {"results": [r.dict() for r in results]}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query")
|
@app.post("/query", response_model=QueryResponseList)
|
||||||
async def query_investors(db: db_dependency, question: str):
|
async def query_investors(db: db_dependency, request: QueryRequest):
|
||||||
processor = QueryProcessor(sql_session=db)
|
agent = LangGraphQueryAgent(sql_session=db)
|
||||||
results = processor.process_query(question)
|
result = agent.run(request.question)
|
||||||
return {"results": [r.dict() for r in results]}
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Binary file not shown.
@@ -0,0 +1,162 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from app.db.tables import InvestorTable
|
||||||
|
from app.pydantic_schemas import QueryResponse, QueryResponseList
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState(BaseModel):
|
||||||
|
question: str
|
||||||
|
sql_results: List[QueryResponse] = Field(default_factory=list)
|
||||||
|
vector_results: List[QueryResponse] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class LangGraphQueryAgent:
|
||||||
|
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sql_session: Optional[object] = None,
|
||||||
|
vector_db_client: Optional[object] = None,
|
||||||
|
) -> None:
|
||||||
|
self.sql_session = sql_session
|
||||||
|
|
||||||
|
# Setup Chroma collection
|
||||||
|
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
|
||||||
|
path="./chroma_db"
|
||||||
|
)
|
||||||
|
self.collection = self.vector_db_client.get_or_create_collection(
|
||||||
|
name="investor_descriptions",
|
||||||
|
metadata={
|
||||||
|
"description": "Investor descriptions and investment thesis focus",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build graph
|
||||||
|
graph = StateGraph(AgentState)
|
||||||
|
graph.add_node("sql_search", self._sql_search)
|
||||||
|
graph.add_node("vector_search", self._vector_search)
|
||||||
|
graph.add_node("merge", self._merge)
|
||||||
|
|
||||||
|
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
|
||||||
|
graph.add_edge(START, "sql_search")
|
||||||
|
graph.add_edge(START, "vector_search")
|
||||||
|
graph.add_edge("sql_search", "merge")
|
||||||
|
graph.add_edge("vector_search", "merge")
|
||||||
|
graph.add_edge("merge", END)
|
||||||
|
|
||||||
|
self.app = graph.compile()
|
||||||
|
|
||||||
|
# Nodes
|
||||||
|
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
|
||||||
|
results: List[QueryResponse] = []
|
||||||
|
if self.sql_session is None:
|
||||||
|
return {"sql_results": results}
|
||||||
|
|
||||||
|
# Simple LIKE-based search across a few fields
|
||||||
|
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
|
||||||
|
q = (
|
||||||
|
self.sql_session.query(InvestorTable)
|
||||||
|
.filter(
|
||||||
|
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
|
||||||
|
| (
|
||||||
|
func.lower(InvestorTable.sector_focus).like(
|
||||||
|
f"%{state.question.lower()}%"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
| (
|
||||||
|
func.lower(InvestorTable.stage_focus).like(
|
||||||
|
f"%{state.question.lower()}%"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
|
||||||
|
)
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in q.all():
|
||||||
|
results.append(
|
||||||
|
QueryResponse(
|
||||||
|
name=row.name,
|
||||||
|
aum=row.aum,
|
||||||
|
check_size=row.check_size,
|
||||||
|
sector_focus=row.sector_focus,
|
||||||
|
stage_focus=row.stage_focus,
|
||||||
|
region=row.region,
|
||||||
|
investment_thesis="",
|
||||||
|
investor_description="",
|
||||||
|
reason="Matched SQL fields via LIKE",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"sql_results": results}
|
||||||
|
|
||||||
|
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
|
||||||
|
results: List[QueryResponse] = []
|
||||||
|
try:
|
||||||
|
q = self.collection.query(query_texts=[state.question], n_results=10)
|
||||||
|
# q has keys: ids, distances, documents, metadatas
|
||||||
|
docs = q.get("documents") or []
|
||||||
|
metas = q.get("metadatas") or []
|
||||||
|
if docs and metas:
|
||||||
|
for i, md in enumerate(metas[0]):
|
||||||
|
name = md.get("name", "Unknown")
|
||||||
|
results.append(
|
||||||
|
QueryResponse(
|
||||||
|
name=name,
|
||||||
|
aum=0,
|
||||||
|
check_size="",
|
||||||
|
sector_focus="",
|
||||||
|
stage_focus="",
|
||||||
|
region=md.get("headquarters", ""),
|
||||||
|
investment_thesis="",
|
||||||
|
investor_description=(docs[0][i] if docs[0] else ""),
|
||||||
|
reason="Vector similarity in Chroma",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Best-effort; leave vector results empty on failure
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"vector_results": results}
|
||||||
|
|
||||||
|
def _merge(self, state: AgentState) -> Dict[str, Any]:
|
||||||
|
# Deduplicate by name, prefer SQL fields where available, keep first reason
|
||||||
|
merged: Dict[str, QueryResponse] = {}
|
||||||
|
|
||||||
|
for item in state.vector_results + state.sql_results:
|
||||||
|
if item.name not in merged:
|
||||||
|
merged[item.name] = item
|
||||||
|
else:
|
||||||
|
existing = merged[item.name]
|
||||||
|
merged[item.name] = QueryResponse(
|
||||||
|
name=existing.name,
|
||||||
|
aum=existing.aum or item.aum,
|
||||||
|
check_size=existing.check_size or item.check_size,
|
||||||
|
sector_focus=existing.sector_focus or item.sector_focus,
|
||||||
|
stage_focus=existing.stage_focus or item.stage_focus,
|
||||||
|
region=existing.region or item.region,
|
||||||
|
investment_thesis=existing.investment_thesis
|
||||||
|
or item.investment_thesis,
|
||||||
|
investor_description=existing.investor_description
|
||||||
|
or item.investor_description,
|
||||||
|
reason=existing.reason or item.reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store back into sql_results to pass through the END with full state
|
||||||
|
return {
|
||||||
|
"sql_results": list(merged.values()),
|
||||||
|
"vector_results": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Public API
|
||||||
|
def run(self, question: str) -> QueryResponseList:
|
||||||
|
state = AgentState(question=question)
|
||||||
|
final_state: AgentState = self.app.invoke(state)
|
||||||
|
return QueryResponseList(responses=final_state.sql_results)
|
||||||
|
|||||||
@@ -3,12 +3,13 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from db.tables import InvestorTable
|
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from pydantic_schemas import Investor, InvestorList
|
from pydantic_schemas import Investor, InvestorList
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
|
||||||
|
from app.db.tables import InvestorTable
|
||||||
|
|
||||||
# Add these imports for your databases
|
# Add these imports for your databases
|
||||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
# from your_vector_db import VectorDBClient
|
# from your_vector_db import VectorDBClient
|
||||||
|
|||||||
@@ -8,9 +8,13 @@ chromadb>=0.4.0
|
|||||||
|
|
||||||
# LLM integration
|
# LLM integration
|
||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
|
langchain>=0.2.0
|
||||||
|
langchain-openai>=0.1.0
|
||||||
|
langgraph>=0.2.0
|
||||||
|
|
||||||
# Environment management
|
# Environment management
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
|
pydantic-settings>=2.0.0
|
||||||
|
|
||||||
# Additional dependencies for data processing
|
# Additional dependencies for data processing
|
||||||
typing-extensions>=4.0.0
|
typing-extensions>=4.0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user