diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index d9c2a4c..9cbd458 100644 Binary files a/app/__pycache__/main.cpython-312.pyc and b/app/__pycache__/main.cpython-312.pyc differ diff --git a/app/db/__pycache__/tables.cpython-312.pyc b/app/db/__pycache__/tables.cpython-312.pyc index 2086f0d..8e6b31c 100644 Binary files a/app/db/__pycache__/tables.cpython-312.pyc and b/app/db/__pycache__/tables.cpython-312.pyc differ diff --git a/app/db/tables.py b/app/db/tables.py index c66d162..3153b82 100644 --- a/app/db/tables.py +++ b/app/db/tables.py @@ -2,7 +2,7 @@ import datetime from sqlalchemy import Column, DateTime, Integer, String -from db.db import Base +from app.db.db import Base class InvestorTable(Base): diff --git a/app/main.py b/app/main.py index a09042e..400730f 100644 --- a/app/main.py +++ b/app/main.py @@ -1,11 +1,12 @@ import io import pandas as pd -from db.db import db_dependency, init_database +from app.db.db import db_dependency, init_database from fastapi import FastAPI, File, UploadFile -from services.openrouter import InvestorProcessor +from app.services.openrouter import InvestorProcessor -from app.services.querying import QueryProcessor +from app.pydantic_schemas import QueryRequest, QueryResponseList +from app.services.langgraph_agent import LangGraphQueryAgent app = FastAPI() @@ -31,11 +32,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)): return {"results": [r.dict() for r in results]} -@app.post("/query") -async def query_investors(db: db_dependency, question: str): - processor = QueryProcessor(sql_session=db) - results = processor.process_query(question) - return {"results": [r.dict() for r in results]} +@app.post("/query", response_model=QueryResponseList) +async def query_investors(db: db_dependency, request: QueryRequest): + agent = LangGraphQueryAgent(sql_session=db) + result = agent.run(request.question) + return result if __name__ == "__main__": diff --git a/app/pydantic_schemas.py b/app/pydantic_schemas.py index 08588b5..af54a14 100644 --- a/app/pydantic_schemas.py +++ b/app/pydantic_schemas.py @@ -35,4 +35,4 @@ class QueryRequest(BaseModel): class QueryResponseList(BaseModel): - responses: List[QueryResponse] \ No newline at end of file + responses: List[QueryResponse] diff --git a/app/services/__pycache__/langgraph_agent.cpython-312.pyc b/app/services/__pycache__/langgraph_agent.cpython-312.pyc index fe524c0..06c5a6d 100644 Binary files a/app/services/__pycache__/langgraph_agent.cpython-312.pyc and b/app/services/__pycache__/langgraph_agent.cpython-312.pyc differ diff --git a/app/services/langgraph_agent.py b/app/services/langgraph_agent.py index e69de29..9bf1866 100644 --- a/app/services/langgraph_agent.py +++ b/app/services/langgraph_agent.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import chromadb +from langgraph.graph import END, START, StateGraph +from pydantic import BaseModel, Field +from sqlalchemy import func + +from app.db.tables import InvestorTable +from app.pydantic_schemas import QueryResponse, QueryResponseList + + +class AgentState(BaseModel): + question: str + sql_results: List[QueryResponse] = Field(default_factory=list) + vector_results: List[QueryResponse] = Field(default_factory=list) + + +class LangGraphQueryAgent: + """Simple LangGraph agent that queries both SQL and Chroma and merges results.""" + + def __init__( + self, + sql_session: Optional[object] = None, + vector_db_client: Optional[object] = None, + ) -> None: + self.sql_session = sql_session + + # Setup Chroma collection + self.vector_db_client = vector_db_client or chromadb.PersistentClient( + path="./chroma_db" + ) + self.collection = self.vector_db_client.get_or_create_collection( + name="investor_descriptions", + metadata={ + "description": "Investor descriptions and investment thesis focus", + }, + ) + + # Build graph + graph = StateGraph(AgentState) + graph.add_node("sql_search", self._sql_search) + graph.add_node("vector_search", self._vector_search) + graph.add_node("merge", self._merge) + + # Parallel fan-out: START -> sql_search & vector_search -> merge -> END + graph.add_edge(START, "sql_search") + graph.add_edge(START, "vector_search") + graph.add_edge("sql_search", "merge") + graph.add_edge("vector_search", "merge") + graph.add_edge("merge", END) + + self.app = graph.compile() + + # Nodes + def _sql_search(self, state: AgentState) -> Dict[str, Any]: + results: List[QueryResponse] = [] + if self.sql_session is None: + return {"sql_results": results} + + # Simple LIKE-based search across a few fields + # Note: SQLite uses case-insensitive LIKE by default for ASCII. + q = ( + self.sql_session.query(InvestorTable) + .filter( + (func.lower(InvestorTable.name).like(f"%{state.question.lower()}%")) + | ( + func.lower(InvestorTable.sector_focus).like( + f"%{state.question.lower()}%" + ) + ) + | ( + func.lower(InvestorTable.stage_focus).like( + f"%{state.question.lower()}%" + ) + ) + | (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%")) + ) + .limit(10) + ) + + for row in q.all(): + results.append( + QueryResponse( + name=row.name, + aum=row.aum, + check_size=row.check_size, + sector_focus=row.sector_focus, + stage_focus=row.stage_focus, + region=row.region, + investment_thesis="", + investor_description="", + reason="Matched SQL fields via LIKE", + ) + ) + + return {"sql_results": results} + + def _vector_search(self, state: AgentState) -> Dict[str, Any]: + results: List[QueryResponse] = [] + try: + q = self.collection.query(query_texts=[state.question], n_results=10) + # q has keys: ids, distances, documents, metadatas + docs = q.get("documents") or [] + metas = q.get("metadatas") or [] + if docs and metas: + for i, md in enumerate(metas[0]): + name = md.get("name", "Unknown") + results.append( + QueryResponse( + name=name, + aum=0, + check_size="", + sector_focus="", + stage_focus="", + region=md.get("headquarters", ""), + investment_thesis="", + investor_description=(docs[0][i] if docs[0] else ""), + reason="Vector similarity in Chroma", + ) + ) + except Exception: + # Best-effort; leave vector results empty on failure + pass + + return {"vector_results": results} + + def _merge(self, state: AgentState) -> Dict[str, Any]: + # Deduplicate by name, prefer SQL fields where available, keep first reason + merged: Dict[str, QueryResponse] = {} + + for item in state.vector_results + state.sql_results: + if item.name not in merged: + merged[item.name] = item + else: + existing = merged[item.name] + merged[item.name] = QueryResponse( + name=existing.name, + aum=existing.aum or item.aum, + check_size=existing.check_size or item.check_size, + sector_focus=existing.sector_focus or item.sector_focus, + stage_focus=existing.stage_focus or item.stage_focus, + region=existing.region or item.region, + investment_thesis=existing.investment_thesis + or item.investment_thesis, + investor_description=existing.investor_description + or item.investor_description, + reason=existing.reason or item.reason, + ) + + # Store back into sql_results to pass through the END with full state + return { + "sql_results": list(merged.values()), + "vector_results": [], + } + + # Public API + def run(self, question: str) -> QueryResponseList: + state = AgentState(question=question) + final_state: AgentState = self.app.invoke(state) + return QueryResponseList(responses=final_state.sql_results) diff --git a/app/services/openrouter.py b/app/services/openrouter.py index 6fa8a01..6a36a61 100644 --- a/app/services/openrouter.py +++ b/app/services/openrouter.py @@ -3,12 +3,13 @@ from typing import List, Optional import chromadb import pandas as pd -from db.tables import InvestorTable from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI from pydantic_schemas import Investor, InvestorList from settings import settings +from app.db.tables import InvestorTable + # Add these imports for your databases # from sqlalchemy.ext.asyncio import AsyncSession # from your_vector_db import VectorDBClient diff --git a/requirements.txt b/requirements.txt index 10ba213..6a7dbd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,9 +8,13 @@ chromadb>=0.4.0 # LLM integration openai>=1.0.0 +langchain>=0.2.0 +langchain-openai>=0.1.0 +langgraph>=0.2.0 # Environment management python-dotenv>=1.0.0 +pydantic-settings>=2.0.0 # Additional dependencies for data processing typing-extensions>=4.0.0