stock-ai-agent/backend/app/services/price_monitor_service.py
2026-02-25 23:28:04 +08:00

356 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
价格监控服务 - 支持 WebSocket 实时推送和轮询两种模式(统一使用 Bitget 数据源)
"""
import asyncio
import threading
import time
import requests
from typing import Dict, List, Callable, Optional, Set
from app.utils.logger import logger
from app.config import get_settings
class PriceMonitorService:
"""实时价格监控服务(支持 WebSocket 和轮询两种模式,统一使用 Bitget"""
_instance = None
_initialized = False
def __new__(cls):
"""单例模式 - 确保只有一个实例"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""初始化价格监控服务"""
# 防止重复初始化
if PriceMonitorService._initialized:
return
PriceMonitorService._initialized = True
self.settings = get_settings()
self.running = False
self.subscribed_symbols: Set[str] = set()
self.price_callbacks: List[Callable[[str, float], None]] = []
self.latest_prices: Dict[str, float] = {}
self._lock = threading.Lock()
# 模式选择
self._use_websocket = getattr(self.settings, 'use_bitget_websocket', False)
# 轮询模式相关(使用 Bitget REST API
self._poll_thread: Optional[threading.Thread] = None
self._poll_interval = 3 # 轮询间隔(秒)
self._session = requests.Session()
self._bitget_rest_url = "https://api.bitget.com" # Bitget REST API
# WebSocket 模式相关
self._ws_thread: Optional[threading.Thread] = None
self._ws_loop: Optional[asyncio.AbstractEventLoop] = None
self._ws_client = None
logger.info(f"价格监控服务初始化完成 (模式: {'Bitget WebSocket' if self._use_websocket else 'Bitget REST 轮询'})")
def start(self):
"""启动价格监控"""
if self.running:
logger.debug("价格监控服务已在运行")
return
self.running = True
if self._use_websocket:
self._start_websocket()
else:
self._start_polling()
def _start_polling(self):
"""启动轮询模式"""
def _poll_loop():
logger.info(f"价格轮询已启动,间隔 {self._poll_interval}")
while self.running:
try:
self._fetch_prices()
except Exception as e:
logger.error(f"获取价格失败: {e}")
# 等待下一次轮询
for _ in range(self._poll_interval * 10):
if not self.running:
break
time.sleep(0.1)
# 防止重复创建线程
if self._poll_thread is None or not self._poll_thread.is_alive():
self._poll_thread = threading.Thread(target=_poll_loop, daemon=True)
self._poll_thread.start()
def _start_websocket(self):
"""启动 WebSocket 模式"""
def _run_ws():
"""在新线程中运行 WebSocket"""
# 创建新的事件循环
self._ws_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._ws_loop)
try:
# 导入 WebSocket 客户端
from app.services.bitget_websocket import get_bitget_ws_client
self._ws_client = get_bitget_ws_client()
if not self._ws_client:
logger.error("无法创建 WebSocket 客户端")
return
# 连接 - 在事件循环中运行
connect_task = self._ws_loop.create_task(self._ws_client.connect())
self._ws_loop.run_until_complete(connect_task)
if not connect_task.result():
logger.error("WebSocket 连接失败")
return
# 注册价格更新回调
self._ws_client.on_price_update('*', self._on_ws_price_update)
# 订阅已有的交易对
if self.subscribed_symbols:
subscribe_task = self._ws_loop.create_task(
self._ws_client.subscribe(list(self.subscribed_symbols))
)
self._ws_loop.run_until_complete(subscribe_task)
# 运行事件循环
self._ws_loop.run_forever()
except Exception as e:
logger.error(f"WebSocket 线程错误: {e}")
finally:
logger.info("WebSocket 线程已退出")
# 防止重复创建线程
if self._ws_thread is None or not self._ws_thread.is_alive():
self._ws_thread = threading.Thread(target=_run_ws, daemon=True)
self._ws_thread.start()
def _on_ws_price_update(self, symbol: str, price: float, data: Dict):
"""
WebSocket 价格更新回调
Args:
symbol: 交易对
price: 新价格
data: 完整的 ticker 数据
"""
self._update_price(symbol, price)
def stop(self):
"""停止价格监控"""
if not self.running:
return
self.running = False
# 停止 WebSocket
if self._use_websocket and self._ws_loop and self._ws_client:
# 在事件循环中停止 WebSocket
asyncio.run_coroutine_threadsafe(
self._ws_client.disconnect(),
self._ws_loop
)
# 停止事件循环
self._ws_loop.call_soon_threadsafe(self._ws_loop.stop)
# 停止轮询
if self._poll_thread:
self._poll_thread.join(timeout=2)
logger.info("价格监控服务已停止")
def _fetch_prices(self):
"""获取所有订阅交易对的价格(轮询模式)"""
if not self.subscribed_symbols:
return
symbols = list(self.subscribed_symbols)
# 如果只有少量交易对,逐个获取
if len(symbols) <= 3:
for symbol in symbols:
self._fetch_single_price(symbol)
else:
# 批量获取所有价格
self._fetch_all_prices(symbols)
def _fetch_single_price(self, symbol: str):
"""获取单个交易对价格(使用 Bitget REST API"""
try:
url = f"{self._bitget_rest_url}/api/v3/market/tickers"
params = {
'category': 'USDT-FUTURES',
'symbol': symbol
}
response = self._session.get(url, params=params, timeout=5)
response.raise_for_status()
result = response.json()
if result.get('code') != '00000':
logger.debug(f"Bitget API 错误: {result.get('msg')}")
return
data = result.get('data', [])
if data:
price = float(data[0].get('lastPr', '0'))
if price > 0:
self._update_price(symbol, price)
except Exception as e:
logger.debug(f"获取 {symbol} 价格失败: {e}")
def _fetch_all_prices(self, symbols: List[str]):
"""批量获取价格(使用 Bitget REST API"""
try:
url = f"{self._bitget_rest_url}/api/v3/market/tickers"
params = {'category': 'USDT-FUTURES'}
response = self._session.get(url, params=params, timeout=10)
response.raise_for_status()
result = response.json()
if result.get('code') != '00000':
logger.debug(f"Bitget API 错误: {result.get('msg')}")
return
all_tickers = result.get('data', [])
symbol_set = set(symbols)
# 过滤出订阅的交易对
for ticker in all_tickers:
symbol = ticker.get('instId')
if symbol in symbol_set:
price = float(ticker.get('lastPr', '0'))
if price > 0:
self._update_price(symbol, price)
except Exception as e:
logger.debug(f"批量获取价格失败: {e}")
def _update_price(self, symbol: str, price: float):
"""更新价格并触发回调"""
old_price = self.latest_prices.get(symbol)
# 只有价格变化时才触发回调
if old_price != price:
self.latest_prices[symbol] = price
# 调用所有注册的回调函数
with self._lock:
callbacks = self.price_callbacks.copy()
for callback in callbacks:
try:
callback(symbol, price)
except Exception as e:
logger.error(f"价格回调执行出错: {e}")
def subscribe_symbol(self, symbol: str):
"""
订阅交易对的实时价格
Args:
symbol: 交易对,如 "BTCUSDT"
"""
symbol = symbol.upper()
if symbol in self.subscribed_symbols:
logger.debug(f"{symbol} 已订阅,跳过")
return
self.subscribed_symbols.add(symbol)
logger.info(f"已订阅 {symbol} 价格更新 (当前订阅: {len(self.subscribed_symbols)} 个)")
# WebSocket 模式:立即订阅
if self._use_websocket and self._ws_client and self._ws_loop and self._ws_loop.is_running():
asyncio.run_coroutine_threadsafe(
self._ws_client.subscribe([symbol]),
self._ws_loop
)
elif self._use_websocket:
# WebSocket 还未就绪,将在连接后自动订阅
logger.debug(f"WebSocket 未就绪,{symbol} 将在连接后自动订阅")
# 轮询模式:立即获取一次价格
elif not self._use_websocket:
self._fetch_single_price(symbol)
# 如果服务未启动,自动启动
if not self.running:
self.start()
def unsubscribe_symbol(self, symbol: str):
"""取消订阅交易对"""
symbol = symbol.upper()
if symbol in self.subscribed_symbols:
self.subscribed_symbols.discard(symbol)
self.latest_prices.pop(symbol, None)
# WebSocket 模式:取消订阅
if self._use_websocket and self._ws_client:
asyncio.run_coroutine_threadsafe(
self._ws_client.unsubscribe([symbol]),
self._ws_loop
)
logger.info(f"已取消订阅 {symbol}")
def add_price_callback(self, callback: Callable[[str, float], None]):
"""
添加价格更新回调函数
Args:
callback: 回调函数,签名为 (symbol: str, price: float) -> None
"""
with self._lock:
if callback not in self.price_callbacks:
self.price_callbacks.append(callback)
def remove_price_callback(self, callback: Callable):
"""移除价格回调函数"""
with self._lock:
if callback in self.price_callbacks:
self.price_callbacks.remove(callback)
def get_latest_price(self, symbol: str) -> Optional[float]:
"""获取交易对的最新缓存价格"""
return self.latest_prices.get(symbol.upper())
def get_all_prices(self) -> Dict[str, float]:
"""获取所有订阅交易对的最新价格"""
return self.latest_prices.copy()
def get_subscribed_symbols(self) -> List[str]:
"""获取已订阅的交易对列表"""
return list(self.subscribed_symbols)
def is_running(self) -> bool:
"""检查服务是否在运行"""
return self.running
def set_poll_interval(self, seconds: int):
"""设置轮询间隔(秒),仅对轮询模式有效"""
if self._use_websocket:
logger.warning("WebSocket 模式下不支持设置轮询间隔")
return
self._poll_interval = max(1, seconds)
logger.info(f"轮询间隔已设置为 {self._poll_interval}")
# 全局单例
_price_monitor_service: Optional[PriceMonitorService] = None
def get_price_monitor_service() -> PriceMonitorService:
"""获取价格监控服务单例"""
# 直接使用类单例,不使用全局变量(避免 reload 时重置)
return PriceMonitorService()