stock-ai-agent/backend/app/services/yfinance_service.py
2026-02-19 21:20:20 +08:00

247 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YFinance 服务 - 美股数据获取
支持获取美股的实时行情和历史 K 线数据
"""
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from app.utils.logger import logger
import time
class YFinanceService:
"""YFinance 服务类"""
def __init__(self):
"""初始化服务"""
try:
import yfinance as yf
self.yf = yf
self._cache = {} # 数据缓存
self._cache_time = {} # 缓存时间
self._cache_ttl = 300 # 缓存有效期(秒)
logger.info("YFinance 服务初始化成功")
except ImportError:
logger.error("yfinance 未安装,请运行: pip install yfinance")
raise
def get_ticker(self, symbol: str) -> Optional[Dict]:
"""
获取股票实时行情
Args:
symbol: 股票代码,如 'AAPL'
Returns:
行情数据字典
"""
try:
ticker = self.yf.Ticker(symbol)
# 使用 history 方法获取数据(更可靠,避免 429 错误)
hist = ticker.history(period="2d", interval="1h")
if hist.empty:
logger.warning(f"无法获取 {symbol} 的历史数据")
return None
latest = hist.iloc[-1]
return {
'symbol': symbol,
'lastPrice': float(latest['Close']),
'priceChange': float(latest['Close'] - latest['Open']),
'priceChangePercent': float((latest['Close'] - latest['Open']) / latest['Open'] * 100) if latest['Open'] > 0 else 0,
'volume': int(latest['Volume']),
'high': float(latest['High']),
'low': float(latest['Low']),
'open': float(latest['Open']),
'prevClose': float(latest['Close']),
'timestamp': datetime.now().isoformat()
}
except Exception as e:
error_msg = str(e)
# 过滤掉常见的 429 错误信息
if "429" in error_msg or "Too Many Requests" in error_msg:
logger.warning(f"YFinance API 限流,请稍后再试 ({symbol})")
else:
logger.error(f"获取 {symbol} 行情失败: {error_msg}")
return None
def get_multi_timeframe_data(
self,
symbol: str,
timeframes: Optional[Dict[str, tuple]] = None
) -> Dict[str, pd.DataFrame]:
"""
获取多时间周期的 K 线数据
Args:
symbol: 股票代码
timeframes: 时间周期配置 {'1d': ('1d', '3mo'), ...}
Returns:
多时间周期数据字典 {'1d': df, '1h': df, ...}
"""
if timeframes is None:
# 默认时间周期配置
timeframes = {
'1d': ('1d', '3mo'), # 日级别3个月
'1h': ('1h', '1mo'), # 小时级别1个月
}
result = {}
for tf_name, (interval, period) in timeframes.items():
try:
df = self._get_cached_data(symbol, interval, period)
if df is not None and not df.empty:
result[tf_name] = df
logger.debug(f"获取 {symbol} {tf_name} 数据成功: {len(df)}")
else:
logger.warning(f"获取 {symbol} {tf_name} 数据失败或为空")
except Exception as e:
logger.error(f"获取 {symbol} {tf_name} 数据出错: {e}")
return result
def _get_cached_data(
self,
symbol: str,
interval: str,
period: str
) -> Optional[pd.DataFrame]:
"""获取带缓存的数据"""
cache_key = f"{symbol}_{interval}_{period}"
now = datetime.now()
# 检查缓存
if cache_key in self._cache:
cache_time = self._cache_time.get(cache_key)
if cache_time and (now - cache_time).total_seconds() < self._cache_ttl:
logger.debug(f"使用缓存数据: {cache_key}")
return self._cache[cache_key]
# 获取新数据
try:
ticker = self.yf.Ticker(symbol)
df = ticker.history(period=period, interval=interval)
if df.empty:
return None
# 转换数据格式(兼容现有代码)
df = self._format_dataframe(df)
# 更新缓存
self._cache[cache_key] = df
self._cache_time[cache_key] = now
return df
except Exception as e:
logger.error(f"获取数据失败 {cache_key}: {e}")
return None
def _format_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""
格式化 DataFrame 以兼容现有代码
yfinance 原始格式:
- 列名大写: Open, High, Low, Close, Volume
- 索引是 Datetime
转换后格式:
- 列名小写: open, high, low, close, volume
- 重置索引time 作为一列
- 添加技术指标
"""
df = df.copy()
# 列名转为小写
df.columns = [col.lower() for col in df.columns]
# 重置索引
df = df.reset_index()
# 重命名日期列
if 'date' in df.columns:
df = df.rename(columns={'date': 'time'})
elif 'datetime' in df.columns:
df = df.rename(columns={'datetime': 'time'})
# 删除不需要的列
cols_to_keep = ['time', 'open', 'high', 'low', 'close', 'volume']
df = df[[col for col in cols_to_keep if col in df.columns]]
# 添加技术指标(与 binance_service 一致)
df = self._add_indicators(df)
return df
def _add_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""
添加技术指标到 DataFrame
Args:
df: 原始数据
Returns:
添加了技术指标的 DataFrame
"""
df = df.copy()
# 移动平均线
df['ma5'] = df['close'].rolling(window=5).mean()
df['ma10'] = df['close'].rolling(window=10).mean()
df['ma20'] = df['close'].rolling(window=20).mean()
df['ma50'] = df['close'].rolling(window=50).mean()
# RSI
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
df['rsi'] = 100 - (100 / (1 + rs))
# MACD (使用与 binance_service 相同的计算方法)
ema_fast = df['close'].ewm(span=12, adjust=False).mean()
ema_slow = df['close'].ewm(span=26, adjust=False).mean()
df['macd'] = ema_fast - ema_slow
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
df['macd_hist'] = df['macd'] - df['macd_signal']
# ATR
high_low = df['high'] - df['low']
high_close = abs(df['high'] - df['close'].shift())
low_close = abs(df['low'] - df['close'].shift())
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = true_range.rolling(window=14).mean()
# KDJ 指标
low_min = df['low'].rolling(window=9).min()
high_max = df['high'].rolling(window=9).max()
rsv = (df['close'] - low_min) / (high_max - low_min) * 100
df['k'] = rsv.ewm(com=2, adjust=False).mean()
df['d'] = df['k'].ewm(com=2, adjust=False).mean()
df['j'] = 3 * df['k'] - 2 * df['d']
return df
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
self._cache_time.clear()
logger.info("YFinance 缓存已清空")
# 全局单例
_yfinance_service: Optional[YFinanceService] = None
def get_yfinance_service() -> YFinanceService:
"""获取 YFinance 服务单例"""
global _yfinance_service
if _yfinance_service is None:
_yfinance_service = YFinanceService()
return _yfinance_service