from typing import Optional from db.db import get_db from db.models import FundTable, InvestorTable, SectorTable from fastapi import APIRouter, Depends from pydantic import BaseModel from sqlalchemy.orm import Session router = APIRouter(tags=["Additional Routes"]) # Response schemas class SectorsResponse(BaseModel): sectors: list[str] total: int class CountryInfo(BaseModel): name: str class ContinentInfo(BaseModel): name: str countries: list[str] class GeographyResponse(BaseModel): continents: list[ContinentInfo] total_continents: int total_countries: int # Mapping of countries to continents COUNTRY_TO_CONTINENT = { # Africa "Algeria": "Africa", "Angola": "Africa", "Benin": "Africa", "Botswana": "Africa", "Burkina Faso": "Africa", "Burundi": "Africa", "Cameroon": "Africa", "Cape Verde": "Africa", "Central African Republic": "Africa", "Chad": "Africa", "Comoros": "Africa", "Congo": "Africa", "Democratic Republic of the Congo": "Africa", "Djibouti": "Africa", "Egypt": "Africa", "Equatorial Guinea": "Africa", "Eritrea": "Africa", "Eswatini": "Africa", "Ethiopia": "Africa", "Gabon": "Africa", "Gambia": "Africa", "Ghana": "Africa", "Guinea": "Africa", "Guinea-Bissau": "Africa", "Ivory Coast": "Africa", "Kenya": "Africa", "Lesotho": "Africa", "Liberia": "Africa", "Libya": "Africa", "Madagascar": "Africa", "Malawi": "Africa", "Mali": "Africa", "Mauritania": "Africa", "Mauritius": "Africa", "Morocco": "Africa", "Mozambique": "Africa", "Namibia": "Africa", "Niger": "Africa", "Nigeria": "Africa", "Rwanda": "Africa", "Sao Tome and Principe": "Africa", "Senegal": "Africa", "Seychelles": "Africa", "Sierra Leone": "Africa", "Somalia": "Africa", "South Africa": "Africa", "South Sudan": "Africa", "Sudan": "Africa", "Tanzania": "Africa", "Togo": "Africa", "Tunisia": "Africa", "Uganda": "Africa", "Zambia": "Africa", "Zimbabwe": "Africa", # Asia "Afghanistan": "Asia", "Armenia": "Asia", "Azerbaijan": "Asia", "Bahrain": "Asia", "Bangladesh": "Asia", "Bhutan": "Asia", "Brunei": "Asia", "Cambodia": "Asia", "China": "Asia", "Cyprus": "Asia", "Georgia": "Asia", "Hong Kong": "Asia", "India": "Asia", "Indonesia": "Asia", "Iran": "Asia", "Iraq": "Asia", "Israel": "Asia", "Japan": "Asia", "Jordan": "Asia", "Kazakhstan": "Asia", "Kuwait": "Asia", "Kyrgyzstan": "Asia", "Laos": "Asia", "Lebanon": "Asia", "Malaysia": "Asia", "Maldives": "Asia", "Mongolia": "Asia", "Myanmar": "Asia", "Nepal": "Asia", "North Korea": "Asia", "Oman": "Asia", "Pakistan": "Asia", "Palestine": "Asia", "Philippines": "Asia", "Qatar": "Asia", "Saudi Arabia": "Asia", "Singapore": "Asia", "South Korea": "Asia", "Sri Lanka": "Asia", "Syria": "Asia", "Taiwan": "Asia", "Tajikistan": "Asia", "Thailand": "Asia", "Timor-Leste": "Asia", "Turkey": "Asia", "Turkmenistan": "Asia", "United Arab Emirates": "Asia", "UAE": "Asia", "Uzbekistan": "Asia", "Vietnam": "Asia", "Yemen": "Asia", # Europe "Albania": "Europe", "Andorra": "Europe", "Austria": "Europe", "Belarus": "Europe", "Belgium": "Europe", "Bosnia and Herzegovina": "Europe", "Bulgaria": "Europe", "Croatia": "Europe", "Czech Republic": "Europe", "Czechia": "Europe", "Denmark": "Europe", "Estonia": "Europe", "Finland": "Europe", "France": "Europe", "Germany": "Europe", "Greece": "Europe", "Hungary": "Europe", "Iceland": "Europe", "Ireland": "Europe", "Italy": "Europe", "Kosovo": "Europe", "Latvia": "Europe", "Liechtenstein": "Europe", "Lithuania": "Europe", "Luxembourg": "Europe", "Malta": "Europe", "Moldova": "Europe", "Monaco": "Europe", "Montenegro": "Europe", "Netherlands": "Europe", "North Macedonia": "Europe", "Norway": "Europe", "Poland": "Europe", "Portugal": "Europe", "Romania": "Europe", "Russia": "Europe", "San Marino": "Europe", "Serbia": "Europe", "Slovakia": "Europe", "Slovenia": "Europe", "Spain": "Europe", "Sweden": "Europe", "Switzerland": "Europe", "Ukraine": "Europe", "United Kingdom": "Europe", "UK": "Europe", "Vatican City": "Europe", # North America "Antigua and Barbuda": "North America", "Bahamas": "North America", "Barbados": "North America", "Belize": "North America", "Canada": "North America", "Costa Rica": "North America", "Cuba": "North America", "Dominica": "North America", "Dominican Republic": "North America", "El Salvador": "North America", "Grenada": "North America", "Guatemala": "North America", "Haiti": "North America", "Honduras": "North America", "Jamaica": "North America", "Mexico": "North America", "Nicaragua": "North America", "Panama": "North America", "Saint Kitts and Nevis": "North America", "Saint Lucia": "North America", "Saint Vincent and the Grenadines": "North America", "Trinidad and Tobago": "North America", "United States": "North America", "USA": "North America", "US": "North America", # South America "Argentina": "South America", "Bolivia": "South America", "Brazil": "South America", "Chile": "South America", "Colombia": "South America", "Ecuador": "South America", "Guyana": "South America", "Paraguay": "South America", "Peru": "South America", "Suriname": "South America", "Uruguay": "South America", "Venezuela": "South America", # Oceania "Australia": "Oceania", "Fiji": "Oceania", "Kiribati": "Oceania", "Marshall Islands": "Oceania", "Micronesia": "Oceania", "Nauru": "Oceania", "New Zealand": "Oceania", "Palau": "Oceania", "Papua New Guinea": "Oceania", "Samoa": "Oceania", "Solomon Islands": "Oceania", "Tonga": "Oceania", "Tuvalu": "Oceania", "Vanuatu": "Oceania", } # Valid continent names for direct matching VALID_CONTINENTS = { "Africa", "Asia", "Europe", "North America", "South America", "Oceania", "Antarctica", } def extract_countries_from_geographic_focus(geographic_focus: str) -> set[str]: """ Extract country names from a geographic_focus string. Handles comma-separated values, slashes, and various formats. """ if not geographic_focus: return set() countries = set() # Split by common delimiters parts = geographic_focus.replace("/", ",").replace(";", ",").split(",") for part in parts: cleaned = part.strip() if cleaned: # Check if it's a known country if cleaned in COUNTRY_TO_CONTINENT: countries.add(cleaned) # Check for partial matches (e.g., "United States of America" -> "United States") else: for country in COUNTRY_TO_CONTINENT.keys(): if country.lower() in cleaned.lower() or cleaned.lower() in country.lower(): countries.add(country) break return countries def organize_geography(geographic_data: list[str]) -> dict[str, set[str]]: """ Organize geographic data into continents and their countries. Returns a dict with continent names as keys and sets of countries as values. """ continent_countries: dict[str, set[str]] = {} for geo_focus in geographic_data: if not geo_focus: continue # Extract countries from the geographic focus string countries = extract_countries_from_geographic_focus(geo_focus) for country in countries: continent = COUNTRY_TO_CONTINENT.get(country) if continent: if continent not in continent_countries: continent_countries[continent] = set() continent_countries[continent].add(country) # Also check if the geographic focus itself is a continent cleaned_geo = geo_focus.strip() if cleaned_geo in VALID_CONTINENTS: if cleaned_geo not in continent_countries: continent_countries[cleaned_geo] = set() return continent_countries @router.get("/sectors", response_model=SectorsResponse) def get_unique_sectors(db: Session = Depends(get_db)): """ Get all unique sectors from the database. Returns a list of sector names sorted alphabetically. """ sectors = db.query(SectorTable.name).distinct().order_by(SectorTable.name).all() sector_names = [s[0] for s in sectors if s[0]] return SectorsResponse(sectors=sector_names, total=len(sector_names)) @router.get("/geography", response_model=GeographyResponse) def get_arranged_geography(db: Session = Depends(get_db)): """ Get all unique geographic locations arranged by continent and countries. Extracts geography from both investors and funds tables. Returns continents with their associated countries. """ # Collect all geographic focus data from investors investor_geo = ( db.query(InvestorTable.geographic_focus) .filter(InvestorTable.geographic_focus.isnot(None)) .distinct() .all() ) # Collect all geographic focus data from funds fund_geo = ( db.query(FundTable.geographic_focus) .filter(FundTable.geographic_focus.isnot(None)) .distinct() .all() ) # Combine all geographic data all_geo_data = [g[0] for g in investor_geo] + [g[0] for g in fund_geo] # Organize into continents and countries continent_countries = organize_geography(all_geo_data) # Build response continents = [] total_countries = 0 for continent_name in sorted(continent_countries.keys()): countries = sorted(continent_countries[continent_name]) total_countries += len(countries) continents.append(ContinentInfo(name=continent_name, countries=countries)) return GeographyResponse( continents=continents, total_continents=len(continents), total_countries=total_countries, )