stock-ai-agent/backend/test_websocket.py
2026-02-25 23:28:04 +08:00

206 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 Bitget WebSocket 价格监控
"""
import asyncio
import sys
import os
# 添加项目路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from app.services.bitget_websocket import BitgetWebSocketClient
from app.utils.logger import logger
async def test_websocket():
"""测试 WebSocket 连接和价格订阅"""
logger.info("=" * 60)
logger.info("Bitget WebSocket 测试开始")
logger.info("=" * 60)
# 创建 WebSocket 客户端
client = BitgetWebSocketClient()
# 注册价格回调
def on_price_update(symbol: str, price: float, data: dict):
"""价格更新回调"""
import datetime
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3]
print(f"[{timestamp}] {symbol}: ${price:,.2f}")
# 显示更多详细信息
if 'open24h' in data:
change_24h = ((price - float(data['open24h'])) / float(data['open24h'])) * 100
print(f" └─ 24h涨跌: {change_24h:+.2f}%")
client.on_price_update('*', on_price_update)
# 连接
logger.info("正在连接 Bitget WebSocket...")
if not await client.connect():
logger.error("❌ 连接失败")
return False
logger.info("✅ 连接成功")
# 订阅交易对
symbols = ['BTCUSDT', 'ETHUSDT', 'SOLUSDT']
logger.info(f"正在订阅: {', '.join(symbols)}")
if not await client.subscribe(symbols):
logger.error("❌ 订阅失败")
return False
logger.info(f"✅ 订阅成功: {symbols}")
# 显示当前订阅状态
logger.info(f"已订阅交易对: {client.subscribed_symbols}")
logger.info(f"连接状态: {client.is_connected}")
# 接收价格更新30秒
logger.info("\n开始接收价格更新30秒...")
logger.info("-" * 60)
try:
# 等待30秒接收数据
await asyncio.sleep(30)
except KeyboardInterrupt:
logger.info("\n用户中断")
# 显示接收到的价格
logger.info("-" * 60)
logger.info("当前价格缓存:")
prices = client.get_all_prices()
for symbol, price in prices.items():
print(f" {symbol}: ${price:,.2f}")
# 断开连接
logger.info("\n正在断开连接...")
await client.disconnect()
logger.info("✅ 已断开连接")
logger.info("=" * 60)
logger.info("测试完成")
logger.info("=" * 60)
return True
async def test_reconnect():
"""测试自动重连功能"""
logger.info("=" * 60)
logger.info("测试自动重连功能")
logger.info("=" * 60)
client = BitgetWebSocketClient()
price_updates = []
def on_price_update(symbol: str, price: float, data: dict):
price_updates.append((symbol, price))
logger.info(f"[重连测试] {symbol}: ${price:,.2f}")
client.on_price_update('*', on_price_update)
# 第一次连接
logger.info("第一次连接...")
await client.connect()
await client.subscribe(['BTCUSDT'])
# 等待5秒接收数据
await asyncio.sleep(5)
logger.info(f"接收到的价格更新: {len(price_updates)}")
# 模拟断线(手动断开)
logger.info("\n模拟断线...")
await client.disconnect()
# 等待2秒
await asyncio.sleep(2)
# 重新连接
logger.info("重新连接...")
if await client.connect():
logger.info("✅ 重连成功")
await client.subscribe(['BTCUSDT'])
# 等待5秒接收数据
await asyncio.sleep(5)
initial_count = len(price_updates)
logger.info(f"重连后接收到的价格更新: {len(price_updates) - initial_count}")
else:
logger.error("❌ 重连失败")
await client.disconnect()
logger.info("重连测试完成")
return True
def test_integration():
"""测试与 price_monitor_service 的集成"""
logger.info("=" * 60)
logger.info("测试 PriceMonitorService 集成")
logger.info("=" * 60)
# 设置环境变量启用 WebSocket
os.environ['USE_BITGET_WEBSOCKET'] = 'true'
from app.services.price_monitor_service import get_price_monitor_service
monitor = get_price_monitor_service()
# 添加回调
update_count = {'count': 0}
def on_update(symbol: str, price: float):
update_count['count'] += 1
logger.info(f"[集成测试] {symbol}: ${price:,.2f}")
monitor.add_price_callback(on_update)
# 订阅交易对
logger.info("订阅 BTCUSDT...")
monitor.subscribe_symbol('BTCUSDT')
# 等待10秒
logger.info("等待10秒接收价格更新...")
import time
time.sleep(10)
logger.info(f"接收到 {update_count['count']} 次价格更新")
logger.info(f"当前价格: {monitor.get_latest_price('BTCUSDT')}")
# 停止监控
monitor.stop()
logger.info("集成测试完成")
return True
def main():
"""主测试函数"""
import argparse
parser = argparse.ArgumentParser(description='测试 Bitget WebSocket')
parser.add_argument('--test', choices=['basic', 'reconnect', 'integration'],
default='basic', help='测试类型')
args = parser.parse_args()
if args.test == 'basic':
asyncio.run(test_websocket())
elif args.test == 'reconnect':
asyncio.run(test_reconnect())
elif args.test == 'integration':
test_integration()
if __name__ == '__main__':
main()