356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""
|
||
价格监控服务 - 支持 WebSocket 实时推送和轮询两种模式(统一使用 Bitget 数据源)
|
||
"""
|
||
import asyncio
|
||
import threading
|
||
import time
|
||
import requests
|
||
from typing import Dict, List, Callable, Optional, Set
|
||
from app.utils.logger import logger
|
||
from app.config import get_settings
|
||
|
||
|
||
class PriceMonitorService:
|
||
"""实时价格监控服务(支持 WebSocket 和轮询两种模式,统一使用 Bitget)"""
|
||
|
||
_instance = None
|
||
_initialized = False
|
||
|
||
def __new__(cls):
|
||
"""单例模式 - 确保只有一个实例"""
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
"""初始化价格监控服务"""
|
||
# 防止重复初始化
|
||
if PriceMonitorService._initialized:
|
||
return
|
||
|
||
PriceMonitorService._initialized = True
|
||
self.settings = get_settings()
|
||
self.running = False
|
||
self.subscribed_symbols: Set[str] = set()
|
||
self.price_callbacks: List[Callable[[str, float], None]] = []
|
||
self.latest_prices: Dict[str, float] = {}
|
||
self._lock = threading.Lock()
|
||
|
||
# 模式选择
|
||
self._use_websocket = getattr(self.settings, 'use_bitget_websocket', False)
|
||
|
||
# 轮询模式相关(使用 Bitget REST API)
|
||
self._poll_thread: Optional[threading.Thread] = None
|
||
self._poll_interval = 3 # 轮询间隔(秒)
|
||
self._session = requests.Session()
|
||
self._bitget_rest_url = "https://api.bitget.com" # Bitget REST API
|
||
|
||
# WebSocket 模式相关
|
||
self._ws_thread: Optional[threading.Thread] = None
|
||
self._ws_loop: Optional[asyncio.AbstractEventLoop] = None
|
||
self._ws_client = None
|
||
|
||
logger.info(f"价格监控服务初始化完成 (模式: {'Bitget WebSocket' if self._use_websocket else 'Bitget REST 轮询'})")
|
||
|
||
def start(self):
|
||
"""启动价格监控"""
|
||
if self.running:
|
||
logger.debug("价格监控服务已在运行")
|
||
return
|
||
|
||
self.running = True
|
||
|
||
if self._use_websocket:
|
||
self._start_websocket()
|
||
else:
|
||
self._start_polling()
|
||
|
||
def _start_polling(self):
|
||
"""启动轮询模式"""
|
||
def _poll_loop():
|
||
logger.info(f"价格轮询已启动,间隔 {self._poll_interval} 秒")
|
||
while self.running:
|
||
try:
|
||
self._fetch_prices()
|
||
except Exception as e:
|
||
logger.error(f"获取价格失败: {e}")
|
||
|
||
# 等待下一次轮询
|
||
for _ in range(self._poll_interval * 10):
|
||
if not self.running:
|
||
break
|
||
time.sleep(0.1)
|
||
|
||
# 防止重复创建线程
|
||
if self._poll_thread is None or not self._poll_thread.is_alive():
|
||
self._poll_thread = threading.Thread(target=_poll_loop, daemon=True)
|
||
self._poll_thread.start()
|
||
|
||
def _start_websocket(self):
|
||
"""启动 WebSocket 模式"""
|
||
def _run_ws():
|
||
"""在新线程中运行 WebSocket"""
|
||
# 创建新的事件循环
|
||
self._ws_loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(self._ws_loop)
|
||
|
||
try:
|
||
# 导入 WebSocket 客户端
|
||
from app.services.bitget_websocket import get_bitget_ws_client
|
||
|
||
self._ws_client = get_bitget_ws_client()
|
||
if not self._ws_client:
|
||
logger.error("无法创建 WebSocket 客户端")
|
||
return
|
||
|
||
# 连接 - 在事件循环中运行
|
||
connect_task = self._ws_loop.create_task(self._ws_client.connect())
|
||
self._ws_loop.run_until_complete(connect_task)
|
||
|
||
if not connect_task.result():
|
||
logger.error("WebSocket 连接失败")
|
||
return
|
||
|
||
# 注册价格更新回调
|
||
self._ws_client.on_price_update('*', self._on_ws_price_update)
|
||
|
||
# 订阅已有的交易对
|
||
if self.subscribed_symbols:
|
||
subscribe_task = self._ws_loop.create_task(
|
||
self._ws_client.subscribe(list(self.subscribed_symbols))
|
||
)
|
||
self._ws_loop.run_until_complete(subscribe_task)
|
||
|
||
# 运行事件循环
|
||
self._ws_loop.run_forever()
|
||
|
||
except Exception as e:
|
||
logger.error(f"WebSocket 线程错误: {e}")
|
||
finally:
|
||
logger.info("WebSocket 线程已退出")
|
||
|
||
# 防止重复创建线程
|
||
if self._ws_thread is None or not self._ws_thread.is_alive():
|
||
self._ws_thread = threading.Thread(target=_run_ws, daemon=True)
|
||
self._ws_thread.start()
|
||
|
||
def _on_ws_price_update(self, symbol: str, price: float, data: Dict):
|
||
"""
|
||
WebSocket 价格更新回调
|
||
|
||
Args:
|
||
symbol: 交易对
|
||
price: 新价格
|
||
data: 完整的 ticker 数据
|
||
"""
|
||
self._update_price(symbol, price)
|
||
|
||
def stop(self):
|
||
"""停止价格监控"""
|
||
if not self.running:
|
||
return
|
||
|
||
self.running = False
|
||
|
||
# 停止 WebSocket
|
||
if self._use_websocket and self._ws_loop and self._ws_client:
|
||
# 在事件循环中停止 WebSocket
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._ws_client.disconnect(),
|
||
self._ws_loop
|
||
)
|
||
# 停止事件循环
|
||
self._ws_loop.call_soon_threadsafe(self._ws_loop.stop)
|
||
|
||
# 停止轮询
|
||
if self._poll_thread:
|
||
self._poll_thread.join(timeout=2)
|
||
|
||
logger.info("价格监控服务已停止")
|
||
|
||
def _fetch_prices(self):
|
||
"""获取所有订阅交易对的价格(轮询模式)"""
|
||
if not self.subscribed_symbols:
|
||
return
|
||
|
||
symbols = list(self.subscribed_symbols)
|
||
|
||
# 如果只有少量交易对,逐个获取
|
||
if len(symbols) <= 3:
|
||
for symbol in symbols:
|
||
self._fetch_single_price(symbol)
|
||
else:
|
||
# 批量获取所有价格
|
||
self._fetch_all_prices(symbols)
|
||
|
||
def _fetch_single_price(self, symbol: str):
|
||
"""获取单个交易对价格(使用 Bitget REST API)"""
|
||
try:
|
||
url = f"{self._bitget_rest_url}/api/v3/market/tickers"
|
||
params = {
|
||
'category': 'USDT-FUTURES',
|
||
'symbol': symbol
|
||
}
|
||
response = self._session.get(url, params=params, timeout=5)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
if result.get('code') != '00000':
|
||
logger.debug(f"Bitget API 错误: {result.get('msg')}")
|
||
return
|
||
|
||
data = result.get('data', [])
|
||
if data:
|
||
price = float(data[0].get('lastPr', '0'))
|
||
if price > 0:
|
||
self._update_price(symbol, price)
|
||
|
||
except Exception as e:
|
||
logger.debug(f"获取 {symbol} 价格失败: {e}")
|
||
|
||
def _fetch_all_prices(self, symbols: List[str]):
|
||
"""批量获取价格(使用 Bitget REST API)"""
|
||
try:
|
||
url = f"{self._bitget_rest_url}/api/v3/market/tickers"
|
||
params = {'category': 'USDT-FUTURES'}
|
||
response = self._session.get(url, params=params, timeout=10)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
if result.get('code') != '00000':
|
||
logger.debug(f"Bitget API 错误: {result.get('msg')}")
|
||
return
|
||
|
||
all_tickers = result.get('data', [])
|
||
symbol_set = set(symbols)
|
||
|
||
# 过滤出订阅的交易对
|
||
for ticker in all_tickers:
|
||
symbol = ticker.get('instId')
|
||
if symbol in symbol_set:
|
||
price = float(ticker.get('lastPr', '0'))
|
||
if price > 0:
|
||
self._update_price(symbol, price)
|
||
|
||
except Exception as e:
|
||
logger.debug(f"批量获取价格失败: {e}")
|
||
|
||
def _update_price(self, symbol: str, price: float):
|
||
"""更新价格并触发回调"""
|
||
old_price = self.latest_prices.get(symbol)
|
||
|
||
# 只有价格变化时才触发回调
|
||
if old_price != price:
|
||
self.latest_prices[symbol] = price
|
||
|
||
# 调用所有注册的回调函数
|
||
with self._lock:
|
||
callbacks = self.price_callbacks.copy()
|
||
|
||
for callback in callbacks:
|
||
try:
|
||
callback(symbol, price)
|
||
except Exception as e:
|
||
logger.error(f"价格回调执行出错: {e}")
|
||
|
||
def subscribe_symbol(self, symbol: str):
|
||
"""
|
||
订阅交易对的实时价格
|
||
|
||
Args:
|
||
symbol: 交易对,如 "BTCUSDT"
|
||
"""
|
||
symbol = symbol.upper()
|
||
|
||
if symbol in self.subscribed_symbols:
|
||
logger.debug(f"{symbol} 已订阅,跳过")
|
||
return
|
||
|
||
self.subscribed_symbols.add(symbol)
|
||
logger.info(f"已订阅 {symbol} 价格更新 (当前订阅: {len(self.subscribed_symbols)} 个)")
|
||
|
||
# WebSocket 模式:立即订阅
|
||
if self._use_websocket and self._ws_client and self._ws_loop and self._ws_loop.is_running():
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._ws_client.subscribe([symbol]),
|
||
self._ws_loop
|
||
)
|
||
elif self._use_websocket:
|
||
# WebSocket 还未就绪,将在连接后自动订阅
|
||
logger.debug(f"WebSocket 未就绪,{symbol} 将在连接后自动订阅")
|
||
# 轮询模式:立即获取一次价格
|
||
elif not self._use_websocket:
|
||
self._fetch_single_price(symbol)
|
||
|
||
# 如果服务未启动,自动启动
|
||
if not self.running:
|
||
self.start()
|
||
|
||
def unsubscribe_symbol(self, symbol: str):
|
||
"""取消订阅交易对"""
|
||
symbol = symbol.upper()
|
||
|
||
if symbol in self.subscribed_symbols:
|
||
self.subscribed_symbols.discard(symbol)
|
||
self.latest_prices.pop(symbol, None)
|
||
|
||
# WebSocket 模式:取消订阅
|
||
if self._use_websocket and self._ws_client:
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._ws_client.unsubscribe([symbol]),
|
||
self._ws_loop
|
||
)
|
||
|
||
logger.info(f"已取消订阅 {symbol}")
|
||
|
||
def add_price_callback(self, callback: Callable[[str, float], None]):
|
||
"""
|
||
添加价格更新回调函数
|
||
|
||
Args:
|
||
callback: 回调函数,签名为 (symbol: str, price: float) -> None
|
||
"""
|
||
with self._lock:
|
||
if callback not in self.price_callbacks:
|
||
self.price_callbacks.append(callback)
|
||
|
||
def remove_price_callback(self, callback: Callable):
|
||
"""移除价格回调函数"""
|
||
with self._lock:
|
||
if callback in self.price_callbacks:
|
||
self.price_callbacks.remove(callback)
|
||
|
||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||
"""获取交易对的最新缓存价格"""
|
||
return self.latest_prices.get(symbol.upper())
|
||
|
||
def get_all_prices(self) -> Dict[str, float]:
|
||
"""获取所有订阅交易对的最新价格"""
|
||
return self.latest_prices.copy()
|
||
|
||
def get_subscribed_symbols(self) -> List[str]:
|
||
"""获取已订阅的交易对列表"""
|
||
return list(self.subscribed_symbols)
|
||
|
||
def is_running(self) -> bool:
|
||
"""检查服务是否在运行"""
|
||
return self.running
|
||
|
||
def set_poll_interval(self, seconds: int):
|
||
"""设置轮询间隔(秒),仅对轮询模式有效"""
|
||
if self._use_websocket:
|
||
logger.warning("WebSocket 模式下不支持设置轮询间隔")
|
||
return
|
||
self._poll_interval = max(1, seconds)
|
||
logger.info(f"轮询间隔已设置为 {self._poll_interval} 秒")
|
||
|
||
|
||
# 全局单例
|
||
_price_monitor_service: Optional[PriceMonitorService] = None
|
||
|
||
|
||
def get_price_monitor_service() -> PriceMonitorService:
|
||
"""获取价格监控服务单例"""
|
||
# 直接使用类单例,不使用全局变量(避免 reload 时重置)
|
||
return PriceMonitorService()
|