alphax/app/config/config_loader.py
2026-05-13 22:49:47 +08:00

433 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略配置加载器 — 从 rules.yaml 加载所有参数,支持热更新
review_engine 调整权重后直接写回 yaml下次运行自动生效
"""
import copy
import datetime
import os
from pathlib import Path
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
RULES_PATH = str(REPO_ROOT / "rules.yaml")
_cache = None
_cache_mtime = None
# 兼容旧代码中的信号名写法
_SIGNAL_NAME_ALIASES = {
"N倍放量(≥10x)": "N倍放量",
"连续3x放量(≥3根)": "连续3x放量",
"静K→动K转折": "静K动K转折",
"Q≥7供给区突破": "Q7供给区突破",
"1H放量(量价背离)": "1H放量",
}
def load_rules(force_reload=False):
"""加载 rules.yaml带文件变更检测自动刷新缓存"""
global _cache, _cache_mtime
mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
if not force_reload and _cache and _cache_mtime == mtime:
return _cache
with open(RULES_PATH, "r", encoding="utf-8") as f:
_cache = yaml.safe_load(f) or {}
_cache_mtime = mtime
return _cache
def save_rules(rules_dict):
"""保存修改后的 rules.yamlreview_engine 调整权重后用)"""
global _cache, _cache_mtime
with open(RULES_PATH, "w", encoding="utf-8") as f:
yaml.dump(rules_dict, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
_cache = rules_dict
_cache_mtime = os.path.getmtime(RULES_PATH)
def _get_section(section_name, default=None):
rules = load_rules()
section = rules.get(section_name, default if default is not None else {})
return copy.deepcopy(section)
def _get_nested(section_name, path, default=None):
node = load_rules().get(section_name, {})
for key in path:
if not isinstance(node, dict):
return copy.deepcopy(default)
node = node.get(key)
if node is None:
return copy.deepcopy(default)
return copy.deepcopy(node)
def normalize_signal_name(signal_name):
"""统一信号名,兼容 rules.yaml 与旧 Python 代码中的不同写法"""
return _SIGNAL_NAME_ALIASES.get(signal_name, signal_name)
def get_strategy_params():
"""返回全局策略约束"""
return _get_section("strategy")
def get_strategy_direction(default="多头启动"):
return get_strategy_params().get("direction", default)
def is_long_only():
return get_strategy_params().get("mode", "") == "long_only"
def allow_short(default=False):
return bool(get_strategy_params().get("allow_short", default))
def get_pa_params():
"""返回 PA 引擎所有参数"""
return _get_section("pa_engine")
def get_pa_section(name=None):
"""返回 PA 引擎某个子区块name=None 时返回整个 pa_engine"""
if name is None:
return get_pa_params()
return _get_nested("pa_engine", [name], {})
def get_screener_params():
"""返回粗筛所有参数"""
return _get_section("screener")
def get_event_driven_params():
"""返回事件驱动舆情触发选币参数"""
return _get_section("event_driven")
def get_event_driven_section(name=None):
"""返回 event_driven 某个子区块name=None 时返回整个 event_driven"""
if name is None:
return get_event_driven_params()
return _get_nested("event_driven", [name], {})
def get_screener_section(name=None):
"""返回 screener 某个子区块name=None 时返回整个 screener"""
if name is None:
return get_screener_params()
return _get_nested("screener", [name], {})
def get_confirm_params():
"""返回确认层所有参数"""
return _get_section("confirm")
def get_confirm_section(name=None):
"""返回 confirm 某个子区块name=None 时返回整个 confirm"""
if name is None:
return get_confirm_params()
return _get_nested("confirm", [name], {})
def get_tracker_params():
"""返回跟踪层所有参数"""
return _get_section("tracker")
def get_signal_weights():
"""返回信号权重 dict优先用 DB signal_performance 的动态权重fallback 到 yaml
兼容层要求:
- 规则侧统一存 canonical key"1H放量"
- 旧脚本仍可能用历史 key"1H放量(量价背离)")直接查 weights[...]
因此返回值同时暴露 canonical key + alias key避免旧调用方 KeyError。
"""
rules = load_rules()
yaml_weights = copy.deepcopy(rules.get("signal_weights", {}))
canonical = {}
for sig, weight in yaml_weights.items():
canonical[normalize_signal_name(sig)] = weight
try:
from app.db.altcoin_db import get_signal_weights as db_get_weights
db_weights = db_get_weights()
for sig, data in db_weights.items():
norm_sig = normalize_signal_name(sig)
if data.get("total_count", 0) >= 3:
canonical[norm_sig] = data["weight"]
except Exception:
pass
merged = dict(canonical)
for alias, target in _SIGNAL_NAME_ALIASES.items():
if target in canonical:
merged[alias] = canonical[target]
return merged
def get_review_params():
"""返回复盘参数"""
return _get_section("review")
def get_reverse_params():
"""返回逆向分析参数"""
return _get_section("reverse_analysis")
def get_learned_rules(active_only=True):
"""返回已学习的规律列表"""
rules = load_rules()
learned = copy.deepcopy(rules.get("learned_rules", []))
if active_only:
return [r for r in learned if r.get("active", True)]
return learned
def add_learned_rule(rule_dict):
"""添加一条新学习规律"""
rules = load_rules(force_reload=True)
learned = rules.get("learned_rules", [])
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M")
rule_dict["id"] = f"rule_{ts}_{len(learned)+1:03d}"
rule_dict["created"] = datetime.datetime.now().strftime("%Y-%m-%d")
rule_dict["hit_count"] = 0
rule_dict["miss_count"] = 0
rule_dict["active"] = True
learned.append(rule_dict)
rules["learned_rules"] = learned
rules.setdefault("meta", {})["total_rules_learned"] = len(learned)
save_rules(rules)
return rule_dict["id"]
def update_learned_rule(rule_id, updates):
"""更新一条规律(如 hit_count/miss_count/active 状态)"""
rules = load_rules(force_reload=True)
learned = rules.get("learned_rules", [])
for r in learned:
if r.get("id") == rule_id:
for k, v in updates.items():
r[k] = v
break
rules["learned_rules"] = learned
save_rules(rules)
def update_signal_weight(signal_name, new_weight):
"""更新单个信号权重(写回 yaml + DB"""
canonical_name = normalize_signal_name(signal_name)
rules = load_rules(force_reload=True)
rules.setdefault("signal_weights", {})[canonical_name] = new_weight
save_rules(rules)
try:
from app.db.altcoin_db import update_signal_performance
update_signal_performance(canonical_name, category=None, is_hit=None, pnl=None, weight_override=new_weight)
except Exception:
pass
def get_meta():
"""返回迭代元数据"""
meta = _get_section("meta")
if not meta.get("strategy_version"):
version_num = meta.get("version", 1)
iteration = meta.get("iteration_count", 0)
meta["strategy_version"] = f"v{version_num}.{iteration}"
return meta
def get_rules_snapshot():
"""返回完整 rules.yaml 快照(深拷贝)"""
return copy.deepcopy(load_rules())
def diff_rule_snapshots(before, after, prefix=""):
"""递归比较两个配置快照,输出 changed/added/removed"""
result = {"changed": [], "added": [], "removed": []}
def walk(path, a, b):
if isinstance(a, dict) and isinstance(b, dict):
keys = set(a.keys()) | set(b.keys())
for key in sorted(keys):
next_path = f"{path}.{key}" if path else str(key)
if key not in a:
result["added"].append({"path": next_path, "new": copy.deepcopy(b[key])})
elif key not in b:
result["removed"].append({"path": next_path, "old": copy.deepcopy(a[key])})
else:
walk(next_path, a[key], b[key])
return
if isinstance(a, list) and isinstance(b, list):
if a != b:
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
return
if a != b:
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
walk(prefix, before or {}, after or {})
return result
def promote_candidate_rule_to_learned_rule(candidate, release_version=""):
"""把通过发布门槛的候选规则正式写入 learned_rules。
候选规则来自 DB strategy_rule_candidate只有发布闸门通过时才调用
避免日常研究直接污染主策略。
"""
desc = (candidate.get("rule_description") or "").strip()
if not desc:
return None
rules = load_rules(force_reload=True)
learned_rules = rules.setdefault("learned_rules", [])
for existing in learned_rules:
if existing.get("description") == desc:
return existing.get("id") or existing.get("rule_id")
rule_id = f"candidate_{candidate.get('id')}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
rule = {
"id": rule_id,
"type": candidate.get("rule_type") or "bonus",
"description": desc,
"conditions": {
"candidate_id": candidate.get("id"),
"signal_name": candidate.get("signal_name") or "",
"source": candidate.get("source") or "strategy_rule_candidate",
"sample_size": candidate.get("sample_size") or 0,
"confidence_score": candidate.get("confidence_score") or 0,
},
"score_adjust": 2 if (candidate.get("rule_type") or "") == "bonus" else -2,
"source": "candidate_release_gate",
"release_version": release_version or get_meta().get("strategy_version", ""),
"created_at": datetime.datetime.now().isoformat(),
}
learned_rules.append(rule)
save_rules(rules)
return rule_id
def bump_strategy_patch_version(note=""):
"""正式发布时才提升 patch 版本号。"""
import re
meta = get_meta()
current_ver = str(meta.get("strategy_version") or "v1.0.0").strip()
m = re.match(r"^v(\d+)\.(\d+)\.(\d+)$", current_ver)
if m:
major, minor, patch = map(int, m.groups())
new_ver = f"v{major}.{minor}.{patch + 1}"
else:
new_ver = current_ver + ".1"
update_meta("strategy_version", new_ver)
update_meta("strategy_revision_note", f"{new_ver}: {note}" if note else new_ver)
update_meta("strategy_revision_started_at", datetime.datetime.now().isoformat())
return current_ver, new_ver
def update_meta(key, value):
"""更新迭代元数据"""
rules = load_rules(force_reload=True)
rules.setdefault("meta", {})[key] = value
save_rules(rules)
# === 快捷取值函数(给各模块直接 import 用)===
def dynamic_k_thresholds():
p = get_pa_section("dynamic_k")
return p["body_ratio_min"], p["atr_ratio_min"]
def static_k_thresholds():
p = get_pa_section("static_k")
return p["body_ratio_max"], p["atr_ratio_max"]
def zone_params():
p = get_pa_section("supply_demand")
return p["lookback"], p["min_static_count"], p["q_score_breakpoints"]
def ignition_params():
p = get_pa_section("ignition")
return p["lookback"], p["min_static_count"], p["static_search_range"], p["confirm_search_range"]
def continuous_k_params():
return get_pa_section("continuous_k")
def exhaustion_params():
return get_pa_section("exhaustion")
def entry_point_params():
return get_pa_section("entry_point")
def burst_thresholds():
p = get_screener_section("burst_threshold")
return p["main"], p["meme"], p["overbought_multiplier"]
def volume_thresholds():
p = get_screener_section("volume")
return p["min_usd"], p["meme_min_usd"]
def vp_fly_params():
return get_screener_section("vp_fly")
def bollinger_squeeze_params():
return get_screener_section("bollinger_squeeze")
def funding_rate_params():
return get_screener_section("funding_rate")
def top_trader_params():
return get_screener_section("top_trader")
def state_score_thresholds():
p = get_screener_section("state_threshold")
return p["accelerate_main"], p["accelerate_meme"], p["accumulate"]
def confirm_min_score():
return get_confirm_params().get("min_score", 6)
def confirm_volume_breakout_ratio():
return get_confirm_params().get("volume_breakout_ratio", 2.0)
def confirm_state_cooldown_hours():
return get_confirm_params().get("state_cooldown_hours", 6)
def confirm_atr_multipliers():
return get_confirm_section("atr_multiplier")
def confirm_stop_loss_params():
return get_confirm_section("stop_loss")
def get_sentiment_params():
"""返回舆情监控所有参数"""
return _get_section("sentiment")
def sentiment_max_bonus():
return get_sentiment_params().get("max_bonus", 2)