astock-agent/backend/app/llm/strategy_iteration.py
2026-04-28 12:46:10 +08:00

354 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""策略复盘迭代 Agent
基于推荐生命周期表现,输出可审查的策略调整建议。
不直接修改策略参数,只给出建议和证据。
"""
import json
import logging
from collections import defaultdict
from datetime import datetime
from app.config import settings
logger = logging.getLogger(__name__)
async def build_strategy_iteration_report(limit: int = 50, include_llm: bool = False) -> dict:
rows = await _load_recent_tracking(limit)
rule_report = _build_rule_report(rows)
if include_llm and settings.deepseek_api_key and rows:
ai_text = await _generate_ai_iteration(rule_report, rows)
if ai_text:
rule_report["ai_analysis"] = ai_text
rule_report["generated_by"] = "rules+llm"
return rule_report
async def build_strategy_feedback_controls(limit: int = 50) -> dict:
rows = await _load_recent_tracking(limit)
report = _build_rule_report(rows)
return _derive_feedback_controls(report)
async def _load_recent_tracking(limit: int) -> list[dict]:
from sqlalchemy import text
from app.db.database import get_db
async with get_db() as db:
rec_columns = await _get_table_columns(db, "recommendations")
tracking_columns = await _get_table_columns(db, "recommendation_tracking")
r_action_plan = _column_or_default(rec_columns, "action_plan", "'观察'", "r")
r_position_score = _column_or_default(rec_columns, "position_score", "50", "r")
r_lifecycle_status = _column_or_default(rec_columns, "lifecycle_status", "'candidate'", "r")
t_max_return = _column_or_default(tracking_columns, "max_return_pct", "t.pct_from_entry", "t")
t_max_drawdown = _column_or_default(tracking_columns, "max_drawdown_pct", "t.pct_from_entry", "t")
t_days_since = _column_or_default(tracking_columns, "days_since_recommendation", "0", "t")
t_close_reason = _column_or_default(tracking_columns, "close_reason", "''", "t")
t_review_note = _column_or_default(tracking_columns, "review_note", "''", "t")
result = await db.execute(
text(
"SELECT r.id, r.ts_code, r.name, r.sector, r.strategy, r.entry_signal_type, "
f"{r_action_plan} AS action_plan, r.score, r.market_temp_score, r.sector_score, "
f"{r_position_score} AS position_score, {r_lifecycle_status} AS lifecycle_status, r.created_at, "
f"t.pct_from_entry, {t_max_return} AS max_return_pct, {t_max_drawdown} AS max_drawdown_pct, "
f"{t_days_since} AS days_since_recommendation, {t_close_reason} AS close_reason, "
f"{t_review_note} AS review_note, t.track_date "
"FROM recommendations r "
"LEFT JOIN ("
" SELECT t.* FROM recommendation_tracking t "
" INNER JOIN ("
" SELECT recommendation_id, MAX(id) AS max_id "
" FROM recommendation_tracking GROUP BY recommendation_id"
" ) latest ON t.id = latest.max_id"
") t ON t.recommendation_id = r.id "
"ORDER BY r.created_at DESC LIMIT :limit"
),
{"limit": limit},
)
return [dict(row._mapping) for row in result.fetchall()]
async def _get_table_columns(db, table_name: str) -> set[str]:
from sqlalchemy import text
result = await db.execute(text(f"PRAGMA table_info({table_name})"))
return {row._mapping["name"] for row in result.fetchall()}
def _column_or_default(columns: set[str], column_name: str, default_sql: str, alias: str = "") -> str:
if column_name in columns:
return f"{alias}.{column_name}" if alias else column_name
return default_sql
def _build_rule_report(rows: list[dict]) -> dict:
if not rows:
return {
"generated_at": datetime.now().isoformat(),
"sample_size": 0,
"summary": "暂无可复盘的推荐样本。",
"strategy_stats": [],
"signal_stats": [],
"failure_patterns": ["样本不足,先积累推荐生命周期数据。"],
"adjustment_suggestions": [],
"ai_analysis": "",
"generated_by": "rules",
}
tracked_rows = [r for r in rows if r.get("pct_from_entry") is not None]
strategy_stats = _group_stats(tracked_rows, "strategy")
signal_stats = _group_stats(tracked_rows, "entry_signal_type")
failure_patterns = _detect_failure_patterns(tracked_rows)
suggestions = _build_adjustment_suggestions(strategy_stats, signal_stats, failure_patterns, len(tracked_rows))
wins = sum(1 for r in tracked_rows if (r.get("pct_from_entry") or 0) > 0)
avg_return = _avg([r.get("pct_from_entry") for r in tracked_rows])
avg_drawdown = _avg([r.get("max_drawdown_pct") for r in tracked_rows])
win_rate = round(wins / len(tracked_rows) * 100, 1) if tracked_rows else 0
return {
"generated_at": datetime.now().isoformat(),
"sample_size": len(tracked_rows),
"summary": (
f"最近 {len(rows)} 条推荐中,{len(tracked_rows)} 条已有跟踪数据;"
f"胜率 {win_rate}%,平均收益 {avg_return}%,平均最大回撤 {avg_drawdown}%。"
),
"strategy_stats": strategy_stats,
"signal_stats": signal_stats,
"failure_patterns": failure_patterns,
"adjustment_suggestions": suggestions,
"ai_analysis": "",
"generated_by": "rules",
}
def _group_stats(rows: list[dict], key: str) -> list[dict]:
groups: dict[str, list[dict]] = defaultdict(list)
for row in rows:
groups[row.get(key) or "unknown"].append(row)
stats = []
for name, items in groups.items():
wins = sum(1 for r in items if (r.get("pct_from_entry") or 0) > 0)
hit_stop = sum(1 for r in items if r.get("close_reason") == "hit_stop_loss")
hit_target = sum(1 for r in items if r.get("close_reason") == "hit_target")
stats.append({
"name": name,
"count": len(items),
"win_rate": round(wins / len(items) * 100, 1) if items else 0,
"avg_return": _avg([r.get("pct_from_entry") for r in items]),
"avg_max_return": _avg([r.get("max_return_pct") for r in items]),
"avg_max_drawdown": _avg([r.get("max_drawdown_pct") for r in items]),
"hit_target": hit_target,
"hit_stop": hit_stop,
})
stats.sort(key=lambda x: (x["count"], x["avg_return"]), reverse=True)
return stats
def _detect_failure_patterns(rows: list[dict]) -> list[str]:
patterns = []
if not rows:
return ["暂无跟踪样本。"]
weak_market_losses = [
r for r in rows
if (r.get("market_temp_score") or 0) < 45 and (r.get("pct_from_entry") or 0) < 0
]
if len(weak_market_losses) >= 2:
patterns.append("弱势市场中仍有亏损推荐,低温环境下应进一步减少 BUY 或提高确认门槛。")
high_position_losses = [
r for r in rows
if (r.get("position_score") or 50) < 40 and (r.get("pct_from_entry") or 0) < 0
]
if len(high_position_losses) >= 2:
patterns.append("位置安全分偏低的推荐亏损较多,追高惩罚需要增强。")
stop_losses = [r for r in rows if r.get("close_reason") == "hit_stop_loss"]
if len(stop_losses) >= 2:
patterns.append("触发止损样本偏多,需要复查止损位置和入场触发条件是否过宽。")
expired_flat = [
r for r in rows
if r.get("close_reason") in ("review_expired_flat", "review_expired_loss")
]
if len(expired_flat) >= 3:
patterns.append("多只推荐到期未形成有效进攻,观察池转可操作的条件需要更严格。")
if not patterns:
patterns.append("暂无明显集中失败模式,继续积累样本并按策略分组观察。")
return patterns
def _build_adjustment_suggestions(
strategy_stats: list[dict],
signal_stats: list[dict],
failure_patterns: list[str],
sample_size: int,
) -> list[dict]:
suggestions = []
if sample_size < 10:
return [{
"target": "全局策略",
"action": "observe",
"reason": "跟踪样本少于10条暂不建议调整参数。",
"confidence": "low",
}]
for stat in strategy_stats:
if stat["count"] >= 3 and stat["win_rate"] < 40 and stat["avg_return"] < 0:
suggestions.append({
"target": stat["name"],
"action": "tighten",
"reason": f"{stat['name']} 胜率{stat['win_rate']}%,平均收益{stat['avg_return']}%,建议提高买入门槛。",
"confidence": "medium",
})
elif stat["count"] >= 3 and stat["win_rate"] >= 60 and stat["avg_return"] > 1:
suggestions.append({
"target": stat["name"],
"action": "promote",
"reason": f"{stat['name']} 近期表现较好,可在相似市场环境下优先使用。",
"confidence": "medium",
})
for stat in signal_stats:
if stat["count"] >= 3 and stat["avg_max_drawdown"] < -5:
suggestions.append({
"target": stat["name"],
"action": "reduce",
"reason": f"{stat['name']} 平均最大回撤{stat['avg_max_drawdown']}%,建议降低排序权重或增加位置过滤。",
"confidence": "medium",
})
if any("弱势市场" in p for p in failure_patterns):
suggestions.append({
"target": "defensive_watch",
"action": "tighten",
"reason": "弱势市场亏损样本集中,防守策略下应只保留观察池,减少 BUY。",
"confidence": "high",
})
if not suggestions:
suggestions.append({
"target": "全局策略",
"action": "keep",
"reason": "当前样本未显示需要立即调整的集中问题。",
"confidence": "medium",
})
return suggestions[:6]
def _derive_feedback_controls(report: dict) -> dict:
suggestions = report.get("adjustment_suggestions", []) or []
sample_size = int(report.get("sample_size") or 0)
controls = {
"sample_size": sample_size,
"enabled": sample_size >= 10,
"buy_threshold_delta": 0,
"max_position_pct_delta": 0,
"actionable_limit_delta": 0,
"watch_limit_delta": 0,
"force_defensive": False,
"notes": [],
}
if sample_size < 10:
controls["notes"].append("样本不足,暂不启用自动回写。")
return controls
promote_count = 0
tighten_count = 0
reduce_count = 0
for item in suggestions[:6]:
action = item.get("action")
reason = item.get("reason", "")
if action == "promote":
promote_count += 1
controls["buy_threshold_delta"] -= 1
controls["watch_limit_delta"] += 1
elif action == "tighten":
tighten_count += 1
controls["buy_threshold_delta"] += 1
controls["actionable_limit_delta"] -= 1
controls["max_position_pct_delta"] -= 5
elif action == "reduce":
reduce_count += 1
controls["buy_threshold_delta"] += 1
controls["watch_limit_delta"] -= 1
if "弱势市场" in reason or item.get("target") == "defensive_watch":
controls["force_defensive"] = True
controls["buy_threshold_delta"] = max(-2, min(3, controls["buy_threshold_delta"]))
controls["max_position_pct_delta"] = max(-10, min(5, controls["max_position_pct_delta"]))
controls["actionable_limit_delta"] = max(-2, min(1, controls["actionable_limit_delta"]))
controls["watch_limit_delta"] = max(-2, min(2, controls["watch_limit_delta"]))
if controls["force_defensive"]:
controls["notes"].append("最近弱市亏损样本偏多,优先启用防守约束。")
elif tighten_count > promote_count:
controls["notes"].append("最近失效样本偏多,整体建议略收紧。")
elif promote_count > 0 and reduce_count == 0:
controls["notes"].append("最近有效样本改善,可适度放宽观察与出手空间。")
else:
controls["notes"].append("最近样本无明显单边倾向,仅做轻微校正。")
return controls
async def _generate_ai_iteration(rule_report: dict, rows: list[dict]) -> str:
from app.llm.client import chat_completion
sample = [
{
"name": r.get("name"),
"strategy": r.get("strategy"),
"signal": r.get("entry_signal_type"),
"return": r.get("pct_from_entry"),
"max_return": r.get("max_return_pct"),
"drawdown": r.get("max_drawdown_pct"),
"reason": r.get("close_reason"),
"market_temp": r.get("market_temp_score"),
"position_score": r.get("position_score"),
}
for r in rows[:20]
]
user_msg = f"""请基于以下推荐复盘数据,输出策略迭代建议。
要求:
1. 明确指出最该收紧、保留、加强的策略或信号;
2. 只提出可执行调整建议,不要泛泛而谈;
3. 不要承诺收益;
4. 180字以内。
规则复盘:
{json.dumps(rule_report, ensure_ascii=False)}
样本:
{json.dumps(sample, ensure_ascii=False)}
"""
resp = await chat_completion([
{"role": "system", "content": "你是一位A股策略复盘研究员负责基于推荐结果提出保守、可验证的策略迭代建议。"},
{"role": "user", "content": user_msg},
])
return resp.content.strip() if resp and resp.content else ""
def _avg(values: list) -> float:
clean = [float(v) for v in values if v is not None]
if not clean:
return 0
return round(sum(clean) / len(clean), 2)