stock-agent/src/data/tushare_client.py
2025-12-28 10:12:30 +08:00

142 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tushare as ts
import pandas as pd
import logging
from typing import Optional, Dict, List
from datetime import datetime, timedelta
from config.config import Config
logger = logging.getLogger(__name__)
class TushareClient:
def __init__(self, token: str = None):
self.token = token or Config.TUSHARE_TOKEN
if not self.token:
raise ValueError("Tushare token is required")
ts.set_token(self.token)
self.pro = ts.pro_api()
def get_stock_list(self, exchange: str = None) -> pd.DataFrame:
try:
return self.pro.stock_basic(exchange=exchange, list_status='L')
except Exception as e:
logger.error(f"Failed to get stock list: {e}")
return pd.DataFrame()
def get_stock_daily(self, ts_code: str, start_date: str = None,
end_date: str = None) -> pd.DataFrame:
try:
if not start_date:
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
if not end_date:
end_date = datetime.now().strftime('%Y%m%d')
return self.pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
except Exception as e:
logger.error(f"Failed to get daily data for {ts_code}: {e}")
return pd.DataFrame()
def get_financial_data(self, ts_code: str, period: str = None) -> Dict[str, pd.DataFrame]:
try:
data = {}
income = self.pro.income(ts_code=ts_code, period=period)
data['income'] = income
balance = self.pro.balancesheet(ts_code=ts_code, period=period)
data['balance'] = balance
cashflow = self.pro.cashflow(ts_code=ts_code, period=period)
data['cashflow'] = cashflow
return data
except Exception as e:
logger.error(f"Failed to get financial data for {ts_code}: {e}")
return {}
def get_industry_classify(self, ts_code: str = None) -> pd.DataFrame:
try:
return self.pro.industry_classify(ts_code=ts_code)
except Exception as e:
logger.error(f"Failed to get industry classify: {e}")
return pd.DataFrame()
def get_hs300_stocks(self) -> pd.DataFrame:
"""获取沪深300成分股列表"""
try:
# 获取沪深300成分股
hs300_stocks = self.pro.index_weight(index_code='399300.SZ')
if hs300_stocks.empty:
logger.warning("No HS300 stocks found, fallback to CSI300")
# 备选使用中证300
hs300_stocks = self.pro.index_weight(index_code='000300.SH')
if hs300_stocks.empty:
logger.error("No HS300 component stocks found")
return pd.DataFrame()
# 获取这些股票的基本信息
ts_codes = hs300_stocks['con_code'].tolist()
stock_basic_list = []
# 批量获取股票基本信息
for i in range(0, len(ts_codes), 50): # 每次获取50只股票
batch_codes = ts_codes[i:i+50]
try:
batch_info = self.pro.stock_basic(ts_code=','.join(batch_codes))
if not batch_info.empty:
stock_basic_list.append(batch_info)
except Exception as e:
logger.warning(f"Failed to get batch stock info: {e}")
continue
if stock_basic_list:
result = pd.concat(stock_basic_list, ignore_index=True)
logger.info(f"Successfully retrieved {len(result)} HS300 component stocks")
return result
else:
logger.error("Failed to get any HS300 stock basic info")
return pd.DataFrame()
except Exception as e:
logger.error(f"Failed to get HS300 stocks: {e}")
return pd.DataFrame()
def get_stock_basic(self, ts_code: str) -> Optional[Dict]:
"""获取股票基本信息"""
try:
# 获取股票基本信息
basic_info = self.pro.stock_basic(ts_code=ts_code)
if basic_info.empty:
return None
# 获取最新的日行情数据(当前价格和总市值)
latest_daily = self.pro.daily_basic(ts_code=ts_code, limit=1)
# 组合基本信息
result = {
'ts_code': ts_code,
'name': basic_info.iloc[0]['name'] if not basic_info.empty else '',
'industry': basic_info.iloc[0]['industry'] if not basic_info.empty else '',
'list_date': basic_info.iloc[0]['list_date'] if not basic_info.empty else '',
'current_price': latest_daily.iloc[0]['close'] if not latest_daily.empty else 0,
'market_cap': latest_daily.iloc[0]['total_mv'] if not latest_daily.empty and 'total_mv' in latest_daily.columns else 0,
}
return result
except Exception as e:
logger.error(f"Failed to get stock basic info for {ts_code}: {e}")
return None
def get_trade_cal(self, start_date: str = None, end_date: str = None) -> pd.DataFrame:
try:
if not start_date:
start_date = datetime.now().strftime('%Y%m%d')
if not end_date:
end_date = (datetime.now() + timedelta(days=30)).strftime('%Y%m%d')
return self.pro.trade_cal(start_date=start_date, end_date=end_date)
except Exception as e:
logger.error(f"Failed to get trade calendar: {e}")
return pd.DataFrame()