stock-ai-agent/backend/app/services/price_monitor_service.py
2026-02-06 23:35:16 +08:00

364 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
价格监控服务 - 使用 Binance WebSocket 实时监控价格
"""
import threading
import time
import sys
import os
import logging
from typing import Dict, List, Callable, Optional, Set
from app.utils.logger import logger
from app.config import get_settings
# 抑制 binance 库的 WebSocket 错误日志(正确的 logger 名称)
logging.getLogger('binance.ws.threaded_stream').setLevel(logging.CRITICAL)
logging.getLogger('binance.ws.reconnecting_websocket').setLevel(logging.CRITICAL)
logging.getLogger('binance.ws').setLevel(logging.WARNING) # 只显示警告及以上
class SuppressOutput:
"""临时抑制 stdout/stderr 输出"""
def __init__(self, suppress_stderr=True, suppress_stdout=False):
self.suppress_stderr = suppress_stderr
self.suppress_stdout = suppress_stdout
self._stderr = None
self._stdout = None
self._devnull = None
def __enter__(self):
self._devnull = open(os.devnull, 'w')
if self.suppress_stderr:
self._stderr = sys.stderr
sys.stderr = self._devnull
if self.suppress_stdout:
self._stdout = sys.stdout
sys.stdout = self._devnull
return self
def __exit__(self, *args):
if self._stderr:
sys.stderr = self._stderr
if self._stdout:
sys.stdout = self._stdout
if self._devnull:
self._devnull.close()
class PriceMonitorService:
"""实时价格监控服务"""
def __init__(self):
"""初始化价格监控服务"""
self.settings = get_settings()
self.twm = None
self.running = False
self.subscribed_symbols: Dict[str, str] = {} # symbol -> stream_name
self.price_callbacks: List[Callable[[str, float], None]] = []
self.latest_prices: Dict[str, float] = {}
self._lock = threading.Lock()
self._pending_symbols: List[str] = [] # 待订阅的交易对
self._reconnecting = False # 是否正在重连
self._desired_symbols: Set[str] = set() # 期望订阅的交易对(用于重连)
self._stop_requested = False # 是否请求停止(区分主动停止和意外断开)
self._last_message_time: Dict[str, float] = {} # 上次收到消息的时间
self._health_check_thread = None
logger.info("价格监控服务初始化完成")
def start(self):
"""启动 WebSocket 管理器(在独立线程中)"""
if self.running:
logger.warning("价格监控服务已在运行")
return
self._stop_requested = False
def _start_in_thread():
try:
# 延迟导入,避免在模块加载时就创建事件循环
from binance import ThreadedWebsocketManager
self.twm = ThreadedWebsocketManager(
api_key=self.settings.binance_api_key or "",
api_secret=self.settings.binance_api_secret or ""
)
self.twm.start()
self.running = True
self._reconnecting = False
logger.info("WebSocket 管理器已启动")
# 等待 WebSocket 完全启动
time.sleep(1)
# 订阅待处理的交易对
for symbol in self._pending_symbols:
self._do_subscribe(symbol)
self._pending_symbols.clear()
# 重连时恢复之前的订阅
for symbol in self._desired_symbols:
if symbol not in self.subscribed_symbols:
self._do_subscribe(symbol)
# 启动健康检查
self._start_health_check()
except Exception as e:
logger.error(f"启动 WebSocket 管理器失败: {e}")
import traceback
logger.error(traceback.format_exc())
# 启动失败,尝试重连
if not self._stop_requested:
self._schedule_reconnect()
# 在独立线程中启动
thread = threading.Thread(target=_start_in_thread, daemon=True)
thread.start()
def _start_health_check(self):
"""启动健康检查线程"""
def _check_health():
while self.running and not self._stop_requested:
time.sleep(30) # 每30秒检查一次
if not self.running or self._stop_requested:
break
# 检查是否有超过60秒没收到消息的交易对
now = time.time()
for symbol in list(self._desired_symbols):
last_time = self._last_message_time.get(symbol, now)
if now - last_time > 60:
logger.warning(f"{symbol} 超过60秒未收到数据触发重连")
self._schedule_reconnect()
break
self._health_check_thread = threading.Thread(target=_check_health, daemon=True)
self._health_check_thread.start()
def stop(self):
"""停止 WebSocket 管理器"""
# 标记为主动停止
self._stop_requested = True
if not self.running:
return
# 先标记为停止,防止回调继续处理
self.running = False
try:
# 抑制关闭时的错误输出binance 库用 print 输出错误)
with SuppressOutput(suppress_stderr=True, suppress_stdout=True):
# 先停止所有 socket 订阅
if self.twm:
for _, stream_name in list(self.subscribed_symbols.items()):
try:
self.twm.stop_socket(stream_name)
except:
pass
# 等待一小段时间让 socket 关闭
time.sleep(0.5)
# 然后停止管理器
try:
self.twm.stop()
except:
pass
self.subscribed_symbols.clear()
self._desired_symbols.clear()
self._last_message_time.clear()
logger.info("WebSocket 管理器已停止")
except Exception as e:
# 忽略关闭时的错误
pass
def _schedule_reconnect(self, delay: int = 5):
"""安排重连"""
if self._stop_requested or self._reconnecting:
return
self._reconnecting = True
logger.warning(f"WebSocket 连接断开,{delay} 秒后尝试重连...")
def _reconnect():
time.sleep(delay)
if not self._stop_requested:
self._do_reconnect()
thread = threading.Thread(target=_reconnect, daemon=True)
thread.start()
def _do_reconnect(self):
"""执行重连"""
if self._stop_requested:
return
logger.info("正在重新连接 WebSocket...")
# 清理旧连接(抑制错误输出)
with SuppressOutput(suppress_stderr=True, suppress_stdout=True):
try:
if self.twm:
self.twm.stop()
except:
pass
self.twm = None
self.running = False
self.subscribed_symbols.clear()
# 重新启动
self.start()
def subscribe_symbol(self, symbol: str):
"""
订阅交易对的实时价格
Args:
symbol: 交易对,如 "BTCUSDT"
"""
symbol = symbol.upper()
# 记录期望订阅的交易对(用于重连恢复)
self._desired_symbols.add(symbol)
if symbol in self.subscribed_symbols:
logger.debug(f"已订阅 {symbol}")
return
if not self.running:
# 如果还没启动,先加入待订阅列表,然后启动
if symbol not in self._pending_symbols:
self._pending_symbols.append(symbol)
self.start()
return
self._do_subscribe(symbol)
def _do_subscribe(self, symbol: str):
"""实际执行订阅"""
if not self.twm or not self.running:
return
try:
stream_name = self.twm.start_symbol_ticker_socket(
callback=self._handle_price_update,
symbol=symbol
)
self.subscribed_symbols[symbol] = stream_name
self._last_message_time[symbol] = time.time()
logger.info(f"已订阅 {symbol} 价格更新")
except Exception as e:
logger.error(f"订阅 {symbol} 失败: {e}")
def unsubscribe_symbol(self, symbol: str):
"""取消订阅交易对"""
symbol = symbol.upper()
if symbol not in self.subscribed_symbols:
return
try:
stream_name = self.subscribed_symbols[symbol]
self.twm.stop_socket(stream_name)
del self.subscribed_symbols[symbol]
self._desired_symbols.discard(symbol)
logger.info(f"已取消订阅 {symbol}")
except Exception as e:
logger.error(f"取消订阅 {symbol} 失败: {e}")
def add_price_callback(self, callback: Callable[[str, float], None]):
"""
添加价格更新回调函数
Args:
callback: 回调函数,签名为 (symbol: str, price: float) -> None
"""
with self._lock:
if callback not in self.price_callbacks:
self.price_callbacks.append(callback)
def remove_price_callback(self, callback: Callable):
"""移除价格回调函数"""
with self._lock:
if callback in self.price_callbacks:
self.price_callbacks.remove(callback)
def _handle_price_update(self, msg: Dict):
"""处理 WebSocket 价格更新消息"""
# 如果服务已停止或正在重连,忽略消息
if not self.running or self._reconnecting or self._stop_requested:
return
try:
# 检查错误消息
if msg.get('e') == 'error':
error_type = msg.get('type', '')
error_msg = str(msg.get('m', ''))
# 这些错误通常是正常的连接关闭,不需要记录
ignored_errors = ['ReadLoopClosed', 'ConnectionClosed', 'WebSocketClosed', 'read loop']
if error_type in ignored_errors or any(e.lower() in error_msg.lower() for e in ignored_errors):
# 如果不是主动停止,触发重连
if not self._stop_requested and self.running:
self.running = False
self._schedule_reconnect()
return
# 其他错误记录日志(但不刷屏)
if self.running and not self._stop_requested:
logger.warning(f"WebSocket 消息: {msg}")
return
symbol = msg.get('s') # 交易对
price_str = msg.get('c') # 最新价格
if not symbol or not price_str:
return
price = float(price_str)
# 更新最新价格缓存和消息时间
self.latest_prices[symbol] = price
self._last_message_time[symbol] = time.time()
# 调用所有注册的回调函数
with self._lock:
callbacks = self.price_callbacks.copy()
for callback in callbacks:
try:
callback(symbol, price)
except Exception as e:
logger.error(f"价格回调执行出错: {e}")
except Exception as e:
if self.running and not self._stop_requested:
logger.error(f"处理价格更新出错: {e}")
def get_latest_price(self, symbol: str) -> Optional[float]:
"""获取交易对的最新缓存价格"""
return self.latest_prices.get(symbol.upper())
def get_subscribed_symbols(self) -> List[str]:
"""获取已订阅的交易对列表"""
return list(self.subscribed_symbols.keys())
def is_running(self) -> bool:
"""检查服务是否在运行"""
return self.running
# 全局单例
_price_monitor_service: Optional[PriceMonitorService] = None
def get_price_monitor_service() -> PriceMonitorService:
"""获取价格监控服务单例"""
global _price_monitor_service
if _price_monitor_service is None:
_price_monitor_service = PriceMonitorService()
return _price_monitor_service