alphax/app/db/strategy_direction_repair.py
2026-06-02 06:49:48 +08:00

200 lines
8.1 KiB
Python

"""Repair recommendations whose strategy identity or evidence conflicts with trade side."""
from __future__ import annotations
import json
from datetime import datetime
from app.core.signal_direction import excluded_factor_delta, sanitize_factor_breakdown_for_side, sanitize_signals_for_side
from app.core.signal_taxonomy import signal_codes, signal_labels
from app.core.strategy_registry import BREAKDOWN_RETEST_SHORT_1H_STRATEGY, MAIN_COMPOSITE_STRATEGY, is_strategy_allowed_for_side, strategy_label
from app.core.trade_direction import trade_side_from_payload
from app.db.postgres_connection import connect
def _loads(value) -> dict:
if isinstance(value, dict):
return dict(value)
if not value:
return {}
try:
parsed = json.loads(value)
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
def _loads_list(value) -> list:
if isinstance(value, list):
return list(value)
if not value:
return []
try:
parsed = json.loads(value)
return parsed if isinstance(parsed, list) else []
except Exception:
return [str(value)]
def _dumps(value: dict) -> str:
return json.dumps(value or {}, ensure_ascii=False, default=str)
def _has_short_breakdown_context(*payloads: dict) -> bool:
for payload in payloads:
if not isinstance(payload, dict):
continue
short_ctx = payload.get("short_breakdown_retest_1h")
if isinstance(short_ctx, dict) and short_ctx.get("detected"):
return True
nested = payload.get("market_context")
if isinstance(nested, dict):
short_ctx = nested.get("short_breakdown_retest_1h")
if isinstance(short_ctx, dict) and short_ctx.get("detected"):
return True
return False
def repair_strategy_direction_mismatches(limit: int = 500, dry_run: bool = False) -> dict:
limit = max(1, min(int(limit or 500), 5000))
conn = connect()
rows = conn.execute(
"""
SELECT id, symbol, direction, strategy_code, signals, rec_score,
entry_plan_json, market_context_json, strategy_snapshot_json, factor_roles_json
FROM recommendation
WHERE strategy_code IS NOT NULL AND strategy_code != ''
ORDER BY id DESC
LIMIT %s
""",
(limit,),
).fetchall()
repaired = []
scanned = 0
now = datetime.now().isoformat()
try:
for row in rows:
scanned += 1
item = dict(row)
entry_plan = _loads(item.get("entry_plan_json"))
market_context = _loads(item.get("market_context_json"))
snapshot = _loads(item.get("strategy_snapshot_json"))
factor_roles = _loads(item.get("factor_roles_json"))
side = trade_side_from_payload(entry_plan, snapshot, item.get("direction"))
old_code = str(item.get("strategy_code") or "").strip()
new_code = old_code
reasons: list[str] = []
if not is_strategy_allowed_for_side(old_code, side):
if side == "short" and _has_short_breakdown_context(entry_plan, market_context, snapshot):
new_code = BREAKDOWN_RETEST_SHORT_1H_STRATEGY
reasons.append("short_breakdown_context")
else:
new_code = MAIN_COMPOSITE_STRATEGY
reasons.append("fallback_composite_direction_mismatch")
signals = _loads_list(item.get("signals"))
removed_signals: list[str] = []
removed_factors: list[dict] = []
score_delta_to_remove = 0.0
if side == "short":
clean_signals, removed_signals = sanitize_signals_for_side(signals, "short")
breakdown = _loads(entry_plan.get("factor_score_breakdown")) or _loads(market_context.get("factor_score_breakdown"))
score_delta_to_remove = excluded_factor_delta(breakdown.get("items") or [], "short")
clean_breakdown, removed_factors = sanitize_factor_breakdown_for_side(breakdown, "short")
if removed_signals or removed_factors:
reasons.append("short_direction_signal_cleanup")
signals = clean_signals or signals
entry_plan["factor_score_breakdown"] = clean_breakdown
market_context["factor_score_breakdown"] = clean_breakdown
market_context["direction_conflict_filter"] = {
"side": side,
"removed_signals": removed_signals,
"removed_factor_codes": [str(x.get("factor_code") or "") for x in removed_factors],
}
decision_log = market_context.get("decision_log") if isinstance(market_context.get("decision_log"), dict) else {}
risk_flags = list(decision_log.get("risk_flags") or [])
for sig in removed_signals[:3]:
flag = f"direction_conflict:{sig}"
if flag not in risk_flags:
risk_flags.append(flag)
decision_log["risk_flags"] = risk_flags
market_context["decision_log"] = decision_log
entry_plan["decision_log"] = decision_log
if not reasons:
continue
entry_plan["side"] = side
entry_plan["strategy_code"] = new_code
entry_plan["strategy_direction_repair"] = {
"old_strategy_code": old_code,
"new_strategy_code": new_code,
"side": side,
"reason": ",".join(reasons),
"repaired_at": now,
}
snapshot["strategy_code"] = new_code
snapshot["strategy_name"] = strategy_label(new_code)
snapshot["direction"] = side
snapshot["entry_plan"] = {**entry_plan, **(_loads(snapshot.get("entry_plan")) if isinstance(snapshot.get("entry_plan"), (str, dict)) else {})}
try:
rec_score = max(0, round(float(item.get("rec_score") or 0) - (score_delta_to_remove * 100.0 / 30.0)))
except Exception:
rec_score = item.get("rec_score") or 0
labels = signal_labels(signals)
codes = signal_codes(labels)
repaired.append(
{
"id": item["id"],
"symbol": item["symbol"],
"side": side,
"old_strategy_code": old_code,
"new_strategy_code": new_code,
"reason": ",".join(reasons),
"removed_signals": removed_signals,
"removed_factor_codes": [str(x.get("factor_code") or "") for x in removed_factors],
}
)
if not dry_run:
conn.execute(
"""
UPDATE recommendation
SET strategy_code=%s,
rec_score=%s,
signals=%s,
signal_codes_json=%s,
signal_labels_json=%s,
entry_plan_json=%s,
market_context_json=%s,
strategy_snapshot_json=%s,
factor_roles_json=%s
WHERE id=%s
""",
(
new_code,
rec_score,
json.dumps(labels, ensure_ascii=False),
json.dumps(codes, ensure_ascii=False),
json.dumps(labels, ensure_ascii=False),
_dumps(entry_plan),
_dumps(market_context),
_dumps(snapshot),
_dumps(factor_roles),
item["id"],
),
)
if dry_run:
conn.rollback()
else:
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"scanned": scanned, "repaired_count": len(repaired), "dry_run": dry_run, "items": repaired}