stock-ai-agent/backend/app/crypto_agent/setup_policy.py
2026-04-27 11:47:27 +08:00

126 lines
5.8 KiB
Python

"""
市场状态到交易行为的硬约束策略
"""
from collections import defaultdict
from typing import Any, Dict, List, Tuple
class SetupPolicy:
"""交易行为约束策略"""
RANGE_ONLY_SETUPS = {"range_reversal"}
TREND_ONLY_SETUPS = {"trend_continuation_pullback", "deep_pullback_continuation", "trend_reversal"}
BREAKOUT_ONLY_SETUPS = {"breakout_confirmation", "breakout_pullback"}
def filter_signals(
self,
signals: List[Dict[str, Any]],
regime_profile: Dict[str, Any],
) -> Tuple[List[Dict[str, Any]], List[str], Dict[str, int]]:
allowed_lanes = set(regime_profile.get("allowed_lanes") or [])
allowed_setups = set(regime_profile.get("allowed_setups") or [])
tradability = regime_profile.get("tradability", "avoid")
reasons: List[str] = []
reason_counts: Dict[str, int] = defaultdict(int)
if tradability == "avoid" or not allowed_lanes or not allowed_setups:
base_reasons = regime_profile.get("no_trade_reasons") or ["当前市场状态不允许交易"]
reasons.extend(base_reasons)
reason_counts["tradability_avoid"] += max(1, len(signals or []))
return [], reasons, dict(reason_counts)
kept: List[Dict[str, Any]] = []
for signal in signals or []:
lane = signal.get("timeframe") or signal.get("type") or "unknown"
setup_type = self._infer_setup_type(signal)
setup_basis = self._build_setup_basis(signal, setup_type)
entry_basis = self._build_entry_basis(signal, setup_type)
if lane not in allowed_lanes:
reasons.append(f"{lane} 不在允许交易周期内")
reason_counts[f"lane_blocked:{lane}"] += 1
continue
if setup_type not in allowed_setups:
reasons.append(f"{setup_type} 不在允许 setup 内")
reason_counts[f"setup_blocked:{setup_type}"] += 1
continue
kept.append({
**signal,
"setup_type": setup_type,
"setup_basis": setup_basis,
"entry_basis": entry_basis,
})
return kept, reasons, dict(reason_counts)
def _infer_setup_type(self, signal: Dict[str, Any]) -> str:
lane = signal.get("timeframe") or signal.get("type") or "medium_term"
action = signal.get("action")
entry_type = signal.get("entry_type", "market")
location_tag = ((signal.get("market_location") or {}).get("location_tag") or "unknown")
regime = signal.get("regime", "")
trend_stage = (signal.get("trend_stage") or {}).get("stage") or "unknown"
volume_context = signal.get("volume_price_context") or {}
breakout_quality = volume_context.get("breakout_quality") or signal.get("breakout_quality")
pullback_quality = volume_context.get("pullback_quality") or signal.get("pullback_quality")
rejection_signal = volume_context.get("rejection_signal") or signal.get("rejection_signal")
volume_price_state = volume_context.get("volume_price_state") or signal.get("volume_price_state")
if regime == "ranging" or location_tag in {"near_range_support", "near_range_resistance"}:
return "range_reversal"
if regime == "transitional" and (
breakout_quality in {"acceptance_breakout_up", "acceptance_breakout_down"} or
volume_price_state in {"bullish_acceptance", "bearish_acceptance"}
) and entry_type == "market":
return "breakout_confirmation"
if regime == "transitional" and entry_type == "limit":
return "breakout_pullback"
if lane == "medium_term" and entry_type == "limit" and pullback_quality == "healthy_pullback":
return "trend_continuation_pullback"
if lane == "short_term" and entry_type == "limit" and location_tag in {"near_long_zone", "near_short_zone"} and pullback_quality == "healthy_pullback":
return "deep_pullback_continuation"
if lane == "medium_term" and action in {"buy", "sell"} and (
rejection_signal in {"bullish_rejection", "bearish_rejection"} or trend_stage == "early"
):
return "trend_reversal"
return "unknown"
def _build_setup_basis(self, signal: Dict[str, Any], setup_type: str) -> str:
market_location = signal.get("market_location") or {}
volume_context = signal.get("volume_price_context") or {}
parts: List[str] = []
location_tag = market_location.get("location_tag")
if location_tag and location_tag != "unknown":
parts.append(f"location={location_tag}")
for key in ("volume_price_state", "breakout_quality", "pullback_quality", "rejection_signal"):
value = volume_context.get(key) or signal.get(key)
if value and value not in {"none", "neutral", "unknown"}:
parts.append(f"{key}={value}")
if setup_type != "unknown":
parts.insert(0, f"setup={setup_type}")
return " | ".join(parts[:4]) if parts else f"setup={setup_type}"
def _build_entry_basis(self, signal: Dict[str, Any], setup_type: str) -> str:
entry_type = signal.get("entry_type", "market")
market_location = signal.get("market_location") or {}
if setup_type == "breakout_confirmation":
return "breakout_acceptance_follow_through"
if setup_type in {"breakout_pullback", "trend_continuation_pullback", "deep_pullback_continuation"}:
return "pullback_into_trade_zone" if entry_type == "limit" else "pullback_confirmed"
if setup_type == "range_reversal":
location_tag = market_location.get("location_tag", "range_edge")
return f"reversal_from_{location_tag}"
if setup_type == "trend_reversal":
return "rejection_or_structure_shift"
return "generic_entry"