251 lines
8.2 KiB
Python
251 lines
8.2 KiB
Python
"""
|
||
YFinance 服务 - 美股数据获取
|
||
支持获取美股的实时行情和历史 K 线数据
|
||
"""
|
||
import pandas as pd
|
||
from typing import Dict, List, Optional
|
||
from datetime import datetime, timedelta
|
||
from app.utils.logger import logger
|
||
import time
|
||
|
||
|
||
class YFinanceService:
|
||
"""YFinance 服务类"""
|
||
|
||
def __init__(self):
|
||
"""初始化服务"""
|
||
try:
|
||
import yfinance as yf
|
||
self.yf = yf
|
||
self._cache = {} # 数据缓存
|
||
self._cache_time = {} # 缓存时间
|
||
self._cache_ttl = 300 # 缓存有效期(秒)
|
||
logger.info("YFinance 服务初始化成功")
|
||
except ImportError:
|
||
logger.error("yfinance 未安装,请运行: pip install yfinance")
|
||
raise
|
||
|
||
def get_ticker(self, symbol: str) -> Optional[Dict]:
|
||
"""
|
||
获取股票实时行情
|
||
|
||
Args:
|
||
symbol: 股票代码,如 'AAPL'
|
||
|
||
Returns:
|
||
行情数据字典
|
||
"""
|
||
try:
|
||
ticker = self.yf.Ticker(symbol)
|
||
|
||
# 使用 history 方法获取数据(更可靠,避免 429 错误)
|
||
hist = ticker.history(period="2d", interval="1h")
|
||
if hist.empty:
|
||
logger.warning(f"无法获取 {symbol} 的历史数据")
|
||
return None
|
||
|
||
latest = hist.iloc[-1]
|
||
|
||
return {
|
||
'symbol': symbol,
|
||
'lastPrice': float(latest['Close']),
|
||
'priceChange': float(latest['Close'] - latest['Open']),
|
||
'priceChangePercent': float((latest['Close'] - latest['Open']) / latest['Open'] * 100) if latest['Open'] > 0 else 0,
|
||
'volume': int(latest['Volume']),
|
||
'high': float(latest['High']),
|
||
'low': float(latest['Low']),
|
||
'open': float(latest['Open']),
|
||
'prevClose': float(latest['Close']),
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
# 过滤掉常见的 429 错误信息
|
||
if "429" in error_msg or "Too Many Requests" in error_msg:
|
||
logger.warning(f"YFinance API 限流,请稍后再试 ({symbol})")
|
||
else:
|
||
logger.error(f"获取 {symbol} 行情失败: {error_msg}")
|
||
return None
|
||
|
||
def get_multi_timeframe_data(
|
||
self,
|
||
symbol: str,
|
||
timeframes: Optional[Dict[str, tuple]] = None
|
||
) -> Dict[str, pd.DataFrame]:
|
||
"""
|
||
获取多时间周期的 K 线数据
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
timeframes: 时间周期配置 {'1d': ('1d', '3mo'), ...}
|
||
|
||
Returns:
|
||
多时间周期数据字典 {'1d': df, '1h': df, ...}
|
||
"""
|
||
if timeframes is None:
|
||
# 默认时间周期配置 - 股票不需要太实时的数据
|
||
timeframes = {
|
||
'1w': ('1wk', '2y'), # 周级别,2年
|
||
'1d': ('1d', '6mo'), # 日级别,6个月
|
||
'1h': ('1h', '1mo'), # 小时级别,1个月
|
||
}
|
||
|
||
result = {}
|
||
|
||
for tf_name, (interval, period) in timeframes.items():
|
||
try:
|
||
df = self._get_cached_data(symbol, interval, period)
|
||
if df is not None and not df.empty:
|
||
result[tf_name] = df
|
||
logger.debug(f"获取 {symbol} {tf_name} 数据成功: {len(df)} 条")
|
||
else:
|
||
logger.warning(f"获取 {symbol} {tf_name} 数据失败或为空")
|
||
except Exception as e:
|
||
logger.error(f"获取 {symbol} {tf_name} 数据出错: {e}")
|
||
|
||
return result
|
||
|
||
def _get_cached_data(
|
||
self,
|
||
symbol: str,
|
||
interval: str,
|
||
period: str
|
||
) -> Optional[pd.DataFrame]:
|
||
"""获取带缓存的数据"""
|
||
cache_key = f"{symbol}_{interval}_{period}"
|
||
now = datetime.now()
|
||
|
||
# 检查缓存
|
||
if cache_key in self._cache:
|
||
cache_time = self._cache_time.get(cache_key)
|
||
if cache_time and (now - cache_time).total_seconds() < self._cache_ttl:
|
||
logger.debug(f"使用缓存数据: {cache_key}")
|
||
return self._cache[cache_key]
|
||
|
||
# 获取新数据
|
||
try:
|
||
ticker = self.yf.Ticker(symbol)
|
||
df = ticker.history(period=period, interval=interval)
|
||
|
||
if df.empty:
|
||
return None
|
||
|
||
# 转换数据格式(兼容现有代码)
|
||
df = self._format_dataframe(df)
|
||
|
||
# 更新缓存
|
||
self._cache[cache_key] = df
|
||
self._cache_time[cache_key] = now
|
||
|
||
return df
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取数据失败 {cache_key}: {e}")
|
||
return None
|
||
|
||
def _format_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
格式化 DataFrame 以兼容现有代码
|
||
|
||
yfinance 原始格式:
|
||
- 列名大写: Open, High, Low, Close, Volume
|
||
- 索引是 Datetime
|
||
|
||
转换后格式:
|
||
- 列名小写: open, high, low, close, volume
|
||
- 重置索引,time 作为一列
|
||
- 添加技术指标
|
||
"""
|
||
df = df.copy()
|
||
|
||
# 列名转为小写
|
||
df.columns = [col.lower() for col in df.columns]
|
||
|
||
# 重置索引
|
||
df = df.reset_index()
|
||
|
||
# 重命名日期列
|
||
if 'date' in df.columns:
|
||
df = df.rename(columns={'date': 'time'})
|
||
elif 'datetime' in df.columns:
|
||
df = df.rename(columns={'datetime': 'time'})
|
||
|
||
# 删除不需要的列
|
||
cols_to_keep = ['time', 'open', 'high', 'low', 'close', 'volume']
|
||
df = df[[col for col in cols_to_keep if col in df.columns]]
|
||
|
||
# 添加技术指标(与 binance_service 一致)
|
||
df = self._add_indicators(df)
|
||
|
||
return df
|
||
|
||
def _add_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
添加技术指标到 DataFrame
|
||
|
||
Args:
|
||
df: 原始数据
|
||
|
||
Returns:
|
||
添加了技术指标的 DataFrame
|
||
"""
|
||
df = df.copy()
|
||
|
||
# 移动平均线
|
||
df['ma5'] = df['close'].rolling(window=5).mean()
|
||
df['ma10'] = df['close'].rolling(window=10).mean()
|
||
df['ma20'] = df['close'].rolling(window=20).mean()
|
||
df['ma50'] = df['close'].rolling(window=50).mean()
|
||
|
||
# RSI(使用 Wilder's Smoothing 方法)
|
||
delta = df['close'].diff()
|
||
gain = delta.where(delta > 0, 0)
|
||
loss = -delta.where(delta < 0, 0)
|
||
# 使用 EMA (Wilder's Smoothing) 而不是简单平均
|
||
avg_gain = gain.ewm(alpha=1/14, adjust=False).mean()
|
||
avg_loss = loss.ewm(alpha=1/14, adjust=False).mean()
|
||
rs = avg_gain / avg_loss
|
||
df['rsi'] = 100 - (100 / (1 + rs))
|
||
|
||
# MACD (使用与 binance_service 相同的计算方法)
|
||
ema_fast = df['close'].ewm(span=12, adjust=False).mean()
|
||
ema_slow = df['close'].ewm(span=26, adjust=False).mean()
|
||
df['macd'] = ema_fast - ema_slow
|
||
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
|
||
df['macd_hist'] = df['macd'] - df['macd_signal']
|
||
|
||
# ATR
|
||
high_low = df['high'] - df['low']
|
||
high_close = abs(df['high'] - df['close'].shift())
|
||
low_close = abs(df['low'] - df['close'].shift())
|
||
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||
df['atr'] = true_range.rolling(window=14).mean()
|
||
|
||
# KDJ 指标
|
||
low_min = df['low'].rolling(window=9).min()
|
||
high_max = df['high'].rolling(window=9).max()
|
||
rsv = (df['close'] - low_min) / (high_max - low_min) * 100
|
||
df['k'] = rsv.ewm(com=2, adjust=False).mean()
|
||
df['d'] = df['k'].ewm(com=2, adjust=False).mean()
|
||
df['j'] = 3 * df['k'] - 2 * df['d']
|
||
|
||
return df
|
||
|
||
def clear_cache(self):
|
||
"""清空缓存"""
|
||
self._cache.clear()
|
||
self._cache_time.clear()
|
||
logger.info("YFinance 缓存已清空")
|
||
|
||
|
||
# 全局单例
|
||
_yfinance_service: Optional[YFinanceService] = None
|
||
|
||
|
||
def get_yfinance_service() -> YFinanceService:
|
||
"""获取 YFinance 服务单例"""
|
||
global _yfinance_service
|
||
if _yfinance_service is None:
|
||
_yfinance_service = YFinanceService()
|
||
return _yfinance_service
|