tradusai/trading/paper_trading.py
2025-12-09 22:46:04 +08:00

1056 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Paper Trading Module - 多币种多周期独立仓位管理
支持多币种 (BTC/USDT, ETH/USDT 等) 和三个独立周期的模拟交易:
- 短周期 (5m/15m/1h): short_term_5m_15m_1h / intraday
- 中周期 (4h/1d): medium_term_4h_1d / swing
- 长周期 (1d/1w): long_term_1d_1w
每个币种的每个周期独立管理:
- 独立仓位
- 独立止盈止损
- 独立统计数据
- 独立权益曲线
"""
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
from pathlib import Path
from dataclasses import dataclass, asdict, field
from enum import Enum
from config.settings import settings
logger = logging.getLogger(__name__)
class TimeFrame(Enum):
"""交易周期"""
SHORT = "short" # 短周期 5m/15m/1h
MEDIUM = "medium" # 中周期 4h/1d
LONG = "long" # 长周期 1d/1w
TIMEFRAME_CONFIG = {
TimeFrame.SHORT: {
'name': '短周期',
'name_en': 'Short-term',
'signal_keys': ['short_term_5m_15m_1h', 'intraday'],
'leverage': 10,
'initial_balance': 10000.0, # 独立初始资金
'max_price_deviation': 0.001, # 0.1% - 短周期要求精准入场
},
TimeFrame.MEDIUM: {
'name': '中周期',
'name_en': 'Medium-term',
'signal_keys': ['medium_term_4h_1d', 'swing'],
'leverage': 10,
'initial_balance': 10000.0, # 独立初始资金
'max_price_deviation': 0.003, # 0.3% - 中周期适中容错
},
TimeFrame.LONG: {
'name': '长周期',
'name_en': 'Long-term',
'signal_keys': ['long_term_1d_1w'],
'leverage': 10,
'initial_balance': 10000.0, # 独立初始资金
'max_price_deviation': 0.005, # 0.5% - 长周期追求大趋势
},
}
# 金字塔加仓配置每次加仓的仓位比例总计100%
PYRAMID_LEVELS = [0.4, 0.3, 0.2, 0.1] # 首仓40%加仓30%、20%、10%
@dataclass
class PositionEntry:
"""单次入场记录"""
price: float
size: float # BTC 数量
margin: float # 本次占用保证金
timestamp: str
level: int # 金字塔层级 0=首仓, 1=加仓1, ...
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'PositionEntry':
return cls(**data)
@dataclass
class Position:
"""持仓信息(支持金字塔加仓)"""
side: str # LONG, SHORT, FLAT
entries: List['PositionEntry'] = field(default_factory=list) # 入场记录
stop_loss: float = 0.0
take_profit: float = 0.0
created_at: str = ""
signal_reasoning: str = ""
@property
def entry_price(self) -> float:
"""加权平均入场价"""
if not self.entries:
return 0.0
total_value = sum(e.price * e.size for e in self.entries)
total_size = sum(e.size for e in self.entries)
return total_value / total_size if total_size > 0 else 0.0
@property
def size(self) -> float:
"""总持仓数量"""
return sum(e.size for e in self.entries)
@property
def margin(self) -> float:
"""总占用保证金"""
return sum(e.margin for e in self.entries)
@property
def pyramid_level(self) -> int:
"""当前金字塔层级"""
return len(self.entries)
def to_dict(self) -> dict:
return {
'side': self.side,
'entry_price': self.entry_price,
'size': self.size,
'margin': self.margin,
'pyramid_level': self.pyramid_level,
'entries': [e.to_dict() for e in self.entries],
'stop_loss': self.stop_loss,
'take_profit': self.take_profit,
'created_at': self.created_at,
'signal_reasoning': self.signal_reasoning,
}
@classmethod
def from_dict(cls, data: dict) -> 'Position':
entries = [PositionEntry.from_dict(e) for e in data.get('entries', [])]
# 兼容旧数据格式
if not entries and data.get('entry_price') and data.get('size'):
entries = [PositionEntry(
price=data['entry_price'],
size=data['size'],
margin=data.get('margin', 0),
timestamp=data.get('created_at', ''),
level=0,
)]
return cls(
side=data['side'],
entries=entries,
stop_loss=data.get('stop_loss', 0),
take_profit=data.get('take_profit', 0),
created_at=data.get('created_at', ''),
signal_reasoning=data.get('signal_reasoning', ''),
)
@dataclass
class Trade:
"""交易记录"""
id: str
timeframe: str
side: str
entry_price: float
entry_time: str
exit_price: float
exit_time: str
size: float
pnl: float
pnl_pct: float
exit_reason: str
symbol: str = "BTCUSDT" # 交易币种
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'Trade':
# 兼容旧数据
if 'symbol' not in data:
data['symbol'] = 'BTCUSDT'
return cls(**data)
@dataclass
class TimeFrameAccount:
"""单个币种单个周期的账户
资金结构:
- initial_balance: 初始本金
- realized_pnl: 已实现盈亏(平仓后累计)
- position.margin: 当前持仓占用保证金
- unrealized_pnl: 未实现盈亏(需实时计算)
账户权益 = initial_balance + realized_pnl + unrealized_pnl
可用余额 = initial_balance + realized_pnl - position.margin
"""
timeframe: str
initial_balance: float
leverage: int
symbol: str = "BTCUSDT" # 交易币种
realized_pnl: float = 0.0 # 已实现盈亏
position: Optional[Position] = None
trades: List[Trade] = field(default_factory=list)
stats: Dict = field(default_factory=dict)
equity_curve: List[Dict] = field(default_factory=list)
def __post_init__(self):
if not self.stats:
self.stats = self._init_stats()
def _init_stats(self) -> dict:
return {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'total_pnl': 0.0,
'max_drawdown': 0.0,
'peak_balance': self.initial_balance,
'win_rate': 0.0,
'avg_win': 0.0,
'avg_loss': 0.0,
'profit_factor': 0.0,
}
def get_used_margin(self) -> float:
"""获取已占用保证金"""
if self.position and self.position.side != 'FLAT':
return self.position.margin
return 0.0
def get_available_balance(self) -> float:
"""获取可用余额(可用于开新仓)"""
return self.initial_balance + self.realized_pnl - self.get_used_margin()
def get_equity(self, unrealized_pnl: float = 0.0) -> float:
"""获取账户权益(包含未实现盈亏)"""
return self.initial_balance + self.realized_pnl + unrealized_pnl
def to_dict(self) -> dict:
return {
'timeframe': self.timeframe,
'symbol': self.symbol,
'initial_balance': self.initial_balance,
'realized_pnl': self.realized_pnl,
'leverage': self.leverage,
'position': self.position.to_dict() if self.position else None,
'trades': [t.to_dict() for t in self.trades[-100:]],
'stats': self.stats,
'equity_curve': self.equity_curve[-500:],
}
@classmethod
def from_dict(cls, data: dict) -> 'TimeFrameAccount':
# 兼容旧数据格式
realized_pnl = data.get('realized_pnl', 0.0)
# 如果是旧格式,从 balance 推算 realized_pnl
if 'realized_pnl' not in data and 'balance' in data:
realized_pnl = data['balance'] - data['initial_balance']
account = cls(
timeframe=data['timeframe'],
initial_balance=data['initial_balance'],
leverage=data['leverage'],
symbol=data.get('symbol', 'BTCUSDT'), # 兼容旧数据
realized_pnl=realized_pnl,
stats=data.get('stats', {}),
equity_curve=data.get('equity_curve', []),
)
if data.get('position'):
account.position = Position.from_dict(data['position'])
account.trades = [Trade.from_dict(t) for t in data.get('trades', [])]
return account
class MultiTimeframePaperTrader:
"""多币种多周期模拟盘交易器"""
def __init__(
self,
initial_balance: float = 10000.0,
state_file: str = None,
symbols: List[str] = None
):
self.initial_balance = initial_balance
# 支持的币种列表
self.symbols = symbols or settings.symbols_list
logger.info(f"支持的交易对: {', '.join(self.symbols)}")
# 状态文件
if state_file:
self.state_file = Path(state_file)
else:
self.state_file = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
# 多币种多周期账户: {symbol: {TimeFrame: TimeFrameAccount}}
self.accounts: Dict[str, Dict[TimeFrame, TimeFrameAccount]] = {}
# 加载或初始化状态
self._load_state()
logger.info(f"Multi-symbol Multi-timeframe Paper Trader initialized: {len(self.symbols)} symbols")
def _load_state(self):
"""加载持久化状态"""
if self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
state = json.load(f)
# 检查是否是新的多币种格式
if 'symbols' in state:
# 新格式: {symbols: {BTCUSDT: {short: {...}, medium: {...}, long: {...}}, ...}}
for symbol in self.symbols:
symbol_data = state.get('symbols', {}).get(symbol, {})
self.accounts[symbol] = {}
for tf in TimeFrame:
tf_data = symbol_data.get(tf.value)
if tf_data:
self.accounts[symbol][tf] = TimeFrameAccount.from_dict(tf_data)
else:
self._init_account(symbol, tf)
else:
# 旧格式: {accounts: {short: {...}, medium: {...}, long: {...}}}
# 将旧数据迁移到第一个币种 (BTCUSDT)
first_symbol = self.symbols[0] if self.symbols else 'BTCUSDT'
self.accounts[first_symbol] = {}
for tf in TimeFrame:
tf_data = state.get('accounts', {}).get(tf.value)
if tf_data:
tf_data['symbol'] = first_symbol # 添加 symbol 字段
self.accounts[first_symbol][tf] = TimeFrameAccount.from_dict(tf_data)
else:
self._init_account(first_symbol, tf)
# 初始化其他币种
for symbol in self.symbols[1:]:
self._init_symbol_accounts(symbol)
logger.info(f"Loaded state from {self.state_file}")
except Exception as e:
logger.error(f"Failed to load state: {e}")
self._init_all_accounts()
else:
self._init_all_accounts()
def _init_all_accounts(self):
"""初始化所有币种所有周期账户"""
for symbol in self.symbols:
self._init_symbol_accounts(symbol)
def _init_symbol_accounts(self, symbol: str):
"""初始化单个币种的所有周期账户"""
self.accounts[symbol] = {}
for tf in TimeFrame:
self._init_account(symbol, tf)
def _init_account(self, symbol: str, tf: TimeFrame):
"""初始化单个币种单个周期账户"""
config = TIMEFRAME_CONFIG[tf]
# 每个币种每个周期独立初始资金 10000 USD10倍杠杆最大仓位价值 100000 USD
self.accounts[symbol][tf] = TimeFrameAccount(
timeframe=tf.value,
initial_balance=config['initial_balance'],
leverage=config['leverage'],
symbol=symbol,
realized_pnl=0.0,
)
def _save_state(self):
"""保存状态到文件"""
self.state_file.parent.mkdir(parents=True, exist_ok=True)
# 新格式: {symbols: {BTCUSDT: {short: {...}, ...}, ETHUSDT: {...}}, accounts: {...}}
symbols_data = {}
for symbol, tf_accounts in self.accounts.items():
symbols_data[symbol] = {
tf.value: acc.to_dict() for tf, acc in tf_accounts.items()
}
# 同时保留旧格式兼容 (使用第一个币种)
first_symbol = self.symbols[0] if self.symbols else 'BTCUSDT'
legacy_accounts = {}
if first_symbol in self.accounts:
legacy_accounts = {
tf.value: acc.to_dict() for tf, acc in self.accounts[first_symbol].items()
}
state = {
'symbols': symbols_data,
'accounts': legacy_accounts, # 向后兼容
'last_updated': datetime.now().isoformat(),
}
with open(self.state_file, 'w') as f:
json.dump(state, f, indent=2, ensure_ascii=False)
def process_signal(
self,
signal: Dict[str, Any],
current_price: float,
symbol: str = None
) -> Dict[str, Any]:
"""处理单个币种的交易信号 - 检查所有周期
Args:
signal: 该币种的信号数据
current_price: 该币种当前价格
symbol: 交易对,如 'BTCUSDT'。若未指定则使用第一个币种
"""
symbol = symbol or (self.symbols[0] if self.symbols else 'BTCUSDT')
# 确保该币种的账户已初始化
if symbol not in self.accounts:
self._init_symbol_accounts(symbol)
results = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'current_price': current_price,
'timeframes': {},
}
for tf in TimeFrame:
result = self._process_timeframe_signal(symbol, tf, signal, current_price)
results['timeframes'][tf.value] = result
self._save_state()
return results
def process_all_signals(
self,
signals: Dict[str, Dict[str, Any]],
prices: Dict[str, float]
) -> Dict[str, Any]:
"""处理所有币种的信号
Args:
signals: {symbol: signal_data} 各币种的信号
prices: {symbol: price} 各币种的当前价格
"""
results = {
'timestamp': datetime.now().isoformat(),
'symbols': {},
}
for symbol in self.symbols:
if symbol in signals and symbol in prices:
result = self.process_signal(
signal=signals[symbol],
current_price=prices[symbol],
symbol=symbol
)
results['symbols'][symbol] = result
return results
def _process_timeframe_signal(
self, symbol: str, tf: TimeFrame, signal: Dict[str, Any], current_price: float
) -> Dict[str, Any]:
"""处理单个币种单个周期的信号"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
result = {
'action': 'NONE',
'details': None,
}
# 更新权益曲线
self._update_equity_curve(symbol, tf, current_price)
# 1. 检查止盈止损
if account.position and account.position.side != 'FLAT':
close_result = self._check_close_position(symbol, tf, current_price)
if close_result:
result['action'] = 'CLOSE'
result['details'] = close_result
return result
# 2. 提取该周期的信号
tf_signal = self._extract_timeframe_signal(signal, config['signal_keys'])
if not tf_signal or not tf_signal.get('exists'):
result['action'] = 'NO_SIGNAL'
return result
direction = tf_signal.get('direction')
if not direction:
result['action'] = 'NO_SIGNAL'
return result
signal_stop_loss = tf_signal.get('stop_loss', 0)
signal_take_profit = tf_signal.get('take_profit', 0)
signal_entry_price = tf_signal.get('entry_price', 0)
# 验证止盈止损
if signal_stop_loss <= 0 or signal_take_profit <= 0:
result['action'] = 'NO_SIGNAL'
result['details'] = {'reason': '缺少有效止盈止损'}
return result
# 检查价格偏差:当前价格与建议入场价偏差超过阈值则不开仓
max_deviation = config.get('max_price_deviation', 0.002)
if signal_entry_price > 0:
price_deviation = abs(current_price - signal_entry_price) / signal_entry_price
if price_deviation > max_deviation:
result['action'] = 'PRICE_DEVIATION'
result['details'] = {
'reason': f'价格偏差过大: {price_deviation*100:.2f}% > {max_deviation*100:.1f}%',
'signal_entry': signal_entry_price,
'current_price': current_price,
'deviation_pct': price_deviation * 100,
'max_deviation_pct': max_deviation * 100,
}
logger.info(
f"[{config['name']}] 跳过开仓: 价格偏差 {price_deviation*100:.2f}% > {max_deviation*100:.1f}% "
f"(信号价: ${signal_entry_price:.2f}, 当前价: ${current_price:.2f})"
)
return result
# 3. 如果有持仓
if account.position and account.position.side != 'FLAT':
# 反向信号:只平仓不开反向仓
if (account.position.side == 'LONG' and direction == 'SHORT') or \
(account.position.side == 'SHORT' and direction == 'LONG'):
close_result = self._close_position(symbol, tf, current_price, 'SIGNAL_REVERSE')
result['action'] = 'CLOSE'
result['details'] = close_result
logger.info(
f"[{symbol}][{config['name']}] 反向信号平仓,等待下一周期新信号"
)
return result
else:
# 同方向信号:尝试金字塔加仓
add_result = self._add_position(
symbol, tf, current_price,
signal_stop_loss, signal_take_profit,
tf_signal.get('reasoning', '')[:100]
)
if add_result:
result['action'] = 'ADD'
result['details'] = add_result
else:
# 已达到最大仓位,保持持仓
result['action'] = 'HOLD'
result['details'] = {
'position': account.position.to_dict(),
'unrealized_pnl': self._calc_unrealized_pnl(symbol, tf, current_price),
'reason': '已达最大仓位层级',
}
return result
# 4. 无持仓,开新仓(首仓)
open_result = self._open_position(
symbol, tf, direction, current_price,
signal_stop_loss, signal_take_profit,
tf_signal.get('reasoning', '')[:100]
)
if open_result:
result['action'] = 'OPEN'
result['details'] = open_result
else:
result['action'] = 'WAIT'
return result
def _extract_timeframe_signal(
self, signal: Dict[str, Any], signal_keys: List[str]
) -> Optional[Dict[str, Any]]:
"""提取特定周期的信号"""
try:
# 从 llm_signal.opportunities 中提取
llm_signal = signal.get('llm_signal') or signal.get('aggregated_signal', {}).get('llm_signal')
if llm_signal and isinstance(llm_signal, dict):
opportunities = llm_signal.get('opportunities', {})
for key in signal_keys:
if key in opportunities and opportunities[key]:
return opportunities[key]
# 备选路径
agg = signal.get('aggregated_signal', {})
if agg:
llm = agg.get('llm_signal', {})
if llm:
opps = llm.get('opportunities', {})
for key in signal_keys:
if key in opps and opps[key]:
return opps[key]
return None
except Exception as e:
logger.error(f"Error extracting signal: {e}")
return None
def _get_max_position_value(self, symbol: str, tf: TimeFrame) -> float:
"""获取最大仓位价值(本金 × 杠杆)"""
account = self.accounts[symbol][tf]
return account.initial_balance * account.leverage
def _get_current_position_value(self, symbol: str, tf: TimeFrame, current_price: float) -> float:
"""获取当前仓位价值"""
account = self.accounts[symbol][tf]
if not account.position or account.position.side == 'FLAT':
return 0.0
return account.position.size * current_price
def _open_position(
self, symbol: str, tf: TimeFrame, direction: str, price: float,
stop_loss: float, take_profit: float, reasoning: str
) -> Optional[Dict]:
"""开首仓(金字塔第一层)"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
# 计算首仓仓位:最大仓位 × 首仓比例
max_position_value = self._get_max_position_value(symbol, tf)
first_level_ratio = PYRAMID_LEVELS[0] # 40%
position_value = max_position_value * first_level_ratio
margin = position_value / account.leverage
size = position_value / price
# 检查可用余额是否足够
available_balance = account.get_available_balance()
if available_balance < margin:
logger.warning(f"[{symbol}][{config['name']}] 可用余额不足: ${available_balance:.2f} < ${margin:.2f}")
return None
if size <= 0:
return None
# 创建首仓入场记录
entry = PositionEntry(
price=price,
size=size,
margin=margin,
timestamp=datetime.now().isoformat(),
level=0,
)
account.position = Position(
side=direction,
entries=[entry],
stop_loss=stop_loss,
take_profit=take_profit,
created_at=datetime.now().isoformat(),
signal_reasoning=reasoning,
)
# 确定单位名称
unit = symbol.replace('USDT', '') if symbol.endswith('USDT') else symbol
logger.info(
f"[{symbol}][{config['name']}] OPEN {direction} [L1/{len(PYRAMID_LEVELS)}]: price=${price:.2f}, "
f"size={size:.6f} {unit}, margin=${margin:.2f}, value=${position_value:.2f}, "
f"SL=${stop_loss:.2f}, TP=${take_profit:.2f}"
)
return {
'symbol': symbol,
'timeframe': tf.value,
'side': direction,
'entry_price': price,
'size': size,
'margin': margin,
'position_value': position_value,
'pyramid_level': 1,
'max_levels': len(PYRAMID_LEVELS),
'stop_loss': stop_loss,
'take_profit': take_profit,
}
def _add_position(
self, symbol: str, tf: TimeFrame, price: float,
stop_loss: float, take_profit: float, reasoning: str
) -> Optional[Dict]:
"""金字塔加仓"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return None
# 检查是否已达最大层级
current_level = pos.pyramid_level
if current_level >= len(PYRAMID_LEVELS):
logger.info(f"[{symbol}][{config['name']}] 已达最大仓位层级 {current_level}/{len(PYRAMID_LEVELS)}")
return None
# 计算加仓仓位
max_position_value = self._get_max_position_value(symbol, tf)
level_ratio = PYRAMID_LEVELS[current_level]
add_position_value = max_position_value * level_ratio
add_margin = add_position_value / account.leverage
add_size = add_position_value / price
# 检查可用余额
available_balance = account.get_available_balance()
if available_balance < add_margin:
logger.warning(
f"[{symbol}][{config['name']}] 加仓余额不足: ${available_balance:.2f} < ${add_margin:.2f}"
)
return None
# 添加入场记录
entry = PositionEntry(
price=price,
size=add_size,
margin=add_margin,
timestamp=datetime.now().isoformat(),
level=current_level,
)
pos.entries.append(entry)
# 更新止盈止损
pos.stop_loss = stop_loss
pos.take_profit = take_profit
# 确定单位名称
unit = symbol.replace('USDT', '') if symbol.endswith('USDT') else symbol
new_level = pos.pyramid_level
logger.info(
f"[{symbol}][{config['name']}] ADD {pos.side} [L{new_level}/{len(PYRAMID_LEVELS)}]: price=${price:.2f}, "
f"add_size={add_size:.6f} {unit}, add_margin=${add_margin:.2f}, "
f"total_size={pos.size:.6f} {unit}, total_margin=${pos.margin:.2f}, "
f"avg_price=${pos.entry_price:.2f}"
)
return {
'symbol': symbol,
'timeframe': tf.value,
'side': pos.side,
'add_price': price,
'add_size': add_size,
'add_margin': add_margin,
'add_position_value': add_position_value,
'total_size': pos.size,
'total_margin': pos.margin,
'avg_entry_price': pos.entry_price,
'pyramid_level': new_level,
'max_levels': len(PYRAMID_LEVELS),
'stop_loss': stop_loss,
'take_profit': take_profit,
}
def _check_close_position(self, symbol: str, tf: TimeFrame, current_price: float) -> Optional[Dict]:
"""检查是否触发止盈止损"""
account = self.accounts[symbol][tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return None
if pos.side == 'LONG':
if current_price >= pos.take_profit:
return self._close_position(symbol, tf, current_price, 'TAKE_PROFIT')
elif current_price <= pos.stop_loss:
return self._close_position(symbol, tf, current_price, 'STOP_LOSS')
else: # SHORT
if current_price <= pos.take_profit:
return self._close_position(symbol, tf, current_price, 'TAKE_PROFIT')
elif current_price >= pos.stop_loss:
return self._close_position(symbol, tf, current_price, 'STOP_LOSS')
return None
def _close_position(self, symbol: str, tf: TimeFrame, price: float, reason: str) -> Dict:
"""平仓"""
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return {'error': 'No position'}
# 计算盈亏(基于保证金的收益率)
if pos.side == 'LONG':
# 做多:(卖出价 - 买入价) * 数量
pnl = (price - pos.entry_price) * pos.size
else:
# 做空:(买入价 - 卖出价) * 数量
pnl = (pos.entry_price - price) * pos.size
# 收益率 = 盈亏 / 保证金 * 100
pnl_pct = (pnl / pos.margin * 100) if pos.margin > 0 else 0
# 更新已实现盈亏(保证金释放 + 盈亏结算)
account.realized_pnl += pnl
# 记录交易
trade = Trade(
id=f"{symbol[0]}{tf.value[0].upper()}{len(account.trades)+1:04d}",
timeframe=tf.value,
side=pos.side,
entry_price=pos.entry_price,
entry_time=pos.created_at,
exit_price=price,
exit_time=datetime.now().isoformat(),
size=pos.size,
pnl=pnl,
pnl_pct=pnl_pct,
exit_reason=reason,
symbol=symbol,
)
account.trades.append(trade)
self._update_stats(symbol, tf, trade)
# 计算新的账户权益
new_equity = account.get_equity()
result = {
'symbol': symbol,
'timeframe': tf.value,
'side': pos.side,
'entry_price': pos.entry_price,
'exit_price': price,
'size': pos.size,
'margin': pos.margin,
'pnl': pnl,
'pnl_pct': pnl_pct,
'reason': reason,
'new_equity': new_equity,
'realized_pnl': account.realized_pnl,
}
logger.info(
f"[{symbol}][{config['name']}] CLOSE {pos.side}: entry=${pos.entry_price:.2f}, "
f"exit=${price:.2f}, PnL=${pnl:.2f} ({pnl_pct:.2f}%), reason={reason}, "
f"equity=${new_equity:.2f}"
)
account.position = None
return result
def _calc_unrealized_pnl(self, symbol: str, tf: TimeFrame, current_price: float) -> Dict[str, float]:
"""计算未实现盈亏"""
account = self.accounts[symbol][tf]
pos = account.position
if not pos or pos.side == 'FLAT':
return {'pnl': 0, 'pnl_pct': 0}
# 计算未实现盈亏
if pos.side == 'LONG':
pnl = (current_price - pos.entry_price) * pos.size
else:
pnl = (pos.entry_price - current_price) * pos.size
# 收益率 = 盈亏 / 保证金 * 100
pnl_pct = (pnl / pos.margin * 100) if pos.margin > 0 else 0
return {'pnl': pnl, 'pnl_pct': pnl_pct}
def _update_equity_curve(self, symbol: str, tf: TimeFrame, current_price: float):
"""更新权益曲线"""
account = self.accounts[symbol][tf]
unrealized = self._calc_unrealized_pnl(symbol, tf, current_price)
equity = account.get_equity(unrealized['pnl'])
account.equity_curve.append({
'timestamp': datetime.now().isoformat(),
'equity': equity,
'initial_balance': account.initial_balance,
'realized_pnl': account.realized_pnl,
'unrealized_pnl': unrealized['pnl'],
'price': current_price,
})
def _update_stats(self, symbol: str, tf: TimeFrame, trade: Trade):
"""更新统计数据"""
account = self.accounts[symbol][tf]
stats = account.stats
stats['total_trades'] += 1
stats['total_pnl'] += trade.pnl
if trade.pnl > 0:
stats['winning_trades'] += 1
else:
stats['losing_trades'] += 1
if stats['total_trades'] > 0:
stats['win_rate'] = stats['winning_trades'] / stats['total_trades'] * 100
wins = [t for t in account.trades if t.pnl > 0]
losses = [t for t in account.trades if t.pnl <= 0]
if wins:
stats['avg_win'] = sum(t.pnl for t in wins) / len(wins)
if losses:
stats['avg_loss'] = sum(t.pnl for t in losses) / len(losses)
if stats['avg_loss'] != 0:
stats['profit_factor'] = abs(stats['avg_win'] / stats['avg_loss'])
# 更新峰值和回撤(基于账户权益)
equity = account.get_equity()
if equity > stats['peak_balance']:
stats['peak_balance'] = equity
drawdown = (stats['peak_balance'] - equity) / stats['peak_balance'] * 100
if drawdown > stats['max_drawdown']:
stats['max_drawdown'] = drawdown
def get_status(
self,
current_price: float = None,
symbol: str = None,
prices: Dict[str, float] = None
) -> Dict[str, Any]:
"""获取状态
Args:
current_price: 单币种价格(向后兼容)
symbol: 指定币种(若为空则返回所有)
prices: 多币种价格 {symbol: price}
"""
# 如果指定了单个币种
if symbol:
return self._get_symbol_status(symbol, current_price or (prices.get(symbol) if prices else None))
# 返回所有币种汇总
return self._get_all_status(prices or {self.symbols[0]: current_price} if current_price else {})
def _get_symbol_status(self, symbol: str, current_price: float = None) -> Dict[str, Any]:
"""获取单个币种所有周期状态"""
if symbol not in self.accounts:
return {'error': f'Symbol {symbol} not found'}
total_equity = 0
total_initial = 0
total_realized_pnl = 0
total_unrealized_pnl = 0
timeframes_data = {}
for tf in TimeFrame:
account = self.accounts[symbol][tf]
config = TIMEFRAME_CONFIG[tf]
# 计算未实现盈亏
unrealized = self._calc_unrealized_pnl(symbol, tf, current_price) if current_price else {'pnl': 0, 'pnl_pct': 0}
equity = account.get_equity(unrealized['pnl'])
total_initial += account.initial_balance
total_realized_pnl += account.realized_pnl
total_unrealized_pnl += unrealized['pnl']
total_equity += equity
return_pct = (equity - account.initial_balance) / account.initial_balance * 100 if account.initial_balance > 0 else 0
tf_status = {
'name': config['name'],
'name_en': config['name_en'],
'symbol': symbol,
'initial_balance': account.initial_balance,
'realized_pnl': account.realized_pnl,
'unrealized_pnl': unrealized['pnl'],
'equity': equity,
'available_balance': account.get_available_balance(),
'used_margin': account.get_used_margin(),
'return_pct': return_pct,
'leverage': account.leverage,
'position': None,
'stats': account.stats,
'recent_trades': [t.to_dict() for t in account.trades[-10:]],
'equity_curve': account.equity_curve[-100:],
}
if account.position and account.position.side != 'FLAT':
pos_dict = account.position.to_dict()
if current_price:
pos_dict['current_price'] = current_price
pos_dict['unrealized_pnl'] = unrealized['pnl']
pos_dict['unrealized_pnl_pct'] = unrealized['pnl_pct']
tf_status['position'] = pos_dict
timeframes_data[tf.value] = tf_status
total_return = (total_equity - total_initial) / total_initial * 100 if total_initial > 0 else 0
return {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'total_initial_balance': total_initial,
'total_realized_pnl': total_realized_pnl,
'total_unrealized_pnl': total_unrealized_pnl,
'total_equity': total_equity,
'total_return': total_return,
'timeframes': timeframes_data,
}
def _get_all_status(self, prices: Dict[str, float] = None) -> Dict[str, Any]:
"""获取所有币种汇总状态"""
prices = prices or {}
grand_total_equity = 0
grand_total_initial = 0
grand_total_realized_pnl = 0
grand_total_unrealized_pnl = 0
symbols_data = {}
for symbol in self.symbols:
if symbol not in self.accounts:
continue
current_price = prices.get(symbol)
symbol_status = self._get_symbol_status(symbol, current_price)
symbols_data[symbol] = symbol_status
grand_total_initial += symbol_status.get('total_initial_balance', 0)
grand_total_realized_pnl += symbol_status.get('total_realized_pnl', 0)
grand_total_unrealized_pnl += symbol_status.get('total_unrealized_pnl', 0)
grand_total_equity += symbol_status.get('total_equity', 0)
grand_total_return = (grand_total_equity - grand_total_initial) / grand_total_initial * 100 if grand_total_initial > 0 else 0
# 向后兼容:保留 timeframes 字段(使用第一个币种)
first_symbol = self.symbols[0] if self.symbols else None
legacy_timeframes = symbols_data.get(first_symbol, {}).get('timeframes', {}) if first_symbol else {}
return {
'timestamp': datetime.now().isoformat(),
'symbols': symbols_data,
'timeframes': legacy_timeframes, # 向后兼容
'grand_total_initial_balance': grand_total_initial,
'grand_total_realized_pnl': grand_total_realized_pnl,
'grand_total_unrealized_pnl': grand_total_unrealized_pnl,
'grand_total_equity': grand_total_equity,
'grand_total_return': grand_total_return,
# 向后兼容字段
'total_initial_balance': grand_total_initial,
'total_realized_pnl': grand_total_realized_pnl,
'total_unrealized_pnl': grand_total_unrealized_pnl,
'total_equity': grand_total_equity,
'total_return': grand_total_return,
}
def reset(self, symbol: str = None):
"""重置账户
Args:
symbol: 指定币种,若为空则重置所有
"""
if symbol:
if symbol in self.accounts:
self._init_symbol_accounts(symbol)
logger.info(f"{symbol} accounts reset")
else:
self._init_all_accounts()
logger.info("All accounts reset")
self._save_state()
# 兼容旧的 PaperTrader 接口
PaperTrader = MultiTimeframePaperTrader