feat: Enhance data models and sorting logic for investors and projects
This commit is contained in:
@@ -162,6 +162,7 @@ class InvestorMember(Base, TimestampMixin):
|
|||||||
role = Column(String, nullable=True)
|
role = Column(String, nullable=True)
|
||||||
title = Column(String, nullable=True) # Alternative to role
|
title = Column(String, nullable=True) # Alternative to role
|
||||||
email = Column(String, nullable=True)
|
email = Column(String, nullable=True)
|
||||||
|
linkedin = Column(String, nullable=True) # LinkedIn profile URL
|
||||||
source_url = Column(String, nullable=True) # URL where member info was found
|
source_url = Column(String, nullable=True) # URL where member info was found
|
||||||
|
|
||||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||||
@@ -215,6 +216,8 @@ class CompanyTable(Base, TimestampMixin):
|
|||||||
description = Column(String, nullable=True)
|
description = Column(String, nullable=True)
|
||||||
founded_year = Column(Integer, nullable=True)
|
founded_year = Column(Integer, nullable=True)
|
||||||
website = Column(String, nullable=True)
|
website = Column(String, nullable=True)
|
||||||
|
product_service = Column(Text, nullable=True) # Product/service description
|
||||||
|
clients = Column(JSON, nullable=True) # List of client names or client information
|
||||||
|
|
||||||
members = relationship(
|
members = relationship(
|
||||||
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
|
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
|
||||||
@@ -300,6 +303,7 @@ class ProjectTable(Base, TimestampMixin):
|
|||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
start_date = Column(DateTime, nullable=True)
|
start_date = Column(DateTime, nullable=True)
|
||||||
end_date = Column(DateTime, nullable=True)
|
end_date = Column(DateTime, nullable=True)
|
||||||
|
is_archived = Column(Integer, default=0, nullable=False) # 0 = active, 1 = archived
|
||||||
|
|
||||||
sector = relationship(
|
sector = relationship(
|
||||||
"SectorTable", secondary=project_sector_association, back_populates="projects"
|
"SectorTable", secondary=project_sector_association, back_populates="projects"
|
||||||
|
|||||||
@@ -63,11 +63,13 @@ def read_companies(
|
|||||||
# Transform CompanyTable objects to CompanyData format
|
# Transform CompanyTable objects to CompanyData format
|
||||||
company_data_list = []
|
company_data_list = []
|
||||||
for company in companies:
|
for company in companies:
|
||||||
|
# Sort sectors alphabetically
|
||||||
|
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
|
||||||
company_data = CompanyData(
|
company_data = CompanyData(
|
||||||
company=company,
|
company=company,
|
||||||
investors=company.investors,
|
investors=company.investors,
|
||||||
members=company.members,
|
members=company.members,
|
||||||
sectors=company.sectors,
|
sectors=sorted_sectors,
|
||||||
)
|
)
|
||||||
company_data_list.append(company_data)
|
company_data_list.append(company_data)
|
||||||
|
|
||||||
@@ -147,11 +149,13 @@ def filter_companies(
|
|||||||
# Transform to CompanyData format
|
# Transform to CompanyData format
|
||||||
company_data_list = []
|
company_data_list = []
|
||||||
for company in companies:
|
for company in companies:
|
||||||
|
# Sort sectors alphabetically
|
||||||
|
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
|
||||||
company_data = CompanyData(
|
company_data = CompanyData(
|
||||||
company=company,
|
company=company,
|
||||||
investors=company.investors,
|
investors=company.investors,
|
||||||
members=company.members,
|
members=company.members,
|
||||||
sectors=company.sectors,
|
sectors=sorted_sectors,
|
||||||
)
|
)
|
||||||
company_data_list.append(company_data)
|
company_data_list.append(company_data)
|
||||||
|
|
||||||
@@ -184,12 +188,15 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
|
|||||||
if not company:
|
if not company:
|
||||||
raise HTTPException(status_code=404, detail="Company not found")
|
raise HTTPException(status_code=404, detail="Company not found")
|
||||||
|
|
||||||
|
# Sort sectors alphabetically
|
||||||
|
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
|
||||||
|
|
||||||
# Transform to CompanyData format
|
# Transform to CompanyData format
|
||||||
return CompanyData(
|
return CompanyData(
|
||||||
company=company,
|
company=company,
|
||||||
investors=company.investors,
|
investors=company.investors,
|
||||||
members=company.members,
|
members=company.members,
|
||||||
sectors=company.sectors,
|
sectors=sorted_sectors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -250,12 +257,15 @@ def update_company(
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Sort sectors alphabetically
|
||||||
|
sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else []
|
||||||
|
|
||||||
# Transform to CompanyData format
|
# Transform to CompanyData format
|
||||||
return CompanyData(
|
return CompanyData(
|
||||||
company=company_with_relations,
|
company=company_with_relations,
|
||||||
investors=company_with_relations.investors,
|
investors=company_with_relations.investors,
|
||||||
members=company_with_relations.members,
|
members=company_with_relations.members,
|
||||||
sectors=company_with_relations.sectors,
|
sectors=sorted_sectors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,25 @@ def read_investors(
|
|||||||
if not project:
|
if not project:
|
||||||
raise HTTPException(status_code=404, detail="Project not found")
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
# Get paginated results
|
# When project_id is provided, we need to get all investors first to sort by compatibility score
|
||||||
|
# Otherwise, we can paginate at the database level
|
||||||
|
if project is not None:
|
||||||
|
# Get all investors (we'll sort by compatibility score, then paginate)
|
||||||
|
all_investors = (
|
||||||
|
db.query(InvestorTable)
|
||||||
|
.options(
|
||||||
|
selectinload(InvestorTable.portfolio_companies),
|
||||||
|
selectinload(InvestorTable.team_members),
|
||||||
|
selectinload(InvestorTable.sectors),
|
||||||
|
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
|
||||||
|
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# We'll paginate after sorting by compatibility score
|
||||||
|
investors = all_investors
|
||||||
|
else:
|
||||||
|
# Get paginated results (no sorting needed)
|
||||||
investors = (
|
investors = (
|
||||||
db.query(InvestorTable)
|
db.query(InvestorTable)
|
||||||
.options(
|
.options(
|
||||||
@@ -122,10 +140,10 @@ def read_investors(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get top 3 sectors from fund (id and name only)
|
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
|
||||||
fund_sectors = [
|
fund_sectors = [
|
||||||
SectorMinimal(id=sector.id, name=sector.name)
|
SectorMinimal(id=sector.id, name=sector.name)
|
||||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
|
||||||
]
|
]
|
||||||
|
|
||||||
investment_response = InvestmentResponse(
|
investment_response = InvestmentResponse(
|
||||||
@@ -166,6 +184,12 @@ def read_investors(
|
|||||||
)
|
)
|
||||||
investment_responses.append(investment_response)
|
investment_responses.append(investment_response)
|
||||||
|
|
||||||
|
# Sort by compatibility score (descending) when project_id is provided
|
||||||
|
if project is not None:
|
||||||
|
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
|
||||||
|
# Apply pagination after sorting
|
||||||
|
investment_responses = investment_responses[offset:offset + page_size]
|
||||||
|
|
||||||
# Calculate total pages
|
# Calculate total pages
|
||||||
total_pages = (total_count + page_size - 1) // page_size
|
total_pages = (total_count + page_size - 1) // page_size
|
||||||
|
|
||||||
@@ -257,7 +281,14 @@ def filter_investors(
|
|||||||
# Get total count before pagination
|
# Get total count before pagination
|
||||||
total_count = query.count()
|
total_count = query.count()
|
||||||
|
|
||||||
# Calculate offset and apply pagination
|
# When project_id is provided, we need to get all funds first to sort by compatibility score
|
||||||
|
# Otherwise, we can paginate at the database level
|
||||||
|
if project is not None:
|
||||||
|
# Get all funds (we'll sort by compatibility score, then paginate)
|
||||||
|
all_funds = query.all()
|
||||||
|
funds = all_funds
|
||||||
|
else:
|
||||||
|
# Calculate offset and apply pagination (no sorting needed)
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
funds = query.offset(offset).limit(page_size).all()
|
funds = query.offset(offset).limit(page_size).all()
|
||||||
|
|
||||||
@@ -286,10 +317,10 @@ def filter_investors(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get top 3 sectors from fund (id and name only)
|
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
|
||||||
fund_sectors = [
|
fund_sectors = [
|
||||||
SectorMinimal(id=sector.id, name=sector.name)
|
SectorMinimal(id=sector.id, name=sector.name)
|
||||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
|
||||||
]
|
]
|
||||||
|
|
||||||
investment_response = InvestmentResponse(
|
investment_response = InvestmentResponse(
|
||||||
@@ -308,6 +339,13 @@ def filter_investors(
|
|||||||
)
|
)
|
||||||
investment_responses.append(investment_response)
|
investment_responses.append(investment_response)
|
||||||
|
|
||||||
|
# Sort by compatibility score (descending) when project_id is provided
|
||||||
|
if project is not None:
|
||||||
|
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
|
||||||
|
# Apply pagination after sorting
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
investment_responses = investment_responses[offset:offset + page_size]
|
||||||
|
|
||||||
# Calculate total pages
|
# Calculate total pages
|
||||||
total_pages = (total_count + page_size - 1) // page_size
|
total_pages = (total_count + page_size - 1) // page_size
|
||||||
|
|
||||||
|
|||||||
+45
-5
@@ -24,19 +24,29 @@ router = APIRouter(tags=["Project Routes"])
|
|||||||
def read_projects(
|
def read_projects(
|
||||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||||
|
include_archived: bool = Query(False, description="Include archived projects"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Get all projects with their related data (paginated)"""
|
"""Get all projects with their related data (paginated)
|
||||||
|
|
||||||
|
By default, archived projects are excluded. Set include_archived=True to include them.
|
||||||
|
"""
|
||||||
# Calculate offset
|
# Calculate offset
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# Start with base query
|
||||||
|
query = db.query(ProjectTable)
|
||||||
|
|
||||||
|
# Filter out archived projects by default
|
||||||
|
if not include_archived:
|
||||||
|
query = query.filter(ProjectTable.is_archived == 0)
|
||||||
|
|
||||||
# Get total count
|
# Get total count
|
||||||
total_count = db.query(ProjectTable).count()
|
total_count = query.count()
|
||||||
|
|
||||||
# Get paginated results
|
# Get paginated results
|
||||||
projects = (
|
projects = (
|
||||||
db.query(ProjectTable)
|
query.options(
|
||||||
.options(
|
|
||||||
selectinload(ProjectTable.sector),
|
selectinload(ProjectTable.sector),
|
||||||
selectinload(ProjectTable.investors),
|
selectinload(ProjectTable.investors),
|
||||||
selectinload(ProjectTable.companies),
|
selectinload(ProjectTable.companies),
|
||||||
@@ -162,7 +172,7 @@ def update_project(
|
|||||||
|
|
||||||
@router.delete("/projects/{project_id}")
|
@router.delete("/projects/{project_id}")
|
||||||
def delete_project(project_id: int, db: Session = Depends(get_db)):
|
def delete_project(project_id: int, db: Session = Depends(get_db)):
|
||||||
"""Delete a project"""
|
"""Delete a project permanently"""
|
||||||
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||||
|
|
||||||
if not db_project:
|
if not db_project:
|
||||||
@@ -174,6 +184,36 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
|
|||||||
return {"message": "Project deleted successfully"}
|
return {"message": "Project deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/archive")
|
||||||
|
def archive_project(project_id: int, db: Session = Depends(get_db)):
|
||||||
|
"""Archive a project (soft delete)"""
|
||||||
|
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||||
|
|
||||||
|
if not db_project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
db_project.is_archived = 1
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_project)
|
||||||
|
|
||||||
|
return {"message": "Project archived successfully", "project_id": project_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/unarchive")
|
||||||
|
def unarchive_project(project_id: int, db: Session = Depends(get_db)):
|
||||||
|
"""Unarchive a project (restore from archive)"""
|
||||||
|
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||||
|
|
||||||
|
if not db_project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
db_project.is_archived = 0
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_project)
|
||||||
|
|
||||||
|
return {"message": "Project unarchived successfully", "project_id": project_id}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
|
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
|
||||||
def filter_projects(
|
def filter_projects(
|
||||||
stage: Optional[InvestmentStage] = Query(
|
stage: Optional[InvestmentStage] = Query(
|
||||||
|
|||||||
@@ -174,9 +174,10 @@ def _calculate_project_fund_compatibility(
|
|||||||
or fund_geo_lower in project_location_lower
|
or fund_geo_lower in project_location_lower
|
||||||
):
|
):
|
||||||
geo_score = 15
|
geo_score = 15
|
||||||
# Check for common geographic terms or regional overlap
|
# Check for common geographic terms or regional overlap (continent/country matching)
|
||||||
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
|
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
|
||||||
geo_score = 12
|
# Give higher score for continent/country matches (e.g., Germany -> Europe)
|
||||||
|
geo_score = 18
|
||||||
|
|
||||||
total_score += geo_score
|
total_score += geo_score
|
||||||
|
|
||||||
@@ -298,9 +299,10 @@ def _calculate_project_investor_direct_compatibility(
|
|||||||
project_location_lower in investor_geo_lower
|
project_location_lower in investor_geo_lower
|
||||||
or investor_geo_lower in project_location_lower
|
or investor_geo_lower in project_location_lower
|
||||||
):
|
):
|
||||||
geo_score = 10
|
geo_score = 15
|
||||||
elif _check_geographic_overlap(project_location_lower, investor_geo_lower):
|
elif _check_geographic_overlap(project_location_lower, investor_geo_lower):
|
||||||
geo_score = 5
|
# Give higher score for continent/country matches (e.g., Germany -> Europe)
|
||||||
|
geo_score = 18
|
||||||
|
|
||||||
total_score += geo_score
|
total_score += geo_score
|
||||||
|
|
||||||
|
|||||||
@@ -258,10 +258,10 @@ Return ONLY the SQL query, no explanations or markdown.""",
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get top 3 sectors from fund (id and name only)
|
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
|
||||||
fund_sectors = [
|
fund_sectors = [
|
||||||
SectorMinimal(id=sector.id, name=sector.name)
|
SectorMinimal(id=sector.id, name=sector.name)
|
||||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
|
||||||
]
|
]
|
||||||
|
|
||||||
investment_response = InvestmentResponse(
|
investment_response = InvestmentResponse(
|
||||||
|
|||||||
Reference in New Issue
Block a user