astock-agent/backend/app/research/industry_chain_agent.py
2026-06-10 08:36:25 +08:00

246 lines
11 KiB
Python

"""Theme and industry-chain mapping.
The first version is intentionally deterministic and editable. It gives the
research layer a stable vocabulary before LLM-assisted updates are introduced.
"""
from __future__ import annotations
from dataclasses import dataclass
import json
from typing import Any
from sqlalchemy import text
from app.db import tables
from app.db.database import get_db
@dataclass(frozen=True)
class ThemeChain:
theme: str
aliases: tuple[str, ...]
nodes: tuple[str, ...]
logic: str
stage: str = "mid"
lifecycle_status: str = "观察期"
THEME_CHAINS: tuple[ThemeChain, ...] = (
ThemeChain("AI算力", ("AI", "算力", "人工智能", "光模块", "服务器", "液冷", "PCB", "CPO"), ("光模块", "PCB", "服务器", "液冷", "IDC", "电源", "铜缆高速连接"), "AI 基础设施扩张带动上游硬件和数据中心链条。"),
ThemeChain("机器人", ("机器人", "人形机器人", "机器人装备", "减速器", "伺服", "控制器"), ("减速器", "伺服系统", "控制器", "传感器", "本体制造", "机器视觉"), "人形机器人产业趋势扩散,重点观察核心零部件和前排整机。"),
ThemeChain("低空经济", ("低空", "飞行汽车", "eVTOL", "通航", "无人机"), ("整机", "飞控", "电池", "空管", "材料", "运营服务"), "政策驱动低空基础设施和飞行器产业链加速。"),
ThemeChain("创新药", ("创新药", "医药", "CXO", "生物医药", "减肥药"), ("创新药企", "CXO", "原料药", "医疗器械", "商业化平台"), "政策、临床进展和出海订单共同驱动医药成长线。"),
ThemeChain("军工", ("军工", "航天", "航空", "卫星", "船舶"), ("航空装备", "航天电子", "卫星互联网", "船舶", "军工材料"), "订单周期、装备升级和地缘扰动共同驱动军工链。"),
ThemeChain("固态电池", ("固态电池", "锂电", "电池", "新能源车"), ("电解质", "正极材料", "负极材料", "隔膜", "设备", "整车验证"), "电池技术迭代推动材料和设备环节重估。"),
ThemeChain("半导体", ("半导体", "芯片", "集成电路", "存储", "先进封装"), ("设备", "材料", "设计", "制造", "封测", "存储"), "国产替代和周期复苏共同影响半导体链。"),
ThemeChain("传媒游戏", ("传媒", "游戏", "短剧", "影视", "AIGC"), ("游戏", "影视院线", "短剧", "营销", "版权IP", "AIGC应用"), "内容供给、AI 应用和监管周期变化驱动传媒弹性。"),
ThemeChain("化工材料", ("化工", "材料", "有机硅", "氟化工", "化纤"), ("基础化工", "新材料", "氟化工", "有机硅", "化纤", "电子化学品"), "价格周期和新材料需求共同影响化工材料链。"),
ThemeChain("有色金属", ("有色", "", "", "黄金", "稀土", ""), ("铜铝", "贵金属", "稀土", "锂资源", "加工材"), "通胀预期、供需缺口和新能源需求影响资源品。"),
ThemeChain("新能源", ("光伏", "风电", "储能", "新能源", "逆变器"), ("光伏组件", "逆变器", "储能", "风电设备", "电网设备"), "装机需求和价格出清影响新能源链条。"),
ThemeChain("消费电子", ("消费电子", "苹果", "MR", "折叠屏", "智能穿戴"), ("结构件", "面板", "光学", "声学", "组装", "芯片"), "终端创新周期带动零部件弹性。"),
)
def resolve_theme(name: str) -> ThemeChain | None:
normalized = name or ""
for theme in THEME_CHAINS:
if theme.theme in normalized or any(alias in normalized for alias in theme.aliases):
return theme
return None
async def ensure_theme_knowledge_seeded() -> None:
"""Seed editable theme knowledge from the built-in v1 map when empty."""
async with get_db() as db:
result = await db.execute(text("SELECT COUNT(*) AS count FROM theme_knowledge"))
count = int(result.fetchone()._mapping["count"])
if count:
return
now_values = []
for index, theme in enumerate(THEME_CHAINS):
now_values.append({
"theme_name": theme.theme,
"aliases_json": json.dumps(list(theme.aliases), ensure_ascii=False),
"logic_summary": theme.logic,
"lifecycle_status": theme.lifecycle_status,
"stage": theme.stage,
"is_active": True,
"sort_order": index,
})
await db.execute(tables.theme_knowledge_table.insert(), now_values)
node_values = []
for theme_index, theme in enumerate(THEME_CHAINS):
for node_index, node in enumerate(theme.nodes):
node_values.append({
"theme_name": theme.theme,
"chain_node": node,
"related_stocks": "[]",
"leader_stocks": "[]",
"node_role": "",
"is_active": True,
"sort_order": theme_index * 100 + node_index,
})
if node_values:
await db.execute(tables.theme_chain_knowledge_table.insert(), node_values)
await db.commit()
async def load_theme_chain_library() -> list[dict[str, Any]]:
await ensure_theme_knowledge_seeded()
async with get_db() as db:
themes_result = await db.execute(
text(
"SELECT * FROM theme_knowledge "
"WHERE is_active = 1 ORDER BY sort_order ASC, id ASC"
)
)
nodes_result = await db.execute(
text(
"SELECT * FROM theme_chain_knowledge "
"WHERE is_active = 1 ORDER BY sort_order ASC, id ASC"
)
)
node_map: dict[str, list[dict[str, Any]]] = {}
for row in nodes_result.fetchall():
item = dict(row._mapping)
node_map.setdefault(str(item.get("theme_name") or ""), []).append(item)
library = []
for row in themes_result.fetchall():
item = dict(row._mapping)
theme_name = str(item.get("theme_name") or "")
nodes = node_map.get(theme_name, [])
library.append({
"theme": theme_name,
"aliases": _safe_json_list(item.get("aliases_json")),
"logic": item.get("logic_summary") or "",
"stage": item.get("stage") or "mid",
"lifecycle_status": item.get("lifecycle_status") or "观察期",
"chain_nodes": [node.get("chain_node") for node in nodes if node.get("chain_node")] or ["未归类"],
"chain_items": [
{
"chain_node": node.get("chain_node") or "",
"related_stocks": _safe_json_list(node.get("related_stocks")),
"leader_stocks": _safe_json_list(node.get("leader_stocks")),
"node_role": node.get("node_role") or "",
}
for node in nodes
],
})
return library
def resolve_theme_from_library(name: str, library: list[dict[str, Any]]) -> dict[str, Any] | None:
normalized = name or ""
for theme in library:
theme_name = str(theme.get("theme") or "")
aliases = [str(alias) for alias in theme.get("aliases", [])]
if theme_name in normalized or any(alias and alias in normalized for alias in aliases):
return theme
return None
def map_sector_to_chain(sector_name: str, leading_stocks: list[dict[str, Any]] | None = None) -> dict[str, Any]:
theme = resolve_theme(sector_name)
if not theme:
return {
"theme": sector_name or "未归类",
"logic": "暂未命中内置产业链图谱,保留为未归类主题等待后续研究补全。",
"chain_nodes": ["未归类"],
"leader_stocks": leading_stocks or [],
}
return {
"theme": theme.theme,
"logic": theme.logic,
"chain_nodes": list(theme.nodes),
"chain_items": [
{"chain_node": node, "related_stocks": [], "leader_stocks": [], "node_role": ""}
for node in theme.nodes
],
"leader_stocks": leading_stocks or [],
}
def map_sector_to_chain_from_library(
sector_name: str,
leading_stocks: list[dict[str, Any]] | None,
library: list[dict[str, Any]],
) -> dict[str, Any]:
theme = resolve_theme_from_library(sector_name, library)
if not theme:
return map_sector_to_chain(sector_name, leading_stocks)
return {
"theme": theme["theme"],
"logic": theme["logic"] or "主题逻辑待补充。",
"chain_nodes": theme["chain_nodes"],
"chain_items": theme.get("chain_items", []),
"leader_stocks": leading_stocks or [],
"stage": theme.get("stage") or "mid",
"lifecycle_status": theme.get("lifecycle_status") or "观察期",
}
def infer_chain_node(theme_name: str, stock_name: str = "", sector_name: str = "") -> str:
theme = resolve_theme(theme_name) or resolve_theme(sector_name)
if not theme:
return "未归类"
text = f"{stock_name}{sector_name}{theme_name}"
for node in theme.nodes:
if node in text:
return node
return theme.nodes[0] if theme.nodes else "未归类"
def infer_chain_position_from_theme_view(theme_view: dict[str, Any], ts_code: str = "", stock_name: str = "") -> dict[str, str]:
"""Infer a candidate's industry-chain node and role from editable theme knowledge."""
chain_items = theme_view.get("chain_items") or []
for item in chain_items:
node = str(item.get("chain_node") or "")
node_role = str(item.get("node_role") or "")
if _stock_in_list(ts_code, stock_name, item.get("leader_stocks") or []):
return {"chain_node": node or "未归类", "stock_role": node_role or "核心股"}
if _stock_in_list(ts_code, stock_name, item.get("related_stocks") or []):
return {"chain_node": node or "未归类", "stock_role": node_role or "相关股"}
text_value = f"{stock_name}{theme_view.get('raw_sector', '')}{theme_view.get('theme', '')}"
for item in chain_items:
node = str(item.get("chain_node") or "")
if node and node in text_value:
return {"chain_node": node, "stock_role": str(item.get("node_role") or "环节标的")}
nodes = theme_view.get("chain_nodes") or []
fallback_node = str(nodes[0]) if nodes else infer_chain_node(str(theme_view.get("theme") or ""), stock_name, str(theme_view.get("raw_sector") or ""))
return {"chain_node": fallback_node or "未归类", "stock_role": "待归类"}
def _safe_json_list(value: Any) -> list[Any]:
if isinstance(value, list):
return value
if not value:
return []
try:
parsed = json.loads(str(value))
return parsed if isinstance(parsed, list) else []
except Exception:
return []
def _stock_in_list(ts_code: str, stock_name: str, values: list[Any]) -> bool:
code = (ts_code or "").upper()
name = stock_name or ""
for value in values:
if isinstance(value, dict):
raw_code = str(value.get("ts_code") or value.get("code") or "").upper()
raw_name = str(value.get("name") or "")
if raw_code and raw_code == code:
return True
if raw_name and raw_name == name:
return True
else:
text_value = str(value)
if text_value and (text_value.upper() == code or text_value == name):
return True
return False