tradusai/trading/paper_trading.py
2025-12-09 12:27:47 +08:00

804 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Paper Trading Module - 模拟盘交易系统
支持仓位管理:
- 分批建仓(信号重复时加仓)
- 金字塔加仓策略
- 最大持仓限制
- 动态止盈止损
"""
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
from pathlib import Path
from dataclasses import dataclass, asdict, field
from enum import Enum
logger = logging.getLogger(__name__)
class PositionSide(Enum):
LONG = "LONG"
SHORT = "SHORT"
FLAT = "FLAT"
@dataclass
class PositionEntry:
"""单次入场记录"""
price: float
size: float # BTC 数量
time: str
signal_id: str # 信号标识
@dataclass
class Position:
"""持仓信息 - 支持多次入场"""
side: str # LONG, SHORT, FLAT
entries: List[Dict] = field(default_factory=list) # 多次入场记录
total_size: float = 0.0 # 总持仓量
avg_entry_price: float = 0.0 # 平均入场价
stop_loss: float = 0.0
take_profit: float = 0.0
created_at: str = ""
last_updated: str = ""
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'Position':
return cls(**data)
def add_entry(self, price: float, size: float, signal_id: str):
"""添加入场记录"""
entry = {
'price': price,
'size': size,
'time': datetime.now().isoformat(),
'signal_id': signal_id,
}
self.entries.append(entry)
# 更新平均价和总量
total_value = sum(e['price'] * e['size'] for e in self.entries)
self.total_size = sum(e['size'] for e in self.entries)
self.avg_entry_price = total_value / self.total_size if self.total_size > 0 else 0
self.last_updated = datetime.now().isoformat()
def reduce_position(self, reduce_size: float) -> float:
"""减仓 - 返回减仓的平均成本"""
if reduce_size >= self.total_size:
# 全部平仓
avg_cost = self.avg_entry_price
self.entries = []
self.total_size = 0
self.avg_entry_price = 0
return avg_cost
# 部分减仓 - FIFO 方式
remaining = reduce_size
removed_value = 0
removed_size = 0
while remaining > 0 and self.entries:
entry = self.entries[0]
if entry['size'] <= remaining:
removed_value += entry['price'] * entry['size']
removed_size += entry['size']
remaining -= entry['size']
self.entries.pop(0)
else:
removed_value += entry['price'] * remaining
removed_size += remaining
entry['size'] -= remaining
remaining = 0
# 更新总量和平均价
self.total_size = sum(e['size'] for e in self.entries)
if self.total_size > 0:
total_value = sum(e['price'] * e['size'] for e in self.entries)
self.avg_entry_price = total_value / self.total_size
else:
self.avg_entry_price = 0
self.last_updated = datetime.now().isoformat()
return removed_value / removed_size if removed_size > 0 else 0
@dataclass
class Trade:
"""交易记录"""
id: str
side: str
entry_price: float
entry_time: str
exit_price: float
exit_time: str
size: float
pnl: float
pnl_pct: float
exit_reason: str
signal_source: str
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'Trade':
return cls(**data)
class PositionManager:
"""仓位管理器"""
def __init__(
self,
max_position_pct: float = 0.5, # 最大持仓比例 (50% 资金)
base_position_pct: float = 0.1, # 基础仓位比例 (10% 资金)
max_entries: int = 5, # 最多加仓次数
pyramid_factor: float = 0.8, # 金字塔因子 (每次加仓量递减)
signal_cooldown: int = 300, # 同方向信号冷却时间(秒)
):
self.max_position_pct = max_position_pct
self.base_position_pct = base_position_pct
self.max_entries = max_entries
self.pyramid_factor = pyramid_factor
self.signal_cooldown = signal_cooldown
# 记录最近的信号
self.last_signal_time: Dict[str, datetime] = {}
self.signal_count: Dict[str, int] = {} # 连续同方向信号计数
def calculate_entry_size(
self,
balance: float,
current_position: Optional[Position],
signal_direction: str,
current_price: float,
leverage: int
) -> float:
"""
计算本次入场的仓位大小
Returns:
BTC 数量0 表示不开仓
"""
# 检查是否在冷却期
now = datetime.now()
last_time = self.last_signal_time.get(signal_direction)
if last_time and (now - last_time).total_seconds() < self.signal_cooldown:
logger.info(f"Signal cooldown: {signal_direction}, skip entry")
return 0
# 计算最大允许仓位价值
max_position_value = balance * self.max_position_pct * leverage
# 当前持仓价值
current_position_value = 0
num_entries = 0
if current_position and current_position.side != 'FLAT':
if current_position.side == signal_direction:
# 同方向,考虑加仓
current_position_value = current_position.total_size * current_price
num_entries = len(current_position.entries)
if num_entries >= self.max_entries:
logger.info(f"Max entries reached: {num_entries}")
return 0
else:
# 反方向,不在此处理(应先平仓)
return 0
# 计算剩余可用仓位
remaining_value = max_position_value - current_position_value
if remaining_value <= 0:
logger.info(f"Max position reached")
return 0
# 金字塔计算:每次加仓量递减
base_value = balance * self.base_position_pct * leverage
entry_value = base_value * (self.pyramid_factor ** num_entries)
# 取最小值
entry_value = min(entry_value, remaining_value)
# 转换为 BTC 数量
entry_size = entry_value / current_price
# 更新信号记录
self.last_signal_time[signal_direction] = now
self.signal_count[signal_direction] = self.signal_count.get(signal_direction, 0) + 1
return entry_size
def should_take_partial_profit(
self,
position: Position,
current_price: float,
profit_levels: List[float] = [0.01, 0.02, 0.03] # 1%, 2%, 3%
) -> Optional[Dict]:
"""
检查是否应该部分止盈
Returns:
{'size': 减仓量, 'reason': 原因} 或 None
"""
if not position or position.side == 'FLAT' or position.total_size == 0:
return None
# 计算当前盈利
if position.side == 'LONG':
profit_pct = (current_price - position.avg_entry_price) / position.avg_entry_price
else:
profit_pct = (position.avg_entry_price - current_price) / position.avg_entry_price
# 根据入场次数决定止盈策略
num_entries = len(position.entries)
# 多次入场时更积极止盈
for i, level in enumerate(profit_levels):
adjusted_level = level * (1 - 0.1 * (num_entries - 1)) # 入场越多,止盈越早
if profit_pct >= adjusted_level:
# 止盈 1/3 仓位
reduce_size = position.total_size / 3
if reduce_size * current_price >= 10: # 最小 $10
return {
'size': reduce_size,
'reason': f'PARTIAL_TP_{int(level*100)}PCT',
'profit_pct': profit_pct,
}
return None
def reset_signal_count(self, direction: str):
"""重置信号计数(平仓后调用)"""
self.signal_count[direction] = 0
class PaperTrader:
"""模拟盘交易器 - 支持仓位管理"""
def __init__(
self,
initial_balance: float = 10000.0,
leverage: int = 5,
max_position_pct: float = 0.5,
base_position_pct: float = 0.1,
state_file: str = None
):
self.initial_balance = initial_balance
self.leverage = leverage
# 仓位管理器
self.position_manager = PositionManager(
max_position_pct=max_position_pct,
base_position_pct=base_position_pct,
)
# 状态文件
if state_file:
self.state_file = Path(state_file)
else:
self.state_file = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
# 加载或初始化状态
self._load_state()
logger.info(f"Paper Trader initialized: balance=${self.balance:.2f}, leverage={leverage}x")
def _load_state(self):
"""加载持久化状态"""
if self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
state = json.load(f)
self.balance = state.get('balance', self.initial_balance)
self.position = Position.from_dict(state['position']) if state.get('position') else None
self.trades = [Trade.from_dict(t) for t in state.get('trades', [])]
self.stats = state.get('stats', self._init_stats())
self.equity_curve = state.get('equity_curve', [])
logger.info(f"Loaded state: balance=${self.balance:.2f}, trades={len(self.trades)}")
except Exception as e:
logger.error(f"Failed to load state: {e}")
self._init_state()
else:
self._init_state()
def _init_state(self):
"""初始化状态"""
self.balance = self.initial_balance
self.position: Optional[Position] = None
self.trades: List[Trade] = []
self.stats = self._init_stats()
self.equity_curve = [] # 权益曲线
def _init_stats(self) -> dict:
"""初始化统计数据"""
return {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'total_pnl': 0.0,
'max_drawdown': 0.0,
'peak_balance': self.initial_balance,
'win_rate': 0.0,
'avg_win': 0.0,
'avg_loss': 0.0,
'profit_factor': 0.0,
'total_long_trades': 0,
'total_short_trades': 0,
'consecutive_wins': 0,
'consecutive_losses': 0,
'max_consecutive_wins': 0,
'max_consecutive_losses': 0,
}
def _save_state(self):
"""保存状态到文件"""
self.state_file.parent.mkdir(parents=True, exist_ok=True)
state = {
'balance': self.balance,
'position': self.position.to_dict() if self.position else None,
'trades': [t.to_dict() for t in self.trades[-200:]],
'stats': self.stats,
'equity_curve': self.equity_curve[-1000:],
'last_updated': datetime.now().isoformat(),
}
with open(self.state_file, 'w') as f:
json.dump(state, f, indent=2, ensure_ascii=False)
def process_signal(self, signal: Dict[str, Any], current_price: float) -> Dict[str, Any]:
"""处理交易信号"""
result = {
'timestamp': datetime.now().isoformat(),
'current_price': current_price,
'action': 'NONE',
'details': None,
}
# 更新权益曲线
self._update_equity_curve(current_price)
# 1. 检查止盈止损
if self.position and self.position.side != 'FLAT':
close_result = self._check_close_position(current_price)
if close_result:
result['action'] = 'CLOSE'
result['details'] = close_result
self._save_state()
return result
# 2. 检查部分止盈
partial_tp = self.position_manager.should_take_partial_profit(
self.position, current_price
)
if partial_tp:
close_result = self._partial_close(current_price, partial_tp['size'], partial_tp['reason'])
result['action'] = 'PARTIAL_CLOSE'
result['details'] = close_result
self._save_state()
return result
# 3. 提取短期信号
short_term = self._extract_short_term_signal(signal)
if not short_term or not short_term.get('exists'):
result['action'] = 'NO_SIGNAL'
result['details'] = {'reason': '无有效短期信号'}
return result
direction = short_term['direction']
# 4. 如果有反向持仓,先平仓
if self.position and self.position.side != 'FLAT':
if (self.position.side == 'LONG' and direction == 'SHORT') or \
(self.position.side == 'SHORT' and direction == 'LONG'):
close_result = self._close_position(current_price, 'SIGNAL_REVERSE')
result['action'] = 'REVERSE'
result['details'] = {'close': close_result}
# 开反向仓
open_result = self._try_open_position(
direction, current_price,
short_term.get('stop_loss', 0),
short_term.get('take_profit', 0),
short_term.get('reasoning', '')[:100]
)
if open_result:
result['details']['open'] = open_result
self._save_state()
return result
else:
# 同方向,尝试加仓
add_result = self._try_add_position(
direction, current_price,
short_term.get('stop_loss', 0),
short_term.get('take_profit', 0),
short_term.get('reasoning', '')[:100]
)
if add_result:
result['action'] = 'ADD'
result['details'] = add_result
self._save_state()
return result
else:
result['action'] = 'HOLD'
result['details'] = {
'position': self.position.to_dict(),
'unrealized_pnl': self._calc_unrealized_pnl(current_price),
'reason': '已有持仓,加仓条件不满足'
}
return result
# 5. 无持仓,开新仓
open_result = self._try_open_position(
direction, current_price,
short_term.get('stop_loss', 0),
short_term.get('take_profit', 0),
short_term.get('reasoning', '')[:100]
)
if open_result:
result['action'] = 'OPEN'
result['details'] = open_result
else:
result['action'] = 'WAIT'
result['details'] = {'reason': '仓位条件不满足'}
self._save_state()
return result
def _extract_short_term_signal(self, signal: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""提取短期信号"""
try:
llm_signal = signal.get('llm_signal') or signal.get('aggregated_signal', {}).get('llm_signal')
if llm_signal and isinstance(llm_signal, dict):
opportunities = llm_signal.get('opportunities', {})
short_term = opportunities.get('short_term_5m_15m_1h') or opportunities.get('intraday')
if short_term:
return short_term
agg = signal.get('aggregated_signal', {})
if agg:
llm = agg.get('llm_signal', {})
if llm:
opps = llm.get('opportunities', {})
short_term = opps.get('short_term_5m_15m_1h') or opps.get('intraday')
if short_term:
return short_term
return None
except Exception as e:
logger.error(f"Error extracting short term signal: {e}")
return None
def _try_open_position(
self, direction: str, price: float,
stop_loss: float, take_profit: float, signal_source: str
) -> Optional[Dict]:
"""尝试开仓"""
# 计算仓位大小
entry_size = self.position_manager.calculate_entry_size(
self.balance, self.position, direction, price, self.leverage
)
if entry_size <= 0:
return None
# 创建持仓
self.position = Position(
side=direction,
stop_loss=stop_loss if stop_loss > 0 else self._calc_default_stop(direction, price),
take_profit=take_profit if take_profit > 0 else self._calc_default_tp(direction, price),
created_at=datetime.now().isoformat(),
)
signal_id = f"S{datetime.now().strftime('%H%M%S')}"
self.position.add_entry(price, entry_size, signal_id)
logger.info(
f"OPEN {direction}: price=${price:.2f}, size={entry_size:.6f} BTC, "
f"SL=${self.position.stop_loss:.2f}, TP=${self.position.take_profit:.2f}"
)
return {
'side': direction,
'entry_price': price,
'size': entry_size,
'total_size': self.position.total_size,
'stop_loss': self.position.stop_loss,
'take_profit': self.position.take_profit,
'num_entries': 1,
}
def _try_add_position(
self, direction: str, price: float,
stop_loss: float, take_profit: float, signal_source: str
) -> Optional[Dict]:
"""尝试加仓"""
if not self.position or self.position.side != direction:
return None
entry_size = self.position_manager.calculate_entry_size(
self.balance, self.position, direction, price, self.leverage
)
if entry_size <= 0:
return None
signal_id = f"S{datetime.now().strftime('%H%M%S')}"
old_avg = self.position.avg_entry_price
self.position.add_entry(price, entry_size, signal_id)
# 可选:更新止盈止损
if stop_loss > 0:
self.position.stop_loss = stop_loss
if take_profit > 0:
self.position.take_profit = take_profit
logger.info(
f"ADD {direction}: price=${price:.2f}, size={entry_size:.6f} BTC, "
f"avg_entry=${old_avg:.2f}->${self.position.avg_entry_price:.2f}, "
f"total_size={self.position.total_size:.6f}"
)
return {
'side': direction,
'add_price': price,
'add_size': entry_size,
'total_size': self.position.total_size,
'avg_entry_price': self.position.avg_entry_price,
'num_entries': len(self.position.entries),
}
def _calc_default_stop(self, side: str, price: float) -> float:
"""计算默认止损 (0.5%)"""
if side == 'LONG':
return price * 0.995
else:
return price * 1.005
def _calc_default_tp(self, side: str, price: float) -> float:
"""计算默认止盈 (1.5%)"""
if side == 'LONG':
return price * 1.015
else:
return price * 0.985
def _check_close_position(self, current_price: float) -> Optional[Dict[str, Any]]:
"""检查是否触发止盈止损"""
if not self.position or self.position.side == 'FLAT':
return None
if self.position.side == 'LONG':
if current_price >= self.position.take_profit:
return self._close_position(current_price, 'TAKE_PROFIT')
elif current_price <= self.position.stop_loss:
return self._close_position(current_price, 'STOP_LOSS')
else:
if current_price <= self.position.take_profit:
return self._close_position(current_price, 'TAKE_PROFIT')
elif current_price >= self.position.stop_loss:
return self._close_position(current_price, 'STOP_LOSS')
return None
def _close_position(self, price: float, reason: str) -> Dict[str, Any]:
"""全部平仓"""
if not self.position or self.position.side == 'FLAT':
return {'error': 'No position to close'}
pnl, pnl_pct = self._calc_pnl(price)
self.balance += pnl
trade = Trade(
id=f"T{len(self.trades)+1:04d}",
side=self.position.side,
entry_price=self.position.avg_entry_price,
entry_time=self.position.created_at,
exit_price=price,
exit_time=datetime.now().isoformat(),
size=self.position.total_size,
pnl=pnl,
pnl_pct=pnl_pct,
exit_reason=reason,
signal_source=f"{len(self.position.entries)} entries",
)
self.trades.append(trade)
self._update_stats(trade)
result = {
'side': self.position.side,
'entry_price': self.position.avg_entry_price,
'exit_price': price,
'size': self.position.total_size,
'num_entries': len(self.position.entries),
'pnl': pnl,
'pnl_pct': pnl_pct,
'reason': reason,
'new_balance': self.balance,
}
logger.info(
f"CLOSE {self.position.side}: avg_entry=${self.position.avg_entry_price:.2f}, "
f"exit=${price:.2f}, PnL=${pnl:.2f} ({pnl_pct:.2f}%), reason={reason}"
)
# 重置
self.position_manager.reset_signal_count(self.position.side)
self.position = None
return result
def _partial_close(self, price: float, size: float, reason: str) -> Dict[str, Any]:
"""部分平仓"""
if not self.position or self.position.side == 'FLAT':
return {'error': 'No position'}
avg_cost = self.position.reduce_position(size)
if self.position.side == 'LONG':
pnl_pct = (price - avg_cost) / avg_cost * 100 * self.leverage
else:
pnl_pct = (avg_cost - price) / avg_cost * 100 * self.leverage
pnl = size * avg_cost * (pnl_pct / 100)
self.balance += pnl
trade = Trade(
id=f"T{len(self.trades)+1:04d}",
side=self.position.side,
entry_price=avg_cost,
entry_time=self.position.created_at,
exit_price=price,
exit_time=datetime.now().isoformat(),
size=size,
pnl=pnl,
pnl_pct=pnl_pct,
exit_reason=reason,
signal_source="partial",
)
self.trades.append(trade)
self._update_stats(trade)
logger.info(
f"PARTIAL CLOSE: size={size:.6f}, PnL=${pnl:.2f} ({pnl_pct:.2f}%), "
f"remaining={self.position.total_size:.6f}"
)
# 如果完全平仓
if self.position.total_size <= 0:
self.position_manager.reset_signal_count(self.position.side)
self.position = None
return {
'side': self.position.side if self.position else 'FLAT',
'closed_size': size,
'exit_price': price,
'pnl': pnl,
'pnl_pct': pnl_pct,
'reason': reason,
'remaining_size': self.position.total_size if self.position else 0,
'new_balance': self.balance,
}
def _calc_pnl(self, current_price: float) -> tuple:
"""计算盈亏"""
if not self.position:
return 0.0, 0.0
if self.position.side == 'LONG':
pnl_pct = (current_price - self.position.avg_entry_price) / self.position.avg_entry_price * 100
else:
pnl_pct = (self.position.avg_entry_price - current_price) / self.position.avg_entry_price * 100
pnl_pct *= self.leverage
position_value = self.position.total_size * self.position.avg_entry_price
pnl = position_value * (pnl_pct / 100)
return pnl, pnl_pct
def _calc_unrealized_pnl(self, current_price: float) -> Dict[str, float]:
"""计算未实现盈亏"""
pnl, pnl_pct = self._calc_pnl(current_price)
return {'pnl': pnl, 'pnl_pct': pnl_pct}
def _update_equity_curve(self, current_price: float):
"""更新权益曲线"""
equity = self.balance
if self.position and self.position.total_size > 0:
unrealized = self._calc_unrealized_pnl(current_price)
equity += unrealized['pnl']
self.equity_curve.append({
'timestamp': datetime.now().isoformat(),
'equity': equity,
'balance': self.balance,
'price': current_price,
})
def _update_stats(self, trade: Trade):
"""更新统计数据"""
self.stats['total_trades'] += 1
self.stats['total_pnl'] += trade.pnl
if trade.side == 'LONG':
self.stats['total_long_trades'] += 1
else:
self.stats['total_short_trades'] += 1
if trade.pnl > 0:
self.stats['winning_trades'] += 1
self.stats['consecutive_wins'] += 1
self.stats['consecutive_losses'] = 0
if self.stats['consecutive_wins'] > self.stats['max_consecutive_wins']:
self.stats['max_consecutive_wins'] = self.stats['consecutive_wins']
else:
self.stats['losing_trades'] += 1
self.stats['consecutive_losses'] += 1
self.stats['consecutive_wins'] = 0
if self.stats['consecutive_losses'] > self.stats['max_consecutive_losses']:
self.stats['max_consecutive_losses'] = self.stats['consecutive_losses']
if self.stats['total_trades'] > 0:
self.stats['win_rate'] = self.stats['winning_trades'] / self.stats['total_trades'] * 100
wins = [t for t in self.trades if t.pnl > 0]
losses = [t for t in self.trades if t.pnl <= 0]
if wins:
self.stats['avg_win'] = sum(t.pnl for t in wins) / len(wins)
if losses:
self.stats['avg_loss'] = sum(t.pnl for t in losses) / len(losses)
if self.stats['avg_loss'] != 0:
self.stats['profit_factor'] = abs(self.stats['avg_win'] / self.stats['avg_loss'])
if self.balance > self.stats['peak_balance']:
self.stats['peak_balance'] = self.balance
drawdown = (self.stats['peak_balance'] - self.balance) / self.stats['peak_balance'] * 100
if drawdown > self.stats['max_drawdown']:
self.stats['max_drawdown'] = drawdown
def get_status(self, current_price: float = None) -> Dict[str, Any]:
"""获取当前状态"""
status = {
'timestamp': datetime.now().isoformat(),
'balance': self.balance,
'initial_balance': self.initial_balance,
'total_return': (self.balance - self.initial_balance) / self.initial_balance * 100,
'leverage': self.leverage,
'position': None,
'stats': self.stats,
'recent_trades': [t.to_dict() for t in self.trades[-10:]],
'equity_curve': self.equity_curve[-100:],
}
if self.position and self.position.total_size > 0:
pos_dict = self.position.to_dict()
if current_price:
unrealized = self._calc_unrealized_pnl(current_price)
pos_dict['current_price'] = current_price
pos_dict['unrealized_pnl'] = unrealized['pnl']
pos_dict['unrealized_pnl_pct'] = unrealized['pnl_pct']
status['position'] = pos_dict
return status
def reset(self):
"""重置模拟盘"""
self._init_state()
self._save_state()
logger.info("Paper trading account reset")