diff --git a/app/db/models.py b/app/db/models.py index 983533b..f9badcc 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -296,6 +296,7 @@ class ProjectTable(Base, TimestampMixin): stage = Column(Enum(InvestmentStage), nullable=True) location = Column(String, nullable=True) + industry = Column(String, nullable=True) description = Column(Text, nullable=True) start_date = Column(DateTime, nullable=True) end_date = Column(DateTime, nullable=True) diff --git a/app/routers/projects.py b/app/routers/projects.py index 6ebc7df..bc276aa 100644 --- a/app/routers/projects.py +++ b/app/routers/projects.py @@ -182,6 +182,7 @@ def filter_projects( min_valuation: Optional[int] = Query(None, description="Minimum valuation"), max_valuation: Optional[int] = Query(None, description="Maximum valuation"), location: Optional[str] = Query(None, description="Location (partial match)"), + industry: Optional[str] = Query(None, description="Industry (partial match)"), sector: Optional[str] = Query(None, description="Sector name (partial match)"), investor_name: Optional[str] = Query( None, description="Investor name (partial match)" @@ -215,6 +216,9 @@ def filter_projects( if location: query = query.filter(ProjectTable.location.ilike(f"%{location}%")) + if industry: + query = query.filter(ProjectTable.industry.ilike(f"%{industry}%")) + if sector: query = query.join(ProjectTable.sector).filter( SectorTable.name.ilike(f"%{sector}%") diff --git a/app/schemas/project_schemas.py b/app/schemas/project_schemas.py index c084fd1..7138f26 100644 --- a/app/schemas/project_schemas.py +++ b/app/schemas/project_schemas.py @@ -60,6 +60,7 @@ class ProjectSchema(BaseModel): valuation: int | None stage: InvestmentStage | None location: str | None + industry: str | None description: Optional[str] start_date: Optional[datetime] end_date: Optional[datetime] @@ -75,6 +76,7 @@ class ProjectCreate(BaseModel): valuation: Optional[int] = None stage: Optional[InvestmentStage] = None location: Optional[str] = None + industry: Optional[str] = None description: Optional[str] = None start_date: Optional[datetime] = None end_date: Optional[datetime] = None @@ -85,6 +87,7 @@ class ProjectUpdate(BaseModel): valuation: Optional[int] = None stage: Optional[InvestmentStage] = None location: Optional[str] = None + industry: Optional[str] = None description: Optional[str] = None start_date: Optional[datetime] = None end_date: Optional[datetime] = None diff --git a/app/services/querying.py b/app/services/querying.py index 252fce9..5bd0219 100644 --- a/app/services/querying.py +++ b/app/services/querying.py @@ -37,7 +37,7 @@ class QueryProcessor: self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) # Update system message to specifically request only fund IDs system_message_updated = ( - prompt_template.format(dialect="SQLite", top_k=5) + prompt_template.format(dialect="SQLite", top_k=100) + "\n\n=== IMPORTANT TERMINOLOGY ===" + "\n- When users say 'investors' or 'find me investors', they mean FUNDS" + "\n- Always query the 'funds' table for investment opportunities" @@ -51,8 +51,19 @@ class QueryProcessor: + "\n1. For geographic searches: use funds.geographic_focus" + "\n2. For sector searches: JOIN with fund_sectors table" + "\n3. For stage searches: JOIN with fund_investment_stages table" - + "\n4. If no results: respond with 'NO_RESULTS'" - + "\n5. Never repeat the same failed query" + + "\n4. Return ALL matching fund IDs, not just the first few" + + "\n5. If no results: respond with 'NO_RESULTS'" + + "\n6. Never repeat the same failed query" + + "\n\n=== GEOGRAPHIC SEARCH RULES (VERY IMPORTANT) ===" + + "\n- ALWAYS use LIKE '%keyword%' for geographic searches, NEVER use exact equality (=)" + + "\n- When user says 'Europe', match ANY location containing 'Europe' (e.g., 'Northern Europe', 'Western Europe', 'Europe', 'Central Europe')" + + "\n- When user says 'America', match locations like 'North America', 'South America', 'Latin America', 'United States'" + + "\n- When user says 'Asia', match 'Asia', 'Southeast Asia', 'East Asia', etc." + + "\n- Examples:" + + "\n * User: 'Europe' → SQL: WHERE geographic_focus LIKE '%Europe%'" + + "\n * User: 'America' → SQL: WHERE geographic_focus LIKE '%America%'" + + "\n * User: 'UK' → SQL: WHERE geographic_focus LIKE '%UK%' OR geographic_focus LIKE '%United Kingdom%'" + + "\n- Be INCLUSIVE: capture all relevant regional variations" ) self.agent = create_react_agent( model=self.llm, diff --git a/investors.db b/investors.db index ccc9762..c29bad8 100644 Binary files a/investors.db and b/investors.db differ diff --git a/migrations/add_industry_to_projects.py b/migrations/add_industry_to_projects.py new file mode 100644 index 0000000..f277f5a --- /dev/null +++ b/migrations/add_industry_to_projects.py @@ -0,0 +1,67 @@ +""" +Migration: Add industry column to projects table +Date: 2025-10-23 +""" + +import os +import sys +from pathlib import Path + +# Add parent directory to path to import app modules +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import create_engine, text +from app.db.db import DATABASE_URL, engine + + +def upgrade(): + """Add industry column to projects table""" + print("Running migration: Add industry column to projects table") + + with engine.connect() as conn: + # Check if column already exists + result = conn.execute(text("PRAGMA table_info(projects)")) + columns = [row[1] for row in result] + + if 'industry' in columns: + print("Column 'industry' already exists in projects table. Skipping migration.") + return + + # Add the industry column + conn.execute(text("ALTER TABLE projects ADD COLUMN industry VARCHAR")) + conn.commit() + + print("Successfully added 'industry' column to projects table") + + +def downgrade(): + """Remove industry column from projects table""" + print("Running downgrade: Remove industry column from projects table") + + # Note: SQLite doesn't support DROP COLUMN directly + # This is a simplified version - in production you'd need to recreate the table + print("Warning: SQLite doesn't support DROP COLUMN.") + print("To remove the column, you would need to:") + print("1. Create a new table without the industry column") + print("2. Copy data from old table to new table") + print("3. Drop old table and rename new table") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run database migration") + parser.add_argument( + "direction", + choices=["upgrade", "downgrade"], + default="upgrade", + nargs="?", + help="Migration direction (default: upgrade)" + ) + + args = parser.parse_args() + + if args.direction == "upgrade": + upgrade() + else: + downgrade()