refactor: Clean up migration script and improve readability by removing unnecessary imports and formatting

This commit is contained in:
bolade
2025-11-11 20:28:20 +01:00
parent 215fec2895
commit b92feaa13a
+24 -17
View File
@@ -9,15 +9,15 @@ Adds the following fields:
- investor_members.linkedin (VARCHAR, nullable)
"""
import os
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import create_engine, text
from app.db.db import DATABASE_URL, engine
from sqlalchemy import text
from app.db.db import engine
def check_column_exists(conn, table_name, column_name):
@@ -31,18 +31,24 @@ def upgrade():
"""Add new columns to tables"""
print("Running migration: Add feedback fixes fields")
print("=" * 60)
with engine.begin() as conn: # Use begin() for transaction management
# 1. Add is_archived to projects table
print("\n1. Adding 'is_archived' column to projects table...")
if check_column_exists(conn, "projects", "is_archived"):
print(" ✓ Column 'is_archived' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"))
conn.execute(
text(
"ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"
)
)
# Set default value for existing rows
conn.execute(text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL"))
conn.execute(
text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL")
)
print(" ✓ Successfully added 'is_archived' column to projects table")
# 2. Add product_service to companies table
print("\n2. Adding 'product_service' column to companies table...")
if check_column_exists(conn, "companies", "product_service"):
@@ -50,7 +56,7 @@ def upgrade():
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT"))
print(" ✓ Successfully added 'product_service' column to companies table")
# 3. Add clients to companies table
print("\n3. Adding 'clients' column to companies table...")
if check_column_exists(conn, "companies", "clients"):
@@ -58,15 +64,17 @@ def upgrade():
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT"))
print(" ✓ Successfully added 'clients' column to companies table")
# 4. Add linkedin to investor_members table
print("\n4. Adding 'linkedin' column to investor_members table...")
if check_column_exists(conn, "investor_members", "linkedin"):
print(" ✓ Column 'linkedin' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR"))
conn.execute(
text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR")
)
print(" ✓ Successfully added 'linkedin' column to investor_members table")
print("\n" + "=" * 60)
print("Migration completed successfully!")
@@ -75,7 +83,7 @@ def downgrade():
"""Remove added columns from tables"""
print("Running downgrade: Remove feedback fixes fields")
print("=" * 60)
# Note: SQLite doesn't support DROP COLUMN directly
print("\nWarning: SQLite doesn't support DROP COLUMN directly.")
print("To remove these columns, you would need to:")
@@ -91,20 +99,19 @@ def downgrade():
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run database migration")
parser.add_argument(
"direction",
choices=["upgrade", "downgrade"],
default="upgrade",
nargs="?",
help="Migration direction (default: upgrade)"
help="Migration direction (default: upgrade)",
)
args = parser.parse_args()
if args.direction == "upgrade":
upgrade()
else:
downgrade()