u
This commit is contained in:
parent
e155274828
commit
4f4df30a37
@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, and_, or_
|
||||
|
||||
from app.models.signal import TradingSignal
|
||||
from app.models.database import SessionLocal, engine, Base
|
||||
from app.services.db_service import DatabaseService
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
@ -16,28 +16,21 @@ class SignalDatabaseService:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.db_service = DatabaseService()
|
||||
self._ensure_tables()
|
||||
|
||||
def _ensure_tables(self):
|
||||
"""确保表已创建"""
|
||||
try:
|
||||
from app.models.database import Base, engine
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("交易信号表已创建")
|
||||
except Exception as e:
|
||||
logger.error(f"创建交易信号表失败: {e}")
|
||||
|
||||
def _get_db(self) -> Session:
|
||||
"""获取数据库会话"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return db
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库会话失败: {e}")
|
||||
raise
|
||||
|
||||
def add_signal(self, signal_data: Dict[str, Any]) -> Optional[TradingSignal]:
|
||||
"""添加信号到数据库"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
# 创建信号对象
|
||||
signal = TradingSignal(
|
||||
@ -85,7 +78,7 @@ class SignalDatabaseService:
|
||||
days: int = 7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取加密货币信号"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
@ -114,7 +107,7 @@ class SignalDatabaseService:
|
||||
days: int = 7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取美股信号"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
@ -138,7 +131,7 @@ class SignalDatabaseService:
|
||||
|
||||
def get_all_signals(self, limit: int = 100, days: int = 7) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""获取所有信号"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
@ -169,7 +162,7 @@ class SignalDatabaseService:
|
||||
|
||||
def get_latest_signals(self, limit: int = 20, days: int = 7) -> List[Dict[str, Any]]:
|
||||
"""获取最新信号(混合)"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
@ -187,7 +180,7 @@ class SignalDatabaseService:
|
||||
|
||||
def get_signal_stats(self, days: int = 7) -> Dict[str, Any]:
|
||||
"""获取信号统计"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
@ -241,7 +234,7 @@ class SignalDatabaseService:
|
||||
|
||||
def get_latest_signal(self, signal_type: str, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取指定交易对的最新信号"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
signal = db.query(TradingSignal).filter(
|
||||
TradingSignal.signal_type == signal_type,
|
||||
@ -260,7 +253,7 @@ class SignalDatabaseService:
|
||||
|
||||
def clear_old_signals(self, days: int = 30):
|
||||
"""清理旧信号"""
|
||||
db = self._get_db()
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
@ -288,3 +281,4 @@ def get_signal_db_service() -> SignalDatabaseService:
|
||||
if _signal_db_service is None:
|
||||
_signal_db_service = SignalDatabaseService()
|
||||
return _signal_db_service
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user