alphax/app/services/live_trading_account.py
2026-06-07 23:43:54 +08:00

361 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Account-centric read model for live trading console."""
from __future__ import annotations
from datetime import datetime
from app.db.live_trading import (
_safe_float,
get_live_account,
get_live_account_snapshot,
list_enabled_live_accounts,
list_live_order_events,
list_live_order_intents,
upsert_live_account_snapshot,
)
from app.integrations.binance_live import LiveTradingConfigError, build_binance_client
_ACCOUNT_OVERVIEW_CACHE: dict[int, dict] = {}
def _now() -> str:
return datetime.now().isoformat()
def _compact_balance(balance: dict) -> dict:
total = balance.get("total") if isinstance(balance.get("total"), dict) else {}
free = balance.get("free") if isinstance(balance.get("free"), dict) else {}
used = balance.get("used") if isinstance(balance.get("used"), dict) else {}
assets = []
for asset in sorted(set(total) | set(free) | set(used)):
total_value = _safe_float(total.get(asset))
free_value = _safe_float(free.get(asset))
used_value = _safe_float(used.get(asset))
if abs(total_value) > 0 or abs(free_value) > 0 or abs(used_value) > 0:
assets.append({"asset": asset, "free": free_value, "used": used_value, "total": total_value})
return {
"assets": assets,
"usdt": {
"free": _safe_float(free.get("USDT")),
"used": _safe_float(used.get("USDT")),
"total": _safe_float(total.get("USDT")),
},
}
def _position_side_label(side: str) -> str:
side = str(side or "").strip().lower()
if side in {"long", "buy"}:
return ""
if side in {"short", "sell"}:
return ""
return "--"
def _compact_position(item: dict, account: dict | None = None) -> dict:
info = item.get("info") if isinstance(item.get("info"), dict) else {}
contracts = _safe_float(item.get("contracts") or info.get("positionAmt"))
notional = _safe_float(item.get("notional") or info.get("notional"))
entry_price = _safe_float(item.get("entryPrice") or info.get("entryPrice"))
mark_price = _safe_float(item.get("markPrice") or info.get("markPrice"))
position_value = abs(notional)
if position_value <= 0 and abs(contracts) > 0 and mark_price > 0:
position_value = abs(contracts) * mark_price
margin = _safe_float(
item.get("initialMargin")
or item.get("collateral")
or info.get("initialMargin")
or info.get("positionInitialMargin")
or info.get("isolatedWallet")
)
leverage = _safe_float(item.get("leverage") or info.get("leverage"))
leverage_source = "exchange"
if leverage <= 0 and position_value > 0 and margin > 0:
leverage = position_value / margin
leverage_source = "computed"
if leverage <= 0 and account:
risk = account.get("risk_config") if isinstance(account.get("risk_config"), dict) else {}
leverage = _safe_float(risk.get("max_symbol_leverage"), 0)
leverage_source = "account_config" if leverage > 0 else "missing"
side = item.get("side") or ("long" if contracts > 0 else ("short" if contracts < 0 else ""))
return {
"symbol": item.get("symbol") or info.get("symbol"),
"side": side,
"side_label": _position_side_label(side),
"contracts": abs(contracts),
"entry_price": entry_price,
"mark_price": mark_price,
"notional": notional,
"position_value_usdt": position_value,
"margin_usdt": margin,
"unrealized_pnl": _safe_float(item.get("unrealizedPnl") or info.get("unrealizedProfit")),
"leverage": leverage,
"leverage_source": leverage_source,
}
def _compact_order(item: dict) -> dict:
info = item.get("info") if isinstance(item.get("info"), dict) else {}
return {
"id": str(item.get("id") or info.get("orderId") or ""),
"client_order_id": item.get("clientOrderId") or info.get("clientOrderId") or "",
"symbol": item.get("symbol") or info.get("symbol"),
"type": item.get("type") or info.get("type"),
"side": item.get("side") or info.get("side"),
"status": item.get("status") or info.get("status"),
"price": _safe_float(item.get("price") or info.get("price")),
"amount": _safe_float(item.get("amount") or info.get("origQty")),
"filled": _safe_float(item.get("filled") or info.get("executedQty")),
"average": _safe_float(item.get("average") or info.get("avgPrice")),
"timestamp": item.get("datetime") or item.get("timestamp") or info.get("updateTime") or info.get("time"),
}
def _normalize_symbol(symbol: str) -> str:
value = str(symbol or "").strip().upper()
if value and "/" not in value and value.endswith("USDT"):
value = value[:-4] + "/USDT"
return value
def _order_history_symbols(account: dict, overview: dict) -> list[str]:
"""Build the smallest safe symbol set for Binance order-history queries."""
risk = account.get("risk_config") if isinstance(account.get("risk_config"), dict) else {}
symbols: list[str] = []
for raw in risk.get("allowed_symbols") or []:
symbol = _normalize_symbol(raw)
if symbol:
symbols.append(symbol)
for row in (overview.get("positions") or []) + (overview.get("open_orders") or []):
symbol = _normalize_symbol(row.get("symbol"))
if symbol:
symbols.append(symbol)
for row in overview.get("intent_history") or []:
symbol = _normalize_symbol(row.get("symbol"))
if symbol:
symbols.append(symbol)
result = []
seen = set()
for symbol in symbols:
key = symbol.upper()
if key not in seen:
seen.add(key)
result.append(symbol)
return result[:20]
def _fetch_order_history_by_symbol(client, symbols: list[str], limit: int) -> tuple[list[dict], list[str]]:
orders = []
errors = []
if not symbols:
return orders, errors
per_symbol_limit = max(1, min(int(limit or 30), 50))
for symbol in symbols:
try:
orders.extend(_compact_order(o) for o in client.fetch_orders(symbol, limit=per_symbol_limit))
except Exception as exc:
errors.append(f"订单历史读取失败 {symbol}{exc}")
orders.sort(key=lambda x: str(x.get("timestamp") or ""), reverse=True)
return orders[: max(1, int(limit or 30))], errors
def _account_risk_view(account: dict) -> dict:
risk = account.get("risk_config") if isinstance(account.get("risk_config"), dict) else {}
allowed = [str(x).strip().upper() for x in risk.get("allowed_symbols", []) if str(x).strip()]
max_leverage = _safe_float(risk.get("max_symbol_leverage"), 1)
margin = _safe_float(risk.get("max_order_margin_usdt"), 0)
return {
"max_order_margin_usdt": margin,
"max_symbol_leverage": max_leverage,
"max_order_notional_usdt": _safe_float(risk.get("max_order_notional_usdt"), margin * max(1.0, max_leverage)),
"max_cumulative_leverage": _safe_float(risk.get("max_cumulative_leverage"), 1),
"max_daily_order_count": int(risk.get("max_daily_order_count") or 0),
"allowed_symbols": allowed,
"symbol_policy": "all" if not allowed else "allowlist",
}
def _cache_overview(account_id: int, overview: dict) -> dict:
_ACCOUNT_OVERVIEW_CACHE[int(account_id)] = overview
return overview
def _cached_overview(account_id: int) -> dict | None:
item = _ACCOUNT_OVERVIEW_CACHE.get(int(account_id))
if not item:
return None
cached = dict(item)
cached["exchange_cache"] = {**(cached.get("exchange_cache") or {}), "cached": True}
return cached
def _exchange_snapshot_payload(overview: dict) -> dict:
return {
"balance": overview.get("balance") or {"assets": [], "usdt": {"free": 0, "used": 0, "total": 0}},
"positions": overview.get("positions") or [],
"open_orders": overview.get("open_orders") or [],
"order_history": overview.get("order_history") or [],
"exchange_cache": overview.get("exchange_cache") or {},
"errors": overview.get("errors") or [],
}
def _merge_snapshot(overview: dict, snapshot_row: dict) -> dict:
payload = snapshot_row.get("snapshot") if isinstance(snapshot_row.get("snapshot"), dict) else {}
if not payload:
return overview
for key in ("balance", "positions", "open_orders", "order_history", "errors"):
if key in payload:
overview[key] = payload.get(key)
synced_at = snapshot_row.get("synced_at") or ""
status = snapshot_row.get("status") or ""
error_message = snapshot_row.get("error_message") or ""
overview["exchange_cache"] = {
**(payload.get("exchange_cache") or {}),
"cached": True,
"loaded": bool((payload.get("exchange_cache") or {}).get("loaded", True)),
"requires_refresh": False,
"source": "database",
"status": status,
"synced_at": synced_at,
}
if status == "error" and error_message and error_message not in (overview.get("errors") or []):
overview.setdefault("errors", []).append(error_message)
return overview
def get_live_account_overview(account_id: int, *, history_limit: int = 30, refresh: bool = False, client_factory=None) -> dict:
account = get_live_account(account_id)
if not account:
raise LiveTradingConfigError("live account not found")
overview = {
"account": account,
"risk": _account_risk_view(account),
"balance": {"assets": [], "usdt": {"free": 0, "used": 0, "total": 0}},
"positions": [],
"open_orders": [],
"order_history": [],
"intent_history": list_live_order_intents(limit=history_limit, account_id=account_id).get("items", []),
"events": list_live_order_events(limit=history_limit).get("items", []),
"exchange_cache": {"cached": False, "loaded": False, "requires_refresh": True},
"errors": [],
}
if account.get("status") != "enabled":
return overview
if not refresh:
snapshot = get_live_account_snapshot(account_id)
if snapshot:
return _merge_snapshot(overview, snapshot)
cached = _cached_overview(account_id)
if cached:
cached["account"] = account
cached["risk"] = overview["risk"]
cached["intent_history"] = overview["intent_history"]
cached["events"] = overview["events"]
return cached
overview["exchange_cache"]["reason"] = "等待后台实盘同步生成账户快照"
return overview
synced_at = _now()
try:
client = client_factory(account) if client_factory else build_binance_client(account, require_testnet=True)
client.load_markets()
except Exception as exc:
overview["errors"].append(f"账户连接失败:{exc}")
overview["exchange_cache"] = {
"cached": False,
"loaded": False,
"requires_refresh": True,
"source": "exchange",
"status": "error",
"synced_at": synced_at,
}
upsert_live_account_snapshot(
account_id,
_exchange_snapshot_payload(overview),
status="error",
error_message=overview["errors"][0],
synced_at=synced_at,
)
return overview
try:
overview["balance"] = _compact_balance(client.fetch_balance())
except Exception as exc:
overview["errors"].append(f"余额读取失败:{exc}")
try:
overview["positions"] = [
item for item in (_compact_position(p, account) for p in client.fetch_positions(None))
if abs(_safe_float(item.get("contracts"))) > 0
]
except Exception as exc:
overview["errors"].append(f"持仓读取失败:{exc}")
try:
overview["open_orders"] = [_compact_order(o) for o in client.fetch_open_orders(None)]
except Exception as exc:
overview["errors"].append(f"挂单读取失败:{exc}")
symbols = _order_history_symbols(account, overview)
order_history, order_history_errors = _fetch_order_history_by_symbol(client, symbols, history_limit)
overview["order_history"] = order_history
overview["errors"].extend(order_history_errors)
overview["exchange_cache"] = {
"cached": False,
"loaded": True,
"requires_refresh": False,
"source": "exchange",
"status": "error" if overview["errors"] else "ok",
"synced_at": synced_at,
}
upsert_live_account_snapshot(
account_id,
_exchange_snapshot_payload(overview),
status="error" if overview["errors"] else "ok",
error_message="".join(overview["errors"][:3]),
synced_at=synced_at,
)
return _cache_overview(account_id, overview)
def sync_live_account_snapshots(*, account_ids: list[int] | None = None, history_limit: int = 30, client_factory=None) -> dict:
selected = {int(x) for x in (account_ids or []) if int(x or 0) > 0}
accounts = list_enabled_live_accounts()
if selected:
accounts = [account for account in accounts if int(account.get("id") or 0) in selected]
items = []
ok_count = 0
error_count = 0
for account in accounts:
try:
overview = get_live_account_overview(
account["id"],
history_limit=history_limit,
refresh=True,
client_factory=client_factory,
)
status = (overview.get("exchange_cache") or {}).get("status") or ("error" if overview.get("errors") else "ok")
if status == "ok":
ok_count += 1
else:
error_count += 1
items.append({
"account_id": account["id"],
"account_code": account.get("account_code"),
"status": status,
"synced_at": (overview.get("exchange_cache") or {}).get("synced_at"),
"errors": overview.get("errors") or [],
})
except Exception as exc:
error_count += 1
items.append({
"account_id": account.get("id"),
"account_code": account.get("account_code"),
"status": "error",
"errors": [str(exc)],
})
return {
"ok": error_count == 0,
"total": len(accounts),
"ok_count": ok_count,
"error_count": error_count,
"items": items,
}