118 lines
4.1 KiB
Python
118 lines
4.1 KiB
Python
"""
|
|
Migration: Add fields from feedback fixes
|
|
Date: 2025-01-07
|
|
|
|
Adds the following fields:
|
|
- projects.is_archived (INTEGER, default 0)
|
|
- companies.product_service (TEXT, nullable)
|
|
- companies.clients (TEXT, nullable - stored as JSON string)
|
|
- investor_members.linkedin (VARCHAR, nullable)
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path to import app modules
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.db.db import engine
|
|
|
|
|
|
def check_column_exists(conn, table_name, column_name):
|
|
"""Check if a column exists in a table"""
|
|
result = conn.execute(text(f"PRAGMA table_info({table_name})"))
|
|
columns = [row[1] for row in result]
|
|
return column_name in columns
|
|
|
|
|
|
def upgrade():
|
|
"""Add new columns to tables"""
|
|
print("Running migration: Add feedback fixes fields")
|
|
print("=" * 60)
|
|
|
|
with engine.begin() as conn: # Use begin() for transaction management
|
|
# 1. Add is_archived to projects table
|
|
print("\n1. Adding 'is_archived' column to projects table...")
|
|
if check_column_exists(conn, "projects", "is_archived"):
|
|
print(" ✓ Column 'is_archived' already exists. Skipping.")
|
|
else:
|
|
conn.execute(
|
|
text(
|
|
"ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"
|
|
)
|
|
)
|
|
# Set default value for existing rows
|
|
conn.execute(
|
|
text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL")
|
|
)
|
|
print(" ✓ Successfully added 'is_archived' column to projects table")
|
|
|
|
# 2. Add product_service to companies table
|
|
print("\n2. Adding 'product_service' column to companies table...")
|
|
if check_column_exists(conn, "companies", "product_service"):
|
|
print(" ✓ Column 'product_service' already exists. Skipping.")
|
|
else:
|
|
conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT"))
|
|
print(" ✓ Successfully added 'product_service' column to companies table")
|
|
|
|
# 3. Add clients to companies table
|
|
print("\n3. Adding 'clients' column to companies table...")
|
|
if check_column_exists(conn, "companies", "clients"):
|
|
print(" ✓ Column 'clients' already exists. Skipping.")
|
|
else:
|
|
conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT"))
|
|
print(" ✓ Successfully added 'clients' column to companies table")
|
|
|
|
# 4. Add linkedin to investor_members table
|
|
print("\n4. Adding 'linkedin' column to investor_members table...")
|
|
if check_column_exists(conn, "investor_members", "linkedin"):
|
|
print(" ✓ Column 'linkedin' already exists. Skipping.")
|
|
else:
|
|
conn.execute(
|
|
text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR")
|
|
)
|
|
print(" ✓ Successfully added 'linkedin' column to investor_members table")
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Migration completed successfully!")
|
|
|
|
|
|
def downgrade():
|
|
"""Remove added columns from tables"""
|
|
print("Running downgrade: Remove feedback fixes fields")
|
|
print("=" * 60)
|
|
|
|
# Note: SQLite doesn't support DROP COLUMN directly
|
|
print("\nWarning: SQLite doesn't support DROP COLUMN directly.")
|
|
print("To remove these columns, you would need to:")
|
|
print("1. Create new tables without the columns")
|
|
print("2. Copy data from old tables to new tables")
|
|
print("3. Drop old tables and rename new tables")
|
|
print("\nColumns to remove:")
|
|
print(" - projects.is_archived")
|
|
print(" - companies.product_service")
|
|
print(" - companies.clients")
|
|
print(" - investor_members.linkedin")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Run database migration")
|
|
parser.add_argument(
|
|
"direction",
|
|
choices=["upgrade", "downgrade"],
|
|
default="upgrade",
|
|
nargs="?",
|
|
help="Migration direction (default: upgrade)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.direction == "upgrade":
|
|
upgrade()
|
|
else:
|
|
downgrade()
|