alphax/scripts/postgres/validate_import.py
2026-05-16 14:52:10 +08:00

199 lines
6.8 KiB
Python
Executable File

#!/usr/bin/env python3
"""Compare SQLite and PostgreSQL row counts after import."""
from __future__ import annotations
import argparse
import json
import sqlite3
import sys
from pathlib import Path
from psycopg import sql
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from app.db.postgres_connection import connect # noqa: E402
from scripts.postgres.import_from_sqlite import ( # noqa: E402
DEFAULT_SCHEDULER_SQLITE_PATH,
DEFAULT_SQLITE_PATH,
EXCLUDED_TABLES,
)
KEY_TABLES = [
"recommendation",
"price_tracking",
"screening_log",
"coin_state",
"cron_run_log",
"review_log",
"app_user",
"user_subscription",
"event_news",
"sentiment_events",
"onchain_events",
"onchain_raw_events",
"llm_insights",
"system_reset_log",
"scheduler_job_config",
"scheduler_runtime_status",
"scheduler_manual_trigger",
]
def _sqlite_tables(conn: sqlite3.Connection) -> set[str]:
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
).fetchall()
return {row["name"] for row in rows if row["name"] not in EXCLUDED_TABLES}
def _postgres_tables(conn) -> set[str]:
rows = conn.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema='public'
AND table_type='BASE TABLE'
"""
).fetchall()
return {row[0] for row in rows if row[0] not in EXCLUDED_TABLES}
def _sqlite_count(conn: sqlite3.Connection, table: str) -> int:
return int(conn.execute(f'SELECT COUNT(*) AS n FROM "{table}"').fetchone()["n"])
def _sqlite_max_id(conn: sqlite3.Connection, table: str) -> int | None:
cols = [row["name"] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()]
if "id" not in cols:
return None
value = conn.execute(f'SELECT MAX(id) AS max_id FROM "{table}"').fetchone()["max_id"]
return int(value) if value is not None else None
def _postgres_count(conn, table: str) -> int:
return int(conn.execute(sql.SQL("SELECT COUNT(*) FROM {table}").format(table=sql.Identifier(table))).fetchone()[0])
def _postgres_max_id(conn, table: str) -> int | None:
has_id = conn.execute(
"""
SELECT 1
FROM information_schema.columns
WHERE table_schema='public'
AND table_name=%s
AND column_name='id'
""",
(table,),
).fetchone()
if not has_id:
return None
value = conn.execute(sql.SQL("SELECT MAX(id) FROM {table}").format(table=sql.Identifier(table))).fetchone()[0]
return int(value) if value is not None else None
def _collect_sqlite_sources(sqlite_path: Path, scheduler_sqlite_path: Path | None) -> list[tuple[str, Path]]:
if not sqlite_path.exists():
raise FileNotFoundError(f"SQLite database not found: {sqlite_path}")
sources = [("main", sqlite_path)]
if scheduler_sqlite_path and scheduler_sqlite_path.exists():
sources.append(("scheduler", scheduler_sqlite_path))
return sources
def validate(
sqlite_path: Path,
database_url: str | None = None,
*,
scheduler_sqlite_path: Path | None = None,
all_tables: bool = False,
) -> dict:
sources = _collect_sqlite_sources(sqlite_path, scheduler_sqlite_path)
sqlite_conns = []
try:
sqlite_by_table = {}
for source, path in sources:
conn = sqlite3.connect(str(path))
conn.row_factory = sqlite3.Row
sqlite_conns.append(conn)
for table in _sqlite_tables(conn):
sqlite_by_table[table] = (source, conn)
with connect(database_url) as pg_conn:
sqlite_tables = set(sqlite_by_table)
pg_tables = _postgres_tables(pg_conn)
table_names = sorted(sqlite_tables & pg_tables) if all_tables else [t for t in KEY_TABLES if t in sqlite_tables and t in pg_tables]
tables = []
ok = True
for table in table_names:
source, sqlite_conn = sqlite_by_table[table]
sqlite_count = _sqlite_count(sqlite_conn, table)
pg_count = _postgres_count(pg_conn, table)
sqlite_max_id = _sqlite_max_id(sqlite_conn, table)
pg_max_id = _postgres_max_id(pg_conn, table)
table_ok = sqlite_count == pg_count and sqlite_max_id == pg_max_id
ok = ok and table_ok
tables.append(
{
"table": table,
"source": source,
"sqlite_count": sqlite_count,
"postgres_count": pg_count,
"sqlite_max_id": sqlite_max_id,
"postgres_max_id": pg_max_id,
"ok": table_ok,
}
)
return {
"ok": ok,
"checked_tables": len(tables),
"sqlite_only_tables": sorted(sqlite_tables - pg_tables),
"postgres_only_tables": sorted(pg_tables - sqlite_tables),
"tables": tables,
}
finally:
for conn in sqlite_conns:
conn.close()
def main() -> int:
parser = argparse.ArgumentParser(description="Validate AlphaX SQLite -> PostgreSQL import.")
parser.add_argument("--sqlite-path", type=Path, default=DEFAULT_SQLITE_PATH)
parser.add_argument("--scheduler-sqlite-path", type=Path, default=DEFAULT_SCHEDULER_SQLITE_PATH)
parser.add_argument("--database-url", default=None, help="Override DATABASE_URL.")
parser.add_argument("--all-tables", action="store_true")
parser.add_argument("--json", action="store_true", help="Print full JSON report.")
args = parser.parse_args()
report = validate(
args.sqlite_path,
args.database_url,
scheduler_sqlite_path=args.scheduler_sqlite_path,
all_tables=args.all_tables,
)
if args.json:
print(json.dumps(report, ensure_ascii=False, indent=2))
else:
status = "PASS" if report["ok"] else "FAIL"
print(f"[validate] {status}: checked {report['checked_tables']} table(s)")
for item in report["tables"]:
mark = "OK" if item["ok"] else "DIFF"
print(
f"[validate] {mark} {item['table']} ({item['source']}): "
f"count {item['sqlite_count']} -> {item['postgres_count']}, "
f"max_id {item['sqlite_max_id']} -> {item['postgres_max_id']}"
)
if report["sqlite_only_tables"]:
print(f"[validate] sqlite-only tables: {', '.join(report['sqlite_only_tables'])}")
if report["postgres_only_tables"]:
print(f"[validate] postgres-only tables: {', '.join(report['postgres_only_tables'])}")
return 0 if report["ok"] else 1
if __name__ == "__main__":
raise SystemExit(main())