98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
"""PostgreSQL runtime and migration helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import psycopg
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
MIGRATIONS_DIR = REPO_ROOT / "app" / "db" / "migrations"
|
|
_MIGRATIONS_CHECKED = False
|
|
|
|
|
|
class DbRow(dict):
|
|
"""Mapping row that also supports legacy positional reads during cutover."""
|
|
|
|
__slots__ = ("_values",)
|
|
|
|
def __init__(self, names, values):
|
|
super().__init__(zip(names, values))
|
|
self._values = tuple(values)
|
|
|
|
def __getitem__(self, key):
|
|
if isinstance(key, int):
|
|
return self._values[key]
|
|
return super().__getitem__(key)
|
|
|
|
|
|
def dict_index_row(cursor):
|
|
if cursor.description is None:
|
|
return None
|
|
names = [col.name for col in cursor.description]
|
|
|
|
def make_row(values):
|
|
return DbRow(names, values)
|
|
|
|
return make_row
|
|
|
|
|
|
def get_database_url(database_url: str | None = None) -> str:
|
|
url = database_url or os.getenv("DATABASE_URL", "")
|
|
if not url:
|
|
raise RuntimeError("DATABASE_URL is required for AlphaX PostgreSQL runtime")
|
|
return url
|
|
|
|
|
|
def connect(database_url: str | None = None, *, autocommit: bool = False) -> psycopg.Connection:
|
|
return psycopg.connect(get_database_url(database_url), autocommit=autocommit, row_factory=dict_index_row)
|
|
|
|
|
|
def apply_migrations(database_url: str | None = None, migrations_dir: Path = MIGRATIONS_DIR) -> list[str]:
|
|
files = sorted(migrations_dir.glob("*.sql"))
|
|
if not files:
|
|
raise RuntimeError(f"No PostgreSQL migration files found in {migrations_dir}")
|
|
|
|
applied_now: list[str] = []
|
|
with connect(database_url) as conn:
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version TEXT PRIMARY KEY,
|
|
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
)
|
|
"""
|
|
)
|
|
applied = {row["version"] for row in conn.execute("SELECT version FROM schema_migrations").fetchall()}
|
|
for path in files:
|
|
version = path.name
|
|
if version in applied:
|
|
continue
|
|
with conn.transaction():
|
|
conn.execute(path.read_text(encoding="utf-8"))
|
|
conn.execute("INSERT INTO schema_migrations(version) VALUES (%s)", (version,))
|
|
applied_now.append(version)
|
|
return applied_now
|
|
|
|
|
|
def ensure_migrations_once() -> None:
|
|
global _MIGRATIONS_CHECKED
|
|
if _MIGRATIONS_CHECKED:
|
|
return
|
|
apply_migrations()
|
|
_MIGRATIONS_CHECKED = True
|
|
|
|
|
|
def table_columns(table_name: str) -> set[str]:
|
|
with connect() as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema='public' AND table_name=%s
|
|
""",
|
|
(table_name,),
|
|
).fetchall()
|
|
return {row["column_name"] for row in rows}
|