alphax/tests/conftest.py
2026-05-22 23:17:37 +08:00

297 lines
8.6 KiB
Python

"""PostgreSQL test harness.
The application runtime is PostgreSQL-only. A number of older tests still use
``sqlite3.connect(tmp_path)`` as a shorthand for "give me an isolated DB".
This file keeps that shorthand test-local by routing it to the PostgreSQL test
database and truncating business tables before every test.
"""
from __future__ import annotations
import re
import sqlite3
import sys
import os
from datetime import datetime
from pathlib import Path
from urllib.parse import urlsplit, urlunsplit
import pytest
import psycopg
PROJECT_DIR = Path(__file__).resolve().parents[1]
if str(PROJECT_DIR) not in sys.path:
sys.path.insert(0, str(PROJECT_DIR))
def _test_database_url() -> str:
explicit = os.getenv("ALPHAX_TEST_DATABASE_URL", "").strip()
if explicit:
return explicit
runtime_url = os.getenv("DATABASE_URL", "").strip()
if not runtime_url:
raise RuntimeError("DATABASE_URL is required for PostgreSQL tests")
parts = urlsplit(runtime_url)
db_name = parts.path.lstrip("/") or "alphax"
test_name = db_name if db_name.endswith("_test") else f"{db_name}_test"
return urlunsplit((parts.scheme, parts.netloc, f"/{test_name}", parts.query, parts.fragment))
def _maintenance_url(database_url: str) -> str:
parts = urlsplit(database_url)
return urlunsplit((parts.scheme, parts.netloc, "/postgres", parts.query, parts.fragment))
def _ensure_test_database(database_url: str) -> None:
parts = urlsplit(database_url)
db_name = parts.path.lstrip("/")
if not db_name:
raise RuntimeError("Test database URL must include a database name")
with psycopg.connect(_maintenance_url(database_url), autocommit=True) as conn:
exists = conn.execute("SELECT 1 FROM pg_database WHERE datname=%s", (db_name,)).fetchone()
if not exists:
conn.execute(f'CREATE DATABASE "{db_name}"')
TEST_DATABASE_URL = _test_database_url()
_ensure_test_database(TEST_DATABASE_URL)
os.environ["DATABASE_URL"] = TEST_DATABASE_URL
from app.db import altcoin_db, auth_db, scheduler_db, schema
from app.db import postgres_connection
# Legacy tests patch these attributes. Keep the names available only for tests;
# runtime modules no longer read them.
altcoin_db.DB_PATH = ""
auth_db.DB_PATH = ""
scheduler_db.SCHEDULER_DB_PATH = ""
schema.DB_PATH = ""
_REAL_SQLITE_CONNECT = sqlite3.connect
_ID_TABLES = {
"screening_log",
"recommendation",
"price_tracking",
"paper_orders",
"paper_trades",
"paper_trade_events",
"live_trade_accounts",
"live_order_intents",
"live_order_events",
"cron_run_log",
"review_log",
"missed_explosions",
"strategy_iteration_log",
"strategy_rule_candidate",
"strategy_failure_pattern",
"push_log",
"sentiment_events",
"llm_insights",
"event_news",
"onchain_token_map",
"onchain_events",
"onchain_token_metrics",
"onchain_raw_events",
"app_user",
"email_verification_code",
"user_session",
"user_subscription",
"payment_order",
"pending_registration",
"referral_reward",
"user_activity",
"user_watchlist",
"user_saved_observation",
"system_reset_log",
"scheduler_manual_trigger",
"strategy_runtime_config",
"system_config",
"chat_sessions",
"chat_messages",
"chat_user_preferences",
}
def _translate_sql(sql: str) -> str:
text = sql
text = text.replace("datetime('now')", "to_char(NOW(), 'YYYY-MM-DD\"T\"HH24:MI:SS')")
text = text.replace("INSERT OR IGNORE", "INSERT")
text = re.sub(r"\?", "%s", text)
return text
def _pragma_table_info(sql: str) -> str:
match = re.match(r"\s*PRAGMA\s+table_info\((['\"]?)([a-zA-Z_][a-zA-Z0-9_]*)\1\)\s*;?\s*$", sql, re.IGNORECASE)
return match.group(2) if match else ""
def _insert_table(sql: str) -> str:
match = re.search(r"insert\s+into\s+([a-zA-Z_][a-zA-Z0-9_]*)", sql, re.IGNORECASE)
return match.group(1).lower() if match else ""
def _should_add_returning_id(sql: str) -> bool:
lowered = sql.strip().lower()
if not lowered.startswith("insert") or " returning " in lowered:
return False
return _insert_table(lowered) in _ID_TABLES
class _PgCursorCompat:
def __init__(self, cursor, lastrowid=None):
self._cursor = cursor
self.lastrowid = lastrowid
def fetchone(self):
return self._cursor.fetchone()
def fetchall(self):
return self._cursor.fetchall()
def __iter__(self):
return iter(self._cursor)
class _PgSqliteCompatConnection:
row_factory = None
def __init__(self):
self._conn = postgres_connection.connect()
self.row_factory = None
def execute(self, sql, params=()):
table_name = _pragma_table_info(str(sql))
if table_name:
cur = self._conn.execute(
"""
SELECT
(ordinal_position - 1)::int AS cid,
column_name AS name,
data_type AS type,
CASE WHEN is_nullable='NO' THEN 1 ELSE 0 END AS notnull,
column_default AS dflt_value,
0 AS pk
FROM information_schema.columns
WHERE table_schema='public' AND table_name=%s
ORDER BY ordinal_position
""",
(table_name,),
)
return _PgCursorCompat(cur)
translated = _translate_sql(str(sql))
lastrowid = None
added_returning_id = False
if _should_add_returning_id(translated):
translated = translated.rstrip().rstrip(";") + " RETURNING id"
added_returning_id = True
cur = self._conn.execute(translated, tuple(params or ()))
try:
if added_returning_id:
row = cur.fetchone()
lastrowid = row["id"] if row else None
except Exception:
lastrowid = None
return _PgCursorCompat(cur, lastrowid=lastrowid)
def executemany(self, sql, seq_of_params):
translated = _translate_sql(str(sql))
cur = self._conn.executemany(translated, seq_of_params)
return _PgCursorCompat(cur)
def commit(self):
self._conn.commit()
def rollback(self):
self._conn.rollback()
def close(self):
self._conn.close()
def cursor(self):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if exc_type:
self.rollback()
else:
self.commit()
self.close()
def _pg_sqlite_connect(*args, **kwargs):
return _PgSqliteCompatConnection()
def _pg_compat_connect(*args, **kwargs):
return _PgSqliteCompatConnection()
def _business_tables() -> list[str]:
conn = postgres_connection.connect()
try:
rows = conn.execute(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname='public' AND tablename != 'schema_migrations'
ORDER BY tablename
"""
).fetchall()
return [row["tablename"] for row in rows]
finally:
conn.close()
def _truncate_business_tables():
postgres_connection.apply_migrations()
tables = _business_tables()
if not tables:
return
conn = postgres_connection.connect()
try:
quoted = ", ".join(f'"{table}"' for table in tables)
conn.execute(f"TRUNCATE TABLE {quoted} RESTART IDENTITY CASCADE")
conn.commit()
finally:
conn.close()
@pytest.fixture(autouse=True)
def postgres_test_db(monkeypatch):
monkeypatch.setenv("ALPHAX_BOOTSTRAP_ADMIN", "0")
monkeypatch.setattr(sqlite3, "connect", _pg_sqlite_connect)
monkeypatch.setattr(altcoin_db, "sqlite3", sqlite3, raising=False)
monkeypatch.setattr(altcoin_db, "get_conn", _pg_compat_connect)
monkeypatch.setattr(auth_db, "get_conn", _pg_compat_connect)
monkeypatch.setattr(schema, "get_conn", _pg_compat_connect)
for module_name in (
"app.db.analytics",
"app.db.llm_insights",
"app.db.recommendation_queries",
"app.db.review_queries",
"app.services.llm_insights",
"app.services.review_engine",
"app.analysis.reverse_analysis",
"app.web.routes_content",
"app.web.routes_strategy",
):
module = sys.modules.get(module_name)
if module is not None and hasattr(module, "get_conn"):
monkeypatch.setattr(module, "get_conn", _pg_compat_connect, raising=False)
_truncate_business_tables()
yield
@pytest.fixture
def pg_conn():
conn = postgres_connection.connect()
try:
yield conn
finally:
conn.close()