431 lines
13 KiB
Python
431 lines
13 KiB
Python
"""
|
||
策略配置加载器 — 从 rules.yaml 加载所有参数,支持热更新
|
||
review_engine 调整权重后直接写回 yaml,下次运行自动生效
|
||
"""
|
||
import copy
|
||
import datetime
|
||
import os
|
||
|
||
import yaml
|
||
|
||
|
||
RULES_PATH = os.path.join(os.path.dirname(__file__), "rules.yaml")
|
||
|
||
_cache = None
|
||
_cache_mtime = None
|
||
|
||
# 兼容旧代码中的信号名写法
|
||
_SIGNAL_NAME_ALIASES = {
|
||
"N倍放量(≥10x)": "N倍放量",
|
||
"连续3x放量(≥3根)": "连续3x放量",
|
||
"静K→动K转折": "静K动K转折",
|
||
"Q≥7供给区突破": "Q7供给区突破",
|
||
"1H放量(量价背离)": "1H放量",
|
||
}
|
||
|
||
|
||
def load_rules(force_reload=False):
|
||
"""加载 rules.yaml,带文件变更检测自动刷新缓存"""
|
||
global _cache, _cache_mtime
|
||
mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
|
||
if not force_reload and _cache and _cache_mtime == mtime:
|
||
return _cache
|
||
|
||
with open(RULES_PATH, "r", encoding="utf-8") as f:
|
||
_cache = yaml.safe_load(f) or {}
|
||
_cache_mtime = mtime
|
||
return _cache
|
||
|
||
|
||
def save_rules(rules_dict):
|
||
"""保存修改后的 rules.yaml(review_engine 调整权重后用)"""
|
||
global _cache, _cache_mtime
|
||
with open(RULES_PATH, "w", encoding="utf-8") as f:
|
||
yaml.dump(rules_dict, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||
_cache = rules_dict
|
||
_cache_mtime = os.path.getmtime(RULES_PATH)
|
||
|
||
|
||
def _get_section(section_name, default=None):
|
||
rules = load_rules()
|
||
section = rules.get(section_name, default if default is not None else {})
|
||
return copy.deepcopy(section)
|
||
|
||
|
||
def _get_nested(section_name, path, default=None):
|
||
node = load_rules().get(section_name, {})
|
||
for key in path:
|
||
if not isinstance(node, dict):
|
||
return copy.deepcopy(default)
|
||
node = node.get(key)
|
||
if node is None:
|
||
return copy.deepcopy(default)
|
||
return copy.deepcopy(node)
|
||
|
||
|
||
def normalize_signal_name(signal_name):
|
||
"""统一信号名,兼容 rules.yaml 与旧 Python 代码中的不同写法"""
|
||
return _SIGNAL_NAME_ALIASES.get(signal_name, signal_name)
|
||
|
||
|
||
def get_strategy_params():
|
||
"""返回全局策略约束"""
|
||
return _get_section("strategy")
|
||
|
||
|
||
def get_strategy_direction(default="多头启动"):
|
||
return get_strategy_params().get("direction", default)
|
||
|
||
|
||
def is_long_only():
|
||
return get_strategy_params().get("mode", "") == "long_only"
|
||
|
||
|
||
def allow_short(default=False):
|
||
return bool(get_strategy_params().get("allow_short", default))
|
||
|
||
|
||
def get_pa_params():
|
||
"""返回 PA 引擎所有参数"""
|
||
return _get_section("pa_engine")
|
||
|
||
|
||
def get_pa_section(name=None):
|
||
"""返回 PA 引擎某个子区块;name=None 时返回整个 pa_engine"""
|
||
if name is None:
|
||
return get_pa_params()
|
||
return _get_nested("pa_engine", [name], {})
|
||
|
||
|
||
def get_screener_params():
|
||
"""返回粗筛所有参数"""
|
||
return _get_section("screener")
|
||
|
||
|
||
def get_event_driven_params():
|
||
"""返回事件驱动舆情触发选币参数"""
|
||
return _get_section("event_driven")
|
||
|
||
|
||
def get_event_driven_section(name=None):
|
||
"""返回 event_driven 某个子区块;name=None 时返回整个 event_driven"""
|
||
if name is None:
|
||
return get_event_driven_params()
|
||
return _get_nested("event_driven", [name], {})
|
||
|
||
|
||
def get_screener_section(name=None):
|
||
"""返回 screener 某个子区块;name=None 时返回整个 screener"""
|
||
if name is None:
|
||
return get_screener_params()
|
||
return _get_nested("screener", [name], {})
|
||
|
||
|
||
def get_confirm_params():
|
||
"""返回确认层所有参数"""
|
||
return _get_section("confirm")
|
||
|
||
|
||
def get_confirm_section(name=None):
|
||
"""返回 confirm 某个子区块;name=None 时返回整个 confirm"""
|
||
if name is None:
|
||
return get_confirm_params()
|
||
return _get_nested("confirm", [name], {})
|
||
|
||
|
||
def get_tracker_params():
|
||
"""返回跟踪层所有参数"""
|
||
return _get_section("tracker")
|
||
|
||
|
||
def get_signal_weights():
|
||
"""返回信号权重 dict(优先用 DB signal_performance 的动态权重,fallback 到 yaml)
|
||
|
||
兼容层要求:
|
||
- 规则侧统一存 canonical key(如 "1H放量")
|
||
- 旧脚本仍可能用历史 key(如 "1H放量(量价背离)")直接查 weights[...]
|
||
因此返回值同时暴露 canonical key + alias key,避免旧调用方 KeyError。
|
||
"""
|
||
rules = load_rules()
|
||
yaml_weights = copy.deepcopy(rules.get("signal_weights", {}))
|
||
|
||
canonical = {}
|
||
for sig, weight in yaml_weights.items():
|
||
canonical[normalize_signal_name(sig)] = weight
|
||
|
||
try:
|
||
from altcoin_db import get_signal_weights as db_get_weights
|
||
db_weights = db_get_weights()
|
||
for sig, data in db_weights.items():
|
||
norm_sig = normalize_signal_name(sig)
|
||
if data.get("total_count", 0) >= 3:
|
||
canonical[norm_sig] = data["weight"]
|
||
except Exception:
|
||
pass
|
||
|
||
merged = dict(canonical)
|
||
for alias, target in _SIGNAL_NAME_ALIASES.items():
|
||
if target in canonical:
|
||
merged[alias] = canonical[target]
|
||
return merged
|
||
|
||
|
||
def get_review_params():
|
||
"""返回复盘参数"""
|
||
return _get_section("review")
|
||
|
||
|
||
def get_reverse_params():
|
||
"""返回逆向分析参数"""
|
||
return _get_section("reverse_analysis")
|
||
|
||
|
||
def get_learned_rules(active_only=True):
|
||
"""返回已学习的规律列表"""
|
||
rules = load_rules()
|
||
learned = copy.deepcopy(rules.get("learned_rules", []))
|
||
if active_only:
|
||
return [r for r in learned if r.get("active", True)]
|
||
return learned
|
||
|
||
|
||
def add_learned_rule(rule_dict):
|
||
"""添加一条新学习规律"""
|
||
rules = load_rules(force_reload=True)
|
||
learned = rules.get("learned_rules", [])
|
||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||
rule_dict["id"] = f"rule_{ts}_{len(learned)+1:03d}"
|
||
rule_dict["created"] = datetime.datetime.now().strftime("%Y-%m-%d")
|
||
rule_dict["hit_count"] = 0
|
||
rule_dict["miss_count"] = 0
|
||
rule_dict["active"] = True
|
||
learned.append(rule_dict)
|
||
rules["learned_rules"] = learned
|
||
rules.setdefault("meta", {})["total_rules_learned"] = len(learned)
|
||
save_rules(rules)
|
||
return rule_dict["id"]
|
||
|
||
|
||
def update_learned_rule(rule_id, updates):
|
||
"""更新一条规律(如 hit_count/miss_count/active 状态)"""
|
||
rules = load_rules(force_reload=True)
|
||
learned = rules.get("learned_rules", [])
|
||
for r in learned:
|
||
if r.get("id") == rule_id:
|
||
for k, v in updates.items():
|
||
r[k] = v
|
||
break
|
||
rules["learned_rules"] = learned
|
||
save_rules(rules)
|
||
|
||
|
||
def update_signal_weight(signal_name, new_weight):
|
||
"""更新单个信号权重(写回 yaml + DB)"""
|
||
canonical_name = normalize_signal_name(signal_name)
|
||
rules = load_rules(force_reload=True)
|
||
rules.setdefault("signal_weights", {})[canonical_name] = new_weight
|
||
save_rules(rules)
|
||
try:
|
||
from altcoin_db import update_signal_performance
|
||
update_signal_performance(canonical_name, category=None, is_hit=None, pnl=None, weight_override=new_weight)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def get_meta():
|
||
"""返回迭代元数据"""
|
||
meta = _get_section("meta")
|
||
if not meta.get("strategy_version"):
|
||
version_num = meta.get("version", 1)
|
||
iteration = meta.get("iteration_count", 0)
|
||
meta["strategy_version"] = f"v{version_num}.{iteration}"
|
||
return meta
|
||
|
||
|
||
def get_rules_snapshot():
|
||
"""返回完整 rules.yaml 快照(深拷贝)"""
|
||
return copy.deepcopy(load_rules())
|
||
|
||
|
||
def diff_rule_snapshots(before, after, prefix=""):
|
||
"""递归比较两个配置快照,输出 changed/added/removed"""
|
||
result = {"changed": [], "added": [], "removed": []}
|
||
|
||
def walk(path, a, b):
|
||
if isinstance(a, dict) and isinstance(b, dict):
|
||
keys = set(a.keys()) | set(b.keys())
|
||
for key in sorted(keys):
|
||
next_path = f"{path}.{key}" if path else str(key)
|
||
if key not in a:
|
||
result["added"].append({"path": next_path, "new": copy.deepcopy(b[key])})
|
||
elif key not in b:
|
||
result["removed"].append({"path": next_path, "old": copy.deepcopy(a[key])})
|
||
else:
|
||
walk(next_path, a[key], b[key])
|
||
return
|
||
|
||
if isinstance(a, list) and isinstance(b, list):
|
||
if a != b:
|
||
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
|
||
return
|
||
|
||
if a != b:
|
||
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
|
||
|
||
walk(prefix, before or {}, after or {})
|
||
return result
|
||
|
||
|
||
|
||
def promote_candidate_rule_to_learned_rule(candidate, release_version=""):
|
||
"""把通过发布门槛的候选规则正式写入 learned_rules。
|
||
|
||
候选规则来自 DB strategy_rule_candidate;只有发布闸门通过时才调用,
|
||
避免日常研究直接污染主策略。
|
||
"""
|
||
desc = (candidate.get("rule_description") or "").strip()
|
||
if not desc:
|
||
return None
|
||
rules = load_rules(force_reload=True)
|
||
learned_rules = rules.setdefault("learned_rules", [])
|
||
for existing in learned_rules:
|
||
if existing.get("description") == desc:
|
||
return existing.get("id") or existing.get("rule_id")
|
||
rule_id = f"candidate_{candidate.get('id')}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||
rule = {
|
||
"id": rule_id,
|
||
"type": candidate.get("rule_type") or "bonus",
|
||
"description": desc,
|
||
"conditions": {
|
||
"candidate_id": candidate.get("id"),
|
||
"signal_name": candidate.get("signal_name") or "",
|
||
"source": candidate.get("source") or "strategy_rule_candidate",
|
||
"sample_size": candidate.get("sample_size") or 0,
|
||
"confidence_score": candidate.get("confidence_score") or 0,
|
||
},
|
||
"score_adjust": 2 if (candidate.get("rule_type") or "") == "bonus" else -2,
|
||
"source": "candidate_release_gate",
|
||
"release_version": release_version or get_meta().get("strategy_version", ""),
|
||
"created_at": datetime.datetime.now().isoformat(),
|
||
}
|
||
learned_rules.append(rule)
|
||
save_rules(rules)
|
||
return rule_id
|
||
|
||
|
||
def bump_strategy_patch_version(note=""):
|
||
"""正式发布时才提升 patch 版本号。"""
|
||
import re
|
||
meta = get_meta()
|
||
current_ver = str(meta.get("strategy_version") or "v1.0.0").strip()
|
||
m = re.match(r"^v(\d+)\.(\d+)\.(\d+)$", current_ver)
|
||
if m:
|
||
major, minor, patch = map(int, m.groups())
|
||
new_ver = f"v{major}.{minor}.{patch + 1}"
|
||
else:
|
||
new_ver = current_ver + ".1"
|
||
update_meta("strategy_version", new_ver)
|
||
update_meta("strategy_revision_note", f"{new_ver}: {note}" if note else new_ver)
|
||
update_meta("strategy_revision_started_at", datetime.datetime.now().isoformat())
|
||
return current_ver, new_ver
|
||
|
||
def update_meta(key, value):
|
||
"""更新迭代元数据"""
|
||
rules = load_rules(force_reload=True)
|
||
rules.setdefault("meta", {})[key] = value
|
||
save_rules(rules)
|
||
|
||
|
||
# === 快捷取值函数(给各模块直接 import 用)===
|
||
|
||
def dynamic_k_thresholds():
|
||
p = get_pa_section("dynamic_k")
|
||
return p["body_ratio_min"], p["atr_ratio_min"]
|
||
|
||
|
||
def static_k_thresholds():
|
||
p = get_pa_section("static_k")
|
||
return p["body_ratio_max"], p["atr_ratio_max"]
|
||
|
||
|
||
def zone_params():
|
||
p = get_pa_section("supply_demand")
|
||
return p["lookback"], p["min_static_count"], p["q_score_breakpoints"]
|
||
|
||
|
||
def ignition_params():
|
||
p = get_pa_section("ignition")
|
||
return p["lookback"], p["min_static_count"], p["static_search_range"], p["confirm_search_range"]
|
||
|
||
|
||
def continuous_k_params():
|
||
return get_pa_section("continuous_k")
|
||
|
||
|
||
def exhaustion_params():
|
||
return get_pa_section("exhaustion")
|
||
|
||
|
||
def entry_point_params():
|
||
return get_pa_section("entry_point")
|
||
|
||
|
||
def burst_thresholds():
|
||
p = get_screener_section("burst_threshold")
|
||
return p["main"], p["meme"], p["overbought_multiplier"]
|
||
|
||
|
||
def volume_thresholds():
|
||
p = get_screener_section("volume")
|
||
return p["min_usd"], p["meme_min_usd"]
|
||
|
||
|
||
def vp_fly_params():
|
||
return get_screener_section("vp_fly")
|
||
|
||
|
||
def bollinger_squeeze_params():
|
||
return get_screener_section("bollinger_squeeze")
|
||
|
||
|
||
def funding_rate_params():
|
||
return get_screener_section("funding_rate")
|
||
|
||
|
||
def top_trader_params():
|
||
return get_screener_section("top_trader")
|
||
|
||
|
||
def state_score_thresholds():
|
||
p = get_screener_section("state_threshold")
|
||
return p["accelerate_main"], p["accelerate_meme"], p["accumulate"]
|
||
|
||
|
||
def confirm_min_score():
|
||
return get_confirm_params().get("min_score", 6)
|
||
|
||
|
||
def confirm_volume_breakout_ratio():
|
||
return get_confirm_params().get("volume_breakout_ratio", 2.0)
|
||
|
||
|
||
def confirm_state_cooldown_hours():
|
||
return get_confirm_params().get("state_cooldown_hours", 6)
|
||
|
||
|
||
def confirm_atr_multipliers():
|
||
return get_confirm_section("atr_multiplier")
|
||
|
||
|
||
def confirm_stop_loss_params():
|
||
return get_confirm_section("stop_loss")
|
||
|
||
|
||
def get_sentiment_params():
|
||
"""返回舆情监控所有参数"""
|
||
return _get_section("sentiment")
|
||
|
||
|
||
def sentiment_max_bonus():
|
||
return get_sentiment_params().get("max_bonus", 2)
|