354 lines
14 KiB
Python
354 lines
14 KiB
Python
"""策略复盘迭代 Agent
|
||
|
||
基于推荐生命周期表现,输出可审查的策略调整建议。
|
||
不直接修改策略参数,只给出建议和证据。
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
async def build_strategy_iteration_report(limit: int = 50, include_llm: bool = False) -> dict:
|
||
rows = await _load_recent_tracking(limit)
|
||
rule_report = _build_rule_report(rows)
|
||
|
||
if include_llm and settings.deepseek_api_key and rows:
|
||
ai_text = await _generate_ai_iteration(rule_report, rows)
|
||
if ai_text:
|
||
rule_report["ai_analysis"] = ai_text
|
||
rule_report["generated_by"] = "rules+llm"
|
||
|
||
return rule_report
|
||
|
||
|
||
async def build_strategy_feedback_controls(limit: int = 50) -> dict:
|
||
rows = await _load_recent_tracking(limit)
|
||
report = _build_rule_report(rows)
|
||
return _derive_feedback_controls(report)
|
||
|
||
|
||
async def _load_recent_tracking(limit: int) -> list[dict]:
|
||
from sqlalchemy import text
|
||
from app.db.database import get_db
|
||
|
||
async with get_db() as db:
|
||
rec_columns = await _get_table_columns(db, "recommendations")
|
||
tracking_columns = await _get_table_columns(db, "recommendation_tracking")
|
||
r_action_plan = _column_or_default(rec_columns, "action_plan", "'观察'", "r")
|
||
r_position_score = _column_or_default(rec_columns, "position_score", "50", "r")
|
||
r_lifecycle_status = _column_or_default(rec_columns, "lifecycle_status", "'candidate'", "r")
|
||
t_max_return = _column_or_default(tracking_columns, "max_return_pct", "t.pct_from_entry", "t")
|
||
t_max_drawdown = _column_or_default(tracking_columns, "max_drawdown_pct", "t.pct_from_entry", "t")
|
||
t_days_since = _column_or_default(tracking_columns, "days_since_recommendation", "0", "t")
|
||
t_close_reason = _column_or_default(tracking_columns, "close_reason", "''", "t")
|
||
t_review_note = _column_or_default(tracking_columns, "review_note", "''", "t")
|
||
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT r.id, r.ts_code, r.name, r.sector, r.strategy, r.entry_signal_type, "
|
||
f"{r_action_plan} AS action_plan, r.score, r.market_temp_score, r.sector_score, "
|
||
f"{r_position_score} AS position_score, {r_lifecycle_status} AS lifecycle_status, r.created_at, "
|
||
f"t.pct_from_entry, {t_max_return} AS max_return_pct, {t_max_drawdown} AS max_drawdown_pct, "
|
||
f"{t_days_since} AS days_since_recommendation, {t_close_reason} AS close_reason, "
|
||
f"{t_review_note} AS review_note, t.track_date "
|
||
"FROM recommendations r "
|
||
"LEFT JOIN ("
|
||
" SELECT t.* FROM recommendation_tracking t "
|
||
" INNER JOIN ("
|
||
" SELECT recommendation_id, MAX(id) AS max_id "
|
||
" FROM recommendation_tracking GROUP BY recommendation_id"
|
||
" ) latest ON t.id = latest.max_id"
|
||
") t ON t.recommendation_id = r.id "
|
||
"ORDER BY r.created_at DESC LIMIT :limit"
|
||
),
|
||
{"limit": limit},
|
||
)
|
||
return [dict(row._mapping) for row in result.fetchall()]
|
||
|
||
|
||
async def _get_table_columns(db, table_name: str) -> set[str]:
|
||
from sqlalchemy import text
|
||
|
||
result = await db.execute(text(f"PRAGMA table_info({table_name})"))
|
||
return {row._mapping["name"] for row in result.fetchall()}
|
||
|
||
|
||
def _column_or_default(columns: set[str], column_name: str, default_sql: str, alias: str = "") -> str:
|
||
if column_name in columns:
|
||
return f"{alias}.{column_name}" if alias else column_name
|
||
return default_sql
|
||
|
||
|
||
def _build_rule_report(rows: list[dict]) -> dict:
|
||
if not rows:
|
||
return {
|
||
"generated_at": datetime.now().isoformat(),
|
||
"sample_size": 0,
|
||
"summary": "暂无可复盘的推荐样本。",
|
||
"strategy_stats": [],
|
||
"signal_stats": [],
|
||
"failure_patterns": ["样本不足,先积累推荐生命周期数据。"],
|
||
"adjustment_suggestions": [],
|
||
"ai_analysis": "",
|
||
"generated_by": "rules",
|
||
}
|
||
|
||
tracked_rows = [r for r in rows if r.get("pct_from_entry") is not None]
|
||
strategy_stats = _group_stats(tracked_rows, "strategy")
|
||
signal_stats = _group_stats(tracked_rows, "entry_signal_type")
|
||
failure_patterns = _detect_failure_patterns(tracked_rows)
|
||
suggestions = _build_adjustment_suggestions(strategy_stats, signal_stats, failure_patterns, len(tracked_rows))
|
||
|
||
wins = sum(1 for r in tracked_rows if (r.get("pct_from_entry") or 0) > 0)
|
||
avg_return = _avg([r.get("pct_from_entry") for r in tracked_rows])
|
||
avg_drawdown = _avg([r.get("max_drawdown_pct") for r in tracked_rows])
|
||
win_rate = round(wins / len(tracked_rows) * 100, 1) if tracked_rows else 0
|
||
|
||
return {
|
||
"generated_at": datetime.now().isoformat(),
|
||
"sample_size": len(tracked_rows),
|
||
"summary": (
|
||
f"最近 {len(rows)} 条推荐中,{len(tracked_rows)} 条已有跟踪数据;"
|
||
f"胜率 {win_rate}%,平均收益 {avg_return}%,平均最大回撤 {avg_drawdown}%。"
|
||
),
|
||
"strategy_stats": strategy_stats,
|
||
"signal_stats": signal_stats,
|
||
"failure_patterns": failure_patterns,
|
||
"adjustment_suggestions": suggestions,
|
||
"ai_analysis": "",
|
||
"generated_by": "rules",
|
||
}
|
||
|
||
|
||
def _group_stats(rows: list[dict], key: str) -> list[dict]:
|
||
groups: dict[str, list[dict]] = defaultdict(list)
|
||
for row in rows:
|
||
groups[row.get(key) or "unknown"].append(row)
|
||
|
||
stats = []
|
||
for name, items in groups.items():
|
||
wins = sum(1 for r in items if (r.get("pct_from_entry") or 0) > 0)
|
||
hit_stop = sum(1 for r in items if r.get("close_reason") == "hit_stop_loss")
|
||
hit_target = sum(1 for r in items if r.get("close_reason") == "hit_target")
|
||
stats.append({
|
||
"name": name,
|
||
"count": len(items),
|
||
"win_rate": round(wins / len(items) * 100, 1) if items else 0,
|
||
"avg_return": _avg([r.get("pct_from_entry") for r in items]),
|
||
"avg_max_return": _avg([r.get("max_return_pct") for r in items]),
|
||
"avg_max_drawdown": _avg([r.get("max_drawdown_pct") for r in items]),
|
||
"hit_target": hit_target,
|
||
"hit_stop": hit_stop,
|
||
})
|
||
|
||
stats.sort(key=lambda x: (x["count"], x["avg_return"]), reverse=True)
|
||
return stats
|
||
|
||
|
||
def _detect_failure_patterns(rows: list[dict]) -> list[str]:
|
||
patterns = []
|
||
if not rows:
|
||
return ["暂无跟踪样本。"]
|
||
|
||
weak_market_losses = [
|
||
r for r in rows
|
||
if (r.get("market_temp_score") or 0) < 45 and (r.get("pct_from_entry") or 0) < 0
|
||
]
|
||
if len(weak_market_losses) >= 2:
|
||
patterns.append("弱势市场中仍有亏损推荐,低温环境下应进一步减少 BUY 或提高确认门槛。")
|
||
|
||
high_position_losses = [
|
||
r for r in rows
|
||
if (r.get("position_score") or 50) < 40 and (r.get("pct_from_entry") or 0) < 0
|
||
]
|
||
if len(high_position_losses) >= 2:
|
||
patterns.append("位置安全分偏低的推荐亏损较多,追高惩罚需要增强。")
|
||
|
||
stop_losses = [r for r in rows if r.get("close_reason") == "hit_stop_loss"]
|
||
if len(stop_losses) >= 2:
|
||
patterns.append("触发止损样本偏多,需要复查止损位置和入场触发条件是否过宽。")
|
||
|
||
expired_flat = [
|
||
r for r in rows
|
||
if r.get("close_reason") in ("review_expired_flat", "review_expired_loss")
|
||
]
|
||
if len(expired_flat) >= 3:
|
||
patterns.append("多只推荐到期未形成有效进攻,观察池转可操作的条件需要更严格。")
|
||
|
||
if not patterns:
|
||
patterns.append("暂无明显集中失败模式,继续积累样本并按策略分组观察。")
|
||
return patterns
|
||
|
||
|
||
def _build_adjustment_suggestions(
|
||
strategy_stats: list[dict],
|
||
signal_stats: list[dict],
|
||
failure_patterns: list[str],
|
||
sample_size: int,
|
||
) -> list[dict]:
|
||
suggestions = []
|
||
|
||
if sample_size < 10:
|
||
return [{
|
||
"target": "全局策略",
|
||
"action": "observe",
|
||
"reason": "跟踪样本少于10条,暂不建议调整参数。",
|
||
"confidence": "low",
|
||
}]
|
||
|
||
for stat in strategy_stats:
|
||
if stat["count"] >= 3 and stat["win_rate"] < 40 and stat["avg_return"] < 0:
|
||
suggestions.append({
|
||
"target": stat["name"],
|
||
"action": "tighten",
|
||
"reason": f"{stat['name']} 胜率{stat['win_rate']}%,平均收益{stat['avg_return']}%,建议提高买入门槛。",
|
||
"confidence": "medium",
|
||
})
|
||
elif stat["count"] >= 3 and stat["win_rate"] >= 60 and stat["avg_return"] > 1:
|
||
suggestions.append({
|
||
"target": stat["name"],
|
||
"action": "promote",
|
||
"reason": f"{stat['name']} 近期表现较好,可在相似市场环境下优先使用。",
|
||
"confidence": "medium",
|
||
})
|
||
|
||
for stat in signal_stats:
|
||
if stat["count"] >= 3 and stat["avg_max_drawdown"] < -5:
|
||
suggestions.append({
|
||
"target": stat["name"],
|
||
"action": "reduce",
|
||
"reason": f"{stat['name']} 平均最大回撤{stat['avg_max_drawdown']}%,建议降低排序权重或增加位置过滤。",
|
||
"confidence": "medium",
|
||
})
|
||
|
||
if any("弱势市场" in p for p in failure_patterns):
|
||
suggestions.append({
|
||
"target": "defensive_watch",
|
||
"action": "tighten",
|
||
"reason": "弱势市场亏损样本集中,防守策略下应只保留观察池,减少 BUY。",
|
||
"confidence": "high",
|
||
})
|
||
|
||
if not suggestions:
|
||
suggestions.append({
|
||
"target": "全局策略",
|
||
"action": "keep",
|
||
"reason": "当前样本未显示需要立即调整的集中问题。",
|
||
"confidence": "medium",
|
||
})
|
||
|
||
return suggestions[:6]
|
||
|
||
|
||
def _derive_feedback_controls(report: dict) -> dict:
|
||
suggestions = report.get("adjustment_suggestions", []) or []
|
||
sample_size = int(report.get("sample_size") or 0)
|
||
|
||
controls = {
|
||
"sample_size": sample_size,
|
||
"enabled": sample_size >= 10,
|
||
"buy_threshold_delta": 0,
|
||
"max_position_pct_delta": 0,
|
||
"actionable_limit_delta": 0,
|
||
"watch_limit_delta": 0,
|
||
"force_defensive": False,
|
||
"notes": [],
|
||
}
|
||
|
||
if sample_size < 10:
|
||
controls["notes"].append("样本不足,暂不启用自动回写。")
|
||
return controls
|
||
|
||
promote_count = 0
|
||
tighten_count = 0
|
||
reduce_count = 0
|
||
|
||
for item in suggestions[:6]:
|
||
action = item.get("action")
|
||
reason = item.get("reason", "")
|
||
|
||
if action == "promote":
|
||
promote_count += 1
|
||
controls["buy_threshold_delta"] -= 1
|
||
controls["watch_limit_delta"] += 1
|
||
elif action == "tighten":
|
||
tighten_count += 1
|
||
controls["buy_threshold_delta"] += 1
|
||
controls["actionable_limit_delta"] -= 1
|
||
controls["max_position_pct_delta"] -= 5
|
||
elif action == "reduce":
|
||
reduce_count += 1
|
||
controls["buy_threshold_delta"] += 1
|
||
controls["watch_limit_delta"] -= 1
|
||
|
||
if "弱势市场" in reason or item.get("target") == "defensive_watch":
|
||
controls["force_defensive"] = True
|
||
|
||
controls["buy_threshold_delta"] = max(-2, min(3, controls["buy_threshold_delta"]))
|
||
controls["max_position_pct_delta"] = max(-10, min(5, controls["max_position_pct_delta"]))
|
||
controls["actionable_limit_delta"] = max(-2, min(1, controls["actionable_limit_delta"]))
|
||
controls["watch_limit_delta"] = max(-2, min(2, controls["watch_limit_delta"]))
|
||
|
||
if controls["force_defensive"]:
|
||
controls["notes"].append("最近弱市亏损样本偏多,优先启用防守约束。")
|
||
elif tighten_count > promote_count:
|
||
controls["notes"].append("最近失效样本偏多,整体建议略收紧。")
|
||
elif promote_count > 0 and reduce_count == 0:
|
||
controls["notes"].append("最近有效样本改善,可适度放宽观察与出手空间。")
|
||
else:
|
||
controls["notes"].append("最近样本无明显单边倾向,仅做轻微校正。")
|
||
|
||
return controls
|
||
|
||
|
||
async def _generate_ai_iteration(rule_report: dict, rows: list[dict]) -> str:
|
||
from app.llm.client import chat_completion
|
||
|
||
sample = [
|
||
{
|
||
"name": r.get("name"),
|
||
"strategy": r.get("strategy"),
|
||
"signal": r.get("entry_signal_type"),
|
||
"return": r.get("pct_from_entry"),
|
||
"max_return": r.get("max_return_pct"),
|
||
"drawdown": r.get("max_drawdown_pct"),
|
||
"reason": r.get("close_reason"),
|
||
"market_temp": r.get("market_temp_score"),
|
||
"position_score": r.get("position_score"),
|
||
}
|
||
for r in rows[:20]
|
||
]
|
||
|
||
user_msg = f"""请基于以下推荐复盘数据,输出策略迭代建议。
|
||
要求:
|
||
1. 明确指出最该收紧、保留、加强的策略或信号;
|
||
2. 只提出可执行调整建议,不要泛泛而谈;
|
||
3. 不要承诺收益;
|
||
4. 180字以内。
|
||
|
||
规则复盘:
|
||
{json.dumps(rule_report, ensure_ascii=False)}
|
||
|
||
样本:
|
||
{json.dumps(sample, ensure_ascii=False)}
|
||
"""
|
||
|
||
resp = await chat_completion([
|
||
{"role": "system", "content": "你是一位A股策略复盘研究员,负责基于推荐结果提出保守、可验证的策略迭代建议。"},
|
||
{"role": "user", "content": user_msg},
|
||
])
|
||
return resp.content.strip() if resp and resp.content else ""
|
||
|
||
|
||
def _avg(values: list) -> float:
|
||
clean = [float(v) for v in values if v is not None]
|
||
if not clean:
|
||
return 0
|
||
return round(sum(clean) / len(clean), 2)
|