347 lines
11 KiB
Python
347 lines
11 KiB
Python
"""
|
||
Tushare 数据封装
|
||
提供 A 股板块、个股行情数据获取接口(使用同花顺系列接口)
|
||
"""
|
||
import time
|
||
import tushare as ts
|
||
import pandas as pd
|
||
from typing import Dict, List, Optional
|
||
from datetime import datetime, timedelta
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class TushareClient:
|
||
"""Tushare 数据客户端(同花顺系列接口)"""
|
||
|
||
# 缓存数据,避免频繁请求
|
||
_cache = {}
|
||
_cache_time = {}
|
||
_last_request_time = 0
|
||
|
||
def __init__(self, token: str):
|
||
"""
|
||
初始化客户端
|
||
|
||
Args:
|
||
token: Tushare token
|
||
"""
|
||
self.token = token
|
||
ts.set_token(token)
|
||
self.pro = ts.pro_api()
|
||
self.cache_ttl = 300 # 缓存5分钟
|
||
self.request_delay = 0.5 # 请求间隔(秒)- tushare 有频率限制
|
||
|
||
def _get_cached(self, key: str, fetch_func) -> pd.DataFrame:
|
||
"""获取缓存数据,支持重试"""
|
||
now = datetime.now()
|
||
|
||
# 检查缓存
|
||
if key in self._cache:
|
||
cache_time = self._cache_time.get(key)
|
||
if cache_time and (now - cache_time).seconds < self.cache_ttl:
|
||
logger.debug(f"使用缓存数据: {key}")
|
||
return self._cache[key]
|
||
|
||
# 请求限流
|
||
elapsed = now.timestamp() - self._last_request_time
|
||
if elapsed < self.request_delay:
|
||
time.sleep(self.request_delay - elapsed)
|
||
|
||
# 重试逻辑
|
||
max_retries = 3
|
||
for attempt in range(max_retries):
|
||
try:
|
||
self._last_request_time = time.time()
|
||
df = fetch_func()
|
||
|
||
if df is not None and not df.empty:
|
||
self._cache[key] = df
|
||
self._cache_time[key] = now
|
||
logger.debug(f"获取数据成功: {key}")
|
||
return df
|
||
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
# 指数退避重试
|
||
if attempt < max_retries - 1:
|
||
wait_time = (2 ** attempt) * 2
|
||
logger.warning(
|
||
f"获取数据失败 {key} (尝试 {attempt + 1}/{max_retries}): {e},"
|
||
f"等待 {wait_time}秒后重试..."
|
||
)
|
||
time.sleep(wait_time)
|
||
continue
|
||
|
||
logger.error(f"获取数据失败 {key}: {e}")
|
||
break
|
||
|
||
return pd.DataFrame()
|
||
|
||
def get_concept_sectors(self) -> pd.DataFrame:
|
||
"""
|
||
获取概念板块列表
|
||
|
||
使用 ths_index 接口,type="N" 代表概念板块
|
||
|
||
Returns:
|
||
概念板块列表
|
||
"""
|
||
def fetch():
|
||
# ths_index - 获取同花顺概念指数列表
|
||
return self.pro.ths_index(type='N')
|
||
|
||
return self._get_cached('concept_sectors', fetch)
|
||
|
||
def get_sector_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||
"""
|
||
获取板块日线行情
|
||
|
||
Args:
|
||
ts_code: 板块指数代码(如 885823.TI)
|
||
start_date: 开始日期 (YYYYMMDD)
|
||
end_date: 结束日期 (YYYYMMDD)
|
||
|
||
Returns:
|
||
板块日线数据
|
||
"""
|
||
if not start_date:
|
||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||
if not end_date:
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
|
||
def fetch():
|
||
# ths_daily - 获取板块指数历史行情
|
||
return self.pro.ths_daily(
|
||
ts_code=ts_code,
|
||
start_date=start_date,
|
||
end_date=end_date
|
||
)
|
||
|
||
return self._get_cached(f'sector_daily_{ts_code}_{end_date}', fetch)
|
||
|
||
def get_sector_members(self, ts_code: str) -> pd.DataFrame:
|
||
"""
|
||
获取板块成分股
|
||
|
||
Args:
|
||
ts_code: 板块指数代码(如 885823.TI)
|
||
|
||
Returns:
|
||
成分股列表
|
||
"""
|
||
def fetch():
|
||
# ths_member - 获取板块成分股
|
||
return self.pro.ths_member(ts_code=ts_code)
|
||
|
||
return self._get_cached(f'sector_members_{ts_code}', fetch)
|
||
|
||
def get_stock_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||
"""
|
||
获取个股日线行情
|
||
|
||
Args:
|
||
ts_code: 股票代码(如 000001.SZ)
|
||
start_date: 开始日期 (YYYYMMDD)
|
||
end_date: 结束日期 (YYYYMMDD)
|
||
|
||
Returns:
|
||
日线数据
|
||
"""
|
||
if not start_date:
|
||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||
if not end_date:
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
|
||
def fetch():
|
||
# daily - 获取日线行情
|
||
return self.pro.daily(
|
||
ts_code=ts_code,
|
||
start_date=start_date,
|
||
end_date=end_date
|
||
)
|
||
|
||
return self._get_cached(f'stock_daily_{ts_code}_{end_date}', fetch)
|
||
|
||
def get_stock_daily_basic(self, ts_codes: List[str], trade_date: str = None) -> pd.DataFrame:
|
||
"""
|
||
获取个股每日指标(包含换手率、量比等)
|
||
|
||
Args:
|
||
ts_codes: 股票代码列表
|
||
trade_date: 交易日期 (YYYYMMDD)
|
||
|
||
Returns:
|
||
每日指标数据
|
||
"""
|
||
if not ts_codes:
|
||
return pd.DataFrame()
|
||
|
||
from datetime import datetime, timedelta
|
||
|
||
if not trade_date:
|
||
trade_date = datetime.now().strftime('%Y%m%d')
|
||
|
||
def fetch():
|
||
# daily_basic - 获取每日指标
|
||
codes_str = ','.join(ts_codes[:300]) # 限制单次查询数量
|
||
|
||
# 尝试获取最近3天的数据(以防当天数据未更新)
|
||
all_data = []
|
||
for i in range(3):
|
||
try_date = (datetime.now() - timedelta(days=i)).strftime('%Y%m%d')
|
||
df = self.pro.daily_basic(
|
||
ts_code=codes_str,
|
||
trade_date=try_date,
|
||
fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb'
|
||
)
|
||
if not df.empty:
|
||
all_data.append(df)
|
||
# 如果找到数据就不再尝试更早的日期
|
||
break
|
||
|
||
if all_data:
|
||
return pd.concat(all_data, ignore_index=True)
|
||
return pd.DataFrame()
|
||
|
||
return self._get_cached(f'stock_daily_basic_{trade_date}', fetch)
|
||
|
||
def get_stock_basic(self) -> pd.DataFrame:
|
||
"""
|
||
获取股票基本信息列表
|
||
|
||
Returns:
|
||
股票基本信息
|
||
"""
|
||
def fetch():
|
||
# stock_basic - 获取股票基本信息
|
||
return self.pro.stock_basic(
|
||
exchange='',
|
||
list_status='L',
|
||
fields='ts_code,symbol,name,area,industry,list_date'
|
||
)
|
||
|
||
return self._get_cached('stock_basic', fetch)
|
||
|
||
def get_realtime_data(self, ts_codes: List[str]) -> pd.DataFrame:
|
||
"""
|
||
获取实时行情数据(使用最新的日线数据)
|
||
|
||
注意:tushare 不提供真正的实时数据,这里返回最新的日线数据
|
||
注意:amount 字段单位是千元,需要 * 1000 转换为元
|
||
|
||
Args:
|
||
ts_codes: 股票代码列表
|
||
|
||
Returns:
|
||
实时行情数据(amount 单位为千元)
|
||
"""
|
||
if not ts_codes:
|
||
return pd.DataFrame()
|
||
|
||
# 获取今天的日期
|
||
today = datetime.now().strftime('%Y%m%d')
|
||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||
|
||
def fetch():
|
||
# 使用 daily 接口获取最近数据
|
||
codes_str = ','.join(ts_codes[:100]) # 限制单次查询数量
|
||
df = self.pro.daily(
|
||
ts_code=codes_str,
|
||
start_date=yesterday,
|
||
end_date=today
|
||
)
|
||
# 只返回每个股票的最新一天数据
|
||
if not df.empty:
|
||
df = df.sort_values('trade_date').groupby('ts_code').tail(1)
|
||
return df
|
||
|
||
return self._get_cached(f'realtime_{today}', fetch)
|
||
|
||
def get_hot_sectors(self, threshold: float = 2.0) -> pd.DataFrame:
|
||
"""
|
||
获取异动板块(一次性获取所有板块的最新行情)
|
||
|
||
Args:
|
||
threshold: 涨跌幅阈值(%)
|
||
|
||
Returns:
|
||
异动板块数据
|
||
"""
|
||
try:
|
||
# 1. 获取所有概念板块
|
||
sectors_df = self.get_concept_sectors()
|
||
if sectors_df.empty:
|
||
logger.warning("获取概念板块列表失败")
|
||
return pd.DataFrame()
|
||
|
||
logger.info(f"获取到 {len(sectors_df)} 个概念板块")
|
||
|
||
# 2. 获取今天的日期
|
||
today = datetime.now().strftime('%Y%m%d')
|
||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||
|
||
# 3. 批量获取板块行情(为了效率,限制数量)
|
||
hot_sectors = []
|
||
max_sectors = 100 # 最多检查100个板块
|
||
|
||
for idx, row in sectors_df.head(max_sectors).iterrows():
|
||
ts_code = row['ts_code']
|
||
name = row.get('name', '')
|
||
|
||
try:
|
||
# 获取板块最新行情
|
||
daily_df = self.pro.ths_daily(
|
||
ts_code=ts_code,
|
||
start_date=yesterday,
|
||
end_date=today
|
||
)
|
||
|
||
if daily_df.empty:
|
||
continue
|
||
|
||
# 获取最新一天的数据
|
||
latest = daily_df.sort_values('trade_date').iloc[-1]
|
||
|
||
# 检查涨跌幅 - 注意列名是 pct_change 不是 pct_chg
|
||
change_pct = float(latest.get('pct_change', 0))
|
||
if change_pct >= threshold:
|
||
hot_sectors.append({
|
||
'ts_code': ts_code,
|
||
'name': name,
|
||
'change_pct': change_pct,
|
||
'change': float(latest.get('change', 0)), # 涨跌额
|
||
'close': float(latest.get('close', 0)),
|
||
'amount': float(latest.get('amount', 0)), # 成交额(元)
|
||
'volume': float(latest.get('vol', 0)), # 成交量(手)
|
||
'turnover_rate': float(latest.get('turnover_rate', 0)), # 换手率
|
||
'trade_date': str(latest.get('trade_date', ''))
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.debug(f"获取板块 {name} 行情失败: {e}")
|
||
continue
|
||
|
||
result_df = pd.DataFrame(hot_sectors)
|
||
if not result_df.empty:
|
||
result_df = result_df.sort_values('change_pct', ascending=False)
|
||
|
||
return result_df
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取异动板块失败: {e}")
|
||
return pd.DataFrame()
|
||
|
||
|
||
# 全局单例
|
||
_tushare_client: Optional[TushareClient] = None
|
||
|
||
|
||
def get_tushare_client(token: str = None) -> Optional[TushareClient]:
|
||
"""获取 Tushare 客户端单例"""
|
||
global _tushare_client
|
||
if _tushare_client is None:
|
||
if not token:
|
||
return None
|
||
_tushare_client = TushareClient(token)
|
||
return _tushare_client
|