"""Stock research note generator. The stock research layer is explanatory. It enriches rule-selected candidates with theme, catalyst and risk context, while deterministic notes remain the fallback when the LLM or local catalyst data is unavailable. """ from __future__ import annotations import json import logging import re from typing import Any from sqlalchemy import text from app.config import settings from app.db.database import get_db from app.llm.client import chat_completion from app.research.industry_chain_agent import infer_chain_node, infer_chain_position_from_theme_view logger = logging.getLogger(__name__) def build_stock_research_notes_sync(recommendations: list[Any], theme_views: list[dict]) -> list[dict]: """Build deterministic notes used by tests, fallback APIs and LLM failures.""" theme_names = {item["theme"] for item in theme_views} notes: list[dict] = [] for rec in recommendations[:20]: notes.append(_fallback_note(rec, theme_names, [], _match_theme_view(rec, theme_views))) return notes async def build_stock_research_notes( recommendations: list[Any], theme_views: list[dict], risk_alerts: list[dict] | None = None, ) -> list[dict]: """Build stock research notes, using LLM for top candidates when configured.""" if not recommendations: return [] theme_names = {item["theme"] for item in theme_views} theme_map = {item["theme"]: item for item in theme_views} risk_map = _group_risks(risk_alerts or []) notes: list[dict] = [] llm_limit = max(0, int(settings.research_stock_llm_limit or 0)) for index, rec in enumerate(recommendations[:20]): theme = getattr(rec, "sector", "") or "未归类" theme_view = _match_theme_view(rec, theme_views) catalysts = await _load_local_catalyst_context(rec, theme) fallback = _fallback_note(rec, theme_names, catalysts, theme_view) if ( not settings.research_stock_llm_enabled or index >= llm_limit or not settings.deepseek_api_key ): notes.append(fallback) continue try: llm_note = await _build_llm_note( rec=rec, fallback=fallback, theme_view=theme_view or theme_map.get(theme, {}), risks=risk_map.get(getattr(rec, "ts_code", ""), []), catalysts=catalysts, ) notes.append(llm_note or fallback) except Exception as exc: logger.warning("股票研究笔记 LLM 生成失败 ts_code=%s error=%s", getattr(rec, "ts_code", ""), exc) notes.append(fallback) return notes def _fallback_note(rec: Any, theme_names: set[str], catalysts: list[dict], theme_view: dict | None = None) -> dict: trace = getattr(rec, "decision_trace", {}) or {} evidence = trace.get("evidence") or getattr(rec, "reasons", []) or [] catalyst_titles = [item.get("title", "") for item in catalysts if item.get("title")] theme = (theme_view or {}).get("theme") or getattr(rec, "sector", "") or "未归类" position = ( infer_chain_position_from_theme_view(theme_view, getattr(rec, "ts_code", ""), getattr(rec, "name", "")) if theme_view else {"chain_node": infer_chain_node(theme, getattr(rec, "name", ""), getattr(rec, "sector", "")), "stock_role": "待归类"} ) chain_node = position["chain_node"] stock_role = position["stock_role"] base = float(getattr(rec, "score", 0) or 0) logic_score = min(100, base + (8 if theme in theme_names else 0) + min(len(catalyst_titles) * 1.5, 6)) action = getattr(rec, "action_plan", "观察") or "观察" invalid = getattr(rec, "invalidation_condition", "") or getattr(rec, "risk_note", "") or "板块热度回落、资金持续性不足或买点未触发。" return { "ts_code": getattr(rec, "ts_code", ""), "name": getattr(rec, "name", ""), "theme": theme, "chain_node": chain_node, "stock_role": stock_role, "logic_score": round(logic_score, 1), "logic_summary": f"{getattr(rec, 'name', '')} 属于 {theme} 方向,产业链位置为{chain_node}({stock_role}),当前结论为{action}。", "evidence": (evidence + catalyst_titles)[:5], "uncertainty": getattr(rec, "risk_note", "") or "等待后续公告、资金持续性和板块生命周期验证。", "disagreement": "若主线扩散失败、成交额无法维持或同板块核心股转弱,当前逻辑需要降级。", "invalid_condition": invalid, "generated_by": "rules", } async def _build_llm_note( rec: Any, fallback: dict, theme_view: dict, risks: list[dict], catalysts: list[dict], ) -> dict | None: payload = { "stock": { "ts_code": fallback["ts_code"], "name": fallback["name"], "theme": fallback["theme"], "chain_node": fallback["chain_node"], "stock_role": fallback.get("stock_role", "待归类"), "score": getattr(rec, "score", 0), "action_plan": getattr(rec, "action_plan", "观察"), "trigger": getattr(rec, "trigger_condition", "") or getattr(rec, "entry_timing", ""), "invalid_condition": getattr(rec, "invalidation_condition", "") or getattr(rec, "risk_note", ""), "decision_trace": getattr(rec, "decision_trace", {}) or {}, "reasons": getattr(rec, "reasons", []) or [], }, "theme_view": theme_view, "recent_catalysts": catalysts, "risk_alerts": risks, } messages = [ { "role": "system", "content": ( "你是A股研究员。你的任务是把系统筛出的候选标的整理成研究笔记," "用于解释机会逻辑、证据、分歧点和失效条件。不要给无条件买入建议," "不要编造未提供的数据。只输出合法JSON。" ), }, { "role": "user", "content": ( "请基于以下结构化输入生成研究笔记。输出JSON字段必须为:" "logic_score(0-100数字), logic_summary(80字内), evidence(字符串数组,最多5条), " "disagreement(80字内), uncertainty(80字内), invalid_condition(80字内)。\n\n" f"{json.dumps(payload, ensure_ascii=False, default=str)}" ), }, ] message = await chat_completion(messages) content = _message_content(message) parsed = _extract_json_object(content) if not parsed: return None evidence = parsed.get("evidence") if isinstance(parsed.get("evidence"), list) else fallback["evidence"] note = dict(fallback) note.update({ "logic_score": _clamp_float(parsed.get("logic_score"), fallback["logic_score"], 0, 100), "logic_summary": _clean_text(parsed.get("logic_summary"), fallback["logic_summary"], 120), "evidence": [_clean_text(item, "", 80) for item in evidence if _clean_text(item, "", 80)][:5], "disagreement": _clean_text(parsed.get("disagreement"), fallback["disagreement"], 120), "uncertainty": _clean_text(parsed.get("uncertainty"), fallback["uncertainty"], 120), "invalid_condition": _clean_text(parsed.get("invalid_condition"), fallback["invalid_condition"], 120), "generated_by": "llm", }) if not note["evidence"]: note["evidence"] = fallback["evidence"] return note async def _load_local_catalyst_context(rec: Any, theme: str) -> list[dict]: name = str(getattr(rec, "name", "") or "") ts_code = str(getattr(rec, "ts_code", "") or "") if not theme and not name and not ts_code: return [] limit = max(1, int(settings.research_stock_news_limit or 6)) params = { "theme": theme, "name": f"%{name}%" if name else "%", "ts_code": f"%{ts_code}%" if ts_code else "%", "limit": limit, } try: async with get_db() as db: result = await db.execute( text( "SELECT c.title, c.summary, c.source, c.published_at, c.catalyst_type, " "c.strength, c.confidence, tc.theme_name, tc.reason " "FROM catalysts c " "LEFT JOIN theme_catalysts tc ON tc.catalyst_id = c.id " "WHERE c.is_active = 1 AND (" "tc.theme_name = :theme OR c.title LIKE :name OR c.summary LIKE :name OR c.raw_text LIKE :ts_code" ") " "ORDER BY COALESCE(c.published_at, c.created_at) DESC, c.id DESC " "LIMIT :limit" ), params, ) rows = [] for row in result.fetchall(): item = dict(row._mapping) rows.append({ "title": item.get("title", ""), "summary": item.get("summary", ""), "source": item.get("source", ""), "published_at": str(item.get("published_at") or ""), "theme": item.get("theme_name", ""), "reason": item.get("reason", ""), "strength": item.get("strength", 0), "confidence": item.get("confidence", 0), }) return rows except Exception as exc: logger.debug("读取本地催化上下文失败 ts_code=%s error=%s", ts_code, exc) return [] def _group_risks(risks: list[dict]) -> dict[str, list[dict]]: grouped: dict[str, list[dict]] = {} for risk in risks: ts_code = str(risk.get("ts_code") or "") if ts_code: grouped.setdefault(ts_code, []).append(risk) return grouped def _match_theme_view(rec: Any, theme_views: list[dict]) -> dict: sector = str(getattr(rec, "sector", "") or "") for item in theme_views: theme = str(item.get("theme") or "") raw_sector = str(item.get("raw_sector") or "") if sector and (sector == theme or sector == raw_sector or theme in sector or sector in raw_sector): return item return {} def _message_content(message: Any) -> str: if not message: return "" if isinstance(message, dict): return str(message.get("content") or "") return str(getattr(message, "content", "") or "") def _extract_json_object(content: str) -> dict: if not content: return {} cleaned = content.strip() if cleaned.startswith("```"): cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip() cleaned = re.sub(r"```$", "", cleaned).strip() try: parsed = json.loads(cleaned) return parsed if isinstance(parsed, dict) else {} except Exception: pass match = re.search(r"\{.*\}", cleaned, flags=re.DOTALL) if not match: return {} try: parsed = json.loads(match.group(0)) return parsed if isinstance(parsed, dict) else {} except Exception: return {} def _clean_text(value: Any, fallback: str, max_chars: int) -> str: text_value = str(value or "").strip() if not text_value: text_value = fallback return text_value[:max_chars] def _clamp_float(value: Any, fallback: float, minimum: float, maximum: float) -> float: try: number = float(value) except Exception: number = float(fallback or 0) return round(max(minimum, min(maximum, number)), 1)