tradusai/trading/paper_trading.py
2025-12-11 22:51:51 +08:00

1639 lines
62 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Paper Trading Module - 多币种多周期独立仓位管理
支持多币种 (BTC/USDT, ETH/USDT 等) 和三个独立周期的模拟交易:
- 短周期 (5m/15m/1h): short_term_5m_15m_1h / intraday
- 中周期 (4h/1d): medium_term_4h_1d / swing
- 长周期 (1d/1w): long_term_1d_1w
每个币种的每个周期独立管理:
- 独立仓位
- 独立止盈止损
- 独立统计数据
- 独立权益曲线
"""
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
from pathlib import Path
from dataclasses import dataclass, asdict, field
from enum import Enum
from config.settings import settings
logger = logging.getLogger(__name__)
class TimeFrame(Enum):
"""交易周期"""
SHORT = "short" # 短周期 5m/15m/1h
MEDIUM = "medium" # 中周期 4h/1d
LONG = "long" # 长周期 1d/1w
TIMEFRAME_CONFIG = {
TimeFrame.SHORT: {
'name': '短周期',
'name_en': 'Short-term',
'signal_keys': ['short_term_5m_15m_1h', 'intraday'],
'leverage': 10,
'initial_balance': 10000.0,
'signal_expiry_minutes': 5, # 信号有效期5分钟
'min_risk_reward_ratio': 1.5, # 最小风险回报比
'base_price_deviation': 0.003, # 基础价格偏差 0.3%
'atr_deviation_multiplier': 0.5, # ATR偏差系数
},
TimeFrame.MEDIUM: {
'name': '中周期',
'name_en': 'Medium-term',
'signal_keys': ['medium_term_4h_1d', 'swing'],
'leverage': 10,
'initial_balance': 10000.0,
'signal_expiry_minutes': 30, # 信号有效期30分钟
'min_risk_reward_ratio': 1.5,
'base_price_deviation': 0.005, # 基础价格偏差 0.5%
'atr_deviation_multiplier': 0.8,
},
TimeFrame.LONG: {
'name': '长周期',
'name_en': 'Long-term',
'signal_keys': ['long_term_1d_1w'],
'leverage': 10,
'initial_balance': 10000.0,
'signal_expiry_minutes': 120, # 信号有效期2小时
'min_risk_reward_ratio': 2.0, # 长周期要求更高回报比
'base_price_deviation': 0.01, # 基础价格偏差 1%
'atr_deviation_multiplier': 1.0,
},
}
# 金字塔加仓配置每次加仓的仓位比例总计100%
PYRAMID_LEVELS = [0.4, 0.3, 0.2, 0.1] # 首仓40%加仓30%、20%、10%
# 加仓价格改善要求(相对于均价的百分比)
PYRAMID_PRICE_IMPROVEMENT = 0.005 # 加仓价格需比均价优 0.5%
# 多周期协调配置
TIMEFRAME_HIERARCHY = {
TimeFrame.SHORT: [TimeFrame.MEDIUM, TimeFrame.LONG], # 短周期受中、长周期约束
TimeFrame.MEDIUM: [TimeFrame.LONG], # 中周期受长周期约束
TimeFrame.LONG: [], # 长周期不受约束
}
# 信号确认配置
SIGNAL_CONFIRMATION_COUNT = 2 # 需要连续2次相同方向信号才执行
@dataclass
class PositionEntry:
"""单次入场记录"""
price: float
size: float # BTC 数量
margin: float # 本次占用保证金
timestamp: str
level: int # 金字塔层级 0=首仓, 1=加仓1, ...
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'PositionEntry':
return cls(**data)
@dataclass
class Position:
"""持仓信息(支持金字塔加仓)"""
side: str # LONG, SHORT, FLAT
entries: List['PositionEntry'] = field(default_factory=list) # 入场记录
stop_loss: float = 0.0
take_profit: float = 0.0
created_at: str = ""
signal_reasoning: str = ""
@property
def entry_price(self) -> float:
"""加权平均入场价"""
if not self.entries:
return 0.0
total_value = sum(e.price * e.size for e in self.entries)
total_size = sum(e.size for e in self.entries)
return total_value / total_size if total_size > 0 else 0.0
@property
def size(self) -> float:
"""总持仓数量"""
return sum(e.size for e in self.entries)
@property
def margin(self) -> float:
"""总占用保证金"""
return sum(e.margin for e in self.entries)
@property
def pyramid_level(self) -> int:
"""当前金字塔层级"""
return len(self.entries)
def to_dict(self) -> dict:
return {
'side': self.side,
'entry_price': self.entry_price,
'size': self.size,
'margin': self.margin,
'pyramid_level': self.pyramid_level,
'entries': [e.to_dict() for e in self.entries],
'stop_loss': self.stop_loss,
'take_profit': self.take_profit,
'created_at': self.created_at,
'signal_reasoning': self.signal_reasoning,
}
@classmethod
def from_dict(cls, data: dict) -> 'Position':
entries = [PositionEntry.from_dict(e) for e in data.get('entries', [])]
# 兼容旧数据格式
if not entries and data.get('entry_price') and data.get('size'):
entries = [PositionEntry(
price=data['entry_price'],
size=data['size'],
margin=data.get('margin', 0),
timestamp=data.get('created_at', ''),
level=0,
)]
return cls(
side=data['side'],
entries=entries,
stop_loss=data.get('stop_loss', 0),
take_profit=data.get('take_profit', 0),
created_at=data.get('created_at', ''),
signal_reasoning=data.get('signal_reasoning', ''),
)
@dataclass
class Trade:
"""交易记录"""
id: str
timeframe: str
side: str
entry_price: float
entry_time: str
exit_price: float
exit_time: str
size: float
pnl: float
pnl_pct: float
exit_reason: str
symbol: str = "BTCUSDT" # 交易币种
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'Trade':
# 兼容旧数据
if 'symbol' not in data:
data['symbol'] = 'BTCUSDT'
return cls(**data)
@dataclass
class SignalHistory:
"""信号历史记录(用于信号确认机制)"""
direction: str # LONG, SHORT, NONE
timestamp: str
entry_price: float = 0.0
stop_loss: float = 0.0
take_profit: float = 0.0
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'SignalHistory':
return cls(**data)
@dataclass
class TimeFrameAccount:
"""单个币种单个周期的账户
资金结构:
- initial_balance: 初始本金
- realized_pnl: 已实现盈亏(平仓后累计)
- position.margin: 当前持仓占用保证金
- unrealized_pnl: 未实现盈亏(需实时计算)
账户权益 = initial_balance + realized_pnl + unrealized_pnl
可用余额 = initial_balance + realized_pnl - position.margin
"""
timeframe: str
initial_balance: float
leverage: int
symbol: str = "BTCUSDT" # 交易币种
realized_pnl: float = 0.0 # 已实现盈亏
position: Optional[Position] = None
trades: List[Trade] = field(default_factory=list)
stats: Dict = field(default_factory=dict)
equity_curve: List[Dict] = field(default_factory=list)
signal_history: List[SignalHistory] = field(default_factory=list) # 信号确认历史
last_atr: float = 0.0 # 最近的ATR值用于动态计算
def __post_init__(self):
if not self.stats:
self.stats = self._init_stats()
def _init_stats(self) -> dict:
return {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'total_pnl': 0.0,
'max_drawdown': 0.0,
'peak_balance': self.initial_balance,
'win_rate': 0.0,
'avg_win': 0.0,
'avg_loss': 0.0,
'profit_factor': 0.0,
}
def get_used_margin(self) -> float:
"""获取已占用保证金"""
if self.position and self.position.side != 'FLAT':
return self.position.margin
return 0.0
def get_available_balance(self) -> float:
"""获取可用余额(可用于开新仓)"""
return self.initial_balance + self.realized_pnl - self.get_used_margin()
def get_equity(self, unrealized_pnl: float = 0.0) -> float:
"""获取账户权益(包含未实现盈亏)"""
return self.initial_balance + self.realized_pnl + unrealized_pnl
def to_dict(self) -> dict:
return {
'timeframe': self.timeframe,
'symbol': self.symbol,
'initial_balance': self.initial_balance,
'realized_pnl': self.realized_pnl,
'leverage': self.leverage,
'position': self.position.to_dict() if self.position else None,
'trades': [t.to_dict() for t in self.trades[-100:]],
'stats': self.stats,
'equity_curve': self.equity_curve[-500:],
}
@classmethod
def from_dict(cls, data: dict) -> 'TimeFrameAccount':
# 兼容旧数据格式
realized_pnl = data.get('realized_pnl', 0.0)
# 如果是旧格式,从 balance 推算 realized_pnl
if 'realized_pnl' not in data and 'balance' in data:
realized_pnl = data['balance'] - data['initial_balance']
account = cls(
timeframe=data['timeframe'],
initial_balance=data['initial_balance'],
leverage=data['leverage'],
symbol=data.get('symbol', 'BTCUSDT'), # 兼容旧数据
realized_pnl=realized_pnl,
stats=data.get('stats', {}),
equity_curve=data.get('equity_curve', []),
)
if data.get('position'):
account.position = Position.from_dict(data['position'])
account.trades = [Trade.from_dict(t) for t in data.get('trades', [])]
return account
class MultiTimeframePaperTrader:
"""多币种多周期模拟盘交易器"""
def __init__(
self,
initial_balance: float = 10000.0,
state_file: str = None,
symbols: List[str] = None
):
self.initial_balance = initial_balance
# 支持的币种列表
self.symbols = symbols or settings.symbols_list
logger.info(f"支持的交易对: {', '.join(self.symbols)}")
# 状态文件
if state_file:
self.state_file = Path(state_file)
else:
self.state_file = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
# 多币种多周期账户: {symbol: {TimeFrame: TimeFrameAccount}}
self.accounts: Dict[str, Dict[TimeFrame, TimeFrameAccount]] = {}
# 加载或初始化状态
self._load_state()
logger.info(f"Multi-symbol Multi-timeframe Paper Trader initialized: {len(self.symbols)} symbols")
def _load_state(self):
"""加载持久化状态"""
if self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
state = json.load(f)
# 检查是否是新的多币种格式
if 'symbols' in state:
# 新格式: {symbols: {BTCUSDT: {short: {...}, medium: {...}, long: {...}}, ...}}
for symbol in self.symbols:
symbol_data = state.get('symbols', {}).get(symbol, {})
self.accounts[symbol] = {}
for tf in TimeFrame:
tf_data = symbol_data.get(tf.value)
if tf_data:
self.accounts[symbol][tf] = TimeFrameAccount.from_dict(tf_data)
else:
self._init_account(symbol, tf)
else:
# 旧格式: {accounts: {short: {...}, medium: {...}, long: {...}}}
# 将旧数据迁移到第一个币种 (BTCUSDT)
first_symbol = self.symbols[0] if self.symbols else 'BTCUSDT'
self.accounts[first_symbol] = {}
for tf in TimeFrame:
tf_data = state.get('accounts', {}).get(tf.value)
if tf_data:
tf_data['symbol'] = first_symbol # 添加 symbol 字段
self.accounts[first_symbol][tf] = TimeFrameAccount.from_dict(tf_data)
else:
self._init_account(first_symbol, tf)
# 初始化其他币种
for symbol in self.symbols[1:]:
self._init_symbol_accounts(symbol)
logger.info(f"Loaded state from {self.state_file}")
except Exception as e:
logger.error(f"Failed to load state: {e}")
self._init_all_accounts()
else:
self._init_all_accounts()
def _init_all_accounts(self):
"""初始化所有币种所有周期账户"""
for symbol in self.symbols:
self._init_symbol_accounts(symbol)
def _init_symbol_accounts(self, symbol: str):
"""初始化单个币种的所有周期账户"""
self.accounts[symbol] = {}
for tf in TimeFrame:
self._init_account(symbol, tf)
def _init_account(self, symbol: str, tf: TimeFrame):
"""初始化单个币种单个周期账户"""
config = TIMEFRAME_CONFIG[tf]
# 每个币种每个周期独立初始资金 10000 USD10倍杠杆最大仓位价值 100000 USD
self.accounts[symbol][tf] = TimeFrameAccount(
timeframe=tf.value,
initial_balance=config['initial_balance'],
leverage=config['leverage'],
symbol=symbol,
realized_pnl=0.0,
)
def _save_state(self):
"""保存状态到文件"""
self.state_file.parent.mkdir(parents=True, exist_ok=True)
# 新格式: {symbols: {BTCUSDT: {short: {...}, ...}, ETHUSDT: {...}}, accounts: {...}}
symbols_data = {}
for symbol, tf_accounts in self.accounts.items():
symbols_data[symbol] = {
tf.value: acc.to_dict() for tf, acc in tf_accounts.items()
}
# 同时保留旧格式兼容 (使用第一个币种)
first_symbol = self.symbols[0] if self.symbols else 'BTCUSDT'
legacy_accounts = {}
if first_symbol in self.accounts:
legacy_accounts = {
tf.value: acc.to_dict() for tf, acc in self.accounts[first_symbol].items()
}
state = {
'symbols': symbols_data,
'accounts': legacy_accounts, # 向后兼容
'last_updated': datetime.now().isoformat(),
}
with open(self.state_file, 'w') as f:
json.dump(state, f, indent=2, ensure_ascii=False)
def process_signal(
self,
signal: Dict[str, Any],
current_price: float,
symbol: str = None
) -> Dict[str, Any]:
"""处理单个币种的交易信号 - 检查所有周期
Args:
signal: 该币种的信号数据
current_price: 该币种当前价格
symbol: 交易对,如 'BTCUSDT'。若未指定则使用第一个币种
"""
symbol = symbol or (self.symbols[0] if self.symbols else 'BTCUSDT')
# 确保该币种的账户已初始化
if symbol not in self.accounts:
self._init_symbol_accounts(symbol)
results = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'current_price': current_price,
'timeframes': {},
}
for tf in TimeFrame:
result = self._process_timeframe_signal(symbol, tf, signal, current_price)
results['timeframes'][tf.value] = result
self._save_state()
return results
def process_all_signals(
self,
signals: Dict[str, Dict[str, Any]],
prices: Dict[str, float]
) -> Dict[str, Any]:
"""处理所有币种的信号
Args:
signals: {symbol: signal_data} 各币种的信号
prices: {symbol: price} 各币种的当前价格
"""
results = {
'timestamp': datetime.now().isoformat(),
'symbols': {},
}
for symbol in self.symbols:
if symbol in signals and symbol in prices:
result = self.process_signal(
signal=signals[symbol],
current_price=prices[symbol],
symbol=symbol
)
results['symbols'][symbol] = result
return results
def _process_timeframe_signal(
self, symbol: str, tf: TimeFrame, signal: Dict[str, Any], current_price: float
) -> Dict[str, Any]:
"""处理单个币种单个周期的信号(包含所有优化)"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
result = {
'action': 'NONE',
'details': None,
}
# 更新权益曲线
self._update_equity_curve(symbol, tf, current_price)
# 1. 检查止盈止损
if account.position and account.position.side != 'FLAT':
close_result = self._check_close_position(symbol, tf, current_price)
if close_result:
result['action'] = 'CLOSE'
result['details'] = close_result
return result
# 2. 提取该周期的信号
tf_signal = self._extract_timeframe_signal(signal, config['signal_keys'])
if not tf_signal or not tf_signal.get('exists'):
# 记录无信号到历史
self._record_signal_history(account, 'NONE', 0, 0, 0)
result['action'] = 'NO_SIGNAL'
return result
direction = tf_signal.get('direction')
if not direction:
self._record_signal_history(account, 'NONE', 0, 0, 0)
result['action'] = 'NO_SIGNAL'
return result
signal_stop_loss = tf_signal.get('stop_loss', 0)
signal_take_profit = tf_signal.get('take_profit', 0)
signal_entry_price = tf_signal.get('entry_price', 0)
signal_timestamp = signal.get('timestamp') or signal.get('aggregated_signal', {}).get('timestamp')
# ========== 优化1: 信号时效性检查 ==========
if signal_timestamp:
expiry_check = self._check_signal_expiry(signal_timestamp, config)
if not expiry_check['valid']:
result['action'] = 'SIGNAL_EXPIRED'
result['details'] = expiry_check
logger.info(f"[{symbol}][{config['name']}] 信号已过期: {expiry_check['age_minutes']:.1f}分钟")
return result
# 验证止盈止损存在
if signal_stop_loss <= 0 or signal_take_profit <= 0:
result['action'] = 'NO_SIGNAL'
result['details'] = {'reason': '缺少有效止盈止损'}
return result
# ========== 优化2: 风险回报比验证 ==========
rr_check = self._check_risk_reward_ratio(
direction, current_price, signal_stop_loss, signal_take_profit, config
)
if not rr_check['valid']:
result['action'] = 'LOW_RISK_REWARD'
result['details'] = rr_check
logger.info(
f"[{symbol}][{config['name']}] 风险回报比不足: {rr_check['ratio']:.2f} < {rr_check['min_ratio']}"
)
return result
# ========== 优化3: 动态价格偏差基于ATR ==========
# 更新ATR从信号中获取
atr = self._get_atr_from_signal(signal)
if atr > 0:
account.last_atr = atr
max_deviation = self._calculate_dynamic_deviation(config, account.last_atr, current_price)
if signal_entry_price > 0:
price_deviation = abs(current_price - signal_entry_price) / signal_entry_price
if price_deviation > max_deviation:
result['action'] = 'PRICE_DEVIATION'
result['details'] = {
'reason': f'价格偏差过大: {price_deviation*100:.2f}% > {max_deviation*100:.2f}%',
'signal_entry': signal_entry_price,
'current_price': current_price,
'deviation_pct': price_deviation * 100,
'max_deviation_pct': max_deviation * 100,
'atr_used': account.last_atr,
}
logger.info(
f"[{symbol}][{config['name']}] 跳过开仓: 价格偏差 {price_deviation*100:.2f}% > {max_deviation*100:.2f}% "
f"(信号价: ${signal_entry_price:.2f}, 当前价: ${current_price:.2f})"
)
return result
# ========== 优化4: 多周期协调(大周期趋势过滤) ==========
if not account.position or account.position.side == 'FLAT':
trend_check = self._check_higher_timeframe_trend(symbol, tf, direction, signal)
if not trend_check['aligned']:
result['action'] = 'TREND_CONFLICT'
result['details'] = trend_check
logger.info(
f"[{symbol}][{config['name']}] 与大周期趋势冲突: {direction} vs {trend_check['higher_tf_trend']}"
)
return result
# ========== 优化5: 信号确认机制 ==========
self._record_signal_history(account, direction, signal_entry_price, signal_stop_loss, signal_take_profit)
if not account.position or account.position.side == 'FLAT':
confirm_check = self._check_signal_confirmation(account, direction)
if not confirm_check['confirmed']:
result['action'] = 'AWAITING_CONFIRMATION'
result['details'] = confirm_check
logger.debug(
f"[{symbol}][{config['name']}] 等待信号确认: {confirm_check['count']}/{SIGNAL_CONFIRMATION_COUNT}"
)
return result
# 3. 如果有持仓
if account.position and account.position.side != 'FLAT':
# 反向信号:只平仓不开反向仓
if (account.position.side == 'LONG' and direction == 'SHORT') or \
(account.position.side == 'SHORT' and direction == 'LONG'):
close_result = self._close_position(symbol, tf, current_price, 'SIGNAL_REVERSE')
result['action'] = 'CLOSE'
result['details'] = close_result
logger.info(
f"[{symbol}][{config['name']}] 反向信号平仓,等待下一周期新信号"
)
return result
else:
# ========== 优化6: 加仓价格检查支持entry_levels==========
entry_levels = tf_signal.get('entry_levels', [])
add_result = self._add_position_with_price_check(
symbol, tf, current_price,
signal_stop_loss, signal_take_profit,
tf_signal.get('reasoning', '')[:100],
entry_levels=entry_levels
)
if add_result:
if add_result.get('skipped'):
result['action'] = 'ADD_PRICE_NOT_IMPROVED'
result['details'] = add_result
else:
result['action'] = 'ADD'
result['details'] = add_result
else:
# 已达到最大仓位,保持持仓
result['action'] = 'HOLD'
result['details'] = {
'position': account.position.to_dict(),
'unrealized_pnl': self._calc_unrealized_pnl(symbol, tf, current_price),
'reason': '已达最大仓位层级',
}
return result
# 4. 无持仓,开新仓(首仓)
# ========== 优化7: 动态止损(验证止损距离合理性) ==========
adjusted_sl, adjusted_tp = self._adjust_stop_loss_take_profit(
direction, current_price, signal_stop_loss, signal_take_profit,
account.last_atr, config
)
# ========== 优化8: 检查首仓价格是否匹配 entry_levels ==========
entry_levels = tf_signal.get('entry_levels', [])
if entry_levels and len(entry_levels) > 0:
first_entry = entry_levels[0]
target_price = first_entry.get('price', 0)
if target_price > 0:
price_tolerance = 0.003 # 0.3% 容差
if direction == 'LONG':
# 做多:当前价格需要 ≤ 首仓目标价格
if current_price > target_price * (1 + price_tolerance):
result['action'] = 'WAIT_ENTRY_LEVEL'
result['details'] = {
'reason': f'等待首仓价位 ${target_price:.2f}',
'current_price': current_price,
'target_price': target_price,
'direction': direction,
'entry_levels': entry_levels,
}
logger.info(
f"[{symbol}][{config['name']}] 等待首仓价位: "
f"目标=${target_price:.2f}, 当前=${current_price:.2f}"
)
return result
else: # SHORT
# 做空:当前价格需要 ≥ 首仓目标价格
if current_price < target_price * (1 - price_tolerance):
result['action'] = 'WAIT_ENTRY_LEVEL'
result['details'] = {
'reason': f'等待首仓价位 ${target_price:.2f}',
'current_price': current_price,
'target_price': target_price,
'direction': direction,
'entry_levels': entry_levels,
}
logger.info(
f"[{symbol}][{config['name']}] 等待首仓价位: "
f"目标=${target_price:.2f}, 当前=${current_price:.2f}"
)
return result
open_result = self._open_position(
symbol, tf, direction, current_price,
adjusted_sl, adjusted_tp,
tf_signal.get('reasoning', '')[:100]
)
if open_result:
result['action'] = 'OPEN'
result['details'] = open_result
else:
result['action'] = 'WAIT'
return result
# ==================== 新增优化方法 ====================
def _check_signal_expiry(self, signal_timestamp: str, config: Dict) -> Dict:
"""检查信号是否过期"""
try:
# 解析信号时间
if 'T' in signal_timestamp:
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
else:
signal_time = datetime.fromisoformat(signal_timestamp)
# 移除时区信息进行比较
if signal_time.tzinfo:
signal_time = signal_time.replace(tzinfo=None)
now = datetime.now()
age = now - signal_time
age_minutes = age.total_seconds() / 60
expiry_minutes = config.get('signal_expiry_minutes', 15)
return {
'valid': age_minutes <= expiry_minutes,
'age_minutes': age_minutes,
'expiry_minutes': expiry_minutes,
'signal_time': signal_timestamp,
}
except Exception as e:
logger.warning(f"信号时间解析失败: {e}")
return {'valid': True, 'age_minutes': 0, 'expiry_minutes': 0}
def _check_risk_reward_ratio(
self, direction: str, entry_price: float,
stop_loss: float, take_profit: float, config: Dict
) -> Dict:
"""验证风险回报比"""
min_ratio = config.get('min_risk_reward_ratio', 1.5)
if direction == 'LONG':
risk = entry_price - stop_loss
reward = take_profit - entry_price
else: # SHORT
risk = stop_loss - entry_price
reward = entry_price - take_profit
if risk <= 0:
return {'valid': False, 'ratio': 0, 'min_ratio': min_ratio, 'reason': '止损设置错误'}
ratio = reward / risk
return {
'valid': ratio >= min_ratio,
'ratio': round(ratio, 2),
'min_ratio': min_ratio,
'risk': risk,
'reward': reward,
}
def _calculate_dynamic_deviation(self, config: Dict, atr: float, current_price: float) -> float:
"""计算动态价格偏差阈值"""
base_deviation = config.get('base_price_deviation', 0.005)
atr_multiplier = config.get('atr_deviation_multiplier', 0.5)
if atr > 0 and current_price > 0:
# ATR 百分比
atr_pct = atr / current_price
# 动态偏差 = 基础偏差 + ATR偏差
dynamic_deviation = base_deviation + (atr_pct * atr_multiplier)
return min(dynamic_deviation, 0.02) # 最大2%
else:
return base_deviation
def _get_atr_from_signal(self, signal: Dict) -> float:
"""从信号中提取ATR值"""
try:
# 尝试多个路径
atr = signal.get('market_analysis', {}).get('volatility_analysis', {}).get('atr', 0)
if not atr:
atr = signal.get('aggregated_signal', {}).get('levels', {}).get('atr', 0)
if not atr:
atr = signal.get('quantitative_signal', {}).get('indicators', {}).get('atr', 0)
return float(atr) if atr else 0.0
except:
return 0.0
def _check_higher_timeframe_trend(
self, symbol: str, tf: TimeFrame, direction: str, signal: Dict
) -> Dict:
"""检查大周期趋势是否与当前方向一致
支持两种格式:
1. 新格式: trades数组
2. 旧格式: opportunities对象
"""
higher_tfs = TIMEFRAME_HIERARCHY.get(tf, [])
if not higher_tfs:
return {'aligned': True, 'reason': '无需检查大周期'}
# 从信号中获取各周期的方向
llm_signal = signal.get('llm_signal') or signal.get('aggregated_signal', {}).get('llm_signal', {})
if not llm_signal:
return {'aligned': True, 'reason': '无LLM信号数据'}
# ========== 新格式: trades数组 ==========
trades = llm_signal.get('trades', [])
if trades and isinstance(trades, list) and len(trades) >= 3:
trades_by_tf = {t.get('timeframe'): t for t in trades if t.get('timeframe')}
for higher_tf in higher_tfs:
higher_tf_key = higher_tf.value # 'short', 'medium', 'long'
higher_trade = trades_by_tf.get(higher_tf_key, {})
if higher_trade and higher_trade.get('status') == 'ACTIVE':
higher_direction = higher_trade.get('direction')
if higher_direction and higher_direction != direction and higher_direction != 'NONE':
return {
'aligned': False,
'higher_tf': higher_tf.value,
'higher_tf_trend': higher_direction,
'current_direction': direction,
'reason': f'{higher_tf.value}周期为{higher_direction},与{direction}冲突',
}
return {'aligned': True, 'reason': '大周期趋势一致或无明确方向'}
# ========== 旧格式: opportunities对象 ==========
opportunities = llm_signal.get('opportunities', {}) if llm_signal else {}
for higher_tf in higher_tfs:
higher_config = TIMEFRAME_CONFIG[higher_tf]
for key in higher_config['signal_keys']:
higher_opp = opportunities.get(key, {})
if higher_opp and higher_opp.get('exists'):
higher_direction = higher_opp.get('direction')
if higher_direction and higher_direction != direction:
# 大周期方向相反,不建议开仓
# 但如果大周期是 HOLD/观望,则允许
return {
'aligned': False,
'higher_tf': higher_tf.value,
'higher_tf_trend': higher_direction,
'current_direction': direction,
'reason': f'{higher_tf.value}周期为{higher_direction},与{direction}冲突',
}
return {'aligned': True, 'reason': '大周期趋势一致或无明确方向'}
def _record_signal_history(
self, account: TimeFrameAccount, direction: str,
entry_price: float, stop_loss: float, take_profit: float
):
"""记录信号历史"""
history = SignalHistory(
direction=direction,
timestamp=datetime.now().isoformat(),
entry_price=entry_price,
stop_loss=stop_loss,
take_profit=take_profit,
)
account.signal_history.append(history)
# 只保留最近10条
if len(account.signal_history) > 10:
account.signal_history = account.signal_history[-10:]
def _check_signal_confirmation(self, account: TimeFrameAccount, direction: str) -> Dict:
"""检查信号是否已确认连续N次相同方向"""
if len(account.signal_history) < SIGNAL_CONFIRMATION_COUNT:
return {
'confirmed': False,
'count': len(account.signal_history),
'required': SIGNAL_CONFIRMATION_COUNT,
}
# 检查最近N次信号是否都是同一方向
recent = account.signal_history[-SIGNAL_CONFIRMATION_COUNT:]
same_direction_count = sum(1 for h in recent if h.direction == direction)
return {
'confirmed': same_direction_count >= SIGNAL_CONFIRMATION_COUNT,
'count': same_direction_count,
'required': SIGNAL_CONFIRMATION_COUNT,
'recent_signals': [h.direction for h in recent],
}
def _add_position_with_price_check(
self, symbol: str, tf: TimeFrame, price: float,
stop_loss: float, take_profit: float, reasoning: str,
entry_levels: List[Dict] = None
) -> Optional[Dict]:
"""带价格检查的加仓 - 支持信号中的 entry_levels
Args:
entry_levels: LLM信号中的多级进场价位列表
[{'price': 90000, 'ratio': 0.4, 'level': 0}, ...]
"""
account = self.accounts[symbol][tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return None
# 检查是否已达最大层级
current_level = pos.pyramid_level
if current_level >= len(PYRAMID_LEVELS):
return None
# ========== 优化:使用信号中的 entry_levels ==========
if entry_levels and len(entry_levels) > current_level:
# 获取当前应该的加仓价位
target_entry = entry_levels[current_level]
target_price = target_entry.get('price', 0)
if target_price > 0:
# 检查当前价格是否到达目标加仓价位
price_tolerance = 0.002 # 0.2% 容差
if pos.side == 'LONG':
# 做多:当前价格需要 ≤ 目标价格 (价格下跌才加仓)
if price > target_price * (1 + price_tolerance):
return {
'skipped': True,
'reason': f'未触及加仓价位 L{current_level+1}',
'current_price': price,
'target_price': target_price,
'next_level': current_level + 1,
'entry_levels': entry_levels,
}
else: # SHORT
# 做空:当前价格需要 ≥ 目标价格 (价格上涨才加仓)
if price < target_price * (1 - price_tolerance):
return {
'skipped': True,
'reason': f'未触及加仓价位 L{current_level+1}',
'current_price': price,
'target_price': target_price,
'next_level': current_level + 1,
'entry_levels': entry_levels,
}
# 价格到达目标价位,执行加仓
logger.info(
f"[{symbol}] 触及加仓价位 L{current_level+1}: "
f"目标=${target_price:.2f}, 当前=${price:.2f}"
)
return self._add_position(symbol, tf, price, stop_loss, take_profit, reasoning)
# ========== 回退:使用均价改善检查 ==========
avg_price = pos.entry_price
improvement_required = PYRAMID_PRICE_IMPROVEMENT
if pos.side == 'LONG':
# 做多:加仓价格需要比均价低
price_improvement = (avg_price - price) / avg_price
if price_improvement < improvement_required:
return {
'skipped': True,
'reason': f'加仓价格未改善: 需低于均价{improvement_required*100:.1f}%',
'avg_price': avg_price,
'current_price': price,
'improvement_pct': price_improvement * 100,
'required_improvement_pct': improvement_required * 100,
}
else: # SHORT
# 做空:加仓价格需要比均价高
price_improvement = (price - avg_price) / avg_price
if price_improvement < improvement_required:
return {
'skipped': True,
'reason': f'加仓价格未改善: 需高于均价{improvement_required*100:.1f}%',
'avg_price': avg_price,
'current_price': price,
'improvement_pct': price_improvement * 100,
'required_improvement_pct': improvement_required * 100,
}
# 价格检查通过,执行加仓
return self._add_position(symbol, tf, price, stop_loss, take_profit, reasoning)
def _adjust_stop_loss_take_profit(
self, direction: str, entry_price: float,
signal_sl: float, signal_tp: float,
atr: float, config: Dict
) -> tuple:
"""调整止损止盈基于ATR验证合理性"""
if atr <= 0:
return signal_sl, signal_tp
# 计算最小止损距离 (1.5 ATR)
min_sl_distance = atr * 1.5
# 计算当前止损距离
if direction == 'LONG':
current_sl_distance = entry_price - signal_sl
# 如果止损太近,调整
if current_sl_distance < min_sl_distance:
adjusted_sl = entry_price - min_sl_distance
logger.info(f"止损调整: ${signal_sl:.2f} -> ${adjusted_sl:.2f} (基于ATR)")
signal_sl = adjusted_sl
else: # SHORT
current_sl_distance = signal_sl - entry_price
if current_sl_distance < min_sl_distance:
adjusted_sl = entry_price + min_sl_distance
logger.info(f"止损调整: ${signal_sl:.2f} -> ${adjusted_sl:.2f} (基于ATR)")
signal_sl = adjusted_sl
return signal_sl, signal_tp
def _extract_timeframe_signal(
self, signal: Dict[str, Any], signal_keys: List[str]
) -> Optional[Dict[str, Any]]:
"""提取特定周期的信号
支持两种格式:
1. 新格式: trades数组 (优先)
2. 旧格式: opportunities对象 (向后兼容)
"""
try:
# 从 llm_signal 中提取
llm_signal = signal.get('llm_signal') or signal.get('aggregated_signal', {}).get('llm_signal')
if llm_signal and isinstance(llm_signal, dict):
# ========== 新格式: trades数组 ==========
trades = llm_signal.get('trades', [])
if trades and isinstance(trades, list) and len(trades) >= 3:
# 确定当前 signal_keys 对应的 timeframe
tf_mapping = {
'short_term_5m_15m_1h': 'short',
'intraday': 'short',
'medium_term_4h_1d': 'medium',
'swing': 'medium',
'long_term_1d_1w': 'long',
}
target_tf = None
for key in signal_keys:
if key in tf_mapping:
target_tf = tf_mapping[key]
break
if target_tf:
# 从 trades 数组中找到对应周期
for trade in trades:
if trade.get('timeframe') == target_tf:
# 转换为统一格式
return self._convert_trade_to_opportunity(trade)
# ========== 旧格式: opportunities对象 ==========
opportunities = llm_signal.get('opportunities', {})
for key in signal_keys:
if key in opportunities and opportunities[key]:
return opportunities[key]
# 备选路径
agg = signal.get('aggregated_signal', {})
if agg:
llm = agg.get('llm_signal', {})
if llm:
# 先检查新格式
trades = llm.get('trades', [])
if trades and isinstance(trades, list) and len(trades) >= 3:
tf_mapping = {
'short_term_5m_15m_1h': 'short',
'intraday': 'short',
'medium_term_4h_1d': 'medium',
'swing': 'medium',
'long_term_1d_1w': 'long',
}
target_tf = None
for key in signal_keys:
if key in tf_mapping:
target_tf = tf_mapping[key]
break
if target_tf:
for trade in trades:
if trade.get('timeframe') == target_tf:
return self._convert_trade_to_opportunity(trade)
# 回退到旧格式
opps = llm.get('opportunities', {})
for key in signal_keys:
if key in opps and opps[key]:
return opps[key]
return None
except Exception as e:
logger.error(f"Error extracting signal: {e}")
return None
def _convert_trade_to_opportunity(self, trade: Dict[str, Any]) -> Dict[str, Any]:
"""将新格式 trade 转换为旧格式 opportunity
新格式:
{
"id": "short_001",
"timeframe": "short",
"status": "ACTIVE|INACTIVE",
"direction": "LONG|SHORT|NONE",
"entry": {"price_1": 90000, "price_2": 89700, ...},
"exit": {"stop_loss": 88500, "take_profit_1": 91000, ...},
"position": {"size_pct_1": 40, "size_pct_2": 30, ...},
"risk_reward": 2.5,
"expected_profit_pct": 1.5,
"reasoning": "..."
}
转换为:
{
"exists": True,
"direction": "LONG",
"entry_price": 90000,
"entry_levels": [...],
"stop_loss": 88500,
"take_profit": 91000,
"reasoning": "..."
}
"""
status = trade.get('status', 'INACTIVE')
is_active = status == 'ACTIVE'
if not is_active:
return {
'exists': False,
'direction': None,
'entry_price': 0,
'stop_loss': 0,
'take_profit': 0,
'reasoning': trade.get('reasoning', ''),
}
entry = trade.get('entry', {})
exit_data = trade.get('exit', {})
position = trade.get('position', {})
# 构建 entry_levels金字塔入场价位
entry_levels = []
for i in range(1, 5):
price = entry.get(f'price_{i}', 0)
ratio = position.get(f'size_pct_{i}', [40, 30, 20, 10][i-1]) / 100
if price > 0:
entry_levels.append({
'price': float(price),
'ratio': ratio,
'level': i - 1,
})
# 第一个入场价作为主入场价
entry_price = float(entry.get('price_1', 0))
return {
'exists': True,
'direction': trade.get('direction', 'NONE'),
'entry_price': entry_price,
'entry_levels': entry_levels,
'stop_loss': float(exit_data.get('stop_loss', 0)),
'take_profit': float(exit_data.get('take_profit_1', 0)),
'take_profit_2': float(exit_data.get('take_profit_2', 0)),
'take_profit_3': float(exit_data.get('take_profit_3', 0)),
'risk_reward': trade.get('risk_reward', 0),
'expected_profit_pct': trade.get('expected_profit_pct', 0),
'reasoning': trade.get('reasoning', ''),
}
def _get_max_position_value(self, symbol: str, tf: TimeFrame) -> float:
"""获取最大仓位价值(本金 × 杠杆)"""
account = self.accounts[symbol][tf]
return account.initial_balance * account.leverage
def _get_current_position_value(self, symbol: str, tf: TimeFrame, current_price: float) -> float:
"""获取当前仓位价值"""
account = self.accounts[symbol][tf]
if not account.position or account.position.side == 'FLAT':
return 0.0
return account.position.size * current_price
def _open_position(
self, symbol: str, tf: TimeFrame, direction: str, price: float,
stop_loss: float, take_profit: float, reasoning: str
) -> Optional[Dict]:
"""开首仓(金字塔第一层)"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
# 计算首仓仓位:最大仓位 × 首仓比例
max_position_value = self._get_max_position_value(symbol, tf)
first_level_ratio = PYRAMID_LEVELS[0] # 40%
position_value = max_position_value * first_level_ratio
margin = position_value / account.leverage
size = position_value / price
# 检查可用余额是否足够
available_balance = account.get_available_balance()
if available_balance < margin:
logger.warning(f"[{symbol}][{config['name']}] 可用余额不足: ${available_balance:.2f} < ${margin:.2f}")
return None
if size <= 0:
return None
# 创建首仓入场记录
entry = PositionEntry(
price=price,
size=size,
margin=margin,
timestamp=datetime.now().isoformat(),
level=0,
)
account.position = Position(
side=direction,
entries=[entry],
stop_loss=stop_loss,
take_profit=take_profit,
created_at=datetime.now().isoformat(),
signal_reasoning=reasoning,
)
# 确定单位名称
unit = symbol.replace('USDT', '') if symbol.endswith('USDT') else symbol
logger.info(
f"[{symbol}][{config['name']}] OPEN {direction} [L1/{len(PYRAMID_LEVELS)}]: price=${price:.2f}, "
f"size={size:.6f} {unit}, margin=${margin:.2f}, value=${position_value:.2f}, "
f"SL=${stop_loss:.2f}, TP=${take_profit:.2f}"
)
return {
'symbol': symbol,
'timeframe': tf.value,
'side': direction,
'entry_price': price,
'size': size,
'margin': margin,
'position_value': position_value,
'pyramid_level': 1,
'max_levels': len(PYRAMID_LEVELS),
'stop_loss': stop_loss,
'take_profit': take_profit,
}
def _add_position(
self, symbol: str, tf: TimeFrame, price: float,
stop_loss: float, take_profit: float, reasoning: str
) -> Optional[Dict]:
"""金字塔加仓"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return None
# 检查是否已达最大层级
current_level = pos.pyramid_level
if current_level >= len(PYRAMID_LEVELS):
logger.info(f"[{symbol}][{config['name']}] 已达最大仓位层级 {current_level}/{len(PYRAMID_LEVELS)}")
return None
# 计算加仓仓位
max_position_value = self._get_max_position_value(symbol, tf)
level_ratio = PYRAMID_LEVELS[current_level]
add_position_value = max_position_value * level_ratio
add_margin = add_position_value / account.leverage
add_size = add_position_value / price
# 检查可用余额
available_balance = account.get_available_balance()
if available_balance < add_margin:
logger.warning(
f"[{symbol}][{config['name']}] 加仓余额不足: ${available_balance:.2f} < ${add_margin:.2f}"
)
return None
# 添加入场记录
entry = PositionEntry(
price=price,
size=add_size,
margin=add_margin,
timestamp=datetime.now().isoformat(),
level=current_level,
)
pos.entries.append(entry)
# 更新止盈止损
pos.stop_loss = stop_loss
pos.take_profit = take_profit
# 确定单位名称
unit = symbol.replace('USDT', '') if symbol.endswith('USDT') else symbol
new_level = pos.pyramid_level
logger.info(
f"[{symbol}][{config['name']}] ADD {pos.side} [L{new_level}/{len(PYRAMID_LEVELS)}]: price=${price:.2f}, "
f"add_size={add_size:.6f} {unit}, add_margin=${add_margin:.2f}, "
f"total_size={pos.size:.6f} {unit}, total_margin=${pos.margin:.2f}, "
f"avg_price=${pos.entry_price:.2f}"
)
return {
'symbol': symbol,
'timeframe': tf.value,
'side': pos.side,
'add_price': price,
'add_size': add_size,
'add_margin': add_margin,
'add_position_value': add_position_value,
'total_size': pos.size,
'total_margin': pos.margin,
'avg_entry_price': pos.entry_price,
'pyramid_level': new_level,
'max_levels': len(PYRAMID_LEVELS),
'stop_loss': stop_loss,
'take_profit': take_profit,
}
def _check_close_position(self, symbol: str, tf: TimeFrame, current_price: float) -> Optional[Dict]:
"""检查是否触发止盈止损"""
account = self.accounts[symbol][tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return None
if pos.side == 'LONG':
if current_price >= pos.take_profit:
return self._close_position(symbol, tf, current_price, 'TAKE_PROFIT')
elif current_price <= pos.stop_loss:
return self._close_position(symbol, tf, current_price, 'STOP_LOSS')
else: # SHORT
if current_price <= pos.take_profit:
return self._close_position(symbol, tf, current_price, 'TAKE_PROFIT')
elif current_price >= pos.stop_loss:
return self._close_position(symbol, tf, current_price, 'STOP_LOSS')
return None
def _close_position(self, symbol: str, tf: TimeFrame, price: float, reason: str) -> Dict:
"""平仓"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return {'error': 'No position'}
# 计算盈亏(基于保证金的收益率)
if pos.side == 'LONG':
# 做多:(卖出价 - 买入价) * 数量
pnl = (price - pos.entry_price) * pos.size
else:
# 做空:(买入价 - 卖出价) * 数量
pnl = (pos.entry_price - price) * pos.size
# 收益率 = 盈亏 / 保证金 * 100
pnl_pct = (pnl / pos.margin * 100) if pos.margin > 0 else 0
# 更新已实现盈亏(保证金释放 + 盈亏结算)
account.realized_pnl += pnl
# 记录交易
trade = Trade(
id=f"{symbol[0]}{tf.value[0].upper()}{len(account.trades)+1:04d}",
timeframe=tf.value,
side=pos.side,
entry_price=pos.entry_price,
entry_time=pos.created_at,
exit_price=price,
exit_time=datetime.now().isoformat(),
size=pos.size,
pnl=pnl,
pnl_pct=pnl_pct,
exit_reason=reason,
symbol=symbol,
)
account.trades.append(trade)
self._update_stats(symbol, tf, trade)
# 计算新的账户权益
new_equity = account.get_equity()
result = {
'symbol': symbol,
'timeframe': tf.value,
'side': pos.side,
'entry_price': pos.entry_price,
'exit_price': price,
'size': pos.size,
'margin': pos.margin,
'pnl': pnl,
'pnl_pct': pnl_pct,
'reason': reason,
'new_equity': new_equity,
'realized_pnl': account.realized_pnl,
}
logger.info(
f"[{symbol}][{config['name']}] CLOSE {pos.side}: entry=${pos.entry_price:.2f}, "
f"exit=${price:.2f}, PnL=${pnl:.2f} ({pnl_pct:.2f}%), reason={reason}, "
f"equity=${new_equity:.2f}"
)
account.position = None
return result
def _calc_unrealized_pnl(self, symbol: str, tf: TimeFrame, current_price: float) -> Dict[str, float]:
"""计算未实现盈亏"""
account = self.accounts[symbol][tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return {'pnl': 0, 'pnl_pct': 0}
# 计算未实现盈亏
if pos.side == 'LONG':
pnl = (current_price - pos.entry_price) * pos.size
else:
pnl = (pos.entry_price - current_price) * pos.size
# 收益率 = 盈亏 / 保证金 * 100
pnl_pct = (pnl / pos.margin * 100) if pos.margin > 0 else 0
return {'pnl': pnl, 'pnl_pct': pnl_pct}
def _update_equity_curve(self, symbol: str, tf: TimeFrame, current_price: float):
"""更新权益曲线"""
account = self.accounts[symbol][tf]
unrealized = self._calc_unrealized_pnl(symbol, tf, current_price)
equity = account.get_equity(unrealized['pnl'])
account.equity_curve.append({
'timestamp': datetime.now().isoformat(),
'equity': equity,
'initial_balance': account.initial_balance,
'realized_pnl': account.realized_pnl,
'unrealized_pnl': unrealized['pnl'],
'price': current_price,
})
def _update_stats(self, symbol: str, tf: TimeFrame, trade: Trade):
"""更新统计数据"""
account = self.accounts[symbol][tf]
stats = account.stats
stats['total_trades'] += 1
stats['total_pnl'] += trade.pnl
if trade.pnl > 0:
stats['winning_trades'] += 1
else:
stats['losing_trades'] += 1
if stats['total_trades'] > 0:
stats['win_rate'] = stats['winning_trades'] / stats['total_trades'] * 100
wins = [t for t in account.trades if t.pnl > 0]
losses = [t for t in account.trades if t.pnl <= 0]
if wins:
stats['avg_win'] = sum(t.pnl for t in wins) / len(wins)
if losses:
stats['avg_loss'] = sum(t.pnl for t in losses) / len(losses)
if stats['avg_loss'] != 0:
stats['profit_factor'] = abs(stats['avg_win'] / stats['avg_loss'])
# 更新峰值和回撤(基于账户权益)
equity = account.get_equity()
if equity > stats['peak_balance']:
stats['peak_balance'] = equity
drawdown = (stats['peak_balance'] - equity) / stats['peak_balance'] * 100
if drawdown > stats['max_drawdown']:
stats['max_drawdown'] = drawdown
def get_status(
self,
current_price: float = None,
symbol: str = None,
prices: Dict[str, float] = None
) -> Dict[str, Any]:
"""获取状态
Args:
current_price: 单币种价格(向后兼容)
symbol: 指定币种(若为空则返回所有)
prices: 多币种价格 {symbol: price}
"""
# 如果指定了单个币种
if symbol:
return self._get_symbol_status(symbol, current_price or (prices.get(symbol) if prices else None))
# 返回所有币种汇总
return self._get_all_status(prices or {self.symbols[0]: current_price} if current_price else {})
def _get_symbol_status(self, symbol: str, current_price: float = None) -> Dict[str, Any]:
"""获取单个币种所有周期状态"""
if symbol not in self.accounts:
return {'error': f'Symbol {symbol} not found'}
total_equity = 0
total_initial = 0
total_realized_pnl = 0
total_unrealized_pnl = 0
timeframes_data = {}
for tf in TimeFrame:
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
# 计算未实现盈亏
unrealized = self._calc_unrealized_pnl(symbol, tf, current_price) if current_price else {'pnl': 0, 'pnl_pct': 0}
equity = account.get_equity(unrealized['pnl'])
total_initial += account.initial_balance
total_realized_pnl += account.realized_pnl
total_unrealized_pnl += unrealized['pnl']
total_equity += equity
return_pct = (equity - account.initial_balance) / account.initial_balance * 100 if account.initial_balance > 0 else 0
tf_status = {
'name': config['name'],
'name_en': config['name_en'],
'symbol': symbol,
'initial_balance': account.initial_balance,
'realized_pnl': account.realized_pnl,
'unrealized_pnl': unrealized['pnl'],
'equity': equity,
'available_balance': account.get_available_balance(),
'used_margin': account.get_used_margin(),
'return_pct': return_pct,
'leverage': account.leverage,
'position': None,
'stats': account.stats,
'recent_trades': [t.to_dict() for t in account.trades[-10:]],
'equity_curve': account.equity_curve[-100:],
}
if account.position and account.position.side != 'FLAT':
pos_dict = account.position.to_dict()
if current_price:
pos_dict['current_price'] = current_price
pos_dict['unrealized_pnl'] = unrealized['pnl']
pos_dict['unrealized_pnl_pct'] = unrealized['pnl_pct']
tf_status['position'] = pos_dict
timeframes_data[tf.value] = tf_status
total_return = (total_equity - total_initial) / total_initial * 100 if total_initial > 0 else 0
return {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'total_initial_balance': total_initial,
'total_realized_pnl': total_realized_pnl,
'total_unrealized_pnl': total_unrealized_pnl,
'total_equity': total_equity,
'total_return': total_return,
'timeframes': timeframes_data,
}
def _get_all_status(self, prices: Dict[str, float] = None) -> Dict[str, Any]:
"""获取所有币种汇总状态"""
prices = prices or {}
grand_total_equity = 0
grand_total_initial = 0
grand_total_realized_pnl = 0
grand_total_unrealized_pnl = 0
symbols_data = {}
for symbol in self.symbols:
if symbol not in self.accounts:
continue
current_price = prices.get(symbol)
symbol_status = self._get_symbol_status(symbol, current_price)
symbols_data[symbol] = symbol_status
grand_total_initial += symbol_status.get('total_initial_balance', 0)
grand_total_realized_pnl += symbol_status.get('total_realized_pnl', 0)
grand_total_unrealized_pnl += symbol_status.get('total_unrealized_pnl', 0)
grand_total_equity += symbol_status.get('total_equity', 0)
grand_total_return = (grand_total_equity - grand_total_initial) / grand_total_initial * 100 if grand_total_initial > 0 else 0
# 向后兼容:保留 timeframes 字段(使用第一个币种)
first_symbol = self.symbols[0] if self.symbols else None
legacy_timeframes = symbols_data.get(first_symbol, {}).get('timeframes', {}) if first_symbol else {}
return {
'timestamp': datetime.now().isoformat(),
'symbols': symbols_data,
'timeframes': legacy_timeframes, # 向后兼容
'grand_total_initial_balance': grand_total_initial,
'grand_total_realized_pnl': grand_total_realized_pnl,
'grand_total_unrealized_pnl': grand_total_unrealized_pnl,
'grand_total_equity': grand_total_equity,
'grand_total_return': grand_total_return,
# 向后兼容字段
'total_initial_balance': grand_total_initial,
'total_realized_pnl': grand_total_realized_pnl,
'total_unrealized_pnl': grand_total_unrealized_pnl,
'total_equity': grand_total_equity,
'total_return': grand_total_return,
}
def reset(self, symbol: str = None):
"""重置账户
Args:
symbol: 指定币种,若为空则重置所有
"""
if symbol:
if symbol in self.accounts:
self._init_symbol_accounts(symbol)
logger.info(f"{symbol} accounts reset")
else:
self._init_all_accounts()
logger.info("All accounts reset")
self._save_state()
# 兼容旧的 PaperTrader 接口
PaperTrader = MultiTimeframePaperTrader