#!/usr/bin/env python3 """Import current SQLite data into the PostgreSQL migration target.""" from __future__ import annotations import argparse import sqlite3 import sys from pathlib import Path from typing import Iterable from psycopg import sql REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from app.db.postgres_connection import connect # noqa: E402 DEFAULT_SQLITE_PATH = REPO_ROOT / "data" / "altcoin_monitor.db" DEFAULT_SCHEDULER_SQLITE_PATH = REPO_ROOT / "data" / "scheduler_state.db" EXCLUDED_TABLES = {"sqlite_sequence", "schema_migrations"} def _sqlite_tables(conn: sqlite3.Connection) -> list[str]: rows = conn.execute( """ SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name """ ).fetchall() return [row["name"] for row in rows if row["name"] not in EXCLUDED_TABLES] def _sqlite_columns(conn: sqlite3.Connection, table: str) -> list[str]: return [row["name"] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()] def _postgres_tables(conn) -> set[str]: rows = conn.execute( """ SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_type='BASE TABLE' """ ).fetchall() return {row[0] for row in rows if row[0] not in EXCLUDED_TABLES} def _postgres_columns(conn, table: str) -> list[str]: rows = conn.execute( """ SELECT column_name FROM information_schema.columns WHERE table_schema='public' AND table_name=%s ORDER BY ordinal_position """, (table,), ).fetchall() return [row[0] for row in rows] def _serial_columns(conn) -> list[tuple[str, str]]: rows = conn.execute( """ SELECT table_name, column_name FROM information_schema.columns WHERE table_schema='public' AND column_default LIKE 'nextval(%' ORDER BY table_name, column_name """ ).fetchall() return [(row[0], row[1]) for row in rows] def _batched(rows: Iterable[sqlite3.Row], size: int) -> Iterable[list[sqlite3.Row]]: batch: list[sqlite3.Row] = [] for row in rows: batch.append(row) if len(batch) >= size: yield batch batch = [] if batch: yield batch def _truncate_tables(conn, tables: list[str]) -> None: if not tables: return stmt = sql.SQL("TRUNCATE TABLE {tables} RESTART IDENTITY CASCADE").format( tables=sql.SQL(", ").join(sql.Identifier(t) for t in tables) ) conn.execute(stmt) def _reset_sequences(conn) -> None: for table, column in _serial_columns(conn): conn.execute( sql.SQL( """ SELECT setval( pg_get_serial_sequence({table_name}, {column_name}), COALESCE((SELECT MAX({column}) FROM {table}), 0) + 1, false ) """ ).format( table_name=sql.Literal(table), column_name=sql.Literal(column), column=sql.Identifier(column), table=sql.Identifier(table), ) ) def _apply_post_import_fixes(conn) -> None: """Bring imported legacy rows up to current PostgreSQL runtime invariants.""" conn.execute( """ UPDATE recommendation SET display_bucket='history', execution_status='invalid', lifecycle_state='closed', entry_triggered=0, state_reason=COALESCE(NULLIF(state_reason, ''), '机会失效,归入历史复盘') WHERE status IN ('expired', 'invalid', 'archived', 'stopped_out') AND COALESCE(display_bucket, '') != 'history' """ ) def _import_one_sqlite( sqlite_conn: sqlite3.Connection, pg_conn, *, truncate: bool, batch_size: int, skip_conflicts: bool, source_label: str, ) -> dict[str, int]: imported: dict[str, int] = {} sqlite_tables = _sqlite_tables(sqlite_conn) pg_tables = _postgres_tables(pg_conn) common_tables = [table for table in sqlite_tables if table in pg_tables] missing_tables = [table for table in sqlite_tables if table not in pg_tables] if missing_tables: print(f"[import:{source_label}] skip tables absent in PostgreSQL: {', '.join(missing_tables)}") if truncate: print(f"[import:{source_label}] truncate {len(common_tables)} table(s)") _truncate_tables(pg_conn, common_tables) for table in common_tables: sqlite_cols = _sqlite_columns(sqlite_conn, table) pg_cols = _postgres_columns(pg_conn, table) columns = [col for col in pg_cols if col in sqlite_cols] if not columns: imported[table] = 0 continue select_sql = "SELECT {} FROM {}".format( ", ".join(f'"{col}"' for col in columns), f'"{table}"', ) rows = sqlite_conn.execute(select_sql) insert_sql = sql.SQL("INSERT INTO {table} ({cols}) VALUES ({values}) {conflict}").format( table=sql.Identifier(table), cols=sql.SQL(", ").join(sql.Identifier(col) for col in columns), values=sql.SQL(", ").join(sql.Placeholder() for _ in columns), conflict=sql.SQL("ON CONFLICT DO NOTHING") if skip_conflicts else sql.SQL(""), ) count = 0 for batch in _batched(rows, batch_size): values = [tuple(row[col] for col in columns) for row in batch] with pg_conn.cursor() as cur: cur.executemany(insert_sql, values) count += len(values) imported[table] = count print(f"[import:{source_label}] {table}: {count}") return imported def import_sqlite( sqlite_path: Path, database_url: str | None = None, *, scheduler_sqlite_path: Path | None = None, truncate: bool = False, batch_size: int = 1000, skip_conflicts: bool = False, ) -> dict[str, int]: if not sqlite_path.exists(): raise FileNotFoundError(f"SQLite database not found: {sqlite_path}") sqlite_conn = sqlite3.connect(str(sqlite_path)) sqlite_conn.row_factory = sqlite3.Row imported: dict[str, int] = {} scheduler_conn: sqlite3.Connection | None = None try: if scheduler_sqlite_path and scheduler_sqlite_path.exists(): scheduler_conn = sqlite3.connect(str(scheduler_sqlite_path)) scheduler_conn.row_factory = sqlite3.Row elif scheduler_sqlite_path: print(f"[import:scheduler] skip missing scheduler db: {scheduler_sqlite_path}") with connect(database_url) as pg_conn: with pg_conn.transaction(): imported.update( _import_one_sqlite( sqlite_conn, pg_conn, truncate=truncate, batch_size=batch_size, skip_conflicts=skip_conflicts, source_label="main", ) ) if scheduler_conn: imported.update( _import_one_sqlite( scheduler_conn, pg_conn, truncate=truncate, batch_size=batch_size, skip_conflicts=skip_conflicts, source_label="scheduler", ) ) _reset_sequences(pg_conn) _apply_post_import_fixes(pg_conn) finally: sqlite_conn.close() if scheduler_conn: scheduler_conn.close() return imported def main() -> int: parser = argparse.ArgumentParser(description="Import AlphaX SQLite data into PostgreSQL.") parser.add_argument("--sqlite-path", type=Path, default=DEFAULT_SQLITE_PATH) parser.add_argument( "--scheduler-sqlite-path", type=Path, default=DEFAULT_SCHEDULER_SQLITE_PATH, help="Optional scheduler_state.db path. Use an empty string to skip.", ) parser.add_argument("--database-url", default=None, help="Override DATABASE_URL.") parser.add_argument("--truncate", action="store_true", help="Clear target tables before import.") parser.add_argument("--batch-size", type=int, default=1000) parser.add_argument( "--skip-conflicts", action="store_true", help="Use ON CONFLICT DO NOTHING. Prefer --truncate for clean migrations.", ) args = parser.parse_args() scheduler_sqlite_path = args.scheduler_sqlite_path if str(scheduler_sqlite_path).strip() == "": scheduler_sqlite_path = None imported = import_sqlite( args.sqlite_path, args.database_url, scheduler_sqlite_path=scheduler_sqlite_path, truncate=args.truncate, batch_size=args.batch_size, skip_conflicts=args.skip_conflicts, ) print(f"[import] completed {len(imported)} table(s), {sum(imported.values())} row(s)") return 0 if __name__ == "__main__": raise SystemExit(main())