astock-agent/backend/app/llm/strategy_config.py
2026-04-30 20:28:19 +08:00

422 lines
15 KiB
Python

"""策略配置中心
把可迭代的策略参数和 Prompt 版本持久化到数据库。
代码里的默认策略只作为兜底;一旦数据库有激活配置,下一轮扫描直接读取配置。
"""
from __future__ import annotations
import json
import logging
from datetime import datetime
from typing import Any
from sqlalchemy import text
from app.db.database import get_db
from app.db import tables
logger = logging.getLogger(__name__)
CONFIG_FIELDS = {
"name",
"description",
"entry_signal_priority",
"score_weights",
"min_score",
"buy_threshold",
"max_position_pct",
"allow_trading",
"actionable_limit",
"watch_limit",
"target_focus_sectors",
"market_stance",
"decision_note",
"notes",
}
PROMPT_DEFAULT_KEYS = {
"stock_prefilter": "STOCK_PREFILTER_PROMPT",
"single_stock_analysis": "SINGLE_STOCK_ANALYSIS_PROMPT",
"strategy_iteration": "STRATEGY_ITERATION_PROMPT",
}
def profile_to_config(profile) -> dict[str, Any]:
data = profile.model_dump() if hasattr(profile, "model_dump") else dict(profile)
return {key: data[key] for key in CONFIG_FIELDS if key in data}
def apply_config_to_profile(profile, config: dict[str, Any] | None, generated_by: str = "config"):
if not config:
return profile
updated = profile.model_copy(deep=True)
for key, value in config.items():
if key in CONFIG_FIELDS and hasattr(updated, key):
setattr(updated, key, value)
updated.generated_by = generated_by
return updated
async def load_active_strategy_profile(profile):
row = await _load_active_strategy_row(profile.strategy_id)
if not row:
return profile
config = _json_loads(row["config_json"], {})
updated = apply_config_to_profile(profile, config, generated_by=f"config:v{row['version']}")
updated.feedback_applied = True
updated.feedback_notes = [
f"策略配置版本 v{row['version']} ({row['source']}) 已生效",
row["change_reason"] or "使用配置中心激活版本",
]
return updated
async def get_active_strategy_configs() -> list[dict]:
await ensure_default_configs()
async with get_db() as db:
result = await db.execute(
text(
"SELECT * FROM strategy_configs "
"WHERE is_active = 1 "
"ORDER BY strategy_id ASC, version DESC"
)
)
return [_format_strategy_row(row._mapping) for row in result.fetchall()]
async def get_recent_config_changes(limit: int = 20) -> list[dict]:
async with get_db() as db:
result = await db.execute(
text(
"SELECT * FROM strategy_config_changes "
"ORDER BY id DESC LIMIT :limit"
),
{"limit": limit},
)
return [_format_change_row(row._mapping) for row in result.fetchall()]
async def get_active_prompt_configs() -> list[dict]:
await ensure_default_configs()
async with get_db() as db:
result = await db.execute(
text(
"SELECT * FROM prompt_configs "
"WHERE is_active = 1 "
"ORDER BY prompt_key ASC, version DESC"
)
)
return [_format_prompt_row(row._mapping) for row in result.fetchall()]
async def get_prompt_content(prompt_key: str, default: str) -> str:
async with get_db() as db:
result = await db.execute(
text(
"SELECT content FROM prompt_configs "
"WHERE prompt_key = :key AND is_active = 1 "
"ORDER BY version DESC LIMIT 1"
),
{"key": prompt_key},
)
row = result.fetchone()
return str(row._mapping["content"]) if row else default
async def create_active_strategy_config(
strategy_id: str,
config: dict[str, Any],
*,
source: str,
reason: str,
evidence: dict[str, Any] | None = None,
change_type: str = "manual",
) -> dict:
"""写入一个新的激活策略配置版本,并记录变更。"""
async with get_db() as db:
base = await _load_active_strategy_row(strategy_id, db=db)
version = int(base["version"]) + 1 if base else 1
before = _json_loads(base["config_json"], {}) if base else {}
diff = _build_diff(before, config)
await db.execute(
text("UPDATE strategy_configs SET is_active = 0 WHERE strategy_id = :sid"),
{"sid": strategy_id},
)
await db.execute(
tables.strategy_configs_table.insert().values(
strategy_id=strategy_id,
version=version,
config_json=json.dumps(config, ensure_ascii=False),
is_active=True,
source=source,
change_reason=reason,
evidence_json=json.dumps(evidence or {}, ensure_ascii=False),
)
)
await db.execute(
tables.strategy_config_changes_table.insert().values(
change_type=change_type,
status="applied",
strategy_id=strategy_id,
base_version=int(base["version"]) if base else 0,
new_version=version,
diff_json=json.dumps(diff, ensure_ascii=False),
evidence_json=json.dumps(evidence or {}, ensure_ascii=False),
reason=reason,
applied_at=datetime.now(),
)
)
await db.commit()
row = await _load_active_strategy_row(strategy_id)
return _format_strategy_row(row)
async def rollback_strategy_config(strategy_id: str) -> dict:
"""回滚到当前策略的上一个版本。"""
async with get_db() as db:
active = await _load_active_strategy_row(strategy_id, db=db)
if not active:
raise ValueError("当前策略没有激活配置")
result = await db.execute(
text(
"SELECT * FROM strategy_configs "
"WHERE strategy_id = :sid AND version < :version "
"ORDER BY version DESC LIMIT 1"
),
{"sid": strategy_id, "version": active["version"]},
)
previous_row = result.fetchone()
if not previous_row:
raise ValueError("没有可回滚的上一版本")
current = active
previous = previous_row._mapping
await db.execute(
text("UPDATE strategy_configs SET is_active = 0 WHERE strategy_id = :sid"),
{"sid": strategy_id},
)
await db.execute(
text("UPDATE strategy_configs SET is_active = 1, source = 'rollback' WHERE id = :id"),
{"id": previous["id"]},
)
await db.execute(
tables.strategy_config_changes_table.insert().values(
change_type="rollback",
status="applied",
strategy_id=strategy_id,
base_version=int(current["version"]),
new_version=int(previous["version"]),
diff_json=json.dumps(
_build_diff(_json_loads(current["config_json"], {}), _json_loads(previous["config_json"], {})),
ensure_ascii=False,
),
reason=f"回滚 {strategy_id} 到 v{previous['version']}",
applied_at=datetime.now(),
)
)
await db.commit()
row = await _load_active_strategy_row(strategy_id)
return _format_strategy_row(row)
async def ensure_default_configs() -> None:
"""首次启动时把代码默认策略和默认 Prompt 种子写入数据库。"""
from app.llm.strategy_selector import get_strategy_profile_by_id
from app.llm import prompts
strategy_ids = ["breakout_attack", "pullback_rotation", "launch_probe", "defensive_watch"]
async with get_db() as db:
for strategy_id in strategy_ids:
count = (await db.execute(
text("SELECT COUNT(*) FROM strategy_configs WHERE strategy_id = :sid"),
{"sid": strategy_id},
)).scalar() or 0
if count:
continue
profile = get_strategy_profile_by_id(strategy_id)
await db.execute(
tables.strategy_configs_table.insert().values(
strategy_id=strategy_id,
version=1,
config_json=json.dumps(profile_to_config(profile), ensure_ascii=False),
is_active=True,
source="default_seed",
change_reason="初始化默认策略配置",
)
)
prompt_defaults = {
"stock_prefilter": getattr(prompts, "STOCK_PREFILTER_PROMPT", ""),
"single_stock_analysis": getattr(prompts, "SINGLE_STOCK_ANALYSIS_PROMPT", ""),
"strategy_iteration": getattr(prompts, "STRATEGY_ITERATION_PROMPT", ""),
}
for prompt_key, content in prompt_defaults.items():
if not content:
continue
count = (await db.execute(
text("SELECT COUNT(*) FROM prompt_configs WHERE prompt_key = :key"),
{"key": prompt_key},
)).scalar() or 0
if count:
continue
await db.execute(
tables.prompt_configs_table.insert().values(
prompt_key=prompt_key,
version=1,
content=content,
is_active=True,
source="default_seed",
change_reason="初始化默认 Prompt 配置",
)
)
await db.commit()
async def maybe_auto_apply_review_adjustment(report: dict) -> dict | None:
"""根据复盘报告做小幅自动配置调整。
大幅结构调整仍只进入报告建议,不自动改配置。
"""
sample_size = int(report.get("sample_size") or 0)
if sample_size < 10:
return None
if await _has_auto_change_today():
return None
for suggestion in report.get("adjustment_suggestions", []) or []:
strategy_id = suggestion.get("target", "")
if strategy_id not in {"breakout_attack", "pullback_rotation", "launch_probe", "defensive_watch"}:
continue
active = await _load_active_strategy_row(strategy_id)
if not active:
continue
config = _json_loads(active["config_json"], {})
changed = _apply_small_adjustment(config, suggestion.get("action", ""))
if not changed:
continue
evidence = {
"sample_size": sample_size,
"summary": report.get("summary", ""),
"suggestion": suggestion,
}
return await create_active_strategy_config(
strategy_id,
config,
source="auto_review",
reason=suggestion.get("reason", "复盘触发小幅自动配置调整"),
evidence=evidence,
change_type="auto_applied",
)
return None
async def _load_active_strategy_row(strategy_id: str, db=None):
own_session = db is None
if own_session:
async with get_db() as session:
return await _load_active_strategy_row(strategy_id, db=session)
result = await db.execute(
text(
"SELECT * FROM strategy_configs "
"WHERE strategy_id = :sid AND is_active = 1 "
"ORDER BY version DESC LIMIT 1"
),
{"sid": strategy_id},
)
row = result.fetchone()
return row._mapping if row else None
async def _has_auto_change_today() -> bool:
async with get_db() as db:
count = (await db.execute(
text(
"SELECT COUNT(*) FROM strategy_config_changes "
"WHERE change_type = 'auto_applied' "
"AND date(created_at) = date('now', 'localtime')"
)
)).scalar() or 0
return count > 0
def _apply_small_adjustment(config: dict[str, Any], action: str) -> bool:
if action == "tighten":
config["buy_threshold"] = min(float(config.get("buy_threshold", 60)) + 1, 80)
config["max_position_pct"] = max(float(config.get("max_position_pct", 10)) - 5, 0)
config["actionable_limit"] = max(int(config.get("actionable_limit", 1)) - 1, 0)
return True
if action == "promote":
config["buy_threshold"] = max(float(config.get("buy_threshold", 60)) - 1, float(config.get("min_score", 0)))
config["watch_limit"] = min(int(config.get("watch_limit", 3)) + 1, 8)
return True
if action == "reduce":
config["buy_threshold"] = min(float(config.get("buy_threshold", 60)) + 1, 80)
config["watch_limit"] = max(int(config.get("watch_limit", 3)) - 1, 1)
return True
return False
def _build_diff(before: dict[str, Any], after: dict[str, Any]) -> dict[str, dict[str, Any]]:
diff: dict[str, dict[str, Any]] = {}
for key in sorted(set(before) | set(after)):
if before.get(key) != after.get(key):
diff[key] = {"from": before.get(key), "to": after.get(key)}
return diff
def _json_loads(value: str | None, default):
try:
return json.loads(value or "")
except Exception:
return default
def _format_strategy_row(row) -> dict:
if not row:
return {}
return {
"id": row["id"],
"strategy_id": row["strategy_id"],
"version": row["version"],
"config": _json_loads(row["config_json"], {}),
"is_active": bool(row["is_active"]),
"source": row["source"] or "",
"change_reason": row["change_reason"] or "",
"evidence": _json_loads(row["evidence_json"], {}),
"effective_from": str(row["effective_from"] or ""),
"created_at": str(row["created_at"] or ""),
}
def _format_prompt_row(row) -> dict:
return {
"id": row["id"],
"prompt_key": row["prompt_key"],
"version": row["version"],
"content": row["content"],
"is_active": bool(row["is_active"]),
"source": row["source"] or "",
"change_reason": row["change_reason"] or "",
"evidence": _json_loads(row["evidence_json"], {}),
"created_at": str(row["created_at"] or ""),
}
def _format_change_row(row) -> dict:
return {
"id": row["id"],
"change_type": row["change_type"],
"status": row["status"],
"strategy_id": row["strategy_id"] or "",
"prompt_key": row["prompt_key"] or "",
"base_version": row["base_version"] or 0,
"new_version": row["new_version"] or 0,
"diff": _json_loads(row["diff_json"], {}),
"evidence": _json_loads(row["evidence_json"], {}),
"reason": row["reason"] or "",
"created_at": str(row["created_at"] or ""),
"applied_at": str(row["applied_at"] or ""),
}