alphax/app/db/paper_trading.py
2026-05-18 00:58:19 +08:00

601 lines
21 KiB
Python

"""Paper trading ledger for separating signal quality from trade PnL."""
from __future__ import annotations
import json
import os
from datetime import datetime, timedelta
from app.config.system_config import paper_trading_config
from app.db.schema import get_conn
def _now() -> str:
return datetime.now().isoformat()
def _safe_float(value, default: float = 0.0) -> float:
try:
if value is None or value == "":
return default
return float(value)
except Exception:
return default
def _safe_int(value, default: int = 0) -> int:
try:
return int(value or 0)
except Exception:
return default
def paper_trading_enabled() -> bool:
return bool(paper_trading_config().get("enabled", True))
def default_account_equity_usdt() -> float:
return max(1.0, _safe_float(paper_trading_config().get("account_equity_usdt"), 20000.0))
def default_leverage() -> float:
return max(1.0, _safe_float(paper_trading_config().get("trade_leverage"), 5.0))
def default_notional_usdt() -> float:
return max(1.0, _safe_float(paper_trading_config().get("trade_notional_usdt"), 5000.0))
def default_margin_usdt() -> float:
return round(default_notional_usdt() / default_leverage(), 8)
def default_fee_rate() -> float:
return max(0.0, _safe_float(paper_trading_config().get("fee_rate"), 0.001))
def default_slippage_pct() -> float:
return max(0.0, _safe_float(paper_trading_config().get("slippage_pct"), 0.05))
def _trailing_config() -> dict:
cfg = paper_trading_config()
return {
"enabled": bool(cfg.get("trailing_stop_enabled", True)),
"activate_pnl_pct": max(0.0, _safe_float(cfg.get("trailing_activate_pnl_pct"), 3.0)),
"min_lock_profit_pct": max(0.0, _safe_float(cfg.get("trailing_min_lock_profit_pct"), 0.5)),
"distance_pct": max(0.1, _safe_float(cfg.get("trailing_distance_pct"), 1.5)),
"tiers": cfg.get("trailing_tiers") if isinstance(cfg.get("trailing_tiers"), list) else [],
}
def _trailing_distance_pct(pnl_pct: float, cfg: dict) -> tuple[float, str]:
distance = _safe_float(cfg.get("distance_pct"), 1.5)
label = ""
tiers = cfg.get("tiers") or []
for tier in sorted((t for t in tiers if isinstance(t, dict)), key=lambda x: _safe_float(x.get("min_pnl_pct")), reverse=True):
if pnl_pct >= _safe_float(tier.get("min_pnl_pct")):
distance = max(0.1, _safe_float(tier.get("distance_pct"), distance))
label = str(tier.get("label") or "")
break
return distance, label
def _loads_json(value, fallback=None):
try:
if isinstance(value, str) and value.strip():
return json.loads(value)
if value:
return value
except Exception:
pass
return fallback if fallback is not None else {}
def _entry_plan(rec: dict) -> dict:
plan = rec.get("entry_plan")
if isinstance(plan, dict):
return plan
return _loads_json(rec.get("entry_plan_json"), {})
def _open_price(current_price: float) -> float:
return round(current_price * (1 + default_slippage_pct() / 100), 12)
def _close_price(current_price: float) -> float:
return round(current_price * (1 - default_slippage_pct() / 100), 12)
def _trade_pnl_pct(entry_price: float, current_price: float) -> float:
if entry_price <= 0 or current_price <= 0:
return 0.0
return round((current_price / entry_price - 1) * 100, 4)
def _account_return_pct(pnl_usdt: float, account_equity: float | None = None) -> float:
equity = max(1.0, _safe_float(account_equity, default_account_equity_usdt()))
return round(_safe_float(pnl_usdt) / equity * 100, 4)
def _margin_roi_pct(pnl_usdt: float, margin_usdt: float) -> float:
margin = max(1.0, _safe_float(margin_usdt, default_margin_usdt()))
return round(_safe_float(pnl_usdt) / margin * 100, 4)
def _trade_margin(trade: dict) -> float:
margin = _safe_float(trade.get("margin_usdt"))
if margin > 0:
return margin
leverage = max(1.0, _safe_float(trade.get("leverage"), default_leverage()))
return round(_safe_float(trade.get("notional_usdt")) / leverage, 8)
def _decorate_trade(trade: dict) -> dict:
item = dict(trade)
notional = _safe_float(item.get("notional_usdt"), default_notional_usdt())
leverage = max(1.0, _safe_float(item.get("leverage"), default_leverage()))
margin = _trade_margin({"margin_usdt": item.get("margin_usdt"), "notional_usdt": notional, "leverage": leverage})
unrealized = round(notional * _safe_float(item.get("pnl_pct")) / 100, 8)
realized = _safe_float(item.get("realized_pnl_usdt"))
effective_pnl = realized if item.get("status") == "closed" else unrealized
item["notional_usdt"] = notional
item["leverage"] = leverage
item["margin_usdt"] = margin
item["unrealized_pnl_usdt"] = unrealized
item["margin_roi_pct"] = _margin_roi_pct(effective_pnl, margin)
item["account_return_pct"] = _account_return_pct(effective_pnl)
item["account_equity_usdt"] = default_account_equity_usdt()
latest_market = _safe_float(item.get("latest_market_price"))
item["latest_price"] = latest_market if latest_market > 0 else _safe_float(item.get("current_price"))
item["latest_price_updated_at"] = item.get("latest_market_price_updated_at") or item.get("updated_at") or ""
return item
def _record_event(conn, trade_id: int, rec_id: int, symbol: str, event_type: str, price: float, pnl_pct: float, message: str, detail=None, event_time: str = ""):
conn.execute(
"""
INSERT INTO paper_trade_events (
trade_id, recommendation_id, symbol, event_type, event_time,
price, pnl_pct, message, detail_json
) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s)
""",
(
trade_id,
rec_id,
symbol,
event_type,
event_time or _now(),
price,
pnl_pct,
message,
json.dumps(detail or {}, ensure_ascii=False, default=str),
),
)
def _open_trade(conn, rec: dict, current_price: float, event_time: str) -> dict:
rec_id = _safe_int(rec.get("id"))
symbol = str(rec.get("symbol") or "").strip().upper()
plan = _entry_plan(rec)
entry_price = _open_price(current_price)
notional = default_notional_usdt()
leverage = default_leverage()
margin = default_margin_usdt()
qty = round(notional / entry_price, 12) if entry_price > 0 else 0
stop_loss = _safe_float(rec.get("stop_loss") or plan.get("stop_loss"))
tp1 = _safe_float(rec.get("tp1") or plan.get("tp1") or plan.get("take_profit_1"))
tp2 = _safe_float(rec.get("tp2") or plan.get("tp2") or plan.get("take_profit_2"))
fee = round(notional * default_fee_rate(), 8)
now = event_time or _now()
row = conn.execute(
"""
INSERT INTO paper_trades (
recommendation_id, symbol, side, status, opened_at,
entry_price, qty, notional_usdt, margin_usdt, leverage, stop_loss, tp1, tp2,
max_price, min_price, current_price, pnl_pct, fee_usdt,
source_status, source_action, strategy_version, created_at, updated_at
) VALUES (%s,%s,'long','open',%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0,%s,%s,%s,%s,%s,%s)
ON CONFLICT(recommendation_id) DO NOTHING
RETURNING id
""",
(
rec_id,
symbol,
now,
entry_price,
qty,
notional,
margin,
leverage,
stop_loss,
tp1,
tp2,
entry_price,
entry_price,
entry_price,
fee,
rec.get("execution_status") or "",
rec.get("action_status") or "",
rec.get("strategy_version") or "",
now,
now,
),
).fetchone()
if not row:
return {"opened": False, "reason": "already_exists"}
trade_id = row["id"]
_record_event(
conn,
trade_id,
rec_id,
symbol,
"open",
entry_price,
0.0,
"模拟交易开仓:仅用于策略收益验证,不代表真实成交",
{
"notional_usdt": notional,
"margin_usdt": margin,
"leverage": leverage,
"qty": qty,
"fee_usdt": fee,
"slippage_pct": default_slippage_pct(),
"source_status": rec.get("execution_status") or "",
"source_action": rec.get("action_status") or "",
},
now,
)
return {
"opened": True,
"trade_id": trade_id,
"entry_price": entry_price,
"qty": qty,
"notional_usdt": notional,
"margin_usdt": margin,
"leverage": leverage,
}
def _close_trade(conn, trade: dict, current_price: float, reason: str, event_time: str) -> dict:
entry_price = _safe_float(trade.get("entry_price"))
exit_price = _close_price(current_price)
pnl_pct = _trade_pnl_pct(entry_price, exit_price)
notional = _safe_float(trade.get("notional_usdt"))
open_fee = _safe_float(trade.get("fee_usdt"))
close_fee = round(notional * default_fee_rate(), 8)
total_fee = round(open_fee + close_fee, 8)
pnl_usdt = round(notional * pnl_pct / 100 - total_fee, 8)
now = event_time or _now()
conn.execute(
"""
UPDATE paper_trades
SET status='closed',
closed_at=%s,
exit_price=%s,
current_price=%s,
pnl_pct=%s,
realized_pnl_pct=%s,
realized_pnl_usdt=%s,
fee_usdt=%s,
exit_reason=%s,
updated_at=%s
WHERE id=%s AND status='open'
""",
(
now,
exit_price,
exit_price,
pnl_pct,
pnl_pct,
pnl_usdt,
total_fee,
reason,
now,
trade["id"],
),
)
_record_event(
conn,
trade["id"],
trade["recommendation_id"],
trade["symbol"],
"close",
exit_price,
pnl_pct,
f"模拟交易平仓:{reason}",
{"realized_pnl_usdt": pnl_usdt, "fee_usdt": total_fee},
now,
)
return {"closed": True, "trade_id": trade["id"], "exit_reason": reason, "pnl_pct": pnl_pct, "pnl_usdt": pnl_usdt}
def _update_trailing_stop(conn, trade: dict, current_price: float, pnl_pct: float, event_time: str) -> tuple[float, dict]:
cfg = _trailing_config()
current_trail = _safe_float(trade.get("trailing_stop"))
if not cfg.get("enabled") or pnl_pct < _safe_float(cfg.get("activate_pnl_pct")):
return current_trail, {"activated": False, "moved": False}
entry_price = _safe_float(trade.get("entry_price"))
if entry_price <= 0 or current_price <= 0:
return current_trail, {"activated": False, "moved": False}
distance_pct, tier_label = _trailing_distance_pct(pnl_pct, cfg)
protection_floor = entry_price * (1 + _safe_float(cfg.get("min_lock_profit_pct")) / 100)
candidate = current_price * (1 - distance_pct / 100)
new_trail = round(max(current_trail, protection_floor, candidate), 12)
activated = current_trail <= 0 and new_trail > 0
moved = current_trail > 0 and new_trail > current_trail + 1e-12
if not activated and not moved:
return current_trail, {"activated": False, "moved": False}
event_type = "trailing_activate" if activated else "trailing_move"
action_text = "激活" if activated else "上移"
message = f"模拟交易移动止盈{action_text}:保护价 {new_trail:.8g}"
_record_event(
conn,
trade["id"],
trade["recommendation_id"],
trade["symbol"],
event_type,
new_trail,
pnl_pct,
message,
{
"current_price": current_price,
"previous_trailing_stop": current_trail,
"trailing_stop": new_trail,
"activate_pnl_pct": cfg.get("activate_pnl_pct"),
"distance_pct": distance_pct,
"tier_label": tier_label,
"min_lock_profit_pct": cfg.get("min_lock_profit_pct"),
},
event_time,
)
return new_trail, {
"activated": activated,
"moved": moved,
"trailing_stop": new_trail,
"previous_trailing_stop": current_trail,
"distance_pct": distance_pct,
"tier_label": tier_label,
}
def _update_open_trade(conn, trade: dict, current_price: float, event_time: str) -> dict:
entry_price = _safe_float(trade.get("entry_price"))
old_max = _safe_float(trade.get("max_price")) or entry_price
old_min = _safe_float(trade.get("min_price")) or entry_price
new_max = max(old_max, current_price)
new_min = min(old_min, current_price)
pnl_pct = _trade_pnl_pct(entry_price, current_price)
stop_loss = _safe_float(trade.get("stop_loss"))
trailing_stop = _safe_float(trade.get("trailing_stop"))
tp2 = _safe_float(trade.get("tp2"))
tp1 = _safe_float(trade.get("tp1"))
reason = ""
if stop_loss > 0 and current_price <= stop_loss:
reason = "stop_loss"
elif trailing_stop > 0 and current_price <= trailing_stop:
reason = "trailing_stop"
elif tp2 > 0 and current_price >= tp2:
reason = "tp2"
elif tp1 > 0 and current_price >= tp1:
reason = "tp1"
if reason:
return _close_trade(conn, trade, current_price, reason, event_time)
trailing_stop, trailing_result = _update_trailing_stop(conn, trade, current_price, pnl_pct, event_time or _now())
conn.execute(
"""
UPDATE paper_trades
SET current_price=%s,
max_price=%s,
min_price=%s,
trailing_stop=%s,
pnl_pct=%s,
updated_at=%s
WHERE id=%s AND status='open'
""",
(current_price, new_max, new_min, trailing_stop, pnl_pct, event_time or _now(), trade["id"]),
)
return {"updated": True, "trade_id": trade["id"], "pnl_pct": pnl_pct, **trailing_result}
def sync_recommendation(rec: dict, current_price: float, event_time: str = "") -> dict:
"""Open/update paper trade for one recommendation.
This is intentionally independent from recommendation PnL fields. A
recommendation can be a signal; only this ledger represents simulated
execution.
"""
if not paper_trading_enabled():
return {"enabled": False, "skipped": True, "reason": "disabled"}
rec_id = _safe_int(rec.get("id"))
symbol = str(rec.get("symbol") or "").strip().upper()
current_price = _safe_float(current_price)
if rec_id <= 0 or not symbol or current_price <= 0:
return {"enabled": True, "skipped": True, "reason": "invalid_input"}
execution_status = str(rec.get("execution_status") or "").strip()
action_status = str(rec.get("action_status") or "").strip()
event_time = event_time or _now()
conn = get_conn()
try:
trade = conn.execute("SELECT * FROM paper_trades WHERE recommendation_id=%s", (rec_id,)).fetchone()
if trade:
trade = dict(trade)
if trade.get("status") == "open":
result = _update_open_trade(conn, trade, current_price, event_time)
conn.commit()
return result
conn.close()
return {"skipped": True, "reason": "already_closed", "trade_id": trade.get("id")}
if execution_status != "buy_now" and action_status != "可即刻买入":
conn.close()
return {"skipped": True, "reason": "not_buy_now"}
result = _open_trade(conn, rec, current_price, event_time)
conn.commit()
return result
except Exception:
conn.rollback()
raise
finally:
try:
conn.close()
except Exception:
pass
def get_paper_trading_summary(days: int = 30) -> dict:
days = max(1, min(_safe_int(days, 30), 365))
cutoff = (datetime.now() - timedelta(days=days)).isoformat()
conn = get_conn()
try:
rows = conn.execute(
"""
SELECT * FROM paper_trades
WHERE opened_at >= %s
ORDER BY opened_at DESC, id DESC
""",
(cutoff,),
).fetchall()
finally:
conn.close()
items = [_decorate_trade(dict(r)) for r in rows]
open_items = [x for x in items if x.get("status") == "open"]
closed_items = [x for x in items if x.get("status") == "closed"]
wins = [x for x in closed_items if _safe_float(x.get("realized_pnl_pct")) > 0]
losses = [x for x in closed_items if _safe_float(x.get("realized_pnl_pct")) <= 0]
total_realized = round(sum(_safe_float(x.get("realized_pnl_usdt")) for x in closed_items), 4)
avg_realized_pct = round(sum(_safe_float(x.get("realized_pnl_pct")) for x in closed_items) / len(closed_items), 4) if closed_items else 0
open_unrealized = round(sum(_safe_float(x.get("unrealized_pnl_usdt")) for x in open_items), 4)
total_pnl = round(total_realized + open_unrealized, 4)
allocated_margin = round(sum(_safe_float(x.get("margin_usdt")) for x in open_items), 4)
open_position_value = round(sum(_safe_float(x.get("notional_usdt")) for x in open_items), 4)
initial_equity = default_account_equity_usdt()
current_balance = round(initial_equity + total_pnl, 4)
cumulative_leverage = round(open_position_value / current_balance, 4) if current_balance > 0 else 0
return {
"days": days,
"total": len(items),
"open_count": len(open_items),
"closed_count": len(closed_items),
"win_count": len(wins),
"loss_count": len(losses),
"win_rate": round(len(wins) / len(closed_items) * 100, 2) if closed_items else 0,
"realized_pnl_usdt": total_realized,
"avg_realized_pnl_pct": avg_realized_pct,
"open_unrealized_pnl_usdt": open_unrealized,
"total_pnl_usdt": total_pnl,
"initial_equity_usdt": initial_equity,
"account_equity_usdt": initial_equity,
"current_balance_usdt": current_balance,
"account_realized_return_pct": _account_return_pct(total_realized),
"account_unrealized_return_pct": _account_return_pct(open_unrealized),
"account_total_return_pct": _account_return_pct(total_pnl),
"allocated_margin_usdt": allocated_margin,
"open_position_value_usdt": open_position_value,
"cumulative_leverage": cumulative_leverage,
"available_equity_usdt": round(current_balance - allocated_margin, 4),
"margin_usdt": default_margin_usdt(),
"leverage": default_leverage(),
"notional_usdt": default_notional_usdt(),
"fee_rate": default_fee_rate(),
"slippage_pct": default_slippage_pct(),
}
def list_paper_trades(limit: int = 50, offset: int = 0, status: str = "") -> dict:
limit = max(1, min(_safe_int(limit, 50), 200))
offset = max(0, _safe_int(offset, 0))
status = str(status or "").strip()
where = ""
params = []
if status in {"open", "closed"}:
where = "WHERE status=%s"
params.append(status)
conn = get_conn()
try:
total = conn.execute(f"SELECT COUNT(*) FROM paper_trades {where}", tuple(params)).fetchone()[0]
rows = conn.execute(
f"""
SELECT pt.*, lpc.price AS latest_market_price, lpc.updated_at AS latest_market_price_updated_at
FROM paper_trades pt
LEFT JOIN latest_price_cache lpc ON lpc.symbol = pt.symbol
{where}
ORDER BY pt.opened_at DESC, pt.id DESC
LIMIT %s OFFSET %s
""",
tuple(params + [limit, offset]),
).fetchall()
finally:
conn.close()
return {
"items": [_decorate_trade(dict(r)) for r in rows],
"total": int(total or 0),
"limit": limit,
"offset": offset,
"has_more": offset + len(rows) < int(total or 0),
}
def list_paper_trade_events(limit: int = 80, offset: int = 0, symbol: str = "", event_type: str = "") -> dict:
limit = max(1, min(_safe_int(limit, 80), 200))
offset = max(0, _safe_int(offset, 0))
symbol = str(symbol or "").strip().upper()
event_type = str(event_type or "").strip()
where = []
params = []
if symbol:
where.append("e.symbol=%s")
params.append(symbol)
if event_type:
where.append("e.event_type=%s")
params.append(event_type)
where_sql = "WHERE " + " AND ".join(where) if where else ""
conn = get_conn()
try:
total = conn.execute(
f"SELECT COUNT(*) FROM paper_trade_events e {where_sql}",
tuple(params),
).fetchone()[0]
rows = conn.execute(
f"""
SELECT
e.*,
t.status AS trade_status,
t.entry_price,
t.exit_price,
t.notional_usdt,
t.margin_usdt,
t.leverage,
t.exit_reason,
t.opened_at,
t.closed_at
FROM paper_trade_events e
LEFT JOIN paper_trades t ON t.id = e.trade_id
{where_sql}
ORDER BY e.event_time DESC, e.id DESC
LIMIT %s OFFSET %s
""",
tuple(params + [limit, offset]),
).fetchall()
finally:
conn.close()
items = []
for row in rows:
item = dict(row)
item["detail"] = _loads_json(item.pop("detail_json", "{}"), {})
items.append(item)
return {
"items": items,
"total": int(total or 0),
"limit": limit,
"offset": offset,
"has_more": offset + len(items) < int(total or 0),
}