tradusai/core/redis_writer.py
2025-12-02 22:54:03 +08:00

248 lines
7.7 KiB
Python

"""
Redis Stream writer with batch support and error handling
"""
import asyncio
import logging
from typing import Dict, Any, Optional
import orjson
import redis.asyncio as redis
from redis.exceptions import RedisError, ConnectionError as RedisConnectionError
from config import settings
logger = logging.getLogger(__name__)
class RedisStreamWriter:
"""
Redis Stream writer for real-time market data.
Features:
- Async Redis client with connection pooling
- Automatic stream trimming (MAXLEN)
- JSON serialization with orjson
- Connection retry logic
- Performance metrics
"""
def __init__(self):
self.redis_client: Optional[redis.Redis] = None
self.is_connected = False
# Statistics
self.stats = {
"messages_written": 0,
"kline_count": 0,
"depth_count": 0,
"trade_count": 0,
"errors": 0,
}
async def connect(self) -> None:
"""Establish Redis connection"""
try:
self.redis_client = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
encoding="utf-8",
decode_responses=False, # We'll handle JSON encoding
socket_connect_timeout=5,
socket_keepalive=True,
health_check_interval=30,
)
# Test connection
await self.redis_client.ping()
self.is_connected = True
logger.info("✓ Redis connection established")
except RedisConnectionError as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error connecting to Redis: {e}")
raise
async def close(self) -> None:
"""Close Redis connection"""
if self.redis_client:
await self.redis_client.close()
self.is_connected = False
logger.info("Redis connection closed")
def _serialize_message(self, message: Dict[str, Any]) -> bytes:
"""
Serialize message to JSON bytes using orjson
Args:
message: Message data
Returns:
JSON bytes
"""
return orjson.dumps(message)
def _determine_stream_key(self, message: Dict[str, Any]) -> Optional[str]:
"""
Determine which Redis Stream to write to based on message type
Args:
message: Message data
Returns:
Redis stream key or None if unknown type
"""
stream = message.get("_stream", "")
# Kline stream - extract interval from stream name
if "kline" in stream or ("e" in message and message["e"] == "kline"):
# Extract interval from stream name (e.g., "btcusdt@kline_5m" -> "5m")
if "@kline_" in stream:
interval = stream.split("@kline_")[1]
return f"{settings.REDIS_STREAM_KLINE_PREFIX}:{interval}"
# Fallback: extract from message data
elif "k" in message and "i" in message["k"]:
interval = message["k"]["i"]
return f"{settings.REDIS_STREAM_KLINE_PREFIX}:{interval}"
# Depth stream
if "depth" in stream or ("e" in message and message["e"] == "depthUpdate"):
return settings.REDIS_STREAM_DEPTH
# Trade stream
if "trade" in stream or "aggTrade" in stream or ("e" in message and message["e"] in ["trade", "aggTrade"]):
return settings.REDIS_STREAM_TRADE
logger.warning(f"Unknown message type, stream: {stream}, message: {message}")
return None
async def write_message(self, message: Dict[str, Any]) -> bool:
"""
Write single message to appropriate Redis Stream
Args:
message: Message data
Returns:
True if successful, False otherwise
"""
if not self.is_connected or not self.redis_client:
logger.error("Redis client not connected")
return False
try:
# Determine stream key
stream_key = self._determine_stream_key(message)
if not stream_key:
return False
# Serialize message
message_json = self._serialize_message(message)
# Write to Redis Stream with MAXLEN
await self.redis_client.xadd(
name=stream_key,
fields={"data": message_json},
maxlen=settings.REDIS_STREAM_MAXLEN,
approximate=True, # Use ~ for better performance
)
# Update statistics
self.stats["messages_written"] += 1
if "kline" in stream_key:
self.stats["kline_count"] += 1
elif "depth" in stream_key:
self.stats["depth_count"] += 1
elif "trade" in stream_key:
self.stats["trade_count"] += 1
return True
except RedisError as e:
logger.error(f"Redis error writing message: {e}")
self.stats["errors"] += 1
return False
except Exception as e:
logger.error(f"Unexpected error writing message: {e}", exc_info=True)
self.stats["errors"] += 1
return False
async def write_batch(self, messages: list[Dict[str, Any]]) -> int:
"""
Write batch of messages using pipeline
Args:
messages: List of messages
Returns:
Number of successfully written messages
"""
if not self.is_connected or not self.redis_client:
logger.error("Redis client not connected")
return 0
if not messages:
return 0
try:
# Group messages by stream key
streams: Dict[str, list[bytes]] = {}
for message in messages:
stream_key = self._determine_stream_key(message)
if not stream_key:
continue
message_json = self._serialize_message(message)
if stream_key not in streams:
streams[stream_key] = []
streams[stream_key].append(message_json)
# Write using pipeline
async with self.redis_client.pipeline(transaction=False) as pipe:
for stream_key, stream_messages in streams.items():
for msg in stream_messages:
pipe.xadd(
name=stream_key,
fields={"data": msg},
maxlen=settings.REDIS_STREAM_MAXLEN,
approximate=True,
)
await pipe.execute()
# Update statistics
total_written = sum(len(msgs) for msgs in streams.values())
self.stats["messages_written"] += total_written
return total_written
except RedisError as e:
logger.error(f"Redis error in batch write: {e}")
self.stats["errors"] += 1
return 0
except Exception as e:
logger.error(f"Unexpected error in batch write: {e}", exc_info=True)
self.stats["errors"] += 1
return 0
async def health_check(self) -> bool:
"""Check Redis connection health"""
try:
if not self.redis_client:
return False
await self.redis_client.ping()
return True
except Exception:
return False
def get_stats(self) -> Dict[str, Any]:
"""Get writer statistics"""
return {**self.stats}