stock-ai-agent/backend/app/services/yfinance_service.py
2026-03-13 22:16:21 +08:00

439 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YFinance 服务 - 美股港股数据获取
支持获取美股的实时行情和历史 K 线数据
备用数据源Stooq
"""
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from app.utils.logger import logger
import time
class YFinanceService:
"""YFinance 服务类(支持 Stooq 备用)"""
def __init__(self):
"""初始化服务"""
# 初始化 YFinance
try:
import yfinance as yf
self.yf = yf
self._yf_available = True
logger.info("YFinance 服务初始化成功")
except ImportError:
logger.warning("yfinance 未安装")
self._yf_available = False
# 初始化 Stooq备用
try:
import pandas_datareader.data as web
self.web = web
self._stooq_available = True
logger.info("Stooq 备用数据源初始化成功")
except ImportError:
logger.warning("pandas_datareader 未安装Stooq 备用不可用")
self._stooq_available = False
if not self._yf_available and not self._stooq_available:
raise Exception("没有可用的数据源,请安装 yfinance 或 pandas_datareader")
self._cache = {} # 数据缓存
self._cache_time = {} # 缓存时间
self._cache_ttl = 300 # 缓存有效期(秒)
def _normalize_hk_symbol(self, symbol: str) -> str:
"""
标准化港股代码格式为 yfinance 要求的格式
- 4位及以下左侧补零到4位如 700.HK → 0700.HK, 5.HK → 0005.HK
- 5位及以上去掉前导零如 09618.HK → 9618.HK
"""
if not symbol.endswith('.HK'):
return symbol
# 分离代码和后缀
code_part = symbol[:-3] # 去掉 .HK
suffix = '.HK'
# 如果是纯数字代码
if code_part.isdigit():
# 4位及以下补零到4位
if len(code_part) <= 4:
normalized_code = code_part.zfill(4)
# 5位及以上去掉前导零
else:
normalized_code = code_part.lstrip('0') or '0'
else:
normalized_code = code_part
return normalized_code + suffix
def get_ticker(self, symbol: str) -> Optional[Dict]:
"""
获取股票实时行情(优先使用 YFinance失败则使用 Stooq
Args:
symbol: 股票代码,如 'AAPL''0700.HK'
Returns:
行情数据字典
"""
# 优先使用 YFinance
if self._yf_available:
result = self._get_yf_ticker(symbol)
if result:
return result
logger.info(f"YFinance 获取失败,尝试使用 Stooq 备用数据源 ({symbol})")
# 备用使用 Stooq
if self._stooq_available:
result = self._get_stooq_ticker(symbol)
if result:
return result
return None
def _get_yf_ticker(self, symbol: str) -> Optional[Dict]:
"""使用 YFinance 获取行情"""
try:
normalized_symbol = self._normalize_hk_symbol(symbol)
ticker = self.yf.Ticker(normalized_symbol)
hist = ticker.history(period="2d", interval="1h")
if hist.empty:
logger.warning(f"YFinance 无法获取 {symbol} 的数据")
return None
latest = hist.iloc[-1]
return {
'symbol': symbol,
'lastPrice': float(latest['Close']),
'priceChange': float(latest['Close'] - latest['Open']),
'priceChangePercent': float((latest['Close'] - latest['Open']) / latest['Open'] * 100) if latest['Open'] > 0 else 0,
'volume': int(latest['Volume']),
'high': float(latest['High']),
'low': float(latest['Low']),
'open': float(latest['Open']),
'prevClose': float(latest['Close']),
'timestamp': datetime.now().isoformat()
}
except Exception as e:
error_msg = str(e)
if "429" in error_msg or "Too Many Requests" in error_msg:
logger.warning(f"YFinance API 限流 ({symbol})")
else:
logger.debug(f"YFinance 获取失败 ({symbol}): {error_msg}")
return None
def _get_stooq_ticker(self, symbol: str) -> Optional[Dict]:
"""使用 Stooq 获取行情(备用)"""
try:
# Stooq 使用的港股格式
stooq_symbol = self._convert_to_stooq_symbol(symbol)
# 获取最近几天的数据
start_date = (datetime.now() - timedelta(days=5)).strftime('%Y-%m-%d')
df = self.web.DataReader(stooq_symbol, 'stooq', start=start_date)
if df.empty:
logger.warning(f"Stooq 无法获取 {symbol} 的数据")
return None
# Stooq 返回的数据是倒序的,取第一行(最新)
latest = df.iloc[-1]
return {
'symbol': symbol,
'lastPrice': float(latest['Close']),
'priceChange': float(latest['Close'] - latest['Open']),
'priceChangePercent': float((latest['Close'] - latest['Open']) / latest['Open'] * 100) if latest['Open'] > 0 else 0,
'volume': int(latest['Volume']),
'high': float(latest['High']),
'low': float(latest['Low']),
'open': float(latest['Open']),
'prevClose': float(latest['Close']),
'timestamp': datetime.now().isoformat(),
'source': 'stooq' # 标记数据来源
}
except Exception as e:
logger.error(f"Stooq 获取 {symbol} 行情失败: {e}")
return None
def _convert_to_stooq_symbol(self, symbol: str) -> str:
"""
转换股票代码为 Stooq 格式
美股AAPL -> AAPL.US
港股0700.HK -> 0700.HK
"""
if symbol.endswith('.HK'):
return symbol
elif '.' in symbol:
# 其他格式保持不变
return symbol
else:
# 美股添加 .US 后缀
return f"{symbol}.US"
def get_multi_timeframe_data(
self,
symbol: str,
timeframes: Optional[Dict[str, tuple]] = None
) -> Dict[str, pd.DataFrame]:
"""
获取多时间周期的 K 线数据
Args:
symbol: 股票代码
timeframes: 时间周期配置 {'1d': ('1d', '3mo'), ...}
Returns:
多时间周期数据字典 {'1d': df, '1h': df, ...}
"""
if timeframes is None:
# 技术面分析时间周期1h、1d、1w
timeframes = {
'1w': ('1wk', '2y'), # 周级别2年 - 长期趋势
'1d': ('1d', '6mo'), # 日级别6个月 - 中期趋势
'1h': ('1h', '1mo'), # 小时级别1个月 - 短期趋势
}
result = {}
for tf_name, (interval, period) in timeframes.items():
try:
df = self._get_cached_data(symbol, interval, period)
if df is not None and not df.empty:
result[tf_name] = df
logger.debug(f"获取 {symbol} {tf_name} 数据成功: {len(df)}")
else:
logger.warning(f"获取 {symbol} {tf_name} 数据失败或为空")
except Exception as e:
logger.error(f"获取 {symbol} {tf_name} 数据出错: {e}")
return result
def _get_cached_data(
self,
symbol: str,
interval: str,
period: str
) -> Optional[pd.DataFrame]:
"""获取带缓存的数据(优先 YFinance失败则使用 Stooq"""
# 标准化港股代码格式
normalized_symbol = self._normalize_hk_symbol(symbol)
cache_key = f"{normalized_symbol}_{interval}_{period}"
now = datetime.now()
# 检查缓存
if cache_key in self._cache:
cache_time = self._cache_time.get(cache_key)
if cache_time and (now - cache_time).total_seconds() < self._cache_ttl:
logger.debug(f"使用缓存数据: {cache_key}")
return self._cache[cache_key]
# 优先使用 YFinance
if self._yf_available:
df = self._get_yf_data(symbol, interval, period, cache_key, now)
if df is not None:
return df
logger.info(f"YFinance 获取历史数据失败,尝试 Stooq ({symbol})")
# 备用使用 Stooq
if self._stooq_available:
df = self._get_stooq_data(symbol, interval, period, cache_key, now)
if df is not None:
logger.info(f"✓ 使用 Stooq 数据源 ({symbol})")
return df
return None
def _get_yf_data(
self,
symbol: str,
interval: str,
period: str,
cache_key: str,
now: datetime
) -> Optional[pd.DataFrame]:
"""使用 YFinance 获取历史数据"""
try:
normalized_symbol = self._normalize_hk_symbol(symbol)
ticker = self.yf.Ticker(normalized_symbol)
df = ticker.history(period=period, interval=interval)
if df.empty:
return None
# 转换数据格式
df = self._format_dataframe(df)
# 更新缓存
self._cache[cache_key] = df
self._cache_time[cache_key] = now
return df
except Exception as e:
logger.debug(f"YFinance 获取历史数据失败: {e}")
return None
def _get_stooq_data(
self,
symbol: str,
interval: str,
period: str,
cache_key: str,
now: datetime
) -> Optional[pd.DataFrame]:
"""使用 Stooq 获取历史数据(备用)"""
try:
# 转换为 Stooq 格式
stooq_symbol = self._convert_to_stooq_symbol(symbol)
# 将 period 转换为天数
period_days = self._period_to_days(period)
start_date = (datetime.now() - timedelta(days=period_days)).strftime('%Y-%m-%d')
# 获取数据
df = self.web.DataReader(stooq_symbol, 'stooq', start=start_date)
if df.empty:
return None
# Stooq 数据是倒序的,需要反转
df = df.iloc[::-1]
# 转换数据格式
df = self._format_dataframe(df)
# 更新缓存
self._cache[cache_key] = df
self._cache_time[cache_key] = now
return df
except Exception as e:
logger.debug(f"Stooq 获取历史数据失败: {e}")
return None
def _period_to_days(self, period: str) -> int:
"""将 YFinance period 格式转换为天数"""
period_map = {
'1mo': 30,
'3mo': 90,
'6mo': 180,
'1y': 365,
'2y': 730,
}
return period_map.get(period, 180) # 默认6个月
def _format_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""
格式化 DataFrame 以兼容现有代码
yfinance 原始格式:
- 列名大写: Open, High, Low, Close, Volume
- 索引是 Datetime
转换后格式:
- 列名小写: open, high, low, close, volume
- 重置索引time 作为一列
- 添加技术指标
"""
df = df.copy()
# 列名转为小写
df.columns = [col.lower() for col in df.columns]
# 重置索引
df = df.reset_index()
# 重命名日期列
if 'date' in df.columns:
df = df.rename(columns={'date': 'time'})
elif 'datetime' in df.columns:
df = df.rename(columns={'datetime': 'time'})
# 删除不需要的列
cols_to_keep = ['time', 'open', 'high', 'low', 'close', 'volume']
df = df[[col for col in cols_to_keep if col in df.columns]]
# 添加技术指标(与 binance_service 一致)
df = self._add_indicators(df)
return df
def _add_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""
添加技术指标到 DataFrame
Args:
df: 原始数据
Returns:
添加了技术指标的 DataFrame
"""
df = df.copy()
# 移动平均线(简单移动平均 MA
df['ma5'] = df['close'].rolling(window=5).mean()
df['ma10'] = df['close'].rolling(window=10).mean()
df['ma20'] = df['close'].rolling(window=20).mean()
df['ma50'] = df['close'].rolling(window=50).mean()
# 指数移动平均线EMA- 用于趋势判断
df['ema20'] = df['close'].ewm(span=20, adjust=False).mean()
df['ema50'] = df['close'].ewm(span=50, adjust=False).mean()
df['ema200'] = df['close'].ewm(span=200, adjust=False).mean()
# RSI使用 Wilder's Smoothing 方法)
delta = df['close'].diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
# 使用 EMA (Wilder's Smoothing) 而不是简单平均
avg_gain = gain.ewm(alpha=1/14, adjust=False).mean()
avg_loss = loss.ewm(alpha=1/14, adjust=False).mean()
rs = avg_gain / avg_loss
df['rsi'] = 100 - (100 / (1 + rs))
# MACD (使用与 binance_service 相同的计算方法)
ema_fast = df['close'].ewm(span=12, adjust=False).mean()
ema_slow = df['close'].ewm(span=26, adjust=False).mean()
df['macd'] = ema_fast - ema_slow
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
df['macd_hist'] = df['macd'] - df['macd_signal']
# ATR
high_low = df['high'] - df['low']
high_close = abs(df['high'] - df['close'].shift())
low_close = abs(df['low'] - df['close'].shift())
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
df['atr'] = true_range.rolling(window=14).mean()
# KDJ 指标
low_min = df['low'].rolling(window=9).min()
high_max = df['high'].rolling(window=9).max()
rsv = (df['close'] - low_min) / (high_max - low_min) * 100
df['k'] = rsv.ewm(com=2, adjust=False).mean()
df['d'] = df['k'].ewm(com=2, adjust=False).mean()
df['j'] = 3 * df['k'] - 2 * df['d']
return df
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
self._cache_time.clear()
logger.info("YFinance 缓存已清空")
# 全局单例
_yfinance_service: Optional[YFinanceService] = None
def get_yfinance_service() -> YFinanceService:
"""获取 YFinance 服务单例"""
global _yfinance_service
if _yfinance_service is None:
_yfinance_service = YFinanceService()
return _yfinance_service