Refactor investor-related schemas and models; update database configuration and enhance investor processing logic
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,6 +1,6 @@
|
||||
from fastapi.routing import apirouter
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = apirouter()
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/investors")
|
||||
def read_investors():
|
||||
|
||||
Binary file not shown.
Binary file not shown.
+1
-1
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
Base = declarative_base()
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors.db")
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors_2.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
|
||||
+18
-4
@@ -33,6 +33,8 @@ investor_sector_association = Table(
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base):
|
||||
__tablename__ = "investors"
|
||||
|
||||
@@ -42,7 +44,7 @@ class InvestorTable(Base):
|
||||
aum = Column(Integer, nullable=False) # Assets Under Management
|
||||
check_size_lower = Column(Integer, nullable=False) # Lower bound
|
||||
check_size_upper = Column(Integer, nullable=False) # Upper bound
|
||||
geography = Column(String, nullable=False)
|
||||
geographic_focus = Column(String, nullable=False)
|
||||
stage_focus = Column(Enum(InvestmentStage), nullable=False)
|
||||
number_of_investments = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
@@ -58,6 +60,12 @@ class InvestorTable(Base):
|
||||
secondary=investor_company_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
team_members = relationship("InvestorTeamMember", back_populates="investor")
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base):
|
||||
@@ -88,16 +96,22 @@ class SectorTable(Base):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, unique=True, nullable=False)
|
||||
|
||||
name = Column(String, nullable=False)
|
||||
|
||||
# Add relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
|
||||
class InvestorTeamMember(Base):
|
||||
__tablename__ = "investor_team"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
email = Column(String, unique=True, nullable=False)
|
||||
email = Column(String, nullable=False)
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
|
||||
+4
-3
@@ -1,14 +1,15 @@
|
||||
import io
|
||||
from api import investors
|
||||
|
||||
import pandas as pd
|
||||
from db.db import db_dependency
|
||||
from api import investors
|
||||
from db.db import db_dependency, init_database
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from services.openrouter import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(investors.router)
|
||||
# init_database()
|
||||
init_database()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "seed"
|
||||
SERIES_A = "series_a"
|
||||
SERIES_B = "series_b"
|
||||
SERIES_C = "series_c"
|
||||
GROWTH = "growth"
|
||||
LATE_STAGE = "late_stage"
|
||||
|
||||
|
||||
class SectorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanySchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str
|
||||
location: str
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorTeamMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
role: str
|
||||
email: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
geographic_focus: str
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorData(BaseModel):
|
||||
"""Comprehensive investor data schema for LLM processing"""
|
||||
investor: InvestorSchema
|
||||
portfolio_companies: List[CompanySchema] = []
|
||||
team_members: List[InvestorTeamMemberSchema] = []
|
||||
sectors: List[SectorSchema] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investors: List[InvestorData]
|
||||
Binary file not shown.
+154
-39
@@ -3,15 +3,18 @@ from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from db.models import InvestorTable
|
||||
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic_schemas import Investor, InvestorList
|
||||
from py_schemas import InvestorData
|
||||
from pydantic import BaseModel
|
||||
from settings import settings
|
||||
|
||||
# Add these imports for your databases
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from your_vector_db import VectorDBClient
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
"""Schema for LLM structured output"""
|
||||
|
||||
investor_list: List[InvestorData]
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
@@ -25,22 +28,27 @@ class InvestorProcessor:
|
||||
Given the following CSV data rows:
|
||||
{question}
|
||||
|
||||
For each row, extract and structure the following fields:
|
||||
For each row, extract and structure the following fields for the investor:
|
||||
- name: The investor's full name
|
||||
- description: Description of the investor
|
||||
- aum: Assets under management (as integer, use 0 if not available)
|
||||
- check_size: Investment check size (as string)
|
||||
- sector_focus: Sector focus (as string)
|
||||
- stage_focus: Investment stage focus (as string)
|
||||
- region: Geographic region (as string)
|
||||
- investment_thesis: Investment thesis (as string)
|
||||
- investor_description: Description of the investor (as string)
|
||||
- check_size_lower: Lower bound of investment check size (as integer)
|
||||
- check_size_upper: Upper bound of investment check size (as integer)
|
||||
- geographic_focus: Geographic region focus
|
||||
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||
- number_of_investments: Number of investments made (default 0)
|
||||
|
||||
Also extract related data:
|
||||
- portfolio_companies: List of companies they've invested in
|
||||
- team_members: List of team members with name, role, email
|
||||
- sectors: List of sectors they focus on
|
||||
|
||||
Important:
|
||||
- If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers)
|
||||
- Ensure all text fields are properly escaped and contain no control characters
|
||||
- If a field is not available, use appropriate defaults
|
||||
- stage_focus must be one of the valid enum values
|
||||
- Return clean, valid JSON only
|
||||
|
||||
Return the data as a structured list of investors."""
|
||||
Return the data as a structured list of comprehensive investor data."""
|
||||
|
||||
self.prompt = PromptTemplate(
|
||||
template=self.template, input_variables=["question"]
|
||||
@@ -49,7 +57,7 @@ Return the data as a structured list of investors."""
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-oss-120b:fre",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
@@ -65,7 +73,9 @@ Return the data as a structured list of investors."""
|
||||
},
|
||||
)
|
||||
|
||||
async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List:
|
||||
async def _process_batch(
|
||||
self, batch: pd.DataFrame, batch_idx: int
|
||||
) -> List[InvestorData]:
|
||||
"""Process a single batch of data"""
|
||||
# Convert batch to string representation - clean the data
|
||||
batch_str = ""
|
||||
@@ -102,25 +112,101 @@ Return the data as a structured list of investors."""
|
||||
print(f"Error processing batch {batch_idx + 1}: {e}")
|
||||
return []
|
||||
|
||||
async def _save_to_sql(self, investors: List[Investor]) -> None:
|
||||
"""Save investors to SQL database"""
|
||||
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors and related data to SQL database"""
|
||||
if not self.sql_session:
|
||||
return
|
||||
|
||||
# Implement SQL saving logic here
|
||||
for investor in investors:
|
||||
db_investor = InvestorTable(
|
||||
name=investor.name,
|
||||
aum=investor.aum,
|
||||
check_size=investor.check_size,
|
||||
sector_focus=investor.sector_focus,
|
||||
stage_focus=investor.stage_focus,
|
||||
region=investor.region,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.commit()
|
||||
try:
|
||||
for investor_data in investor_data_list:
|
||||
# Save investor
|
||||
db_investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.flush() # Get the ID
|
||||
|
||||
async def _save_to_vector_db(self, investors: List[Investor]) -> None:
|
||||
# Save sectors and create associations
|
||||
for sector_data in investor_data.sectors:
|
||||
# Check if sector exists, create if not
|
||||
existing_sector = (
|
||||
self.sql_session.query(SectorTable)
|
||||
.filter(SectorTable.name == sector_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_sector:
|
||||
db_sector = SectorTable(name=sector_data.name)
|
||||
self.sql_session.add(db_sector)
|
||||
self.sql_session.flush()
|
||||
# Add sector to investor's sectors
|
||||
db_investor.sectors.append(db_sector)
|
||||
else:
|
||||
# Add existing sector to investor if not already there
|
||||
if existing_sector not in db_investor.sectors:
|
||||
db_investor.sectors.append(existing_sector)
|
||||
|
||||
# Save companies and create portfolio associations
|
||||
for company_data in investor_data.portfolio_companies:
|
||||
# Check if company exists, create if not
|
||||
existing_company = (
|
||||
self.sql_session.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_company:
|
||||
db_company = CompanyTable(
|
||||
name=company_data.name,
|
||||
industry=company_data.industry,
|
||||
location=company_data.location,
|
||||
founded_year=company_data.founded_year,
|
||||
website=company_data.website,
|
||||
)
|
||||
self.sql_session.add(db_company)
|
||||
self.sql_session.flush()
|
||||
|
||||
# Add to investor's portfolio
|
||||
db_investor.portfolio_companies.append(db_company)
|
||||
else:
|
||||
# Add existing company to portfolio if not already there
|
||||
if existing_company not in db_investor.portfolio_companies:
|
||||
db_investor.portfolio_companies.append(existing_company)
|
||||
|
||||
# Save team members
|
||||
for team_member_data in investor_data.team_members:
|
||||
# Check if team member exists
|
||||
existing_member = (
|
||||
self.sql_session.query(InvestorTeamMember)
|
||||
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_member:
|
||||
db_team_member = InvestorTeamMember(
|
||||
name=team_member_data.name,
|
||||
role=team_member_data.role,
|
||||
email=team_member_data.email,
|
||||
investor_id=db_investor.id,
|
||||
)
|
||||
self.sql_session.add(db_team_member)
|
||||
|
||||
self.sql_session.commit()
|
||||
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||
|
||||
except Exception as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Error saving to SQL database: {e}")
|
||||
raise
|
||||
|
||||
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors to vector database"""
|
||||
if not self.vector_db_client:
|
||||
return
|
||||
@@ -129,19 +215,47 @@ Return the data as a structured list of investors."""
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for i, investor in enumerate(investors):
|
||||
doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}"
|
||||
for i, investor_data in enumerate(investor_data_list):
|
||||
investor = investor_data.investor
|
||||
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||
|
||||
doc_text = f"""
|
||||
Investor: {investor.name}
|
||||
Description: {investor.description or "N/A"}
|
||||
AUM: ${investor.aum:,}
|
||||
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||
Geographic Focus: {investor.geographic_focus}
|
||||
Stage Focus: {investor.stage_focus.value}
|
||||
Sectors: {sectors}
|
||||
Portfolio Companies: {companies}
|
||||
""".strip()
|
||||
|
||||
documents.append(doc_text)
|
||||
metadatas.append({"name": investor.name})
|
||||
ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}")
|
||||
metadatas.append(
|
||||
{
|
||||
"name": investor.name,
|
||||
"stage_focus": investor.stage_focus.value,
|
||||
"geographic_focus": investor.geographic_focus,
|
||||
"aum": investor.aum,
|
||||
}
|
||||
)
|
||||
ids.append(
|
||||
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||
)
|
||||
|
||||
if documents:
|
||||
# Use add method with proper parameters
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
try:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
print(
|
||||
f"Successfully saved {len(documents)} investors to vector database"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error saving to vector database: {e}")
|
||||
|
||||
async def process_csv(
|
||||
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
|
||||
) -> List:
|
||||
) -> List[InvestorData]:
|
||||
"""Process CSV data in parallel batches and save to databases"""
|
||||
results = []
|
||||
|
||||
@@ -172,6 +286,7 @@ Return the data as a structured list of investors."""
|
||||
|
||||
# Save to databases
|
||||
if results:
|
||||
print(f"Successfully processed {len(results)} investors")
|
||||
await self._save_to_sql(results)
|
||||
await self._save_to_vector_db(results)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from pydantic_schemas import Investor, InvestorList
|
||||
from py_schemas import InvestorList
|
||||
from settings import settings
|
||||
|
||||
# Connect to SQLite
|
||||
@@ -29,7 +29,7 @@ class QueryProcessor:
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
temperature=0.3,
|
||||
)
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
self.agent = create_react_agent(
|
||||
|
||||
Reference in New Issue
Block a user