feat: Enhance data models and sorting logic for investors and projects

This commit is contained in:
bolade
2025-11-11 13:09:30 +01:00
parent 0e4763bf4f
commit 5e83734acf
6 changed files with 129 additions and 35 deletions
+4
View File
@@ -162,6 +162,7 @@ class InvestorMember(Base, TimestampMixin):
role = Column(String, nullable=True) role = Column(String, nullable=True)
title = Column(String, nullable=True) # Alternative to role title = Column(String, nullable=True) # Alternative to role
email = Column(String, nullable=True) email = Column(String, nullable=True)
linkedin = Column(String, nullable=True) # LinkedIn profile URL
source_url = Column(String, nullable=True) # URL where member info was found source_url = Column(String, nullable=True) # URL where member info was found
investor_id = Column(Integer, ForeignKey("investors.id")) investor_id = Column(Integer, ForeignKey("investors.id"))
@@ -215,6 +216,8 @@ class CompanyTable(Base, TimestampMixin):
description = Column(String, nullable=True) description = Column(String, nullable=True)
founded_year = Column(Integer, nullable=True) founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True) website = Column(String, nullable=True)
product_service = Column(Text, nullable=True) # Product/service description
clients = Column(JSON, nullable=True) # List of client names or client information
members = relationship( members = relationship(
"CompanyMember", back_populates="company", cascade="all, delete-orphan" "CompanyMember", back_populates="company", cascade="all, delete-orphan"
@@ -300,6 +303,7 @@ class ProjectTable(Base, TimestampMixin):
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
start_date = Column(DateTime, nullable=True) start_date = Column(DateTime, nullable=True)
end_date = Column(DateTime, nullable=True) end_date = Column(DateTime, nullable=True)
is_archived = Column(Integer, default=0, nullable=False) # 0 = active, 1 = archived
sector = relationship( sector = relationship(
"SectorTable", secondary=project_sector_association, back_populates="projects" "SectorTable", secondary=project_sector_association, back_populates="projects"
+14 -4
View File
@@ -63,11 +63,13 @@ def read_companies(
# Transform CompanyTable objects to CompanyData format # Transform CompanyTable objects to CompanyData format
company_data_list = [] company_data_list = []
for company in companies: for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData( company_data = CompanyData(
company=company, company=company,
investors=company.investors, investors=company.investors,
members=company.members, members=company.members,
sectors=company.sectors, sectors=sorted_sectors,
) )
company_data_list.append(company_data) company_data_list.append(company_data)
@@ -147,11 +149,13 @@ def filter_companies(
# Transform to CompanyData format # Transform to CompanyData format
company_data_list = [] company_data_list = []
for company in companies: for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData( company_data = CompanyData(
company=company, company=company,
investors=company.investors, investors=company.investors,
members=company.members, members=company.members,
sectors=company.sectors, sectors=sorted_sectors,
) )
company_data_list.append(company_data) company_data_list.append(company_data)
@@ -184,12 +188,15 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
if not company: if not company:
raise HTTPException(status_code=404, detail="Company not found") raise HTTPException(status_code=404, detail="Company not found")
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
# Transform to CompanyData format # Transform to CompanyData format
return CompanyData( return CompanyData(
company=company, company=company,
investors=company.investors, investors=company.investors,
members=company.members, members=company.members,
sectors=company.sectors, sectors=sorted_sectors,
) )
@@ -250,12 +257,15 @@ def update_company(
.first() .first()
) )
# Sort sectors alphabetically
sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else []
# Transform to CompanyData format # Transform to CompanyData format
return CompanyData( return CompanyData(
company=company_with_relations, company=company_with_relations,
investors=company_with_relations.investors, investors=company_with_relations.investors,
members=company_with_relations.members, members=company_with_relations.members,
sectors=company_with_relations.sectors, sectors=sorted_sectors,
) )
+58 -20
View File
@@ -81,20 +81,38 @@ def read_investors(
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
# Get paginated results # When project_id is provided, we need to get all investors first to sort by compatibility score
investors = ( # Otherwise, we can paginate at the database level
db.query(InvestorTable) if project is not None:
.options( # Get all investors (we'll sort by compatibility score, then paginate)
selectinload(InvestorTable.portfolio_companies), all_investors = (
selectinload(InvestorTable.team_members), db.query(InvestorTable)
selectinload(InvestorTable.sectors), .options(
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors), selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.all()
)
# We'll paginate after sorting by compatibility score
investors = all_investors
else:
# Get paginated results (no sorting needed)
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.offset(offset)
.limit(page_size)
.all()
) )
.offset(offset)
.limit(page_size)
.all()
)
# Transform to InvestmentResponse format (one row per investor-fund combination) # Transform to InvestmentResponse format (one row per investor-fund combination)
investment_responses = [] investment_responses = []
@@ -122,10 +140,10 @@ def read_investors(
else None else None
) )
# Get top 3 sectors from fund (id and name only) # Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [ fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name) SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else []) for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
] ]
investment_response = InvestmentResponse( investment_response = InvestmentResponse(
@@ -166,6 +184,12 @@ def read_investors(
) )
investment_responses.append(investment_response) investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
investment_responses = investment_responses[offset:offset + page_size]
# Calculate total pages # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size total_pages = (total_count + page_size - 1) // page_size
@@ -257,9 +281,16 @@ def filter_investors(
# Get total count before pagination # Get total count before pagination
total_count = query.count() total_count = query.count()
# Calculate offset and apply pagination # When project_id is provided, we need to get all funds first to sort by compatibility score
offset = (page - 1) * page_size # Otherwise, we can paginate at the database level
funds = query.offset(offset).limit(page_size).all() if project is not None:
# Get all funds (we'll sort by compatibility score, then paginate)
all_funds = query.all()
funds = all_funds
else:
# Calculate offset and apply pagination (no sorting needed)
offset = (page - 1) * page_size
funds = query.offset(offset).limit(page_size).all()
# Transform to InvestmentResponse format (one row per fund) # Transform to InvestmentResponse format (one row per fund)
investment_responses = [] investment_responses = []
@@ -286,10 +317,10 @@ def filter_investors(
else None else None
) )
# Get top 3 sectors from fund (id and name only) # Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [ fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name) SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else []) for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
] ]
investment_response = InvestmentResponse( investment_response = InvestmentResponse(
@@ -308,6 +339,13 @@ def filter_investors(
) )
investment_responses.append(investment_response) investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
offset = (page - 1) * page_size
investment_responses = investment_responses[offset:offset + page_size]
# Calculate total pages # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size total_pages = (total_count + page_size - 1) // page_size
+45 -5
View File
@@ -24,19 +24,29 @@ router = APIRouter(tags=["Project Routes"])
def read_projects( def read_projects(
page: int = Query(1, ge=1, description="Page number (starts at 1)"), page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
include_archived: bool = Query(False, description="Include archived projects"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Get all projects with their related data (paginated)""" """Get all projects with their related data (paginated)
By default, archived projects are excluded. Set include_archived=True to include them.
"""
# Calculate offset # Calculate offset
offset = (page - 1) * page_size offset = (page - 1) * page_size
# Start with base query
query = db.query(ProjectTable)
# Filter out archived projects by default
if not include_archived:
query = query.filter(ProjectTable.is_archived == 0)
# Get total count # Get total count
total_count = db.query(ProjectTable).count() total_count = query.count()
# Get paginated results # Get paginated results
projects = ( projects = (
db.query(ProjectTable) query.options(
.options(
selectinload(ProjectTable.sector), selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors), selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies), selectinload(ProjectTable.companies),
@@ -162,7 +172,7 @@ def update_project(
@router.delete("/projects/{project_id}") @router.delete("/projects/{project_id}")
def delete_project(project_id: int, db: Session = Depends(get_db)): def delete_project(project_id: int, db: Session = Depends(get_db)):
"""Delete a project""" """Delete a project permanently"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project: if not db_project:
@@ -174,6 +184,36 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
return {"message": "Project deleted successfully"} return {"message": "Project deleted successfully"}
@router.post("/projects/{project_id}/archive")
def archive_project(project_id: int, db: Session = Depends(get_db)):
"""Archive a project (soft delete)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 1
db.commit()
db.refresh(db_project)
return {"message": "Project archived successfully", "project_id": project_id}
@router.post("/projects/{project_id}/unarchive")
def unarchive_project(project_id: int, db: Session = Depends(get_db)):
"""Unarchive a project (restore from archive)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 0
db.commit()
db.refresh(db_project)
return {"message": "Project unarchived successfully", "project_id": project_id}
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData]) @router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
def filter_projects( def filter_projects(
stage: Optional[InvestmentStage] = Query( stage: Optional[InvestmentStage] = Query(
+6 -4
View File
@@ -174,9 +174,10 @@ def _calculate_project_fund_compatibility(
or fund_geo_lower in project_location_lower or fund_geo_lower in project_location_lower
): ):
geo_score = 15 geo_score = 15
# Check for common geographic terms or regional overlap # Check for common geographic terms or regional overlap (continent/country matching)
elif _check_geographic_overlap(project_location_lower, fund_geo_lower): elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
geo_score = 12 # Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score total_score += geo_score
@@ -298,9 +299,10 @@ def _calculate_project_investor_direct_compatibility(
project_location_lower in investor_geo_lower project_location_lower in investor_geo_lower
or investor_geo_lower in project_location_lower or investor_geo_lower in project_location_lower
): ):
geo_score = 10 geo_score = 15
elif _check_geographic_overlap(project_location_lower, investor_geo_lower): elif _check_geographic_overlap(project_location_lower, investor_geo_lower):
geo_score = 5 # Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score total_score += geo_score
+2 -2
View File
@@ -258,10 +258,10 @@ Return ONLY the SQL query, no explanations or markdown.""",
else None else None
) )
# Get top 3 sectors from fund (id and name only) # Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [ fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name) SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else []) for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
] ]
investment_response = InvestmentResponse( investment_response = InvestmentResponse(