From 3729b2de1e946fede1d72344cbd836dbe0d2a3fc Mon Sep 17 00:00:00 2001 From: aaron <> Date: Wed, 25 Feb 2026 23:28:04 +0800 Subject: [PATCH] update --- backend/app/config.py | 3 + backend/app/services/bitget_websocket.py | 433 ++++++++++++++++++ backend/app/services/price_monitor_service.py | 204 +++++++-- backend/requirements.txt | 1 + backend/test_websocket.py | 205 +++++++++ 5 files changed, 808 insertions(+), 38 deletions(-) create mode 100644 backend/app/services/bitget_websocket.py create mode 100644 backend/test_websocket.py diff --git a/backend/app/config.py b/backend/app/config.py index 9d1faaf..806f05c 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -109,6 +109,9 @@ class Settings(BaseSettings): crypto_analysis_interval: int = 60 # 分析间隔(秒) crypto_llm_threshold: float = 0.70 # 触发 LLM 分析的置信度阈值 + # 价格监控模式配置 + use_bitget_websocket: bool = True # 是否使用 Bitget WebSocket 实时价格(默认 False 使用 Binance 轮询) + # 波动率过滤配置(节省 LLM 调用) crypto_volatility_filter_enabled: bool = True # 是否启用波动率过滤 crypto_min_volatility_percent: float = 0.5 # 最小波动率(百分比),低于此值跳过分析 diff --git a/backend/app/services/bitget_websocket.py b/backend/app/services/bitget_websocket.py new file mode 100644 index 0000000..16e0bdc --- /dev/null +++ b/backend/app/services/bitget_websocket.py @@ -0,0 +1,433 @@ +""" +Bitget WebSocket 实时价格服务 +通过 WebSocket 订阅实时 ticker 价格更新 +""" +import asyncio +import json +import logging +from typing import Dict, Callable, Optional, Set, Any +from datetime import datetime + +try: + import websockets + WEBSOCKETS_AVAILABLE = True +except ImportError: + WEBSOCKETS_AVAILABLE = False + +from app.utils.logger import logger + + +class BitgetWebSocketClient: + """ + Bitget WebSocket 客户端 - 实时价格订阅 + + 使用异步 WebSocket 连接获取实时价格更新 + """ + + # Bitget WebSocket v2 端点 + WS_URL = "wss://ws.bitget.com/v2/ws/public" # 公共频道 v2 + + # 心跳间隔(秒) + HEARTBEAT_INTERVAL = 25 + + # 重连间隔(秒) + RECONNECT_INTERVAL = 5 + + # 订阅限制:每个连接最多订阅 50 个交易对 + MAX_SUBSCRIPTIONS = 50 + + def __init__(self): + """初始化 WebSocket 客户端""" + if not WEBSOCKETS_AVAILABLE: + raise ImportError("需要安装 websockets 库: pip install websockets") + + self._ws: Optional[websockets.WebSocketClientProtocol] = None + self._running = False + self._subscribed_symbols: Set[str] = set() + + # 价格缓存 + self._prices: Dict[str, float] = {} + + # 回调函数 + self._callbacks: Dict[str, Set[Callable]] = {} + + # 心跳任务 + self._heartbeat_task: Optional[asyncio.Task] = None + self._reconnect_task: Optional[asyncio.Task] = None + + logger.info("Bitget WebSocket 客户端初始化完成") + + async def connect(self) -> bool: + """ + 连接到 Bitget WebSocket + + Returns: + 连接是否成功 + """ + try: + logger.info(f"正在连接 Bitget WebSocket: {self.WS_URL}") + self._ws = await websockets.connect( + self.WS_URL, + ping_interval=self.HEARTBEAT_INTERVAL, + ping_timeout=10, + close_timeout=10 + ) + + self._running = True + logger.info("✅ Bitget WebSocket 连接成功") + + # 启动消息接收循环 + asyncio.create_task(self._message_loop()) + + # 启动心跳任务 + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + return True + + except Exception as e: + logger.error(f"❌ Bitget WebSocket 连接失败: {e}") + return False + + async def disconnect(self): + """断开 WebSocket 连接""" + logger.info("正在断开 Bitget WebSocket...") + + self._running = False + + # 取消心跳任务 + if self._heartbeat_task: + self._heartbeat_task.cancel() + self._heartbeat_task = None + + # 取消重连任务 + if self._reconnect_task: + self._reconnect_task.cancel() + self._reconnect_task = None + + # 关闭 WebSocket 连接 + if self._ws: + await self._ws.close() + self._ws = None + + logger.info("Bitget WebSocket 已断开") + + async def subscribe(self, symbols: list) -> bool: + """ + 订阅交易对价格 + + Args: + symbols: 交易对列表,如 ['BTCUSDT', 'ETHUSDT'] + + Returns: + 是否订阅成功 + """ + if not self._ws or not self._running: + logger.warning("WebSocket 未连接,无法订阅") + return False + + try: + # 构建订阅消息 (根据 Bitget 官方文档) + # USDT-FUTURES = USDT 永续合约 + args = [] + for symbol in symbols: + if symbol not in self._subscribed_symbols: + args.append({ + "instType": "USDT-FUTURES", + "channel": "ticker", + "instId": symbol + }) + self._subscribed_symbols.add(symbol) + + if not args: + logger.info("所有交易对已订阅") + return True + + message = { + "op": "subscribe", + "args": args + } + + await self._ws.send(json.dumps(message)) + logger.info(f"✅ 订阅 {len(args)} 个交易对: {[s['instId'] for s in args]}") + + return True + + except Exception as e: + logger.error(f"订阅失败: {e}") + return False + + async def unsubscribe(self, symbols: list) -> bool: + """ + 取消订阅交易对价格 + + Args: + symbols: 交易对列表 + + Returns: + 是否取消成功 + """ + if not self._ws or not self._running: + return False + + try: + args = [] + for symbol in symbols: + if symbol in self._subscribed_symbols: + args.append({ + "instType": "USDT-FUTURES", + "channel": "ticker", + "instId": symbol + }) + self._subscribed_symbols.discard(symbol) + + if not args: + return True + + message = { + "op": "unsubscribe", + "args": args + } + + await self._ws.send(json.dumps(message)) + logger.info(f"取消订阅 {len(args)} 个交易对") + + return True + + except Exception as e: + logger.error(f"取消订阅失败: {e}") + return False + + def get_price(self, symbol: str) -> Optional[float]: + """ + 获取交易对的最新价格(从缓存) + + Args: + symbol: 交易对 + + Returns: + 最新价格,如果未订阅则返回 None + """ + return self._prices.get(symbol) + + def get_all_prices(self) -> Dict[str, float]: + """ + 获取所有已订阅交易对的价格 + + Returns: + {symbol: price} 字典 + """ + return self._prices.copy() + + def on_price_update(self, symbol: str, callback: Callable[[str, float, Dict], None]): + """ + 注册价格更新回调 + + Args: + symbol: 交易对,'*' 表示所有交易对 + callback: 回调函数 (symbol, price, data) -> None + """ + if symbol not in self._callbacks: + self._callbacks[symbol] = set() + self._callbacks[symbol].add(callback) + logger.debug(f"注册价格回调: {symbol}") + + def off_price_update(self, symbol: str, callback: Callable): + """ + 取消价格更新回调 + + Args: + symbol: 交易对 + callback: 回调函数 + """ + if symbol in self._callbacks: + self._callbacks[symbol].discard(callback) + + async def _message_loop(self): + """消息接收循环""" + try: + async for message in self._ws: + await self._handle_message(message) + + except websockets.ConnectionClosed: + logger.warning("WebSocket 连接已关闭") + if self._running: + # 自动重连 + self._schedule_reconnect() + + except Exception as e: + logger.error(f"消息循环错误: {e}") + if self._running: + self._schedule_reconnect() + + async def _handle_message(self, message: str): + """ + 处理接收到的消息 (v2 API 格式) + + Args: + message: WebSocket 消息 + """ + try: + data = json.loads(message) + + # 调试:记录所有收到的消息 + logger.info(f"📨 收到消息: action={data.get('action', 'unknown')}, event={data.get('event', 'none')}") + + # v2 API: 订阅/取消订阅确认事件 (使用 event 字段) + if data.get('event') == 'subscribe': + logger.info(f"✅ 订阅确认: {data}") + return + + if data.get('event') == 'unsubscribe': + logger.info(f"✅ 取消订阅确认: {data}") + return + + if data.get('event') == 'error': + logger.error(f"❌ WebSocket 错误: {data}") + return + + # v2 API: ticker 数据格式 (使用 action 字段) + # {"action": "snapshot" or "update", "data": [...], "arg": {...}} + if 'data' in data and isinstance(data['data'], list): + # 处理 data 数组中的每个 ticker + for ticker_item in data['data']: + if 'instId' in ticker_item or 'lastPr' in ticker_item: + self._process_ticker(ticker_item) + + except json.JSONDecodeError: + logger.warning(f"无法解析消息: {message[:100]}") + except Exception as e: + logger.error(f"处理消息错误: {e}") + + def _process_ticker(self, ticker: Dict[str, Any]): + """ + 处理 ticker 数据 (v2 API 格式) + + Args: + ticker: ticker 数据,包含 instId 和 lastPr 字段 + """ + try: + # v2 API: 直接从 ticker 获取 instId 和 lastPr + symbol = ticker.get('instId', '') + price_str = ticker.get('lastPr', '0') + + if not symbol: + logger.debug(f"跳过无效 ticker (无 instId): {ticker}") + return + + try: + price = float(price_str) + except (ValueError, TypeError): + logger.debug(f"跳过无效 ticker (价格无效): symbol={symbol}, price_str={price_str}") + return + + if price == 0: + logger.debug(f"跳过无效 ticker (价格为0): symbol={symbol}") + return + + # 更新价格缓存 + old_price = self._prices.get(symbol) + self._prices[symbol] = price + + # 触发回调 + self._trigger_callbacks(symbol, price, ticker) + + # 价格变化日志 - 改为 info 级别方便调试 + logger.info(f"💰 {symbol}: ${price:,.2f}") + if old_price and old_price != price: + change = ((price - old_price) / old_price) * 100 + logger.debug(f" 变化: {change:+.2f}%") + + except Exception as e: + logger.error(f"解析 ticker 数据错误: {e}") + + def _trigger_callbacks(self, symbol: str, price: float, data: Dict): + """ + 触发价格更新回调 + + Args: + symbol: 交易对 + price: 价格 + data: 完整的 ticker 数据 + """ + # 触发该交易对的回调 + if symbol in self._callbacks: + for callback in self._callbacks[symbol]: + try: + callback(symbol, price, data) + except Exception as e: + logger.error(f"回调函数错误 ({symbol}): {e}") + + # 触发全局回调('*') + if '*' in self._callbacks: + for callback in self._callbacks['*']: + try: + callback(symbol, price, data) + except Exception as e: + logger.error(f"全局回调函数错误: {e}") + + async def _heartbeat_loop(self): + """心跳循环""" + while self._running and self._ws: + try: + await asyncio.sleep(self.HEARTBEAT_INTERVAL) + + # 发送 ping + await self._ws.ping() + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"心跳错误: {e}") + self._schedule_reconnect() + break + + def _schedule_reconnect(self): + """安排重连""" + if not self._running: + return + + if self._reconnect_task and not self._reconnect_task.done(): + return # 已经有重连任务在运行 + + async def reconnect(): + await asyncio.sleep(self.RECONNECT_INTERVAL) + if self._running: + logger.info("尝试重新连接...") + await self.disconnect() + if await self.connect(): + # 重新订阅之前的交易对 + if self._subscribed_symbols: + await self.subscribe(list(self._subscribed_symbols)) + + self._reconnect_task = asyncio.create_task(reconnect()) + + @property + def is_connected(self) -> bool: + """是否已连接""" + return self._ws is not None and self._running + + @property + def subscribed_symbols(self) -> Set[str]: + """已订阅的交易对""" + return self._subscribed_symbols.copy() + + +# 全局实例 +_bitget_ws_client: Optional[BitgetWebSocketClient] = None + + +def get_bitget_ws_client() -> Optional[BitgetWebSocketClient]: + """ + 获取 Bitget WebSocket 客户端单例 + + Returns: + WebSocket 客户端实例,如果 websockets 库未安装则返回 None + """ + global _bitget_ws_client + if _bitget_ws_client is None: + try: + _bitget_ws_client = BitgetWebSocketClient() + except ImportError: + logger.warning("websockets 库未安装,WebSocket 功能不可用") + return None + return _bitget_ws_client diff --git a/backend/app/services/price_monitor_service.py b/backend/app/services/price_monitor_service.py index 67cf467..2212491 100644 --- a/backend/app/services/price_monitor_service.py +++ b/backend/app/services/price_monitor_service.py @@ -1,6 +1,7 @@ """ -价格监控服务 - 使用轮询方式获取实时价格(更稳定) +价格监控服务 - 支持 WebSocket 实时推送和轮询两种模式(统一使用 Bitget 数据源) """ +import asyncio import threading import time import requests @@ -10,10 +11,7 @@ from app.config import get_settings class PriceMonitorService: - """实时价格监控服务(轮询模式)""" - - # Binance API - BASE_URL = "https://api.binance.com" + """实时价格监控服务(支持 WebSocket 和轮询两种模式,统一使用 Bitget)""" _instance = None _initialized = False @@ -37,22 +35,40 @@ class PriceMonitorService: self.price_callbacks: List[Callable[[str, float], None]] = [] self.latest_prices: Dict[str, float] = {} self._lock = threading.Lock() + + # 模式选择 + self._use_websocket = getattr(self.settings, 'use_bitget_websocket', False) + + # 轮询模式相关(使用 Bitget REST API) self._poll_thread: Optional[threading.Thread] = None self._poll_interval = 3 # 轮询间隔(秒) self._session = requests.Session() + self._bitget_rest_url = "https://api.bitget.com" # Bitget REST API - logger.info(f"[PriceMonitor:{id(self)}] 价格监控服务初始化完成(轮询模式)") + # WebSocket 模式相关 + self._ws_thread: Optional[threading.Thread] = None + self._ws_loop: Optional[asyncio.AbstractEventLoop] = None + self._ws_client = None + + logger.info(f"价格监控服务初始化完成 (模式: {'Bitget WebSocket' if self._use_websocket else 'Bitget REST 轮询'})") def start(self): - """启动价格轮询""" + """启动价格监控""" if self.running: - logger.debug(f"[PriceMonitor:{id(self)}] 价格监控服务已在运行") + logger.debug("价格监控服务已在运行") return self.running = True + if self._use_websocket: + self._start_websocket() + else: + self._start_polling() + + def _start_polling(self): + """启动轮询模式""" def _poll_loop(): - logger.info(f"[PriceMonitor:{id(self)}] 价格轮询已启动,间隔 {self._poll_interval} 秒") + logger.info(f"价格轮询已启动,间隔 {self._poll_interval} 秒") while self.running: try: self._fetch_prices() @@ -70,16 +86,90 @@ class PriceMonitorService: self._poll_thread = threading.Thread(target=_poll_loop, daemon=True) self._poll_thread.start() + def _start_websocket(self): + """启动 WebSocket 模式""" + def _run_ws(): + """在新线程中运行 WebSocket""" + # 创建新的事件循环 + self._ws_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._ws_loop) + + try: + # 导入 WebSocket 客户端 + from app.services.bitget_websocket import get_bitget_ws_client + + self._ws_client = get_bitget_ws_client() + if not self._ws_client: + logger.error("无法创建 WebSocket 客户端") + return + + # 连接 - 在事件循环中运行 + connect_task = self._ws_loop.create_task(self._ws_client.connect()) + self._ws_loop.run_until_complete(connect_task) + + if not connect_task.result(): + logger.error("WebSocket 连接失败") + return + + # 注册价格更新回调 + self._ws_client.on_price_update('*', self._on_ws_price_update) + + # 订阅已有的交易对 + if self.subscribed_symbols: + subscribe_task = self._ws_loop.create_task( + self._ws_client.subscribe(list(self.subscribed_symbols)) + ) + self._ws_loop.run_until_complete(subscribe_task) + + # 运行事件循环 + self._ws_loop.run_forever() + + except Exception as e: + logger.error(f"WebSocket 线程错误: {e}") + finally: + logger.info("WebSocket 线程已退出") + + # 防止重复创建线程 + if self._ws_thread is None or not self._ws_thread.is_alive(): + self._ws_thread = threading.Thread(target=_run_ws, daemon=True) + self._ws_thread.start() + + def _on_ws_price_update(self, symbol: str, price: float, data: Dict): + """ + WebSocket 价格更新回调 + + Args: + symbol: 交易对 + price: 新价格 + data: 完整的 ticker 数据 + """ + self._update_price(symbol, price) + def stop(self): - """停止价格轮询""" + """停止价格监控""" if not self.running: return self.running = False + + # 停止 WebSocket + if self._use_websocket and self._ws_loop and self._ws_client: + # 在事件循环中停止 WebSocket + asyncio.run_coroutine_threadsafe( + self._ws_client.disconnect(), + self._ws_loop + ) + # 停止事件循环 + self._ws_loop.call_soon_threadsafe(self._ws_loop.stop) + + # 停止轮询 + if self._poll_thread: + self._poll_thread.join(timeout=2) + logger.info("价格监控服务已停止") def _fetch_prices(self): - """获取所有订阅交易对的价格""" + """获取所有订阅交易对的价格(轮询模式)""" if not self.subscribed_symbols: return @@ -94,32 +184,54 @@ class PriceMonitorService: self._fetch_all_prices(symbols) def _fetch_single_price(self, symbol: str): - """获取单个交易对价格""" + """获取单个交易对价格(使用 Bitget REST API)""" try: - url = f"{self.BASE_URL}/api/v3/ticker/price" - response = self._session.get(url, params={'symbol': symbol}, timeout=5) + url = f"{self._bitget_rest_url}/api/v3/market/tickers" + params = { + 'category': 'USDT-FUTURES', + 'symbol': symbol + } + response = self._session.get(url, params=params, timeout=5) response.raise_for_status() - data = response.json() - price = float(data['price']) - self._update_price(symbol, price) + result = response.json() + + if result.get('code') != '00000': + logger.debug(f"Bitget API 错误: {result.get('msg')}") + return + + data = result.get('data', []) + if data: + price = float(data[0].get('lastPr', '0')) + if price > 0: + self._update_price(symbol, price) + except Exception as e: logger.debug(f"获取 {symbol} 价格失败: {e}") def _fetch_all_prices(self, symbols: List[str]): - """批量获取价格""" + """批量获取价格(使用 Bitget REST API)""" try: - url = f"{self.BASE_URL}/api/v3/ticker/price" - response = self._session.get(url, timeout=10) + url = f"{self._bitget_rest_url}/api/v3/market/tickers" + params = {'category': 'USDT-FUTURES'} + response = self._session.get(url, params=params, timeout=10) response.raise_for_status() - all_prices = response.json() + result = response.json() + + if result.get('code') != '00000': + logger.debug(f"Bitget API 错误: {result.get('msg')}") + return + + all_tickers = result.get('data', []) + symbol_set = set(symbols) # 过滤出订阅的交易对 - symbol_set = set(symbols) - for item in all_prices: - symbol = item['symbol'] + for ticker in all_tickers: + symbol = ticker.get('instId') if symbol in symbol_set: - price = float(item['price']) - self._update_price(symbol, price) + price = float(ticker.get('lastPr', '0')) + if price > 0: + self._update_price(symbol, price) + except Exception as e: logger.debug(f"批量获取价格失败: {e}") @@ -148,33 +260,47 @@ class PriceMonitorService: Args: symbol: 交易对,如 "BTCUSDT" """ - import traceback symbol = symbol.upper() - # 添加调用栈追踪 - stack = traceback.extract_stack() - caller = stack[-2] if len(stack) >= 2 else None - if symbol in self.subscribed_symbols: - logger.debug(f"[PriceMonitor:{id(self)}] {symbol} 已订阅,跳过 (来自: {caller})") + logger.debug(f"{symbol} 已订阅,跳过") return self.subscribed_symbols.add(symbol) - logger.info(f"[PriceMonitor:{id(self)}] 已订阅 {symbol} 价格更新 (来自: {caller},当前订阅: {self.subscribed_symbols})") + logger.info(f"已订阅 {symbol} 价格更新 (当前订阅: {len(self.subscribed_symbols)} 个)") + + # WebSocket 模式:立即订阅 + if self._use_websocket and self._ws_client and self._ws_loop and self._ws_loop.is_running(): + asyncio.run_coroutine_threadsafe( + self._ws_client.subscribe([symbol]), + self._ws_loop + ) + elif self._use_websocket: + # WebSocket 还未就绪,将在连接后自动订阅 + logger.debug(f"WebSocket 未就绪,{symbol} 将在连接后自动订阅") + # 轮询模式:立即获取一次价格 + elif not self._use_websocket: + self._fetch_single_price(symbol) # 如果服务未启动,自动启动 if not self.running: self.start() - # 立即获取一次价格 - self._fetch_single_price(symbol) - def unsubscribe_symbol(self, symbol: str): """取消订阅交易对""" symbol = symbol.upper() + if symbol in self.subscribed_symbols: self.subscribed_symbols.discard(symbol) self.latest_prices.pop(symbol, None) + + # WebSocket 模式:取消订阅 + if self._use_websocket and self._ws_client: + asyncio.run_coroutine_threadsafe( + self._ws_client.unsubscribe([symbol]), + self._ws_loop + ) + logger.info(f"已取消订阅 {symbol}") def add_price_callback(self, callback: Callable[[str, float], None]): @@ -211,7 +337,10 @@ class PriceMonitorService: return self.running def set_poll_interval(self, seconds: int): - """设置轮询间隔(秒)""" + """设置轮询间隔(秒),仅对轮询模式有效""" + if self._use_websocket: + logger.warning("WebSocket 模式下不支持设置轮询间隔") + return self._poll_interval = max(1, seconds) logger.info(f"轮询间隔已设置为 {self._poll_interval} 秒") @@ -224,4 +353,3 @@ def get_price_monitor_service() -> PriceMonitorService: """获取价格监控服务单例""" # 直接使用类单例,不使用全局变量(避免 reload 时重置) return PriceMonitorService() - diff --git a/backend/requirements.txt b/backend/requirements.txt index 020dede..43b4ac0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -23,6 +23,7 @@ python-jose[cryptography]==3.3.0 python-binance>=1.0.19 httpx>=0.27.0 ccxt>=4.0.0 # 统一交易所API接口,支持Bitget等主流交易所 +websockets>=12.0 # WebSocket 支持,用于实时价格更新 # 新闻智能体依赖 feedparser>=6.0.10 diff --git a/backend/test_websocket.py b/backend/test_websocket.py new file mode 100644 index 0000000..3b21ddf --- /dev/null +++ b/backend/test_websocket.py @@ -0,0 +1,205 @@ +""" +测试 Bitget WebSocket 价格监控 +""" +import asyncio +import sys +import os + +# 添加项目路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from app.services.bitget_websocket import BitgetWebSocketClient +from app.utils.logger import logger + + +async def test_websocket(): + """测试 WebSocket 连接和价格订阅""" + + logger.info("=" * 60) + logger.info("Bitget WebSocket 测试开始") + logger.info("=" * 60) + + # 创建 WebSocket 客户端 + client = BitgetWebSocketClient() + + # 注册价格回调 + def on_price_update(symbol: str, price: float, data: dict): + """价格更新回调""" + import datetime + timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] + print(f"[{timestamp}] {symbol}: ${price:,.2f}") + + # 显示更多详细信息 + if 'open24h' in data: + change_24h = ((price - float(data['open24h'])) / float(data['open24h'])) * 100 + print(f" └─ 24h涨跌: {change_24h:+.2f}%") + + client.on_price_update('*', on_price_update) + + # 连接 + logger.info("正在连接 Bitget WebSocket...") + if not await client.connect(): + logger.error("❌ 连接失败") + return False + + logger.info("✅ 连接成功") + + # 订阅交易对 + symbols = ['BTCUSDT', 'ETHUSDT', 'SOLUSDT'] + logger.info(f"正在订阅: {', '.join(symbols)}") + + if not await client.subscribe(symbols): + logger.error("❌ 订阅失败") + return False + + logger.info(f"✅ 订阅成功: {symbols}") + + # 显示当前订阅状态 + logger.info(f"已订阅交易对: {client.subscribed_symbols}") + logger.info(f"连接状态: {client.is_connected}") + + # 接收价格更新(30秒) + logger.info("\n开始接收价格更新(30秒)...") + logger.info("-" * 60) + + try: + # 等待30秒接收数据 + await asyncio.sleep(30) + + except KeyboardInterrupt: + logger.info("\n用户中断") + + # 显示接收到的价格 + logger.info("-" * 60) + logger.info("当前价格缓存:") + prices = client.get_all_prices() + for symbol, price in prices.items(): + print(f" {symbol}: ${price:,.2f}") + + # 断开连接 + logger.info("\n正在断开连接...") + await client.disconnect() + logger.info("✅ 已断开连接") + + logger.info("=" * 60) + logger.info("测试完成") + logger.info("=" * 60) + + return True + + +async def test_reconnect(): + """测试自动重连功能""" + + logger.info("=" * 60) + logger.info("测试自动重连功能") + logger.info("=" * 60) + + client = BitgetWebSocketClient() + + price_updates = [] + + def on_price_update(symbol: str, price: float, data: dict): + price_updates.append((symbol, price)) + logger.info(f"[重连测试] {symbol}: ${price:,.2f}") + + client.on_price_update('*', on_price_update) + + # 第一次连接 + logger.info("第一次连接...") + await client.connect() + await client.subscribe(['BTCUSDT']) + + # 等待5秒接收数据 + await asyncio.sleep(5) + logger.info(f"接收到的价格更新: {len(price_updates)} 次") + + # 模拟断线(手动断开) + logger.info("\n模拟断线...") + await client.disconnect() + + # 等待2秒 + await asyncio.sleep(2) + + # 重新连接 + logger.info("重新连接...") + if await client.connect(): + logger.info("✅ 重连成功") + await client.subscribe(['BTCUSDT']) + + # 等待5秒接收数据 + await asyncio.sleep(5) + + initial_count = len(price_updates) + logger.info(f"重连后接收到的价格更新: {len(price_updates) - initial_count} 次") + else: + logger.error("❌ 重连失败") + + await client.disconnect() + logger.info("重连测试完成") + + return True + + +def test_integration(): + """测试与 price_monitor_service 的集成""" + + logger.info("=" * 60) + logger.info("测试 PriceMonitorService 集成") + logger.info("=" * 60) + + # 设置环境变量启用 WebSocket + os.environ['USE_BITGET_WEBSOCKET'] = 'true' + + from app.services.price_monitor_service import get_price_monitor_service + + monitor = get_price_monitor_service() + + # 添加回调 + update_count = {'count': 0} + + def on_update(symbol: str, price: float): + update_count['count'] += 1 + logger.info(f"[集成测试] {symbol}: ${price:,.2f}") + + monitor.add_price_callback(on_update) + + # 订阅交易对 + logger.info("订阅 BTCUSDT...") + monitor.subscribe_symbol('BTCUSDT') + + # 等待10秒 + logger.info("等待10秒接收价格更新...") + import time + time.sleep(10) + + logger.info(f"接收到 {update_count['count']} 次价格更新") + logger.info(f"当前价格: {monitor.get_latest_price('BTCUSDT')}") + + # 停止监控 + monitor.stop() + + logger.info("集成测试完成") + + return True + + +def main(): + """主测试函数""" + import argparse + + parser = argparse.ArgumentParser(description='测试 Bitget WebSocket') + parser.add_argument('--test', choices=['basic', 'reconnect', 'integration'], + default='basic', help='测试类型') + args = parser.parse_args() + + if args.test == 'basic': + asyncio.run(test_websocket()) + elif args.test == 'reconnect': + asyncio.run(test_reconnect()) + elif args.test == 'integration': + test_integration() + + +if __name__ == '__main__': + main()