astock-agent/backend/app/research/stock_research_agent.py
2026-06-10 08:36:25 +08:00

287 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Stock research note generator.
The stock research layer is explanatory. It enriches rule-selected candidates
with theme, catalyst and risk context, while deterministic notes remain the
fallback when the LLM or local catalyst data is unavailable.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any
from sqlalchemy import text
from app.config import settings
from app.db.database import get_db
from app.llm.client import chat_completion
from app.research.industry_chain_agent import infer_chain_node, infer_chain_position_from_theme_view
logger = logging.getLogger(__name__)
def build_stock_research_notes_sync(recommendations: list[Any], theme_views: list[dict]) -> list[dict]:
"""Build deterministic notes used by tests, fallback APIs and LLM failures."""
theme_names = {item["theme"] for item in theme_views}
notes: list[dict] = []
for rec in recommendations[:20]:
notes.append(_fallback_note(rec, theme_names, [], _match_theme_view(rec, theme_views)))
return notes
async def build_stock_research_notes(
recommendations: list[Any],
theme_views: list[dict],
risk_alerts: list[dict] | None = None,
) -> list[dict]:
"""Build stock research notes, using LLM for top candidates when configured."""
if not recommendations:
return []
theme_names = {item["theme"] for item in theme_views}
theme_map = {item["theme"]: item for item in theme_views}
risk_map = _group_risks(risk_alerts or [])
notes: list[dict] = []
llm_limit = max(0, int(settings.research_stock_llm_limit or 0))
for index, rec in enumerate(recommendations[:20]):
theme = getattr(rec, "sector", "") or "未归类"
theme_view = _match_theme_view(rec, theme_views)
catalysts = await _load_local_catalyst_context(rec, theme)
fallback = _fallback_note(rec, theme_names, catalysts, theme_view)
if (
not settings.research_stock_llm_enabled
or index >= llm_limit
or not settings.deepseek_api_key
):
notes.append(fallback)
continue
try:
llm_note = await _build_llm_note(
rec=rec,
fallback=fallback,
theme_view=theme_view or theme_map.get(theme, {}),
risks=risk_map.get(getattr(rec, "ts_code", ""), []),
catalysts=catalysts,
)
notes.append(llm_note or fallback)
except Exception as exc:
logger.warning("股票研究笔记 LLM 生成失败 ts_code=%s error=%s", getattr(rec, "ts_code", ""), exc)
notes.append(fallback)
return notes
def _fallback_note(rec: Any, theme_names: set[str], catalysts: list[dict], theme_view: dict | None = None) -> dict:
trace = getattr(rec, "decision_trace", {}) or {}
evidence = trace.get("evidence") or getattr(rec, "reasons", []) or []
catalyst_titles = [item.get("title", "") for item in catalysts if item.get("title")]
theme = (theme_view or {}).get("theme") or getattr(rec, "sector", "") or "未归类"
position = (
infer_chain_position_from_theme_view(theme_view, getattr(rec, "ts_code", ""), getattr(rec, "name", ""))
if theme_view
else {"chain_node": infer_chain_node(theme, getattr(rec, "name", ""), getattr(rec, "sector", "")), "stock_role": "待归类"}
)
chain_node = position["chain_node"]
stock_role = position["stock_role"]
base = float(getattr(rec, "score", 0) or 0)
logic_score = min(100, base + (8 if theme in theme_names else 0) + min(len(catalyst_titles) * 1.5, 6))
action = getattr(rec, "action_plan", "观察") or "观察"
invalid = getattr(rec, "invalidation_condition", "") or getattr(rec, "risk_note", "") or "板块热度回落、资金持续性不足或买点未触发。"
return {
"ts_code": getattr(rec, "ts_code", ""),
"name": getattr(rec, "name", ""),
"theme": theme,
"chain_node": chain_node,
"stock_role": stock_role,
"logic_score": round(logic_score, 1),
"logic_summary": f"{getattr(rec, 'name', '')} 属于 {theme} 方向,产业链位置为{chain_node}{stock_role}),当前结论为{action}",
"evidence": (evidence + catalyst_titles)[:5],
"uncertainty": getattr(rec, "risk_note", "") or "等待后续公告、资金持续性和板块生命周期验证。",
"disagreement": "若主线扩散失败、成交额无法维持或同板块核心股转弱,当前逻辑需要降级。",
"invalid_condition": invalid,
"generated_by": "rules",
}
async def _build_llm_note(
rec: Any,
fallback: dict,
theme_view: dict,
risks: list[dict],
catalysts: list[dict],
) -> dict | None:
payload = {
"stock": {
"ts_code": fallback["ts_code"],
"name": fallback["name"],
"theme": fallback["theme"],
"chain_node": fallback["chain_node"],
"stock_role": fallback.get("stock_role", "待归类"),
"score": getattr(rec, "score", 0),
"action_plan": getattr(rec, "action_plan", "观察"),
"trigger": getattr(rec, "trigger_condition", "") or getattr(rec, "entry_timing", ""),
"invalid_condition": getattr(rec, "invalidation_condition", "") or getattr(rec, "risk_note", ""),
"decision_trace": getattr(rec, "decision_trace", {}) or {},
"reasons": getattr(rec, "reasons", []) or [],
},
"theme_view": theme_view,
"recent_catalysts": catalysts,
"risk_alerts": risks,
}
messages = [
{
"role": "system",
"content": (
"你是A股研究员。你的任务是把系统筛出的候选标的整理成研究笔记"
"用于解释机会逻辑、证据、分歧点和失效条件。不要给无条件买入建议,"
"不要编造未提供的数据。只输出合法JSON。"
),
},
{
"role": "user",
"content": (
"请基于以下结构化输入生成研究笔记。输出JSON字段必须为"
"logic_score(0-100数字), logic_summary(80字内), evidence(字符串数组,最多5条), "
"disagreement(80字内), uncertainty(80字内), invalid_condition(80字内)。\n\n"
f"{json.dumps(payload, ensure_ascii=False, default=str)}"
),
},
]
message = await chat_completion(messages)
content = _message_content(message)
parsed = _extract_json_object(content)
if not parsed:
return None
evidence = parsed.get("evidence") if isinstance(parsed.get("evidence"), list) else fallback["evidence"]
note = dict(fallback)
note.update({
"logic_score": _clamp_float(parsed.get("logic_score"), fallback["logic_score"], 0, 100),
"logic_summary": _clean_text(parsed.get("logic_summary"), fallback["logic_summary"], 120),
"evidence": [_clean_text(item, "", 80) for item in evidence if _clean_text(item, "", 80)][:5],
"disagreement": _clean_text(parsed.get("disagreement"), fallback["disagreement"], 120),
"uncertainty": _clean_text(parsed.get("uncertainty"), fallback["uncertainty"], 120),
"invalid_condition": _clean_text(parsed.get("invalid_condition"), fallback["invalid_condition"], 120),
"generated_by": "llm",
})
if not note["evidence"]:
note["evidence"] = fallback["evidence"]
return note
async def _load_local_catalyst_context(rec: Any, theme: str) -> list[dict]:
name = str(getattr(rec, "name", "") or "")
ts_code = str(getattr(rec, "ts_code", "") or "")
if not theme and not name and not ts_code:
return []
limit = max(1, int(settings.research_stock_news_limit or 6))
params = {
"theme": theme,
"name": f"%{name}%" if name else "%",
"ts_code": f"%{ts_code}%" if ts_code else "%",
"limit": limit,
}
try:
async with get_db() as db:
result = await db.execute(
text(
"SELECT c.title, c.summary, c.source, c.published_at, c.catalyst_type, "
"c.strength, c.confidence, tc.theme_name, tc.reason "
"FROM catalysts c "
"LEFT JOIN theme_catalysts tc ON tc.catalyst_id = c.id "
"WHERE c.is_active = 1 AND ("
"tc.theme_name = :theme OR c.title LIKE :name OR c.summary LIKE :name OR c.raw_text LIKE :ts_code"
") "
"ORDER BY COALESCE(c.published_at, c.created_at) DESC, c.id DESC "
"LIMIT :limit"
),
params,
)
rows = []
for row in result.fetchall():
item = dict(row._mapping)
rows.append({
"title": item.get("title", ""),
"summary": item.get("summary", ""),
"source": item.get("source", ""),
"published_at": str(item.get("published_at") or ""),
"theme": item.get("theme_name", ""),
"reason": item.get("reason", ""),
"strength": item.get("strength", 0),
"confidence": item.get("confidence", 0),
})
return rows
except Exception as exc:
logger.debug("读取本地催化上下文失败 ts_code=%s error=%s", ts_code, exc)
return []
def _group_risks(risks: list[dict]) -> dict[str, list[dict]]:
grouped: dict[str, list[dict]] = {}
for risk in risks:
ts_code = str(risk.get("ts_code") or "")
if ts_code:
grouped.setdefault(ts_code, []).append(risk)
return grouped
def _match_theme_view(rec: Any, theme_views: list[dict]) -> dict:
sector = str(getattr(rec, "sector", "") or "")
for item in theme_views:
theme = str(item.get("theme") or "")
raw_sector = str(item.get("raw_sector") or "")
if sector and (sector == theme or sector == raw_sector or theme in sector or sector in raw_sector):
return item
return {}
def _message_content(message: Any) -> str:
if not message:
return ""
if isinstance(message, dict):
return str(message.get("content") or "")
return str(getattr(message, "content", "") or "")
def _extract_json_object(content: str) -> dict:
if not content:
return {}
cleaned = content.strip()
if cleaned.startswith("```"):
cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
cleaned = re.sub(r"```$", "", cleaned).strip()
try:
parsed = json.loads(cleaned)
return parsed if isinstance(parsed, dict) else {}
except Exception:
pass
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
if not match:
return {}
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
def _clean_text(value: Any, fallback: str, max_chars: int) -> str:
text_value = str(value or "").strip()
if not text_value:
text_value = fallback
return text_value[:max_chars]
def _clamp_float(value: Any, fallback: float, minimum: float, maximum: float) -> float:
try:
number = float(value)
except Exception:
number = float(fallback or 0)
return round(max(minimum, min(maximum, number)), 1)