286 lines
9.8 KiB
Python
286 lines
9.8 KiB
Python
"""
|
|
交易信号数据库服务
|
|
"""
|
|
from typing import Dict, List, Optional, Any
|
|
from datetime import datetime, timedelta
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import desc, and_, or_
|
|
|
|
from app.models.signal import TradingSignal
|
|
from app.services.db_service import DatabaseService
|
|
from app.utils.logger import logger
|
|
|
|
|
|
class SignalDatabaseService:
|
|
"""交易信号数据库服务"""
|
|
|
|
def __init__(self):
|
|
"""初始化服务"""
|
|
self.db_service = DatabaseService()
|
|
self._ensure_tables()
|
|
|
|
def _ensure_tables(self):
|
|
"""确保表已创建"""
|
|
try:
|
|
from app.models.database import Base
|
|
# 使用 db_service 的 engine
|
|
Base.metadata.create_all(bind=self.db_service.engine)
|
|
logger.info("交易信号表已创建")
|
|
except Exception as e:
|
|
logger.error(f"创建交易信号表失败: {e}")
|
|
|
|
def add_signal(self, signal_data: Dict[str, Any]) -> Optional[TradingSignal]:
|
|
"""添加信号到数据库"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
# 创建信号对象
|
|
signal = TradingSignal(
|
|
signal_type=signal_data.get('signal_type', 'crypto'),
|
|
symbol=signal_data.get('symbol', ''),
|
|
action=signal_data.get('action', 'hold'),
|
|
grade=signal_data.get('grade', 'D'),
|
|
confidence=signal_data.get('confidence', 0),
|
|
entry_price=signal_data.get('entry_price'),
|
|
stop_loss=signal_data.get('stop_loss'),
|
|
take_profit=signal_data.get('take_profit'),
|
|
current_price=signal_data.get('current_price'),
|
|
signal_type_detail=signal_data.get('type'),
|
|
entry_type=signal_data.get('entry_type'),
|
|
position_size=signal_data.get('position_size'),
|
|
reason=signal_data.get('reason'),
|
|
risk_warning=signal_data.get('risk_warning'),
|
|
analysis_summary=signal_data.get('analysis_summary'),
|
|
news_sentiment=signal_data.get('news_sentiment'),
|
|
news_impact=signal_data.get('news_impact'),
|
|
key_levels=signal_data.get('key_levels'),
|
|
indicators=signal_data.get('indicators'),
|
|
notified=True,
|
|
notification_sent_at=datetime.utcnow()
|
|
)
|
|
|
|
db.add(signal)
|
|
db.commit()
|
|
db.refresh(signal)
|
|
|
|
logger.info(f"保存信号到数据库: {signal.signal_type} {signal.symbol} {signal.action} {signal.grade}")
|
|
return signal
|
|
|
|
except Exception as e:
|
|
db.rollback()
|
|
logger.error(f"保存信号失败: {e}")
|
|
return None
|
|
finally:
|
|
db.close()
|
|
|
|
def get_crypto_signals(
|
|
self,
|
|
limit: int = 50,
|
|
symbol: Optional[str] = None,
|
|
days: int = 7
|
|
) -> List[Dict[str, Any]]:
|
|
"""获取加密货币信号"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
|
|
|
query = db.query(TradingSignal).filter(
|
|
TradingSignal.signal_type == 'crypto',
|
|
TradingSignal.created_at >= cutoff_time
|
|
)
|
|
|
|
if symbol:
|
|
query = query.filter(TradingSignal.symbol == symbol.upper())
|
|
|
|
signals = query.order_by(desc(TradingSignal.created_at)).limit(limit).all()
|
|
|
|
return [signal.to_dict() for signal in signals]
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取加密货币信号失败: {e}")
|
|
return []
|
|
finally:
|
|
db.close()
|
|
|
|
def get_stock_signals(
|
|
self,
|
|
limit: int = 50,
|
|
symbol: Optional[str] = None,
|
|
days: int = 7
|
|
) -> List[Dict[str, Any]]:
|
|
"""获取美股信号"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
|
|
|
query = db.query(TradingSignal).filter(
|
|
TradingSignal.signal_type == 'stock',
|
|
TradingSignal.created_at >= cutoff_time
|
|
)
|
|
|
|
if symbol:
|
|
query = query.filter(TradingSignal.symbol == symbol.upper())
|
|
|
|
signals = query.order_by(desc(TradingSignal.created_at)).limit(limit).all()
|
|
|
|
return [signal.to_dict() for signal in signals]
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取美股信号失败: {e}")
|
|
return []
|
|
finally:
|
|
db.close()
|
|
|
|
def get_all_signals(self, limit: int = 100, days: int = 7) -> Dict[str, List[Dict[str, Any]]]:
|
|
"""获取所有信号"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
|
|
|
signals = db.query(TradingSignal).filter(
|
|
TradingSignal.created_at >= cutoff_time
|
|
).order_by(desc(TradingSignal.created_at)).limit(limit).all()
|
|
|
|
crypto_signals = []
|
|
stock_signals = []
|
|
|
|
for signal in signals:
|
|
signal_dict = signal.to_dict()
|
|
if signal.signal_type == 'crypto':
|
|
crypto_signals.append(signal_dict)
|
|
else:
|
|
stock_signals.append(signal_dict)
|
|
|
|
return {
|
|
'crypto': crypto_signals,
|
|
'stock': stock_signals
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取所有信号失败: {e}")
|
|
return {'crypto': [], 'stock': []}
|
|
finally:
|
|
db.close()
|
|
|
|
def get_latest_signals(self, limit: int = 20, days: int = 7) -> List[Dict[str, Any]]:
|
|
"""获取最新信号(混合)"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
|
|
|
signals = db.query(TradingSignal).filter(
|
|
TradingSignal.created_at >= cutoff_time
|
|
).order_by(desc(TradingSignal.created_at)).limit(limit).all()
|
|
|
|
return [signal.to_dict() for signal in signals]
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取最新信号失败: {e}")
|
|
return []
|
|
finally:
|
|
db.close()
|
|
|
|
def get_signal_stats(self, days: int = 7) -> Dict[str, Any]:
|
|
"""获取信号统计"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
|
|
|
# 获取所有信号
|
|
all_signals = db.query(TradingSignal).filter(
|
|
TradingSignal.created_at >= cutoff_time
|
|
).all()
|
|
|
|
# 统计加密货币信号
|
|
crypto_signals = [s for s in all_signals if s.signal_type == 'crypto']
|
|
crypto_buy = sum(1 for s in crypto_signals if s.action == 'buy')
|
|
crypto_sell = sum(1 for s in crypto_signals if s.action == 'sell')
|
|
|
|
# 统计美股信号
|
|
stock_signals = [s for s in all_signals if s.signal_type == 'stock']
|
|
stock_buy = sum(1 for s in stock_signals if s.action == 'buy')
|
|
stock_sell = sum(1 for s in stock_signals if s.action == 'sell')
|
|
|
|
# 按等级统计
|
|
grade_stats = {}
|
|
for signal in all_signals:
|
|
grade_stats[signal.grade] = grade_stats.get(signal.grade, 0) + 1
|
|
|
|
# 最近24小时信号
|
|
recent_cutoff = datetime.utcnow() - timedelta(hours=24)
|
|
recent_crypto = sum(1 for s in crypto_signals if s.created_at >= recent_cutoff)
|
|
recent_stock = sum(1 for s in stock_signals if s.created_at >= recent_cutoff)
|
|
|
|
return {
|
|
'crypto': {
|
|
'total': len(crypto_signals),
|
|
'buy': crypto_buy,
|
|
'sell': crypto_sell,
|
|
'recent_24h': recent_crypto
|
|
},
|
|
'stock': {
|
|
'total': len(stock_signals),
|
|
'buy': stock_buy,
|
|
'sell': stock_sell,
|
|
'recent_24h': recent_stock
|
|
},
|
|
'grades': grade_stats,
|
|
'total': len(all_signals)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取信号统计失败: {e}")
|
|
return {}
|
|
finally:
|
|
db.close()
|
|
|
|
def get_latest_signal(self, signal_type: str, symbol: str) -> Optional[Dict[str, Any]]:
|
|
"""获取指定交易对的最新信号"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
signal = db.query(TradingSignal).filter(
|
|
TradingSignal.signal_type == signal_type,
|
|
TradingSignal.symbol == symbol.upper()
|
|
).order_by(desc(TradingSignal.created_at)).first()
|
|
|
|
if signal:
|
|
return signal.to_dict()
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取最新信号失败: {e}")
|
|
return None
|
|
finally:
|
|
db.close()
|
|
|
|
def clear_old_signals(self, days: int = 30):
|
|
"""清理旧信号"""
|
|
db = self.db_service.get_session()
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
|
|
|
deleted = db.query(TradingSignal).filter(
|
|
TradingSignal.created_at < cutoff_time
|
|
).delete()
|
|
|
|
db.commit()
|
|
logger.info(f"清理了 {deleted} 条旧信号(超过 {days} 天)")
|
|
|
|
except Exception as e:
|
|
db.rollback()
|
|
logger.error(f"清理旧信号失败: {e}")
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# 全局单例
|
|
_signal_db_service: Optional[SignalDatabaseService] = None
|
|
|
|
|
|
def get_signal_db_service() -> SignalDatabaseService:
|
|
"""获取信号数据库服务单例"""
|
|
global _signal_db_service
|
|
if _signal_db_service is None:
|
|
_signal_db_service = SignalDatabaseService()
|
|
return _signal_db_service
|
|
|