Add CompanyTable model and refactor query handling; update requirements for new dependencies

This commit is contained in:
bolade
2025-09-02 12:22:50 +01:00
parent 74931f235e
commit 65b5df3a43
19 changed files with 166 additions and 216 deletions
+1
View File
@@ -12,3 +12,4 @@
/*.db /*.db
/*.cypython-*
Binary file not shown.
Binary file not shown.
View File
Binary file not shown.
Binary file not shown.
+8
View File
@@ -0,0 +1,8 @@
from fastapi.routing import apirouter
router = apirouter()
@router.get("/companies")
def read_companies():
return {"message": "list of companies"}
+8
View File
@@ -0,0 +1,8 @@
from fastapi.routing import apirouter
router = apirouter()
@router.get("/investors")
def read_investors():
return {"message": "list of investors"}
Binary file not shown.
+103
View File
@@ -0,0 +1,103 @@
import datetime
import enum
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text
from sqlalchemy.orm import relationship
from sqlalchemy.types import Enum
from db.db import Base
class InvestmentStage(enum.Enum):
SEED = "seed"
SERIES_A = "series_a"
SERIES_B = "series_b"
SERIES_C = "series_c"
GROWTH = "growth"
LATE_STAGE = "late_stage"
# Association table for many-to-many relationship between investors and companies
investor_company_association = Table(
"investor_companies",
Base.metadata,
Column("investor_id", Integer, ForeignKey("investors.id")),
Column("company_id", Integer, ForeignKey("companies.id")),
)
# Association table for investor-sector many-to-many
investor_sector_association = Table(
"investor_sectors",
Base.metadata,
Column("investor_id", Integer, ForeignKey("investors.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
class InvestorTable(Base):
__tablename__ = "investors"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
description = Column(Text, nullable=True)
aum = Column(Integer, nullable=False) # Assets Under Management
check_size_lower = Column(Integer, nullable=False) # Lower bound
check_size_upper = Column(Integer, nullable=False) # Upper bound
geography = Column(String, nullable=False)
stage_focus = Column(Enum(InvestmentStage), nullable=False)
number_of_investments = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
updated_at = Column(
DateTime,
default=datetime.datetime.now(datetime.UTC),
onupdate=datetime.datetime.now(datetime.UTC),
)
# Relationship to portfolio companies
portfolio_companies = relationship(
"CompanyTable",
secondary=investor_company_association,
back_populates="investors",
)
class CompanyTable(Base):
__tablename__ = "companies"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
industry = Column(String, nullable=False)
location = Column(String, nullable=False)
founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
updated_at = Column(
DateTime,
default=datetime.datetime.now(datetime.UTC),
onupdate=datetime.datetime.now(datetime.UTC),
)
# Relationship back to investors
investors = relationship(
"InvestorTable",
secondary=investor_company_association,
back_populates="portfolio_companies",
)
class SectorTable(Base):
__tablename__ = "sectors"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, nullable=False)
class InvestorTeamMember(Base):
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
role = Column(String, nullable=False)
email = Column(String, unique=True, nullable=False)
investor_id = Column(Integer, ForeignKey("investors.id"))
investor = relationship("InvestorTable", back_populates="team_members")
-23
View File
@@ -1,23 +0,0 @@
import datetime
from sqlalchemy import Column, DateTime, Integer, String
from app.db.db import Base
class InvestorTable(Base):
__tablename__ = "investors"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
aum = Column(Integer, nullable=False)
check_size = Column(String, nullable=False)
sector_focus = Column(String, nullable=False)
stage_focus = Column(String, nullable=False)
region = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
updated_at = Column(
DateTime,
default=datetime.datetime.now(datetime.UTC),
onupdate=datetime.datetime.now(datetime.UTC),
)
+11 -13
View File
@@ -1,16 +1,14 @@
import io import io
from api import investors
import pandas as pd import pandas as pd
from app.db.db import db_dependency, init_database from db.db import db_dependency
from fastapi import FastAPI, File, UploadFile from fastapi import FastAPI, File, UploadFile
from app.services.openrouter import InvestorProcessor from services.openrouter import InvestorProcessor
from services.querying import QueryProcessor
from app.pydantic_schemas import QueryRequest, QueryResponseList
from app.services.langgraph_agent import LangGraphQueryAgent
app = FastAPI() app = FastAPI()
app.include_router(investors.router)
init_database() # init_database()
@app.get("/") @app.get("/")
@@ -32,11 +30,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
return {"results": [r.dict() for r in results]} return {"results": [r.dict() for r in results]}
@app.post("/query", response_model=QueryResponseList) @app.post("/query")
async def query_investors(db: db_dependency, request: QueryRequest): async def query_investors(db: db_dependency, question: str):
agent = LangGraphQueryAgent(sql_session=db) processor = QueryProcessor(sql_session=db)
result = agent.run(request.question) results = processor.process_query(question)
return result return {"results": results}
if __name__ == "__main__": if __name__ == "__main__":
Binary file not shown.
Binary file not shown.
-162
View File
@@ -1,162 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
import chromadb
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field
from sqlalchemy import func
from app.db.tables import InvestorTable
from app.pydantic_schemas import QueryResponse, QueryResponseList
class AgentState(BaseModel):
question: str
sql_results: List[QueryResponse] = Field(default_factory=list)
vector_results: List[QueryResponse] = Field(default_factory=list)
class LangGraphQueryAgent:
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
) -> None:
self.sql_session = sql_session
# Setup Chroma collection
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
path="./chroma_db"
)
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus",
},
)
# Build graph
graph = StateGraph(AgentState)
graph.add_node("sql_search", self._sql_search)
graph.add_node("vector_search", self._vector_search)
graph.add_node("merge", self._merge)
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
graph.add_edge(START, "sql_search")
graph.add_edge(START, "vector_search")
graph.add_edge("sql_search", "merge")
graph.add_edge("vector_search", "merge")
graph.add_edge("merge", END)
self.app = graph.compile()
# Nodes
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
results: List[QueryResponse] = []
if self.sql_session is None:
return {"sql_results": results}
# Simple LIKE-based search across a few fields
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
q = (
self.sql_session.query(InvestorTable)
.filter(
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
| (
func.lower(InvestorTable.sector_focus).like(
f"%{state.question.lower()}%"
)
)
| (
func.lower(InvestorTable.stage_focus).like(
f"%{state.question.lower()}%"
)
)
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
)
.limit(10)
)
for row in q.all():
results.append(
QueryResponse(
name=row.name,
aum=row.aum,
check_size=row.check_size,
sector_focus=row.sector_focus,
stage_focus=row.stage_focus,
region=row.region,
investment_thesis="",
investor_description="",
reason="Matched SQL fields via LIKE",
)
)
return {"sql_results": results}
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
results: List[QueryResponse] = []
try:
q = self.collection.query(query_texts=[state.question], n_results=10)
# q has keys: ids, distances, documents, metadatas
docs = q.get("documents") or []
metas = q.get("metadatas") or []
if docs and metas:
for i, md in enumerate(metas[0]):
name = md.get("name", "Unknown")
results.append(
QueryResponse(
name=name,
aum=0,
check_size="",
sector_focus="",
stage_focus="",
region=md.get("headquarters", ""),
investment_thesis="",
investor_description=(docs[0][i] if docs[0] else ""),
reason="Vector similarity in Chroma",
)
)
except Exception:
# Best-effort; leave vector results empty on failure
pass
return {"vector_results": results}
def _merge(self, state: AgentState) -> Dict[str, Any]:
# Deduplicate by name, prefer SQL fields where available, keep first reason
merged: Dict[str, QueryResponse] = {}
for item in state.vector_results + state.sql_results:
if item.name not in merged:
merged[item.name] = item
else:
existing = merged[item.name]
merged[item.name] = QueryResponse(
name=existing.name,
aum=existing.aum or item.aum,
check_size=existing.check_size or item.check_size,
sector_focus=existing.sector_focus or item.sector_focus,
stage_focus=existing.stage_focus or item.stage_focus,
region=existing.region or item.region,
investment_thesis=existing.investment_thesis
or item.investment_thesis,
investor_description=existing.investor_description
or item.investor_description,
reason=existing.reason or item.reason,
)
# Store back into sql_results to pass through the END with full state
return {
"sql_results": list(merged.values()),
"vector_results": [],
}
# Public API
def run(self, question: str) -> QueryResponseList:
state = AgentState(question=question)
final_state: AgentState = self.app.invoke(state)
return QueryResponseList(responses=final_state.sql_results)
+1 -2
View File
@@ -3,13 +3,12 @@ from typing import List, Optional
import chromadb import chromadb
import pandas as pd import pandas as pd
from db.models import InvestorTable
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from pydantic_schemas import Investor, InvestorList from pydantic_schemas import Investor, InvestorList
from settings import settings from settings import settings
from app.db.tables import InvestorTable
# Add these imports for your databases # Add these imports for your databases
# from sqlalchemy.ext.asyncio import AsyncSession # from sqlalchemy.ext.asyncio import AsyncSession
# from your_vector_db import VectorDBClient # from your_vector_db import VectorDBClient
+34 -12
View File
@@ -1,13 +1,22 @@
from typing import Optional from typing import Optional
import chromadb import chromadb
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from pydantic_schemas import Investor, InvestorList from pydantic_schemas import Investor, InvestorList
from settings import settings from settings import settings
# Add these imports for your databases # Connect to SQLite
# from sqlalchemy.ext.asyncio import AsyncSession
# from your_vector_db import VectorDBClient prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri("sqlite:///investors.db")
system_message = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n Get answers from the Sql database and the vector database"
)
class QueryProcessor: class QueryProcessor:
@@ -19,12 +28,16 @@ class QueryProcessor:
self.llm = ChatOpenAI( self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY, api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="openai/gpt-oss-120b:free", model="google/gemini-2.5-flash-lite",
temperature=0, temperature=0,
) )
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
self.structured_llm = self.llm.with_structured_output(InvestorList) self.agent = create_react_agent(
self.sql_session = sql_session model=self.llm,
tools=self.toolkit.get_tools() + [self.query_vector_database],
prompt=system_message,
response_format=InvestorList,
)
self.vector_db_client = vector_db_client self.vector_db_client = vector_db_client
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db") self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
@@ -49,13 +62,22 @@ class QueryProcessor:
"""Query the vector database for investor information.""" """Query the vector database for investor information."""
if not self.vector_db_client: if not self.vector_db_client:
return None return None
print("VECTOR STORE WAS CALLED")
# Implement vector database querying logic here # Query the collection directly, not passing collection as parameter
results = self.vector_db_client.query(collection=self.collection, query=query) results = self.collection.query(
investors = [Investor(**doc.metadata) for doc in results.documents] query_texts=[query], # ChromaDB expects a list of query texts
return InvestorList(investors=investors) n_results=3, # Specify how many results you want
)
print(results)
# ChromaDB returns results in a different structure
# results will have 'documents', 'metadatas', 'ids', 'distances'
return results
def process_query(self, question: str) -> InvestorList: def process_query(self, question: str) -> InvestorList:
"""Process a query using the LLM and return structured investor data.""" """Process a query using the LLM and return structured investor data."""
response = self.structured_llm.predict(question=question) response = self.agent.invoke(
{"messages": [("user", question)]},
)
return response return response
-4
View File
@@ -8,13 +8,9 @@ chromadb>=0.4.0
# LLM integration # LLM integration
openai>=1.0.0 openai>=1.0.0
langchain>=0.2.0
langchain-openai>=0.1.0
langgraph>=0.2.0
# Environment management # Environment management
python-dotenv>=1.0.0 python-dotenv>=1.0.0
pydantic-settings>=2.0.0
# Additional dependencies for data processing # Additional dependencies for data processing
typing-extensions>=4.0.0 typing-extensions>=4.0.0