diff --git a/app/db/models.py b/app/db/models.py index f9badcc..038b3d3 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -162,6 +162,7 @@ class InvestorMember(Base, TimestampMixin): role = Column(String, nullable=True) title = Column(String, nullable=True) # Alternative to role email = Column(String, nullable=True) + linkedin = Column(String, nullable=True) # LinkedIn profile URL source_url = Column(String, nullable=True) # URL where member info was found investor_id = Column(Integer, ForeignKey("investors.id")) @@ -215,6 +216,8 @@ class CompanyTable(Base, TimestampMixin): description = Column(String, nullable=True) founded_year = Column(Integer, nullable=True) website = Column(String, nullable=True) + product_service = Column(Text, nullable=True) # Product/service description + clients = Column(JSON, nullable=True) # List of client names or client information members = relationship( "CompanyMember", back_populates="company", cascade="all, delete-orphan" @@ -300,6 +303,7 @@ class ProjectTable(Base, TimestampMixin): description = Column(Text, nullable=True) start_date = Column(DateTime, nullable=True) end_date = Column(DateTime, nullable=True) + is_archived = Column(Integer, default=0, nullable=False) # 0 = active, 1 = archived sector = relationship( "SectorTable", secondary=project_sector_association, back_populates="projects" diff --git a/app/routers/companies.py b/app/routers/companies.py index d97a41b..887e438 100644 --- a/app/routers/companies.py +++ b/app/routers/companies.py @@ -63,11 +63,13 @@ def read_companies( # Transform CompanyTable objects to CompanyData format company_data_list = [] for company in companies: + # Sort sectors alphabetically + sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else [] company_data = CompanyData( company=company, investors=company.investors, members=company.members, - sectors=company.sectors, + sectors=sorted_sectors, ) company_data_list.append(company_data) @@ -147,11 +149,13 @@ def filter_companies( # Transform to CompanyData format company_data_list = [] for company in companies: + # Sort sectors alphabetically + sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else [] company_data = CompanyData( company=company, investors=company.investors, members=company.members, - sectors=company.sectors, + sectors=sorted_sectors, ) company_data_list.append(company_data) @@ -184,12 +188,15 @@ def read_company(company_id: int, db: Session = Depends(get_db)): if not company: raise HTTPException(status_code=404, detail="Company not found") + # Sort sectors alphabetically + sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else [] + # Transform to CompanyData format return CompanyData( company=company, investors=company.investors, members=company.members, - sectors=company.sectors, + sectors=sorted_sectors, ) @@ -250,12 +257,15 @@ def update_company( .first() ) + # Sort sectors alphabetically + sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else [] + # Transform to CompanyData format return CompanyData( company=company_with_relations, investors=company_with_relations.investors, members=company_with_relations.members, - sectors=company_with_relations.sectors, + sectors=sorted_sectors, ) diff --git a/app/routers/investors.py b/app/routers/investors.py index 4448c60..5d41df0 100644 --- a/app/routers/investors.py +++ b/app/routers/investors.py @@ -81,20 +81,38 @@ def read_investors( if not project: raise HTTPException(status_code=404, detail="Project not found") - # Get paginated results - investors = ( - db.query(InvestorTable) - .options( - selectinload(InvestorTable.portfolio_companies), - selectinload(InvestorTable.team_members), - selectinload(InvestorTable.sectors), - selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), - selectinload(InvestorTable.funds).selectinload(FundTable.sectors), + # When project_id is provided, we need to get all investors first to sort by compatibility score + # Otherwise, we can paginate at the database level + if project is not None: + # Get all investors (we'll sort by compatibility score, then paginate) + all_investors = ( + db.query(InvestorTable) + .options( + selectinload(InvestorTable.portfolio_companies), + selectinload(InvestorTable.team_members), + selectinload(InvestorTable.sectors), + selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), + selectinload(InvestorTable.funds).selectinload(FundTable.sectors), + ) + .all() + ) + # We'll paginate after sorting by compatibility score + investors = all_investors + else: + # Get paginated results (no sorting needed) + investors = ( + db.query(InvestorTable) + .options( + selectinload(InvestorTable.portfolio_companies), + selectinload(InvestorTable.team_members), + selectinload(InvestorTable.sectors), + selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), + selectinload(InvestorTable.funds).selectinload(FundTable.sectors), + ) + .offset(offset) + .limit(page_size) + .all() ) - .offset(offset) - .limit(page_size) - .all() - ) # Transform to InvestmentResponse format (one row per investor-fund combination) investment_responses = [] @@ -122,10 +140,10 @@ def read_investors( else None ) - # Get top 3 sectors from fund (id and name only) + # Get top 3 sectors from fund (id and name only) - sorted alphabetically fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) - for sector in (fund.sectors[:3] if fund.sectors else []) + for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name) ] investment_response = InvestmentResponse( @@ -166,6 +184,12 @@ def read_investors( ) investment_responses.append(investment_response) + # Sort by compatibility score (descending) when project_id is provided + if project is not None: + investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True) + # Apply pagination after sorting + investment_responses = investment_responses[offset:offset + page_size] + # Calculate total pages total_pages = (total_count + page_size - 1) // page_size @@ -257,9 +281,16 @@ def filter_investors( # Get total count before pagination total_count = query.count() - # Calculate offset and apply pagination - offset = (page - 1) * page_size - funds = query.offset(offset).limit(page_size).all() + # When project_id is provided, we need to get all funds first to sort by compatibility score + # Otherwise, we can paginate at the database level + if project is not None: + # Get all funds (we'll sort by compatibility score, then paginate) + all_funds = query.all() + funds = all_funds + else: + # Calculate offset and apply pagination (no sorting needed) + offset = (page - 1) * page_size + funds = query.offset(offset).limit(page_size).all() # Transform to InvestmentResponse format (one row per fund) investment_responses = [] @@ -286,10 +317,10 @@ def filter_investors( else None ) - # Get top 3 sectors from fund (id and name only) + # Get top 3 sectors from fund (id and name only) - sorted alphabetically fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) - for sector in (fund.sectors[:3] if fund.sectors else []) + for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name) ] investment_response = InvestmentResponse( @@ -308,6 +339,13 @@ def filter_investors( ) investment_responses.append(investment_response) + # Sort by compatibility score (descending) when project_id is provided + if project is not None: + investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True) + # Apply pagination after sorting + offset = (page - 1) * page_size + investment_responses = investment_responses[offset:offset + page_size] + # Calculate total pages total_pages = (total_count + page_size - 1) // page_size diff --git a/app/routers/projects.py b/app/routers/projects.py index bc276aa..c4fcd17 100644 --- a/app/routers/projects.py +++ b/app/routers/projects.py @@ -24,19 +24,29 @@ router = APIRouter(tags=["Project Routes"]) def read_projects( page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), + include_archived: bool = Query(False, description="Include archived projects"), db: Session = Depends(get_db), ): - """Get all projects with their related data (paginated)""" + """Get all projects with their related data (paginated) + + By default, archived projects are excluded. Set include_archived=True to include them. + """ # Calculate offset offset = (page - 1) * page_size + # Start with base query + query = db.query(ProjectTable) + + # Filter out archived projects by default + if not include_archived: + query = query.filter(ProjectTable.is_archived == 0) + # Get total count - total_count = db.query(ProjectTable).count() + total_count = query.count() # Get paginated results projects = ( - db.query(ProjectTable) - .options( + query.options( selectinload(ProjectTable.sector), selectinload(ProjectTable.investors), selectinload(ProjectTable.companies), @@ -162,7 +172,7 @@ def update_project( @router.delete("/projects/{project_id}") def delete_project(project_id: int, db: Session = Depends(get_db)): - """Delete a project""" + """Delete a project permanently""" db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not db_project: @@ -174,6 +184,36 @@ def delete_project(project_id: int, db: Session = Depends(get_db)): return {"message": "Project deleted successfully"} +@router.post("/projects/{project_id}/archive") +def archive_project(project_id: int, db: Session = Depends(get_db)): + """Archive a project (soft delete)""" + db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() + + if not db_project: + raise HTTPException(status_code=404, detail="Project not found") + + db_project.is_archived = 1 + db.commit() + db.refresh(db_project) + + return {"message": "Project archived successfully", "project_id": project_id} + + +@router.post("/projects/{project_id}/unarchive") +def unarchive_project(project_id: int, db: Session = Depends(get_db)): + """Unarchive a project (restore from archive)""" + db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() + + if not db_project: + raise HTTPException(status_code=404, detail="Project not found") + + db_project.is_archived = 0 + db.commit() + db.refresh(db_project) + + return {"message": "Project unarchived successfully", "project_id": project_id} + + @router.get("/projects/filter", response_model=PaginatedResponse[ProjectData]) def filter_projects( stage: Optional[InvestmentStage] = Query( diff --git a/app/services/compatibility_score.py b/app/services/compatibility_score.py index 1576fdb..707c0ba 100644 --- a/app/services/compatibility_score.py +++ b/app/services/compatibility_score.py @@ -174,9 +174,10 @@ def _calculate_project_fund_compatibility( or fund_geo_lower in project_location_lower ): geo_score = 15 - # Check for common geographic terms or regional overlap + # Check for common geographic terms or regional overlap (continent/country matching) elif _check_geographic_overlap(project_location_lower, fund_geo_lower): - geo_score = 12 + # Give higher score for continent/country matches (e.g., Germany -> Europe) + geo_score = 18 total_score += geo_score @@ -298,9 +299,10 @@ def _calculate_project_investor_direct_compatibility( project_location_lower in investor_geo_lower or investor_geo_lower in project_location_lower ): - geo_score = 10 + geo_score = 15 elif _check_geographic_overlap(project_location_lower, investor_geo_lower): - geo_score = 5 + # Give higher score for continent/country matches (e.g., Germany -> Europe) + geo_score = 18 total_score += geo_score diff --git a/app/services/querying.py b/app/services/querying.py index 2a8bca5..94fb252 100644 --- a/app/services/querying.py +++ b/app/services/querying.py @@ -258,10 +258,10 @@ Return ONLY the SQL query, no explanations or markdown.""", else None ) - # Get top 3 sectors from fund (id and name only) + # Get top 3 sectors from fund (id and name only) - sorted alphabetically fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) - for sector in (fund.sectors[:3] if fund.sectors else []) + for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name) ] investment_response = InvestmentResponse(