""" 策略配置加载器 — 从 rules.yaml 加载所有参数,支持热更新 review_engine 调整权重后直接写回 yaml,下次运行自动生效 """ import copy import datetime import os import yaml RULES_PATH = os.path.join(os.path.dirname(__file__), "rules.yaml") _cache = None _cache_mtime = None # 兼容旧代码中的信号名写法 _SIGNAL_NAME_ALIASES = { "N倍放量(≥10x)": "N倍放量", "连续3x放量(≥3根)": "连续3x放量", "静K→动K转折": "静K动K转折", "Q≥7供给区突破": "Q7供给区突破", "1H放量(量价背离)": "1H放量", } def load_rules(force_reload=False): """加载 rules.yaml,带文件变更检测自动刷新缓存""" global _cache, _cache_mtime mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0 if not force_reload and _cache and _cache_mtime == mtime: return _cache with open(RULES_PATH, "r", encoding="utf-8") as f: _cache = yaml.safe_load(f) or {} _cache_mtime = mtime return _cache def save_rules(rules_dict): """保存修改后的 rules.yaml(review_engine 调整权重后用)""" global _cache, _cache_mtime with open(RULES_PATH, "w", encoding="utf-8") as f: yaml.dump(rules_dict, f, default_flow_style=False, allow_unicode=True, sort_keys=False) _cache = rules_dict _cache_mtime = os.path.getmtime(RULES_PATH) def _get_section(section_name, default=None): rules = load_rules() section = rules.get(section_name, default if default is not None else {}) return copy.deepcopy(section) def _get_nested(section_name, path, default=None): node = load_rules().get(section_name, {}) for key in path: if not isinstance(node, dict): return copy.deepcopy(default) node = node.get(key) if node is None: return copy.deepcopy(default) return copy.deepcopy(node) def normalize_signal_name(signal_name): """统一信号名,兼容 rules.yaml 与旧 Python 代码中的不同写法""" return _SIGNAL_NAME_ALIASES.get(signal_name, signal_name) def get_strategy_params(): """返回全局策略约束""" return _get_section("strategy") def get_strategy_direction(default="多头启动"): return get_strategy_params().get("direction", default) def is_long_only(): return get_strategy_params().get("mode", "") == "long_only" def allow_short(default=False): return bool(get_strategy_params().get("allow_short", default)) def get_pa_params(): """返回 PA 引擎所有参数""" return _get_section("pa_engine") def get_pa_section(name=None): """返回 PA 引擎某个子区块;name=None 时返回整个 pa_engine""" if name is None: return get_pa_params() return _get_nested("pa_engine", [name], {}) def get_screener_params(): """返回粗筛所有参数""" return _get_section("screener") def get_event_driven_params(): """返回事件驱动舆情触发选币参数""" return _get_section("event_driven") def get_event_driven_section(name=None): """返回 event_driven 某个子区块;name=None 时返回整个 event_driven""" if name is None: return get_event_driven_params() return _get_nested("event_driven", [name], {}) def get_screener_section(name=None): """返回 screener 某个子区块;name=None 时返回整个 screener""" if name is None: return get_screener_params() return _get_nested("screener", [name], {}) def get_confirm_params(): """返回确认层所有参数""" return _get_section("confirm") def get_confirm_section(name=None): """返回 confirm 某个子区块;name=None 时返回整个 confirm""" if name is None: return get_confirm_params() return _get_nested("confirm", [name], {}) def get_tracker_params(): """返回跟踪层所有参数""" return _get_section("tracker") def get_signal_weights(): """返回信号权重 dict(优先用 DB signal_performance 的动态权重,fallback 到 yaml) 兼容层要求: - 规则侧统一存 canonical key(如 "1H放量") - 旧脚本仍可能用历史 key(如 "1H放量(量价背离)")直接查 weights[...] 因此返回值同时暴露 canonical key + alias key,避免旧调用方 KeyError。 """ rules = load_rules() yaml_weights = copy.deepcopy(rules.get("signal_weights", {})) canonical = {} for sig, weight in yaml_weights.items(): canonical[normalize_signal_name(sig)] = weight try: from altcoin_db import get_signal_weights as db_get_weights db_weights = db_get_weights() for sig, data in db_weights.items(): norm_sig = normalize_signal_name(sig) if data.get("total_count", 0) >= 3: canonical[norm_sig] = data["weight"] except Exception: pass merged = dict(canonical) for alias, target in _SIGNAL_NAME_ALIASES.items(): if target in canonical: merged[alias] = canonical[target] return merged def get_review_params(): """返回复盘参数""" return _get_section("review") def get_reverse_params(): """返回逆向分析参数""" return _get_section("reverse_analysis") def get_learned_rules(active_only=True): """返回已学习的规律列表""" rules = load_rules() learned = copy.deepcopy(rules.get("learned_rules", [])) if active_only: return [r for r in learned if r.get("active", True)] return learned def add_learned_rule(rule_dict): """添加一条新学习规律""" rules = load_rules(force_reload=True) learned = rules.get("learned_rules", []) ts = datetime.datetime.now().strftime("%Y%m%d_%H%M") rule_dict["id"] = f"rule_{ts}_{len(learned)+1:03d}" rule_dict["created"] = datetime.datetime.now().strftime("%Y-%m-%d") rule_dict["hit_count"] = 0 rule_dict["miss_count"] = 0 rule_dict["active"] = True learned.append(rule_dict) rules["learned_rules"] = learned rules.setdefault("meta", {})["total_rules_learned"] = len(learned) save_rules(rules) return rule_dict["id"] def update_learned_rule(rule_id, updates): """更新一条规律(如 hit_count/miss_count/active 状态)""" rules = load_rules(force_reload=True) learned = rules.get("learned_rules", []) for r in learned: if r.get("id") == rule_id: for k, v in updates.items(): r[k] = v break rules["learned_rules"] = learned save_rules(rules) def update_signal_weight(signal_name, new_weight): """更新单个信号权重(写回 yaml + DB)""" canonical_name = normalize_signal_name(signal_name) rules = load_rules(force_reload=True) rules.setdefault("signal_weights", {})[canonical_name] = new_weight save_rules(rules) try: from altcoin_db import update_signal_performance update_signal_performance(canonical_name, category=None, is_hit=None, pnl=None, weight_override=new_weight) except Exception: pass def get_meta(): """返回迭代元数据""" meta = _get_section("meta") if not meta.get("strategy_version"): version_num = meta.get("version", 1) iteration = meta.get("iteration_count", 0) meta["strategy_version"] = f"v{version_num}.{iteration}" return meta def get_rules_snapshot(): """返回完整 rules.yaml 快照(深拷贝)""" return copy.deepcopy(load_rules()) def diff_rule_snapshots(before, after, prefix=""): """递归比较两个配置快照,输出 changed/added/removed""" result = {"changed": [], "added": [], "removed": []} def walk(path, a, b): if isinstance(a, dict) and isinstance(b, dict): keys = set(a.keys()) | set(b.keys()) for key in sorted(keys): next_path = f"{path}.{key}" if path else str(key) if key not in a: result["added"].append({"path": next_path, "new": copy.deepcopy(b[key])}) elif key not in b: result["removed"].append({"path": next_path, "old": copy.deepcopy(a[key])}) else: walk(next_path, a[key], b[key]) return if isinstance(a, list) and isinstance(b, list): if a != b: result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)}) return if a != b: result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)}) walk(prefix, before or {}, after or {}) return result def promote_candidate_rule_to_learned_rule(candidate, release_version=""): """把通过发布门槛的候选规则正式写入 learned_rules。 候选规则来自 DB strategy_rule_candidate;只有发布闸门通过时才调用, 避免日常研究直接污染主策略。 """ desc = (candidate.get("rule_description") or "").strip() if not desc: return None rules = load_rules(force_reload=True) learned_rules = rules.setdefault("learned_rules", []) for existing in learned_rules: if existing.get("description") == desc: return existing.get("id") or existing.get("rule_id") rule_id = f"candidate_{candidate.get('id')}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" rule = { "id": rule_id, "type": candidate.get("rule_type") or "bonus", "description": desc, "conditions": { "candidate_id": candidate.get("id"), "signal_name": candidate.get("signal_name") or "", "source": candidate.get("source") or "strategy_rule_candidate", "sample_size": candidate.get("sample_size") or 0, "confidence_score": candidate.get("confidence_score") or 0, }, "score_adjust": 2 if (candidate.get("rule_type") or "") == "bonus" else -2, "source": "candidate_release_gate", "release_version": release_version or get_meta().get("strategy_version", ""), "created_at": datetime.datetime.now().isoformat(), } learned_rules.append(rule) save_rules(rules) return rule_id def bump_strategy_patch_version(note=""): """正式发布时才提升 patch 版本号。""" import re meta = get_meta() current_ver = str(meta.get("strategy_version") or "v1.0.0").strip() m = re.match(r"^v(\d+)\.(\d+)\.(\d+)$", current_ver) if m: major, minor, patch = map(int, m.groups()) new_ver = f"v{major}.{minor}.{patch + 1}" else: new_ver = current_ver + ".1" update_meta("strategy_version", new_ver) update_meta("strategy_revision_note", f"{new_ver}: {note}" if note else new_ver) update_meta("strategy_revision_started_at", datetime.datetime.now().isoformat()) return current_ver, new_ver def update_meta(key, value): """更新迭代元数据""" rules = load_rules(force_reload=True) rules.setdefault("meta", {})[key] = value save_rules(rules) # === 快捷取值函数(给各模块直接 import 用)=== def dynamic_k_thresholds(): p = get_pa_section("dynamic_k") return p["body_ratio_min"], p["atr_ratio_min"] def static_k_thresholds(): p = get_pa_section("static_k") return p["body_ratio_max"], p["atr_ratio_max"] def zone_params(): p = get_pa_section("supply_demand") return p["lookback"], p["min_static_count"], p["q_score_breakpoints"] def ignition_params(): p = get_pa_section("ignition") return p["lookback"], p["min_static_count"], p["static_search_range"], p["confirm_search_range"] def continuous_k_params(): return get_pa_section("continuous_k") def exhaustion_params(): return get_pa_section("exhaustion") def entry_point_params(): return get_pa_section("entry_point") def burst_thresholds(): p = get_screener_section("burst_threshold") return p["main"], p["meme"], p["overbought_multiplier"] def volume_thresholds(): p = get_screener_section("volume") return p["min_usd"], p["meme_min_usd"] def vp_fly_params(): return get_screener_section("vp_fly") def bollinger_squeeze_params(): return get_screener_section("bollinger_squeeze") def funding_rate_params(): return get_screener_section("funding_rate") def top_trader_params(): return get_screener_section("top_trader") def state_score_thresholds(): p = get_screener_section("state_threshold") return p["accelerate_main"], p["accelerate_meme"], p["accumulate"] def confirm_min_score(): return get_confirm_params().get("min_score", 6) def confirm_volume_breakout_ratio(): return get_confirm_params().get("volume_breakout_ratio", 2.0) def confirm_state_cooldown_hours(): return get_confirm_params().get("state_cooldown_hours", 6) def confirm_atr_multipliers(): return get_confirm_section("atr_multiplier") def confirm_stop_loss_params(): return get_confirm_section("stop_loss") def get_sentiment_params(): """返回舆情监控所有参数""" return _get_section("sentiment") def sentiment_max_bonus(): return get_sentiment_params().get("max_bonus", 2)