248 lines
7.7 KiB
Python
248 lines
7.7 KiB
Python
"""
|
|
Redis Stream writer with batch support and error handling
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
import orjson
|
|
import redis.asyncio as redis
|
|
from redis.exceptions import RedisError, ConnectionError as RedisConnectionError
|
|
|
|
from config import settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisStreamWriter:
|
|
"""
|
|
Redis Stream writer for real-time market data.
|
|
|
|
Features:
|
|
- Async Redis client with connection pooling
|
|
- Automatic stream trimming (MAXLEN)
|
|
- JSON serialization with orjson
|
|
- Connection retry logic
|
|
- Performance metrics
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.redis_client: Optional[redis.Redis] = None
|
|
self.is_connected = False
|
|
|
|
# Statistics
|
|
self.stats = {
|
|
"messages_written": 0,
|
|
"kline_count": 0,
|
|
"depth_count": 0,
|
|
"trade_count": 0,
|
|
"errors": 0,
|
|
}
|
|
|
|
async def connect(self) -> None:
|
|
"""Establish Redis connection"""
|
|
try:
|
|
self.redis_client = redis.Redis(
|
|
host=settings.REDIS_HOST,
|
|
port=settings.REDIS_PORT,
|
|
db=settings.REDIS_DB,
|
|
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
|
encoding="utf-8",
|
|
decode_responses=False, # We'll handle JSON encoding
|
|
socket_connect_timeout=5,
|
|
socket_keepalive=True,
|
|
health_check_interval=30,
|
|
)
|
|
|
|
# Test connection
|
|
await self.redis_client.ping()
|
|
self.is_connected = True
|
|
logger.info("✓ Redis connection established")
|
|
|
|
except RedisConnectionError as e:
|
|
logger.error(f"Failed to connect to Redis: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error connecting to Redis: {e}")
|
|
raise
|
|
|
|
async def close(self) -> None:
|
|
"""Close Redis connection"""
|
|
if self.redis_client:
|
|
await self.redis_client.close()
|
|
self.is_connected = False
|
|
logger.info("Redis connection closed")
|
|
|
|
def _serialize_message(self, message: Dict[str, Any]) -> bytes:
|
|
"""
|
|
Serialize message to JSON bytes using orjson
|
|
|
|
Args:
|
|
message: Message data
|
|
|
|
Returns:
|
|
JSON bytes
|
|
"""
|
|
return orjson.dumps(message)
|
|
|
|
def _determine_stream_key(self, message: Dict[str, Any]) -> Optional[str]:
|
|
"""
|
|
Determine which Redis Stream to write to based on message type
|
|
|
|
Args:
|
|
message: Message data
|
|
|
|
Returns:
|
|
Redis stream key or None if unknown type
|
|
"""
|
|
stream = message.get("_stream", "")
|
|
|
|
# Kline stream - extract interval from stream name
|
|
if "kline" in stream or ("e" in message and message["e"] == "kline"):
|
|
# Extract interval from stream name (e.g., "btcusdt@kline_5m" -> "5m")
|
|
if "@kline_" in stream:
|
|
interval = stream.split("@kline_")[1]
|
|
return f"{settings.REDIS_STREAM_KLINE_PREFIX}:{interval}"
|
|
# Fallback: extract from message data
|
|
elif "k" in message and "i" in message["k"]:
|
|
interval = message["k"]["i"]
|
|
return f"{settings.REDIS_STREAM_KLINE_PREFIX}:{interval}"
|
|
|
|
# Depth stream
|
|
if "depth" in stream or ("e" in message and message["e"] == "depthUpdate"):
|
|
return settings.REDIS_STREAM_DEPTH
|
|
|
|
# Trade stream
|
|
if "trade" in stream or "aggTrade" in stream or ("e" in message and message["e"] in ["trade", "aggTrade"]):
|
|
return settings.REDIS_STREAM_TRADE
|
|
|
|
logger.warning(f"Unknown message type, stream: {stream}, message: {message}")
|
|
return None
|
|
|
|
async def write_message(self, message: Dict[str, Any]) -> bool:
|
|
"""
|
|
Write single message to appropriate Redis Stream
|
|
|
|
Args:
|
|
message: Message data
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
if not self.is_connected or not self.redis_client:
|
|
logger.error("Redis client not connected")
|
|
return False
|
|
|
|
try:
|
|
# Determine stream key
|
|
stream_key = self._determine_stream_key(message)
|
|
if not stream_key:
|
|
return False
|
|
|
|
# Serialize message
|
|
message_json = self._serialize_message(message)
|
|
|
|
# Write to Redis Stream with MAXLEN
|
|
await self.redis_client.xadd(
|
|
name=stream_key,
|
|
fields={"data": message_json},
|
|
maxlen=settings.REDIS_STREAM_MAXLEN,
|
|
approximate=True, # Use ~ for better performance
|
|
)
|
|
|
|
# Update statistics
|
|
self.stats["messages_written"] += 1
|
|
if "kline" in stream_key:
|
|
self.stats["kline_count"] += 1
|
|
elif "depth" in stream_key:
|
|
self.stats["depth_count"] += 1
|
|
elif "trade" in stream_key:
|
|
self.stats["trade_count"] += 1
|
|
|
|
return True
|
|
|
|
except RedisError as e:
|
|
logger.error(f"Redis error writing message: {e}")
|
|
self.stats["errors"] += 1
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error writing message: {e}", exc_info=True)
|
|
self.stats["errors"] += 1
|
|
return False
|
|
|
|
async def write_batch(self, messages: list[Dict[str, Any]]) -> int:
|
|
"""
|
|
Write batch of messages using pipeline
|
|
|
|
Args:
|
|
messages: List of messages
|
|
|
|
Returns:
|
|
Number of successfully written messages
|
|
"""
|
|
if not self.is_connected or not self.redis_client:
|
|
logger.error("Redis client not connected")
|
|
return 0
|
|
|
|
if not messages:
|
|
return 0
|
|
|
|
try:
|
|
# Group messages by stream key
|
|
streams: Dict[str, list[bytes]] = {}
|
|
|
|
for message in messages:
|
|
stream_key = self._determine_stream_key(message)
|
|
if not stream_key:
|
|
continue
|
|
|
|
message_json = self._serialize_message(message)
|
|
|
|
if stream_key not in streams:
|
|
streams[stream_key] = []
|
|
streams[stream_key].append(message_json)
|
|
|
|
# Write using pipeline
|
|
async with self.redis_client.pipeline(transaction=False) as pipe:
|
|
for stream_key, stream_messages in streams.items():
|
|
for msg in stream_messages:
|
|
pipe.xadd(
|
|
name=stream_key,
|
|
fields={"data": msg},
|
|
maxlen=settings.REDIS_STREAM_MAXLEN,
|
|
approximate=True,
|
|
)
|
|
|
|
await pipe.execute()
|
|
|
|
# Update statistics
|
|
total_written = sum(len(msgs) for msgs in streams.values())
|
|
self.stats["messages_written"] += total_written
|
|
|
|
return total_written
|
|
|
|
except RedisError as e:
|
|
logger.error(f"Redis error in batch write: {e}")
|
|
self.stats["errors"] += 1
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in batch write: {e}", exc_info=True)
|
|
self.stats["errors"] += 1
|
|
return 0
|
|
|
|
async def health_check(self) -> bool:
|
|
"""Check Redis connection health"""
|
|
try:
|
|
if not self.redis_client:
|
|
return False
|
|
await self.redis_client.ping()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get writer statistics"""
|
|
return {**self.stats}
|