"""PostgreSQL runtime and migration helpers.""" from __future__ import annotations import os from pathlib import Path import psycopg REPO_ROOT = Path(__file__).resolve().parents[2] MIGRATIONS_DIR = REPO_ROOT / "app" / "db" / "migrations" _MIGRATIONS_CHECKED = False class DbRow(dict): """Mapping row that also supports legacy positional reads during cutover.""" __slots__ = ("_values",) def __init__(self, names, values): super().__init__(zip(names, values)) self._values = tuple(values) def __getitem__(self, key): if isinstance(key, int): return self._values[key] return super().__getitem__(key) def dict_index_row(cursor): if cursor.description is None: return None names = [col.name for col in cursor.description] def make_row(values): return DbRow(names, values) return make_row def get_database_url(database_url: str | None = None) -> str: url = database_url or os.getenv("DATABASE_URL", "") if not url: raise RuntimeError("DATABASE_URL is required for AlphaX PostgreSQL runtime") return url def connect(database_url: str | None = None, *, autocommit: bool = False) -> psycopg.Connection: return psycopg.connect(get_database_url(database_url), autocommit=autocommit, row_factory=dict_index_row) def apply_migrations(database_url: str | None = None, migrations_dir: Path = MIGRATIONS_DIR) -> list[str]: files = sorted(migrations_dir.glob("*.sql")) if not files: raise RuntimeError(f"No PostgreSQL migration files found in {migrations_dir}") applied_now: list[str] = [] with connect(database_url) as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS schema_migrations ( version TEXT PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) """ ) applied = {row["version"] for row in conn.execute("SELECT version FROM schema_migrations").fetchall()} for path in files: version = path.name if version in applied: continue with conn.transaction(): conn.execute(path.read_text(encoding="utf-8")) conn.execute("INSERT INTO schema_migrations(version) VALUES (%s)", (version,)) applied_now.append(version) return applied_now def ensure_migrations_once() -> None: global _MIGRATIONS_CHECKED if _MIGRATIONS_CHECKED: return apply_migrations() _MIGRATIONS_CHECKED = True def table_columns(table_name: str) -> set[str]: with connect() as conn: rows = conn.execute( """ SELECT column_name FROM information_schema.columns WHERE table_schema='public' AND table_name=%s """, (table_name,), ).fetchall() return {row["column_name"] for row in rows}