diff --git a/.gitignore b/.gitignore index 698ec62..2ae4ed8 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ /*.db +/*.cypython-* \ No newline at end of file diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index 9cbd458..6d7f7be 100644 Binary files a/app/__pycache__/main.cpython-312.pyc and b/app/__pycache__/main.cpython-312.pyc differ diff --git a/app/__pycache__/pydantic_schemas.cpython-312.pyc b/app/__pycache__/pydantic_schemas.cpython-312.pyc index 49fa62c..761fc87 100644 Binary files a/app/__pycache__/pydantic_schemas.cpython-312.pyc and b/app/__pycache__/pydantic_schemas.cpython-312.pyc differ diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/__pycache__/__init__.cpython-312.pyc b/app/api/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..707593a Binary files /dev/null and b/app/api/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/api/__pycache__/investors.cpython-312.pyc b/app/api/__pycache__/investors.cpython-312.pyc new file mode 100644 index 0000000..837640b Binary files /dev/null and b/app/api/__pycache__/investors.cpython-312.pyc differ diff --git a/app/api/companies.py b/app/api/companies.py new file mode 100644 index 0000000..d718c44 --- /dev/null +++ b/app/api/companies.py @@ -0,0 +1,8 @@ +from fastapi.routing import apirouter + +router = apirouter() + +@router.get("/companies") +def read_companies(): + return {"message": "list of companies"} + diff --git a/app/api/investors.py b/app/api/investors.py new file mode 100644 index 0000000..42d5b39 --- /dev/null +++ b/app/api/investors.py @@ -0,0 +1,8 @@ +from fastapi.routing import apirouter + +router = apirouter() + +@router.get("/investors") +def read_investors(): + return {"message": "list of investors"} + diff --git a/app/db/__pycache__/tables.cpython-312.pyc b/app/db/__pycache__/tables.cpython-312.pyc index 8e6b31c..18d72cf 100644 Binary files a/app/db/__pycache__/tables.cpython-312.pyc and b/app/db/__pycache__/tables.cpython-312.pyc differ diff --git a/app/db/models.py b/app/db/models.py new file mode 100644 index 0000000..27c09ac --- /dev/null +++ b/app/db/models.py @@ -0,0 +1,103 @@ +import datetime +import enum + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text +from sqlalchemy.orm import relationship +from sqlalchemy.types import Enum + +from db.db import Base + + +class InvestmentStage(enum.Enum): + SEED = "seed" + SERIES_A = "series_a" + SERIES_B = "series_b" + SERIES_C = "series_c" + GROWTH = "growth" + LATE_STAGE = "late_stage" + + +# Association table for many-to-many relationship between investors and companies +investor_company_association = Table( + "investor_companies", + Base.metadata, + Column("investor_id", Integer, ForeignKey("investors.id")), + Column("company_id", Integer, ForeignKey("companies.id")), +) + + +# Association table for investor-sector many-to-many +investor_sector_association = Table( + "investor_sectors", + Base.metadata, + Column("investor_id", Integer, ForeignKey("investors.id")), + Column("sector_id", Integer, ForeignKey("sectors.id")), +) +class InvestorTable(Base): + __tablename__ = "investors" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + description = Column(Text, nullable=True) + aum = Column(Integer, nullable=False) # Assets Under Management + check_size_lower = Column(Integer, nullable=False) # Lower bound + check_size_upper = Column(Integer, nullable=False) # Upper bound + geography = Column(String, nullable=False) + stage_focus = Column(Enum(InvestmentStage), nullable=False) + number_of_investments = Column(Integer, default=0) + created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC)) + updated_at = Column( + DateTime, + default=datetime.datetime.now(datetime.UTC), + onupdate=datetime.datetime.now(datetime.UTC), + ) + + # Relationship to portfolio companies + portfolio_companies = relationship( + "CompanyTable", + secondary=investor_company_association, + back_populates="investors", + ) + + +class CompanyTable(Base): + __tablename__ = "companies" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + industry = Column(String, nullable=False) + location = Column(String, nullable=False) + founded_year = Column(Integer, nullable=True) + website = Column(String, nullable=True) + created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC)) + updated_at = Column( + DateTime, + default=datetime.datetime.now(datetime.UTC), + onupdate=datetime.datetime.now(datetime.UTC), + ) + + # Relationship back to investors + investors = relationship( + "InvestorTable", + secondary=investor_company_association, + back_populates="portfolio_companies", + ) + + +class SectorTable(Base): + __tablename__ = "sectors" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, unique=True, nullable=False) + + + + +class InvestorTeamMember(Base): + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + role = Column(String, nullable=False) + email = Column(String, unique=True, nullable=False) + + investor_id = Column(Integer, ForeignKey("investors.id")) + investor = relationship("InvestorTable", back_populates="team_members") diff --git a/app/db/tables.py b/app/db/tables.py deleted file mode 100644 index 3153b82..0000000 --- a/app/db/tables.py +++ /dev/null @@ -1,23 +0,0 @@ -import datetime - -from sqlalchemy import Column, DateTime, Integer, String - -from app.db.db import Base - - -class InvestorTable(Base): - __tablename__ = "investors" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, nullable=False) - aum = Column(Integer, nullable=False) - check_size = Column(String, nullable=False) - sector_focus = Column(String, nullable=False) - stage_focus = Column(String, nullable=False) - region = Column(String, nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC)) - updated_at = Column( - DateTime, - default=datetime.datetime.now(datetime.UTC), - onupdate=datetime.datetime.now(datetime.UTC), - ) diff --git a/app/main.py b/app/main.py index 400730f..88c2143 100644 --- a/app/main.py +++ b/app/main.py @@ -1,16 +1,14 @@ import io - +from api import investors import pandas as pd -from app.db.db import db_dependency, init_database +from db.db import db_dependency from fastapi import FastAPI, File, UploadFile -from app.services.openrouter import InvestorProcessor - -from app.pydantic_schemas import QueryRequest, QueryResponseList -from app.services.langgraph_agent import LangGraphQueryAgent +from services.openrouter import InvestorProcessor +from services.querying import QueryProcessor app = FastAPI() - -init_database() +app.include_router(investors.router) +# init_database() @app.get("/") @@ -32,11 +30,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)): return {"results": [r.dict() for r in results]} -@app.post("/query", response_model=QueryResponseList) -async def query_investors(db: db_dependency, request: QueryRequest): - agent = LangGraphQueryAgent(sql_session=db) - result = agent.run(request.question) - return result +@app.post("/query") +async def query_investors(db: db_dependency, question: str): + processor = QueryProcessor(sql_session=db) + results = processor.process_query(question) + return {"results": results} if __name__ == "__main__": diff --git a/app/services/__pycache__/langgraph_agent.cpython-312.pyc b/app/services/__pycache__/langgraph_agent.cpython-312.pyc index 06c5a6d..5168061 100644 Binary files a/app/services/__pycache__/langgraph_agent.cpython-312.pyc and b/app/services/__pycache__/langgraph_agent.cpython-312.pyc differ diff --git a/app/services/__pycache__/openrouter.cpython-312.pyc b/app/services/__pycache__/openrouter.cpython-312.pyc index 46d57bd..8fcc5d3 100644 Binary files a/app/services/__pycache__/openrouter.cpython-312.pyc and b/app/services/__pycache__/openrouter.cpython-312.pyc differ diff --git a/app/services/__pycache__/querying.cpython-312.pyc b/app/services/__pycache__/querying.cpython-312.pyc new file mode 100644 index 0000000..f323760 Binary files /dev/null and b/app/services/__pycache__/querying.cpython-312.pyc differ diff --git a/app/services/langgraph_agent.py b/app/services/langgraph_agent.py index 9bf1866..e69de29 100644 --- a/app/services/langgraph_agent.py +++ b/app/services/langgraph_agent.py @@ -1,162 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -import chromadb -from langgraph.graph import END, START, StateGraph -from pydantic import BaseModel, Field -from sqlalchemy import func - -from app.db.tables import InvestorTable -from app.pydantic_schemas import QueryResponse, QueryResponseList - - -class AgentState(BaseModel): - question: str - sql_results: List[QueryResponse] = Field(default_factory=list) - vector_results: List[QueryResponse] = Field(default_factory=list) - - -class LangGraphQueryAgent: - """Simple LangGraph agent that queries both SQL and Chroma and merges results.""" - - def __init__( - self, - sql_session: Optional[object] = None, - vector_db_client: Optional[object] = None, - ) -> None: - self.sql_session = sql_session - - # Setup Chroma collection - self.vector_db_client = vector_db_client or chromadb.PersistentClient( - path="./chroma_db" - ) - self.collection = self.vector_db_client.get_or_create_collection( - name="investor_descriptions", - metadata={ - "description": "Investor descriptions and investment thesis focus", - }, - ) - - # Build graph - graph = StateGraph(AgentState) - graph.add_node("sql_search", self._sql_search) - graph.add_node("vector_search", self._vector_search) - graph.add_node("merge", self._merge) - - # Parallel fan-out: START -> sql_search & vector_search -> merge -> END - graph.add_edge(START, "sql_search") - graph.add_edge(START, "vector_search") - graph.add_edge("sql_search", "merge") - graph.add_edge("vector_search", "merge") - graph.add_edge("merge", END) - - self.app = graph.compile() - - # Nodes - def _sql_search(self, state: AgentState) -> Dict[str, Any]: - results: List[QueryResponse] = [] - if self.sql_session is None: - return {"sql_results": results} - - # Simple LIKE-based search across a few fields - # Note: SQLite uses case-insensitive LIKE by default for ASCII. - q = ( - self.sql_session.query(InvestorTable) - .filter( - (func.lower(InvestorTable.name).like(f"%{state.question.lower()}%")) - | ( - func.lower(InvestorTable.sector_focus).like( - f"%{state.question.lower()}%" - ) - ) - | ( - func.lower(InvestorTable.stage_focus).like( - f"%{state.question.lower()}%" - ) - ) - | (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%")) - ) - .limit(10) - ) - - for row in q.all(): - results.append( - QueryResponse( - name=row.name, - aum=row.aum, - check_size=row.check_size, - sector_focus=row.sector_focus, - stage_focus=row.stage_focus, - region=row.region, - investment_thesis="", - investor_description="", - reason="Matched SQL fields via LIKE", - ) - ) - - return {"sql_results": results} - - def _vector_search(self, state: AgentState) -> Dict[str, Any]: - results: List[QueryResponse] = [] - try: - q = self.collection.query(query_texts=[state.question], n_results=10) - # q has keys: ids, distances, documents, metadatas - docs = q.get("documents") or [] - metas = q.get("metadatas") or [] - if docs and metas: - for i, md in enumerate(metas[0]): - name = md.get("name", "Unknown") - results.append( - QueryResponse( - name=name, - aum=0, - check_size="", - sector_focus="", - stage_focus="", - region=md.get("headquarters", ""), - investment_thesis="", - investor_description=(docs[0][i] if docs[0] else ""), - reason="Vector similarity in Chroma", - ) - ) - except Exception: - # Best-effort; leave vector results empty on failure - pass - - return {"vector_results": results} - - def _merge(self, state: AgentState) -> Dict[str, Any]: - # Deduplicate by name, prefer SQL fields where available, keep first reason - merged: Dict[str, QueryResponse] = {} - - for item in state.vector_results + state.sql_results: - if item.name not in merged: - merged[item.name] = item - else: - existing = merged[item.name] - merged[item.name] = QueryResponse( - name=existing.name, - aum=existing.aum or item.aum, - check_size=existing.check_size or item.check_size, - sector_focus=existing.sector_focus or item.sector_focus, - stage_focus=existing.stage_focus or item.stage_focus, - region=existing.region or item.region, - investment_thesis=existing.investment_thesis - or item.investment_thesis, - investor_description=existing.investor_description - or item.investor_description, - reason=existing.reason or item.reason, - ) - - # Store back into sql_results to pass through the END with full state - return { - "sql_results": list(merged.values()), - "vector_results": [], - } - - # Public API - def run(self, question: str) -> QueryResponseList: - state = AgentState(question=question) - final_state: AgentState = self.app.invoke(state) - return QueryResponseList(responses=final_state.sql_results) diff --git a/app/services/openrouter.py b/app/services/openrouter.py index 6a36a61..057a5dc 100644 --- a/app/services/openrouter.py +++ b/app/services/openrouter.py @@ -3,13 +3,12 @@ from typing import List, Optional import chromadb import pandas as pd +from db.models import InvestorTable from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI from pydantic_schemas import Investor, InvestorList from settings import settings -from app.db.tables import InvestorTable - # Add these imports for your databases # from sqlalchemy.ext.asyncio import AsyncSession # from your_vector_db import VectorDBClient diff --git a/app/services/querying.py b/app/services/querying.py index ea80784..e253b66 100644 --- a/app/services/querying.py +++ b/app/services/querying.py @@ -1,13 +1,22 @@ from typing import Optional import chromadb +from langchain import hub +from langchain_community.agent_toolkits import SQLDatabaseToolkit +from langchain_community.utilities import SQLDatabase from langchain_openai import ChatOpenAI +from langgraph.prebuilt import create_react_agent from pydantic_schemas import Investor, InvestorList from settings import settings -# Add these imports for your databases -# from sqlalchemy.ext.asyncio import AsyncSession -# from your_vector_db import VectorDBClient +# Connect to SQLite + +prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") +db = SQLDatabase.from_uri("sqlite:///investors.db") +system_message = ( + prompt_template.format(dialect="SQLite", top_k=5) + + "\n Get answers from the Sql database and the vector database" +) class QueryProcessor: @@ -19,12 +28,16 @@ class QueryProcessor: self.llm = ChatOpenAI( api_key=settings.OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1", - model="openai/gpt-oss-120b:free", + model="google/gemini-2.5-flash-lite", temperature=0, ) - - self.structured_llm = self.llm.with_structured_output(InvestorList) - self.sql_session = sql_session + self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) + self.agent = create_react_agent( + model=self.llm, + tools=self.toolkit.get_tools() + [self.query_vector_database], + prompt=system_message, + response_format=InvestorList, + ) self.vector_db_client = vector_db_client self.vector_db_client = chromadb.PersistentClient(path="./chroma_db") @@ -49,13 +62,22 @@ class QueryProcessor: """Query the vector database for investor information.""" if not self.vector_db_client: return None + print("VECTOR STORE WAS CALLED") - # Implement vector database querying logic here - results = self.vector_db_client.query(collection=self.collection, query=query) - investors = [Investor(**doc.metadata) for doc in results.documents] - return InvestorList(investors=investors) + # Query the collection directly, not passing collection as parameter + results = self.collection.query( + query_texts=[query], # ChromaDB expects a list of query texts + n_results=3, # Specify how many results you want + ) + print(results) + + # ChromaDB returns results in a different structure + # results will have 'documents', 'metadatas', 'ids', 'distances' + return results def process_query(self, question: str) -> InvestorList: """Process a query using the LLM and return structured investor data.""" - response = self.structured_llm.predict(question=question) + response = self.agent.invoke( + {"messages": [("user", question)]}, + ) return response diff --git a/requirements.txt b/requirements.txt index 6a7dbd3..10ba213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,13 +8,9 @@ chromadb>=0.4.0 # LLM integration openai>=1.0.0 -langchain>=0.2.0 -langchain-openai>=0.1.0 -langgraph>=0.2.0 # Environment management python-dotenv>=1.0.0 -pydantic-settings>=2.0.0 # Additional dependencies for data processing typing-extensions>=4.0.0