Add CompanyTable model and refactor query handling; update requirements for new dependencies
This commit is contained in:
@@ -12,3 +12,4 @@
|
|||||||
|
|
||||||
/*.db
|
/*.db
|
||||||
|
|
||||||
|
/*.cypython-*
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,8 @@
|
|||||||
|
from fastapi.routing import apirouter
|
||||||
|
|
||||||
|
router = apirouter()
|
||||||
|
|
||||||
|
@router.get("/companies")
|
||||||
|
def read_companies():
|
||||||
|
return {"message": "list of companies"}
|
||||||
|
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
from fastapi.routing import apirouter
|
||||||
|
|
||||||
|
router = apirouter()
|
||||||
|
|
||||||
|
@router.get("/investors")
|
||||||
|
def read_investors():
|
||||||
|
return {"message": "list of investors"}
|
||||||
|
|
||||||
Binary file not shown.
@@ -0,0 +1,103 @@
|
|||||||
|
import datetime
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from sqlalchemy.types import Enum
|
||||||
|
|
||||||
|
from db.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class InvestmentStage(enum.Enum):
|
||||||
|
SEED = "seed"
|
||||||
|
SERIES_A = "series_a"
|
||||||
|
SERIES_B = "series_b"
|
||||||
|
SERIES_C = "series_c"
|
||||||
|
GROWTH = "growth"
|
||||||
|
LATE_STAGE = "late_stage"
|
||||||
|
|
||||||
|
|
||||||
|
# Association table for many-to-many relationship between investors and companies
|
||||||
|
investor_company_association = Table(
|
||||||
|
"investor_companies",
|
||||||
|
Base.metadata,
|
||||||
|
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||||
|
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Association table for investor-sector many-to-many
|
||||||
|
investor_sector_association = Table(
|
||||||
|
"investor_sectors",
|
||||||
|
Base.metadata,
|
||||||
|
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||||
|
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||||
|
)
|
||||||
|
class InvestorTable(Base):
|
||||||
|
__tablename__ = "investors"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
aum = Column(Integer, nullable=False) # Assets Under Management
|
||||||
|
check_size_lower = Column(Integer, nullable=False) # Lower bound
|
||||||
|
check_size_upper = Column(Integer, nullable=False) # Upper bound
|
||||||
|
geography = Column(String, nullable=False)
|
||||||
|
stage_focus = Column(Enum(InvestmentStage), nullable=False)
|
||||||
|
number_of_investments = Column(Integer, default=0)
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=datetime.datetime.now(datetime.UTC),
|
||||||
|
onupdate=datetime.datetime.now(datetime.UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship to portfolio companies
|
||||||
|
portfolio_companies = relationship(
|
||||||
|
"CompanyTable",
|
||||||
|
secondary=investor_company_association,
|
||||||
|
back_populates="investors",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompanyTable(Base):
|
||||||
|
__tablename__ = "companies"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
industry = Column(String, nullable=False)
|
||||||
|
location = Column(String, nullable=False)
|
||||||
|
founded_year = Column(Integer, nullable=True)
|
||||||
|
website = Column(String, nullable=True)
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
default=datetime.datetime.now(datetime.UTC),
|
||||||
|
onupdate=datetime.datetime.now(datetime.UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship back to investors
|
||||||
|
investors = relationship(
|
||||||
|
"InvestorTable",
|
||||||
|
secondary=investor_company_association,
|
||||||
|
back_populates="portfolio_companies",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SectorTable(Base):
|
||||||
|
__tablename__ = "sectors"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String, unique=True, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class InvestorTeamMember(Base):
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
role = Column(String, nullable=False)
|
||||||
|
email = Column(String, unique=True, nullable=False)
|
||||||
|
|
||||||
|
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||||
|
investor = relationship("InvestorTable", back_populates="team_members")
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import Column, DateTime, Integer, String
|
|
||||||
|
|
||||||
from app.db.db import Base
|
|
||||||
|
|
||||||
|
|
||||||
class InvestorTable(Base):
|
|
||||||
__tablename__ = "investors"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
aum = Column(Integer, nullable=False)
|
|
||||||
check_size = Column(String, nullable=False)
|
|
||||||
sector_focus = Column(String, nullable=False)
|
|
||||||
stage_focus = Column(String, nullable=False)
|
|
||||||
region = Column(String, nullable=False)
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
|
||||||
updated_at = Column(
|
|
||||||
DateTime,
|
|
||||||
default=datetime.datetime.now(datetime.UTC),
|
|
||||||
onupdate=datetime.datetime.now(datetime.UTC),
|
|
||||||
)
|
|
||||||
+11
-13
@@ -1,16 +1,14 @@
|
|||||||
import io
|
import io
|
||||||
|
from api import investors
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from app.db.db import db_dependency, init_database
|
from db.db import db_dependency
|
||||||
from fastapi import FastAPI, File, UploadFile
|
from fastapi import FastAPI, File, UploadFile
|
||||||
from app.services.openrouter import InvestorProcessor
|
from services.openrouter import InvestorProcessor
|
||||||
|
from services.querying import QueryProcessor
|
||||||
from app.pydantic_schemas import QueryRequest, QueryResponseList
|
|
||||||
from app.services.langgraph_agent import LangGraphQueryAgent
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
app.include_router(investors.router)
|
||||||
init_database()
|
# init_database()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@@ -32,11 +30,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
|
|||||||
return {"results": [r.dict() for r in results]}
|
return {"results": [r.dict() for r in results]}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query", response_model=QueryResponseList)
|
@app.post("/query")
|
||||||
async def query_investors(db: db_dependency, request: QueryRequest):
|
async def query_investors(db: db_dependency, question: str):
|
||||||
agent = LangGraphQueryAgent(sql_session=db)
|
processor = QueryProcessor(sql_session=db)
|
||||||
result = agent.run(request.question)
|
results = processor.process_query(question)
|
||||||
return result
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,162 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import chromadb
|
|
||||||
from langgraph.graph import END, START, StateGraph
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy import func
|
|
||||||
|
|
||||||
from app.db.tables import InvestorTable
|
|
||||||
from app.pydantic_schemas import QueryResponse, QueryResponseList
|
|
||||||
|
|
||||||
|
|
||||||
class AgentState(BaseModel):
|
|
||||||
question: str
|
|
||||||
sql_results: List[QueryResponse] = Field(default_factory=list)
|
|
||||||
vector_results: List[QueryResponse] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class LangGraphQueryAgent:
|
|
||||||
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
sql_session: Optional[object] = None,
|
|
||||||
vector_db_client: Optional[object] = None,
|
|
||||||
) -> None:
|
|
||||||
self.sql_session = sql_session
|
|
||||||
|
|
||||||
# Setup Chroma collection
|
|
||||||
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
|
|
||||||
path="./chroma_db"
|
|
||||||
)
|
|
||||||
self.collection = self.vector_db_client.get_or_create_collection(
|
|
||||||
name="investor_descriptions",
|
|
||||||
metadata={
|
|
||||||
"description": "Investor descriptions and investment thesis focus",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build graph
|
|
||||||
graph = StateGraph(AgentState)
|
|
||||||
graph.add_node("sql_search", self._sql_search)
|
|
||||||
graph.add_node("vector_search", self._vector_search)
|
|
||||||
graph.add_node("merge", self._merge)
|
|
||||||
|
|
||||||
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
|
|
||||||
graph.add_edge(START, "sql_search")
|
|
||||||
graph.add_edge(START, "vector_search")
|
|
||||||
graph.add_edge("sql_search", "merge")
|
|
||||||
graph.add_edge("vector_search", "merge")
|
|
||||||
graph.add_edge("merge", END)
|
|
||||||
|
|
||||||
self.app = graph.compile()
|
|
||||||
|
|
||||||
# Nodes
|
|
||||||
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
|
|
||||||
results: List[QueryResponse] = []
|
|
||||||
if self.sql_session is None:
|
|
||||||
return {"sql_results": results}
|
|
||||||
|
|
||||||
# Simple LIKE-based search across a few fields
|
|
||||||
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
|
|
||||||
q = (
|
|
||||||
self.sql_session.query(InvestorTable)
|
|
||||||
.filter(
|
|
||||||
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
|
|
||||||
| (
|
|
||||||
func.lower(InvestorTable.sector_focus).like(
|
|
||||||
f"%{state.question.lower()}%"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
| (
|
|
||||||
func.lower(InvestorTable.stage_focus).like(
|
|
||||||
f"%{state.question.lower()}%"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
|
|
||||||
)
|
|
||||||
.limit(10)
|
|
||||||
)
|
|
||||||
|
|
||||||
for row in q.all():
|
|
||||||
results.append(
|
|
||||||
QueryResponse(
|
|
||||||
name=row.name,
|
|
||||||
aum=row.aum,
|
|
||||||
check_size=row.check_size,
|
|
||||||
sector_focus=row.sector_focus,
|
|
||||||
stage_focus=row.stage_focus,
|
|
||||||
region=row.region,
|
|
||||||
investment_thesis="",
|
|
||||||
investor_description="",
|
|
||||||
reason="Matched SQL fields via LIKE",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"sql_results": results}
|
|
||||||
|
|
||||||
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
|
|
||||||
results: List[QueryResponse] = []
|
|
||||||
try:
|
|
||||||
q = self.collection.query(query_texts=[state.question], n_results=10)
|
|
||||||
# q has keys: ids, distances, documents, metadatas
|
|
||||||
docs = q.get("documents") or []
|
|
||||||
metas = q.get("metadatas") or []
|
|
||||||
if docs and metas:
|
|
||||||
for i, md in enumerate(metas[0]):
|
|
||||||
name = md.get("name", "Unknown")
|
|
||||||
results.append(
|
|
||||||
QueryResponse(
|
|
||||||
name=name,
|
|
||||||
aum=0,
|
|
||||||
check_size="",
|
|
||||||
sector_focus="",
|
|
||||||
stage_focus="",
|
|
||||||
region=md.get("headquarters", ""),
|
|
||||||
investment_thesis="",
|
|
||||||
investor_description=(docs[0][i] if docs[0] else ""),
|
|
||||||
reason="Vector similarity in Chroma",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# Best-effort; leave vector results empty on failure
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {"vector_results": results}
|
|
||||||
|
|
||||||
def _merge(self, state: AgentState) -> Dict[str, Any]:
|
|
||||||
# Deduplicate by name, prefer SQL fields where available, keep first reason
|
|
||||||
merged: Dict[str, QueryResponse] = {}
|
|
||||||
|
|
||||||
for item in state.vector_results + state.sql_results:
|
|
||||||
if item.name not in merged:
|
|
||||||
merged[item.name] = item
|
|
||||||
else:
|
|
||||||
existing = merged[item.name]
|
|
||||||
merged[item.name] = QueryResponse(
|
|
||||||
name=existing.name,
|
|
||||||
aum=existing.aum or item.aum,
|
|
||||||
check_size=existing.check_size or item.check_size,
|
|
||||||
sector_focus=existing.sector_focus or item.sector_focus,
|
|
||||||
stage_focus=existing.stage_focus or item.stage_focus,
|
|
||||||
region=existing.region or item.region,
|
|
||||||
investment_thesis=existing.investment_thesis
|
|
||||||
or item.investment_thesis,
|
|
||||||
investor_description=existing.investor_description
|
|
||||||
or item.investor_description,
|
|
||||||
reason=existing.reason or item.reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store back into sql_results to pass through the END with full state
|
|
||||||
return {
|
|
||||||
"sql_results": list(merged.values()),
|
|
||||||
"vector_results": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Public API
|
|
||||||
def run(self, question: str) -> QueryResponseList:
|
|
||||||
state = AgentState(question=question)
|
|
||||||
final_state: AgentState = self.app.invoke(state)
|
|
||||||
return QueryResponseList(responses=final_state.sql_results)
|
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from db.models import InvestorTable
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from pydantic_schemas import Investor, InvestorList
|
from pydantic_schemas import Investor, InvestorList
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
|
||||||
from app.db.tables import InvestorTable
|
|
||||||
|
|
||||||
# Add these imports for your databases
|
# Add these imports for your databases
|
||||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
# from your_vector_db import VectorDBClient
|
# from your_vector_db import VectorDBClient
|
||||||
|
|||||||
+34
-12
@@ -1,13 +1,22 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
from langchain import hub
|
||||||
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||||
|
from langchain_community.utilities import SQLDatabase
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
from pydantic_schemas import Investor, InvestorList
|
from pydantic_schemas import Investor, InvestorList
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
|
||||||
# Add these imports for your databases
|
# Connect to SQLite
|
||||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
# from your_vector_db import VectorDBClient
|
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||||
|
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
||||||
|
system_message = (
|
||||||
|
prompt_template.format(dialect="SQLite", top_k=5)
|
||||||
|
+ "\n Get answers from the Sql database and the vector database"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QueryProcessor:
|
class QueryProcessor:
|
||||||
@@ -19,12 +28,16 @@ class QueryProcessor:
|
|||||||
self.llm = ChatOpenAI(
|
self.llm = ChatOpenAI(
|
||||||
api_key=settings.OPENROUTER_API_KEY,
|
api_key=settings.OPENROUTER_API_KEY,
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url="https://openrouter.ai/api/v1",
|
||||||
model="openai/gpt-oss-120b:free",
|
model="google/gemini-2.5-flash-lite",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||||
self.structured_llm = self.llm.with_structured_output(InvestorList)
|
self.agent = create_react_agent(
|
||||||
self.sql_session = sql_session
|
model=self.llm,
|
||||||
|
tools=self.toolkit.get_tools() + [self.query_vector_database],
|
||||||
|
prompt=system_message,
|
||||||
|
response_format=InvestorList,
|
||||||
|
)
|
||||||
self.vector_db_client = vector_db_client
|
self.vector_db_client = vector_db_client
|
||||||
|
|
||||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||||
@@ -49,13 +62,22 @@ class QueryProcessor:
|
|||||||
"""Query the vector database for investor information."""
|
"""Query the vector database for investor information."""
|
||||||
if not self.vector_db_client:
|
if not self.vector_db_client:
|
||||||
return None
|
return None
|
||||||
|
print("VECTOR STORE WAS CALLED")
|
||||||
|
|
||||||
# Implement vector database querying logic here
|
# Query the collection directly, not passing collection as parameter
|
||||||
results = self.vector_db_client.query(collection=self.collection, query=query)
|
results = self.collection.query(
|
||||||
investors = [Investor(**doc.metadata) for doc in results.documents]
|
query_texts=[query], # ChromaDB expects a list of query texts
|
||||||
return InvestorList(investors=investors)
|
n_results=3, # Specify how many results you want
|
||||||
|
)
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
# ChromaDB returns results in a different structure
|
||||||
|
# results will have 'documents', 'metadatas', 'ids', 'distances'
|
||||||
|
return results
|
||||||
|
|
||||||
def process_query(self, question: str) -> InvestorList:
|
def process_query(self, question: str) -> InvestorList:
|
||||||
"""Process a query using the LLM and return structured investor data."""
|
"""Process a query using the LLM and return structured investor data."""
|
||||||
response = self.structured_llm.predict(question=question)
|
response = self.agent.invoke(
|
||||||
|
{"messages": [("user", question)]},
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -8,13 +8,9 @@ chromadb>=0.4.0
|
|||||||
|
|
||||||
# LLM integration
|
# LLM integration
|
||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
langchain>=0.2.0
|
|
||||||
langchain-openai>=0.1.0
|
|
||||||
langgraph>=0.2.0
|
|
||||||
|
|
||||||
# Environment management
|
# Environment management
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
pydantic-settings>=2.0.0
|
|
||||||
|
|
||||||
# Additional dependencies for data processing
|
# Additional dependencies for data processing
|
||||||
typing-extensions>=4.0.0
|
typing-extensions>=4.0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user