tradusai/core/websocket_client.py
2025-12-02 22:54:03 +08:00

210 lines
7.0 KiB
Python

"""
Binance WebSocket Client with auto-reconnection and exponential backoff
"""
import asyncio
import logging
import json
import time
from typing import Callable, Optional, Dict, Any
from datetime import datetime
import websockets
from websockets.exceptions import ConnectionClosed, WebSocketException
from config import settings
logger = logging.getLogger(__name__)
class BinanceWebSocketClient:
"""
Binance Futures WebSocket client with production-grade features:
- Auto-reconnection with exponential backoff
- Multi-stream subscription
- Heartbeat monitoring
- Graceful shutdown
"""
def __init__(
self,
symbol: str,
on_message: Callable[[Dict[str, Any]], None],
on_error: Optional[Callable[[Exception], None]] = None,
):
self.symbol = symbol.lower()
self.on_message = on_message
self.on_error = on_error
self.ws: Optional[websockets.WebSocketClientProtocol] = None
self.is_running = False
self.reconnect_count = 0
self.last_message_time = time.time()
# Reconnection settings
self.reconnect_delay = settings.RECONNECT_INITIAL_DELAY
self.max_reconnect_delay = settings.RECONNECT_MAX_DELAY
self.reconnect_multiplier = settings.RECONNECT_MULTIPLIER
# Build stream URL
self.ws_url = self._build_stream_url()
def _build_stream_url(self) -> str:
"""Build multi-stream WebSocket URL"""
streams = []
# Add multiple kline intervals
for interval in settings.kline_intervals_list:
streams.append(f"{self.symbol}@kline_{interval}")
# Add depth and trade streams
streams.append(f"{self.symbol}@depth20@100ms") # Top 20 depth, 100ms updates
streams.append(f"{self.symbol}@aggTrade") # Aggregated trades
stream_path = "/".join(streams)
url = f"{settings.BINANCE_WS_BASE_URL}/stream?streams={stream_path}"
logger.info(f"WebSocket URL: {url}")
logger.info(f"Subscribing to kline intervals: {', '.join(settings.kline_intervals_list)}")
return url
async def connect(self) -> None:
"""Establish WebSocket connection with retry logic"""
attempt = 0
while self.is_running:
try:
attempt += 1
logger.info(f"Connecting to Binance WebSocket (attempt {attempt})...")
async with websockets.connect(
self.ws_url,
ping_interval=20, # Send ping every 20s
ping_timeout=10, # Wait 10s for pong
close_timeout=10,
) as websocket:
self.ws = websocket
self.reconnect_delay = settings.RECONNECT_INITIAL_DELAY
self.reconnect_count = 0
logger.info("✓ WebSocket connected successfully")
# Message receiving loop
await self._receive_messages()
except ConnectionClosed as e:
logger.warning(f"WebSocket connection closed: {e.code} - {e.reason}")
await self._handle_reconnect()
except WebSocketException as e:
logger.error(f"WebSocket error: {e}")
if self.on_error:
self.on_error(e)
await self._handle_reconnect()
except Exception as e:
logger.error(f"Unexpected error: {e}", exc_info=True)
if self.on_error:
self.on_error(e)
await self._handle_reconnect()
finally:
self.ws = None
logger.info("WebSocket client stopped")
async def _receive_messages(self) -> None:
"""Receive and process messages from WebSocket"""
if not self.ws:
return
async for message in self.ws:
try:
self.last_message_time = time.time()
# Parse JSON message
data = json.loads(message)
# Handle combined stream format
if "stream" in data and "data" in data:
stream_name = data["stream"]
stream_data = data["data"]
# Add metadata
stream_data["_stream"] = stream_name
stream_data["_received_at"] = datetime.utcnow().isoformat()
# Process message
await self._process_message(stream_data)
else:
# Single stream format
data["_received_at"] = datetime.utcnow().isoformat()
await self._process_message(data)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {e}, message: {message[:200]}")
except Exception as e:
logger.error(f"Error processing message: {e}", exc_info=True)
async def _process_message(self, data: Dict[str, Any]) -> None:
"""Process received message"""
try:
# Call user-defined message handler
if asyncio.iscoroutinefunction(self.on_message):
await self.on_message(data)
else:
self.on_message(data)
except Exception as e:
logger.error(f"Error in message handler: {e}", exc_info=True)
async def _handle_reconnect(self) -> None:
"""Handle reconnection with exponential backoff"""
if not self.is_running:
return
self.reconnect_count += 1
# Check max attempts
if (
settings.MAX_RECONNECT_ATTEMPTS > 0
and self.reconnect_count > settings.MAX_RECONNECT_ATTEMPTS
):
logger.error("Max reconnection attempts reached. Stopping client.")
self.is_running = False
return
# Calculate delay with exponential backoff
delay = min(
self.reconnect_delay * (self.reconnect_multiplier ** (self.reconnect_count - 1)),
self.max_reconnect_delay,
)
logger.info(f"Reconnecting in {delay:.1f}s (attempt {self.reconnect_count})...")
await asyncio.sleep(delay)
async def start(self) -> None:
"""Start WebSocket client"""
if self.is_running:
logger.warning("Client is already running")
return
self.is_running = True
logger.info("Starting WebSocket client...")
await self.connect()
async def stop(self) -> None:
"""Stop WebSocket client gracefully"""
logger.info("Stopping WebSocket client...")
self.is_running = False
if self.ws:
await self.ws.close()
self.ws = None
def is_healthy(self) -> bool:
"""Check if client is healthy (receiving messages)"""
if not self.is_running or not self.ws:
return False
# Check if we've received a message in the last 60 seconds
time_since_last_message = time.time() - self.last_message_time
return time_since_last_message < 60