feat: Add industry column to ProjectTable and update related schemas and query filters
This commit is contained in:
@@ -296,6 +296,7 @@ class ProjectTable(Base, TimestampMixin):
|
|||||||
|
|
||||||
stage = Column(Enum(InvestmentStage), nullable=True)
|
stage = Column(Enum(InvestmentStage), nullable=True)
|
||||||
location = Column(String, nullable=True)
|
location = Column(String, nullable=True)
|
||||||
|
industry = Column(String, nullable=True)
|
||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
start_date = Column(DateTime, nullable=True)
|
start_date = Column(DateTime, nullable=True)
|
||||||
end_date = Column(DateTime, nullable=True)
|
end_date = Column(DateTime, nullable=True)
|
||||||
|
|||||||
@@ -182,6 +182,7 @@ def filter_projects(
|
|||||||
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
|
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
|
||||||
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
|
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
|
||||||
location: Optional[str] = Query(None, description="Location (partial match)"),
|
location: Optional[str] = Query(None, description="Location (partial match)"),
|
||||||
|
industry: Optional[str] = Query(None, description="Industry (partial match)"),
|
||||||
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
||||||
investor_name: Optional[str] = Query(
|
investor_name: Optional[str] = Query(
|
||||||
None, description="Investor name (partial match)"
|
None, description="Investor name (partial match)"
|
||||||
@@ -215,6 +216,9 @@ def filter_projects(
|
|||||||
if location:
|
if location:
|
||||||
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
|
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
|
||||||
|
|
||||||
|
if industry:
|
||||||
|
query = query.filter(ProjectTable.industry.ilike(f"%{industry}%"))
|
||||||
|
|
||||||
if sector:
|
if sector:
|
||||||
query = query.join(ProjectTable.sector).filter(
|
query = query.join(ProjectTable.sector).filter(
|
||||||
SectorTable.name.ilike(f"%{sector}%")
|
SectorTable.name.ilike(f"%{sector}%")
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ class ProjectSchema(BaseModel):
|
|||||||
valuation: int | None
|
valuation: int | None
|
||||||
stage: InvestmentStage | None
|
stage: InvestmentStage | None
|
||||||
location: str | None
|
location: str | None
|
||||||
|
industry: str | None
|
||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
start_date: Optional[datetime]
|
start_date: Optional[datetime]
|
||||||
end_date: Optional[datetime]
|
end_date: Optional[datetime]
|
||||||
@@ -75,6 +76,7 @@ class ProjectCreate(BaseModel):
|
|||||||
valuation: Optional[int] = None
|
valuation: Optional[int] = None
|
||||||
stage: Optional[InvestmentStage] = None
|
stage: Optional[InvestmentStage] = None
|
||||||
location: Optional[str] = None
|
location: Optional[str] = None
|
||||||
|
industry: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
start_date: Optional[datetime] = None
|
start_date: Optional[datetime] = None
|
||||||
end_date: Optional[datetime] = None
|
end_date: Optional[datetime] = None
|
||||||
@@ -85,6 +87,7 @@ class ProjectUpdate(BaseModel):
|
|||||||
valuation: Optional[int] = None
|
valuation: Optional[int] = None
|
||||||
stage: Optional[InvestmentStage] = None
|
stage: Optional[InvestmentStage] = None
|
||||||
location: Optional[str] = None
|
location: Optional[str] = None
|
||||||
|
industry: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
start_date: Optional[datetime] = None
|
start_date: Optional[datetime] = None
|
||||||
end_date: Optional[datetime] = None
|
end_date: Optional[datetime] = None
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class QueryProcessor:
|
|||||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||||
# Update system message to specifically request only fund IDs
|
# Update system message to specifically request only fund IDs
|
||||||
system_message_updated = (
|
system_message_updated = (
|
||||||
prompt_template.format(dialect="SQLite", top_k=5)
|
prompt_template.format(dialect="SQLite", top_k=100)
|
||||||
+ "\n\n=== IMPORTANT TERMINOLOGY ==="
|
+ "\n\n=== IMPORTANT TERMINOLOGY ==="
|
||||||
+ "\n- When users say 'investors' or 'find me investors', they mean FUNDS"
|
+ "\n- When users say 'investors' or 'find me investors', they mean FUNDS"
|
||||||
+ "\n- Always query the 'funds' table for investment opportunities"
|
+ "\n- Always query the 'funds' table for investment opportunities"
|
||||||
@@ -51,8 +51,19 @@ class QueryProcessor:
|
|||||||
+ "\n1. For geographic searches: use funds.geographic_focus"
|
+ "\n1. For geographic searches: use funds.geographic_focus"
|
||||||
+ "\n2. For sector searches: JOIN with fund_sectors table"
|
+ "\n2. For sector searches: JOIN with fund_sectors table"
|
||||||
+ "\n3. For stage searches: JOIN with fund_investment_stages table"
|
+ "\n3. For stage searches: JOIN with fund_investment_stages table"
|
||||||
+ "\n4. If no results: respond with 'NO_RESULTS'"
|
+ "\n4. Return ALL matching fund IDs, not just the first few"
|
||||||
+ "\n5. Never repeat the same failed query"
|
+ "\n5. If no results: respond with 'NO_RESULTS'"
|
||||||
|
+ "\n6. Never repeat the same failed query"
|
||||||
|
+ "\n\n=== GEOGRAPHIC SEARCH RULES (VERY IMPORTANT) ==="
|
||||||
|
+ "\n- ALWAYS use LIKE '%keyword%' for geographic searches, NEVER use exact equality (=)"
|
||||||
|
+ "\n- When user says 'Europe', match ANY location containing 'Europe' (e.g., 'Northern Europe', 'Western Europe', 'Europe', 'Central Europe')"
|
||||||
|
+ "\n- When user says 'America', match locations like 'North America', 'South America', 'Latin America', 'United States'"
|
||||||
|
+ "\n- When user says 'Asia', match 'Asia', 'Southeast Asia', 'East Asia', etc."
|
||||||
|
+ "\n- Examples:"
|
||||||
|
+ "\n * User: 'Europe' → SQL: WHERE geographic_focus LIKE '%Europe%'"
|
||||||
|
+ "\n * User: 'America' → SQL: WHERE geographic_focus LIKE '%America%'"
|
||||||
|
+ "\n * User: 'UK' → SQL: WHERE geographic_focus LIKE '%UK%' OR geographic_focus LIKE '%United Kingdom%'"
|
||||||
|
+ "\n- Be INCLUSIVE: capture all relevant regional variations"
|
||||||
)
|
)
|
||||||
self.agent = create_react_agent(
|
self.agent = create_react_agent(
|
||||||
model=self.llm,
|
model=self.llm,
|
||||||
|
|||||||
Binary file not shown.
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
Migration: Add industry column to projects table
|
||||||
|
Date: 2025-10-23
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path to import app modules
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from app.db.db import DATABASE_URL, engine
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""Add industry column to projects table"""
|
||||||
|
print("Running migration: Add industry column to projects table")
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
# Check if column already exists
|
||||||
|
result = conn.execute(text("PRAGMA table_info(projects)"))
|
||||||
|
columns = [row[1] for row in result]
|
||||||
|
|
||||||
|
if 'industry' in columns:
|
||||||
|
print("Column 'industry' already exists in projects table. Skipping migration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add the industry column
|
||||||
|
conn.execute(text("ALTER TABLE projects ADD COLUMN industry VARCHAR"))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
print("Successfully added 'industry' column to projects table")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""Remove industry column from projects table"""
|
||||||
|
print("Running downgrade: Remove industry column from projects table")
|
||||||
|
|
||||||
|
# Note: SQLite doesn't support DROP COLUMN directly
|
||||||
|
# This is a simplified version - in production you'd need to recreate the table
|
||||||
|
print("Warning: SQLite doesn't support DROP COLUMN.")
|
||||||
|
print("To remove the column, you would need to:")
|
||||||
|
print("1. Create a new table without the industry column")
|
||||||
|
print("2. Copy data from old table to new table")
|
||||||
|
print("3. Drop old table and rename new table")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Run database migration")
|
||||||
|
parser.add_argument(
|
||||||
|
"direction",
|
||||||
|
choices=["upgrade", "downgrade"],
|
||||||
|
default="upgrade",
|
||||||
|
nargs="?",
|
||||||
|
help="Migration direction (default: upgrade)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.direction == "upgrade":
|
||||||
|
upgrade()
|
||||||
|
else:
|
||||||
|
downgrade()
|
||||||
Reference in New Issue
Block a user