422 lines
15 KiB
Python
422 lines
15 KiB
Python
"""策略配置中心
|
|
|
|
把可迭代的策略参数和 Prompt 版本持久化到数据库。
|
|
代码里的默认策略只作为兜底;一旦数据库有激活配置,下一轮扫描直接读取配置。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.db.database import get_db
|
|
from app.db import tables
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CONFIG_FIELDS = {
|
|
"name",
|
|
"description",
|
|
"entry_signal_priority",
|
|
"score_weights",
|
|
"min_score",
|
|
"buy_threshold",
|
|
"max_position_pct",
|
|
"allow_trading",
|
|
"actionable_limit",
|
|
"watch_limit",
|
|
"target_focus_sectors",
|
|
"market_stance",
|
|
"decision_note",
|
|
"notes",
|
|
}
|
|
|
|
PROMPT_DEFAULT_KEYS = {
|
|
"stock_prefilter": "STOCK_PREFILTER_PROMPT",
|
|
"single_stock_analysis": "SINGLE_STOCK_ANALYSIS_PROMPT",
|
|
"strategy_iteration": "STRATEGY_ITERATION_PROMPT",
|
|
}
|
|
|
|
|
|
def profile_to_config(profile) -> dict[str, Any]:
|
|
data = profile.model_dump() if hasattr(profile, "model_dump") else dict(profile)
|
|
return {key: data[key] for key in CONFIG_FIELDS if key in data}
|
|
|
|
|
|
def apply_config_to_profile(profile, config: dict[str, Any] | None, generated_by: str = "config"):
|
|
if not config:
|
|
return profile
|
|
updated = profile.model_copy(deep=True)
|
|
for key, value in config.items():
|
|
if key in CONFIG_FIELDS and hasattr(updated, key):
|
|
setattr(updated, key, value)
|
|
updated.generated_by = generated_by
|
|
return updated
|
|
|
|
|
|
async def load_active_strategy_profile(profile):
|
|
row = await _load_active_strategy_row(profile.strategy_id)
|
|
if not row:
|
|
return profile
|
|
config = _json_loads(row["config_json"], {})
|
|
updated = apply_config_to_profile(profile, config, generated_by=f"config:v{row['version']}")
|
|
updated.feedback_applied = True
|
|
updated.feedback_notes = [
|
|
f"策略配置版本 v{row['version']} ({row['source']}) 已生效",
|
|
row["change_reason"] or "使用配置中心激活版本",
|
|
]
|
|
return updated
|
|
|
|
|
|
async def get_active_strategy_configs() -> list[dict]:
|
|
await ensure_default_configs()
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM strategy_configs "
|
|
"WHERE is_active = 1 "
|
|
"ORDER BY strategy_id ASC, version DESC"
|
|
)
|
|
)
|
|
return [_format_strategy_row(row._mapping) for row in result.fetchall()]
|
|
|
|
|
|
async def get_recent_config_changes(limit: int = 20) -> list[dict]:
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM strategy_config_changes "
|
|
"ORDER BY id DESC LIMIT :limit"
|
|
),
|
|
{"limit": limit},
|
|
)
|
|
return [_format_change_row(row._mapping) for row in result.fetchall()]
|
|
|
|
|
|
async def get_active_prompt_configs() -> list[dict]:
|
|
await ensure_default_configs()
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM prompt_configs "
|
|
"WHERE is_active = 1 "
|
|
"ORDER BY prompt_key ASC, version DESC"
|
|
)
|
|
)
|
|
return [_format_prompt_row(row._mapping) for row in result.fetchall()]
|
|
|
|
|
|
async def get_prompt_content(prompt_key: str, default: str) -> str:
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT content FROM prompt_configs "
|
|
"WHERE prompt_key = :key AND is_active = 1 "
|
|
"ORDER BY version DESC LIMIT 1"
|
|
),
|
|
{"key": prompt_key},
|
|
)
|
|
row = result.fetchone()
|
|
return str(row._mapping["content"]) if row else default
|
|
|
|
|
|
async def create_active_strategy_config(
|
|
strategy_id: str,
|
|
config: dict[str, Any],
|
|
*,
|
|
source: str,
|
|
reason: str,
|
|
evidence: dict[str, Any] | None = None,
|
|
change_type: str = "manual",
|
|
) -> dict:
|
|
"""写入一个新的激活策略配置版本,并记录变更。"""
|
|
async with get_db() as db:
|
|
base = await _load_active_strategy_row(strategy_id, db=db)
|
|
version = int(base["version"]) + 1 if base else 1
|
|
before = _json_loads(base["config_json"], {}) if base else {}
|
|
diff = _build_diff(before, config)
|
|
|
|
await db.execute(
|
|
text("UPDATE strategy_configs SET is_active = 0 WHERE strategy_id = :sid"),
|
|
{"sid": strategy_id},
|
|
)
|
|
await db.execute(
|
|
tables.strategy_configs_table.insert().values(
|
|
strategy_id=strategy_id,
|
|
version=version,
|
|
config_json=json.dumps(config, ensure_ascii=False),
|
|
is_active=True,
|
|
source=source,
|
|
change_reason=reason,
|
|
evidence_json=json.dumps(evidence or {}, ensure_ascii=False),
|
|
)
|
|
)
|
|
await db.execute(
|
|
tables.strategy_config_changes_table.insert().values(
|
|
change_type=change_type,
|
|
status="applied",
|
|
strategy_id=strategy_id,
|
|
base_version=int(base["version"]) if base else 0,
|
|
new_version=version,
|
|
diff_json=json.dumps(diff, ensure_ascii=False),
|
|
evidence_json=json.dumps(evidence or {}, ensure_ascii=False),
|
|
reason=reason,
|
|
applied_at=datetime.now(),
|
|
)
|
|
)
|
|
await db.commit()
|
|
|
|
row = await _load_active_strategy_row(strategy_id)
|
|
return _format_strategy_row(row)
|
|
|
|
|
|
async def rollback_strategy_config(strategy_id: str) -> dict:
|
|
"""回滚到当前策略的上一个版本。"""
|
|
async with get_db() as db:
|
|
active = await _load_active_strategy_row(strategy_id, db=db)
|
|
if not active:
|
|
raise ValueError("当前策略没有激活配置")
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM strategy_configs "
|
|
"WHERE strategy_id = :sid AND version < :version "
|
|
"ORDER BY version DESC LIMIT 1"
|
|
),
|
|
{"sid": strategy_id, "version": active["version"]},
|
|
)
|
|
previous_row = result.fetchone()
|
|
if not previous_row:
|
|
raise ValueError("没有可回滚的上一版本")
|
|
current = active
|
|
previous = previous_row._mapping
|
|
await db.execute(
|
|
text("UPDATE strategy_configs SET is_active = 0 WHERE strategy_id = :sid"),
|
|
{"sid": strategy_id},
|
|
)
|
|
await db.execute(
|
|
text("UPDATE strategy_configs SET is_active = 1, source = 'rollback' WHERE id = :id"),
|
|
{"id": previous["id"]},
|
|
)
|
|
await db.execute(
|
|
tables.strategy_config_changes_table.insert().values(
|
|
change_type="rollback",
|
|
status="applied",
|
|
strategy_id=strategy_id,
|
|
base_version=int(current["version"]),
|
|
new_version=int(previous["version"]),
|
|
diff_json=json.dumps(
|
|
_build_diff(_json_loads(current["config_json"], {}), _json_loads(previous["config_json"], {})),
|
|
ensure_ascii=False,
|
|
),
|
|
reason=f"回滚 {strategy_id} 到 v{previous['version']}",
|
|
applied_at=datetime.now(),
|
|
)
|
|
)
|
|
await db.commit()
|
|
row = await _load_active_strategy_row(strategy_id)
|
|
return _format_strategy_row(row)
|
|
|
|
|
|
async def ensure_default_configs() -> None:
|
|
"""首次启动时把代码默认策略和默认 Prompt 种子写入数据库。"""
|
|
from app.llm.strategy_selector import get_strategy_profile_by_id
|
|
from app.llm import prompts
|
|
|
|
strategy_ids = ["breakout_attack", "pullback_rotation", "launch_probe", "defensive_watch"]
|
|
async with get_db() as db:
|
|
for strategy_id in strategy_ids:
|
|
count = (await db.execute(
|
|
text("SELECT COUNT(*) FROM strategy_configs WHERE strategy_id = :sid"),
|
|
{"sid": strategy_id},
|
|
)).scalar() or 0
|
|
if count:
|
|
continue
|
|
profile = get_strategy_profile_by_id(strategy_id)
|
|
await db.execute(
|
|
tables.strategy_configs_table.insert().values(
|
|
strategy_id=strategy_id,
|
|
version=1,
|
|
config_json=json.dumps(profile_to_config(profile), ensure_ascii=False),
|
|
is_active=True,
|
|
source="default_seed",
|
|
change_reason="初始化默认策略配置",
|
|
)
|
|
)
|
|
|
|
prompt_defaults = {
|
|
"stock_prefilter": getattr(prompts, "STOCK_PREFILTER_PROMPT", ""),
|
|
"single_stock_analysis": getattr(prompts, "SINGLE_STOCK_ANALYSIS_PROMPT", ""),
|
|
"strategy_iteration": getattr(prompts, "STRATEGY_ITERATION_PROMPT", ""),
|
|
}
|
|
for prompt_key, content in prompt_defaults.items():
|
|
if not content:
|
|
continue
|
|
count = (await db.execute(
|
|
text("SELECT COUNT(*) FROM prompt_configs WHERE prompt_key = :key"),
|
|
{"key": prompt_key},
|
|
)).scalar() or 0
|
|
if count:
|
|
continue
|
|
await db.execute(
|
|
tables.prompt_configs_table.insert().values(
|
|
prompt_key=prompt_key,
|
|
version=1,
|
|
content=content,
|
|
is_active=True,
|
|
source="default_seed",
|
|
change_reason="初始化默认 Prompt 配置",
|
|
)
|
|
)
|
|
await db.commit()
|
|
|
|
|
|
async def maybe_auto_apply_review_adjustment(report: dict) -> dict | None:
|
|
"""根据复盘报告做小幅自动配置调整。
|
|
|
|
大幅结构调整仍只进入报告建议,不自动改配置。
|
|
"""
|
|
sample_size = int(report.get("sample_size") or 0)
|
|
if sample_size < 10:
|
|
return None
|
|
if await _has_auto_change_today():
|
|
return None
|
|
|
|
for suggestion in report.get("adjustment_suggestions", []) or []:
|
|
strategy_id = suggestion.get("target", "")
|
|
if strategy_id not in {"breakout_attack", "pullback_rotation", "launch_probe", "defensive_watch"}:
|
|
continue
|
|
active = await _load_active_strategy_row(strategy_id)
|
|
if not active:
|
|
continue
|
|
config = _json_loads(active["config_json"], {})
|
|
changed = _apply_small_adjustment(config, suggestion.get("action", ""))
|
|
if not changed:
|
|
continue
|
|
evidence = {
|
|
"sample_size": sample_size,
|
|
"summary": report.get("summary", ""),
|
|
"suggestion": suggestion,
|
|
}
|
|
return await create_active_strategy_config(
|
|
strategy_id,
|
|
config,
|
|
source="auto_review",
|
|
reason=suggestion.get("reason", "复盘触发小幅自动配置调整"),
|
|
evidence=evidence,
|
|
change_type="auto_applied",
|
|
)
|
|
return None
|
|
|
|
|
|
async def _load_active_strategy_row(strategy_id: str, db=None):
|
|
own_session = db is None
|
|
if own_session:
|
|
async with get_db() as session:
|
|
return await _load_active_strategy_row(strategy_id, db=session)
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM strategy_configs "
|
|
"WHERE strategy_id = :sid AND is_active = 1 "
|
|
"ORDER BY version DESC LIMIT 1"
|
|
),
|
|
{"sid": strategy_id},
|
|
)
|
|
row = result.fetchone()
|
|
return row._mapping if row else None
|
|
|
|
|
|
async def _has_auto_change_today() -> bool:
|
|
async with get_db() as db:
|
|
count = (await db.execute(
|
|
text(
|
|
"SELECT COUNT(*) FROM strategy_config_changes "
|
|
"WHERE change_type = 'auto_applied' "
|
|
"AND date(created_at) = date('now', 'localtime')"
|
|
)
|
|
)).scalar() or 0
|
|
return count > 0
|
|
|
|
|
|
def _apply_small_adjustment(config: dict[str, Any], action: str) -> bool:
|
|
if action == "tighten":
|
|
config["buy_threshold"] = min(float(config.get("buy_threshold", 60)) + 1, 80)
|
|
config["max_position_pct"] = max(float(config.get("max_position_pct", 10)) - 5, 0)
|
|
config["actionable_limit"] = max(int(config.get("actionable_limit", 1)) - 1, 0)
|
|
return True
|
|
if action == "promote":
|
|
config["buy_threshold"] = max(float(config.get("buy_threshold", 60)) - 1, float(config.get("min_score", 0)))
|
|
config["watch_limit"] = min(int(config.get("watch_limit", 3)) + 1, 8)
|
|
return True
|
|
if action == "reduce":
|
|
config["buy_threshold"] = min(float(config.get("buy_threshold", 60)) + 1, 80)
|
|
config["watch_limit"] = max(int(config.get("watch_limit", 3)) - 1, 1)
|
|
return True
|
|
return False
|
|
|
|
|
|
def _build_diff(before: dict[str, Any], after: dict[str, Any]) -> dict[str, dict[str, Any]]:
|
|
diff: dict[str, dict[str, Any]] = {}
|
|
for key in sorted(set(before) | set(after)):
|
|
if before.get(key) != after.get(key):
|
|
diff[key] = {"from": before.get(key), "to": after.get(key)}
|
|
return diff
|
|
|
|
|
|
def _json_loads(value: str | None, default):
|
|
try:
|
|
return json.loads(value or "")
|
|
except Exception:
|
|
return default
|
|
|
|
|
|
def _format_strategy_row(row) -> dict:
|
|
if not row:
|
|
return {}
|
|
return {
|
|
"id": row["id"],
|
|
"strategy_id": row["strategy_id"],
|
|
"version": row["version"],
|
|
"config": _json_loads(row["config_json"], {}),
|
|
"is_active": bool(row["is_active"]),
|
|
"source": row["source"] or "",
|
|
"change_reason": row["change_reason"] or "",
|
|
"evidence": _json_loads(row["evidence_json"], {}),
|
|
"effective_from": str(row["effective_from"] or ""),
|
|
"created_at": str(row["created_at"] or ""),
|
|
}
|
|
|
|
|
|
def _format_prompt_row(row) -> dict:
|
|
return {
|
|
"id": row["id"],
|
|
"prompt_key": row["prompt_key"],
|
|
"version": row["version"],
|
|
"content": row["content"],
|
|
"is_active": bool(row["is_active"]),
|
|
"source": row["source"] or "",
|
|
"change_reason": row["change_reason"] or "",
|
|
"evidence": _json_loads(row["evidence_json"], {}),
|
|
"created_at": str(row["created_at"] or ""),
|
|
}
|
|
|
|
|
|
def _format_change_row(row) -> dict:
|
|
return {
|
|
"id": row["id"],
|
|
"change_type": row["change_type"],
|
|
"status": row["status"],
|
|
"strategy_id": row["strategy_id"] or "",
|
|
"prompt_key": row["prompt_key"] or "",
|
|
"base_version": row["base_version"] or 0,
|
|
"new_version": row["new_version"] or 0,
|
|
"diff": _json_loads(row["diff_json"], {}),
|
|
"evidence": _json_loads(row["evidence_json"], {}),
|
|
"reason": row["reason"] or "",
|
|
"created_at": str(row["created_at"] or ""),
|
|
"applied_at": str(row["applied_at"] or ""),
|
|
}
|