Add CompanyTable model and refactor query handling; update requirements for new dependencies
This commit is contained in:
@@ -12,3 +12,4 @@
|
||||
|
||||
/*.db
|
||||
|
||||
/*.cypython-*
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,8 @@
|
||||
from fastapi.routing import apirouter
|
||||
|
||||
router = apirouter()
|
||||
|
||||
@router.get("/companies")
|
||||
def read_companies():
|
||||
return {"message": "list of companies"}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from fastapi.routing import apirouter
|
||||
|
||||
router = apirouter()
|
||||
|
||||
@router.get("/investors")
|
||||
def read_investors():
|
||||
return {"message": "list of investors"}
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,103 @@
|
||||
import datetime
|
||||
import enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import Enum
|
||||
|
||||
from db.db import Base
|
||||
|
||||
|
||||
class InvestmentStage(enum.Enum):
|
||||
SEED = "seed"
|
||||
SERIES_A = "series_a"
|
||||
SERIES_B = "series_b"
|
||||
SERIES_C = "series_c"
|
||||
GROWTH = "growth"
|
||||
LATE_STAGE = "late_stage"
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between investors and companies
|
||||
investor_company_association = Table(
|
||||
"investor_companies",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
|
||||
# Association table for investor-sector many-to-many
|
||||
investor_sector_association = Table(
|
||||
"investor_sectors",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
class InvestorTable(Base):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
aum = Column(Integer, nullable=False) # Assets Under Management
|
||||
check_size_lower = Column(Integer, nullable=False) # Lower bound
|
||||
check_size_upper = Column(Integer, nullable=False) # Upper bound
|
||||
geography = Column(String, nullable=False)
|
||||
stage_focus = Column(Enum(InvestmentStage), nullable=False)
|
||||
number_of_investments = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
onupdate=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
|
||||
# Relationship to portfolio companies
|
||||
portfolio_companies = relationship(
|
||||
"CompanyTable",
|
||||
secondary=investor_company_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base):
|
||||
__tablename__ = "companies"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
industry = Column(String, nullable=False)
|
||||
location = Column(String, nullable=False)
|
||||
founded_year = Column(Integer, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
onupdate=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
|
||||
# Relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_company_association,
|
||||
back_populates="portfolio_companies",
|
||||
)
|
||||
|
||||
|
||||
class SectorTable(Base):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, unique=True, nullable=False)
|
||||
|
||||
|
||||
|
||||
|
||||
class InvestorTeamMember(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
email = Column(String, unique=True, nullable=False)
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
@@ -1,23 +0,0 @@
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String
|
||||
|
||||
from app.db.db import Base
|
||||
|
||||
|
||||
class InvestorTable(Base):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
aum = Column(Integer, nullable=False)
|
||||
check_size = Column(String, nullable=False)
|
||||
sector_focus = Column(String, nullable=False)
|
||||
stage_focus = Column(String, nullable=False)
|
||||
region = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
onupdate=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
+11
-13
@@ -1,16 +1,14 @@
|
||||
import io
|
||||
|
||||
from api import investors
|
||||
import pandas as pd
|
||||
from app.db.db import db_dependency, init_database
|
||||
from db.db import db_dependency
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from app.services.openrouter import InvestorProcessor
|
||||
|
||||
from app.pydantic_schemas import QueryRequest, QueryResponseList
|
||||
from app.services.langgraph_agent import LangGraphQueryAgent
|
||||
from services.openrouter import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
init_database()
|
||||
app.include_router(investors.router)
|
||||
# init_database()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@@ -32,11 +30,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
|
||||
return {"results": [r.dict() for r in results]}
|
||||
|
||||
|
||||
@app.post("/query", response_model=QueryResponseList)
|
||||
async def query_investors(db: db_dependency, request: QueryRequest):
|
||||
agent = LangGraphQueryAgent(sql_session=db)
|
||||
result = agent.run(request.question)
|
||||
return result
|
||||
@app.post("/query")
|
||||
async def query_investors(db: db_dependency, question: str):
|
||||
processor = QueryProcessor(sql_session=db)
|
||||
results = processor.process_query(question)
|
||||
return {"results": results}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,162 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chromadb
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.db.tables import InvestorTable
|
||||
from app.pydantic_schemas import QueryResponse, QueryResponseList
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
question: str
|
||||
sql_results: List[QueryResponse] = Field(default_factory=list)
|
||||
vector_results: List[QueryResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class LangGraphQueryAgent:
|
||||
"""Simple LangGraph agent that queries both SQL and Chroma and merges results."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
) -> None:
|
||||
self.sql_session = sql_session
|
||||
|
||||
# Setup Chroma collection
|
||||
self.vector_db_client = vector_db_client or chromadb.PersistentClient(
|
||||
path="./chroma_db"
|
||||
)
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus",
|
||||
},
|
||||
)
|
||||
|
||||
# Build graph
|
||||
graph = StateGraph(AgentState)
|
||||
graph.add_node("sql_search", self._sql_search)
|
||||
graph.add_node("vector_search", self._vector_search)
|
||||
graph.add_node("merge", self._merge)
|
||||
|
||||
# Parallel fan-out: START -> sql_search & vector_search -> merge -> END
|
||||
graph.add_edge(START, "sql_search")
|
||||
graph.add_edge(START, "vector_search")
|
||||
graph.add_edge("sql_search", "merge")
|
||||
graph.add_edge("vector_search", "merge")
|
||||
graph.add_edge("merge", END)
|
||||
|
||||
self.app = graph.compile()
|
||||
|
||||
# Nodes
|
||||
def _sql_search(self, state: AgentState) -> Dict[str, Any]:
|
||||
results: List[QueryResponse] = []
|
||||
if self.sql_session is None:
|
||||
return {"sql_results": results}
|
||||
|
||||
# Simple LIKE-based search across a few fields
|
||||
# Note: SQLite uses case-insensitive LIKE by default for ASCII.
|
||||
q = (
|
||||
self.sql_session.query(InvestorTable)
|
||||
.filter(
|
||||
(func.lower(InvestorTable.name).like(f"%{state.question.lower()}%"))
|
||||
| (
|
||||
func.lower(InvestorTable.sector_focus).like(
|
||||
f"%{state.question.lower()}%"
|
||||
)
|
||||
)
|
||||
| (
|
||||
func.lower(InvestorTable.stage_focus).like(
|
||||
f"%{state.question.lower()}%"
|
||||
)
|
||||
)
|
||||
| (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%"))
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
for row in q.all():
|
||||
results.append(
|
||||
QueryResponse(
|
||||
name=row.name,
|
||||
aum=row.aum,
|
||||
check_size=row.check_size,
|
||||
sector_focus=row.sector_focus,
|
||||
stage_focus=row.stage_focus,
|
||||
region=row.region,
|
||||
investment_thesis="",
|
||||
investor_description="",
|
||||
reason="Matched SQL fields via LIKE",
|
||||
)
|
||||
)
|
||||
|
||||
return {"sql_results": results}
|
||||
|
||||
def _vector_search(self, state: AgentState) -> Dict[str, Any]:
|
||||
results: List[QueryResponse] = []
|
||||
try:
|
||||
q = self.collection.query(query_texts=[state.question], n_results=10)
|
||||
# q has keys: ids, distances, documents, metadatas
|
||||
docs = q.get("documents") or []
|
||||
metas = q.get("metadatas") or []
|
||||
if docs and metas:
|
||||
for i, md in enumerate(metas[0]):
|
||||
name = md.get("name", "Unknown")
|
||||
results.append(
|
||||
QueryResponse(
|
||||
name=name,
|
||||
aum=0,
|
||||
check_size="",
|
||||
sector_focus="",
|
||||
stage_focus="",
|
||||
region=md.get("headquarters", ""),
|
||||
investment_thesis="",
|
||||
investor_description=(docs[0][i] if docs[0] else ""),
|
||||
reason="Vector similarity in Chroma",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort; leave vector results empty on failure
|
||||
pass
|
||||
|
||||
return {"vector_results": results}
|
||||
|
||||
def _merge(self, state: AgentState) -> Dict[str, Any]:
|
||||
# Deduplicate by name, prefer SQL fields where available, keep first reason
|
||||
merged: Dict[str, QueryResponse] = {}
|
||||
|
||||
for item in state.vector_results + state.sql_results:
|
||||
if item.name not in merged:
|
||||
merged[item.name] = item
|
||||
else:
|
||||
existing = merged[item.name]
|
||||
merged[item.name] = QueryResponse(
|
||||
name=existing.name,
|
||||
aum=existing.aum or item.aum,
|
||||
check_size=existing.check_size or item.check_size,
|
||||
sector_focus=existing.sector_focus or item.sector_focus,
|
||||
stage_focus=existing.stage_focus or item.stage_focus,
|
||||
region=existing.region or item.region,
|
||||
investment_thesis=existing.investment_thesis
|
||||
or item.investment_thesis,
|
||||
investor_description=existing.investor_description
|
||||
or item.investor_description,
|
||||
reason=existing.reason or item.reason,
|
||||
)
|
||||
|
||||
# Store back into sql_results to pass through the END with full state
|
||||
return {
|
||||
"sql_results": list(merged.values()),
|
||||
"vector_results": [],
|
||||
}
|
||||
|
||||
# Public API
|
||||
def run(self, question: str) -> QueryResponseList:
|
||||
state = AgentState(question=question)
|
||||
final_state: AgentState = self.app.invoke(state)
|
||||
return QueryResponseList(responses=final_state.sql_results)
|
||||
|
||||
@@ -3,13 +3,12 @@ from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from db.models import InvestorTable
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic_schemas import Investor, InvestorList
|
||||
from settings import settings
|
||||
|
||||
from app.db.tables import InvestorTable
|
||||
|
||||
# Add these imports for your databases
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from your_vector_db import VectorDBClient
|
||||
|
||||
+34
-12
@@ -1,13 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
import chromadb
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from pydantic_schemas import Investor, InvestorList
|
||||
from settings import settings
|
||||
|
||||
# Add these imports for your databases
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from your_vector_db import VectorDBClient
|
||||
# Connect to SQLite
|
||||
|
||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
||||
system_message = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n Get answers from the Sql database and the vector database"
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessor:
|
||||
@@ -19,12 +28,16 @@ class QueryProcessor:
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-oss-120b:free",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.structured_llm = self.llm.with_structured_output(InvestorList)
|
||||
self.sql_session = sql_session
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
self.agent = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=self.toolkit.get_tools() + [self.query_vector_database],
|
||||
prompt=system_message,
|
||||
response_format=InvestorList,
|
||||
)
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
@@ -49,13 +62,22 @@ class QueryProcessor:
|
||||
"""Query the vector database for investor information."""
|
||||
if not self.vector_db_client:
|
||||
return None
|
||||
print("VECTOR STORE WAS CALLED")
|
||||
|
||||
# Implement vector database querying logic here
|
||||
results = self.vector_db_client.query(collection=self.collection, query=query)
|
||||
investors = [Investor(**doc.metadata) for doc in results.documents]
|
||||
return InvestorList(investors=investors)
|
||||
# Query the collection directly, not passing collection as parameter
|
||||
results = self.collection.query(
|
||||
query_texts=[query], # ChromaDB expects a list of query texts
|
||||
n_results=3, # Specify how many results you want
|
||||
)
|
||||
print(results)
|
||||
|
||||
# ChromaDB returns results in a different structure
|
||||
# results will have 'documents', 'metadatas', 'ids', 'distances'
|
||||
return results
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return structured investor data."""
|
||||
response = self.structured_llm.predict(question=question)
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -8,13 +8,9 @@ chromadb>=0.4.0
|
||||
|
||||
# LLM integration
|
||||
openai>=1.0.0
|
||||
langchain>=0.2.0
|
||||
langchain-openai>=0.1.0
|
||||
langgraph>=0.2.0
|
||||
|
||||
# Environment management
|
||||
python-dotenv>=1.0.0
|
||||
pydantic-settings>=2.0.0
|
||||
|
||||
# Additional dependencies for data processing
|
||||
typing-extensions>=4.0.0
|
||||
|
||||
Reference in New Issue
Block a user