This commit is contained in:
aaron 2026-02-19 21:48:23 +08:00
parent e155274828
commit 4f4df30a37

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from sqlalchemy import desc, and_, or_
from app.models.signal import TradingSignal
from app.models.database import SessionLocal, engine, Base
from app.services.db_service import DatabaseService
from app.utils.logger import logger
@ -16,28 +16,21 @@ class SignalDatabaseService:
def __init__(self):
"""初始化服务"""
self.db_service = DatabaseService()
self._ensure_tables()
def _ensure_tables(self):
"""确保表已创建"""
try:
from app.models.database import Base, engine
Base.metadata.create_all(bind=engine)
logger.info("交易信号表已创建")
except Exception as e:
logger.error(f"创建交易信号表失败: {e}")
def _get_db(self) -> Session:
"""获取数据库会话"""
db = SessionLocal()
try:
return db
except Exception as e:
logger.error(f"获取数据库会话失败: {e}")
raise
def add_signal(self, signal_data: Dict[str, Any]) -> Optional[TradingSignal]:
"""添加信号到数据库"""
db = self._get_db()
db = self.db_service.get_session()
try:
# 创建信号对象
signal = TradingSignal(
@ -85,7 +78,7 @@ class SignalDatabaseService:
days: int = 7
) -> List[Dict[str, Any]]:
"""获取加密货币信号"""
db = self._get_db()
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@ -114,7 +107,7 @@ class SignalDatabaseService:
days: int = 7
) -> List[Dict[str, Any]]:
"""获取美股信号"""
db = self._get_db()
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@ -138,7 +131,7 @@ class SignalDatabaseService:
def get_all_signals(self, limit: int = 100, days: int = 7) -> Dict[str, List[Dict[str, Any]]]:
"""获取所有信号"""
db = self._get_db()
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@ -169,7 +162,7 @@ class SignalDatabaseService:
def get_latest_signals(self, limit: int = 20, days: int = 7) -> List[Dict[str, Any]]:
"""获取最新信号(混合)"""
db = self._get_db()
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@ -187,7 +180,7 @@ class SignalDatabaseService:
def get_signal_stats(self, days: int = 7) -> Dict[str, Any]:
"""获取信号统计"""
db = self._get_db()
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@ -241,7 +234,7 @@ class SignalDatabaseService:
def get_latest_signal(self, signal_type: str, symbol: str) -> Optional[Dict[str, Any]]:
"""获取指定交易对的最新信号"""
db = self._get_db()
db = self.db_service.get_session()
try:
signal = db.query(TradingSignal).filter(
TradingSignal.signal_type == signal_type,
@ -260,7 +253,7 @@ class SignalDatabaseService:
def clear_old_signals(self, days: int = 30):
"""清理旧信号"""
db = self._get_db()
db = self.db_service.get_session()
try:
cutoff_time = datetime.utcnow() - timedelta(days=days)
@ -288,3 +281,4 @@ def get_signal_db_service() -> SignalDatabaseService:
if _signal_db_service is None:
_signal_db_service = SignalDatabaseService()
return _signal_db_service