stock-ai-agent/backend/app/services/us_stock_service.py
2026-02-27 11:27:27 +08:00

322 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
美股数据服务 - 使用 yfinance 获取美股数据
"""
from typing import Optional, Dict, Any, List
import yfinance as yf
from datetime import datetime, timedelta
import pandas as pd
from app.utils.logger import logger
class USStockService:
"""美股数据服务类(支持美股和港股)"""
def __init__(self):
"""初始化美股数据服务"""
self.cache = {} # 简单的内存缓存
@staticmethod
def _normalize_hk_symbol(symbol: str) -> str:
"""
标准化港股代码格式为 yfinance 要求的格式
- 4位及以下左侧补零到4位如 700.HK → 0700.HK, 5.HK → 0005.HK
- 5位及以上去掉前导零如 09618.HK → 9618.HK
"""
if not symbol.endswith('.HK'):
return symbol
# 分离代码和后缀
code_part = symbol[:-3] # 去掉 .HK
suffix = '.HK'
# 如果是纯数字代码
if code_part.isdigit():
# 4位及以下补零到4位
if len(code_part) <= 4:
normalized_code = code_part.zfill(4)
# 5位及以上去掉前导零
else:
normalized_code = code_part.lstrip('0') or '0'
else:
normalized_code = code_part
return normalized_code + suffix
def get_stock_info(self, symbol: str) -> Optional[Dict[str, Any]]:
"""
获取美股基本信息
Args:
symbol: 股票代码(如 AAPL, TSLA 或 0700.HK
Returns:
股票基本信息字典
"""
try:
# 标准化港股代码格式
normalized_symbol = self._normalize_hk_symbol(symbol)
stock = yf.Ticker(normalized_symbol)
info = stock.info
if not info or 'symbol' not in info:
logger.warning(f"未找到股票: {symbol}")
return None
# 提取关键信息
result = {
"symbol": symbol,
"name": info.get("longName", info.get("shortName", symbol)),
"sector": info.get("sector", "未知"),
"industry": info.get("industry", "未知"),
"market_cap": info.get("marketCap", 0),
"current_price": info.get("currentPrice", info.get("regularMarketPrice", 0)),
"previous_close": info.get("previousClose", 0),
"open": info.get("open", 0),
"day_high": info.get("dayHigh", 0),
"day_low": info.get("dayLow", 0),
"volume": info.get("volume", 0),
"avg_volume": info.get("averageVolume", 0),
"pe_ratio": info.get("trailingPE", 0),
"forward_pe": info.get("forwardPE", 0),
"pb_ratio": info.get("priceToBook", 0),
"dividend_yield": info.get("dividendYield", 0),
"52_week_high": info.get("fiftyTwoWeekHigh", 0),
"52_week_low": info.get("fiftyTwoWeekLow", 0),
"50_day_avg": info.get("fiftyDayAverage", 0),
"200_day_avg": info.get("twoHundredDayAverage", 0),
"beta": info.get("beta", 0),
"eps": info.get("trailingEps", 0),
"description": info.get("longBusinessSummary", ""),
}
logger.info(f"获取美股信息成功: {symbol}")
return result
except Exception as e:
logger.error(f"获取美股信息失败 {symbol}: {e}")
return None
def get_historical_data(
self,
symbol: str,
period: str = "1mo",
interval: str = "1d"
) -> Optional[pd.DataFrame]:
"""
获取美股历史K线数据
Args:
symbol: 股票代码
period: 时间周期 (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max)
interval: K线间隔 (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo)
Returns:
包含OHLCV数据的DataFrame
"""
try:
# 标准化港股代码格式
normalized_symbol = self._normalize_hk_symbol(symbol)
stock = yf.Ticker(normalized_symbol)
hist = stock.history(period=period, interval=interval)
if hist.empty:
logger.warning(f"未找到历史数据: {symbol}")
return None
logger.info(f"获取美股历史数据成功: {symbol}, 周期: {period}")
return hist
except Exception as e:
logger.error(f"获取美股历史数据失败 {symbol}: {e}")
return None
def get_financial_data(self, symbol: str) -> Optional[Dict[str, Any]]:
"""
获取美股财务数据
Args:
symbol: 股票代码
Returns:
财务数据字典
"""
try:
# 标准化港股代码格式
normalized_symbol = self._normalize_hk_symbol(symbol)
stock = yf.Ticker(normalized_symbol)
# 获取财务报表
financials = stock.financials
balance_sheet = stock.balance_sheet
cashflow = stock.cashflow
result = {
"symbol": symbol,
"income_statement": financials.to_dict() if not financials.empty else {},
"balance_sheet": balance_sheet.to_dict() if not balance_sheet.empty else {},
"cash_flow": cashflow.to_dict() if not cashflow.empty else {},
}
# 获取关键财务指标
info = stock.info
result["key_metrics"] = {
"revenue": info.get("totalRevenue", 0),
"gross_profit": info.get("grossProfits", 0),
"ebitda": info.get("ebitda", 0),
"net_income": info.get("netIncomeToCommon", 0),
"total_assets": info.get("totalAssets", 0),
"total_debt": info.get("totalDebt", 0),
"total_cash": info.get("totalCash", 0),
"operating_cash_flow": info.get("operatingCashflow", 0),
"free_cash_flow": info.get("freeCashflow", 0),
"roe": info.get("returnOnEquity", 0),
"roa": info.get("returnOnAssets", 0),
"profit_margin": info.get("profitMargins", 0),
"operating_margin": info.get("operatingMargins", 0),
}
logger.info(f"获取美股财务数据成功: {symbol}")
return result
except Exception as e:
logger.error(f"获取美股财务数据失败 {symbol}: {e}")
return None
def calculate_technical_indicators(self, hist: pd.DataFrame) -> Dict[str, Any]:
"""
计算技术指标
Args:
hist: 历史数据DataFrame
Returns:
技术指标字典
"""
try:
if hist.empty or len(hist) < 20:
return {}
close = hist['Close']
# 计算移动平均线
ma5 = close.rolling(window=5).mean().iloc[-1] if len(close) >= 5 else None
ma10 = close.rolling(window=10).mean().iloc[-1] if len(close) >= 10 else None
ma20 = close.rolling(window=20).mean().iloc[-1] if len(close) >= 20 else None
ma60 = close.rolling(window=60).mean().iloc[-1] if len(close) >= 60 else None
# 计算RSI使用 Wilder's Smoothing 方法)
delta = close.diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
# 使用 EMA (Wilder's Smoothing) 而不是简单平均
avg_gain = gain.ewm(alpha=1/14, adjust=False).mean()
avg_loss = loss.ewm(alpha=1/14, adjust=False).mean()
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
rsi_value = rsi.iloc[-1] if len(rsi) >= 14 else None
# 计算MACD
exp1 = close.ewm(span=12, adjust=False).mean()
exp2 = close.ewm(span=26, adjust=False).mean()
macd = exp1 - exp2
signal = macd.ewm(span=9, adjust=False).mean()
macd_value = macd.iloc[-1] if len(macd) >= 26 else None
signal_value = signal.iloc[-1] if len(signal) >= 26 else None
# 计算布林带
bb_middle = close.rolling(window=20).mean()
bb_std = close.rolling(window=20).std()
bb_upper = bb_middle + (bb_std * 2)
bb_lower = bb_middle - (bb_std * 2)
result = {
"ma5": float(ma5) if ma5 and not pd.isna(ma5) else None,
"ma10": float(ma10) if ma10 and not pd.isna(ma10) else None,
"ma20": float(ma20) if ma20 and not pd.isna(ma20) else None,
"ma60": float(ma60) if ma60 and not pd.isna(ma60) else None,
"rsi": float(rsi_value) if rsi_value and not pd.isna(rsi_value) else None,
"macd": float(macd_value) if macd_value and not pd.isna(macd_value) else None,
"macd_signal": float(signal_value) if signal_value and not pd.isna(signal_value) else None,
"bb_upper": float(bb_upper.iloc[-1]) if len(bb_upper) >= 20 and not pd.isna(bb_upper.iloc[-1]) else None,
"bb_middle": float(bb_middle.iloc[-1]) if len(bb_middle) >= 20 and not pd.isna(bb_middle.iloc[-1]) else None,
"bb_lower": float(bb_lower.iloc[-1]) if len(bb_lower) >= 20 and not pd.isna(bb_lower.iloc[-1]) else None,
}
return result
except Exception as e:
logger.error(f"计算技术指标失败: {e}")
return {}
def get_comprehensive_analysis(self, symbol: str) -> Optional[Dict[str, Any]]:
"""
获取美股综合分析数据
Args:
symbol: 股票代码
Returns:
综合分析数据字典
"""
try:
# 获取基本信息
info = self.get_stock_info(symbol)
if not info:
return None
# 获取历史数据
hist = self.get_historical_data(symbol, period="6mo", interval="1d")
if hist is None or hist.empty:
return {
"success": False,
"error": "无法获取历史数据"
}
# 计算技术指标
technical = self.calculate_technical_indicators(hist)
# 获取最近的价格数据
latest = hist.iloc[-1]
prev = hist.iloc[-2] if len(hist) > 1 else latest
# 计算涨跌幅
change = latest['Close'] - prev['Close']
change_pct = (change / prev['Close'] * 100) if prev['Close'] != 0 else 0
result = {
"success": True,
"symbol": symbol,
"name": info["name"],
"sector": info["sector"],
"industry": info["industry"],
"current_price": float(latest['Close']),
"change": float(change),
"change_percent": float(change_pct),
"volume": int(latest['Volume']),
"market_cap": info["market_cap"],
"pe_ratio": info["pe_ratio"],
"pb_ratio": info["pb_ratio"],
"dividend_yield": info["dividend_yield"],
"52_week_high": info["52_week_high"],
"52_week_low": info["52_week_low"],
"technical_indicators": technical,
"description": info["description"][:500] if info["description"] else "",
}
logger.info(f"获取美股综合分析成功: {symbol}")
return result
except Exception as e:
logger.error(f"获取美股综合分析失败 {symbol}: {e}")
return {
"success": False,
"error": str(e)
}
# 创建全局实例
us_stock_service = USStockService()