stock-ai-agent/backend/app/services/signal_database_service.py
2026-03-02 22:58:04 +08:00

305 lines
11 KiB
Python

"""
交易信号数据库服务
"""
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import desc, and_, or_
from app.models.signal import TradingSignal
from app.services.db_service import DatabaseService
from app.utils.logger import logger
class SignalDatabaseService:
"""交易信号数据库服务"""
def __init__(self):
"""初始化服务"""
self.db_service = DatabaseService()
self._ensure_tables()
def _ensure_tables(self):
"""确保表已创建"""
try:
from app.models.database import Base
# 使用 db_service 的 engine
Base.metadata.create_all(bind=self.db_service.engine)
logger.info("交易信号表已创建")
except Exception as e:
logger.error(f"创建交易信号表失败: {e}")
def add_signal(self, signal_data: Dict[str, Any]) -> Optional[TradingSignal]:
"""添加信号到数据库"""
db = self.db_service.get_session()
try:
# 清理价格字段 - 移除 $ 符号和逗号
def clean_price(price_value):
"""清理价格字段,转换为 float"""
if price_value is None:
return None
if isinstance(price_value, (int, float)):
return float(price_value)
if isinstance(price_value, str):
# 移除 $ 符号和逗号
cleaned = price_value.replace('$', '').replace(',', '').strip()
if cleaned:
try:
return float(cleaned)
except ValueError:
return None
return None
# 创建信号对象
entry_price_value = clean_price(signal_data.get('entry_price'))
signal = TradingSignal(
signal_type=signal_data.get('signal_type', 'crypto'),
symbol=signal_data.get('symbol', ''),
action=signal_data.get('action', 'hold'),
grade=signal_data.get('grade', 'D'),
confidence=signal_data.get('confidence', 0),
entry_price=entry_price_value,
stop_loss=clean_price(signal_data.get('stop_loss')),
take_profit=clean_price(signal_data.get('take_profit')),
current_price=clean_price(signal_data.get('current_price')),
signal_type_detail=signal_data.get('type') or signal_data.get('timeframe'),
entry_type=signal_data.get('entry_type'),
position_size=signal_data.get('position_size'),
reason=signal_data.get('reasoning') or signal_data.get('reason'), # reasoning 来自 LLM 输出
risk_warning=signal_data.get('risk_warning'),
analysis_summary=signal_data.get('analysis_summary'),
news_sentiment=signal_data.get('news_sentiment'),
news_impact=signal_data.get('news_impact'),
key_levels=signal_data.get('key_levels'),
indicators=signal_data.get('indicators'),
notified=True,
notification_sent_at=datetime.utcnow()
)
db.add(signal)
db.commit()
db.refresh(signal)
logger.info(f"保存信号到数据库: {signal.signal_type} {signal.symbol} {signal.action} {signal.grade}")
return signal
except Exception as e:
db.rollback()
logger.error(f"保存信号失败: {e}")
return None
finally:
db.close()
def get_crypto_signals(
self,
limit: int = 50,
symbol: Optional[str] = None,
days: int = 7
) -> List[Dict[str, Any]]:
"""获取加密货币信号"""
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
query = db.query(TradingSignal).filter(
TradingSignal.signal_type == 'crypto',
TradingSignal.created_at >= cutoff_time
)
if symbol:
query = query.filter(TradingSignal.symbol == symbol.upper())
signals = query.order_by(desc(TradingSignal.created_at)).limit(limit).all()
return [signal.to_dict() for signal in signals]
except Exception as e:
logger.error(f"获取加密货币信号失败: {e}")
return []
finally:
db.close()
def get_stock_signals(
self,
limit: int = 50,
symbol: Optional[str] = None,
days: int = 7
) -> List[Dict[str, Any]]:
"""获取美股信号"""
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
query = db.query(TradingSignal).filter(
TradingSignal.signal_type == 'stock',
TradingSignal.created_at >= cutoff_time
)
if symbol:
query = query.filter(TradingSignal.symbol == symbol.upper())
signals = query.order_by(desc(TradingSignal.created_at)).limit(limit).all()
return [signal.to_dict() for signal in signals]
except Exception as e:
logger.error(f"获取美股信号失败: {e}")
return []
finally:
db.close()
def get_all_signals(self, limit: int = 100, days: int = 7) -> Dict[str, List[Dict[str, Any]]]:
"""获取所有信号"""
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
signals = db.query(TradingSignal).filter(
TradingSignal.created_at >= cutoff_time
).order_by(desc(TradingSignal.created_at)).limit(limit).all()
crypto_signals = []
stock_signals = []
for signal in signals:
signal_dict = signal.to_dict()
if signal.signal_type == 'crypto':
crypto_signals.append(signal_dict)
else:
stock_signals.append(signal_dict)
return {
'crypto': crypto_signals,
'stock': stock_signals
}
except Exception as e:
logger.error(f"获取所有信号失败: {e}")
return {'crypto': [], 'stock': []}
finally:
db.close()
def get_latest_signals(self, limit: int = 20, days: int = 7) -> List[Dict[str, Any]]:
"""获取最新信号(混合)"""
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
signals = db.query(TradingSignal).filter(
TradingSignal.created_at >= cutoff_time
).order_by(desc(TradingSignal.created_at)).limit(limit).all()
return [signal.to_dict() for signal in signals]
except Exception as e:
logger.error(f"获取最新信号失败: {e}")
return []
finally:
db.close()
def get_signal_stats(self, days: int = 7) -> Dict[str, Any]:
"""获取信号统计"""
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
# 获取所有信号
all_signals = db.query(TradingSignal).filter(
TradingSignal.created_at >= cutoff_time
).all()
# 统计加密货币信号
crypto_signals = [s for s in all_signals if s.signal_type == 'crypto']
crypto_buy = sum(1 for s in crypto_signals if s.action == 'buy')
crypto_sell = sum(1 for s in crypto_signals if s.action == 'sell')
# 统计美股信号
stock_signals = [s for s in all_signals if s.signal_type == 'stock']
stock_buy = sum(1 for s in stock_signals if s.action == 'buy')
stock_sell = sum(1 for s in stock_signals if s.action == 'sell')
# 按等级统计
grade_stats = {}
for signal in all_signals:
grade_stats[signal.grade] = grade_stats.get(signal.grade, 0) + 1
# 最近24小时信号
recent_cutoff = datetime.utcnow() - timedelta(hours=24)
recent_crypto = sum(1 for s in crypto_signals if s.created_at >= recent_cutoff)
recent_stock = sum(1 for s in stock_signals if s.created_at >= recent_cutoff)
return {
'crypto': {
'total': len(crypto_signals),
'buy': crypto_buy,
'sell': crypto_sell,
'recent_24h': recent_crypto
},
'stock': {
'total': len(stock_signals),
'buy': stock_buy,
'sell': stock_sell,
'recent_24h': recent_stock
},
'grades': grade_stats,
'total': len(all_signals)
}
except Exception as e:
logger.error(f"获取信号统计失败: {e}")
return {}
finally:
db.close()
def get_latest_signal(self, signal_type: str, symbol: str) -> Optional[Dict[str, Any]]:
"""获取指定交易对的最新信号"""
db = self.db_service.get_session()
try:
signal = db.query(TradingSignal).filter(
TradingSignal.signal_type == signal_type,
TradingSignal.symbol == symbol.upper()
).order_by(desc(TradingSignal.created_at)).first()
if signal:
return signal.to_dict()
return None
except Exception as e:
logger.error(f"获取最新信号失败: {e}")
return None
finally:
db.close()
def clear_old_signals(self, days: int = 30):
"""清理旧信号"""
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
deleted = db.query(TradingSignal).filter(
TradingSignal.created_at < cutoff_time
).delete()
db.commit()
logger.info(f"清理了 {deleted} 条旧信号(超过 {days} 天)")
except Exception as e:
db.rollback()
logger.error(f"清理旧信号失败: {e}")
finally:
db.close()
# 全局单例
_signal_db_service: Optional[SignalDatabaseService] = None
def get_signal_db_service() -> SignalDatabaseService:
"""获取信号数据库服务单例"""
global _signal_db_service
if _signal_db_service is None:
_signal_db_service = SignalDatabaseService()
return _signal_db_service