stock-ai-agent/backend/app/services/websocket_monitor.py
2026-02-20 22:53:42 +08:00

285 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
WebSocket 价格监控服务 - 使用 Binance WebSocket API 实现实时价格推送
"""
import json
import asyncio
import threading
from typing import Dict, List, Callable, Optional, Set
from datetime import datetime
import websockets
from app.utils.logger import logger
class WebSocketPriceMonitor:
"""WebSocket 实时价格监控服务"""
# Binance WebSocket 端点
BASE_WS_URL = "wss://stream.binance.com:9443/ws"
def __init__(self):
"""初始化 WebSocket 价格监控服务"""
self._ws = None
self._loop = None
self._thread = None
self._running = False
self._subscribed_symbols: Set[str] = set()
self._price_callbacks: List[Callable[[str, float], None]] = []
self._latest_prices: Dict[str, float] = {}
self._lock = threading.Lock()
self._last_heartbeat: Optional[datetime] = None
# 连接和重连配置
self._reconnect_delay = 5 # 重连延迟(秒)
self._max_reconnect_attempts = 10
logger.info("WebSocket 价格监控服务初始化完成")
def is_running(self) -> bool:
"""检查服务是否在运行"""
return self._running and self._ws is not None and self._running
def subscribe_symbol(self, symbol: str):
"""
订阅交易对的价格推送
Args:
symbol: 交易对,如 "BTCUSDT"
"""
symbol = symbol.upper()
need_start = False
with self._lock:
if symbol in self._subscribed_symbols:
logger.debug(f"[WS:{id(self)}] {symbol} 已订阅,跳过")
return
self._subscribed_symbols.add(symbol)
# 检查是否需要启动服务
if not self.is_running():
need_start = True
# 在锁外启动服务(避免死锁)
if need_start:
self.start()
# 在锁外获取当前价格(避免阻塞)
self._fetch_current_price(symbol)
logger.info(f"[WS:{id(self)}] 已订阅 {symbol} 价格更新 (当前订阅: {self._subscribed_symbols})")
def unsubscribe_symbol(self, symbol: str):
"""取消订阅交易对"""
symbol = symbol.upper()
with self._lock:
if symbol in self._subscribed_symbols:
self._subscribed_symbols.discard(symbol)
self._latest_prices.pop(symbol, None)
logger.info(f"[WS:{id(self)}] 已取消订阅 {symbol}")
# 如果没有订阅了,可以考虑断开连接
if not self._subscribed_symbols:
logger.info(f"[WS:{id(self)}] 没有订阅的交易对,准备断开连接")
def add_price_callback(self, callback: Callable[[str, float], None]):
"""添加价格更新回调函数"""
with self._lock:
if callback not in self._price_callbacks:
self._price_callbacks.append(callback)
def remove_price_callback(self, callback: Callable):
"""移除价格回调函数"""
with self._lock:
if callback in self._price_callbacks:
self._price_callbacks.remove(callback)
def get_latest_price(self, symbol: str) -> Optional[float]:
"""获取交易对的最新缓存价格"""
return self._latest_prices.get(symbol.upper())
def get_subscribed_symbols(self) -> List[str]:
"""获取已订阅的交易对列表"""
with self._lock:
return list(self._subscribed_symbols)
def start(self):
"""启动 WebSocket 连接"""
with self._lock:
if self._running:
logger.debug(f"[WS:{id(self)}] WebSocket 服务已在运行")
return
self._running = True
# 在新线程中运行事件循环
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
self._thread.start()
def stop(self):
"""停止 WebSocket 连接"""
with self._lock:
if not self._running:
return
self._running = False
# 关闭 WebSocket 连接
if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._close_ws())
if self._thread and self._thread.is_alive():
self._thread.join(timeout=5)
logger.info(f"[WS:{id(self)}] WebSocket 价格监控服务已停止")
def _run_event_loop(self):
"""运行 WebSocket 事件循环(在单独线程中)"""
# 创建新的事件循环
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
try:
self._loop.run_until_complete(self._connect_and_listen())
except Exception as e:
logger.error(f"[WS:{id(self)}] WebSocket 事件循环出错: {e}")
finally:
self._loop.close()
async def _connect_and_listen(self):
"""连接并监听 WebSocket 消息"""
retry_count = 0
while self._running and retry_count < self._max_reconnect_attempts:
try:
# 构建订阅流
with self._lock:
symbols = list(self._subscribed_symbols)
if not symbols:
# 没有订阅的交易对,等待订阅
logger.debug(f"[WS:{id(self)}] 没有订阅的交易对,等待 5 秒")
await asyncio.sleep(5)
continue
# 构建 WebSocket 流
streams = []
for symbol in symbols:
streams.append(f"{symbol.lower()}@ticker")
# Binance 组合流 URL 格式: /stream?streams=btcusdt@ticker/ethusdt@ticker
url = f"{self.BASE_WS_URL}/stream?streams={'/'.join(streams)}"
logger.info(f"[WS:{id(self)}] 正在连接 WebSocket... (订阅: {', '.join(symbols)})")
logger.debug(f"[WS:{id(self)}] WebSocket URL: {url}")
async with websockets.connect(url, ping_interval=30) as ws:
self._ws = ws
retry_count = 0 # 连接成功,重置重试计数
self._last_heartbeat = datetime.now()
logger.info(f"[WS:{id(self)}] WebSocket 已连接")
# 监听消息
async for message in self._ws:
await self._on_message(message)
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"[WS:{id(self)}] WebSocket 连接关闭: {e}")
except websockets.exceptions.ConnectionError as e:
logger.error(f"[WS:{id(self)}] WebSocket 连接错误: {e}")
except Exception as e:
logger.error(f"[WS:{id(self)}] WebSocket 异常: {e}")
# 检查是否需要重连
with self._lock:
should_reconnect = self._running and self._subscribed_symbols and retry_count < self._max_reconnect_attempts
if should_reconnect:
retry_count += 1
logger.info(f"[WS:{id(self)}] 将在 {self._reconnect_delay} 秒后重连... (尝试 {retry_count}/{self._max_reconnect_attempts})")
await asyncio.sleep(self._reconnect_delay)
else:
if self._running:
logger.warning(f"[WS:{id(self)}] 达到最大重连次数,停止服务")
self._running = False
break
async def _on_message(self, message):
"""处理 WebSocket 消息"""
try:
data = json.loads(message)
# 处理不同的消息类型
if data.get('e') == '24hrTicker': # 24小时价格变动
symbol = data.get('s')
if symbol:
# 解析价格
price = float(data.get('c', 0)) # 当前价格
self._update_price(symbol.upper(), price)
elif data.get('result') is not None and isinstance(data['result'], list):
# 多个交易对的价格推送
for item in data['result']:
symbol = item.get('s')
if symbol:
price = float(item.get('c', 0))
self._update_price(symbol.upper(), price)
except json.JSONDecodeError as e:
logger.error(f"[WS:{id(self)}] 解析 WebSocket 消息失败: {e}")
except Exception as e:
logger.error(f"[WS:{id(self)}] 处理 WebSocket 消息出错: {e}")
def _update_price(self, symbol: str, price: float):
"""更新价格并触发回调"""
old_price = self._latest_prices.get(symbol)
# 只有价格变化时才触发回调
if old_price != price:
self._latest_prices[symbol] = price
# 调用所有注册的回调函数
with self._lock:
callbacks = self._price_callbacks.copy()
# 在线程中执行回调
for callback in callbacks:
try:
callback(symbol, price)
except Exception as e:
logger.error(f"[WS:{id(self)}] 价格回调执行出错: {e}")
async def _close_ws(self):
"""关闭 WebSocket 连接"""
if self._ws:
await self._ws.close()
self._ws = None
logger.info(f"[WS:{id(self)}] WebSocket 连接已关闭")
def _fetch_current_price(self, symbol: str):
"""立即获取当前价格WebSocket 连接建立前的临时方案)"""
try:
import requests
url = f"https://api.binance.com/api/v3/ticker/price?symbol={symbol}"
response = requests.get(url, timeout=5)
if response.status_code == 200:
data = response.json()
price = float(data['price'])
self._latest_prices[symbol] = price
logger.debug(f"[WS:{id(self)}] 获取 {symbol} 当前价格: ${price:,.2f}")
except Exception as e:
logger.warning(f"[WS:{id(self)}] 获取 {symbol} 当前价格失败: {e}")
# 全局单例
_ws_monitor: Optional[WebSocketPriceMonitor] = None
def get_ws_price_monitor() -> WebSocketPriceMonitor:
"""获取 WebSocket 价格监控服务单例"""
global _ws_monitor
if _ws_monitor is None:
_ws_monitor = WebSocketPriceMonitor()
return _ws_monitor