285 lines
9.1 KiB
Python
Executable File
285 lines
9.1 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Import current SQLite data into the PostgreSQL migration target."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import sqlite3
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
from psycopg import sql
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
from app.db.postgres_connection import connect # noqa: E402
|
|
|
|
|
|
DEFAULT_SQLITE_PATH = REPO_ROOT / "data" / "altcoin_monitor.db"
|
|
DEFAULT_SCHEDULER_SQLITE_PATH = REPO_ROOT / "data" / "scheduler_state.db"
|
|
EXCLUDED_TABLES = {"sqlite_sequence", "schema_migrations"}
|
|
|
|
|
|
def _sqlite_tables(conn: sqlite3.Connection) -> list[str]:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT name
|
|
FROM sqlite_master
|
|
WHERE type='table'
|
|
AND name NOT LIKE 'sqlite_%'
|
|
ORDER BY name
|
|
"""
|
|
).fetchall()
|
|
return [row["name"] for row in rows if row["name"] not in EXCLUDED_TABLES]
|
|
|
|
|
|
def _sqlite_columns(conn: sqlite3.Connection, table: str) -> list[str]:
|
|
return [row["name"] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()]
|
|
|
|
|
|
def _postgres_tables(conn) -> set[str]:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema='public'
|
|
AND table_type='BASE TABLE'
|
|
"""
|
|
).fetchall()
|
|
return {row[0] for row in rows if row[0] not in EXCLUDED_TABLES}
|
|
|
|
|
|
def _postgres_columns(conn, table: str) -> list[str]:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema='public'
|
|
AND table_name=%s
|
|
ORDER BY ordinal_position
|
|
""",
|
|
(table,),
|
|
).fetchall()
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def _serial_columns(conn) -> list[tuple[str, str]]:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT table_name, column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema='public'
|
|
AND column_default LIKE 'nextval(%'
|
|
ORDER BY table_name, column_name
|
|
"""
|
|
).fetchall()
|
|
return [(row[0], row[1]) for row in rows]
|
|
|
|
|
|
def _batched(rows: Iterable[sqlite3.Row], size: int) -> Iterable[list[sqlite3.Row]]:
|
|
batch: list[sqlite3.Row] = []
|
|
for row in rows:
|
|
batch.append(row)
|
|
if len(batch) >= size:
|
|
yield batch
|
|
batch = []
|
|
if batch:
|
|
yield batch
|
|
|
|
|
|
def _truncate_tables(conn, tables: list[str]) -> None:
|
|
if not tables:
|
|
return
|
|
stmt = sql.SQL("TRUNCATE TABLE {tables} RESTART IDENTITY CASCADE").format(
|
|
tables=sql.SQL(", ").join(sql.Identifier(t) for t in tables)
|
|
)
|
|
conn.execute(stmt)
|
|
|
|
|
|
def _reset_sequences(conn) -> None:
|
|
for table, column in _serial_columns(conn):
|
|
conn.execute(
|
|
sql.SQL(
|
|
"""
|
|
SELECT setval(
|
|
pg_get_serial_sequence({table_name}, {column_name}),
|
|
COALESCE((SELECT MAX({column}) FROM {table}), 0) + 1,
|
|
false
|
|
)
|
|
"""
|
|
).format(
|
|
table_name=sql.Literal(table),
|
|
column_name=sql.Literal(column),
|
|
column=sql.Identifier(column),
|
|
table=sql.Identifier(table),
|
|
)
|
|
)
|
|
|
|
|
|
def _apply_post_import_fixes(conn) -> None:
|
|
"""Bring imported legacy rows up to current PostgreSQL runtime invariants."""
|
|
conn.execute(
|
|
"""
|
|
UPDATE recommendation
|
|
SET display_bucket='history',
|
|
execution_status='invalid',
|
|
lifecycle_state='closed',
|
|
entry_triggered=0,
|
|
state_reason=COALESCE(NULLIF(state_reason, ''), '机会失效,归入历史复盘')
|
|
WHERE status IN ('expired', 'invalid', 'archived', 'stopped_out')
|
|
AND COALESCE(display_bucket, '') != 'history'
|
|
"""
|
|
)
|
|
|
|
|
|
def _import_one_sqlite(
|
|
sqlite_conn: sqlite3.Connection,
|
|
pg_conn,
|
|
*,
|
|
truncate: bool,
|
|
batch_size: int,
|
|
skip_conflicts: bool,
|
|
source_label: str,
|
|
) -> dict[str, int]:
|
|
imported: dict[str, int] = {}
|
|
sqlite_tables = _sqlite_tables(sqlite_conn)
|
|
pg_tables = _postgres_tables(pg_conn)
|
|
common_tables = [table for table in sqlite_tables if table in pg_tables]
|
|
missing_tables = [table for table in sqlite_tables if table not in pg_tables]
|
|
|
|
if missing_tables:
|
|
print(f"[import:{source_label}] skip tables absent in PostgreSQL: {', '.join(missing_tables)}")
|
|
|
|
if truncate:
|
|
print(f"[import:{source_label}] truncate {len(common_tables)} table(s)")
|
|
_truncate_tables(pg_conn, common_tables)
|
|
|
|
for table in common_tables:
|
|
sqlite_cols = _sqlite_columns(sqlite_conn, table)
|
|
pg_cols = _postgres_columns(pg_conn, table)
|
|
columns = [col for col in pg_cols if col in sqlite_cols]
|
|
if not columns:
|
|
imported[table] = 0
|
|
continue
|
|
|
|
select_sql = "SELECT {} FROM {}".format(
|
|
", ".join(f'"{col}"' for col in columns),
|
|
f'"{table}"',
|
|
)
|
|
rows = sqlite_conn.execute(select_sql)
|
|
insert_sql = sql.SQL("INSERT INTO {table} ({cols}) VALUES ({values}) {conflict}").format(
|
|
table=sql.Identifier(table),
|
|
cols=sql.SQL(", ").join(sql.Identifier(col) for col in columns),
|
|
values=sql.SQL(", ").join(sql.Placeholder() for _ in columns),
|
|
conflict=sql.SQL("ON CONFLICT DO NOTHING") if skip_conflicts else sql.SQL(""),
|
|
)
|
|
count = 0
|
|
for batch in _batched(rows, batch_size):
|
|
values = [tuple(row[col] for col in columns) for row in batch]
|
|
with pg_conn.cursor() as cur:
|
|
cur.executemany(insert_sql, values)
|
|
count += len(values)
|
|
imported[table] = count
|
|
print(f"[import:{source_label}] {table}: {count}")
|
|
|
|
return imported
|
|
|
|
|
|
def import_sqlite(
|
|
sqlite_path: Path,
|
|
database_url: str | None = None,
|
|
*,
|
|
scheduler_sqlite_path: Path | None = None,
|
|
truncate: bool = False,
|
|
batch_size: int = 1000,
|
|
skip_conflicts: bool = False,
|
|
) -> dict[str, int]:
|
|
if not sqlite_path.exists():
|
|
raise FileNotFoundError(f"SQLite database not found: {sqlite_path}")
|
|
|
|
sqlite_conn = sqlite3.connect(str(sqlite_path))
|
|
sqlite_conn.row_factory = sqlite3.Row
|
|
imported: dict[str, int] = {}
|
|
scheduler_conn: sqlite3.Connection | None = None
|
|
|
|
try:
|
|
if scheduler_sqlite_path and scheduler_sqlite_path.exists():
|
|
scheduler_conn = sqlite3.connect(str(scheduler_sqlite_path))
|
|
scheduler_conn.row_factory = sqlite3.Row
|
|
elif scheduler_sqlite_path:
|
|
print(f"[import:scheduler] skip missing scheduler db: {scheduler_sqlite_path}")
|
|
|
|
with connect(database_url) as pg_conn:
|
|
with pg_conn.transaction():
|
|
imported.update(
|
|
_import_one_sqlite(
|
|
sqlite_conn,
|
|
pg_conn,
|
|
truncate=truncate,
|
|
batch_size=batch_size,
|
|
skip_conflicts=skip_conflicts,
|
|
source_label="main",
|
|
)
|
|
)
|
|
if scheduler_conn:
|
|
imported.update(
|
|
_import_one_sqlite(
|
|
scheduler_conn,
|
|
pg_conn,
|
|
truncate=truncate,
|
|
batch_size=batch_size,
|
|
skip_conflicts=skip_conflicts,
|
|
source_label="scheduler",
|
|
)
|
|
)
|
|
|
|
_reset_sequences(pg_conn)
|
|
_apply_post_import_fixes(pg_conn)
|
|
finally:
|
|
sqlite_conn.close()
|
|
if scheduler_conn:
|
|
scheduler_conn.close()
|
|
|
|
return imported
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Import AlphaX SQLite data into PostgreSQL.")
|
|
parser.add_argument("--sqlite-path", type=Path, default=DEFAULT_SQLITE_PATH)
|
|
parser.add_argument(
|
|
"--scheduler-sqlite-path",
|
|
type=Path,
|
|
default=DEFAULT_SCHEDULER_SQLITE_PATH,
|
|
help="Optional scheduler_state.db path. Use an empty string to skip.",
|
|
)
|
|
parser.add_argument("--database-url", default=None, help="Override DATABASE_URL.")
|
|
parser.add_argument("--truncate", action="store_true", help="Clear target tables before import.")
|
|
parser.add_argument("--batch-size", type=int, default=1000)
|
|
parser.add_argument(
|
|
"--skip-conflicts",
|
|
action="store_true",
|
|
help="Use ON CONFLICT DO NOTHING. Prefer --truncate for clean migrations.",
|
|
)
|
|
args = parser.parse_args()
|
|
scheduler_sqlite_path = args.scheduler_sqlite_path
|
|
if str(scheduler_sqlite_path).strip() == "":
|
|
scheduler_sqlite_path = None
|
|
|
|
imported = import_sqlite(
|
|
args.sqlite_path,
|
|
args.database_url,
|
|
scheduler_sqlite_path=scheduler_sqlite_path,
|
|
truncate=args.truncate,
|
|
batch_size=args.batch_size,
|
|
skip_conflicts=args.skip_conflicts,
|
|
)
|
|
print(f"[import] completed {len(imported)} table(s), {sum(imported.values())} row(s)")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|