Refactor investor-related schemas and models; update database configuration and enhance investor processing logic
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,6 +1,6 @@
|
|||||||
from fastapi.routing import apirouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
router = apirouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/investors")
|
@router.get("/investors")
|
||||||
def read_investors():
|
def read_investors():
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
+1
-1
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
|||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
# Database configuration
|
# Database configuration
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors.db")
|
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors_2.db")
|
||||||
|
|
||||||
# Create engine
|
# Create engine
|
||||||
engine = create_engine(DATABASE_URL, echo=False)
|
engine = create_engine(DATABASE_URL, echo=False)
|
||||||
|
|||||||
+18
-4
@@ -33,6 +33,8 @@ investor_sector_association = Table(
|
|||||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvestorTable(Base):
|
class InvestorTable(Base):
|
||||||
__tablename__ = "investors"
|
__tablename__ = "investors"
|
||||||
|
|
||||||
@@ -42,7 +44,7 @@ class InvestorTable(Base):
|
|||||||
aum = Column(Integer, nullable=False) # Assets Under Management
|
aum = Column(Integer, nullable=False) # Assets Under Management
|
||||||
check_size_lower = Column(Integer, nullable=False) # Lower bound
|
check_size_lower = Column(Integer, nullable=False) # Lower bound
|
||||||
check_size_upper = Column(Integer, nullable=False) # Upper bound
|
check_size_upper = Column(Integer, nullable=False) # Upper bound
|
||||||
geography = Column(String, nullable=False)
|
geographic_focus = Column(String, nullable=False)
|
||||||
stage_focus = Column(Enum(InvestmentStage), nullable=False)
|
stage_focus = Column(Enum(InvestmentStage), nullable=False)
|
||||||
number_of_investments = Column(Integer, default=0)
|
number_of_investments = Column(Integer, default=0)
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||||
@@ -58,6 +60,12 @@ class InvestorTable(Base):
|
|||||||
secondary=investor_company_association,
|
secondary=investor_company_association,
|
||||||
back_populates="investors",
|
back_populates="investors",
|
||||||
)
|
)
|
||||||
|
team_members = relationship("InvestorTeamMember", back_populates="investor")
|
||||||
|
sectors = relationship(
|
||||||
|
"SectorTable",
|
||||||
|
secondary=investor_sector_association,
|
||||||
|
back_populates="investors",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompanyTable(Base):
|
class CompanyTable(Base):
|
||||||
@@ -88,16 +96,22 @@ class SectorTable(Base):
|
|||||||
__tablename__ = "sectors"
|
__tablename__ = "sectors"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
name = Column(String, unique=True, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Add relationship back to investors
|
||||||
|
investors = relationship(
|
||||||
|
"InvestorTable",
|
||||||
|
secondary=investor_sector_association,
|
||||||
|
back_populates="sectors",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvestorTeamMember(Base):
|
class InvestorTeamMember(Base):
|
||||||
|
__tablename__ = "investor_team"
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
role = Column(String, nullable=False)
|
role = Column(String, nullable=False)
|
||||||
email = Column(String, unique=True, nullable=False)
|
email = Column(String, nullable=False)
|
||||||
|
|
||||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||||
investor = relationship("InvestorTable", back_populates="team_members")
|
investor = relationship("InvestorTable", back_populates="team_members")
|
||||||
|
|||||||
+4
-3
@@ -1,14 +1,15 @@
|
|||||||
import io
|
import io
|
||||||
from api import investors
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from db.db import db_dependency
|
from api import investors
|
||||||
|
from db.db import db_dependency, init_database
|
||||||
from fastapi import FastAPI, File, UploadFile
|
from fastapi import FastAPI, File, UploadFile
|
||||||
from services.openrouter import InvestorProcessor
|
from services.openrouter import InvestorProcessor
|
||||||
from services.querying import QueryProcessor
|
from services.querying import QueryProcessor
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(investors.router)
|
app.include_router(investors.router)
|
||||||
# init_database()
|
init_database()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class InvestmentStage(str, Enum):
|
||||||
|
SEED = "seed"
|
||||||
|
SERIES_A = "series_a"
|
||||||
|
SERIES_B = "series_b"
|
||||||
|
SERIES_C = "series_c"
|
||||||
|
GROWTH = "growth"
|
||||||
|
LATE_STAGE = "late_stage"
|
||||||
|
|
||||||
|
|
||||||
|
class SectorSchema(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class CompanySchema(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
industry: str
|
||||||
|
location: str
|
||||||
|
founded_year: Optional[int]
|
||||||
|
website: Optional[str]
|
||||||
|
created_at: Optional[datetime]
|
||||||
|
updated_at: Optional[datetime]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class InvestorTeamMemberSchema(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
email: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class InvestorSchema(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
description: Optional[str]
|
||||||
|
aum: int
|
||||||
|
check_size_lower: int
|
||||||
|
check_size_upper: int
|
||||||
|
geographic_focus: str
|
||||||
|
stage_focus: InvestmentStage
|
||||||
|
number_of_investments: int
|
||||||
|
created_at: Optional[datetime]
|
||||||
|
updated_at: Optional[datetime]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class InvestorData(BaseModel):
|
||||||
|
"""Comprehensive investor data schema for LLM processing"""
|
||||||
|
investor: InvestorSchema
|
||||||
|
portfolio_companies: List[CompanySchema] = []
|
||||||
|
team_members: List[InvestorTeamMemberSchema] = []
|
||||||
|
sectors: List[SectorSchema] = []
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class InvestorList(BaseModel):
|
||||||
|
investors: List[InvestorData]
|
||||||
Binary file not shown.
+150
-35
@@ -3,15 +3,18 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from db.models import InvestorTable
|
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from pydantic_schemas import Investor, InvestorList
|
from py_schemas import InvestorData
|
||||||
|
from pydantic import BaseModel
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
|
||||||
# Add these imports for your databases
|
|
||||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
class InvestorList(BaseModel):
|
||||||
# from your_vector_db import VectorDBClient
|
"""Schema for LLM structured output"""
|
||||||
|
|
||||||
|
investor_list: List[InvestorData]
|
||||||
|
|
||||||
|
|
||||||
class InvestorProcessor:
|
class InvestorProcessor:
|
||||||
@@ -25,22 +28,27 @@ class InvestorProcessor:
|
|||||||
Given the following CSV data rows:
|
Given the following CSV data rows:
|
||||||
{question}
|
{question}
|
||||||
|
|
||||||
For each row, extract and structure the following fields:
|
For each row, extract and structure the following fields for the investor:
|
||||||
- name: The investor's full name
|
- name: The investor's full name
|
||||||
|
- description: Description of the investor
|
||||||
- aum: Assets under management (as integer, use 0 if not available)
|
- aum: Assets under management (as integer, use 0 if not available)
|
||||||
- check_size: Investment check size (as string)
|
- check_size_lower: Lower bound of investment check size (as integer)
|
||||||
- sector_focus: Sector focus (as string)
|
- check_size_upper: Upper bound of investment check size (as integer)
|
||||||
- stage_focus: Investment stage focus (as string)
|
- geographic_focus: Geographic region focus
|
||||||
- region: Geographic region (as string)
|
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||||
- investment_thesis: Investment thesis (as string)
|
- number_of_investments: Number of investments made (default 0)
|
||||||
- investor_description: Description of the investor (as string)
|
|
||||||
|
Also extract related data:
|
||||||
|
- portfolio_companies: List of companies they've invested in
|
||||||
|
- team_members: List of team members with name, role, email
|
||||||
|
- sectors: List of sectors they focus on
|
||||||
|
|
||||||
Important:
|
Important:
|
||||||
- If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers)
|
- If a field is not available, use appropriate defaults
|
||||||
- Ensure all text fields are properly escaped and contain no control characters
|
- stage_focus must be one of the valid enum values
|
||||||
- Return clean, valid JSON only
|
- Return clean, valid JSON only
|
||||||
|
|
||||||
Return the data as a structured list of investors."""
|
Return the data as a structured list of comprehensive investor data."""
|
||||||
|
|
||||||
self.prompt = PromptTemplate(
|
self.prompt = PromptTemplate(
|
||||||
template=self.template, input_variables=["question"]
|
template=self.template, input_variables=["question"]
|
||||||
@@ -49,7 +57,7 @@ Return the data as a structured list of investors."""
|
|||||||
self.llm = ChatOpenAI(
|
self.llm = ChatOpenAI(
|
||||||
api_key=settings.OPENROUTER_API_KEY,
|
api_key=settings.OPENROUTER_API_KEY,
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url="https://openrouter.ai/api/v1",
|
||||||
model="openai/gpt-oss-120b:fre",
|
model="google/gemini-2.5-flash-lite",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +73,9 @@ Return the data as a structured list of investors."""
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List:
|
async def _process_batch(
|
||||||
|
self, batch: pd.DataFrame, batch_idx: int
|
||||||
|
) -> List[InvestorData]:
|
||||||
"""Process a single batch of data"""
|
"""Process a single batch of data"""
|
||||||
# Convert batch to string representation - clean the data
|
# Convert batch to string representation - clean the data
|
||||||
batch_str = ""
|
batch_str = ""
|
||||||
@@ -102,25 +112,101 @@ Return the data as a structured list of investors."""
|
|||||||
print(f"Error processing batch {batch_idx + 1}: {e}")
|
print(f"Error processing batch {batch_idx + 1}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _save_to_sql(self, investors: List[Investor]) -> None:
|
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||||
"""Save investors to SQL database"""
|
"""Save investors and related data to SQL database"""
|
||||||
if not self.sql_session:
|
if not self.sql_session:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Implement SQL saving logic here
|
try:
|
||||||
for investor in investors:
|
for investor_data in investor_data_list:
|
||||||
|
# Save investor
|
||||||
db_investor = InvestorTable(
|
db_investor = InvestorTable(
|
||||||
name=investor.name,
|
name=investor_data.investor.name,
|
||||||
aum=investor.aum,
|
description=investor_data.investor.description,
|
||||||
check_size=investor.check_size,
|
aum=investor_data.investor.aum,
|
||||||
sector_focus=investor.sector_focus,
|
check_size_lower=investor_data.investor.check_size_lower,
|
||||||
stage_focus=investor.stage_focus,
|
check_size_upper=investor_data.investor.check_size_upper,
|
||||||
region=investor.region,
|
geographic_focus=investor_data.investor.geographic_focus,
|
||||||
|
stage_focus=investor_data.investor.stage_focus,
|
||||||
|
number_of_investments=investor_data.investor.number_of_investments,
|
||||||
)
|
)
|
||||||
self.sql_session.add(db_investor)
|
self.sql_session.add(db_investor)
|
||||||
self.sql_session.commit()
|
self.sql_session.flush() # Get the ID
|
||||||
|
|
||||||
async def _save_to_vector_db(self, investors: List[Investor]) -> None:
|
# Save sectors and create associations
|
||||||
|
for sector_data in investor_data.sectors:
|
||||||
|
# Check if sector exists, create if not
|
||||||
|
existing_sector = (
|
||||||
|
self.sql_session.query(SectorTable)
|
||||||
|
.filter(SectorTable.name == sector_data.name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_sector:
|
||||||
|
db_sector = SectorTable(name=sector_data.name)
|
||||||
|
self.sql_session.add(db_sector)
|
||||||
|
self.sql_session.flush()
|
||||||
|
# Add sector to investor's sectors
|
||||||
|
db_investor.sectors.append(db_sector)
|
||||||
|
else:
|
||||||
|
# Add existing sector to investor if not already there
|
||||||
|
if existing_sector not in db_investor.sectors:
|
||||||
|
db_investor.sectors.append(existing_sector)
|
||||||
|
|
||||||
|
# Save companies and create portfolio associations
|
||||||
|
for company_data in investor_data.portfolio_companies:
|
||||||
|
# Check if company exists, create if not
|
||||||
|
existing_company = (
|
||||||
|
self.sql_session.query(CompanyTable)
|
||||||
|
.filter(CompanyTable.name == company_data.name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_company:
|
||||||
|
db_company = CompanyTable(
|
||||||
|
name=company_data.name,
|
||||||
|
industry=company_data.industry,
|
||||||
|
location=company_data.location,
|
||||||
|
founded_year=company_data.founded_year,
|
||||||
|
website=company_data.website,
|
||||||
|
)
|
||||||
|
self.sql_session.add(db_company)
|
||||||
|
self.sql_session.flush()
|
||||||
|
|
||||||
|
# Add to investor's portfolio
|
||||||
|
db_investor.portfolio_companies.append(db_company)
|
||||||
|
else:
|
||||||
|
# Add existing company to portfolio if not already there
|
||||||
|
if existing_company not in db_investor.portfolio_companies:
|
||||||
|
db_investor.portfolio_companies.append(existing_company)
|
||||||
|
|
||||||
|
# Save team members
|
||||||
|
for team_member_data in investor_data.team_members:
|
||||||
|
# Check if team member exists
|
||||||
|
existing_member = (
|
||||||
|
self.sql_session.query(InvestorTeamMember)
|
||||||
|
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_member:
|
||||||
|
db_team_member = InvestorTeamMember(
|
||||||
|
name=team_member_data.name,
|
||||||
|
role=team_member_data.role,
|
||||||
|
email=team_member_data.email,
|
||||||
|
investor_id=db_investor.id,
|
||||||
|
)
|
||||||
|
self.sql_session.add(db_team_member)
|
||||||
|
|
||||||
|
self.sql_session.commit()
|
||||||
|
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.sql_session.rollback()
|
||||||
|
print(f"Error saving to SQL database: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||||
"""Save investors to vector database"""
|
"""Save investors to vector database"""
|
||||||
if not self.vector_db_client:
|
if not self.vector_db_client:
|
||||||
return
|
return
|
||||||
@@ -129,19 +215,47 @@ Return the data as a structured list of investors."""
|
|||||||
metadatas = []
|
metadatas = []
|
||||||
ids = []
|
ids = []
|
||||||
|
|
||||||
for i, investor in enumerate(investors):
|
for i, investor_data in enumerate(investor_data_list):
|
||||||
doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}"
|
investor = investor_data.investor
|
||||||
|
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||||
|
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||||
|
|
||||||
|
doc_text = f"""
|
||||||
|
Investor: {investor.name}
|
||||||
|
Description: {investor.description or "N/A"}
|
||||||
|
AUM: ${investor.aum:,}
|
||||||
|
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||||
|
Geographic Focus: {investor.geographic_focus}
|
||||||
|
Stage Focus: {investor.stage_focus.value}
|
||||||
|
Sectors: {sectors}
|
||||||
|
Portfolio Companies: {companies}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
documents.append(doc_text)
|
documents.append(doc_text)
|
||||||
metadatas.append({"name": investor.name})
|
metadatas.append(
|
||||||
ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}")
|
{
|
||||||
|
"name": investor.name,
|
||||||
|
"stage_focus": investor.stage_focus.value,
|
||||||
|
"geographic_focus": investor.geographic_focus,
|
||||||
|
"aum": investor.aum,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ids.append(
|
||||||
|
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||||
|
)
|
||||||
|
|
||||||
if documents:
|
if documents:
|
||||||
# Use add method with proper parameters
|
try:
|
||||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||||
|
print(
|
||||||
|
f"Successfully saved {len(documents)} investors to vector database"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving to vector database: {e}")
|
||||||
|
|
||||||
async def process_csv(
|
async def process_csv(
|
||||||
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
|
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
|
||||||
) -> List:
|
) -> List[InvestorData]:
|
||||||
"""Process CSV data in parallel batches and save to databases"""
|
"""Process CSV data in parallel batches and save to databases"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@@ -172,6 +286,7 @@ Return the data as a structured list of investors."""
|
|||||||
|
|
||||||
# Save to databases
|
# Save to databases
|
||||||
if results:
|
if results:
|
||||||
|
print(f"Successfully processed {len(results)} investors")
|
||||||
await self._save_to_sql(results)
|
await self._save_to_sql(results)
|
||||||
await self._save_to_vector_db(results)
|
await self._save_to_vector_db(results)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|||||||
from langchain_community.utilities import SQLDatabase
|
from langchain_community.utilities import SQLDatabase
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from pydantic_schemas import Investor, InvestorList
|
from py_schemas import InvestorList
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
|
||||||
# Connect to SQLite
|
# Connect to SQLite
|
||||||
@@ -29,7 +29,7 @@ class QueryProcessor:
|
|||||||
api_key=settings.OPENROUTER_API_KEY,
|
api_key=settings.OPENROUTER_API_KEY,
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url="https://openrouter.ai/api/v1",
|
||||||
model="google/gemini-2.5-flash-lite",
|
model="google/gemini-2.5-flash-lite",
|
||||||
temperature=0,
|
temperature=0.3,
|
||||||
)
|
)
|
||||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||||
self.agent = create_react_agent(
|
self.agent = create_react_agent(
|
||||||
|
|||||||
Reference in New Issue
Block a user