208 lines
6.7 KiB
Python
208 lines
6.7 KiB
Python
"""
|
|
价格监控服务 - 使用轮询方式获取实时价格(更稳定)
|
|
"""
|
|
import threading
|
|
import time
|
|
import requests
|
|
from typing import Dict, List, Callable, Optional, Set
|
|
from app.utils.logger import logger
|
|
from app.config import get_settings
|
|
|
|
|
|
class PriceMonitorService:
|
|
"""实时价格监控服务(轮询模式)"""
|
|
|
|
# Binance API
|
|
BASE_URL = "https://api.binance.com"
|
|
|
|
def __init__(self):
|
|
"""初始化价格监控服务"""
|
|
self.settings = get_settings()
|
|
self.running = False
|
|
self.subscribed_symbols: Set[str] = set()
|
|
self.price_callbacks: List[Callable[[str, float], None]] = []
|
|
self.latest_prices: Dict[str, float] = {}
|
|
self._lock = threading.Lock()
|
|
self._poll_thread: Optional[threading.Thread] = None
|
|
self._poll_interval = 3 # 轮询间隔(秒)
|
|
self._session = requests.Session()
|
|
|
|
logger.info("价格监控服务初始化完成(轮询模式)")
|
|
|
|
def start(self):
|
|
"""启动价格轮询"""
|
|
if self.running:
|
|
logger.warning("价格监控服务已在运行")
|
|
return
|
|
|
|
self.running = True
|
|
|
|
def _poll_loop():
|
|
logger.info(f"价格轮询已启动,间隔 {self._poll_interval} 秒")
|
|
while self.running:
|
|
try:
|
|
self._fetch_prices()
|
|
except Exception as e:
|
|
logger.error(f"获取价格失败: {e}")
|
|
|
|
# 等待下一次轮询
|
|
for _ in range(self._poll_interval * 10):
|
|
if not self.running:
|
|
break
|
|
time.sleep(0.1)
|
|
|
|
self._poll_thread = threading.Thread(target=_poll_loop, daemon=True)
|
|
self._poll_thread.start()
|
|
|
|
def stop(self):
|
|
"""停止价格轮询"""
|
|
if not self.running:
|
|
return
|
|
|
|
self.running = False
|
|
logger.info("价格监控服务已停止")
|
|
|
|
def _fetch_prices(self):
|
|
"""获取所有订阅交易对的价格"""
|
|
if not self.subscribed_symbols:
|
|
return
|
|
|
|
symbols = list(self.subscribed_symbols)
|
|
|
|
# 如果只有少量交易对,逐个获取
|
|
if len(symbols) <= 3:
|
|
for symbol in symbols:
|
|
self._fetch_single_price(symbol)
|
|
else:
|
|
# 批量获取所有价格
|
|
self._fetch_all_prices(symbols)
|
|
|
|
def _fetch_single_price(self, symbol: str):
|
|
"""获取单个交易对价格"""
|
|
try:
|
|
url = f"{self.BASE_URL}/api/v3/ticker/price"
|
|
response = self._session.get(url, params={'symbol': symbol}, timeout=5)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
price = float(data['price'])
|
|
self._update_price(symbol, price)
|
|
except Exception as e:
|
|
logger.debug(f"获取 {symbol} 价格失败: {e}")
|
|
|
|
def _fetch_all_prices(self, symbols: List[str]):
|
|
"""批量获取价格"""
|
|
try:
|
|
url = f"{self.BASE_URL}/api/v3/ticker/price"
|
|
response = self._session.get(url, timeout=10)
|
|
response.raise_for_status()
|
|
all_prices = response.json()
|
|
|
|
# 过滤出订阅的交易对
|
|
symbol_set = set(symbols)
|
|
for item in all_prices:
|
|
symbol = item['symbol']
|
|
if symbol in symbol_set:
|
|
price = float(item['price'])
|
|
self._update_price(symbol, price)
|
|
except Exception as e:
|
|
logger.debug(f"批量获取价格失败: {e}")
|
|
|
|
def _update_price(self, symbol: str, price: float):
|
|
"""更新价格并触发回调"""
|
|
old_price = self.latest_prices.get(symbol)
|
|
|
|
# 只有价格变化时才触发回调
|
|
if old_price != price:
|
|
self.latest_prices[symbol] = price
|
|
|
|
# 调用所有注册的回调函数
|
|
with self._lock:
|
|
callbacks = self.price_callbacks.copy()
|
|
|
|
for callback in callbacks:
|
|
try:
|
|
callback(symbol, price)
|
|
except Exception as e:
|
|
logger.error(f"价格回调执行出错: {e}")
|
|
|
|
def subscribe_symbol(self, symbol: str):
|
|
"""
|
|
订阅交易对的实时价格
|
|
|
|
Args:
|
|
symbol: 交易对,如 "BTCUSDT"
|
|
"""
|
|
symbol = symbol.upper()
|
|
|
|
if symbol in self.subscribed_symbols:
|
|
logger.debug(f"已订阅 {symbol}")
|
|
return
|
|
|
|
self.subscribed_symbols.add(symbol)
|
|
logger.info(f"已订阅 {symbol} 价格更新")
|
|
|
|
# 如果服务未启动,自动启动
|
|
if not self.running:
|
|
self.start()
|
|
|
|
# 立即获取一次价格
|
|
self._fetch_single_price(symbol)
|
|
|
|
def unsubscribe_symbol(self, symbol: str):
|
|
"""取消订阅交易对"""
|
|
symbol = symbol.upper()
|
|
if symbol in self.subscribed_symbols:
|
|
self.subscribed_symbols.discard(symbol)
|
|
self.latest_prices.pop(symbol, None)
|
|
logger.info(f"已取消订阅 {symbol}")
|
|
|
|
def add_price_callback(self, callback: Callable[[str, float], None]):
|
|
"""
|
|
添加价格更新回调函数
|
|
|
|
Args:
|
|
callback: 回调函数,签名为 (symbol: str, price: float) -> None
|
|
"""
|
|
with self._lock:
|
|
if callback not in self.price_callbacks:
|
|
self.price_callbacks.append(callback)
|
|
|
|
def remove_price_callback(self, callback: Callable):
|
|
"""移除价格回调函数"""
|
|
with self._lock:
|
|
if callback in self.price_callbacks:
|
|
self.price_callbacks.remove(callback)
|
|
|
|
def get_latest_price(self, symbol: str) -> Optional[float]:
|
|
"""获取交易对的最新缓存价格"""
|
|
return self.latest_prices.get(symbol.upper())
|
|
|
|
def get_all_prices(self) -> Dict[str, float]:
|
|
"""获取所有订阅交易对的最新价格"""
|
|
return self.latest_prices.copy()
|
|
|
|
def get_subscribed_symbols(self) -> List[str]:
|
|
"""获取已订阅的交易对列表"""
|
|
return list(self.subscribed_symbols)
|
|
|
|
def is_running(self) -> bool:
|
|
"""检查服务是否在运行"""
|
|
return self.running
|
|
|
|
def set_poll_interval(self, seconds: int):
|
|
"""设置轮询间隔(秒)"""
|
|
self._poll_interval = max(1, seconds)
|
|
logger.info(f"轮询间隔已设置为 {self._poll_interval} 秒")
|
|
|
|
|
|
# 全局单例
|
|
_price_monitor_service: Optional[PriceMonitorService] = None
|
|
|
|
|
|
def get_price_monitor_service() -> PriceMonitorService:
|
|
"""获取价格监控服务单例"""
|
|
global _price_monitor_service
|
|
if _price_monitor_service is None:
|
|
_price_monitor_service = PriceMonitorService()
|
|
return _price_monitor_service
|