alphax/app/services/live_trading_account.py
2026-06-07 23:53:07 +08:00

491 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Account-centric read model for live trading console."""
from __future__ import annotations
from datetime import datetime
from app.db.live_trading import (
_safe_float,
get_live_account,
get_live_account_snapshot,
list_live_account_equity_history,
list_enabled_live_accounts,
list_live_order_events,
list_live_order_intents,
record_live_account_equity_snapshot,
upsert_live_account_snapshot,
)
from app.integrations.binance_live import LiveTradingConfigError, build_binance_client
_ACCOUNT_OVERVIEW_CACHE: dict[int, dict] = {}
def _now() -> str:
return datetime.now().isoformat()
def _compact_balance(balance: dict) -> dict:
total = balance.get("total") if isinstance(balance.get("total"), dict) else {}
free = balance.get("free") if isinstance(balance.get("free"), dict) else {}
used = balance.get("used") if isinstance(balance.get("used"), dict) else {}
assets = []
for asset in sorted(set(total) | set(free) | set(used)):
total_value = _safe_float(total.get(asset))
free_value = _safe_float(free.get(asset))
used_value = _safe_float(used.get(asset))
if abs(total_value) > 0 or abs(free_value) > 0 or abs(used_value) > 0:
assets.append({"asset": asset, "free": free_value, "used": used_value, "total": total_value})
return {
"assets": assets,
"usdt": {
"free": _safe_float(free.get("USDT")),
"used": _safe_float(used.get("USDT")),
"total": _safe_float(total.get("USDT")),
},
}
def _position_side_label(side: str) -> str:
side = str(side or "").strip().lower()
if side in {"long", "buy"}:
return ""
if side in {"short", "sell"}:
return ""
return "--"
def _position_pnl_pct(unrealized_pnl: float, margin: float, position_value: float) -> float:
if margin > 0:
return round(unrealized_pnl / margin * 100.0, 6)
if position_value > 0:
return round(unrealized_pnl / position_value * 100.0, 6)
return 0.0
def _compact_position(item: dict, account: dict | None = None) -> dict:
info = item.get("info") if isinstance(item.get("info"), dict) else {}
contracts = _safe_float(item.get("contracts") or info.get("positionAmt"))
notional = _safe_float(item.get("notional") or info.get("notional"))
entry_price = _safe_float(item.get("entryPrice") or info.get("entryPrice"))
mark_price = _safe_float(item.get("markPrice") or info.get("markPrice"))
position_value = abs(notional)
if position_value <= 0 and abs(contracts) > 0 and mark_price > 0:
position_value = abs(contracts) * mark_price
margin = _safe_float(
item.get("initialMargin")
or item.get("collateral")
or info.get("initialMargin")
or info.get("positionInitialMargin")
or info.get("isolatedWallet")
)
leverage = _safe_float(item.get("leverage") or info.get("leverage"))
leverage_source = "exchange"
if leverage <= 0 and position_value > 0 and margin > 0:
leverage = position_value / margin
leverage_source = "computed"
if leverage <= 0 and account:
risk = account.get("risk_config") if isinstance(account.get("risk_config"), dict) else {}
leverage = _safe_float(risk.get("max_symbol_leverage"), 0)
leverage_source = "account_config" if leverage > 0 else "missing"
side = item.get("side") or ("long" if contracts > 0 else ("short" if contracts < 0 else ""))
return {
"symbol": item.get("symbol") or info.get("symbol"),
"side": side,
"side_label": _position_side_label(side),
"contracts": abs(contracts),
"entry_price": entry_price,
"mark_price": mark_price,
"notional": notional,
"position_value_usdt": position_value,
"margin_usdt": margin,
"unrealized_pnl": _safe_float(item.get("unrealizedPnl") or info.get("unrealizedProfit")),
"leverage": leverage,
"leverage_source": leverage_source,
}
def _compact_order(item: dict) -> dict:
info = item.get("info") if isinstance(item.get("info"), dict) else {}
return {
"id": str(item.get("id") or info.get("orderId") or ""),
"client_order_id": item.get("clientOrderId") or info.get("clientOrderId") or "",
"symbol": item.get("symbol") or info.get("symbol"),
"type": item.get("type") or info.get("type"),
"side": item.get("side") or info.get("side"),
"status": item.get("status") or info.get("status"),
"price": _safe_float(item.get("price") or info.get("price")),
"amount": _safe_float(item.get("amount") or info.get("origQty")),
"filled": _safe_float(item.get("filled") or info.get("executedQty")),
"average": _safe_float(item.get("average") or info.get("avgPrice")),
"realized_pnl": _safe_float(item.get("realizedPnl") or info.get("realizedPnl") or info.get("realizedProfit")),
"reduce_only": bool(item.get("reduceOnly") or info.get("reduceOnly")),
"position_side": item.get("positionSide") or info.get("positionSide") or "",
"timestamp": item.get("datetime") or item.get("timestamp") or info.get("updateTime") or info.get("time"),
}
def _enrich_positions(positions: list[dict]) -> list[dict]:
enriched = []
for item in positions or []:
row = dict(item)
unrealized = _safe_float(row.get("unrealized_pnl"))
margin = _safe_float(row.get("margin_usdt"))
value = _safe_float(row.get("position_value_usdt"))
row["pnl_pct"] = _position_pnl_pct(unrealized, margin, value)
if unrealized > 0:
row["pnl_status"] = "profit"
elif unrealized < 0:
row["pnl_status"] = "loss"
else:
row["pnl_status"] = "flat"
enriched.append(row)
return enriched
def _normalize_symbol(symbol: str) -> str:
value = str(symbol or "").strip().upper()
if value and "/" not in value and value.endswith("USDT"):
value = value[:-4] + "/USDT"
return value
def _order_history_symbols(account: dict, overview: dict) -> list[str]:
"""Build the smallest safe symbol set for Binance order-history queries."""
risk = account.get("risk_config") if isinstance(account.get("risk_config"), dict) else {}
symbols: list[str] = []
for raw in risk.get("allowed_symbols") or []:
symbol = _normalize_symbol(raw)
if symbol:
symbols.append(symbol)
for row in (overview.get("positions") or []) + (overview.get("open_orders") or []):
symbol = _normalize_symbol(row.get("symbol"))
if symbol:
symbols.append(symbol)
for row in overview.get("intent_history") or []:
symbol = _normalize_symbol(row.get("symbol"))
if symbol:
symbols.append(symbol)
result = []
seen = set()
for symbol in symbols:
key = symbol.upper()
if key not in seen:
seen.add(key)
result.append(symbol)
return result[:20]
def _fetch_order_history_by_symbol(client, symbols: list[str], limit: int) -> tuple[list[dict], list[str]]:
orders = []
errors = []
if not symbols:
return orders, errors
per_symbol_limit = max(1, min(int(limit or 30), 50))
for symbol in symbols:
try:
orders.extend(_compact_order(o) for o in client.fetch_orders(symbol, limit=per_symbol_limit))
except Exception as exc:
errors.append(f"订单历史读取失败 {symbol}{exc}")
orders.sort(key=lambda x: str(x.get("timestamp") or ""), reverse=True)
return orders[: max(1, int(limit or 30))], errors
def _account_risk_view(account: dict) -> dict:
risk = account.get("risk_config") if isinstance(account.get("risk_config"), dict) else {}
allowed = [str(x).strip().upper() for x in risk.get("allowed_symbols", []) if str(x).strip()]
max_leverage = _safe_float(risk.get("max_symbol_leverage"), 1)
margin = _safe_float(risk.get("max_order_margin_usdt"), 0)
return {
"max_order_margin_usdt": margin,
"max_symbol_leverage": max_leverage,
"max_order_notional_usdt": _safe_float(risk.get("max_order_notional_usdt"), margin * max(1.0, max_leverage)),
"max_cumulative_leverage": _safe_float(risk.get("max_cumulative_leverage"), 1),
"max_daily_order_count": int(risk.get("max_daily_order_count") or 0),
"allowed_symbols": allowed,
"symbol_policy": "all" if not allowed else "allowlist",
}
def _cache_overview(account_id: int, overview: dict) -> dict:
_ACCOUNT_OVERVIEW_CACHE[int(account_id)] = overview
return overview
def _cached_overview(account_id: int) -> dict | None:
item = _ACCOUNT_OVERVIEW_CACHE.get(int(account_id))
if not item:
return None
cached = dict(item)
cached["exchange_cache"] = {**(cached.get("exchange_cache") or {}), "cached": True}
return cached
def _exchange_snapshot_payload(overview: dict) -> dict:
return {
"balance": overview.get("balance") or {"assets": [], "usdt": {"free": 0, "used": 0, "total": 0}},
"positions": overview.get("positions") or [],
"open_orders": overview.get("open_orders") or [],
"order_history": overview.get("order_history") or [],
"exchange_cache": overview.get("exchange_cache") or {},
"errors": overview.get("errors") or [],
"performance": overview.get("performance") or {},
"historical_positions": overview.get("historical_positions") or [],
}
def _current_equity_metrics(overview: dict) -> dict:
balance = overview.get("balance") if isinstance(overview.get("balance"), dict) else {}
usdt = balance.get("usdt") if isinstance(balance.get("usdt"), dict) else {}
positions = overview.get("positions") or []
equity = _safe_float(usdt.get("total"))
available = _safe_float(usdt.get("free"))
used = _safe_float(usdt.get("used"))
unrealized = sum(_safe_float(x.get("unrealized_pnl")) for x in positions)
position_value = sum(abs(_safe_float(x.get("position_value_usdt"))) for x in positions)
return {
"equity_usdt": round(equity, 8),
"wallet_balance_usdt": round(equity - unrealized, 8),
"available_usdt": round(available, 8),
"used_margin_usdt": round(used, 8),
"unrealized_pnl_usdt": round(unrealized, 8),
"open_position_value_usdt": round(position_value, 8),
"position_count": len(positions),
}
def _max_drawdown(history: list[dict]) -> tuple[float, float]:
peak = 0.0
max_dd_pct = 0.0
max_dd_usdt = 0.0
for row in history:
equity = _safe_float(row.get("equity_usdt"))
if equity <= 0:
continue
peak = max(peak, equity)
if peak <= 0:
continue
dd_usdt = max(0.0, peak - equity)
dd_pct = dd_usdt / peak * 100.0
if dd_pct > max_dd_pct:
max_dd_pct = dd_pct
max_dd_usdt = dd_usdt
return round(max_dd_pct, 6), round(max_dd_usdt, 8)
def _historical_positions_from_orders(orders: list[dict]) -> list[dict]:
rows = []
for order in orders or []:
status = str(order.get("status") or "").lower()
if status not in {"closed", "filled"}:
continue
pnl = _safe_float(order.get("realized_pnl"))
result = "盈利" if pnl > 0 else "亏损" if pnl < 0 else "未知"
rows.append({
"time": order.get("timestamp") or "",
"symbol": order.get("symbol") or "",
"side": order.get("side") or "",
"side_label": _position_side_label(order.get("side")),
"price": _safe_float(order.get("average") or order.get("price")),
"amount": _safe_float(order.get("filled") or order.get("amount")),
"realized_pnl": pnl,
"result": result,
"status": order.get("status") or "",
})
return rows[:30]
def _attach_performance(overview: dict, account_id: int, *, record_history: bool = False, snapshot_at: str = "") -> dict:
overview["positions"] = _enrich_positions(overview.get("positions") or [])
metrics = _current_equity_metrics(overview)
if record_history and metrics["equity_usdt"] > 0:
record_live_account_equity_snapshot(account_id, **metrics, snapshot_at=snapshot_at or _now())
history = list_live_account_equity_history(account_id)
if not history and metrics["equity_usdt"] > 0:
history = [{**metrics, "snapshot_at": snapshot_at or _now()}]
baseline = _safe_float(history[0].get("equity_usdt")) if history else metrics["equity_usdt"]
peak = max([_safe_float(x.get("equity_usdt")) for x in history] + [metrics["equity_usdt"], 0.0])
total_pnl = metrics["equity_usdt"] - baseline if baseline > 0 else 0.0
return_pct = total_pnl / baseline * 100.0 if baseline > 0 else 0.0
current_dd_usdt = max(0.0, peak - metrics["equity_usdt"])
current_dd_pct = current_dd_usdt / peak * 100.0 if peak > 0 else 0.0
max_dd_pct, max_dd_usdt = _max_drawdown(history + [{**metrics, "snapshot_at": snapshot_at or _now()}])
overview["performance"] = {
**metrics,
"baseline_equity_usdt": round(baseline, 8),
"total_pnl_usdt": round(total_pnl, 8),
"return_pct": round(return_pct, 6),
"peak_equity_usdt": round(peak, 8),
"current_drawdown_usdt": round(current_dd_usdt, 8),
"current_drawdown_pct": round(current_dd_pct, 6),
"max_drawdown_usdt": max_dd_usdt,
"max_drawdown_pct": max_dd_pct,
"history_points": len(history),
"basis": "按首次同步净值计算,未单独扣除充值/提现影响",
}
overview["historical_positions"] = _historical_positions_from_orders(overview.get("order_history") or [])
return overview
def _merge_snapshot(overview: dict, snapshot_row: dict) -> dict:
payload = snapshot_row.get("snapshot") if isinstance(snapshot_row.get("snapshot"), dict) else {}
if not payload:
return overview
for key in ("balance", "positions", "open_orders", "order_history", "errors", "performance", "historical_positions"):
if key in payload:
overview[key] = payload.get(key)
synced_at = snapshot_row.get("synced_at") or ""
status = snapshot_row.get("status") or ""
error_message = snapshot_row.get("error_message") or ""
overview["exchange_cache"] = {
**(payload.get("exchange_cache") or {}),
"cached": True,
"loaded": bool((payload.get("exchange_cache") or {}).get("loaded", True)),
"requires_refresh": False,
"source": "database",
"status": status,
"synced_at": synced_at,
}
if status == "error" and error_message and error_message not in (overview.get("errors") or []):
overview.setdefault("errors", []).append(error_message)
return _attach_performance(overview, int(overview.get("account", {}).get("id") or 0))
def get_live_account_overview(account_id: int, *, history_limit: int = 30, refresh: bool = False, client_factory=None) -> dict:
account = get_live_account(account_id)
if not account:
raise LiveTradingConfigError("live account not found")
overview = {
"account": account,
"risk": _account_risk_view(account),
"balance": {"assets": [], "usdt": {"free": 0, "used": 0, "total": 0}},
"positions": [],
"open_orders": [],
"order_history": [],
"historical_positions": [],
"intent_history": list_live_order_intents(limit=history_limit, account_id=account_id).get("items", []),
"events": list_live_order_events(limit=history_limit).get("items", []),
"performance": {},
"exchange_cache": {"cached": False, "loaded": False, "requires_refresh": True},
"errors": [],
}
if account.get("status") != "enabled":
return overview
if not refresh:
snapshot = get_live_account_snapshot(account_id)
if snapshot:
return _merge_snapshot(overview, snapshot)
cached = _cached_overview(account_id)
if cached:
cached["account"] = account
cached["risk"] = overview["risk"]
cached["intent_history"] = overview["intent_history"]
cached["events"] = overview["events"]
return cached
overview["exchange_cache"]["reason"] = "等待后台实盘同步生成账户快照"
return overview
synced_at = _now()
try:
client = client_factory(account) if client_factory else build_binance_client(account, require_testnet=True)
client.load_markets()
except Exception as exc:
overview["errors"].append(f"账户连接失败:{exc}")
overview["exchange_cache"] = {
"cached": False,
"loaded": False,
"requires_refresh": True,
"source": "exchange",
"status": "error",
"synced_at": synced_at,
}
overview = _attach_performance(overview, account_id, record_history=False, snapshot_at=synced_at)
upsert_live_account_snapshot(
account_id,
_exchange_snapshot_payload(overview),
status="error",
error_message=overview["errors"][0],
synced_at=synced_at,
)
return overview
try:
overview["balance"] = _compact_balance(client.fetch_balance())
except Exception as exc:
overview["errors"].append(f"余额读取失败:{exc}")
try:
overview["positions"] = [
item for item in (_compact_position(p, account) for p in client.fetch_positions(None))
if abs(_safe_float(item.get("contracts"))) > 0
]
except Exception as exc:
overview["errors"].append(f"持仓读取失败:{exc}")
try:
overview["open_orders"] = [_compact_order(o) for o in client.fetch_open_orders(None)]
except Exception as exc:
overview["errors"].append(f"挂单读取失败:{exc}")
symbols = _order_history_symbols(account, overview)
order_history, order_history_errors = _fetch_order_history_by_symbol(client, symbols, history_limit)
overview["order_history"] = order_history
overview["errors"].extend(order_history_errors)
overview["exchange_cache"] = {
"cached": False,
"loaded": True,
"requires_refresh": False,
"source": "exchange",
"status": "error" if overview["errors"] else "ok",
"synced_at": synced_at,
}
overview = _attach_performance(overview, account_id, record_history=True, snapshot_at=synced_at)
upsert_live_account_snapshot(
account_id,
_exchange_snapshot_payload(overview),
status="error" if overview["errors"] else "ok",
error_message="".join(overview["errors"][:3]),
synced_at=synced_at,
)
return _cache_overview(account_id, overview)
def sync_live_account_snapshots(*, account_ids: list[int] | None = None, history_limit: int = 30, client_factory=None) -> dict:
selected = {int(x) for x in (account_ids or []) if int(x or 0) > 0}
accounts = list_enabled_live_accounts()
if selected:
accounts = [account for account in accounts if int(account.get("id") or 0) in selected]
items = []
ok_count = 0
error_count = 0
for account in accounts:
try:
overview = get_live_account_overview(
account["id"],
history_limit=history_limit,
refresh=True,
client_factory=client_factory,
)
status = (overview.get("exchange_cache") or {}).get("status") or ("error" if overview.get("errors") else "ok")
if status == "ok":
ok_count += 1
else:
error_count += 1
items.append({
"account_id": account["id"],
"account_code": account.get("account_code"),
"status": status,
"synced_at": (overview.get("exchange_cache") or {}).get("synced_at"),
"errors": overview.get("errors") or [],
})
except Exception as exc:
error_count += 1
items.append({
"account_id": account.get("id"),
"account_code": account.get("account_code"),
"status": "error",
"errors": [str(exc)],
})
return {
"ok": error_count == 0,
"total": len(accounts),
"ok_count": ok_count,
"error_count": error_count,
"items": items,
}