299 lines
8.6 KiB
Python
299 lines
8.6 KiB
Python
"""PostgreSQL test harness.
|
|
|
|
The application runtime is PostgreSQL-only. A number of older tests still use
|
|
``sqlite3.connect(tmp_path)`` as a shorthand for "give me an isolated DB".
|
|
This file keeps that shorthand test-local by routing it to the PostgreSQL test
|
|
database and truncating business tables before every test.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
import sqlite3
|
|
import sys
|
|
import os
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from urllib.parse import urlsplit, urlunsplit
|
|
|
|
import pytest
|
|
import psycopg
|
|
|
|
PROJECT_DIR = Path(__file__).resolve().parents[1]
|
|
if str(PROJECT_DIR) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_DIR))
|
|
|
|
|
|
def _test_database_url() -> str:
|
|
explicit = os.getenv("ALPHAX_TEST_DATABASE_URL", "").strip()
|
|
if explicit:
|
|
return explicit
|
|
runtime_url = os.getenv("DATABASE_URL", "").strip()
|
|
if not runtime_url:
|
|
raise RuntimeError("DATABASE_URL is required for PostgreSQL tests")
|
|
parts = urlsplit(runtime_url)
|
|
db_name = parts.path.lstrip("/") or "alphax"
|
|
test_name = db_name if db_name.endswith("_test") else f"{db_name}_test"
|
|
return urlunsplit((parts.scheme, parts.netloc, f"/{test_name}", parts.query, parts.fragment))
|
|
|
|
|
|
def _maintenance_url(database_url: str) -> str:
|
|
parts = urlsplit(database_url)
|
|
return urlunsplit((parts.scheme, parts.netloc, "/postgres", parts.query, parts.fragment))
|
|
|
|
|
|
def _ensure_test_database(database_url: str) -> None:
|
|
parts = urlsplit(database_url)
|
|
db_name = parts.path.lstrip("/")
|
|
if not db_name:
|
|
raise RuntimeError("Test database URL must include a database name")
|
|
with psycopg.connect(_maintenance_url(database_url), autocommit=True) as conn:
|
|
exists = conn.execute("SELECT 1 FROM pg_database WHERE datname=%s", (db_name,)).fetchone()
|
|
if not exists:
|
|
conn.execute(f'CREATE DATABASE "{db_name}"')
|
|
|
|
|
|
TEST_DATABASE_URL = _test_database_url()
|
|
_ensure_test_database(TEST_DATABASE_URL)
|
|
os.environ["DATABASE_URL"] = TEST_DATABASE_URL
|
|
|
|
from app.db import altcoin_db, auth_db, scheduler_db, schema
|
|
from app.db import postgres_connection
|
|
|
|
|
|
# Legacy tests patch these attributes. Keep the names available only for tests;
|
|
# runtime modules no longer read them.
|
|
altcoin_db.DB_PATH = ""
|
|
auth_db.DB_PATH = ""
|
|
scheduler_db.SCHEDULER_DB_PATH = ""
|
|
schema.DB_PATH = ""
|
|
|
|
|
|
_REAL_SQLITE_CONNECT = sqlite3.connect
|
|
_ID_TABLES = {
|
|
"screening_log",
|
|
"recommendation",
|
|
"price_tracking",
|
|
"paper_orders",
|
|
"paper_trades",
|
|
"paper_trade_events",
|
|
"live_account_snapshots",
|
|
"live_trade_accounts",
|
|
"live_order_intents",
|
|
"live_order_events",
|
|
"cron_run_log",
|
|
"review_log",
|
|
"missed_explosions",
|
|
"strategy_iteration_log",
|
|
"strategy_signals",
|
|
"strategy_rule_candidate",
|
|
"strategy_failure_pattern",
|
|
"push_log",
|
|
"sentiment_events",
|
|
"llm_insights",
|
|
"event_news",
|
|
"onchain_token_map",
|
|
"onchain_events",
|
|
"onchain_token_metrics",
|
|
"onchain_raw_events",
|
|
"app_user",
|
|
"email_verification_code",
|
|
"user_session",
|
|
"user_subscription",
|
|
"payment_order",
|
|
"pending_registration",
|
|
"referral_reward",
|
|
"user_activity",
|
|
"user_watchlist",
|
|
"user_saved_observation",
|
|
"system_reset_log",
|
|
"scheduler_manual_trigger",
|
|
"strategy_runtime_config",
|
|
"system_config",
|
|
"chat_sessions",
|
|
"chat_messages",
|
|
"chat_user_preferences",
|
|
}
|
|
|
|
|
|
def _translate_sql(sql: str) -> str:
|
|
text = sql
|
|
text = text.replace("datetime('now')", "to_char(NOW(), 'YYYY-MM-DD\"T\"HH24:MI:SS')")
|
|
text = text.replace("INSERT OR IGNORE", "INSERT")
|
|
text = re.sub(r"\?", "%s", text)
|
|
return text
|
|
|
|
|
|
def _pragma_table_info(sql: str) -> str:
|
|
match = re.match(r"\s*PRAGMA\s+table_info\((['\"]?)([a-zA-Z_][a-zA-Z0-9_]*)\1\)\s*;?\s*$", sql, re.IGNORECASE)
|
|
return match.group(2) if match else ""
|
|
|
|
|
|
def _insert_table(sql: str) -> str:
|
|
match = re.search(r"insert\s+into\s+([a-zA-Z_][a-zA-Z0-9_]*)", sql, re.IGNORECASE)
|
|
return match.group(1).lower() if match else ""
|
|
|
|
|
|
def _should_add_returning_id(sql: str) -> bool:
|
|
lowered = sql.strip().lower()
|
|
if not lowered.startswith("insert") or " returning " in lowered:
|
|
return False
|
|
return _insert_table(lowered) in _ID_TABLES
|
|
|
|
|
|
class _PgCursorCompat:
|
|
def __init__(self, cursor, lastrowid=None):
|
|
self._cursor = cursor
|
|
self.lastrowid = lastrowid
|
|
|
|
def fetchone(self):
|
|
return self._cursor.fetchone()
|
|
|
|
def fetchall(self):
|
|
return self._cursor.fetchall()
|
|
|
|
def __iter__(self):
|
|
return iter(self._cursor)
|
|
|
|
|
|
class _PgSqliteCompatConnection:
|
|
row_factory = None
|
|
|
|
def __init__(self):
|
|
self._conn = postgres_connection.connect()
|
|
self.row_factory = None
|
|
|
|
def execute(self, sql, params=()):
|
|
table_name = _pragma_table_info(str(sql))
|
|
if table_name:
|
|
cur = self._conn.execute(
|
|
"""
|
|
SELECT
|
|
(ordinal_position - 1)::int AS cid,
|
|
column_name AS name,
|
|
data_type AS type,
|
|
CASE WHEN is_nullable='NO' THEN 1 ELSE 0 END AS notnull,
|
|
column_default AS dflt_value,
|
|
0 AS pk
|
|
FROM information_schema.columns
|
|
WHERE table_schema='public' AND table_name=%s
|
|
ORDER BY ordinal_position
|
|
""",
|
|
(table_name,),
|
|
)
|
|
return _PgCursorCompat(cur)
|
|
translated = _translate_sql(str(sql))
|
|
lastrowid = None
|
|
added_returning_id = False
|
|
if _should_add_returning_id(translated):
|
|
translated = translated.rstrip().rstrip(";") + " RETURNING id"
|
|
added_returning_id = True
|
|
cur = self._conn.execute(translated, tuple(params or ()))
|
|
try:
|
|
if added_returning_id:
|
|
row = cur.fetchone()
|
|
lastrowid = row["id"] if row else None
|
|
except Exception:
|
|
lastrowid = None
|
|
return _PgCursorCompat(cur, lastrowid=lastrowid)
|
|
|
|
def executemany(self, sql, seq_of_params):
|
|
translated = _translate_sql(str(sql))
|
|
cur = self._conn.executemany(translated, seq_of_params)
|
|
return _PgCursorCompat(cur)
|
|
|
|
def commit(self):
|
|
self._conn.commit()
|
|
|
|
def rollback(self):
|
|
self._conn.rollback()
|
|
|
|
def close(self):
|
|
self._conn.close()
|
|
|
|
def cursor(self):
|
|
return self
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
if exc_type:
|
|
self.rollback()
|
|
else:
|
|
self.commit()
|
|
self.close()
|
|
|
|
|
|
def _pg_sqlite_connect(*args, **kwargs):
|
|
return _PgSqliteCompatConnection()
|
|
|
|
|
|
def _pg_compat_connect(*args, **kwargs):
|
|
return _PgSqliteCompatConnection()
|
|
|
|
|
|
def _business_tables() -> list[str]:
|
|
conn = postgres_connection.connect()
|
|
try:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT tablename
|
|
FROM pg_tables
|
|
WHERE schemaname='public' AND tablename != 'schema_migrations'
|
|
ORDER BY tablename
|
|
"""
|
|
).fetchall()
|
|
return [row["tablename"] for row in rows]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def _truncate_business_tables():
|
|
postgres_connection.apply_migrations()
|
|
tables = _business_tables()
|
|
if not tables:
|
|
return
|
|
conn = postgres_connection.connect()
|
|
try:
|
|
quoted = ", ".join(f'"{table}"' for table in tables)
|
|
conn.execute(f"TRUNCATE TABLE {quoted} RESTART IDENTITY CASCADE")
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def postgres_test_db(monkeypatch):
|
|
monkeypatch.setenv("ALPHAX_BOOTSTRAP_ADMIN", "0")
|
|
monkeypatch.setattr(sqlite3, "connect", _pg_sqlite_connect)
|
|
monkeypatch.setattr(altcoin_db, "sqlite3", sqlite3, raising=False)
|
|
monkeypatch.setattr(altcoin_db, "get_conn", _pg_compat_connect)
|
|
monkeypatch.setattr(auth_db, "get_conn", _pg_compat_connect)
|
|
monkeypatch.setattr(schema, "get_conn", _pg_compat_connect)
|
|
for module_name in (
|
|
"app.db.analytics",
|
|
"app.db.llm_insights",
|
|
"app.db.recommendation_queries",
|
|
"app.db.review_queries",
|
|
"app.services.llm_insights",
|
|
"app.services.review_engine",
|
|
"app.analysis.reverse_analysis",
|
|
"app.web.routes_content",
|
|
"app.web.routes_strategy",
|
|
):
|
|
module = sys.modules.get(module_name)
|
|
if module is not None and hasattr(module, "get_conn"):
|
|
monkeypatch.setattr(module, "get_conn", _pg_compat_connect, raising=False)
|
|
_truncate_business_tables()
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def pg_conn():
|
|
conn = postgres_connection.connect()
|
|
try:
|
|
yield conn
|
|
finally:
|
|
conn.close()
|