287 lines
11 KiB
Python
287 lines
11 KiB
Python
"""Stock research note generator.
|
||
|
||
The stock research layer is explanatory. It enriches rule-selected candidates
|
||
with theme, catalyst and risk context, while deterministic notes remain the
|
||
fallback when the LLM or local catalyst data is unavailable.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Any
|
||
|
||
from sqlalchemy import text
|
||
|
||
from app.config import settings
|
||
from app.db.database import get_db
|
||
from app.llm.client import chat_completion
|
||
from app.research.industry_chain_agent import infer_chain_node, infer_chain_position_from_theme_view
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def build_stock_research_notes_sync(recommendations: list[Any], theme_views: list[dict]) -> list[dict]:
|
||
"""Build deterministic notes used by tests, fallback APIs and LLM failures."""
|
||
theme_names = {item["theme"] for item in theme_views}
|
||
notes: list[dict] = []
|
||
for rec in recommendations[:20]:
|
||
notes.append(_fallback_note(rec, theme_names, [], _match_theme_view(rec, theme_views)))
|
||
return notes
|
||
|
||
|
||
async def build_stock_research_notes(
|
||
recommendations: list[Any],
|
||
theme_views: list[dict],
|
||
risk_alerts: list[dict] | None = None,
|
||
) -> list[dict]:
|
||
"""Build stock research notes, using LLM for top candidates when configured."""
|
||
if not recommendations:
|
||
return []
|
||
|
||
theme_names = {item["theme"] for item in theme_views}
|
||
theme_map = {item["theme"]: item for item in theme_views}
|
||
risk_map = _group_risks(risk_alerts or [])
|
||
notes: list[dict] = []
|
||
llm_limit = max(0, int(settings.research_stock_llm_limit or 0))
|
||
|
||
for index, rec in enumerate(recommendations[:20]):
|
||
theme = getattr(rec, "sector", "") or "未归类"
|
||
theme_view = _match_theme_view(rec, theme_views)
|
||
catalysts = await _load_local_catalyst_context(rec, theme)
|
||
fallback = _fallback_note(rec, theme_names, catalysts, theme_view)
|
||
if (
|
||
not settings.research_stock_llm_enabled
|
||
or index >= llm_limit
|
||
or not settings.deepseek_api_key
|
||
):
|
||
notes.append(fallback)
|
||
continue
|
||
|
||
try:
|
||
llm_note = await _build_llm_note(
|
||
rec=rec,
|
||
fallback=fallback,
|
||
theme_view=theme_view or theme_map.get(theme, {}),
|
||
risks=risk_map.get(getattr(rec, "ts_code", ""), []),
|
||
catalysts=catalysts,
|
||
)
|
||
notes.append(llm_note or fallback)
|
||
except Exception as exc:
|
||
logger.warning("股票研究笔记 LLM 生成失败 ts_code=%s error=%s", getattr(rec, "ts_code", ""), exc)
|
||
notes.append(fallback)
|
||
|
||
return notes
|
||
|
||
|
||
def _fallback_note(rec: Any, theme_names: set[str], catalysts: list[dict], theme_view: dict | None = None) -> dict:
|
||
trace = getattr(rec, "decision_trace", {}) or {}
|
||
evidence = trace.get("evidence") or getattr(rec, "reasons", []) or []
|
||
catalyst_titles = [item.get("title", "") for item in catalysts if item.get("title")]
|
||
theme = (theme_view or {}).get("theme") or getattr(rec, "sector", "") or "未归类"
|
||
position = (
|
||
infer_chain_position_from_theme_view(theme_view, getattr(rec, "ts_code", ""), getattr(rec, "name", ""))
|
||
if theme_view
|
||
else {"chain_node": infer_chain_node(theme, getattr(rec, "name", ""), getattr(rec, "sector", "")), "stock_role": "待归类"}
|
||
)
|
||
chain_node = position["chain_node"]
|
||
stock_role = position["stock_role"]
|
||
base = float(getattr(rec, "score", 0) or 0)
|
||
logic_score = min(100, base + (8 if theme in theme_names else 0) + min(len(catalyst_titles) * 1.5, 6))
|
||
action = getattr(rec, "action_plan", "观察") or "观察"
|
||
invalid = getattr(rec, "invalidation_condition", "") or getattr(rec, "risk_note", "") or "板块热度回落、资金持续性不足或买点未触发。"
|
||
return {
|
||
"ts_code": getattr(rec, "ts_code", ""),
|
||
"name": getattr(rec, "name", ""),
|
||
"theme": theme,
|
||
"chain_node": chain_node,
|
||
"stock_role": stock_role,
|
||
"logic_score": round(logic_score, 1),
|
||
"logic_summary": f"{getattr(rec, 'name', '')} 属于 {theme} 方向,产业链位置为{chain_node}({stock_role}),当前结论为{action}。",
|
||
"evidence": (evidence + catalyst_titles)[:5],
|
||
"uncertainty": getattr(rec, "risk_note", "") or "等待后续公告、资金持续性和板块生命周期验证。",
|
||
"disagreement": "若主线扩散失败、成交额无法维持或同板块核心股转弱,当前逻辑需要降级。",
|
||
"invalid_condition": invalid,
|
||
"generated_by": "rules",
|
||
}
|
||
|
||
|
||
async def _build_llm_note(
|
||
rec: Any,
|
||
fallback: dict,
|
||
theme_view: dict,
|
||
risks: list[dict],
|
||
catalysts: list[dict],
|
||
) -> dict | None:
|
||
payload = {
|
||
"stock": {
|
||
"ts_code": fallback["ts_code"],
|
||
"name": fallback["name"],
|
||
"theme": fallback["theme"],
|
||
"chain_node": fallback["chain_node"],
|
||
"stock_role": fallback.get("stock_role", "待归类"),
|
||
"score": getattr(rec, "score", 0),
|
||
"action_plan": getattr(rec, "action_plan", "观察"),
|
||
"trigger": getattr(rec, "trigger_condition", "") or getattr(rec, "entry_timing", ""),
|
||
"invalid_condition": getattr(rec, "invalidation_condition", "") or getattr(rec, "risk_note", ""),
|
||
"decision_trace": getattr(rec, "decision_trace", {}) or {},
|
||
"reasons": getattr(rec, "reasons", []) or [],
|
||
},
|
||
"theme_view": theme_view,
|
||
"recent_catalysts": catalysts,
|
||
"risk_alerts": risks,
|
||
}
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是A股研究员。你的任务是把系统筛出的候选标的整理成研究笔记,"
|
||
"用于解释机会逻辑、证据、分歧点和失效条件。不要给无条件买入建议,"
|
||
"不要编造未提供的数据。只输出合法JSON。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
"请基于以下结构化输入生成研究笔记。输出JSON字段必须为:"
|
||
"logic_score(0-100数字), logic_summary(80字内), evidence(字符串数组,最多5条), "
|
||
"disagreement(80字内), uncertainty(80字内), invalid_condition(80字内)。\n\n"
|
||
f"{json.dumps(payload, ensure_ascii=False, default=str)}"
|
||
),
|
||
},
|
||
]
|
||
message = await chat_completion(messages)
|
||
content = _message_content(message)
|
||
parsed = _extract_json_object(content)
|
||
if not parsed:
|
||
return None
|
||
|
||
evidence = parsed.get("evidence") if isinstance(parsed.get("evidence"), list) else fallback["evidence"]
|
||
note = dict(fallback)
|
||
note.update({
|
||
"logic_score": _clamp_float(parsed.get("logic_score"), fallback["logic_score"], 0, 100),
|
||
"logic_summary": _clean_text(parsed.get("logic_summary"), fallback["logic_summary"], 120),
|
||
"evidence": [_clean_text(item, "", 80) for item in evidence if _clean_text(item, "", 80)][:5],
|
||
"disagreement": _clean_text(parsed.get("disagreement"), fallback["disagreement"], 120),
|
||
"uncertainty": _clean_text(parsed.get("uncertainty"), fallback["uncertainty"], 120),
|
||
"invalid_condition": _clean_text(parsed.get("invalid_condition"), fallback["invalid_condition"], 120),
|
||
"generated_by": "llm",
|
||
})
|
||
if not note["evidence"]:
|
||
note["evidence"] = fallback["evidence"]
|
||
return note
|
||
|
||
|
||
async def _load_local_catalyst_context(rec: Any, theme: str) -> list[dict]:
|
||
name = str(getattr(rec, "name", "") or "")
|
||
ts_code = str(getattr(rec, "ts_code", "") or "")
|
||
if not theme and not name and not ts_code:
|
||
return []
|
||
|
||
limit = max(1, int(settings.research_stock_news_limit or 6))
|
||
params = {
|
||
"theme": theme,
|
||
"name": f"%{name}%" if name else "%",
|
||
"ts_code": f"%{ts_code}%" if ts_code else "%",
|
||
"limit": limit,
|
||
}
|
||
try:
|
||
async with get_db() as db:
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT c.title, c.summary, c.source, c.published_at, c.catalyst_type, "
|
||
"c.strength, c.confidence, tc.theme_name, tc.reason "
|
||
"FROM catalysts c "
|
||
"LEFT JOIN theme_catalysts tc ON tc.catalyst_id = c.id "
|
||
"WHERE c.is_active = 1 AND ("
|
||
"tc.theme_name = :theme OR c.title LIKE :name OR c.summary LIKE :name OR c.raw_text LIKE :ts_code"
|
||
") "
|
||
"ORDER BY COALESCE(c.published_at, c.created_at) DESC, c.id DESC "
|
||
"LIMIT :limit"
|
||
),
|
||
params,
|
||
)
|
||
rows = []
|
||
for row in result.fetchall():
|
||
item = dict(row._mapping)
|
||
rows.append({
|
||
"title": item.get("title", ""),
|
||
"summary": item.get("summary", ""),
|
||
"source": item.get("source", ""),
|
||
"published_at": str(item.get("published_at") or ""),
|
||
"theme": item.get("theme_name", ""),
|
||
"reason": item.get("reason", ""),
|
||
"strength": item.get("strength", 0),
|
||
"confidence": item.get("confidence", 0),
|
||
})
|
||
return rows
|
||
except Exception as exc:
|
||
logger.debug("读取本地催化上下文失败 ts_code=%s error=%s", ts_code, exc)
|
||
return []
|
||
|
||
|
||
def _group_risks(risks: list[dict]) -> dict[str, list[dict]]:
|
||
grouped: dict[str, list[dict]] = {}
|
||
for risk in risks:
|
||
ts_code = str(risk.get("ts_code") or "")
|
||
if ts_code:
|
||
grouped.setdefault(ts_code, []).append(risk)
|
||
return grouped
|
||
|
||
|
||
def _match_theme_view(rec: Any, theme_views: list[dict]) -> dict:
|
||
sector = str(getattr(rec, "sector", "") or "")
|
||
for item in theme_views:
|
||
theme = str(item.get("theme") or "")
|
||
raw_sector = str(item.get("raw_sector") or "")
|
||
if sector and (sector == theme or sector == raw_sector or theme in sector or sector in raw_sector):
|
||
return item
|
||
return {}
|
||
|
||
|
||
def _message_content(message: Any) -> str:
|
||
if not message:
|
||
return ""
|
||
if isinstance(message, dict):
|
||
return str(message.get("content") or "")
|
||
return str(getattr(message, "content", "") or "")
|
||
|
||
|
||
def _extract_json_object(content: str) -> dict:
|
||
if not content:
|
||
return {}
|
||
cleaned = content.strip()
|
||
if cleaned.startswith("```"):
|
||
cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
|
||
cleaned = re.sub(r"```$", "", cleaned).strip()
|
||
try:
|
||
parsed = json.loads(cleaned)
|
||
return parsed if isinstance(parsed, dict) else {}
|
||
except Exception:
|
||
pass
|
||
|
||
match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL)
|
||
if not match:
|
||
return {}
|
||
try:
|
||
parsed = json.loads(match.group(0))
|
||
return parsed if isinstance(parsed, dict) else {}
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
def _clean_text(value: Any, fallback: str, max_chars: int) -> str:
|
||
text_value = str(value or "").strip()
|
||
if not text_value:
|
||
text_value = fallback
|
||
return text_value[:max_chars]
|
||
|
||
|
||
def _clamp_float(value: Any, fallback: float, minimum: float, maximum: float) -> float:
|
||
try:
|
||
number = float(value)
|
||
except Exception:
|
||
number = float(fallback or 0)
|
||
return round(max(minimum, min(maximum, number)), 1)
|