stock-ai-agent/backend/app/services/paper_trading_service.py
2026-02-08 22:38:06 +08:00

736 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模拟交易服务 - 订单管理和盈亏统计
"""
import uuid
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from app.models.paper_trading import PaperOrder, OrderStatus, OrderSide, SignalGrade, EntryType
from app.services.db_service import db_service
from app.config import get_settings
from app.utils.logger import logger
# 仓位大小配置
POSITION_SIZE = {
'A': 1000, # A级信号 1000 USDT
'B': 500, # B级信号 500 USDT
'C': 200, # C级信号 200 USDT
'D': 0 # D级信号不开仓
}
class PaperTradingService:
"""模拟交易服务"""
def __init__(self):
"""初始化模拟交易服务"""
self.settings = get_settings()
self.active_orders: Dict[str, PaperOrder] = {} # 内存缓存活跃订单
# 确保表已创建
self._ensure_table_exists()
# 加载活跃订单到内存
self._load_active_orders()
logger.info("模拟交易服务初始化完成")
def _ensure_table_exists(self):
"""确保数据表已创建"""
from app.models.paper_trading import PaperOrder
from app.models.database import Base
Base.metadata.create_all(bind=db_service.engine)
def _load_active_orders(self):
"""从数据库加载活跃订单到内存"""
db = db_service.get_session()
try:
orders = db.query(PaperOrder).filter(
PaperOrder.status.in_([OrderStatus.PENDING, OrderStatus.OPEN])
).all()
for order in orders:
self.active_orders[order.order_id] = order
logger.info(f"已加载 {len(orders)} 个活跃订单")
except Exception as e:
logger.error(f"加载活跃订单失败: {e}")
finally:
db.close()
def create_order_from_signal(self, signal: Dict[str, Any], current_price: float = None) -> Optional[PaperOrder]:
"""
从交易信号创建模拟订单
Args:
signal: 交易信号
- symbol: 交易对
- action: 'buy''sell'
- entry_type: 'market''limit'
- price / entry_price: 入场价
- stop_loss: 止损价
- take_profit: 止盈价
- confidence: 置信度
- signal_grade / grade: 信号等级
- signal_type / type: 信号类型
- reason: 入场原因
current_price: 当前价格(用于市价单)
Returns:
创建的订单或 None
"""
action = signal.get('action')
if action not in ['buy', 'sell']:
return None
symbol = signal.get('symbol', 'UNKNOWN')
side = OrderSide.LONG if action == 'buy' else OrderSide.SHORT
entry_price = signal.get('entry_price') or signal.get('price', 0)
# === 限制检查 ===
# 1. 同一交易对同一方向最多 3 个订单
same_direction_orders = [
order for order in self.active_orders.values()
if order.symbol == symbol and order.side == side
]
if len(same_direction_orders) >= 3:
logger.info(f"订单限制: {symbol} {side.value} 方向已有 {len(same_direction_orders)} 个订单,跳过")
return None
# 2. 检查是否有接近的挂单(价格差距 < 1%
pending_orders = [
order for order in same_direction_orders
if order.status == OrderStatus.PENDING
]
for pending in pending_orders:
price_diff = abs(pending.entry_price - entry_price) / pending.entry_price
if price_diff < 0.01: # 价格差距小于 1%
logger.info(f"订单限制: {symbol} 已有接近的挂单 @ ${pending.entry_price:,.2f},新信号 @ ${entry_price:,.2f},跳过")
return None
# 获取信号等级
grade = signal.get('signal_grade') or signal.get('grade', 'D')
if grade == 'D':
logger.info(f"D级信号不开仓: {signal.get('symbol')}")
return None
# 确定仓位大小
quantity = POSITION_SIZE.get(grade, 0)
if quantity == 0:
return None
# 确定入场类型
entry_type_str = signal.get('entry_type', 'market')
entry_type = EntryType.LIMIT if entry_type_str == 'limit' else EntryType.MARKET
# 生成订单ID
order_id = f"PT-{symbol}-{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:6]}"
# 确定订单状态和成交价
if entry_type == EntryType.MARKET:
# 现价单:立即开仓
status = OrderStatus.OPEN
filled_price = current_price if current_price else entry_price
opened_at = datetime.utcnow()
else:
# 挂单:等待触发
status = OrderStatus.PENDING
filled_price = None
opened_at = None
db = db_service.get_session()
try:
order = PaperOrder(
order_id=order_id,
symbol=symbol,
side=side,
entry_price=entry_price,
stop_loss=signal.get('stop_loss', 0),
take_profit=signal.get('take_profit', 0),
filled_price=filled_price,
quantity=quantity,
signal_grade=SignalGrade(grade),
signal_type=signal.get('signal_type') or signal.get('type', 'swing'),
confidence=signal.get('confidence', 0),
trend=signal.get('trend'),
entry_type=entry_type,
status=status,
opened_at=opened_at,
entry_reasons=[signal.get('reason', '')] if signal.get('reason') else signal.get('reasons', []),
indicators=signal.get('indicators', {})
)
db.add(order)
db.commit()
db.refresh(order)
# 添加到活跃订单缓存
self.active_orders[order.order_id] = order
entry_type_text = "现价" if entry_type == EntryType.MARKET else "挂单"
status_text = "已开仓" if status == OrderStatus.OPEN else "等待触发"
logger.info(f"创建模拟订单: {order_id} | {symbol} {side.value} [{entry_type_text}] @ ${entry_price:,.2f} | {status_text} | 仓位: ${quantity}")
return order
except Exception as e:
logger.error(f"创建模拟订单失败: {e}")
db.rollback()
return None
finally:
db.close()
def check_price_triggers(self, symbol: str, current_price: float) -> List[Dict[str, Any]]:
"""
检查当前价格是否触发挂单入场或止盈止损
Args:
symbol: 交易对
current_price: 当前价格
Returns:
触发的订单结果列表(平仓结果)
"""
triggered = []
# 1. 检查挂单是否触发入场
pending_orders = [
order for order in self.active_orders.values()
if order.symbol == symbol and order.status == OrderStatus.PENDING
]
for order in pending_orders:
if self._check_pending_entry(order, current_price):
logger.info(f"挂单触发入场: {order.order_id} | {symbol} @ ${current_price:,.2f}")
# 2. 检查持仓订单是否触发止盈止损
open_orders = [
order for order in self.active_orders.values()
if order.symbol == symbol and order.status == OrderStatus.OPEN
]
for order in open_orders:
result = self._check_order_trigger(order, current_price)
if result:
triggered.append(result)
else:
# 更新最大回撤和最大盈利
self._update_order_extremes(order, current_price)
return triggered
def _check_pending_entry(self, order: PaperOrder, current_price: float) -> bool:
"""
检查挂单是否触发入场
做多挂单:价格下跌到入场价时触发(买入)
做空挂单:价格上涨到入场价时触发(卖出)
"""
should_trigger = False
if order.side == OrderSide.LONG:
# 做多:价格 <= 入场价 触发
if current_price <= order.entry_price:
should_trigger = True
else:
# 做空:价格 >= 入场价 触发
if current_price >= order.entry_price:
should_trigger = True
if should_trigger:
return self._activate_pending_order(order, current_price)
return False
def _activate_pending_order(self, order: PaperOrder, filled_price: float) -> bool:
"""激活挂单,转为持仓"""
db = db_service.get_session()
try:
order.status = OrderStatus.OPEN
order.filled_price = filled_price
order.opened_at = datetime.utcnow()
db.merge(order)
db.commit()
logger.info(f"挂单已激活: {order.order_id} | {order.symbol} {order.side.value} @ ${filled_price:,.2f}")
return True
except Exception as e:
logger.error(f"激活挂单失败: {e}")
db.rollback()
return False
finally:
db.close()
def _check_order_trigger(self, order: PaperOrder, current_price: float) -> Optional[Dict[str, Any]]:
"""检查单个订单是否触发"""
triggered = False
new_status = None
exit_price = current_price
if order.side == OrderSide.LONG:
# 做多: 价格 >= 止盈价 触发止盈, 价格 <= 止损价 触发止损
if current_price >= order.take_profit:
triggered = True
new_status = OrderStatus.CLOSED_TP
exit_price = order.take_profit
elif current_price <= order.stop_loss:
triggered = True
new_status = OrderStatus.CLOSED_SL
exit_price = order.stop_loss
else:
# 做空: 价格 <= 止盈价 触发止盈, 价格 >= 止损价 触发止损
if current_price <= order.take_profit:
triggered = True
new_status = OrderStatus.CLOSED_TP
exit_price = order.take_profit
elif current_price >= order.stop_loss:
triggered = True
new_status = OrderStatus.CLOSED_SL
exit_price = order.stop_loss
if triggered:
return self._close_order(order, new_status, exit_price)
return None
def _close_order(self, order: PaperOrder, status: OrderStatus, exit_price: float) -> Dict[str, Any]:
"""平仓并计算盈亏"""
db = db_service.get_session()
try:
# 计算盈亏
if order.side == OrderSide.LONG:
pnl_percent = ((exit_price - order.filled_price) / order.filled_price) * 100
else:
pnl_percent = ((order.filled_price - exit_price) / order.filled_price) * 100
pnl_amount = order.quantity * pnl_percent / 100
# 计算持仓时间
hold_duration = datetime.utcnow() - order.opened_at if order.opened_at else timedelta(0)
# 更新订单
order.status = status
order.exit_price = exit_price
order.closed_at = datetime.utcnow()
order.pnl_amount = round(pnl_amount, 2)
order.pnl_percent = round(pnl_percent, 4)
db.merge(order)
db.commit()
# 从活跃订单缓存中移除
if order.order_id in self.active_orders:
del self.active_orders[order.order_id]
result = {
'order_id': order.order_id,
'symbol': order.symbol,
'side': order.side.value,
'status': status.value,
'entry_price': order.filled_price,
'exit_price': exit_price,
'quantity': order.quantity,
'pnl_amount': order.pnl_amount,
'pnl_percent': order.pnl_percent,
'is_win': pnl_amount > 0,
'hold_duration': str(hold_duration).split('.')[0], # 去掉微秒
'signal_grade': order.signal_grade.value if order.signal_grade else None
}
status_text = "止盈" if status == OrderStatus.CLOSED_TP else "止损"
logger.info(f"订单{status_text}: {order.order_id} | {order.symbol} | 盈亏: {pnl_percent:+.2f}% (${pnl_amount:+.2f})")
return result
except Exception as e:
logger.error(f"平仓失败: {e}")
db.rollback()
return None
finally:
db.close()
def _update_order_extremes(self, order: PaperOrder, current_price: float):
"""更新订单的最大回撤和最大盈利"""
if order.side == OrderSide.LONG:
current_pnl_percent = ((current_price - order.filled_price) / order.filled_price) * 100
else:
current_pnl_percent = ((order.filled_price - current_price) / order.filled_price) * 100
# 更新极值
if current_pnl_percent > order.max_profit:
order.max_profit = current_pnl_percent
if current_pnl_percent < order.max_drawdown:
order.max_drawdown = current_pnl_percent
def close_order_manual(self, order_id: str, exit_price: float) -> Optional[Dict[str, Any]]:
"""手动平仓或取消挂单"""
if order_id not in self.active_orders:
logger.warning(f"订单不存在或已平仓: {order_id}")
return None
order = self.active_orders[order_id]
# 如果是挂单,取消而不是平仓
if order.status == OrderStatus.PENDING:
return self._cancel_pending_order(order)
return self._close_order(order, OrderStatus.CLOSED_MANUAL, exit_price)
def _cancel_pending_order(self, order: PaperOrder) -> Dict[str, Any]:
"""取消挂单"""
db = db_service.get_session()
try:
order.status = OrderStatus.CANCELLED
order.closed_at = datetime.utcnow()
db.merge(order)
db.commit()
# 从活跃订单缓存中移除
if order.order_id in self.active_orders:
del self.active_orders[order.order_id]
logger.info(f"挂单已取消: {order.order_id} | {order.symbol}")
return {
'order_id': order.order_id,
'symbol': order.symbol,
'side': order.side.value,
'status': 'cancelled',
'entry_price': order.entry_price,
'message': '挂单已取消'
}
except Exception as e:
logger.error(f"取消挂单失败: {e}")
db.rollback()
return None
finally:
db.close()
def get_active_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取活跃订单"""
orders = list(self.active_orders.values())
if symbol:
orders = [o for o in orders if o.symbol == symbol]
return [o.to_dict() for o in orders]
def get_order_by_id(self, order_id: str) -> Optional[Dict[str, Any]]:
"""根据ID获取订单"""
# 先从缓存查找
if order_id in self.active_orders:
return self.active_orders[order_id].to_dict()
# 从数据库查找
db = db_service.get_session()
try:
order = db.query(PaperOrder).filter(PaperOrder.order_id == order_id).first()
return order.to_dict() if order else None
finally:
db.close()
def get_order_history(self, symbol: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
"""获取历史订单"""
db = db_service.get_session()
try:
query = db.query(PaperOrder).filter(
PaperOrder.status.in_([
OrderStatus.CLOSED_TP,
OrderStatus.CLOSED_SL,
OrderStatus.CLOSED_MANUAL
])
)
if symbol:
query = query.filter(PaperOrder.symbol == symbol)
orders = query.order_by(PaperOrder.closed_at.desc()).limit(limit).all()
return [o.to_dict() for o in orders]
finally:
db.close()
def calculate_statistics(self, symbol: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None) -> Dict[str, Any]:
"""计算交易统计"""
db = db_service.get_session()
try:
query = db.query(PaperOrder).filter(
PaperOrder.status.in_([
OrderStatus.CLOSED_TP,
OrderStatus.CLOSED_SL,
OrderStatus.CLOSED_MANUAL
])
)
if symbol:
query = query.filter(PaperOrder.symbol == symbol)
if start_date:
query = query.filter(PaperOrder.closed_at >= start_date)
if end_date:
query = query.filter(PaperOrder.closed_at <= end_date)
orders = query.all()
if not orders:
return self._empty_statistics()
# 计算各项指标
total_trades = len(orders)
winning_trades = len([o for o in orders if o.pnl_amount > 0])
losing_trades = len([o for o in orders if o.pnl_amount < 0])
total_pnl = sum(o.pnl_amount for o in orders)
total_pnl_percent = sum(o.pnl_percent for o in orders)
wins = [o.pnl_amount for o in orders if o.pnl_amount > 0]
losses = [abs(o.pnl_amount) for o in orders if o.pnl_amount < 0]
gross_profit = sum(wins) if wins else 0
gross_loss = sum(losses) if losses else 0
return {
'total_trades': total_trades,
'winning_trades': winning_trades,
'losing_trades': losing_trades,
'win_rate': round((winning_trades / total_trades * 100), 2) if total_trades > 0 else 0,
'total_pnl': round(total_pnl, 2),
'total_pnl_percent': round(total_pnl_percent, 2),
'average_pnl': round(total_pnl / total_trades, 2) if total_trades > 0 else 0,
'average_win': round(sum(wins) / len(wins), 2) if wins else 0,
'average_loss': round(sum(losses) / len(losses), 2) if losses else 0,
'profit_factor': round(gross_profit / gross_loss, 2) if gross_loss > 0 else float('inf'),
'max_drawdown': min(o.max_drawdown for o in orders) if orders else 0,
'best_trade': max(o.pnl_percent for o in orders) if orders else 0,
'worst_trade': min(o.pnl_percent for o in orders) if orders else 0,
'by_grade': self._calculate_grade_statistics(orders),
'by_type': self._calculate_type_statistics(orders),
'by_symbol': self._calculate_symbol_statistics(orders)
}
finally:
db.close()
def _empty_statistics(self) -> Dict[str, Any]:
"""返回空统计结构"""
return {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'win_rate': 0,
'total_pnl': 0,
'total_pnl_percent': 0,
'average_pnl': 0,
'average_win': 0,
'average_loss': 0,
'profit_factor': 0,
'max_drawdown': 0,
'best_trade': 0,
'worst_trade': 0,
'by_grade': {},
'by_type': {},
'by_symbol': {}
}
def _calculate_grade_statistics(self, orders: List[PaperOrder]) -> Dict[str, Any]:
"""按信号等级统计"""
result = {}
for grade in ['A', 'B', 'C', 'D']:
grade_orders = [o for o in orders if o.signal_grade and o.signal_grade.value == grade]
if grade_orders:
wins = len([o for o in grade_orders if o.pnl_amount > 0])
result[grade] = {
'count': len(grade_orders),
'win_rate': round(wins / len(grade_orders) * 100, 1),
'total_pnl': round(sum(o.pnl_amount for o in grade_orders), 2)
}
return result
def _calculate_type_statistics(self, orders: List[PaperOrder]) -> Dict[str, Any]:
"""按信号类型统计"""
result = {}
for signal_type in ['swing', 'short_term']:
type_orders = [o for o in orders if o.signal_type == signal_type]
if type_orders:
wins = len([o for o in type_orders if o.pnl_amount > 0])
result[signal_type] = {
'count': len(type_orders),
'win_rate': round(wins / len(type_orders) * 100, 1),
'total_pnl': round(sum(o.pnl_amount for o in type_orders), 2)
}
return result
def _calculate_symbol_statistics(self, orders: List[PaperOrder]) -> Dict[str, Any]:
"""按交易对统计"""
result = {}
symbols = set(o.symbol for o in orders)
for symbol in symbols:
symbol_orders = [o for o in orders if o.symbol == symbol]
if symbol_orders:
wins = len([o for o in symbol_orders if o.pnl_amount > 0])
result[symbol] = {
'count': len(symbol_orders),
'win_rate': round(wins / len(symbol_orders) * 100, 1),
'total_pnl': round(sum(o.pnl_amount for o in symbol_orders), 2)
}
return result
def get_period_statistics(self, hours: int = 4) -> Dict[str, Any]:
"""
获取指定时间段内的统计数据
Args:
hours: 统计时间段(小时)
Returns:
时间段内的统计数据
"""
db = db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
# 查询时间段内平仓的订单
closed_orders = db.query(PaperOrder).filter(
PaperOrder.status.in_([
OrderStatus.CLOSED_TP,
OrderStatus.CLOSED_SL,
OrderStatus.CLOSED_MANUAL
]),
PaperOrder.closed_at >= cutoff_time
).all()
# 查询时间段内新开仓的订单(包括当前活跃的)
new_orders = db.query(PaperOrder).filter(
PaperOrder.created_at >= cutoff_time
).all()
# 计算时间段内的盈亏
period_pnl = sum(o.pnl_amount for o in closed_orders)
period_wins = len([o for o in closed_orders if o.pnl_amount > 0])
period_losses = len([o for o in closed_orders if o.pnl_amount < 0])
return {
'period_hours': hours,
'new_orders': len(new_orders),
'closed_orders': len(closed_orders),
'period_pnl': round(period_pnl, 2),
'period_wins': period_wins,
'period_losses': period_losses,
'period_win_rate': round(period_wins / len(closed_orders) * 100, 1) if closed_orders else 0
}
finally:
db.close()
def generate_report(self, hours: int = 4) -> str:
"""
生成模拟交易报告
Args:
hours: 报告时间段(小时)
Returns:
格式化的报告文本
"""
# 获取总体统计
total_stats = self.calculate_statistics()
# 获取时间段统计
period_stats = self.get_period_statistics(hours)
# 获取当前活跃订单
active_orders = self.get_active_orders()
# 构建报告
lines = [
f"📊 <b>模拟交易 {hours} 小时报告</b>",
"",
"━━━━━━ 总体情况 ━━━━━━",
f"总交易数: {total_stats['total_trades']} | 胜率: {total_stats['win_rate']}%",
f"总盈亏: <code>${total_stats['total_pnl']:+.2f}</code>",
"",
f"━━━━━━ 过去 {hours} 小时 ━━━━━━",
f"新订单: {period_stats['new_orders']} | 已平仓: {period_stats['closed_orders']}",
f"本期盈亏: <code>${period_stats['period_pnl']:+.2f}</code>",
]
# 当前持仓
open_orders = [o for o in active_orders if o.get('status') == 'open']
pending_orders = [o for o in active_orders if o.get('status') == 'pending']
if open_orders or pending_orders:
lines.append("")
lines.append("━━━━━━ 当前订单 ━━━━━━")
for order in open_orders[:5]: # 最多显示5个
side_text = "做多" if order.get('side') == 'long' else "做空"
entry_price = order.get('filled_price') or order.get('entry_price', 0)
lines.append(f"{order.get('symbol')} {side_text} @ ${entry_price:,.0f}")
for order in pending_orders[:3]: # 最多显示3个挂单
side_text = "做多" if order.get('side') == 'long' else "做空"
lines.append(f"{order.get('symbol')} {side_text} 挂单 @ ${order.get('entry_price', 0):,.0f}")
if len(open_orders) > 5:
lines.append(f"... 还有 {len(open_orders) - 5} 个持仓")
if len(pending_orders) > 3:
lines.append(f"... 还有 {len(pending_orders) - 3} 个挂单")
# 按等级统计
by_grade = total_stats.get('by_grade', {})
if by_grade:
lines.append("")
lines.append("━━━━━━ 按等级统计 ━━━━━━")
for grade in ['A', 'B', 'C']:
if grade in by_grade:
g = by_grade[grade]
pnl_sign = "+" if g['total_pnl'] >= 0 else ""
lines.append(f"{grade}级: {g['count']}笔 | 胜率{g['win_rate']}% | {pnl_sign}${g['total_pnl']:.0f}")
lines.append("")
lines.append(f"<i>报告时间: {datetime.now().strftime('%Y-%m-%d %H:%M')}</i>")
return "\n".join(lines)
def reset_all_data(self) -> Dict[str, Any]:
"""
重置所有模拟交易数据
Returns:
重置结果,包含删除的订单数量
"""
db = db_service.get_session()
try:
# 统计删除前的数量
total_count = db.query(PaperOrder).count()
active_count = len(self.active_orders)
# 删除所有订单(包括活跃和历史订单)
deleted = db.query(PaperOrder).delete(synchronize_session='fetch')
db.commit()
# 清空内存缓存
self.active_orders.clear()
logger.info(f"模拟交易数据已重置,删除 {deleted} 条订单(总计 {total_count} 条)")
return {
'deleted_count': deleted,
'active_orders_cleared': active_count
}
except Exception as e:
db.rollback()
logger.error(f"重置模拟交易数据失败: {e}")
raise
finally:
db.close()
# 全局单例
_paper_trading_service: Optional[PaperTradingService] = None
def get_paper_trading_service() -> PaperTradingService:
"""获取模拟交易服务单例"""
global _paper_trading_service
if _paper_trading_service is None:
_paper_trading_service = PaperTradingService()
return _paper_trading_service