"""PostgreSQL test harness. The application runtime is PostgreSQL-only. A number of older tests still use ``sqlite3.connect(tmp_path)`` as a shorthand for "give me an isolated DB". This file keeps that shorthand test-local by routing it to the PostgreSQL test database and truncating business tables before every test. """ from __future__ import annotations import re import sqlite3 import sys import os from datetime import datetime from pathlib import Path from urllib.parse import urlsplit, urlunsplit import pytest import psycopg PROJECT_DIR = Path(__file__).resolve().parents[1] if str(PROJECT_DIR) not in sys.path: sys.path.insert(0, str(PROJECT_DIR)) def _test_database_url() -> str: explicit = os.getenv("ALPHAX_TEST_DATABASE_URL", "").strip() if explicit: return explicit runtime_url = os.getenv("DATABASE_URL", "").strip() if not runtime_url: raise RuntimeError("DATABASE_URL is required for PostgreSQL tests") parts = urlsplit(runtime_url) db_name = parts.path.lstrip("/") or "alphax" test_name = db_name if db_name.endswith("_test") else f"{db_name}_test" return urlunsplit((parts.scheme, parts.netloc, f"/{test_name}", parts.query, parts.fragment)) def _maintenance_url(database_url: str) -> str: parts = urlsplit(database_url) return urlunsplit((parts.scheme, parts.netloc, "/postgres", parts.query, parts.fragment)) def _ensure_test_database(database_url: str) -> None: parts = urlsplit(database_url) db_name = parts.path.lstrip("/") if not db_name: raise RuntimeError("Test database URL must include a database name") with psycopg.connect(_maintenance_url(database_url), autocommit=True) as conn: exists = conn.execute("SELECT 1 FROM pg_database WHERE datname=%s", (db_name,)).fetchone() if not exists: conn.execute(f'CREATE DATABASE "{db_name}"') TEST_DATABASE_URL = _test_database_url() _ensure_test_database(TEST_DATABASE_URL) os.environ["DATABASE_URL"] = TEST_DATABASE_URL from app.db import altcoin_db, auth_db, scheduler_db, schema from app.db import postgres_connection # Legacy tests patch these attributes. Keep the names available only for tests; # runtime modules no longer read them. altcoin_db.DB_PATH = "" auth_db.DB_PATH = "" scheduler_db.SCHEDULER_DB_PATH = "" schema.DB_PATH = "" _REAL_SQLITE_CONNECT = sqlite3.connect _ID_TABLES = { "screening_log", "recommendation", "price_tracking", "paper_orders", "paper_trades", "paper_trade_events", "live_account_snapshots", "live_trade_accounts", "live_order_intents", "live_order_events", "cron_run_log", "review_log", "missed_explosions", "strategy_iteration_log", "strategy_signals", "strategy_rule_candidate", "strategy_failure_pattern", "push_log", "sentiment_events", "llm_insights", "event_news", "onchain_token_map", "onchain_events", "onchain_token_metrics", "onchain_raw_events", "app_user", "email_verification_code", "user_session", "user_subscription", "payment_order", "pending_registration", "referral_reward", "user_activity", "user_watchlist", "user_saved_observation", "system_reset_log", "scheduler_manual_trigger", "strategy_runtime_config", "system_config", "chat_sessions", "chat_messages", "chat_user_preferences", } def _translate_sql(sql: str) -> str: text = sql text = text.replace("datetime('now')", "to_char(NOW(), 'YYYY-MM-DD\"T\"HH24:MI:SS')") text = text.replace("INSERT OR IGNORE", "INSERT") text = re.sub(r"\?", "%s", text) return text def _pragma_table_info(sql: str) -> str: match = re.match(r"\s*PRAGMA\s+table_info\((['\"]?)([a-zA-Z_][a-zA-Z0-9_]*)\1\)\s*;?\s*$", sql, re.IGNORECASE) return match.group(2) if match else "" def _insert_table(sql: str) -> str: match = re.search(r"insert\s+into\s+([a-zA-Z_][a-zA-Z0-9_]*)", sql, re.IGNORECASE) return match.group(1).lower() if match else "" def _should_add_returning_id(sql: str) -> bool: lowered = sql.strip().lower() if not lowered.startswith("insert") or " returning " in lowered: return False return _insert_table(lowered) in _ID_TABLES class _PgCursorCompat: def __init__(self, cursor, lastrowid=None): self._cursor = cursor self.lastrowid = lastrowid def fetchone(self): return self._cursor.fetchone() def fetchall(self): return self._cursor.fetchall() def __iter__(self): return iter(self._cursor) class _PgSqliteCompatConnection: row_factory = None def __init__(self): self._conn = postgres_connection.connect() self.row_factory = None def execute(self, sql, params=()): table_name = _pragma_table_info(str(sql)) if table_name: cur = self._conn.execute( """ SELECT (ordinal_position - 1)::int AS cid, column_name AS name, data_type AS type, CASE WHEN is_nullable='NO' THEN 1 ELSE 0 END AS notnull, column_default AS dflt_value, 0 AS pk FROM information_schema.columns WHERE table_schema='public' AND table_name=%s ORDER BY ordinal_position """, (table_name,), ) return _PgCursorCompat(cur) translated = _translate_sql(str(sql)) lastrowid = None added_returning_id = False if _should_add_returning_id(translated): translated = translated.rstrip().rstrip(";") + " RETURNING id" added_returning_id = True cur = self._conn.execute(translated, tuple(params or ())) try: if added_returning_id: row = cur.fetchone() lastrowid = row["id"] if row else None except Exception: lastrowid = None return _PgCursorCompat(cur, lastrowid=lastrowid) def executemany(self, sql, seq_of_params): translated = _translate_sql(str(sql)) cur = self._conn.executemany(translated, seq_of_params) return _PgCursorCompat(cur) def commit(self): self._conn.commit() def rollback(self): self._conn.rollback() def close(self): self._conn.close() def cursor(self): return self def __enter__(self): return self def __exit__(self, exc_type, exc, tb): if exc_type: self.rollback() else: self.commit() self.close() def _pg_sqlite_connect(*args, **kwargs): return _PgSqliteCompatConnection() def _pg_compat_connect(*args, **kwargs): return _PgSqliteCompatConnection() def _business_tables() -> list[str]: conn = postgres_connection.connect() try: rows = conn.execute( """ SELECT tablename FROM pg_tables WHERE schemaname='public' AND tablename != 'schema_migrations' ORDER BY tablename """ ).fetchall() return [row["tablename"] for row in rows] finally: conn.close() def _truncate_business_tables(): postgres_connection.apply_migrations() tables = _business_tables() if not tables: return conn = postgres_connection.connect() try: quoted = ", ".join(f'"{table}"' for table in tables) conn.execute(f"TRUNCATE TABLE {quoted} RESTART IDENTITY CASCADE") conn.commit() finally: conn.close() @pytest.fixture(autouse=True) def postgres_test_db(monkeypatch): monkeypatch.setenv("ALPHAX_BOOTSTRAP_ADMIN", "0") monkeypatch.setattr(sqlite3, "connect", _pg_sqlite_connect) monkeypatch.setattr(altcoin_db, "sqlite3", sqlite3, raising=False) monkeypatch.setattr(altcoin_db, "get_conn", _pg_compat_connect) monkeypatch.setattr(auth_db, "get_conn", _pg_compat_connect) monkeypatch.setattr(schema, "get_conn", _pg_compat_connect) for module_name in ( "app.db.analytics", "app.db.llm_insights", "app.db.recommendation_queries", "app.db.review_queries", "app.services.llm_insights", "app.services.review_engine", "app.analysis.reverse_analysis", "app.web.routes_content", "app.web.routes_strategy", ): module = sys.modules.get(module_name) if module is not None and hasattr(module, "get_conn"): monkeypatch.setattr(module, "get_conn", _pg_compat_connect, raising=False) _truncate_business_tables() yield @pytest.fixture def pg_conn(): conn = postgres_connection.connect() try: yield conn finally: conn.close()