diff --git a/app/main.py b/app/main.py index a95376d..10cf78a 100644 --- a/app/main.py +++ b/app/main.py @@ -8,6 +8,7 @@ from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from routers import ( + addition, companies, folk_crm, insight_route, @@ -154,6 +155,7 @@ app.include_router(projects.router) app.include_router(folk_crm.router) app.include_router(insight_route.router) app.include_router(report_route.router) +app.include_router(addition.router) if __name__ == "__main__": import uvicorn diff --git a/app/routers/addition.py b/app/routers/addition.py new file mode 100644 index 0000000..252c687 --- /dev/null +++ b/app/routers/addition.py @@ -0,0 +1,370 @@ +from typing import Optional + +from db.db import get_db +from db.models import FundTable, InvestorTable, SectorTable +from fastapi import APIRouter, Depends +from pydantic import BaseModel +from sqlalchemy.orm import Session + +router = APIRouter(tags=["Additional Routes"]) + + +# Response schemas +class SectorsResponse(BaseModel): + sectors: list[str] + total: int + + +class CountryInfo(BaseModel): + name: str + + +class ContinentInfo(BaseModel): + name: str + countries: list[str] + + +class GeographyResponse(BaseModel): + continents: list[ContinentInfo] + total_continents: int + total_countries: int + + +# Mapping of countries to continents +COUNTRY_TO_CONTINENT = { + # Africa + "Algeria": "Africa", + "Angola": "Africa", + "Benin": "Africa", + "Botswana": "Africa", + "Burkina Faso": "Africa", + "Burundi": "Africa", + "Cameroon": "Africa", + "Cape Verde": "Africa", + "Central African Republic": "Africa", + "Chad": "Africa", + "Comoros": "Africa", + "Congo": "Africa", + "Democratic Republic of the Congo": "Africa", + "Djibouti": "Africa", + "Egypt": "Africa", + "Equatorial Guinea": "Africa", + "Eritrea": "Africa", + "Eswatini": "Africa", + "Ethiopia": "Africa", + "Gabon": "Africa", + "Gambia": "Africa", + "Ghana": "Africa", + "Guinea": "Africa", + "Guinea-Bissau": "Africa", + "Ivory Coast": "Africa", + "Kenya": "Africa", + "Lesotho": "Africa", + "Liberia": "Africa", + "Libya": "Africa", + "Madagascar": "Africa", + "Malawi": "Africa", + "Mali": "Africa", + "Mauritania": "Africa", + "Mauritius": "Africa", + "Morocco": "Africa", + "Mozambique": "Africa", + "Namibia": "Africa", + "Niger": "Africa", + "Nigeria": "Africa", + "Rwanda": "Africa", + "Sao Tome and Principe": "Africa", + "Senegal": "Africa", + "Seychelles": "Africa", + "Sierra Leone": "Africa", + "Somalia": "Africa", + "South Africa": "Africa", + "South Sudan": "Africa", + "Sudan": "Africa", + "Tanzania": "Africa", + "Togo": "Africa", + "Tunisia": "Africa", + "Uganda": "Africa", + "Zambia": "Africa", + "Zimbabwe": "Africa", + # Asia + "Afghanistan": "Asia", + "Armenia": "Asia", + "Azerbaijan": "Asia", + "Bahrain": "Asia", + "Bangladesh": "Asia", + "Bhutan": "Asia", + "Brunei": "Asia", + "Cambodia": "Asia", + "China": "Asia", + "Cyprus": "Asia", + "Georgia": "Asia", + "Hong Kong": "Asia", + "India": "Asia", + "Indonesia": "Asia", + "Iran": "Asia", + "Iraq": "Asia", + "Israel": "Asia", + "Japan": "Asia", + "Jordan": "Asia", + "Kazakhstan": "Asia", + "Kuwait": "Asia", + "Kyrgyzstan": "Asia", + "Laos": "Asia", + "Lebanon": "Asia", + "Malaysia": "Asia", + "Maldives": "Asia", + "Mongolia": "Asia", + "Myanmar": "Asia", + "Nepal": "Asia", + "North Korea": "Asia", + "Oman": "Asia", + "Pakistan": "Asia", + "Palestine": "Asia", + "Philippines": "Asia", + "Qatar": "Asia", + "Saudi Arabia": "Asia", + "Singapore": "Asia", + "South Korea": "Asia", + "Sri Lanka": "Asia", + "Syria": "Asia", + "Taiwan": "Asia", + "Tajikistan": "Asia", + "Thailand": "Asia", + "Timor-Leste": "Asia", + "Turkey": "Asia", + "Turkmenistan": "Asia", + "United Arab Emirates": "Asia", + "UAE": "Asia", + "Uzbekistan": "Asia", + "Vietnam": "Asia", + "Yemen": "Asia", + # Europe + "Albania": "Europe", + "Andorra": "Europe", + "Austria": "Europe", + "Belarus": "Europe", + "Belgium": "Europe", + "Bosnia and Herzegovina": "Europe", + "Bulgaria": "Europe", + "Croatia": "Europe", + "Czech Republic": "Europe", + "Czechia": "Europe", + "Denmark": "Europe", + "Estonia": "Europe", + "Finland": "Europe", + "France": "Europe", + "Germany": "Europe", + "Greece": "Europe", + "Hungary": "Europe", + "Iceland": "Europe", + "Ireland": "Europe", + "Italy": "Europe", + "Kosovo": "Europe", + "Latvia": "Europe", + "Liechtenstein": "Europe", + "Lithuania": "Europe", + "Luxembourg": "Europe", + "Malta": "Europe", + "Moldova": "Europe", + "Monaco": "Europe", + "Montenegro": "Europe", + "Netherlands": "Europe", + "North Macedonia": "Europe", + "Norway": "Europe", + "Poland": "Europe", + "Portugal": "Europe", + "Romania": "Europe", + "Russia": "Europe", + "San Marino": "Europe", + "Serbia": "Europe", + "Slovakia": "Europe", + "Slovenia": "Europe", + "Spain": "Europe", + "Sweden": "Europe", + "Switzerland": "Europe", + "Ukraine": "Europe", + "United Kingdom": "Europe", + "UK": "Europe", + "Vatican City": "Europe", + # North America + "Antigua and Barbuda": "North America", + "Bahamas": "North America", + "Barbados": "North America", + "Belize": "North America", + "Canada": "North America", + "Costa Rica": "North America", + "Cuba": "North America", + "Dominica": "North America", + "Dominican Republic": "North America", + "El Salvador": "North America", + "Grenada": "North America", + "Guatemala": "North America", + "Haiti": "North America", + "Honduras": "North America", + "Jamaica": "North America", + "Mexico": "North America", + "Nicaragua": "North America", + "Panama": "North America", + "Saint Kitts and Nevis": "North America", + "Saint Lucia": "North America", + "Saint Vincent and the Grenadines": "North America", + "Trinidad and Tobago": "North America", + "United States": "North America", + "USA": "North America", + "US": "North America", + # South America + "Argentina": "South America", + "Bolivia": "South America", + "Brazil": "South America", + "Chile": "South America", + "Colombia": "South America", + "Ecuador": "South America", + "Guyana": "South America", + "Paraguay": "South America", + "Peru": "South America", + "Suriname": "South America", + "Uruguay": "South America", + "Venezuela": "South America", + # Oceania + "Australia": "Oceania", + "Fiji": "Oceania", + "Kiribati": "Oceania", + "Marshall Islands": "Oceania", + "Micronesia": "Oceania", + "Nauru": "Oceania", + "New Zealand": "Oceania", + "Palau": "Oceania", + "Papua New Guinea": "Oceania", + "Samoa": "Oceania", + "Solomon Islands": "Oceania", + "Tonga": "Oceania", + "Tuvalu": "Oceania", + "Vanuatu": "Oceania", +} + +# Valid continent names for direct matching +VALID_CONTINENTS = { + "Africa", + "Asia", + "Europe", + "North America", + "South America", + "Oceania", + "Antarctica", +} + + +def extract_countries_from_geographic_focus(geographic_focus: str) -> set[str]: + """ + Extract country names from a geographic_focus string. + Handles comma-separated values, slashes, and various formats. + """ + if not geographic_focus: + return set() + + countries = set() + # Split by common delimiters + parts = geographic_focus.replace("/", ",").replace(";", ",").split(",") + + for part in parts: + cleaned = part.strip() + if cleaned: + # Check if it's a known country + if cleaned in COUNTRY_TO_CONTINENT: + countries.add(cleaned) + # Check for partial matches (e.g., "United States of America" -> "United States") + else: + for country in COUNTRY_TO_CONTINENT.keys(): + if country.lower() in cleaned.lower() or cleaned.lower() in country.lower(): + countries.add(country) + break + + return countries + + +def organize_geography(geographic_data: list[str]) -> dict[str, set[str]]: + """ + Organize geographic data into continents and their countries. + Returns a dict with continent names as keys and sets of countries as values. + """ + continent_countries: dict[str, set[str]] = {} + + for geo_focus in geographic_data: + if not geo_focus: + continue + + # Extract countries from the geographic focus string + countries = extract_countries_from_geographic_focus(geo_focus) + + for country in countries: + continent = COUNTRY_TO_CONTINENT.get(country) + if continent: + if continent not in continent_countries: + continent_countries[continent] = set() + continent_countries[continent].add(country) + + # Also check if the geographic focus itself is a continent + cleaned_geo = geo_focus.strip() + if cleaned_geo in VALID_CONTINENTS: + if cleaned_geo not in continent_countries: + continent_countries[cleaned_geo] = set() + + return continent_countries + + +@router.get("/sectors", response_model=SectorsResponse) +def get_unique_sectors(db: Session = Depends(get_db)): + """ + Get all unique sectors from the database. + Returns a list of sector names sorted alphabetically. + """ + sectors = db.query(SectorTable.name).distinct().order_by(SectorTable.name).all() + sector_names = [s[0] for s in sectors if s[0]] + + return SectorsResponse(sectors=sector_names, total=len(sector_names)) + + +@router.get("/geography", response_model=GeographyResponse) +def get_arranged_geography(db: Session = Depends(get_db)): + """ + Get all unique geographic locations arranged by continent and countries. + Extracts geography from both investors and funds tables. + Returns continents with their associated countries. + """ + # Collect all geographic focus data from investors + investor_geo = ( + db.query(InvestorTable.geographic_focus) + .filter(InvestorTable.geographic_focus.isnot(None)) + .distinct() + .all() + ) + + # Collect all geographic focus data from funds + fund_geo = ( + db.query(FundTable.geographic_focus) + .filter(FundTable.geographic_focus.isnot(None)) + .distinct() + .all() + ) + + # Combine all geographic data + all_geo_data = [g[0] for g in investor_geo] + [g[0] for g in fund_geo] + + # Organize into continents and countries + continent_countries = organize_geography(all_geo_data) + + # Build response + continents = [] + total_countries = 0 + + for continent_name in sorted(continent_countries.keys()): + countries = sorted(continent_countries[continent_name]) + total_countries += len(countries) + continents.append(ContinentInfo(name=continent_name, countries=countries)) + + return GeographyResponse( + continents=continents, + total_continents=len(continents), + total_countries=total_countries, + ) diff --git a/app/routers/projects.py b/app/routers/projects.py index c4fcd17..9a0b70d 100644 --- a/app/routers/projects.py +++ b/app/routers/projects.py @@ -214,6 +214,57 @@ def unarchive_project(project_id: int, db: Session = Depends(get_db)): return {"message": "Project unarchived successfully", "project_id": project_id} +@router.get("/projects/archived", response_model=PaginatedResponse[ProjectData]) +def read_archived_projects( + page: int = Query(1, ge=1, description="Page number (starts at 1)"), + page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), + db: Session = Depends(get_db), +): + """Get all archived projects (paginated)""" + # Calculate offset + offset = (page - 1) * page_size + + # Query only archived projects + query = db.query(ProjectTable).filter(ProjectTable.is_archived == 1) + + # Get total count + total_count = query.count() + + # Get paginated results + projects = ( + query.options( + selectinload(ProjectTable.sector), + selectinload(ProjectTable.investors), + selectinload(ProjectTable.companies), + ) + .offset(offset) + .limit(page_size) + .all() + ) + + # Transform ProjectTable objects to ProjectData format + project_data_list = [] + for project in projects: + project_data = ProjectData( + project=project, + sector=project.sector, + investors=project.investors, + companies=project.companies, + ) + project_data_list.append(project_data) + + # Calculate total pages + total_pages = (total_count + page_size - 1) // page_size + + return PaginatedResponse( + items=project_data_list, + total=total_count, + page=page, + page_size=page_size, + total_pages=total_pages, + ) + + @router.get("/projects/filter", response_model=PaginatedResponse[ProjectData]) def filter_projects( stage: Optional[InvestmentStage] = Query( diff --git a/app/services/insight.py b/app/services/insight.py index 06cca4c..324a154 100644 --- a/app/services/insight.py +++ b/app/services/insight.py @@ -49,7 +49,7 @@ class QueryProcessor: """Tool to search the web using google, provide the relevant query to get the information""" logger.info(f"\nWeb Search Tool Called with query: {query}") if query: - result = self.ddg_search.text(query, max_results=10, backend="google") + result = self.ddg_search.text(query, max_results=10) return result return "No query provided." diff --git a/investors.db b/investors.db index dceae79..07c306b 100644 Binary files a/investors.db and b/investors.db differ