astock-agent/backend/app/llm/strategy_iteration.py
2026-04-30 20:28:19 +08:00

577 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""策略复盘迭代 Agent
基于推荐生命周期表现,输出可审查的策略调整建议。
不直接修改策略参数,只给出建议和证据。
"""
import json
import logging
from collections import defaultdict
from datetime import datetime
from app.config import settings
logger = logging.getLogger(__name__)
async def build_strategy_iteration_report(
limit: int = 50,
include_llm: bool = False,
apply_auto_config: bool = False,
) -> dict:
rows = await _load_recent_tracking(limit)
rule_report = _build_rule_report(rows)
auto_change = None
if apply_auto_config:
from app.llm.strategy_config import maybe_auto_apply_review_adjustment
try:
auto_change = await maybe_auto_apply_review_adjustment(rule_report)
except Exception as e:
logger.warning(f"自动策略配置调整失败: {e}")
if auto_change:
rule_report["auto_config_change"] = auto_change
rule_report["generated_by"] = "rules+auto_config"
if include_llm and settings.deepseek_api_key and rows:
ai_text = await _generate_ai_iteration(rule_report, rows)
if ai_text:
rule_report["ai_analysis"] = ai_text
rule_report["generated_by"] = "rules+llm"
return rule_report
async def build_strategy_feedback_controls(limit: int = 50) -> dict:
rows = await _load_recent_tracking(limit)
report = _build_rule_report(rows)
return _derive_feedback_controls(report)
async def _load_recent_tracking(limit: int) -> list[dict]:
from sqlalchemy import text
from app.db.database import get_db
async with get_db() as db:
rec_columns = await _get_table_columns(db, "recommendations")
tracking_columns = await _get_table_columns(db, "recommendation_tracking")
r_action_plan = _column_or_default(rec_columns, "action_plan", "'观察'", "r")
r_position_score = _column_or_default(rec_columns, "position_score", "50", "r")
r_lifecycle_status = _column_or_default(rec_columns, "lifecycle_status", "'candidate'", "r")
r_capital_score = _column_or_default(rec_columns, "capital_score", "0", "r")
r_recall_tags = _column_or_default(rec_columns, "recall_tags", "'[]'", "r")
t_max_return = _column_or_default(tracking_columns, "max_return_pct", "t.pct_from_entry", "t")
t_max_drawdown = _column_or_default(tracking_columns, "max_drawdown_pct", "t.pct_from_entry", "t")
t_days_since = _column_or_default(tracking_columns, "days_since_recommendation", "0", "t")
t_close_reason = _column_or_default(tracking_columns, "close_reason", "''", "t")
t_review_note = _column_or_default(tracking_columns, "review_note", "''", "t")
result = await db.execute(
text(
"SELECT r.id, r.ts_code, r.name, r.sector, r.strategy, r.entry_signal_type, "
f"{r_action_plan} AS action_plan, r.score, r.market_temp_score, r.sector_score, "
f"{r_capital_score} AS capital_score, {r_position_score} AS position_score, "
f"{r_lifecycle_status} AS lifecycle_status, {r_recall_tags} AS recall_tags, "
"r.entry_price, r.target_price, r.stop_loss, r.created_at, "
f"t.pct_from_entry, {t_max_return} AS max_return_pct, {t_max_drawdown} AS max_drawdown_pct, "
f"{t_days_since} AS days_since_recommendation, {t_close_reason} AS close_reason, "
f"{t_review_note} AS review_note, t.track_date "
"FROM recommendations r "
"LEFT JOIN ("
" SELECT t.* FROM recommendation_tracking t "
" INNER JOIN ("
" SELECT recommendation_id, MAX(id) AS max_id "
" FROM recommendation_tracking GROUP BY recommendation_id"
" ) latest ON t.id = latest.max_id"
") t ON t.recommendation_id = r.id "
"ORDER BY r.created_at DESC LIMIT :limit"
),
{"limit": limit},
)
return [dict(row._mapping) for row in result.fetchall()]
async def _get_table_columns(db, table_name: str) -> set[str]:
from sqlalchemy import text
result = await db.execute(text(f"PRAGMA table_info({table_name})"))
return {row._mapping["name"] for row in result.fetchall()}
def _column_or_default(columns: set[str], column_name: str, default_sql: str, alias: str = "") -> str:
if column_name in columns:
return f"{alias}.{column_name}" if alias else column_name
return default_sql
def _build_rule_report(rows: list[dict]) -> dict:
if not rows:
return {
"generated_at": datetime.now().isoformat(),
"sample_size": 0,
"summary": "暂无可复盘的推荐样本。",
"strategy_stats": [],
"signal_stats": [],
"failure_patterns": ["样本不足,先积累推荐生命周期数据。"],
"review_windows": [],
"failure_cases": [],
"success_patterns": [],
"adjustment_suggestions": [],
"agent_patch_prompts": [],
"auto_config_change": None,
"ai_analysis": "",
"generated_by": "rules",
}
tracked_rows = [r for r in rows if r.get("pct_from_entry") is not None]
strategy_stats = _group_stats(tracked_rows, "strategy")
signal_stats = _group_stats(tracked_rows, "entry_signal_type")
failure_patterns = _detect_failure_patterns(tracked_rows)
suggestions = _build_adjustment_suggestions(strategy_stats, signal_stats, failure_patterns, len(tracked_rows))
review_windows = _build_review_windows(tracked_rows)
failure_cases = _build_failure_cases(tracked_rows)
success_patterns = _build_success_patterns(tracked_rows)
patch_prompts = _build_agent_patch_prompts(
strategy_stats=strategy_stats,
signal_stats=signal_stats,
failure_patterns=failure_patterns,
failure_cases=failure_cases,
sample_size=len(tracked_rows),
)
wins = sum(1 for r in tracked_rows if (r.get("pct_from_entry") or 0) > 0)
avg_return = _avg([r.get("pct_from_entry") for r in tracked_rows])
avg_drawdown = _avg([r.get("max_drawdown_pct") for r in tracked_rows])
win_rate = round(wins / len(tracked_rows) * 100, 1) if tracked_rows else 0
return {
"generated_at": datetime.now().isoformat(),
"sample_size": len(tracked_rows),
"summary": (
f"最近 {len(rows)} 条推荐中,{len(tracked_rows)} 条已有跟踪数据;"
f"胜率 {win_rate}%,平均收益 {avg_return}%,平均最大回撤 {avg_drawdown}%。"
),
"strategy_stats": strategy_stats,
"signal_stats": signal_stats,
"failure_patterns": failure_patterns,
"review_windows": review_windows,
"failure_cases": failure_cases,
"success_patterns": success_patterns,
"adjustment_suggestions": suggestions,
"agent_patch_prompts": patch_prompts,
"auto_config_change": None,
"ai_analysis": "",
"generated_by": "rules",
}
def _group_stats(rows: list[dict], key: str) -> list[dict]:
groups: dict[str, list[dict]] = defaultdict(list)
for row in rows:
groups[row.get(key) or "unknown"].append(row)
stats = []
for name, items in groups.items():
wins = sum(1 for r in items if (r.get("pct_from_entry") or 0) > 0)
hit_stop = sum(1 for r in items if r.get("close_reason") == "hit_stop_loss")
hit_target = sum(1 for r in items if r.get("close_reason") == "hit_target")
stats.append({
"name": name,
"count": len(items),
"win_rate": round(wins / len(items) * 100, 1) if items else 0,
"avg_return": _avg([r.get("pct_from_entry") for r in items]),
"avg_max_return": _avg([r.get("max_return_pct") for r in items]),
"avg_max_drawdown": _avg([r.get("max_drawdown_pct") for r in items]),
"hit_target": hit_target,
"hit_stop": hit_stop,
})
stats.sort(key=lambda x: (x["count"], x["avg_return"]), reverse=True)
return stats
def _detect_failure_patterns(rows: list[dict]) -> list[str]:
patterns = []
if not rows:
return ["暂无跟踪样本。"]
weak_market_losses = [
r for r in rows
if (r.get("market_temp_score") or 0) < 45 and (r.get("pct_from_entry") or 0) < 0
]
if len(weak_market_losses) >= 2:
patterns.append("弱势市场中仍有亏损推荐,低温环境下应进一步减少 BUY 或提高确认门槛。")
high_position_losses = [
r for r in rows
if (r.get("position_score") or 50) < 40 and (r.get("pct_from_entry") or 0) < 0
]
if len(high_position_losses) >= 2:
patterns.append("位置安全分偏低的推荐亏损较多,追高惩罚需要增强。")
stop_losses = [r for r in rows if r.get("close_reason") == "hit_stop_loss"]
if len(stop_losses) >= 2:
patterns.append("触发止损样本偏多,需要复查止损位置和入场触发条件是否过宽。")
expired_flat = [
r for r in rows
if r.get("close_reason") in ("review_expired_flat", "review_expired_loss")
]
if len(expired_flat) >= 3:
patterns.append("多只推荐到期未形成有效进攻,观察池转可操作的条件需要更严格。")
if not patterns:
patterns.append("暂无明显集中失败模式,继续积累样本并按策略分组观察。")
return patterns
def _build_adjustment_suggestions(
strategy_stats: list[dict],
signal_stats: list[dict],
failure_patterns: list[str],
sample_size: int,
) -> list[dict]:
suggestions = []
if sample_size < 10:
return [{
"target": "全局策略",
"action": "observe",
"reason": "跟踪样本少于10条暂不建议调整参数。",
"confidence": "low",
}]
for stat in strategy_stats:
if stat["count"] >= 3 and stat["win_rate"] < 40 and stat["avg_return"] < 0:
suggestions.append({
"target": stat["name"],
"action": "tighten",
"reason": f"{stat['name']} 胜率{stat['win_rate']}%,平均收益{stat['avg_return']}%,建议提高买入门槛。",
"confidence": "medium",
})
elif stat["count"] >= 3 and stat["win_rate"] >= 60 and stat["avg_return"] > 1:
suggestions.append({
"target": stat["name"],
"action": "promote",
"reason": f"{stat['name']} 近期表现较好,可在相似市场环境下优先使用。",
"confidence": "medium",
})
for stat in signal_stats:
if stat["count"] >= 3 and stat["avg_max_drawdown"] < -5:
suggestions.append({
"target": stat["name"],
"action": "reduce",
"reason": f"{stat['name']} 平均最大回撤{stat['avg_max_drawdown']}%,建议降低排序权重或增加位置过滤。",
"confidence": "medium",
})
if any("弱势市场" in p for p in failure_patterns):
suggestions.append({
"target": "defensive_watch",
"action": "tighten",
"reason": "弱势市场亏损样本集中,防守策略下应只保留观察池,减少 BUY。",
"confidence": "high",
})
if not suggestions:
suggestions.append({
"target": "全局策略",
"action": "keep",
"reason": "当前样本未显示需要立即调整的集中问题。",
"confidence": "medium",
})
return suggestions[:6]
def _build_review_windows(rows: list[dict]) -> list[dict]:
windows = []
for days in [3, 5, 10]:
items = [
r for r in rows
if int(r.get("days_since_recommendation") or 0) >= days
]
if not items:
windows.append({
"window_days": days,
"count": 0,
"win_rate": 0,
"avg_return": 0,
"hit_target_rate": 0,
"hit_stop_rate": 0,
"avg_max_return": 0,
"avg_max_drawdown": 0,
})
continue
wins = sum(1 for r in items if (r.get("pct_from_entry") or 0) > 0)
hit_target = sum(1 for r in items if r.get("close_reason") == "hit_target")
hit_stop = sum(1 for r in items if r.get("close_reason") == "hit_stop_loss")
count = len(items)
windows.append({
"window_days": days,
"count": count,
"win_rate": round(wins / count * 100, 1),
"avg_return": _avg([r.get("pct_from_entry") for r in items]),
"hit_target_rate": round(hit_target / count * 100, 1),
"hit_stop_rate": round(hit_stop / count * 100, 1),
"avg_max_return": _avg([r.get("max_return_pct") for r in items]),
"avg_max_drawdown": _avg([r.get("max_drawdown_pct") for r in items]),
})
return windows
def _build_failure_cases(rows: list[dict]) -> list[dict]:
failures = [
r for r in rows
if (r.get("pct_from_entry") or 0) < 0
or (r.get("close_reason") in {"hit_stop_loss", "review_expired_loss", "review_expired_flat"})
or (r.get("max_drawdown_pct") or 0) < -5
]
failures.sort(key=lambda r: ((r.get("pct_from_entry") or 0), (r.get("max_drawdown_pct") or 0)))
return [_case_summary(r) for r in failures[:8]]
def _build_success_patterns(rows: list[dict]) -> list[dict]:
successes = [
r for r in rows
if r.get("close_reason") == "hit_target"
or (r.get("max_return_pct") or 0) >= 3
or (r.get("pct_from_entry") or 0) > 2
]
successes.sort(key=lambda r: (r.get("max_return_pct") or 0), reverse=True)
return [_case_summary(r) for r in successes[:8]]
def _case_summary(row: dict) -> dict:
recall_tags = row.get("recall_tags") or "[]"
try:
tags = json.loads(recall_tags) if isinstance(recall_tags, str) else recall_tags
except Exception:
tags = []
return {
"ts_code": row.get("ts_code"),
"name": row.get("name"),
"sector": row.get("sector") or "",
"strategy": row.get("strategy") or "unknown",
"entry_signal_type": row.get("entry_signal_type") or "unknown",
"action_plan": row.get("action_plan") or "观察",
"score": row.get("score") or 0,
"market_temp_score": row.get("market_temp_score") or 0,
"sector_score": row.get("sector_score") or 0,
"capital_score": row.get("capital_score") or 0,
"position_score": row.get("position_score") or 50,
"pct_from_entry": row.get("pct_from_entry") or 0,
"max_return_pct": row.get("max_return_pct") or 0,
"max_drawdown_pct": row.get("max_drawdown_pct") or 0,
"days_since_recommendation": row.get("days_since_recommendation") or 0,
"close_reason": row.get("close_reason") or "",
"review_note": row.get("review_note") or "",
"recall_tags": tags,
}
def _build_agent_patch_prompts(
strategy_stats: list[dict],
signal_stats: list[dict],
failure_patterns: list[str],
failure_cases: list[dict],
sample_size: int,
) -> list[dict]:
if sample_size < 10:
return []
prompts = []
weak_strategy = next(
(
s for s in strategy_stats
if s["count"] >= 3 and s["win_rate"] < 40 and s["avg_return"] < 0
),
None,
)
if weak_strategy:
prompts.append(_patch_prompt(
title=f"收紧 {weak_strategy['name']} 策略配置",
severity="high",
evidence=f"样本{weak_strategy['count']}条,胜率{weak_strategy['win_rate']}%,平均收益{weak_strategy['avg_return']}%。",
target_files=["backend/app/llm/strategy_config.py", "backend/app/llm/strategy_selector.py"],
prompt=(
f"请基于策略复盘收紧 {weak_strategy['name']}。优先通过策略配置版本调整完成,不要改无关代码。"
f"证据:{weak_strategy['count']}条样本,胜率{weak_strategy['win_rate']}%,平均收益{weak_strategy['avg_return']}%"
f"平均最大回撤{weak_strategy['avg_max_drawdown']}%。"
"请提高 buy_threshold 1-2 分,降低 actionable_limit 或 max_position_pct并保留回滚记录。"
),
))
weak_signal = next(
(
s for s in signal_stats
if s["count"] >= 3 and s["avg_max_drawdown"] < -5
),
None,
)
if weak_signal:
prompts.append(_patch_prompt(
title=f"降低 {weak_signal['name']} 信号风险暴露",
severity="medium",
evidence=f"样本{weak_signal['count']}条,平均最大回撤{weak_signal['avg_max_drawdown']}%。",
target_files=["backend/app/engine/screener.py", "backend/app/analysis/breakout_signals.py"],
prompt=(
f"请基于复盘结果降低 {weak_signal['name']} 信号的风险暴露。"
f"证据:样本{weak_signal['count']}条,平均最大回撤{weak_signal['avg_max_drawdown']}%"
f"命中止损{weak_signal['hit_stop']}次。"
"请检查该信号的入场质量、位置过滤和止损设置,给出最小代码补丁,并保持其他信号行为不变。"
),
))
if any("弱势市场" in p for p in failure_patterns):
prompts.append(_patch_prompt(
title="强化弱势市场防守配置",
severity="high",
evidence="复盘显示弱势市场亏损样本集中。",
target_files=["backend/app/llm/strategy_config.py", "backend/app/llm/strategy_selector.py"],
prompt=(
"请强化弱势市场防守配置。证据复盘显示市场温度低于45时亏损样本集中。"
"优先把低温环境下的 allow_trading、actionable_limit、buy_threshold 做成可配置护栏,"
"小幅收紧可自动生效,大幅禁用策略需生成待确认变更。"
),
))
if failure_cases and not prompts:
worst = failure_cases[0]
prompts.append(_patch_prompt(
title="复查推荐失效样本的共同过滤条件",
severity="medium",
evidence=f"最差样本 {worst['name']} 收益{worst['pct_from_entry']}%,最大回撤{worst['max_drawdown_pct']}%。",
target_files=["backend/app/engine/screener.py", "backend/app/llm/strategy_iteration.py"],
prompt=(
"请复查最近推荐失效样本的共同过滤条件,优先寻找可配置化的收紧项。"
f"最差样本:{worst['name']}({worst['ts_code']}),策略{worst['strategy']}"
f"信号{worst['entry_signal_type']},收益{worst['pct_from_entry']}%"
f"最大回撤{worst['max_drawdown_pct']}%。"
"请不要凭单一样本大改策略,必须保留样本数门槛。"
),
))
return prompts[:4]
def _patch_prompt(title: str, severity: str, evidence: str, target_files: list[str], prompt: str) -> dict:
return {
"title": title,
"severity": severity,
"evidence": evidence,
"target_files": target_files,
"prompt": prompt,
"acceptance_criteria": [
"python3 -m compileall backend/app 通过",
"策略配置变更有版本记录且可回滚",
"历史推荐和跟踪数据读取不受影响",
],
}
def _derive_feedback_controls(report: dict) -> dict:
suggestions = report.get("adjustment_suggestions", []) or []
sample_size = int(report.get("sample_size") or 0)
controls = {
"sample_size": sample_size,
"enabled": sample_size >= 10,
"buy_threshold_delta": 0,
"max_position_pct_delta": 0,
"actionable_limit_delta": 0,
"watch_limit_delta": 0,
"force_defensive": False,
"notes": [],
}
if sample_size < 10:
controls["notes"].append("样本不足,暂不启用自动回写。")
return controls
promote_count = 0
tighten_count = 0
reduce_count = 0
for item in suggestions[:6]:
action = item.get("action")
reason = item.get("reason", "")
if action == "promote":
promote_count += 1
controls["buy_threshold_delta"] -= 1
controls["watch_limit_delta"] += 1
elif action == "tighten":
tighten_count += 1
controls["buy_threshold_delta"] += 1
controls["actionable_limit_delta"] -= 1
controls["max_position_pct_delta"] -= 5
elif action == "reduce":
reduce_count += 1
controls["buy_threshold_delta"] += 1
controls["watch_limit_delta"] -= 1
if "弱势市场" in reason or item.get("target") == "defensive_watch":
controls["force_defensive"] = True
controls["buy_threshold_delta"] = max(-2, min(3, controls["buy_threshold_delta"]))
controls["max_position_pct_delta"] = max(-10, min(5, controls["max_position_pct_delta"]))
controls["actionable_limit_delta"] = max(-2, min(1, controls["actionable_limit_delta"]))
controls["watch_limit_delta"] = max(-2, min(2, controls["watch_limit_delta"]))
if controls["force_defensive"]:
controls["notes"].append("最近弱市亏损样本偏多,优先启用防守约束。")
elif tighten_count > promote_count:
controls["notes"].append("最近失效样本偏多,整体建议略收紧。")
elif promote_count > 0 and reduce_count == 0:
controls["notes"].append("最近有效样本改善,可适度放宽观察与出手空间。")
else:
controls["notes"].append("最近样本无明显单边倾向,仅做轻微校正。")
return controls
async def _generate_ai_iteration(rule_report: dict, rows: list[dict]) -> str:
from app.llm.client import chat_completion
from app.llm.prompts import STRATEGY_ITERATION_PROMPT
from app.llm.strategy_config import get_prompt_content
sample = [
{
"name": r.get("name"),
"strategy": r.get("strategy"),
"signal": r.get("entry_signal_type"),
"return": r.get("pct_from_entry"),
"max_return": r.get("max_return_pct"),
"drawdown": r.get("max_drawdown_pct"),
"reason": r.get("close_reason"),
"market_temp": r.get("market_temp_score"),
"position_score": r.get("position_score"),
}
for r in rows[:20]
]
prompt = await get_prompt_content("strategy_iteration", STRATEGY_ITERATION_PROMPT)
user_msg = f"""{prompt}
规则复盘:
{json.dumps(rule_report, ensure_ascii=False)}
样本:
{json.dumps(sample, ensure_ascii=False)}
"""
resp = await chat_completion([
{"role": "system", "content": "你是一位A股策略复盘研究员负责基于推荐结果提出保守、可验证的策略迭代建议。"},
{"role": "user", "content": user_msg},
])
return resp.content.strip() if resp and resp.content else ""
def _avg(values: list) -> float:
clean = [float(v) for v in values if v is not None]
if not clean:
return 0
return round(sum(clean) / len(clean), 2)