142 lines
5.8 KiB
Python
142 lines
5.8 KiB
Python
import tushare as ts
|
||
import pandas as pd
|
||
import logging
|
||
from typing import Optional, Dict, List
|
||
from datetime import datetime, timedelta
|
||
from config.config import Config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class TushareClient:
|
||
def __init__(self, token: str = None):
|
||
self.token = token or Config.TUSHARE_TOKEN
|
||
if not self.token:
|
||
raise ValueError("Tushare token is required")
|
||
|
||
ts.set_token(self.token)
|
||
self.pro = ts.pro_api()
|
||
|
||
def get_stock_list(self, exchange: str = None) -> pd.DataFrame:
|
||
try:
|
||
return self.pro.stock_basic(exchange=exchange, list_status='L')
|
||
except Exception as e:
|
||
logger.error(f"Failed to get stock list: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def get_stock_daily(self, ts_code: str, start_date: str = None,
|
||
end_date: str = None) -> pd.DataFrame:
|
||
try:
|
||
if not start_date:
|
||
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
|
||
if not end_date:
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
|
||
return self.pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
except Exception as e:
|
||
logger.error(f"Failed to get daily data for {ts_code}: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def get_financial_data(self, ts_code: str, period: str = None) -> Dict[str, pd.DataFrame]:
|
||
try:
|
||
data = {}
|
||
|
||
income = self.pro.income(ts_code=ts_code, period=period)
|
||
data['income'] = income
|
||
|
||
balance = self.pro.balancesheet(ts_code=ts_code, period=period)
|
||
data['balance'] = balance
|
||
|
||
cashflow = self.pro.cashflow(ts_code=ts_code, period=period)
|
||
data['cashflow'] = cashflow
|
||
|
||
return data
|
||
except Exception as e:
|
||
logger.error(f"Failed to get financial data for {ts_code}: {e}")
|
||
return {}
|
||
|
||
def get_industry_classify(self, ts_code: str = None) -> pd.DataFrame:
|
||
try:
|
||
return self.pro.industry_classify(ts_code=ts_code)
|
||
except Exception as e:
|
||
logger.error(f"Failed to get industry classify: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def get_hs300_stocks(self) -> pd.DataFrame:
|
||
"""获取沪深300成分股列表"""
|
||
try:
|
||
# 获取沪深300成分股
|
||
hs300_stocks = self.pro.index_weight(index_code='399300.SZ')
|
||
if hs300_stocks.empty:
|
||
logger.warning("No HS300 stocks found, fallback to CSI300")
|
||
# 备选:使用中证300
|
||
hs300_stocks = self.pro.index_weight(index_code='000300.SH')
|
||
|
||
if hs300_stocks.empty:
|
||
logger.error("No HS300 component stocks found")
|
||
return pd.DataFrame()
|
||
|
||
# 获取这些股票的基本信息
|
||
ts_codes = hs300_stocks['con_code'].tolist()
|
||
stock_basic_list = []
|
||
|
||
# 批量获取股票基本信息
|
||
for i in range(0, len(ts_codes), 50): # 每次获取50只股票
|
||
batch_codes = ts_codes[i:i+50]
|
||
try:
|
||
batch_info = self.pro.stock_basic(ts_code=','.join(batch_codes))
|
||
if not batch_info.empty:
|
||
stock_basic_list.append(batch_info)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to get batch stock info: {e}")
|
||
continue
|
||
|
||
if stock_basic_list:
|
||
result = pd.concat(stock_basic_list, ignore_index=True)
|
||
logger.info(f"Successfully retrieved {len(result)} HS300 component stocks")
|
||
return result
|
||
else:
|
||
logger.error("Failed to get any HS300 stock basic info")
|
||
return pd.DataFrame()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get HS300 stocks: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def get_stock_basic(self, ts_code: str) -> Optional[Dict]:
|
||
"""获取股票基本信息"""
|
||
try:
|
||
# 获取股票基本信息
|
||
basic_info = self.pro.stock_basic(ts_code=ts_code)
|
||
if basic_info.empty:
|
||
return None
|
||
|
||
# 获取最新的日行情数据(当前价格和总市值)
|
||
latest_daily = self.pro.daily_basic(ts_code=ts_code, limit=1)
|
||
|
||
# 组合基本信息
|
||
result = {
|
||
'ts_code': ts_code,
|
||
'name': basic_info.iloc[0]['name'] if not basic_info.empty else '',
|
||
'industry': basic_info.iloc[0]['industry'] if not basic_info.empty else '',
|
||
'list_date': basic_info.iloc[0]['list_date'] if not basic_info.empty else '',
|
||
'current_price': latest_daily.iloc[0]['close'] if not latest_daily.empty else 0,
|
||
'market_cap': latest_daily.iloc[0]['total_mv'] if not latest_daily.empty and 'total_mv' in latest_daily.columns else 0,
|
||
}
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get stock basic info for {ts_code}: {e}")
|
||
return None
|
||
|
||
def get_trade_cal(self, start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||
try:
|
||
if not start_date:
|
||
start_date = datetime.now().strftime('%Y%m%d')
|
||
if not end_date:
|
||
end_date = (datetime.now() + timedelta(days=30)).strftime('%Y%m%d')
|
||
|
||
return self.pro.trade_cal(start_date=start_date, end_date=end_date)
|
||
except Exception as e:
|
||
logger.error(f"Failed to get trade calendar: {e}")
|
||
return pd.DataFrame() |