stock-ai-agent/backend/app/astock_agent/tushare_stock_selector.py
2026-02-27 09:54:17 +08:00

245 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
龙头股筛选Tushare 版本)
从异动板块中筛选出龙头股
"""
import pandas as pd
from typing import Dict, List
from datetime import datetime
from app.utils.logger import logger
class TushareStockSelector:
"""龙头股筛选器(使用 Tushare"""
def __init__(self, tushare_client, top_n: int = 3):
"""
初始化筛选器
Args:
tushare_client: TushareClient 实例
top_n: 返回前 N 只龙头股
"""
self.top_n = top_n
self.ts_client = tushare_client
def select_leading_stocks(self, ts_code: str, sector_name: str) -> List[Dict]:
"""
筛选板块龙头股
Args:
ts_code: 板块指数代码
sector_name: 板块名称
Returns:
龙头股列表(已排序)
"""
try:
# 获取成分股
members_df = self.ts_client.get_sector_members(ts_code)
if members_df.empty:
logger.warning(f"获取板块 {sector_name} 成分股失败")
return []
# ths_member 返回的是 con_code成分股代码需要用这个来查行情
stock_codes = members_df['con_code'].tolist()
# 限制数量,避免请求过多
if len(stock_codes) > 50:
stock_codes = stock_codes[:50]
# 获取实时行情
realtime_df = self.ts_client.get_realtime_data(stock_codes)
if realtime_df.empty:
logger.warning(f"获取板块 {sector_name} 成分股行情失败")
return []
# 获取每日指标(换手率、量比)
from datetime import datetime
trade_date = datetime.now().strftime('%Y%m%d')
basic_df = self.ts_client.get_stock_daily_basic(stock_codes, trade_date)
# 合并数据 - 注意ths_member 的 con_code 对应 daily 的 ts_code
members_df = members_df.rename(columns={'con_code': 'stock_code'})
realtime_df = realtime_df.rename(columns={'ts_code': 'stock_code'})
if not basic_df.empty:
basic_df = basic_df.rename(columns={'ts_code': 'stock_code'})
merged = pd.merge(
members_df[['stock_code', 'con_name']],
realtime_df,
on='stock_code',
how='inner'
)
merged = pd.merge(
merged,
basic_df[['stock_code', 'turnover_rate', 'volume_ratio']],
on='stock_code',
how='left'
)
else:
merged = pd.merge(
members_df[['stock_code', 'con_name']],
realtime_df,
on='stock_code',
how='inner'
)
if merged.empty:
return []
# 数据类型转换 - daily 接口返回 pct_chg 不是 pct_change
merged['close'] = pd.to_numeric(merged['close'], errors='coerce')
merged['pct_chg'] = pd.to_numeric(merged['pct_chg'], errors='coerce')
merged['change'] = pd.to_numeric(merged['change'], errors='coerce')
merged['vol'] = pd.to_numeric(merged['vol'], errors='coerce')
# 注意daily 接口的 amount 单位是千元,需要转换为元
merged['amount'] = pd.to_numeric(merged['amount'], errors='coerce') * 1000
# 换手率和量比填充默认值
if 'turnover_rate' in merged.columns:
merged['turnover_rate'] = pd.to_numeric(merged['turnover_rate'], errors='coerce').fillna(0)
else:
merged['turnover_rate'] = 0.0
if 'volume_ratio' in merged.columns:
merged['volume_ratio'] = pd.to_numeric(merged['volume_ratio'], errors='coerce').fillna(1.0)
else:
merged['volume_ratio'] = 1.0
# 过滤:只保留有成交额的股票
merged = merged[merged['amount'] > 0].copy()
if merged.empty:
return []
# 计算综合评分
merged['score'] = merged.apply(self._calculate_score, axis=1)
# 排序:按综合得分
merged = merged.sort_values('score', ascending=False)
# 取前 N 只
top_stocks = merged.head(self.top_n)
# 转换结果
results = []
for _, row in top_stocks.iterrows():
# 计算涨速等级
change_pct = row['pct_chg']
if change_pct >= 5:
speed_level = "⚡⚡⚡ 极快"
elif change_pct >= 3:
speed_level = "⚡⚡ 快速"
elif change_pct >= 1:
speed_level = "⚡ 较快"
else:
speed_level = "🐌 平稳"
# 计算振幅
amplitude = 0.0
if 'high' in row and 'low' in row and row['low'] > 0:
amplitude = (row['high'] - row['low']) / row['low'] * 100
results.append({
'code': row['stock_code'],
'name': row['con_name'],
'price': float(row['close']),
'change_pct': float(row['pct_chg']),
'change_amount': float(row['change']),
'amount': float(row['amount']),
'turnover': float(row.get('turnover_rate', 0)),
'volume_ratio': float(row.get('volume_ratio', 1.0)),
'amplitude': amplitude,
'score': float(row['score']),
'speed_level': speed_level,
})
logger.info(f"板块 {sector_name} 龙头股筛选完成Top {len(results)}")
return results
except Exception as e:
logger.error(f"筛选龙头股失败 {sector_name}: {e}")
return []
def _calculate_score(self, row: pd.Series) -> float:
"""
计算综合得分
评分维度:
- 涨跌幅 (40%)
- 成交额 (30%)
- 涨速 (20%)
- 换手率 (10%)
Args:
row: 股票数据行
Returns:
综合得分
"""
score = 0.0
# 1. 涨跌幅得分 (40分) - 涨幅越高得分越高
change_pct = row['pct_chg']
if change_pct >= 7:
score += 40 # 涨停级别
elif change_pct >= 5:
score += 35
elif change_pct >= 3:
score += 30
elif change_pct >= 2:
score += 25
elif change_pct >= 1:
score += 20
elif change_pct > 0:
score += 15
else:
score += max(0, 10 + change_pct * 5) # 下跌也有基础分
# 2. 成交额得分 (30分) - 成交额越大得分越高
# 注意amount 已在 select_leading_stocks 中从千元转换为元
amount = row['amount'] # 单位是元
if amount >= 1000000000: # 10亿以上
score += 30
elif amount >= 500000000: # 5亿以上
score += 25
elif amount >= 100000000: # 1亿以上
score += 20
elif amount >= 50000000: # 5000万以上
score += 15
elif amount >= 10000000: # 1000万以上
score += 10
else:
score += 5
# 3. 涨速得分 (20分) - 简化用涨幅代替
if change_pct >= 5:
score += 20
elif change_pct >= 3:
score += 15
elif change_pct >= 1:
score += 10
else:
score += 5
# 4. 换手率得分 (10分) - 使用真实换手率数据
turnover_rate = row.get('turnover_rate', 0)
if turnover_rate >= 15:
score += 10 # 换手率极高,资金活跃
elif turnover_rate >= 10:
score += 9
elif turnover_rate >= 7:
score += 8
elif turnover_rate >= 5:
score += 7
elif turnover_rate >= 3:
score += 6
elif turnover_rate >= 1:
score += 4
elif turnover_rate >= 0.5:
score += 2
else:
score += 1 # 换手率较低
return score