alphax/scripts/postgres/import_from_sqlite.py
2026-05-16 14:52:10 +08:00

285 lines
9.1 KiB
Python
Executable File

#!/usr/bin/env python3
"""Import current SQLite data into the PostgreSQL migration target."""
from __future__ import annotations
import argparse
import sqlite3
import sys
from pathlib import Path
from typing import Iterable
from psycopg import sql
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from app.db.postgres_connection import connect # noqa: E402
DEFAULT_SQLITE_PATH = REPO_ROOT / "data" / "altcoin_monitor.db"
DEFAULT_SCHEDULER_SQLITE_PATH = REPO_ROOT / "data" / "scheduler_state.db"
EXCLUDED_TABLES = {"sqlite_sequence", "schema_migrations"}
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
rows = conn.execute(
"""
SELECT name
FROM sqlite_master
WHERE type='table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name
"""
).fetchall()
return [row["name"] for row in rows if row["name"] not in EXCLUDED_TABLES]
def _sqlite_columns(conn: sqlite3.Connection, table: str) -> list[str]:
return [row["name"] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()]
def _postgres_tables(conn) -> set[str]:
rows = conn.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema='public'
AND table_type='BASE TABLE'
"""
).fetchall()
return {row[0] for row in rows if row[0] not in EXCLUDED_TABLES}
def _postgres_columns(conn, table: str) -> list[str]:
rows = conn.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema='public'
AND table_name=%s
ORDER BY ordinal_position
""",
(table,),
).fetchall()
return [row[0] for row in rows]
def _serial_columns(conn) -> list[tuple[str, str]]:
rows = conn.execute(
"""
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema='public'
AND column_default LIKE 'nextval(%'
ORDER BY table_name, column_name
"""
).fetchall()
return [(row[0], row[1]) for row in rows]
def _batched(rows: Iterable[sqlite3.Row], size: int) -> Iterable[list[sqlite3.Row]]:
batch: list[sqlite3.Row] = []
for row in rows:
batch.append(row)
if len(batch) >= size:
yield batch
batch = []
if batch:
yield batch
def _truncate_tables(conn, tables: list[str]) -> None:
if not tables:
return
stmt = sql.SQL("TRUNCATE TABLE {tables} RESTART IDENTITY CASCADE").format(
tables=sql.SQL(", ").join(sql.Identifier(t) for t in tables)
)
conn.execute(stmt)
def _reset_sequences(conn) -> None:
for table, column in _serial_columns(conn):
conn.execute(
sql.SQL(
"""
SELECT setval(
pg_get_serial_sequence({table_name}, {column_name}),
COALESCE((SELECT MAX({column}) FROM {table}), 0) + 1,
false
)
"""
).format(
table_name=sql.Literal(table),
column_name=sql.Literal(column),
column=sql.Identifier(column),
table=sql.Identifier(table),
)
)
def _apply_post_import_fixes(conn) -> None:
"""Bring imported legacy rows up to current PostgreSQL runtime invariants."""
conn.execute(
"""
UPDATE recommendation
SET display_bucket='history',
execution_status='invalid',
lifecycle_state='closed',
entry_triggered=0,
state_reason=COALESCE(NULLIF(state_reason, ''), '机会失效,归入历史复盘')
WHERE status IN ('expired', 'invalid', 'archived', 'stopped_out')
AND COALESCE(display_bucket, '') != 'history'
"""
)
def _import_one_sqlite(
sqlite_conn: sqlite3.Connection,
pg_conn,
*,
truncate: bool,
batch_size: int,
skip_conflicts: bool,
source_label: str,
) -> dict[str, int]:
imported: dict[str, int] = {}
sqlite_tables = _sqlite_tables(sqlite_conn)
pg_tables = _postgres_tables(pg_conn)
common_tables = [table for table in sqlite_tables if table in pg_tables]
missing_tables = [table for table in sqlite_tables if table not in pg_tables]
if missing_tables:
print(f"[import:{source_label}] skip tables absent in PostgreSQL: {', '.join(missing_tables)}")
if truncate:
print(f"[import:{source_label}] truncate {len(common_tables)} table(s)")
_truncate_tables(pg_conn, common_tables)
for table in common_tables:
sqlite_cols = _sqlite_columns(sqlite_conn, table)
pg_cols = _postgres_columns(pg_conn, table)
columns = [col for col in pg_cols if col in sqlite_cols]
if not columns:
imported[table] = 0
continue
select_sql = "SELECT {} FROM {}".format(
", ".join(f'"{col}"' for col in columns),
f'"{table}"',
)
rows = sqlite_conn.execute(select_sql)
insert_sql = sql.SQL("INSERT INTO {table} ({cols}) VALUES ({values}) {conflict}").format(
table=sql.Identifier(table),
cols=sql.SQL(", ").join(sql.Identifier(col) for col in columns),
values=sql.SQL(", ").join(sql.Placeholder() for _ in columns),
conflict=sql.SQL("ON CONFLICT DO NOTHING") if skip_conflicts else sql.SQL(""),
)
count = 0
for batch in _batched(rows, batch_size):
values = [tuple(row[col] for col in columns) for row in batch]
with pg_conn.cursor() as cur:
cur.executemany(insert_sql, values)
count += len(values)
imported[table] = count
print(f"[import:{source_label}] {table}: {count}")
return imported
def import_sqlite(
sqlite_path: Path,
database_url: str | None = None,
*,
scheduler_sqlite_path: Path | None = None,
truncate: bool = False,
batch_size: int = 1000,
skip_conflicts: bool = False,
) -> dict[str, int]:
if not sqlite_path.exists():
raise FileNotFoundError(f"SQLite database not found: {sqlite_path}")
sqlite_conn = sqlite3.connect(str(sqlite_path))
sqlite_conn.row_factory = sqlite3.Row
imported: dict[str, int] = {}
scheduler_conn: sqlite3.Connection | None = None
try:
if scheduler_sqlite_path and scheduler_sqlite_path.exists():
scheduler_conn = sqlite3.connect(str(scheduler_sqlite_path))
scheduler_conn.row_factory = sqlite3.Row
elif scheduler_sqlite_path:
print(f"[import:scheduler] skip missing scheduler db: {scheduler_sqlite_path}")
with connect(database_url) as pg_conn:
with pg_conn.transaction():
imported.update(
_import_one_sqlite(
sqlite_conn,
pg_conn,
truncate=truncate,
batch_size=batch_size,
skip_conflicts=skip_conflicts,
source_label="main",
)
)
if scheduler_conn:
imported.update(
_import_one_sqlite(
scheduler_conn,
pg_conn,
truncate=truncate,
batch_size=batch_size,
skip_conflicts=skip_conflicts,
source_label="scheduler",
)
)
_reset_sequences(pg_conn)
_apply_post_import_fixes(pg_conn)
finally:
sqlite_conn.close()
if scheduler_conn:
scheduler_conn.close()
return imported
def main() -> int:
parser = argparse.ArgumentParser(description="Import AlphaX SQLite data into PostgreSQL.")
parser.add_argument("--sqlite-path", type=Path, default=DEFAULT_SQLITE_PATH)
parser.add_argument(
"--scheduler-sqlite-path",
type=Path,
default=DEFAULT_SCHEDULER_SQLITE_PATH,
help="Optional scheduler_state.db path. Use an empty string to skip.",
)
parser.add_argument("--database-url", default=None, help="Override DATABASE_URL.")
parser.add_argument("--truncate", action="store_true", help="Clear target tables before import.")
parser.add_argument("--batch-size", type=int, default=1000)
parser.add_argument(
"--skip-conflicts",
action="store_true",
help="Use ON CONFLICT DO NOTHING. Prefer --truncate for clean migrations.",
)
args = parser.parse_args()
scheduler_sqlite_path = args.scheduler_sqlite_path
if str(scheduler_sqlite_path).strip() == "":
scheduler_sqlite_path = None
imported = import_sqlite(
args.sqlite_path,
args.database_url,
scheduler_sqlite_path=scheduler_sqlite_path,
truncate=args.truncate,
batch_size=args.batch_size,
skip_conflicts=args.skip_conflicts,
)
print(f"[import] completed {len(imported)} table(s), {sum(imported.values())} row(s)")
return 0
if __name__ == "__main__":
raise SystemExit(main())