alphax/app/db/postgres_connection.py
2026-05-16 14:52:10 +08:00

98 lines
2.9 KiB
Python

"""PostgreSQL runtime and migration helpers."""
from __future__ import annotations
import os
from pathlib import Path
import psycopg
REPO_ROOT = Path(__file__).resolve().parents[2]
MIGRATIONS_DIR = REPO_ROOT / "app" / "db" / "migrations"
_MIGRATIONS_CHECKED = False
class DbRow(dict):
"""Mapping row that also supports legacy positional reads during cutover."""
__slots__ = ("_values",)
def __init__(self, names, values):
super().__init__(zip(names, values))
self._values = tuple(values)
def __getitem__(self, key):
if isinstance(key, int):
return self._values[key]
return super().__getitem__(key)
def dict_index_row(cursor):
if cursor.description is None:
return None
names = [col.name for col in cursor.description]
def make_row(values):
return DbRow(names, values)
return make_row
def get_database_url(database_url: str | None = None) -> str:
url = database_url or os.getenv("DATABASE_URL", "")
if not url:
raise RuntimeError("DATABASE_URL is required for AlphaX PostgreSQL runtime")
return url
def connect(database_url: str | None = None, *, autocommit: bool = False) -> psycopg.Connection:
return psycopg.connect(get_database_url(database_url), autocommit=autocommit, row_factory=dict_index_row)
def apply_migrations(database_url: str | None = None, migrations_dir: Path = MIGRATIONS_DIR) -> list[str]:
files = sorted(migrations_dir.glob("*.sql"))
if not files:
raise RuntimeError(f"No PostgreSQL migration files found in {migrations_dir}")
applied_now: list[str] = []
with connect(database_url) as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
applied = {row["version"] for row in conn.execute("SELECT version FROM schema_migrations").fetchall()}
for path in files:
version = path.name
if version in applied:
continue
with conn.transaction():
conn.execute(path.read_text(encoding="utf-8"))
conn.execute("INSERT INTO schema_migrations(version) VALUES (%s)", (version,))
applied_now.append(version)
return applied_now
def ensure_migrations_once() -> None:
global _MIGRATIONS_CHECKED
if _MIGRATIONS_CHECKED:
return
apply_migrations()
_MIGRATIONS_CHECKED = True
def table_columns(table_name: str) -> set[str]:
with connect() as conn:
rows = conn.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema='public' AND table_name=%s
""",
(table_name,),
).fetchall()
return {row["column_name"] for row in rows}