alphax/app/config/config_loader.py
2026-05-29 10:09:30 +08:00

581 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略配置加载器。
rules.yaml 只作为只读 baseline线上运行期变化写 PostgreSQL runtime config
避免代码部署覆盖生产策略迭代状态。
"""
import copy
import datetime
import os
from pathlib import Path
import yaml
from app.config.rules_schema import validate_rules_payload
REPO_ROOT = Path(__file__).resolve().parents[2]
RULES_PATH = str(REPO_ROOT / "rules.yaml")
_cache = None
_cache_mtime = None
_yaml_cache = None
_yaml_cache_mtime = None
# 兼容旧代码中的信号名写法
_SIGNAL_NAME_ALIASES = {
"N倍放量(≥10x)": "vp_fly_1h_current",
"N倍放量": "vp_fly_1h_current",
"量价齐飞": "vp_fly_1h_current",
"连续3x放量(≥3根)": "volume_consecutive_1h",
"连续3x放量": "volume_consecutive_1h",
"静K→动K转折": "ignition_1h_current",
"静K动K转折": "ignition_1h_current",
"Q≥7供给区突破": "breakout_pullback_d1",
"Q7供给区突破": "breakout_pullback_d1",
"1H放量(量价背离)": "volume_divergence_1h",
"量价背离": "volume_divergence_1h",
"静K蓄力": "static_accum_4h",
"大户偏多": "top_trader_long",
"舆情共振": "sentiment_resonance",
"板块联动": "sector_rotation",
"DEX 放量": "dex_volume_spike",
"链上成交放量": "dex_volume_spike",
"流动性增加": "liquidity_add",
"交易所流出": "exchange_outflow",
"鲸鱼增持": "whale_accumulation",
"聪明钱买入": "smart_money_buying",
"流动性撤出风险": "liquidity_remove_risk",
"交易所流入风险": "exchange_inflow_risk",
"持仓集中风险": "holder_concentration_risk",
}
def _load_yaml_baseline(force_reload=False):
global _yaml_cache, _yaml_cache_mtime
mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
if not force_reload and _yaml_cache and _yaml_cache_mtime == mtime:
return _yaml_cache
with open(RULES_PATH, "r", encoding="utf-8") as f:
_yaml_cache = validate_rules_payload(yaml.safe_load(f) or {})
_yaml_cache_mtime = mtime
return _yaml_cache
def _runtime_overrides():
try:
from app.db.runtime_config_db import (
deep_merge,
get_event_driven_config,
get_event_sources,
get_learned_rules_config,
get_monitoring_config,
get_sentiment_config,
get_strategy_meta,
get_strategy_override,
set_event_driven_config,
set_event_sources,
set_monitoring_config,
set_sentiment_config,
)
override = get_strategy_override() or {}
meta = get_strategy_meta(default=None)
learned = get_learned_rules_config(default=None)
event_driven = get_event_driven_config(default=None)
event_sources = get_event_sources(default=None)
sentiment = get_sentiment_config(default=None)
monitoring = get_monitoring_config(default=None)
baseline = _yaml_cache or {}
if event_driven is None:
baseline_event = copy.deepcopy(baseline.get("event_driven", {}))
if isinstance(baseline_event, dict) and baseline_event:
if event_sources is not None and isinstance(event_sources, dict):
baseline_event = deep_merge(baseline_event, {"sources": event_sources})
event_driven = set_event_driven_config(baseline_event, source="seed_from_rules_yaml")
baseline_sources = baseline_event.get("sources", {})
if isinstance(baseline_sources, dict) and baseline_sources and get_event_sources(default=None) is None:
set_event_sources(baseline_sources, source="seed_from_rules_yaml")
elif event_sources is not None and isinstance(event_driven, dict):
event_driven = deep_merge(event_driven, {"sources": event_sources})
if sentiment is None:
baseline_sentiment = copy.deepcopy(baseline.get("sentiment", {}))
if isinstance(baseline_sentiment, dict) and baseline_sentiment:
sentiment = set_sentiment_config(baseline_sentiment, source="seed_from_rules_yaml")
if monitoring is None:
baseline_monitoring = copy.deepcopy(baseline.get("monitoring", {}))
if isinstance(baseline_monitoring, dict) and baseline_monitoring:
monitoring = set_monitoring_config(baseline_monitoring, source="seed_from_rules_yaml")
if isinstance(meta, dict) and meta:
override = deep_merge(override, {"meta": meta})
if isinstance(learned, list):
override = deep_merge(override, {"learned_rules": learned})
if isinstance(event_driven, dict) and event_driven:
override = deep_merge(override, {"event_driven": event_driven})
if isinstance(sentiment, dict) and sentiment:
override = deep_merge(override, {"sentiment": sentiment})
if isinstance(monitoring, dict) and monitoring:
override = deep_merge(override, {"monitoring": monitoring})
return override
except Exception:
return {}
def load_rules(force_reload=False):
"""加载配置rules.yaml baseline + PostgreSQL runtime override。"""
global _cache, _cache_mtime
baseline = _load_yaml_baseline(force_reload=force_reload)
mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
if not force_reload and _cache and _cache_mtime == mtime:
return _cache
rules = copy.deepcopy(baseline)
try:
from app.db.runtime_config_db import deep_merge
rules = deep_merge(rules, _runtime_overrides())
except Exception:
pass
_cache = validate_rules_payload(rules or {})
_cache_mtime = mtime
return _cache
def save_rules(rules_dict):
"""保存运行期策略覆盖到 PostgreSQL不再写 rules.yaml。"""
global _cache, _cache_mtime
baseline = _load_yaml_baseline(force_reload=True)
diff = diff_rule_snapshots(baseline, rules_dict)
override = {}
def set_path(target, dotted_path, value):
node = target
parts = dotted_path.split(".") if dotted_path else []
for part in parts[:-1]:
node = node.setdefault(part, {})
if parts:
node[parts[-1]] = copy.deepcopy(value)
for item in diff.get("changed", []) + diff.get("added", []):
set_path(override, item["path"], item.get("new"))
try:
from app.db.runtime_config_db import set_strategy_override
set_strategy_override(override, source="runtime_save_rules")
except Exception as exc:
if os.getenv("ALPHAX_ALLOW_YAML_RUNTIME_WRITE", "0").strip() == "1":
with open(RULES_PATH, "w", encoding="utf-8") as f:
yaml.dump(rules_dict, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
else:
raise RuntimeError("Runtime rules writes must go to PostgreSQL; rules.yaml is read-only") from exc
_cache = validate_rules_payload(copy.deepcopy(rules_dict))
_cache_mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
def _get_section(section_name, default=None):
rules = load_rules()
section = rules.get(section_name, default if default is not None else {})
return copy.deepcopy(section)
def _get_nested(section_name, path, default=None):
node = load_rules().get(section_name, {})
for key in path:
if not isinstance(node, dict):
return copy.deepcopy(default)
node = node.get(key)
if node is None:
return copy.deepcopy(default)
return copy.deepcopy(node)
def normalize_signal_name(signal_name):
"""统一信号名,兼容 rules.yaml 与旧 Python 代码中的不同写法"""
return _SIGNAL_NAME_ALIASES.get(signal_name, signal_name)
def get_strategy_params():
"""返回全局策略约束"""
return _get_section("strategy")
def get_strategy_direction(default="多头启动"):
return get_strategy_params().get("direction", default)
def is_long_only():
return get_strategy_params().get("mode", "") == "long_only"
def allow_short(default=False):
return bool(get_strategy_params().get("allow_short", default))
def get_pa_params():
"""返回 PA 引擎所有参数"""
return _get_section("pa_engine")
def get_pa_section(name=None):
"""返回 PA 引擎某个子区块name=None 时返回整个 pa_engine"""
if name is None:
return get_pa_params()
return _get_nested("pa_engine", [name], {})
def get_screener_params():
"""返回粗筛所有参数"""
return _get_section("screener")
def get_event_driven_params():
"""返回事件驱动舆情触发选币参数"""
return _get_section("event_driven")
def get_event_driven_section(name=None):
"""返回 event_driven 某个子区块name=None 时返回整个 event_driven"""
if name is None:
return get_event_driven_params()
return _get_nested("event_driven", [name], {})
def get_screener_section(name=None):
"""返回 screener 某个子区块name=None 时返回整个 screener"""
if name is None:
return get_screener_params()
return _get_nested("screener", [name], {})
def get_confirm_params():
"""返回确认层所有参数"""
return _get_section("confirm")
def get_confirm_section(name=None):
"""返回 confirm 某个子区块name=None 时返回整个 confirm"""
if name is None:
return get_confirm_params()
return _get_nested("confirm", [name], {})
def get_tracker_params():
"""返回跟踪层所有参数"""
return _get_section("tracker")
def get_signal_weights():
"""返回信号权重 dict优先用 DB signal_performance 的动态权重fallback 到 yaml
兼容层要求:
- 规则侧统一存 canonical key"1H放量"
- 旧脚本仍可能用历史 key"1H放量(量价背离)")直接查 weights[...]
因此返回值同时暴露 canonical key + alias key避免旧调用方 KeyError。
"""
# Signal weights need a stable baseline. Runtime strategy_override may
# contain small-sample governance writes; those are only trusted through
# signal_performance after the sample-size gate below.
rules = _load_yaml_baseline()
yaml_weights = copy.deepcopy(rules.get("signal_weights", {}))
canonical = {}
for sig, weight in yaml_weights.items():
canonical[normalize_signal_name(sig)] = weight
try:
from app.db.altcoin_db import get_signal_weights as db_get_weights
db_weights = db_get_weights()
review_params = get_review_params()
deprecation_params = review_params.get("signal_deprecation") or {}
min_samples = max(
int(review_params.get("min_samples_for_weight", 12) or 12),
int(deprecation_params.get("min_samples", 12) or 12),
)
for sig, data in db_weights.items():
norm_sig = normalize_signal_name(sig)
if data.get("total_count", 0) >= min_samples:
canonical[norm_sig] = data["weight"]
except Exception:
pass
merged = dict(canonical)
for alias, target in _SIGNAL_NAME_ALIASES.items():
if target in canonical:
merged[alias] = canonical[target]
return merged
def get_review_params():
"""返回复盘参数"""
return _get_section("review")
def get_reverse_params():
"""返回逆向分析参数"""
return _get_section("reverse_analysis")
def get_learned_rules(active_only=True):
"""返回已学习的规律列表"""
try:
from app.db.runtime_config_db import get_learned_rules_config
learned = get_learned_rules_config(default=None)
if learned is None:
learned = copy.deepcopy(load_rules().get("learned_rules", []))
except Exception:
learned = copy.deepcopy(load_rules().get("learned_rules", []))
if active_only:
return [r for r in learned if r.get("active", True)]
return learned
def add_learned_rule(rule_dict):
"""添加一条新学习规律"""
learned = get_learned_rules(active_only=False)
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M")
rule_dict["id"] = f"rule_{ts}_{len(learned)+1:03d}"
rule_dict["created"] = datetime.datetime.now().strftime("%Y-%m-%d")
rule_dict["hit_count"] = 0
rule_dict["miss_count"] = 0
rule_dict["active"] = True
learned.append(rule_dict)
try:
from app.db.runtime_config_db import set_learned_rules_config
set_learned_rules_config(learned, source="add_learned_rule")
except Exception:
rules = load_rules(force_reload=True)
rules["learned_rules"] = learned
save_rules(rules)
update_meta("total_rules_learned", len(learned))
return rule_dict["id"]
def update_learned_rule(rule_id, updates):
"""更新一条规律(如 hit_count/miss_count/active 状态)"""
learned = get_learned_rules(active_only=False)
for r in learned:
if r.get("id") == rule_id:
for k, v in updates.items():
r[k] = v
break
from app.db.runtime_config_db import set_learned_rules_config
set_learned_rules_config(learned, source="update_learned_rule")
def update_signal_weight(signal_name, new_weight):
"""更新单个信号权重(写 DB runtime override + signal_performance"""
canonical_name = normalize_signal_name(signal_name)
rules = load_rules(force_reload=True)
rules.setdefault("signal_weights", {})[canonical_name] = new_weight
save_rules(rules)
try:
from app.db.altcoin_db import update_signal_performance
update_signal_performance(canonical_name, category=None, is_hit=None, pnl=None, weight_override=new_weight)
except Exception:
pass
def get_meta():
"""返回迭代元数据"""
try:
from app.db.runtime_config_db import get_strategy_meta
meta = get_strategy_meta(default=None)
if not isinstance(meta, dict) or not meta:
meta = _get_section("meta")
except Exception:
meta = _get_section("meta")
if not meta.get("strategy_version"):
version_num = meta.get("version", 1)
iteration = meta.get("iteration_count", 0)
meta["strategy_version"] = f"v{version_num}.{iteration}"
return meta
def get_rules_snapshot():
"""返回完整 rules.yaml 快照(深拷贝)"""
return copy.deepcopy(load_rules())
def diff_rule_snapshots(before, after, prefix=""):
"""递归比较两个配置快照,输出 changed/added/removed"""
result = {"changed": [], "added": [], "removed": []}
def walk(path, a, b):
if isinstance(a, dict) and isinstance(b, dict):
keys = set(a.keys()) | set(b.keys())
for key in sorted(keys):
next_path = f"{path}.{key}" if path else str(key)
if key not in a:
result["added"].append({"path": next_path, "new": copy.deepcopy(b[key])})
elif key not in b:
result["removed"].append({"path": next_path, "old": copy.deepcopy(a[key])})
else:
walk(next_path, a[key], b[key])
return
if isinstance(a, list) and isinstance(b, list):
if a != b:
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
return
if a != b:
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
walk(prefix, before or {}, after or {})
return result
def promote_candidate_rule_to_learned_rule(candidate, release_version=""):
"""把通过发布门槛的候选规则正式写入 learned_rules。
候选规则来自 DB strategy_rule_candidate只有发布闸门通过时才调用
避免日常研究直接污染已发布策略。
"""
desc = (candidate.get("rule_description") or "").strip()
if not desc:
return None
learned_rules = get_learned_rules(active_only=False)
for existing in learned_rules:
if existing.get("description") == desc:
return existing.get("id") or existing.get("rule_id")
rule_id = f"candidate_{candidate.get('id')}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
rule = {
"id": rule_id,
"type": candidate.get("rule_type") or "bonus",
"description": desc,
"conditions": {
"candidate_id": candidate.get("id"),
"signal_name": candidate.get("signal_name") or "",
"source": candidate.get("source") or "strategy_rule_candidate",
"sample_size": candidate.get("sample_size") or 0,
"confidence_score": candidate.get("confidence_score") or 0,
},
"score_adjust": 2 if (candidate.get("rule_type") or "") == "bonus" else -2,
"source": "candidate_release_gate",
"release_version": release_version or get_meta().get("strategy_version", ""),
"created_at": datetime.datetime.now().isoformat(),
}
learned_rules.append(rule)
from app.db.runtime_config_db import set_learned_rules_config
set_learned_rules_config(learned_rules, source="candidate_release_gate")
return rule_id
def bump_strategy_patch_version(note=""):
"""正式发布时才提升 patch 版本号。"""
import re
meta = get_meta()
current_ver = str(meta.get("strategy_version") or "v1.0.0").strip()
m = re.match(r"^v(\d+)\.(\d+)\.(\d+)$", current_ver)
if m:
major, minor, patch = map(int, m.groups())
new_ver = f"v{major}.{minor}.{patch + 1}"
else:
new_ver = current_ver + ".1"
update_meta("strategy_version", new_ver)
update_meta("strategy_revision_note", f"{new_ver}: {note}" if note else new_ver)
update_meta("strategy_revision_started_at", datetime.datetime.now().isoformat())
return current_ver, new_ver
def update_meta(key, value):
"""更新迭代元数据"""
meta = get_meta()
meta[key] = value
try:
from app.db.runtime_config_db import set_strategy_meta
set_strategy_meta(meta, source="update_meta")
except Exception:
rules = load_rules(force_reload=True)
rules.setdefault("meta", {})[key] = value
save_rules(rules)
# === 快捷取值函数(给各模块直接 import 用)===
def dynamic_k_thresholds():
p = get_pa_section("dynamic_k")
return p["body_ratio_min"], p["atr_ratio_min"]
def static_k_thresholds():
p = get_pa_section("static_k")
return p["body_ratio_max"], p["atr_ratio_max"]
def zone_params():
p = get_pa_section("supply_demand")
return p["lookback"], p["min_static_count"], p["q_score_breakpoints"]
def ignition_params():
p = get_pa_section("ignition")
return p["lookback"], p["min_static_count"], p["static_search_range"], p["confirm_search_range"]
def continuous_k_params():
return get_pa_section("continuous_k")
def exhaustion_params():
return get_pa_section("exhaustion")
def entry_point_params():
return get_pa_section("entry_point")
def burst_thresholds():
p = get_screener_section("burst_threshold")
return p["main"], p["meme"], p["overbought_multiplier"]
def volume_thresholds():
p = get_screener_section("volume")
return p["min_usd"], p["meme_min_usd"]
def vp_fly_params():
return get_screener_section("vp_fly")
def bollinger_squeeze_params():
return get_screener_section("bollinger_squeeze")
def funding_rate_params():
return get_screener_section("funding_rate")
def top_trader_params():
return get_screener_section("top_trader")
def state_score_thresholds():
p = get_screener_section("state_threshold")
return p["accelerate_main"], p["accelerate_meme"], p["accumulate"]
def confirm_min_score():
return get_confirm_params().get("min_score", 6)
def confirm_volume_breakout_ratio():
return get_confirm_params().get("volume_breakout_ratio", 2.0)
def confirm_state_cooldown_hours():
return get_confirm_params().get("state_cooldown_hours", 6)
def confirm_atr_multipliers():
return get_confirm_section("atr_multiplier")
def confirm_stop_loss_params():
return get_confirm_section("stop_loss")
def get_sentiment_params():
"""返回舆情监控所有参数"""
return _get_section("sentiment")
def sentiment_max_bonus():
return get_sentiment_params().get("max_bonus", 2)