astock-agent/backend/app/llm/strategy_selector.py
2026-04-28 13:15:11 +08:00

258 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""动态策略选择器
在固定筛选引擎前增加一层“先选打法,再选股票”的策略决策。
规则负责稳定分类LLM 负责补充语义判断和操作建议。
"""
import json
import logging
from pydantic import BaseModel
from app.config import settings
from app.data.models import MarketTemperature, SectorInfo
logger = logging.getLogger(__name__)
class StrategyProfile(BaseModel):
strategy_id: str
name: str
description: str
entry_signal_priority: list[str]
score_weights: dict[str, float]
min_score: float
buy_threshold: float
max_position_pct: float
allow_trading: bool = True
actionable_limit: int = 2
watch_limit: int = 4
target_focus_sectors: int = 2
market_stance: str = ""
decision_note: str = ""
notes: list[str] = []
feedback_applied: bool = False
feedback_notes: list[str] = []
generated_by: str = "rules"
def get_strategy_profile_by_id(strategy_id: str) -> StrategyProfile:
normalized = strategy_id or "defensive_watch"
if normalized == "trend_breakout":
normalized = "breakout_attack"
actionable_cap = max(0, settings.actionable_limit)
watch_cap = max(0, settings.watch_limit)
profiles = {
"breakout_attack": StrategyProfile(
strategy_id="breakout_attack",
name="主线突破",
description="市场偏强,优先寻找主线板块内的突破和突破确认。",
entry_signal_priority=["breakout", "breakout_confirm", "launch", "pullback", "reversal"],
score_weights={"supply_demand": 0.45, "price_action": 0.35, "trend": 0.20},
min_score=62,
buy_threshold=66,
max_position_pct=30,
allow_trading=True,
actionable_limit=min(3, actionable_cap),
watch_limit=min(4, watch_cap),
target_focus_sectors=2,
market_stance="主线进攻",
decision_note="只处理最强主线前排,不扩散到跟风和后排。",
notes=["优先做主线早中期板块", "放量突破优先于回踩低吸"],
),
"pullback_rotation": StrategyProfile(
strategy_id="pullback_rotation",
name="回踩轮动",
description="市场震荡分化,优先做回踩支撑和板块轮动中的低吸确认。",
entry_signal_priority=["pullback", "breakout_confirm", "launch", "breakout", "reversal"],
score_weights={"supply_demand": 0.40, "price_action": 0.30, "trend": 0.30},
min_score=60,
buy_threshold=63,
max_position_pct=20,
allow_trading=True,
actionable_limit=min(2, actionable_cap),
watch_limit=min(5, watch_cap),
target_focus_sectors=2,
market_stance="轮动低吸",
decision_note="先等回踩承接和板块回流,再决定是否出手。",
notes=["降低追高仓位", "更看重位置安全和回踩承接"],
),
"launch_probe": StrategyProfile(
strategy_id="launch_probe",
name="启动试错",
description="市场偏弱,适合少量观察启动型和反转型机会,不做强追涨。",
entry_signal_priority=["launch", "reversal", "pullback", "breakout_confirm", "breakout"],
score_weights={"supply_demand": 0.35, "price_action": 0.35, "trend": 0.30},
min_score=58,
buy_threshold=61,
max_position_pct=10,
allow_trading=True,
actionable_limit=min(1, actionable_cap),
watch_limit=min(4, watch_cap),
target_focus_sectors=1,
market_stance="轻仓试错",
decision_note="只有极少数启动确认标的值得小仓试错。",
notes=["仅做小仓位试错", "突破型需要更强板块一致性才可介入"],
),
"defensive_watch": StrategyProfile(
strategy_id="defensive_watch",
name="防守观察",
description="市场退潮,系统以观察池为主,不主动扩大出手。",
entry_signal_priority=["pullback", "launch", "reversal", "breakout_confirm", "breakout"],
score_weights={"supply_demand": 0.35, "price_action": 0.40, "trend": 0.25},
min_score=56,
buy_threshold=64,
max_position_pct=5,
allow_trading=False,
actionable_limit=0,
watch_limit=min(3, watch_cap),
target_focus_sectors=1,
market_stance="防守观察",
decision_note="今天不主动出手,只保留少量观察名单。",
notes=["原则上只保留观察池", "等待市场温度修复后再转入主动进攻"],
),
}
return profiles.get(normalized, profiles["defensive_watch"]).model_copy(deep=True)
async def select_strategy_profile(
market_temp: MarketTemperature | None,
hot_sectors: list[SectorInfo],
intraday: bool,
) -> StrategyProfile:
profile = _select_rule_profile(market_temp, hot_sectors, intraday)
profile = await _apply_strategy_feedback(profile)
if settings.deepseek_api_key:
llm_profile = await _select_llm_profile(market_temp, hot_sectors, intraday, profile)
if llm_profile:
profile = llm_profile
return profile
def _select_rule_profile(
market_temp: MarketTemperature | None,
hot_sectors: list[SectorInfo],
intraday: bool,
) -> StrategyProfile:
temp = market_temp.temperature if market_temp else 0
early_count = sum(1 for s in hot_sectors[:5] if s.stage == "early")
late_count = sum(1 for s in hot_sectors[:5] if s.stage in ("late", "end"))
if temp >= 65 and early_count >= 1:
return get_strategy_profile_by_id("breakout_attack")
if temp >= 45 and late_count < 2:
return get_strategy_profile_by_id("pullback_rotation")
if temp >= 30:
return get_strategy_profile_by_id("launch_probe")
return get_strategy_profile_by_id("defensive_watch")
async def _apply_strategy_feedback(profile: StrategyProfile) -> StrategyProfile:
from app.llm.strategy_iteration import build_strategy_feedback_controls
try:
controls = await build_strategy_feedback_controls(limit=50)
except Exception as e:
logger.debug(f"策略反馈控制生成失败: {e}")
return profile
if not controls.get("enabled"):
return profile
updated = profile.model_copy(deep=True)
updated.feedback_applied = True
if controls.get("force_defensive"):
updated.allow_trading = False
updated.actionable_limit = 0
updated.watch_limit = min(updated.watch_limit, 3)
updated.max_position_pct = min(updated.max_position_pct, 10)
updated.market_stance = "防守观察"
updated.buy_threshold = max(updated.min_score, min(updated.buy_threshold + int(controls.get("buy_threshold_delta") or 0), 80))
updated.max_position_pct = max(0, min(updated.max_position_pct + int(controls.get("max_position_pct_delta") or 0), 40))
updated.actionable_limit = max(0, min(updated.actionable_limit + int(controls.get("actionable_limit_delta") or 0), settings.actionable_limit))
updated.watch_limit = max(1, min(updated.watch_limit + int(controls.get("watch_limit_delta") or 0), settings.watch_limit))
notes = controls.get("notes") or []
if notes:
updated.feedback_notes = notes[:3]
updated.notes.extend(notes[:2])
updated.decision_note = notes[0]
updated.generated_by = f"{updated.generated_by}+feedback"
return updated
async def _select_llm_profile(
market_temp: MarketTemperature | None,
hot_sectors: list[SectorInfo],
intraday: bool,
fallback: StrategyProfile,
) -> StrategyProfile | None:
from app.llm.client import chat_completion
sector_text = "\n".join(
f"- {s.sector_name}: 涨幅{s.pct_change}%, 热度{s.heat_score}, 阶段{s.stage}, 涨停{s.limit_up_count}"
for s in hot_sectors[:5]
) or "暂无板块数据"
user_msg = f"""你需要为今日A股环境选择一个短线策略模板。
市场温度: {market_temp.temperature if market_temp else 0}
上涨家数: {market_temp.up_count if market_temp else 0}
下跌家数: {market_temp.down_count if market_temp else 0}
涨停数: {market_temp.limit_up_count if market_temp else 0}
炸板率: {market_temp.broken_rate if market_temp else 0}
盘中模式: {'' if intraday else ''}
热门板块:
{sector_text}
规则候选策略:
- breakout_attack: 主线突破
- pullback_rotation: 回踩轮动
- launch_probe: 启动试错
- defensive_watch: 防守观察
请输出 JSON格式:
{{
"strategy_id": "上面四选一",
"notes": ["两条以内理由"],
"buy_threshold_delta": -3到3之间的整数
}}
"""
resp = await chat_completion([
{"role": "system", "content": "你是一位A股短线策略研究员只能在给定策略模板中选择不要发明新策略。回复必须是 JSON。"},
{"role": "user", "content": user_msg},
])
if not resp or not resp.content:
return None
try:
data = json.loads(resp.content)
strategy_id = data.get("strategy_id")
if strategy_id not in {"breakout_attack", "pullback_rotation", "launch_probe", "defensive_watch"}:
return None
selected = _select_rule_profile(market_temp, hot_sectors, intraday)
if selected.strategy_id != strategy_id:
selected = get_strategy_profile_by_id(strategy_id)
delta = int(data.get("buy_threshold_delta", 0))
delta = max(-3, min(3, delta))
selected.buy_threshold += delta
selected.notes.extend(data.get("notes", [])[:2])
selected.generated_by = "rules+llm"
return selected
except Exception as e:
logger.debug(f"LLM 策略选择解析失败: {e}")
return fallback