126 lines
5.8 KiB
Python
126 lines
5.8 KiB
Python
"""
|
|
市场状态到交易行为的硬约束策略
|
|
"""
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
|
|
class SetupPolicy:
|
|
"""交易行为约束策略"""
|
|
|
|
RANGE_ONLY_SETUPS = {"range_reversal"}
|
|
TREND_ONLY_SETUPS = {"trend_continuation_pullback", "deep_pullback_continuation", "trend_reversal"}
|
|
BREAKOUT_ONLY_SETUPS = {"breakout_confirmation", "breakout_pullback"}
|
|
|
|
def filter_signals(
|
|
self,
|
|
signals: List[Dict[str, Any]],
|
|
regime_profile: Dict[str, Any],
|
|
) -> Tuple[List[Dict[str, Any]], List[str], Dict[str, int]]:
|
|
allowed_lanes = set(regime_profile.get("allowed_lanes") or [])
|
|
allowed_setups = set(regime_profile.get("allowed_setups") or [])
|
|
tradability = regime_profile.get("tradability", "avoid")
|
|
reasons: List[str] = []
|
|
reason_counts: Dict[str, int] = defaultdict(int)
|
|
|
|
if tradability == "avoid" or not allowed_lanes or not allowed_setups:
|
|
base_reasons = regime_profile.get("no_trade_reasons") or ["当前市场状态不允许交易"]
|
|
reasons.extend(base_reasons)
|
|
reason_counts["tradability_avoid"] += max(1, len(signals or []))
|
|
return [], reasons, dict(reason_counts)
|
|
|
|
kept: List[Dict[str, Any]] = []
|
|
for signal in signals or []:
|
|
lane = signal.get("timeframe") or signal.get("type") or "unknown"
|
|
setup_type = self._infer_setup_type(signal)
|
|
setup_basis = self._build_setup_basis(signal, setup_type)
|
|
entry_basis = self._build_entry_basis(signal, setup_type)
|
|
|
|
if lane not in allowed_lanes:
|
|
reasons.append(f"{lane} 不在允许交易周期内")
|
|
reason_counts[f"lane_blocked:{lane}"] += 1
|
|
continue
|
|
|
|
if setup_type not in allowed_setups:
|
|
reasons.append(f"{setup_type} 不在允许 setup 内")
|
|
reason_counts[f"setup_blocked:{setup_type}"] += 1
|
|
continue
|
|
|
|
kept.append({
|
|
**signal,
|
|
"setup_type": setup_type,
|
|
"setup_basis": setup_basis,
|
|
"entry_basis": entry_basis,
|
|
})
|
|
|
|
return kept, reasons, dict(reason_counts)
|
|
|
|
def _infer_setup_type(self, signal: Dict[str, Any]) -> str:
|
|
lane = signal.get("timeframe") or signal.get("type") or "medium_term"
|
|
action = signal.get("action")
|
|
entry_type = signal.get("entry_type", "market")
|
|
location_tag = ((signal.get("market_location") or {}).get("location_tag") or "unknown")
|
|
regime = signal.get("regime", "")
|
|
trend_stage = (signal.get("trend_stage") or {}).get("stage") or "unknown"
|
|
volume_context = signal.get("volume_price_context") or {}
|
|
breakout_quality = volume_context.get("breakout_quality") or signal.get("breakout_quality")
|
|
pullback_quality = volume_context.get("pullback_quality") or signal.get("pullback_quality")
|
|
rejection_signal = volume_context.get("rejection_signal") or signal.get("rejection_signal")
|
|
volume_price_state = volume_context.get("volume_price_state") or signal.get("volume_price_state")
|
|
|
|
if regime == "ranging" or location_tag in {"near_range_support", "near_range_resistance"}:
|
|
return "range_reversal"
|
|
|
|
if regime == "transitional" and (
|
|
breakout_quality in {"acceptance_breakout_up", "acceptance_breakout_down"} or
|
|
volume_price_state in {"bullish_acceptance", "bearish_acceptance"}
|
|
) and entry_type == "market":
|
|
return "breakout_confirmation"
|
|
if regime == "transitional" and entry_type == "limit":
|
|
return "breakout_pullback"
|
|
|
|
if lane == "medium_term" and entry_type == "limit" and pullback_quality == "healthy_pullback":
|
|
return "trend_continuation_pullback"
|
|
if lane == "short_term" and entry_type == "limit" and location_tag in {"near_long_zone", "near_short_zone"} and pullback_quality == "healthy_pullback":
|
|
return "deep_pullback_continuation"
|
|
if lane == "medium_term" and action in {"buy", "sell"} and (
|
|
rejection_signal in {"bullish_rejection", "bearish_rejection"} or trend_stage == "early"
|
|
):
|
|
return "trend_reversal"
|
|
|
|
return "unknown"
|
|
|
|
def _build_setup_basis(self, signal: Dict[str, Any], setup_type: str) -> str:
|
|
market_location = signal.get("market_location") or {}
|
|
volume_context = signal.get("volume_price_context") or {}
|
|
parts: List[str] = []
|
|
|
|
location_tag = market_location.get("location_tag")
|
|
if location_tag and location_tag != "unknown":
|
|
parts.append(f"location={location_tag}")
|
|
|
|
for key in ("volume_price_state", "breakout_quality", "pullback_quality", "rejection_signal"):
|
|
value = volume_context.get(key) or signal.get(key)
|
|
if value and value not in {"none", "neutral", "unknown"}:
|
|
parts.append(f"{key}={value}")
|
|
|
|
if setup_type != "unknown":
|
|
parts.insert(0, f"setup={setup_type}")
|
|
|
|
return " | ".join(parts[:4]) if parts else f"setup={setup_type}"
|
|
|
|
def _build_entry_basis(self, signal: Dict[str, Any], setup_type: str) -> str:
|
|
entry_type = signal.get("entry_type", "market")
|
|
market_location = signal.get("market_location") or {}
|
|
|
|
if setup_type == "breakout_confirmation":
|
|
return "breakout_acceptance_follow_through"
|
|
if setup_type in {"breakout_pullback", "trend_continuation_pullback", "deep_pullback_continuation"}:
|
|
return "pullback_into_trade_zone" if entry_type == "limit" else "pullback_confirmed"
|
|
if setup_type == "range_reversal":
|
|
location_tag = market_location.get("location_tag", "range_edge")
|
|
return f"reversal_from_{location_tag}"
|
|
if setup_type == "trend_reversal":
|
|
return "rejection_or_structure_shift"
|
|
return "generic_entry"
|