stock-ai-agent/backend/app/services/tushare_service.py
2026-02-03 10:08:15 +08:00

264 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tushare数据服务
封装Tushare API调用
"""
import tushare as ts
import pandas as pd
from typing import Optional, List
from datetime import datetime, timedelta
from app.config import get_settings
from app.utils.logger import logger
from app.utils.validators import normalize_stock_code
class TushareService:
"""Tushare数据服务类"""
def __init__(self):
"""初始化Tushare服务"""
settings = get_settings()
if not settings.tushare_token:
logger.warning("Tushare token未配置")
self.pro = None
else:
ts.set_token(settings.tushare_token)
self.pro = ts.pro_api()
logger.info("Tushare服务初始化成功")
def get_realtime_quote(self, stock_code: str) -> Optional[dict]:
"""
获取实时行情
Args:
stock_code: 股票代码
Returns:
行情数据字典
"""
if not self.pro:
logger.error("Tushare服务未初始化")
return None
try:
# 标准化股票代码
ts_code = normalize_stock_code(stock_code)
if not ts_code:
logger.error(f"无效的股票代码: {stock_code}")
return None
# 获取最新交易日数据
df = self.pro.daily(ts_code=ts_code, start_date='', end_date='')
if df.empty:
logger.warning(f"未找到股票数据: {ts_code}")
return None
# 取最新一条
latest = df.iloc[0]
# 获取股票名称
stock_info = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,name')
name = stock_info.iloc[0]['name'] if not stock_info.empty else None
return {
'ts_code': ts_code,
'name': name,
'trade_date': latest['trade_date'],
'open': float(latest['open']),
'high': float(latest['high']),
'low': float(latest['low']),
'close': float(latest['close']),
'pre_close': float(latest['pre_close']),
'change': float(latest['change']),
'pct_chg': float(latest['pct_chg']),
'vol': float(latest['vol']),
'amount': float(latest['amount'])
}
except Exception as e:
logger.error(f"获取实时行情失败: {e}")
return None
def get_kline_data(
self,
stock_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
period: str = 'D'
) -> Optional[List[dict]]:
"""
获取K线数据
Args:
stock_code: 股票代码
start_date: 开始日期YYYYMMDD
end_date: 结束日期YYYYMMDD
period: 周期D=日W=周M=月)
Returns:
K线数据列表
"""
if not self.pro:
logger.error("Tushare服务未初始化")
return None
try:
# 标准化股票代码
ts_code = normalize_stock_code(stock_code)
if not ts_code:
logger.error(f"无效的股票代码: {stock_code}")
return None
# 默认获取最近60个交易日
if not start_date:
start_date = (datetime.now() - timedelta(days=90)).strftime('%Y%m%d')
if not end_date:
end_date = datetime.now().strftime('%Y%m%d')
# 获取日线数据
if period == 'D':
df = self.pro.daily(
ts_code=ts_code,
start_date=start_date,
end_date=end_date
)
elif period == 'W':
df = self.pro.weekly(
ts_code=ts_code,
start_date=start_date,
end_date=end_date
)
elif period == 'M':
df = self.pro.monthly(
ts_code=ts_code,
start_date=start_date,
end_date=end_date
)
else:
logger.error(f"不支持的周期: {period}")
return None
if df.empty:
logger.warning(f"未找到K线数据: {ts_code}")
return None
# 按日期升序排列
df = df.sort_values('trade_date')
# 转换为字典列表
kline_data = []
for _, row in df.iterrows():
kline_data.append({
'ts_code': ts_code,
'trade_date': row['trade_date'],
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'vol': float(row['vol']),
'amount': float(row['amount']) if pd.notna(row['amount']) else None
})
return kline_data
except Exception as e:
logger.error(f"获取K线数据失败: {e}")
return None
def get_stock_basic(self, stock_code: str) -> Optional[dict]:
"""
获取股票基本信息
Args:
stock_code: 股票代码
Returns:
基本信息字典
"""
if not self.pro:
logger.error("Tushare服务未初始化")
return None
try:
ts_code = normalize_stock_code(stock_code)
if not ts_code:
return None
df = self.pro.stock_basic(
ts_code=ts_code,
fields='ts_code,symbol,name,area,industry,market,list_date'
)
if df.empty:
return None
info = df.iloc[0]
return {
'ts_code': info['ts_code'],
'symbol': info['symbol'],
'name': info['name'],
'area': info['area'],
'industry': info['industry'],
'market': info['market'],
'list_date': info['list_date']
}
except Exception as e:
logger.error(f"获取股票基本信息失败: {e}")
return None
def search_stock(self, keyword: str) -> Optional[List[dict]]:
"""
搜索股票(通过名称或代码)
Args:
keyword: 搜索关键词(股票名称或代码)
Returns:
匹配的股票列表
"""
if not self.pro:
logger.error("Tushare服务未初始化")
return None
try:
# 获取所有股票列表
df = self.pro.stock_basic(
fields='ts_code,symbol,name,area,industry,market,list_date'
)
if df.empty:
return None
# 搜索匹配的股票
# 1. 精确匹配代码
exact_match = df[df['symbol'] == keyword]
if not exact_match.empty:
return [exact_match.iloc[0].to_dict()]
# 2. 模糊匹配名称
name_match = df[df['name'].str.contains(keyword, na=False)]
if not name_match.empty:
results = []
for _, row in name_match.iterrows():
results.append(row.to_dict())
return results[:5] # 最多返回5个结果
# 3. 模糊匹配代码
code_match = df[df['symbol'].str.contains(keyword, na=False)]
if not code_match.empty:
results = []
for _, row in code_match.iterrows():
results.append(row.to_dict())
return results[:5]
return None
except Exception as e:
logger.error(f"搜索股票失败: {e}")
return None
# 创建全局实例
tushare_service = TushareService()