103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
"""DB helpers for standard strategy signals."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime
|
|
|
|
from app.core.strategy_contract import StrategySignal, strategy_context_payload
|
|
from app.core.strategy_registry import normalize_strategy_code
|
|
from app.db.schema import get_conn
|
|
|
|
|
|
def insert_strategy_signal(signal: StrategySignal | dict) -> dict:
|
|
payload = signal.to_json_dict() if isinstance(signal, StrategySignal) else strategy_context_payload(signal)
|
|
now = payload.get("created_at") or datetime.now().isoformat()
|
|
conn = get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"""
|
|
INSERT INTO strategy_signals (
|
|
run_id, strategy_code, strategy_version, symbol, direction, signal_status,
|
|
confidence, score, market_regime, trigger_json, factor_roles_json,
|
|
entry_plan_json, risk_plan_json, decision_log_json, created_at
|
|
) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
|
|
RETURNING id
|
|
""",
|
|
(
|
|
payload.get("run_id") or "",
|
|
normalize_strategy_code(payload.get("strategy_code")),
|
|
payload.get("strategy_version") or "",
|
|
payload.get("symbol") or "",
|
|
payload.get("direction") or "long",
|
|
payload.get("status") or payload.get("signal_status") or "candidate",
|
|
float(payload.get("confidence") or 0),
|
|
float(payload.get("score") or 0),
|
|
(payload.get("trigger") or {}).get("market_regime") or "",
|
|
json.dumps(payload.get("trigger") or {}, ensure_ascii=False, default=str),
|
|
json.dumps(payload.get("factor_roles") or {}, ensure_ascii=False, default=str),
|
|
json.dumps(payload.get("entry_plan") or {}, ensure_ascii=False, default=str),
|
|
json.dumps(payload.get("risk_plan") or {}, ensure_ascii=False, default=str),
|
|
json.dumps(payload.get("decision_log") or {}, ensure_ascii=False, default=str),
|
|
now,
|
|
),
|
|
).fetchone()
|
|
signal_id = int(row["id"] if row else 0)
|
|
conn.commit()
|
|
payload["strategy_signal_id"] = signal_id
|
|
payload["id"] = signal_id
|
|
return payload
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def list_recent_strategy_signals(strategy_code: str = "", symbol: str = "", limit: int = 50) -> list[dict]:
|
|
limit = max(1, min(int(limit or 50), 500))
|
|
where = []
|
|
params = []
|
|
if strategy_code:
|
|
where.append("strategy_code=%s")
|
|
params.append(normalize_strategy_code(strategy_code))
|
|
if symbol:
|
|
where.append("symbol=%s")
|
|
params.append(symbol)
|
|
where_sql = "WHERE " + " AND ".join(where) if where else ""
|
|
conn = get_conn()
|
|
try:
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT *
|
|
FROM strategy_signals
|
|
{where_sql}
|
|
ORDER BY created_at DESC, id DESC
|
|
LIMIT %s
|
|
""",
|
|
tuple(params + [limit]),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def get_strategy_signal_summary(days: int = 7) -> list[dict]:
|
|
days = max(1, min(int(days or 7), 365))
|
|
conn = get_conn()
|
|
try:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT strategy_code,
|
|
COUNT(*) AS signal_count,
|
|
AVG(score) AS avg_score,
|
|
AVG(confidence) AS avg_confidence,
|
|
MAX(created_at) AS latest_at
|
|
FROM strategy_signals
|
|
WHERE created_at >= (NOW() - (%s || ' days')::interval)::TEXT
|
|
GROUP BY strategy_code
|
|
ORDER BY signal_count DESC, latest_at DESC
|
|
""",
|
|
(str(days),),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
finally:
|
|
conn.close()
|