astock-agent/backend/app/data/tushare_client.py
2026-04-15 23:32:22 +08:00

265 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tushare Pro API 客户端封装
负责所有 Tushare 数据获取,内置缓存和限流。
"""
import time
import logging
import tushare as ts
import pandas as pd
from datetime import datetime, timedelta
from app.config import settings
from app.data.cache import cache
logger = logging.getLogger(__name__)
class TushareClient:
def __init__(self):
self.token = settings.tushare_token
self.pro: ts.pro_api | None = None
self._last_request_time: float = 0
self._initialized = False
def _ensure_init(self):
if not self._initialized:
if not self.token:
raise ValueError("Tushare token 未配置,请设置 ASTOCK_TUSHARE_TOKEN")
ts.set_token(self.token)
self.pro = ts.pro_api()
self._initialized = True
def _rate_limit(self):
elapsed = time.time() - self._last_request_time
if elapsed < settings.tushare_request_delay:
time.sleep(settings.tushare_request_delay - elapsed)
self._last_request_time = time.time()
def _fetch_with_retry(self, func, *args, **kwargs) -> pd.DataFrame:
self._ensure_init()
for attempt in range(settings.tushare_max_retry):
try:
self._rate_limit()
result = func(*args, **kwargs)
if result is not None and not result.empty:
return result
return pd.DataFrame()
except Exception as e:
logger.warning(f"Tushare 请求失败 (attempt {attempt + 1}): {e}")
if attempt < settings.tushare_max_retry - 1:
time.sleep((2 ** attempt) * 1)
else:
logger.error(f"Tushare 请求最终失败: {e}")
return pd.DataFrame()
return pd.DataFrame()
def _cached_fetch(self, cache_key: str, ttl: int, func, *args, **kwargs) -> pd.DataFrame:
cached = cache.get(cache_key)
if cached is not None:
return cached
result = self._fetch_with_retry(func, *args, **kwargs)
if not result.empty:
# Tushare 返回降序(最新在前),排序为升序(最旧在前)以便分析
if "trade_date" in result.columns:
result = result.sort_values("trade_date").reset_index(drop=True)
cache.set(cache_key, result, ttl)
return result
# ── 交易日历 ──
def get_trade_dates(self, start: str = None, end: str = None) -> list[str]:
if not start:
start = (datetime.now() - timedelta(days=60)).strftime("%Y%m%d")
if not end:
end = datetime.now().strftime("%Y%m%d")
key = f"trade_cal:{start}:{end}"
df = self._cached_fetch(
key, settings.cache_ttl_static,
lambda: self.pro.trade_cal(
exchange="SSE", start_date=start, end_date=end,
fields="cal_date,is_open"
)
)
if df.empty:
return []
return df[df["is_open"] == 1]["cal_date"].sort_values().tolist()
def get_latest_trade_date(self) -> str:
"""获取有数据的最新交易日。
盘中时 daily 数据尚未更新,需要回退到上一个交易日。
策略:如果当天是交易日且已收盘(>15:30用当天
否则用前一个交易日。
"""
dates = self.get_trade_dates()
today = datetime.now().strftime("%Y%m%d")
now_hour = datetime.now().hour
now_minute = datetime.now().minute
past_dates = [d for d in dates if d <= today]
if not past_dates:
return today
# 如果当天是交易日但还没收盘(<15:30daily 数据不可用,回退
if past_dates[-1] == today and (now_hour < 15 or (now_hour == 15 and now_minute < 30)):
return past_dates[-2] if len(past_dates) >= 2 else today
return past_dates[-1]
# ── 全市场日线(用于计算涨跌家数)──
def get_daily_all(self, trade_date: str) -> pd.DataFrame:
key = f"daily_all:{trade_date}"
return self._cached_fetch(
key, settings.cache_ttl_daily,
lambda: self.pro.daily(
trade_date=trade_date,
fields="ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount"
)
)
# ── 涨跌停数据 ──
def get_limit_list(self, trade_date: str) -> pd.DataFrame:
key = f"limit_list:{trade_date}"
return self._cached_fetch(
key, settings.cache_ttl_daily,
lambda: self.pro.limit_list_d(
trade_date=trade_date,
fields="ts_code,trade_date,name,close,pct_chg,fd_amount,first_time,last_time,open_times,up_stat,limit_times,limit"
)
)
# ── 指数日线 ──
def get_index_daily(self, ts_code: str = "000001.SH", days: int = 30) -> pd.DataFrame:
end = datetime.now().strftime("%Y%m%d")
start = (datetime.now() - timedelta(days=days + 15)).strftime("%Y%m%d")
key = f"index_daily:{ts_code}:{end}"
return self._cached_fetch(
key, settings.cache_ttl_daily,
lambda: self.pro.index_daily(
ts_code=ts_code, start_date=start, end_date=end,
fields="ts_code,trade_date,close,open,high,low,vol,amount"
)
)
# ── 板块数据 ──
def get_ths_index_list(self, index_type: str = "N") -> pd.DataFrame:
"""获取同花顺板块列表N=概念板块"""
key = f"ths_index_list:{index_type}"
return self._cached_fetch(
key, settings.cache_ttl_static,
lambda: self.pro.ths_index(
exchange="A", type=index_type,
fields="ts_code,name,count,list_date"
)
)
def get_ths_daily(self, ts_code: str, days: int = 5) -> pd.DataFrame:
end = datetime.now().strftime("%Y%m%d")
start = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d")
key = f"ths_daily:{ts_code}:{end}:{days}"
return self._cached_fetch(
key, settings.cache_ttl_sector,
lambda: self.pro.ths_daily(
ts_code=ts_code, start_date=start, end_date=end,
fields="ts_code,trade_date,close,open,high,low,pct_change,vol,turnover_rate"
)
)
def get_ths_members(self, ts_code: str) -> pd.DataFrame:
"""获取板块成分股
返回列: ts_code(板块代码), con_code(成分股代码), con_name(成分股名称)
"""
key = f"ths_member:{ts_code}"
return self._cached_fetch(
key, settings.cache_ttl_static,
lambda: self.pro.ths_member(
ts_code=ts_code,
fields="ts_code,con_code,con_name"
)
)
# ── 板块资金流向 ──
def get_sector_moneyflow(self, trade_date: str) -> pd.DataFrame:
key = f"sector_mf:{trade_date}"
return self._cached_fetch(
key, settings.cache_ttl_sector,
lambda: self.pro.moneyflow_ind_ths(
trade_date=trade_date,
fields="ts_code,trade_date,industry,net_amount,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount"
)
)
# ── 个股资金流向 ──
def get_stock_moneyflow(self, ts_code: str, days: int = 5) -> pd.DataFrame:
end = self.get_latest_trade_date()
start_dt = datetime.strptime(end, "%Y%m%d") - timedelta(days=days + 10)
start = start_dt.strftime("%Y%m%d")
key = f"stock_mf:{ts_code}:{end}:{days}"
return self._cached_fetch(
key, settings.cache_ttl_daily,
lambda: self.pro.moneyflow(
ts_code=ts_code, start_date=start, end_date=end,
fields="ts_code,trade_date,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount,buy_md_amount,sell_md_amount,buy_sm_amount,sell_sm_amount,net_mf_amount"
)
)
def get_moneyflow_batch(self, trade_date: str) -> pd.DataFrame:
"""批量获取当日全市场资金流向"""
key = f"mf_batch:{trade_date}"
return self._cached_fetch(
key, settings.cache_ttl_daily,
lambda: self.pro.moneyflow(
trade_date=trade_date,
fields="ts_code,trade_date,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount,buy_md_amount,sell_md_amount,buy_sm_amount,sell_sm_amount,net_mf_amount"
)
)
# ── 日线基础指标PE/PB/换手率/市值)──
def get_daily_basic(self, trade_date: str) -> pd.DataFrame:
key = f"daily_basic:{trade_date}"
return self._cached_fetch(
key, settings.cache_ttl_daily,
lambda: self.pro.daily_basic(
trade_date=trade_date,
fields="ts_code,trade_date,turnover_rate,volume_ratio,pe,pb,total_mv,circ_mv"
)
)
# ── 个股日线历史K线──
def get_stock_daily(self, ts_code: str, days: int = 120) -> pd.DataFrame:
end = datetime.now().strftime("%Y%m%d")
start = (datetime.now() - timedelta(days=days + 30)).strftime("%Y%m%d")
key = f"stock_daily:{ts_code}:{end}:{days}"
return self._cached_fetch(
key, settings.cache_ttl_daily,
lambda: self.pro.daily(
ts_code=ts_code, start_date=start, end_date=end,
fields="ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount"
)
)
# ── 股票基本信息 ──
def get_stock_basic(self) -> pd.DataFrame:
key = "stock_basic"
return self._cached_fetch(
key, settings.cache_ttl_static,
lambda: self.pro.stock_basic(
exchange="", list_status="L",
fields="ts_code,symbol,name,area,industry,market,list_date"
)
)
tushare_client = TushareClient()