stock-ai-agent/backend/app/astock_agent/tushare_client.py
2026-02-27 09:54:17 +08:00

347 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tushare 数据封装
提供 A 股板块、个股行情数据获取接口(使用同花顺系列接口)
"""
import time
import tushare as ts
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from app.utils.logger import logger
class TushareClient:
"""Tushare 数据客户端(同花顺系列接口)"""
# 缓存数据,避免频繁请求
_cache = {}
_cache_time = {}
_last_request_time = 0
def __init__(self, token: str):
"""
初始化客户端
Args:
token: Tushare token
"""
self.token = token
ts.set_token(token)
self.pro = ts.pro_api()
self.cache_ttl = 300 # 缓存5分钟
self.request_delay = 0.5 # 请求间隔(秒)- tushare 有频率限制
def _get_cached(self, key: str, fetch_func) -> pd.DataFrame:
"""获取缓存数据,支持重试"""
now = datetime.now()
# 检查缓存
if key in self._cache:
cache_time = self._cache_time.get(key)
if cache_time and (now - cache_time).seconds < self.cache_ttl:
logger.debug(f"使用缓存数据: {key}")
return self._cache[key]
# 请求限流
elapsed = now.timestamp() - self._last_request_time
if elapsed < self.request_delay:
time.sleep(self.request_delay - elapsed)
# 重试逻辑
max_retries = 3
for attempt in range(max_retries):
try:
self._last_request_time = time.time()
df = fetch_func()
if df is not None and not df.empty:
self._cache[key] = df
self._cache_time[key] = now
logger.debug(f"获取数据成功: {key}")
return df
except Exception as e:
error_msg = str(e)
# 指数退避重试
if attempt < max_retries - 1:
wait_time = (2 ** attempt) * 2
logger.warning(
f"获取数据失败 {key} (尝试 {attempt + 1}/{max_retries}): {e}"
f"等待 {wait_time}秒后重试..."
)
time.sleep(wait_time)
continue
logger.error(f"获取数据失败 {key}: {e}")
break
return pd.DataFrame()
def get_concept_sectors(self) -> pd.DataFrame:
"""
获取概念板块列表
使用 ths_index 接口type="N" 代表概念板块
Returns:
概念板块列表
"""
def fetch():
# ths_index - 获取同花顺概念指数列表
return self.pro.ths_index(type='N')
return self._get_cached('concept_sectors', fetch)
def get_sector_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
获取板块日线行情
Args:
ts_code: 板块指数代码(如 885823.TI
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
Returns:
板块日线数据
"""
if not start_date:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if not end_date:
end_date = datetime.now().strftime('%Y%m%d')
def fetch():
# ths_daily - 获取板块指数历史行情
return self.pro.ths_daily(
ts_code=ts_code,
start_date=start_date,
end_date=end_date
)
return self._get_cached(f'sector_daily_{ts_code}_{end_date}', fetch)
def get_sector_members(self, ts_code: str) -> pd.DataFrame:
"""
获取板块成分股
Args:
ts_code: 板块指数代码(如 885823.TI
Returns:
成分股列表
"""
def fetch():
# ths_member - 获取板块成分股
return self.pro.ths_member(ts_code=ts_code)
return self._get_cached(f'sector_members_{ts_code}', fetch)
def get_stock_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
获取个股日线行情
Args:
ts_code: 股票代码(如 000001.SZ
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
Returns:
日线数据
"""
if not start_date:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if not end_date:
end_date = datetime.now().strftime('%Y%m%d')
def fetch():
# daily - 获取日线行情
return self.pro.daily(
ts_code=ts_code,
start_date=start_date,
end_date=end_date
)
return self._get_cached(f'stock_daily_{ts_code}_{end_date}', fetch)
def get_stock_daily_basic(self, ts_codes: List[str], trade_date: str = None) -> pd.DataFrame:
"""
获取个股每日指标(包含换手率、量比等)
Args:
ts_codes: 股票代码列表
trade_date: 交易日期 (YYYYMMDD)
Returns:
每日指标数据
"""
if not ts_codes:
return pd.DataFrame()
from datetime import datetime, timedelta
if not trade_date:
trade_date = datetime.now().strftime('%Y%m%d')
def fetch():
# daily_basic - 获取每日指标
codes_str = ','.join(ts_codes[:300]) # 限制单次查询数量
# 尝试获取最近3天的数据以防当天数据未更新
all_data = []
for i in range(3):
try_date = (datetime.now() - timedelta(days=i)).strftime('%Y%m%d')
df = self.pro.daily_basic(
ts_code=codes_str,
trade_date=try_date,
fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb'
)
if not df.empty:
all_data.append(df)
# 如果找到数据就不再尝试更早的日期
break
if all_data:
return pd.concat(all_data, ignore_index=True)
return pd.DataFrame()
return self._get_cached(f'stock_daily_basic_{trade_date}', fetch)
def get_stock_basic(self) -> pd.DataFrame:
"""
获取股票基本信息列表
Returns:
股票基本信息
"""
def fetch():
# stock_basic - 获取股票基本信息
return self.pro.stock_basic(
exchange='',
list_status='L',
fields='ts_code,symbol,name,area,industry,list_date'
)
return self._get_cached('stock_basic', fetch)
def get_realtime_data(self, ts_codes: List[str]) -> pd.DataFrame:
"""
获取实时行情数据(使用最新的日线数据)
注意tushare 不提供真正的实时数据,这里返回最新的日线数据
注意amount 字段单位是千元,需要 * 1000 转换为元
Args:
ts_codes: 股票代码列表
Returns:
实时行情数据amount 单位为千元)
"""
if not ts_codes:
return pd.DataFrame()
# 获取今天的日期
today = datetime.now().strftime('%Y%m%d')
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
def fetch():
# 使用 daily 接口获取最近数据
codes_str = ','.join(ts_codes[:100]) # 限制单次查询数量
df = self.pro.daily(
ts_code=codes_str,
start_date=yesterday,
end_date=today
)
# 只返回每个股票的最新一天数据
if not df.empty:
df = df.sort_values('trade_date').groupby('ts_code').tail(1)
return df
return self._get_cached(f'realtime_{today}', fetch)
def get_hot_sectors(self, threshold: float = 2.0) -> pd.DataFrame:
"""
获取异动板块(一次性获取所有板块的最新行情)
Args:
threshold: 涨跌幅阈值(%
Returns:
异动板块数据
"""
try:
# 1. 获取所有概念板块
sectors_df = self.get_concept_sectors()
if sectors_df.empty:
logger.warning("获取概念板块列表失败")
return pd.DataFrame()
logger.info(f"获取到 {len(sectors_df)} 个概念板块")
# 2. 获取今天的日期
today = datetime.now().strftime('%Y%m%d')
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
# 3. 批量获取板块行情(为了效率,限制数量)
hot_sectors = []
max_sectors = 100 # 最多检查100个板块
for idx, row in sectors_df.head(max_sectors).iterrows():
ts_code = row['ts_code']
name = row.get('name', '')
try:
# 获取板块最新行情
daily_df = self.pro.ths_daily(
ts_code=ts_code,
start_date=yesterday,
end_date=today
)
if daily_df.empty:
continue
# 获取最新一天的数据
latest = daily_df.sort_values('trade_date').iloc[-1]
# 检查涨跌幅 - 注意列名是 pct_change 不是 pct_chg
change_pct = float(latest.get('pct_change', 0))
if change_pct >= threshold:
hot_sectors.append({
'ts_code': ts_code,
'name': name,
'change_pct': change_pct,
'change': float(latest.get('change', 0)), # 涨跌额
'close': float(latest.get('close', 0)),
'amount': float(latest.get('amount', 0)), # 成交额(元)
'volume': float(latest.get('vol', 0)), # 成交量(手)
'turnover_rate': float(latest.get('turnover_rate', 0)), # 换手率
'trade_date': str(latest.get('trade_date', ''))
})
except Exception as e:
logger.debug(f"获取板块 {name} 行情失败: {e}")
continue
result_df = pd.DataFrame(hot_sectors)
if not result_df.empty:
result_df = result_df.sort_values('change_pct', ascending=False)
return result_df
except Exception as e:
logger.error(f"获取异动板块失败: {e}")
return pd.DataFrame()
# 全局单例
_tushare_client: Optional[TushareClient] = None
def get_tushare_client(token: str = None) -> Optional[TushareClient]:
"""获取 Tushare 客户端单例"""
global _tushare_client
if _tushare_client is None:
if not token:
return None
_tushare_client = TushareClient(token)
return _tushare_client