stock-ai-agent/backend/app/services/binance_service.py
2026-02-19 19:32:46 +08:00

278 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Binance 数据服务 - 获取加密货币 K 线数据和技术指标
使用 requests 直接调用 REST API避免与 WebSocket 的事件循环冲突
"""
import pandas as pd
import numpy as np
import requests
from typing import Dict, List, Optional, Any
from app.utils.logger import logger
class BinanceService:
"""Binance 数据服务(使用 requests 直接调用 REST API"""
# K线周期映射
INTERVALS = {
'5m': '5m',
'15m': '15m',
'1h': '1h',
'4h': '4h'
}
# Binance API 基础 URL
BASE_URL = "https://api.binance.com"
def __init__(self, api_key: str = "", api_secret: str = ""):
"""
初始化 Binance 服务
Args:
api_key: API 密钥(可选,公开数据不需要)
api_secret: API 密钥(可选)
"""
self._api_key = api_key
self._api_secret = api_secret
self._session = requests.Session()
if api_key:
self._session.headers.update({'X-MBX-APIKEY': api_key})
logger.info("Binance 服务初始化完成")
def get_klines(self, symbol: str, interval: str, limit: int = 100) -> pd.DataFrame:
"""
获取 K 线数据
Args:
symbol: 交易对,如 'BTCUSDT'
interval: K线周期'5m', '15m', '1h', '4h'
limit: 获取数量
Returns:
DataFrame 包含 OHLCV 数据
"""
try:
binance_interval = self.INTERVALS.get(interval, interval)
url = f"{self.BASE_URL}/api/v3/klines"
params = {
'symbol': symbol,
'interval': binance_interval,
'limit': limit
}
response = self._session.get(url, params=params, timeout=10)
response.raise_for_status()
klines = response.json()
return self._parse_klines(klines)
except Exception as e:
logger.error(f"获取 {symbol} {interval} K线数据失败: {e}")
return pd.DataFrame()
def get_multi_timeframe_data(self, symbol: str) -> Dict[str, pd.DataFrame]:
"""
获取多周期 K 线数据
Args:
symbol: 交易对
Returns:
包含 5m, 15m, 1h, 4h 数据的字典
"""
# 不同周期使用不同的数据量,平衡分析深度和性能
# 5m: 200根 = 16.7小时(日内分析)
# 15m: 200根 = 2.1天(短线分析)
# 1h: 300根 = 12.5天(中线分析)
# 4h: 200根 = 33.3天(趋势分析)
limits = {
'5m': 200,
'15m': 200,
'1h': 300,
'4h': 200
}
data = {}
for interval in ['5m', '15m', '1h', '4h']:
df = self.get_klines(symbol, interval, limit=limits.get(interval, 100))
if not df.empty:
df = self.calculate_indicators(df)
data[interval] = df
logger.info(f"获取 {symbol} 多周期数据完成")
return data
def _parse_klines(self, klines: List) -> pd.DataFrame:
"""解析 K 线数据为 DataFrame"""
if not klines:
return pd.DataFrame()
df = pd.DataFrame(klines, columns=[
'open_time', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume', 'trades',
'taker_buy_base', 'taker_buy_quote', 'ignore'
])
# 转换数据类型
df['open_time'] = pd.to_datetime(df['open_time'], unit='ms')
df['close_time'] = pd.to_datetime(df['close_time'], unit='ms')
for col in ['open', 'high', 'low', 'close', 'volume', 'quote_volume']:
df[col] = df[col].astype(float)
df['trades'] = df['trades'].astype(int)
# 只保留需要的列
df = df[['open_time', 'open', 'high', 'low', 'close', 'volume', 'trades']]
return df
def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""
计算技术指标
Args:
df: K线数据 DataFrame
Returns:
添加了技术指标的 DataFrame
"""
if df.empty:
return df
# 移动平均线
df['ma5'] = self._calculate_ma(df['close'], 5)
df['ma10'] = self._calculate_ma(df['close'], 10)
df['ma20'] = self._calculate_ma(df['close'], 20)
df['ma50'] = self._calculate_ma(df['close'], 50)
# EMA
df['ema12'] = self._calculate_ema(df['close'], 12)
df['ema26'] = self._calculate_ema(df['close'], 26)
# RSI
df['rsi'] = self._calculate_rsi(df['close'], 14)
# MACD
df['macd'], df['macd_signal'], df['macd_hist'] = self._calculate_macd(df['close'])
# 布林带
df['bb_upper'], df['bb_middle'], df['bb_lower'] = self._calculate_bollinger(df['close'])
# KDJ
df['k'], df['d'], df['j'] = self._calculate_kdj(df['high'], df['low'], df['close'])
# ATR
df['atr'] = self._calculate_atr(df['high'], df['low'], df['close'])
# 成交量均线
df['volume_ma5'] = self._calculate_ma(df['volume'], 5)
df['volume_ma20'] = self._calculate_ma(df['volume'], 20)
df['volume_ratio'] = df['volume'] / df['volume_ma20'] # 量比
return df
@staticmethod
def _calculate_ma(data: pd.Series, period: int) -> pd.Series:
"""简单移动平均线"""
return data.rolling(window=period).mean()
@staticmethod
def _calculate_ema(data: pd.Series, period: int) -> pd.Series:
"""指数移动平均线"""
return data.ewm(span=period, adjust=False).mean()
@staticmethod
def _calculate_rsi(data: pd.Series, period: int = 14) -> pd.Series:
"""RSI 指标"""
delta = data.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi
@staticmethod
def _calculate_macd(data: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9):
"""MACD 指标"""
ema_fast = data.ewm(span=fast, adjust=False).mean()
ema_slow = data.ewm(span=slow, adjust=False).mean()
macd = ema_fast - ema_slow
signal_line = macd.ewm(span=signal, adjust=False).mean()
histogram = macd - signal_line
return macd, signal_line, histogram
@staticmethod
def _calculate_bollinger(data: pd.Series, period: int = 20, std_dev: float = 2.0):
"""布林带"""
middle = data.rolling(window=period).mean()
std = data.rolling(window=period).std()
upper = middle + (std * std_dev)
lower = middle - (std * std_dev)
return upper, middle, lower
@staticmethod
def _calculate_kdj(high: pd.Series, low: pd.Series, close: pd.Series,
period: int = 9, k_period: int = 3, d_period: int = 3):
"""KDJ 指标"""
low_min = low.rolling(window=period).min()
high_max = high.rolling(window=period).max()
rsv = (close - low_min) / (high_max - low_min) * 100
k = rsv.ewm(com=k_period - 1, adjust=False).mean()
d = k.ewm(com=d_period - 1, adjust=False).mean()
j = 3 * k - 2 * d
return k, d, j
@staticmethod
def _calculate_atr(high: pd.Series, low: pd.Series, close: pd.Series, period: int = 14):
"""ATR 平均真实波幅"""
tr1 = high - low
tr2 = abs(high - close.shift())
tr3 = abs(low - close.shift())
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
atr = tr.rolling(window=period).mean()
return atr
def get_current_price(self, symbol: str) -> Optional[float]:
"""获取当前价格"""
try:
url = f"{self.BASE_URL}/api/v3/ticker/price"
params = {'symbol': symbol}
response = self._session.get(url, params=params, timeout=10)
response.raise_for_status()
ticker = response.json()
return float(ticker['price'])
except Exception as e:
logger.error(f"获取 {symbol} 当前价格失败: {e}")
return None
def get_24h_stats(self, symbol: str) -> Optional[Dict[str, Any]]:
"""获取 24 小时统计数据"""
try:
url = f"{self.BASE_URL}/api/v3/ticker/24hr"
params = {'symbol': symbol}
response = self._session.get(url, params=params, timeout=10)
response.raise_for_status()
stats = response.json()
return {
'price': float(stats['lastPrice']),
'price_change': float(stats['priceChange']),
'price_change_percent': float(stats['priceChangePercent']),
'high': float(stats['highPrice']),
'low': float(stats['lowPrice']),
'volume': float(stats['volume']),
'quote_volume': float(stats['quoteVolume'])
}
except Exception as e:
logger.error(f"获取 {symbol} 24h 统计失败: {e}")
return None
# 全局实例
binance_service = BinanceService()