270 lines
10 KiB
Python
270 lines
10 KiB
Python
"""Tushare Pro API 客户端封装
|
||
|
||
负责所有 Tushare 数据获取,内置缓存和限流。
|
||
"""
|
||
|
||
import time
|
||
import logging
|
||
import tushare as ts
|
||
import pandas as pd
|
||
from datetime import datetime, timedelta
|
||
|
||
from app.config import settings
|
||
from app.data.cache import cache
|
||
from app.db.error_logger import log_error_background
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TushareClient:
|
||
def __init__(self):
|
||
self.token = settings.tushare_token
|
||
self.pro: ts.pro_api | None = None
|
||
self._last_request_time: float = 0
|
||
self._initialized = False
|
||
|
||
def _ensure_init(self):
|
||
if not self._initialized:
|
||
if not self.token:
|
||
raise ValueError("Tushare token 未配置,请设置 ASTOCK_TUSHARE_TOKEN")
|
||
# 直接把 token 传给 pro_api,避免 tushare 尝试在用户主目录写 tk.csv
|
||
self.pro = ts.pro_api(self.token)
|
||
self._initialized = True
|
||
|
||
def _rate_limit(self):
|
||
elapsed = time.time() - self._last_request_time
|
||
if elapsed < settings.tushare_request_delay:
|
||
time.sleep(settings.tushare_request_delay - elapsed)
|
||
self._last_request_time = time.time()
|
||
|
||
def _fetch_with_retry(self, func, *args, **kwargs) -> pd.DataFrame:
|
||
self._ensure_init()
|
||
for attempt in range(settings.tushare_max_retry):
|
||
try:
|
||
self._rate_limit()
|
||
result = func(*args, **kwargs)
|
||
if result is not None and not result.empty:
|
||
return result
|
||
return pd.DataFrame()
|
||
except Exception as e:
|
||
logger.warning(f"Tushare 请求失败 (attempt {attempt + 1}): {e}")
|
||
if attempt < settings.tushare_max_retry - 1:
|
||
time.sleep((2 ** attempt) * 1)
|
||
else:
|
||
logger.error(f"Tushare 请求最终失败: {e}")
|
||
log_error_background(
|
||
"tushare",
|
||
f"Tushare 请求最终失败: {e}",
|
||
)
|
||
return pd.DataFrame()
|
||
return pd.DataFrame()
|
||
|
||
def _cached_fetch(self, cache_key: str, ttl: int, func, *args, **kwargs) -> pd.DataFrame:
|
||
cached = cache.get(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
result = self._fetch_with_retry(func, *args, **kwargs)
|
||
if not result.empty:
|
||
# Tushare 返回降序(最新在前),排序为升序(最旧在前)以便分析
|
||
if "trade_date" in result.columns:
|
||
result = result.sort_values("trade_date").reset_index(drop=True)
|
||
cache.set(cache_key, result, ttl)
|
||
return result
|
||
|
||
# ── 交易日历 ──
|
||
|
||
def get_trade_dates(self, start: str = None, end: str = None) -> list[str]:
|
||
if not start:
|
||
start = (datetime.now() - timedelta(days=60)).strftime("%Y%m%d")
|
||
if not end:
|
||
end = datetime.now().strftime("%Y%m%d")
|
||
key = f"trade_cal:{start}:{end}"
|
||
df = self._cached_fetch(
|
||
key, settings.cache_ttl_static,
|
||
lambda: self.pro.trade_cal(
|
||
exchange="SSE", start_date=start, end_date=end,
|
||
fields="cal_date,is_open"
|
||
)
|
||
)
|
||
if df.empty:
|
||
return []
|
||
return df[df["is_open"] == 1]["cal_date"].sort_values().tolist()
|
||
|
||
def get_latest_trade_date(self) -> str:
|
||
"""获取有数据的最新交易日。
|
||
|
||
盘中时 daily 数据尚未更新,需要回退到上一个交易日。
|
||
策略:如果当天是交易日且已收盘(>15:30),用当天;
|
||
否则用前一个交易日。
|
||
"""
|
||
dates = self.get_trade_dates()
|
||
today = datetime.now().strftime("%Y%m%d")
|
||
now_hour = datetime.now().hour
|
||
now_minute = datetime.now().minute
|
||
|
||
past_dates = [d for d in dates if d <= today]
|
||
if not past_dates:
|
||
return today
|
||
|
||
# 如果当天是交易日但还没收盘(<15:30),daily 数据不可用,回退
|
||
if past_dates[-1] == today and (now_hour < 15 or (now_hour == 15 and now_minute < 30)):
|
||
return past_dates[-2] if len(past_dates) >= 2 else today
|
||
|
||
return past_dates[-1]
|
||
|
||
# ── 全市场日线(用于计算涨跌家数)──
|
||
|
||
def get_daily_all(self, trade_date: str) -> pd.DataFrame:
|
||
key = f"daily_all:{trade_date}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_daily,
|
||
lambda: self.pro.daily(
|
||
trade_date=trade_date,
|
||
fields="ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount"
|
||
)
|
||
)
|
||
|
||
# ── 涨跌停数据 ──
|
||
|
||
def get_limit_list(self, trade_date: str) -> pd.DataFrame:
|
||
key = f"limit_list:{trade_date}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_daily,
|
||
lambda: self.pro.limit_list_d(
|
||
trade_date=trade_date,
|
||
fields="ts_code,trade_date,name,close,pct_chg,fd_amount,first_time,last_time,open_times,up_stat,limit_times,limit"
|
||
)
|
||
)
|
||
|
||
# ── 指数日线 ──
|
||
|
||
def get_index_daily(self, ts_code: str = "000001.SH", days: int = 30) -> pd.DataFrame:
|
||
end = datetime.now().strftime("%Y%m%d")
|
||
start = (datetime.now() - timedelta(days=days + 15)).strftime("%Y%m%d")
|
||
key = f"index_daily:{ts_code}:{end}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_daily,
|
||
lambda: self.pro.index_daily(
|
||
ts_code=ts_code, start_date=start, end_date=end,
|
||
fields="ts_code,trade_date,close,open,high,low,vol,amount"
|
||
)
|
||
)
|
||
|
||
# ── 板块数据 ──
|
||
|
||
def get_ths_index_list(self, index_type: str = "N") -> pd.DataFrame:
|
||
"""获取同花顺板块列表,N=概念板块"""
|
||
key = f"ths_index_list:{index_type}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_static,
|
||
lambda: self.pro.ths_index(
|
||
exchange="A", type=index_type,
|
||
fields="ts_code,name,count,list_date"
|
||
)
|
||
)
|
||
|
||
def get_ths_daily(self, ts_code: str, days: int = 5) -> pd.DataFrame:
|
||
end = datetime.now().strftime("%Y%m%d")
|
||
start = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d")
|
||
key = f"ths_daily:{ts_code}:{end}:{days}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_sector,
|
||
lambda: self.pro.ths_daily(
|
||
ts_code=ts_code, start_date=start, end_date=end,
|
||
fields="ts_code,trade_date,close,open,high,low,pct_change,vol,turnover_rate"
|
||
)
|
||
)
|
||
|
||
def get_ths_members(self, ts_code: str) -> pd.DataFrame:
|
||
"""获取板块成分股
|
||
|
||
返回列: ts_code(板块代码), con_code(成分股代码), con_name(成分股名称)
|
||
"""
|
||
key = f"ths_member:{ts_code}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_static,
|
||
lambda: self.pro.ths_member(
|
||
ts_code=ts_code,
|
||
fields="ts_code,con_code,con_name"
|
||
)
|
||
)
|
||
|
||
# ── 板块资金流向 ──
|
||
|
||
def get_sector_moneyflow(self, trade_date: str) -> pd.DataFrame:
|
||
key = f"sector_mf:{trade_date}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_sector,
|
||
lambda: self.pro.moneyflow_ind_ths(
|
||
trade_date=trade_date,
|
||
fields="ts_code,trade_date,industry,net_amount,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount"
|
||
)
|
||
)
|
||
|
||
# ── 个股资金流向 ──
|
||
|
||
def get_stock_moneyflow(self, ts_code: str, days: int = 5) -> pd.DataFrame:
|
||
end = self.get_latest_trade_date()
|
||
start_dt = datetime.strptime(end, "%Y%m%d") - timedelta(days=days + 10)
|
||
start = start_dt.strftime("%Y%m%d")
|
||
key = f"stock_mf:{ts_code}:{end}:{days}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_daily,
|
||
lambda: self.pro.moneyflow(
|
||
ts_code=ts_code, start_date=start, end_date=end,
|
||
fields="ts_code,trade_date,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount,buy_md_amount,sell_md_amount,buy_sm_amount,sell_sm_amount,net_mf_amount"
|
||
)
|
||
)
|
||
|
||
def get_moneyflow_batch(self, trade_date: str) -> pd.DataFrame:
|
||
"""批量获取当日全市场资金流向"""
|
||
key = f"mf_batch:{trade_date}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_daily,
|
||
lambda: self.pro.moneyflow(
|
||
trade_date=trade_date,
|
||
fields="ts_code,trade_date,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount,buy_md_amount,sell_md_amount,buy_sm_amount,sell_sm_amount,net_mf_amount"
|
||
)
|
||
)
|
||
|
||
# ── 日线基础指标(PE/PB/换手率/市值)──
|
||
|
||
def get_daily_basic(self, trade_date: str) -> pd.DataFrame:
|
||
key = f"daily_basic:{trade_date}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_daily,
|
||
lambda: self.pro.daily_basic(
|
||
trade_date=trade_date,
|
||
fields="ts_code,trade_date,turnover_rate,volume_ratio,pe,pb,total_mv,circ_mv"
|
||
)
|
||
)
|
||
|
||
# ── 个股日线(历史K线)──
|
||
|
||
def get_stock_daily(self, ts_code: str, days: int = 120) -> pd.DataFrame:
|
||
end = datetime.now().strftime("%Y%m%d")
|
||
start = (datetime.now() - timedelta(days=days + 30)).strftime("%Y%m%d")
|
||
key = f"stock_daily:{ts_code}:{end}:{days}"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_daily,
|
||
lambda: self.pro.daily(
|
||
ts_code=ts_code, start_date=start, end_date=end,
|
||
fields="ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount"
|
||
)
|
||
)
|
||
|
||
# ── 股票基本信息 ──
|
||
|
||
def get_stock_basic(self) -> pd.DataFrame:
|
||
key = "stock_basic"
|
||
return self._cached_fetch(
|
||
key, settings.cache_ttl_static,
|
||
lambda: self.pro.stock_basic(
|
||
exchange="", list_status="L",
|
||
fields="ts_code,symbol,name,area,industry,market,list_date"
|
||
)
|
||
)
|
||
|
||
|
||
tushare_client = TushareClient()
|