astock-agent/backend/app/api/research.py
2026-06-10 08:36:25 +08:00

265 lines
9.9 KiB
Python

"""Research report APIs."""
from __future__ import annotations
import json
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import text
from app.core.deps import get_current_admin
from app.db import tables
from app.db.database import get_db
from app.research.industry_chain_agent import ensure_theme_knowledge_seeded, load_theme_chain_library
from app.research.report_agent import load_latest_research_report
router = APIRouter(prefix="/api/research", tags=["research"])
class ThemeKnowledgeUpdate(BaseModel):
theme_name: str = Field(min_length=1)
aliases: list[str] = Field(default_factory=list)
logic: str = ""
lifecycle_status: str = "观察期"
stage: str = "mid"
chain_nodes: list[str] = Field(default_factory=list)
chain_items: list[dict] = Field(default_factory=list)
is_active: bool = True
sort_order: int = 0
@router.get("/today")
async def get_today_research():
report = await load_latest_research_report()
if report:
return report
from app.engine.recommender import get_latest_recommendations
from app.research.report_agent import build_research_report_async
latest = await get_latest_recommendations()
latest_scan = latest.get("latest_scan") or {}
if latest_scan:
return await build_research_report_async(latest, latest_scan.get("scan_session") or "latest")
return {
"trade_date": datetime.now().strftime("%Y%m%d"),
"scan_session": "",
"scanned_at": "",
"market_view": {"regime": "unknown", "confidence": 0, "summary": "暂无研究报告。"},
"theme_views": [],
"industry_chain_map": [],
"opportunity_cards": [],
"risk_alerts": [],
"no_trade_reason": {"has_scan": False, "reason": "暂无完成扫描。"},
}
@router.get("/history")
async def get_research_history(days: int = 14):
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
async with get_db() as db:
result = await db.execute(
text(
"SELECT scan_session, trade_date, market_summary, theme_summary, no_trade_reason, report_json, created_at "
"FROM research_reports WHERE created_at >= :start "
"ORDER BY created_at DESC, id DESC LIMIT 60"
),
{"start": start},
)
rows = []
for row in result.fetchall():
r = row._mapping
rows.append({
"scan_session": r["scan_session"],
"trade_date": r["trade_date"],
"market_summary": r["market_summary"] or "",
"theme_summary": r["theme_summary"] or "",
"no_trade_reason": _safe_json(r["no_trade_reason"]),
"created_at": str(r["created_at"] or ""),
})
return rows
@router.get("/themes")
async def get_research_themes():
async with get_db() as db:
result = await db.execute(
text(
"SELECT * FROM theme_maps "
"WHERE scan_session = (SELECT scan_session FROM research_reports ORDER BY created_at DESC, id DESC LIMIT 1) "
"ORDER BY heat_score DESC"
)
)
return [dict(row._mapping) for row in result.fetchall()]
@router.get("/opportunities")
async def get_research_opportunities():
async with get_db() as db:
result = await db.execute(
text(
"SELECT * FROM opportunity_cards "
"WHERE scan_session = (SELECT scan_session FROM research_reports ORDER BY created_at DESC, id DESC LIMIT 1) "
"ORDER BY score DESC"
)
)
return [dict(row._mapping) for row in result.fetchall()]
@router.get("/risks")
async def get_research_risks():
async with get_db() as db:
result = await db.execute(
text(
"SELECT * FROM risk_events "
"WHERE scan_session = (SELECT scan_session FROM research_reports ORDER BY created_at DESC, id DESC LIMIT 1) "
"ORDER BY reject DESC, severity DESC, id DESC"
)
)
return [dict(row._mapping) for row in result.fetchall()]
@router.get("/review")
async def get_research_review(days: int = 60):
from app.research.review_agent import build_research_review
return await build_research_review(days=days)
@router.get("/theme-knowledge")
async def get_theme_knowledge():
return await load_theme_chain_library()
@router.put("/theme-knowledge/{theme_name}")
async def update_theme_knowledge(
theme_name: str,
payload: ThemeKnowledgeUpdate,
_admin: dict = Depends(get_current_admin),
):
await ensure_theme_knowledge_seeded()
normalized_name = payload.theme_name.strip() or theme_name.strip()
if not normalized_name:
raise HTTPException(status_code=400, detail="主题名称不能为空")
chain_nodes = [node.strip() for node in payload.chain_nodes if node.strip()]
if payload.chain_items:
chain_nodes = [str(item.get("chain_node") or "").strip() for item in payload.chain_items if str(item.get("chain_node") or "").strip()]
if not chain_nodes:
raise HTTPException(status_code=400, detail="至少需要一个产业链环节")
async with get_db() as db:
existing = await db.execute(
text("SELECT id FROM theme_knowledge WHERE theme_name = :theme_name LIMIT 1"),
{"theme_name": theme_name},
)
row = existing.fetchone()
values = {
"theme_name": normalized_name,
"aliases_json": json.dumps([item.strip() for item in payload.aliases if item.strip()], ensure_ascii=False),
"logic_summary": payload.logic.strip(),
"lifecycle_status": payload.lifecycle_status.strip() or "观察期",
"stage": payload.stage.strip() or "mid",
"is_active": bool(payload.is_active),
"sort_order": int(payload.sort_order or 0),
}
if row:
await db.execute(
text(
"UPDATE theme_knowledge SET theme_name = :theme_name, aliases_json = :aliases_json, "
"logic_summary = :logic_summary, lifecycle_status = :lifecycle_status, stage = :stage, "
"is_active = :is_active, sort_order = :sort_order, updated_at = CURRENT_TIMESTAMP "
"WHERE id = :id"
),
{**values, "id": row._mapping["id"]},
)
else:
await db.execute(tables.theme_knowledge_table.insert().values(**values))
await db.execute(text("DELETE FROM theme_chain_knowledge WHERE theme_name = :theme_name"), {"theme_name": theme_name})
if normalized_name != theme_name:
await db.execute(text("DELETE FROM theme_chain_knowledge WHERE theme_name = :theme_name"), {"theme_name": normalized_name})
await db.execute(
tables.theme_chain_knowledge_table.insert(),
[
{
"theme_name": normalized_name,
"chain_node": item["chain_node"],
"related_stocks": json.dumps(item.get("related_stocks", []), ensure_ascii=False, default=str),
"leader_stocks": json.dumps(item.get("leader_stocks", []), ensure_ascii=False, default=str),
"node_role": item.get("node_role", ""),
"is_active": True,
"sort_order": index,
}
for index, item in enumerate(_normalize_chain_items(payload.chain_items, chain_nodes))
],
)
await db.commit()
library = await load_theme_chain_library()
for item in library:
if item.get("theme") == normalized_name:
return item
return {"status": "ok", "theme": normalized_name}
@router.post("/refresh")
async def refresh_research(_admin: dict = Depends(get_current_admin)):
from app.engine.recommender import get_latest_recommendations
from app.research.report_agent import build_research_report_async, save_research_report
result = await get_latest_recommendations()
latest_scan = result.get("latest_scan") or {}
scan_session = latest_scan.get("scan_session") or "manual_research"
report = await build_research_report_async(result, scan_session)
await save_research_report(report)
return {"status": "ok", "scan_session": scan_session, "opportunity_count": len(report.get("opportunity_cards", []))}
def _safe_json(value: str | None) -> dict:
if not value:
return {}
try:
parsed = json.loads(value)
return parsed if isinstance(parsed, dict) else {}
except Exception:
return {}
def _normalize_chain_items(chain_items: list[dict], chain_nodes: list[str]) -> list[dict]:
if not chain_items:
return [
{"chain_node": node, "related_stocks": [], "leader_stocks": [], "node_role": ""}
for node in chain_nodes
]
normalized = []
for item in chain_items:
node = str(item.get("chain_node") or "").strip()
if not node:
continue
normalized.append({
"chain_node": node,
"related_stocks": _stock_list(item.get("related_stocks")),
"leader_stocks": _stock_list(item.get("leader_stocks")),
"node_role": str(item.get("node_role") or "").strip(),
})
return normalized
def _stock_list(value) -> list:
if not isinstance(value, list):
return []
cleaned = []
for item in value:
if isinstance(item, dict):
name = str(item.get("name") or "").strip()
ts_code = str(item.get("ts_code") or item.get("code") or "").strip()
if name or ts_code:
cleaned.append({"name": name, "ts_code": ts_code})
else:
text_value = str(item).strip()
if text_value:
cleaned.append(text_value)
return cleaned