""" 交易信号数据库服务 """ from typing import Dict, List, Optional, Any from datetime import datetime, timedelta from sqlalchemy.orm import Session from sqlalchemy import desc, and_, or_ from app.models.signal import TradingSignal from app.services.db_service import DatabaseService from app.utils.logger import logger class SignalDatabaseService: """交易信号数据库服务""" def __init__(self): """初始化服务""" self.db_service = DatabaseService() self._ensure_tables() def _ensure_tables(self): """确保表已创建""" try: from app.models.database import Base # 使用 db_service 的 engine Base.metadata.create_all(bind=self.db_service.engine) logger.info("交易信号表已创建") except Exception as e: logger.error(f"创建交易信号表失败: {e}") def add_signal(self, signal_data: Dict[str, Any]) -> Optional[TradingSignal]: """添加信号到数据库""" db = self.db_service.get_session() try: # 清理价格字段 - 移除 $ 符号和逗号 def clean_price(price_value): """清理价格字段,转换为 float""" if price_value is None: return None if isinstance(price_value, (int, float)): return float(price_value) if isinstance(price_value, str): # 移除 $ 符号和逗号 cleaned = price_value.replace('$', '').replace(',', '').strip() if cleaned: try: return float(cleaned) except ValueError: return None return None # 创建信号对象 signal = TradingSignal( signal_type=signal_data.get('signal_type', 'crypto'), symbol=signal_data.get('symbol', ''), action=signal_data.get('action', 'hold'), grade=signal_data.get('grade', 'D'), confidence=signal_data.get('confidence', 0), entry_price=clean_price(signal_data.get('entry_price')), stop_loss=clean_price(signal_data.get('stop_loss')), take_profit=clean_price(signal_data.get('take_profit')), current_price=clean_price(signal_data.get('current_price')), signal_type_detail=signal_data.get('type'), entry_type=signal_data.get('entry_type'), position_size=signal_data.get('position_size'), reason=signal_data.get('reason'), risk_warning=signal_data.get('risk_warning'), analysis_summary=signal_data.get('analysis_summary'), news_sentiment=signal_data.get('news_sentiment'), news_impact=signal_data.get('news_impact'), key_levels=signal_data.get('key_levels'), indicators=signal_data.get('indicators'), notified=True, notification_sent_at=datetime.utcnow() ) db.add(signal) db.commit() db.refresh(signal) logger.info(f"保存信号到数据库: {signal.signal_type} {signal.symbol} {signal.action} {signal.grade}") return signal except Exception as e: db.rollback() logger.error(f"保存信号失败: {e}") return None finally: db.close() def get_crypto_signals( self, limit: int = 50, symbol: Optional[str] = None, days: int = 7 ) -> List[Dict[str, Any]]: """获取加密货币信号""" db = self.db_service.get_session() try: cutoff_time = datetime.utcnow() - timedelta(days=days) query = db.query(TradingSignal).filter( TradingSignal.signal_type == 'crypto', TradingSignal.created_at >= cutoff_time ) if symbol: query = query.filter(TradingSignal.symbol == symbol.upper()) signals = query.order_by(desc(TradingSignal.created_at)).limit(limit).all() return [signal.to_dict() for signal in signals] except Exception as e: logger.error(f"获取加密货币信号失败: {e}") return [] finally: db.close() def get_stock_signals( self, limit: int = 50, symbol: Optional[str] = None, days: int = 7 ) -> List[Dict[str, Any]]: """获取美股信号""" db = self.db_service.get_session() try: cutoff_time = datetime.utcnow() - timedelta(days=days) query = db.query(TradingSignal).filter( TradingSignal.signal_type == 'stock', TradingSignal.created_at >= cutoff_time ) if symbol: query = query.filter(TradingSignal.symbol == symbol.upper()) signals = query.order_by(desc(TradingSignal.created_at)).limit(limit).all() return [signal.to_dict() for signal in signals] except Exception as e: logger.error(f"获取美股信号失败: {e}") return [] finally: db.close() def get_all_signals(self, limit: int = 100, days: int = 7) -> Dict[str, List[Dict[str, Any]]]: """获取所有信号""" db = self.db_service.get_session() try: cutoff_time = datetime.utcnow() - timedelta(days=days) signals = db.query(TradingSignal).filter( TradingSignal.created_at >= cutoff_time ).order_by(desc(TradingSignal.created_at)).limit(limit).all() crypto_signals = [] stock_signals = [] for signal in signals: signal_dict = signal.to_dict() if signal.signal_type == 'crypto': crypto_signals.append(signal_dict) else: stock_signals.append(signal_dict) return { 'crypto': crypto_signals, 'stock': stock_signals } except Exception as e: logger.error(f"获取所有信号失败: {e}") return {'crypto': [], 'stock': []} finally: db.close() def get_latest_signals(self, limit: int = 20, days: int = 7) -> List[Dict[str, Any]]: """获取最新信号(混合)""" db = self.db_service.get_session() try: cutoff_time = datetime.utcnow() - timedelta(days=days) signals = db.query(TradingSignal).filter( TradingSignal.created_at >= cutoff_time ).order_by(desc(TradingSignal.created_at)).limit(limit).all() return [signal.to_dict() for signal in signals] except Exception as e: logger.error(f"获取最新信号失败: {e}") return [] finally: db.close() def get_signal_stats(self, days: int = 7) -> Dict[str, Any]: """获取信号统计""" db = self.db_service.get_session() try: cutoff_time = datetime.utcnow() - timedelta(days=days) # 获取所有信号 all_signals = db.query(TradingSignal).filter( TradingSignal.created_at >= cutoff_time ).all() # 统计加密货币信号 crypto_signals = [s for s in all_signals if s.signal_type == 'crypto'] crypto_buy = sum(1 for s in crypto_signals if s.action == 'buy') crypto_sell = sum(1 for s in crypto_signals if s.action == 'sell') # 统计美股信号 stock_signals = [s for s in all_signals if s.signal_type == 'stock'] stock_buy = sum(1 for s in stock_signals if s.action == 'buy') stock_sell = sum(1 for s in stock_signals if s.action == 'sell') # 按等级统计 grade_stats = {} for signal in all_signals: grade_stats[signal.grade] = grade_stats.get(signal.grade, 0) + 1 # 最近24小时信号 recent_cutoff = datetime.utcnow() - timedelta(hours=24) recent_crypto = sum(1 for s in crypto_signals if s.created_at >= recent_cutoff) recent_stock = sum(1 for s in stock_signals if s.created_at >= recent_cutoff) return { 'crypto': { 'total': len(crypto_signals), 'buy': crypto_buy, 'sell': crypto_sell, 'recent_24h': recent_crypto }, 'stock': { 'total': len(stock_signals), 'buy': stock_buy, 'sell': stock_sell, 'recent_24h': recent_stock }, 'grades': grade_stats, 'total': len(all_signals) } except Exception as e: logger.error(f"获取信号统计失败: {e}") return {} finally: db.close() def get_latest_signal(self, signal_type: str, symbol: str) -> Optional[Dict[str, Any]]: """获取指定交易对的最新信号""" db = self.db_service.get_session() try: signal = db.query(TradingSignal).filter( TradingSignal.signal_type == signal_type, TradingSignal.symbol == symbol.upper() ).order_by(desc(TradingSignal.created_at)).first() if signal: return signal.to_dict() return None except Exception as e: logger.error(f"获取最新信号失败: {e}") return None finally: db.close() def clear_old_signals(self, days: int = 30): """清理旧信号""" db = self.db_service.get_session() try: cutoff_time = datetime.utcnow() - timedelta(days=days) deleted = db.query(TradingSignal).filter( TradingSignal.created_at < cutoff_time ).delete() db.commit() logger.info(f"清理了 {deleted} 条旧信号(超过 {days} 天)") except Exception as e: db.rollback() logger.error(f"清理旧信号失败: {e}") finally: db.close() # 全局单例 _signal_db_service: Optional[SignalDatabaseService] = None def get_signal_db_service() -> SignalDatabaseService: """获取信号数据库服务单例""" global _signal_db_service if _signal_db_service is None: _signal_db_service = SignalDatabaseService() return _signal_db_service