astock-agent/backend/app/catalyst/mapper.py
2026-05-14 11:10:17 +08:00

263 lines
8.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""新闻/政策催化归因。
边界:
- LLM 只负责把文本映射到题材、提炼催化类型和解释。
- 行情、资金、最终动作仍由规则引擎决定。
"""
from __future__ import annotations
import json
import logging
import re
from datetime import datetime, timezone
from app.analysis.theme_mapper import THEME_ALIASES, THEME_NAMES, resolve_theme
from app.catalyst.models import CatalystAnalysis, CatalystInput, CatalystTheme
from app.config import settings
logger = logging.getLogger(__name__)
CATALYST_TYPE_KEYWORDS = {
"policy": ["政策", "工信部", "发改委", "国务院", "证监会", "财政部", "规划", "指导意见", "补贴"],
"industry": ["订单", "需求", "涨价", "产能", "景气", "出口", "交付", "装机", "销量"],
"event": ["大会", "发布会", "展会", "会议", "试点", "招标", "中标", "事故"],
"earnings": ["业绩", "净利润", "营收", "预增", "扭亏", "年报", "季报"],
"announcement": ["公告", "重组", "并购", "定增", "回购", "签订合同"],
}
STRENGTH_KEYWORDS = {
18: ["重大", "重磅", "首次", "超预期", "全面", "国家级"],
12: ["政策", "补贴", "涨价", "订单", "中标", "突破"],
8: ["试点", "规划", "发布", "扩产", "合作"],
}
async def analyze_catalyst(item: CatalystInput, use_llm: bool = True) -> CatalystAnalysis:
"""分析单条催化文本LLM 不可用时使用规则归因。"""
rule_result = _analyze_by_rules(item)
if not use_llm or not settings.deepseek_api_key:
return rule_result
llm_result = await _analyze_by_llm(item, rule_result)
return llm_result or rule_result
def _analyze_by_rules(item: CatalystInput) -> CatalystAnalysis:
text = f"{item.title}\n{item.content}".strip()
themes = _match_themes(text)
catalyst_type = _infer_catalyst_type(text)
strength = _score_strength(text, themes, catalyst_type)
freshness = _score_freshness(item.published_at)
confidence = 45 + min(len(themes) * 12, 35)
if catalyst_type in {"policy", "announcement"}:
confidence += 8
return CatalystAnalysis(
title=item.title,
summary=_summarize_text(item.content or item.title),
source=item.source,
url=item.url,
published_at=item.published_at,
catalyst_type=catalyst_type,
strength=min(strength, 100),
freshness=freshness,
confidence=min(confidence, 90),
themes=themes,
raw_text=text,
generated_by="rules",
)
async def _analyze_by_llm(item: CatalystInput, baseline: CatalystAnalysis) -> CatalystAnalysis | None:
from app.llm.client import get_client
client = get_client()
if not client:
return None
aliases_text = "\n".join(
f"- {THEME_NAMES[theme_id]}: {', '.join(aliases[:8])}"
for theme_id, aliases in THEME_ALIASES.items()
)
user_text = f"""\
请把下面新闻/政策/公告归因到 A 股题材。只做语义归因,不给买卖建议。
## 可选系统题材
{aliases_text}
## 文本
标题: {item.title}
来源: {item.source}
正文: {(item.content or '')[:1600]}
请严格输出 JSON
{{
"summary": "一句话摘要",
"catalyst_type": "policy | industry | event | earnings | announcement | news",
"strength": 0-100,
"confidence": 0-100,
"themes": [
{{"theme_name": "系统题材名或新题材名", "relevance": 0-100, "reason": "一句话"}}
],
"reason": "为什么这么归因"
}}"""
try:
response = await client.chat.completions.create(
model=settings.deepseek_model,
messages=[
{
"role": "system",
"content": (
"你是A股新闻催化归因器。"
"你只能做题材归因、催化类型和强度判断,不能输出买入卖出建议。"
"必须返回合法JSON。"
),
},
{"role": "user", "content": user_text},
],
max_tokens=700,
temperature=0.1,
)
data = _extract_json(response.choices[0].message.content or "")
if not data:
return None
themes = []
for raw_theme in data.get("themes", [])[:5]:
theme_name = str(raw_theme.get("theme_name", "")).strip()
if not theme_name:
continue
theme_id, resolved_name, _ = resolve_theme(theme_name)
themes.append(CatalystTheme(
theme_id=theme_id,
theme_name=resolved_name,
relevance=_clamp_float(raw_theme.get("relevance"), 0, 100, 60),
reason=str(raw_theme.get("reason", "")).strip(),
))
if not themes:
themes = baseline.themes
return CatalystAnalysis(
title=item.title,
summary=str(data.get("summary", "")).strip() or baseline.summary,
source=item.source,
url=item.url,
published_at=item.published_at,
catalyst_type=_normalize_type(data.get("catalyst_type")) or baseline.catalyst_type,
strength=_clamp_float(data.get("strength"), 0, 100, baseline.strength),
freshness=baseline.freshness,
confidence=_clamp_float(data.get("confidence"), 0, 100, baseline.confidence),
themes=themes,
raw_text=baseline.raw_text,
llm_reason=str(data.get("reason", "")).strip(),
generated_by="llm",
)
except Exception as e:
logger.warning("LLM 催化归因失败: %s", e)
return None
def _match_themes(text: str) -> list[CatalystTheme]:
clean_text = _clean(text)
matched: list[CatalystTheme] = []
for theme_id, aliases in THEME_ALIASES.items():
hits = []
for alias in aliases:
alias_clean = _clean(alias)
if alias_clean and alias_clean in clean_text:
hits.append(alias)
if not hits:
continue
relevance = min(55 + len(hits) * 12, 95)
matched.append(CatalystTheme(
theme_id=theme_id,
theme_name=THEME_NAMES[theme_id],
relevance=relevance,
reason=f"命中关键词: {'/'.join(hits[:3])}",
))
matched.sort(key=lambda item: item.relevance, reverse=True)
return matched[:5]
def _infer_catalyst_type(text: str) -> str:
for catalyst_type, keywords in CATALYST_TYPE_KEYWORDS.items():
if any(keyword in text for keyword in keywords):
return catalyst_type
return "news"
def _score_strength(text: str, themes: list[CatalystTheme], catalyst_type: str) -> float:
score = 35.0
if themes:
score += min(max(theme.relevance for theme in themes) * 0.25, 25)
for bonus, keywords in STRENGTH_KEYWORDS.items():
if any(keyword in text for keyword in keywords):
score += bonus
if catalyst_type == "policy":
score += 10
elif catalyst_type == "announcement":
score += 6
return round(min(score, 100), 1)
def _score_freshness(published_at: datetime | None) -> float:
if not published_at:
return 70
dt = published_at
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
hours = max((datetime.now(timezone.utc) - dt.astimezone(timezone.utc)).total_seconds() / 3600, 0)
if hours <= 6:
return 100
if hours <= 24:
return 90
if hours <= 72:
return 70
if hours <= 168:
return 45
return 20
def _summarize_text(text: str) -> str:
value = re.sub(r"\s+", " ", text or "").strip()
return value[:120]
def _clean(value: str) -> str:
return re.sub(r"[\s_\-()【】\[\]、,。:]+", "", value or "")
def _extract_json(text: str) -> dict:
text = (text or "").strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?", "", text).strip()
text = re.sub(r"```$", "", text).strip()
try:
return json.loads(text)
except Exception:
pass
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
try:
return json.loads(text[start:end + 1])
except Exception:
return {}
return {}
def _normalize_type(value) -> str:
text = str(value or "").strip().lower()
return text if text in {"policy", "industry", "event", "earnings", "announcement", "news"} else ""
def _clamp_float(value, minimum: float, maximum: float, default: float) -> float:
try:
num = float(value)
except (TypeError, ValueError):
return default
return max(minimum, min(maximum, num))