581 lines
20 KiB
Python
581 lines
20 KiB
Python
"""
|
||
策略配置加载器。
|
||
|
||
rules.yaml 只作为只读 baseline;线上运行期变化写 PostgreSQL runtime config,
|
||
避免代码部署覆盖生产策略迭代状态。
|
||
"""
|
||
import copy
|
||
import datetime
|
||
import os
|
||
from pathlib import Path
|
||
|
||
import yaml
|
||
|
||
from app.config.rules_schema import validate_rules_payload
|
||
|
||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||
RULES_PATH = str(REPO_ROOT / "rules.yaml")
|
||
|
||
_cache = None
|
||
_cache_mtime = None
|
||
_yaml_cache = None
|
||
_yaml_cache_mtime = None
|
||
|
||
# 兼容旧代码中的信号名写法
|
||
_SIGNAL_NAME_ALIASES = {
|
||
"N倍放量(≥10x)": "vp_fly_1h_current",
|
||
"N倍放量": "vp_fly_1h_current",
|
||
"量价齐飞": "vp_fly_1h_current",
|
||
"连续3x放量(≥3根)": "volume_consecutive_1h",
|
||
"连续3x放量": "volume_consecutive_1h",
|
||
"静K→动K转折": "ignition_1h_current",
|
||
"静K动K转折": "ignition_1h_current",
|
||
"Q≥7供给区突破": "breakout_pullback_d1",
|
||
"Q7供给区突破": "breakout_pullback_d1",
|
||
"1H放量(量价背离)": "volume_divergence_1h",
|
||
"量价背离": "volume_divergence_1h",
|
||
"静K蓄力": "static_accum_4h",
|
||
"大户偏多": "top_trader_long",
|
||
"舆情共振": "sentiment_resonance",
|
||
"板块联动": "sector_rotation",
|
||
"DEX 放量": "dex_volume_spike",
|
||
"链上成交放量": "dex_volume_spike",
|
||
"流动性增加": "liquidity_add",
|
||
"交易所流出": "exchange_outflow",
|
||
"鲸鱼增持": "whale_accumulation",
|
||
"聪明钱买入": "smart_money_buying",
|
||
"流动性撤出风险": "liquidity_remove_risk",
|
||
"交易所流入风险": "exchange_inflow_risk",
|
||
"持仓集中风险": "holder_concentration_risk",
|
||
}
|
||
|
||
|
||
def _load_yaml_baseline(force_reload=False):
|
||
global _yaml_cache, _yaml_cache_mtime
|
||
mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
|
||
if not force_reload and _yaml_cache and _yaml_cache_mtime == mtime:
|
||
return _yaml_cache
|
||
with open(RULES_PATH, "r", encoding="utf-8") as f:
|
||
_yaml_cache = validate_rules_payload(yaml.safe_load(f) or {})
|
||
_yaml_cache_mtime = mtime
|
||
return _yaml_cache
|
||
|
||
|
||
def _runtime_overrides():
|
||
try:
|
||
from app.db.runtime_config_db import (
|
||
deep_merge,
|
||
get_event_driven_config,
|
||
get_event_sources,
|
||
get_learned_rules_config,
|
||
get_monitoring_config,
|
||
get_sentiment_config,
|
||
get_strategy_meta,
|
||
get_strategy_override,
|
||
set_event_driven_config,
|
||
set_event_sources,
|
||
set_monitoring_config,
|
||
set_sentiment_config,
|
||
)
|
||
override = get_strategy_override() or {}
|
||
meta = get_strategy_meta(default=None)
|
||
learned = get_learned_rules_config(default=None)
|
||
event_driven = get_event_driven_config(default=None)
|
||
event_sources = get_event_sources(default=None)
|
||
sentiment = get_sentiment_config(default=None)
|
||
monitoring = get_monitoring_config(default=None)
|
||
baseline = _yaml_cache or {}
|
||
if event_driven is None:
|
||
baseline_event = copy.deepcopy(baseline.get("event_driven", {}))
|
||
if isinstance(baseline_event, dict) and baseline_event:
|
||
if event_sources is not None and isinstance(event_sources, dict):
|
||
baseline_event = deep_merge(baseline_event, {"sources": event_sources})
|
||
event_driven = set_event_driven_config(baseline_event, source="seed_from_rules_yaml")
|
||
baseline_sources = baseline_event.get("sources", {})
|
||
if isinstance(baseline_sources, dict) and baseline_sources and get_event_sources(default=None) is None:
|
||
set_event_sources(baseline_sources, source="seed_from_rules_yaml")
|
||
elif event_sources is not None and isinstance(event_driven, dict):
|
||
event_driven = deep_merge(event_driven, {"sources": event_sources})
|
||
if sentiment is None:
|
||
baseline_sentiment = copy.deepcopy(baseline.get("sentiment", {}))
|
||
if isinstance(baseline_sentiment, dict) and baseline_sentiment:
|
||
sentiment = set_sentiment_config(baseline_sentiment, source="seed_from_rules_yaml")
|
||
if monitoring is None:
|
||
baseline_monitoring = copy.deepcopy(baseline.get("monitoring", {}))
|
||
if isinstance(baseline_monitoring, dict) and baseline_monitoring:
|
||
monitoring = set_monitoring_config(baseline_monitoring, source="seed_from_rules_yaml")
|
||
if isinstance(meta, dict) and meta:
|
||
override = deep_merge(override, {"meta": meta})
|
||
if isinstance(learned, list):
|
||
override = deep_merge(override, {"learned_rules": learned})
|
||
if isinstance(event_driven, dict) and event_driven:
|
||
override = deep_merge(override, {"event_driven": event_driven})
|
||
if isinstance(sentiment, dict) and sentiment:
|
||
override = deep_merge(override, {"sentiment": sentiment})
|
||
if isinstance(monitoring, dict) and monitoring:
|
||
override = deep_merge(override, {"monitoring": monitoring})
|
||
return override
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
def load_rules(force_reload=False):
|
||
"""加载配置:rules.yaml baseline + PostgreSQL runtime override。"""
|
||
global _cache, _cache_mtime
|
||
baseline = _load_yaml_baseline(force_reload=force_reload)
|
||
mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
|
||
if not force_reload and _cache and _cache_mtime == mtime:
|
||
return _cache
|
||
rules = copy.deepcopy(baseline)
|
||
try:
|
||
from app.db.runtime_config_db import deep_merge
|
||
rules = deep_merge(rules, _runtime_overrides())
|
||
except Exception:
|
||
pass
|
||
_cache = validate_rules_payload(rules or {})
|
||
_cache_mtime = mtime
|
||
return _cache
|
||
|
||
|
||
def save_rules(rules_dict):
|
||
"""保存运行期策略覆盖到 PostgreSQL,不再写 rules.yaml。"""
|
||
global _cache, _cache_mtime
|
||
baseline = _load_yaml_baseline(force_reload=True)
|
||
diff = diff_rule_snapshots(baseline, rules_dict)
|
||
override = {}
|
||
|
||
def set_path(target, dotted_path, value):
|
||
node = target
|
||
parts = dotted_path.split(".") if dotted_path else []
|
||
for part in parts[:-1]:
|
||
node = node.setdefault(part, {})
|
||
if parts:
|
||
node[parts[-1]] = copy.deepcopy(value)
|
||
|
||
for item in diff.get("changed", []) + diff.get("added", []):
|
||
set_path(override, item["path"], item.get("new"))
|
||
try:
|
||
from app.db.runtime_config_db import set_strategy_override
|
||
set_strategy_override(override, source="runtime_save_rules")
|
||
except Exception as exc:
|
||
if os.getenv("ALPHAX_ALLOW_YAML_RUNTIME_WRITE", "0").strip() == "1":
|
||
with open(RULES_PATH, "w", encoding="utf-8") as f:
|
||
yaml.dump(rules_dict, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||
else:
|
||
raise RuntimeError("Runtime rules writes must go to PostgreSQL; rules.yaml is read-only") from exc
|
||
_cache = validate_rules_payload(copy.deepcopy(rules_dict))
|
||
_cache_mtime = os.path.getmtime(RULES_PATH) if os.path.exists(RULES_PATH) else 0
|
||
|
||
|
||
def _get_section(section_name, default=None):
|
||
rules = load_rules()
|
||
section = rules.get(section_name, default if default is not None else {})
|
||
return copy.deepcopy(section)
|
||
|
||
|
||
def _get_nested(section_name, path, default=None):
|
||
node = load_rules().get(section_name, {})
|
||
for key in path:
|
||
if not isinstance(node, dict):
|
||
return copy.deepcopy(default)
|
||
node = node.get(key)
|
||
if node is None:
|
||
return copy.deepcopy(default)
|
||
return copy.deepcopy(node)
|
||
|
||
|
||
def normalize_signal_name(signal_name):
|
||
"""统一信号名,兼容 rules.yaml 与旧 Python 代码中的不同写法"""
|
||
return _SIGNAL_NAME_ALIASES.get(signal_name, signal_name)
|
||
|
||
|
||
def get_strategy_params():
|
||
"""返回全局策略约束"""
|
||
return _get_section("strategy")
|
||
|
||
|
||
def get_strategy_direction(default="多头启动"):
|
||
return get_strategy_params().get("direction", default)
|
||
|
||
|
||
def is_long_only():
|
||
return get_strategy_params().get("mode", "") == "long_only"
|
||
|
||
|
||
def allow_short(default=False):
|
||
return bool(get_strategy_params().get("allow_short", default))
|
||
|
||
|
||
def get_pa_params():
|
||
"""返回 PA 引擎所有参数"""
|
||
return _get_section("pa_engine")
|
||
|
||
|
||
def get_pa_section(name=None):
|
||
"""返回 PA 引擎某个子区块;name=None 时返回整个 pa_engine"""
|
||
if name is None:
|
||
return get_pa_params()
|
||
return _get_nested("pa_engine", [name], {})
|
||
|
||
|
||
def get_screener_params():
|
||
"""返回粗筛所有参数"""
|
||
return _get_section("screener")
|
||
|
||
|
||
def get_event_driven_params():
|
||
"""返回事件驱动舆情触发选币参数"""
|
||
return _get_section("event_driven")
|
||
|
||
|
||
def get_event_driven_section(name=None):
|
||
"""返回 event_driven 某个子区块;name=None 时返回整个 event_driven"""
|
||
if name is None:
|
||
return get_event_driven_params()
|
||
return _get_nested("event_driven", [name], {})
|
||
|
||
|
||
def get_screener_section(name=None):
|
||
"""返回 screener 某个子区块;name=None 时返回整个 screener"""
|
||
if name is None:
|
||
return get_screener_params()
|
||
return _get_nested("screener", [name], {})
|
||
|
||
|
||
def get_confirm_params():
|
||
"""返回确认层所有参数"""
|
||
return _get_section("confirm")
|
||
|
||
|
||
def get_confirm_section(name=None):
|
||
"""返回 confirm 某个子区块;name=None 时返回整个 confirm"""
|
||
if name is None:
|
||
return get_confirm_params()
|
||
return _get_nested("confirm", [name], {})
|
||
|
||
|
||
def get_tracker_params():
|
||
"""返回跟踪层所有参数"""
|
||
return _get_section("tracker")
|
||
|
||
|
||
def get_signal_weights():
|
||
"""返回信号权重 dict(优先用 DB signal_performance 的动态权重,fallback 到 yaml)
|
||
|
||
兼容层要求:
|
||
- 规则侧统一存 canonical key(如 "1H放量")
|
||
- 旧脚本仍可能用历史 key(如 "1H放量(量价背离)")直接查 weights[...]
|
||
因此返回值同时暴露 canonical key + alias key,避免旧调用方 KeyError。
|
||
"""
|
||
# Signal weights need a stable baseline. Runtime strategy_override may
|
||
# contain small-sample governance writes; those are only trusted through
|
||
# signal_performance after the sample-size gate below.
|
||
rules = _load_yaml_baseline()
|
||
yaml_weights = copy.deepcopy(rules.get("signal_weights", {}))
|
||
|
||
canonical = {}
|
||
for sig, weight in yaml_weights.items():
|
||
canonical[normalize_signal_name(sig)] = weight
|
||
|
||
try:
|
||
from app.db.altcoin_db import get_signal_weights as db_get_weights
|
||
db_weights = db_get_weights()
|
||
review_params = get_review_params()
|
||
deprecation_params = review_params.get("signal_deprecation") or {}
|
||
min_samples = max(
|
||
int(review_params.get("min_samples_for_weight", 12) or 12),
|
||
int(deprecation_params.get("min_samples", 12) or 12),
|
||
)
|
||
for sig, data in db_weights.items():
|
||
norm_sig = normalize_signal_name(sig)
|
||
if data.get("total_count", 0) >= min_samples:
|
||
canonical[norm_sig] = data["weight"]
|
||
except Exception:
|
||
pass
|
||
|
||
merged = dict(canonical)
|
||
for alias, target in _SIGNAL_NAME_ALIASES.items():
|
||
if target in canonical:
|
||
merged[alias] = canonical[target]
|
||
return merged
|
||
|
||
|
||
def get_review_params():
|
||
"""返回复盘参数"""
|
||
return _get_section("review")
|
||
|
||
|
||
def get_reverse_params():
|
||
"""返回逆向分析参数"""
|
||
return _get_section("reverse_analysis")
|
||
|
||
|
||
def get_learned_rules(active_only=True):
|
||
"""返回已学习的规律列表"""
|
||
try:
|
||
from app.db.runtime_config_db import get_learned_rules_config
|
||
learned = get_learned_rules_config(default=None)
|
||
if learned is None:
|
||
learned = copy.deepcopy(load_rules().get("learned_rules", []))
|
||
except Exception:
|
||
learned = copy.deepcopy(load_rules().get("learned_rules", []))
|
||
if active_only:
|
||
return [r for r in learned if r.get("active", True)]
|
||
return learned
|
||
|
||
|
||
def add_learned_rule(rule_dict):
|
||
"""添加一条新学习规律"""
|
||
learned = get_learned_rules(active_only=False)
|
||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||
rule_dict["id"] = f"rule_{ts}_{len(learned)+1:03d}"
|
||
rule_dict["created"] = datetime.datetime.now().strftime("%Y-%m-%d")
|
||
rule_dict["hit_count"] = 0
|
||
rule_dict["miss_count"] = 0
|
||
rule_dict["active"] = True
|
||
learned.append(rule_dict)
|
||
try:
|
||
from app.db.runtime_config_db import set_learned_rules_config
|
||
set_learned_rules_config(learned, source="add_learned_rule")
|
||
except Exception:
|
||
rules = load_rules(force_reload=True)
|
||
rules["learned_rules"] = learned
|
||
save_rules(rules)
|
||
update_meta("total_rules_learned", len(learned))
|
||
return rule_dict["id"]
|
||
|
||
|
||
def update_learned_rule(rule_id, updates):
|
||
"""更新一条规律(如 hit_count/miss_count/active 状态)"""
|
||
learned = get_learned_rules(active_only=False)
|
||
for r in learned:
|
||
if r.get("id") == rule_id:
|
||
for k, v in updates.items():
|
||
r[k] = v
|
||
break
|
||
from app.db.runtime_config_db import set_learned_rules_config
|
||
set_learned_rules_config(learned, source="update_learned_rule")
|
||
|
||
|
||
def update_signal_weight(signal_name, new_weight):
|
||
"""更新单个信号权重(写 DB runtime override + signal_performance)"""
|
||
canonical_name = normalize_signal_name(signal_name)
|
||
rules = load_rules(force_reload=True)
|
||
rules.setdefault("signal_weights", {})[canonical_name] = new_weight
|
||
save_rules(rules)
|
||
try:
|
||
from app.db.altcoin_db import update_signal_performance
|
||
update_signal_performance(canonical_name, category=None, is_hit=None, pnl=None, weight_override=new_weight)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def get_meta():
|
||
"""返回迭代元数据"""
|
||
try:
|
||
from app.db.runtime_config_db import get_strategy_meta
|
||
meta = get_strategy_meta(default=None)
|
||
if not isinstance(meta, dict) or not meta:
|
||
meta = _get_section("meta")
|
||
except Exception:
|
||
meta = _get_section("meta")
|
||
if not meta.get("strategy_version"):
|
||
version_num = meta.get("version", 1)
|
||
iteration = meta.get("iteration_count", 0)
|
||
meta["strategy_version"] = f"v{version_num}.{iteration}"
|
||
return meta
|
||
|
||
|
||
def get_rules_snapshot():
|
||
"""返回完整 rules.yaml 快照(深拷贝)"""
|
||
return copy.deepcopy(load_rules())
|
||
|
||
|
||
def diff_rule_snapshots(before, after, prefix=""):
|
||
"""递归比较两个配置快照,输出 changed/added/removed"""
|
||
result = {"changed": [], "added": [], "removed": []}
|
||
|
||
def walk(path, a, b):
|
||
if isinstance(a, dict) and isinstance(b, dict):
|
||
keys = set(a.keys()) | set(b.keys())
|
||
for key in sorted(keys):
|
||
next_path = f"{path}.{key}" if path else str(key)
|
||
if key not in a:
|
||
result["added"].append({"path": next_path, "new": copy.deepcopy(b[key])})
|
||
elif key not in b:
|
||
result["removed"].append({"path": next_path, "old": copy.deepcopy(a[key])})
|
||
else:
|
||
walk(next_path, a[key], b[key])
|
||
return
|
||
|
||
if isinstance(a, list) and isinstance(b, list):
|
||
if a != b:
|
||
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
|
||
return
|
||
|
||
if a != b:
|
||
result["changed"].append({"path": path, "old": copy.deepcopy(a), "new": copy.deepcopy(b)})
|
||
|
||
walk(prefix, before or {}, after or {})
|
||
return result
|
||
|
||
|
||
|
||
def promote_candidate_rule_to_learned_rule(candidate, release_version=""):
|
||
"""把通过发布门槛的候选规则正式写入 learned_rules。
|
||
|
||
候选规则来自 DB strategy_rule_candidate;只有发布闸门通过时才调用,
|
||
避免日常研究直接污染已发布策略。
|
||
"""
|
||
desc = (candidate.get("rule_description") or "").strip()
|
||
if not desc:
|
||
return None
|
||
learned_rules = get_learned_rules(active_only=False)
|
||
for existing in learned_rules:
|
||
if existing.get("description") == desc:
|
||
return existing.get("id") or existing.get("rule_id")
|
||
rule_id = f"candidate_{candidate.get('id')}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||
rule = {
|
||
"id": rule_id,
|
||
"type": candidate.get("rule_type") or "bonus",
|
||
"description": desc,
|
||
"conditions": {
|
||
"candidate_id": candidate.get("id"),
|
||
"signal_name": candidate.get("signal_name") or "",
|
||
"source": candidate.get("source") or "strategy_rule_candidate",
|
||
"sample_size": candidate.get("sample_size") or 0,
|
||
"confidence_score": candidate.get("confidence_score") or 0,
|
||
},
|
||
"score_adjust": 2 if (candidate.get("rule_type") or "") == "bonus" else -2,
|
||
"source": "candidate_release_gate",
|
||
"release_version": release_version or get_meta().get("strategy_version", ""),
|
||
"created_at": datetime.datetime.now().isoformat(),
|
||
}
|
||
learned_rules.append(rule)
|
||
from app.db.runtime_config_db import set_learned_rules_config
|
||
set_learned_rules_config(learned_rules, source="candidate_release_gate")
|
||
return rule_id
|
||
|
||
|
||
def bump_strategy_patch_version(note=""):
|
||
"""正式发布时才提升 patch 版本号。"""
|
||
import re
|
||
meta = get_meta()
|
||
current_ver = str(meta.get("strategy_version") or "v1.0.0").strip()
|
||
m = re.match(r"^v(\d+)\.(\d+)\.(\d+)$", current_ver)
|
||
if m:
|
||
major, minor, patch = map(int, m.groups())
|
||
new_ver = f"v{major}.{minor}.{patch + 1}"
|
||
else:
|
||
new_ver = current_ver + ".1"
|
||
update_meta("strategy_version", new_ver)
|
||
update_meta("strategy_revision_note", f"{new_ver}: {note}" if note else new_ver)
|
||
update_meta("strategy_revision_started_at", datetime.datetime.now().isoformat())
|
||
return current_ver, new_ver
|
||
|
||
def update_meta(key, value):
|
||
"""更新迭代元数据"""
|
||
meta = get_meta()
|
||
meta[key] = value
|
||
try:
|
||
from app.db.runtime_config_db import set_strategy_meta
|
||
set_strategy_meta(meta, source="update_meta")
|
||
except Exception:
|
||
rules = load_rules(force_reload=True)
|
||
rules.setdefault("meta", {})[key] = value
|
||
save_rules(rules)
|
||
|
||
|
||
# === 快捷取值函数(给各模块直接 import 用)===
|
||
|
||
def dynamic_k_thresholds():
|
||
p = get_pa_section("dynamic_k")
|
||
return p["body_ratio_min"], p["atr_ratio_min"]
|
||
|
||
|
||
def static_k_thresholds():
|
||
p = get_pa_section("static_k")
|
||
return p["body_ratio_max"], p["atr_ratio_max"]
|
||
|
||
|
||
def zone_params():
|
||
p = get_pa_section("supply_demand")
|
||
return p["lookback"], p["min_static_count"], p["q_score_breakpoints"]
|
||
|
||
|
||
def ignition_params():
|
||
p = get_pa_section("ignition")
|
||
return p["lookback"], p["min_static_count"], p["static_search_range"], p["confirm_search_range"]
|
||
|
||
|
||
def continuous_k_params():
|
||
return get_pa_section("continuous_k")
|
||
|
||
|
||
def exhaustion_params():
|
||
return get_pa_section("exhaustion")
|
||
|
||
|
||
def entry_point_params():
|
||
return get_pa_section("entry_point")
|
||
|
||
|
||
def burst_thresholds():
|
||
p = get_screener_section("burst_threshold")
|
||
return p["main"], p["meme"], p["overbought_multiplier"]
|
||
|
||
|
||
def volume_thresholds():
|
||
p = get_screener_section("volume")
|
||
return p["min_usd"], p["meme_min_usd"]
|
||
|
||
|
||
def vp_fly_params():
|
||
return get_screener_section("vp_fly")
|
||
|
||
|
||
def bollinger_squeeze_params():
|
||
return get_screener_section("bollinger_squeeze")
|
||
|
||
|
||
def funding_rate_params():
|
||
return get_screener_section("funding_rate")
|
||
|
||
|
||
def top_trader_params():
|
||
return get_screener_section("top_trader")
|
||
|
||
|
||
def state_score_thresholds():
|
||
p = get_screener_section("state_threshold")
|
||
return p["accelerate_main"], p["accelerate_meme"], p["accumulate"]
|
||
|
||
|
||
def confirm_min_score():
|
||
return get_confirm_params().get("min_score", 6)
|
||
|
||
|
||
def confirm_volume_breakout_ratio():
|
||
return get_confirm_params().get("volume_breakout_ratio", 2.0)
|
||
|
||
|
||
def confirm_state_cooldown_hours():
|
||
return get_confirm_params().get("state_cooldown_hours", 6)
|
||
|
||
|
||
def confirm_atr_multipliers():
|
||
return get_confirm_section("atr_multiplier")
|
||
|
||
|
||
def confirm_stop_loss_params():
|
||
return get_confirm_section("stop_loss")
|
||
|
||
|
||
def get_sentiment_params():
|
||
"""返回舆情监控所有参数"""
|
||
return _get_section("sentiment")
|
||
|
||
|
||
def sentiment_max_bonus():
|
||
return get_sentiment_params().get("max_bonus", 2)
|