265 lines
9.9 KiB
Python
265 lines
9.9 KiB
Python
"""Research report APIs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import text
|
|
|
|
from app.core.deps import get_current_admin
|
|
from app.db import tables
|
|
from app.db.database import get_db
|
|
from app.research.industry_chain_agent import ensure_theme_knowledge_seeded, load_theme_chain_library
|
|
from app.research.report_agent import load_latest_research_report
|
|
|
|
router = APIRouter(prefix="/api/research", tags=["research"])
|
|
|
|
|
|
class ThemeKnowledgeUpdate(BaseModel):
|
|
theme_name: str = Field(min_length=1)
|
|
aliases: list[str] = Field(default_factory=list)
|
|
logic: str = ""
|
|
lifecycle_status: str = "观察期"
|
|
stage: str = "mid"
|
|
chain_nodes: list[str] = Field(default_factory=list)
|
|
chain_items: list[dict] = Field(default_factory=list)
|
|
is_active: bool = True
|
|
sort_order: int = 0
|
|
|
|
|
|
@router.get("/today")
|
|
async def get_today_research():
|
|
report = await load_latest_research_report()
|
|
if report:
|
|
return report
|
|
from app.engine.recommender import get_latest_recommendations
|
|
from app.research.report_agent import build_research_report_async
|
|
|
|
latest = await get_latest_recommendations()
|
|
latest_scan = latest.get("latest_scan") or {}
|
|
if latest_scan:
|
|
return await build_research_report_async(latest, latest_scan.get("scan_session") or "latest")
|
|
return {
|
|
"trade_date": datetime.now().strftime("%Y%m%d"),
|
|
"scan_session": "",
|
|
"scanned_at": "",
|
|
"market_view": {"regime": "unknown", "confidence": 0, "summary": "暂无研究报告。"},
|
|
"theme_views": [],
|
|
"industry_chain_map": [],
|
|
"opportunity_cards": [],
|
|
"risk_alerts": [],
|
|
"no_trade_reason": {"has_scan": False, "reason": "暂无完成扫描。"},
|
|
}
|
|
|
|
|
|
@router.get("/history")
|
|
async def get_research_history(days: int = 14):
|
|
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT scan_session, trade_date, market_summary, theme_summary, no_trade_reason, report_json, created_at "
|
|
"FROM research_reports WHERE created_at >= :start "
|
|
"ORDER BY created_at DESC, id DESC LIMIT 60"
|
|
),
|
|
{"start": start},
|
|
)
|
|
rows = []
|
|
for row in result.fetchall():
|
|
r = row._mapping
|
|
rows.append({
|
|
"scan_session": r["scan_session"],
|
|
"trade_date": r["trade_date"],
|
|
"market_summary": r["market_summary"] or "",
|
|
"theme_summary": r["theme_summary"] or "",
|
|
"no_trade_reason": _safe_json(r["no_trade_reason"]),
|
|
"created_at": str(r["created_at"] or ""),
|
|
})
|
|
return rows
|
|
|
|
|
|
@router.get("/themes")
|
|
async def get_research_themes():
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM theme_maps "
|
|
"WHERE scan_session = (SELECT scan_session FROM research_reports ORDER BY created_at DESC, id DESC LIMIT 1) "
|
|
"ORDER BY heat_score DESC"
|
|
)
|
|
)
|
|
return [dict(row._mapping) for row in result.fetchall()]
|
|
|
|
|
|
@router.get("/opportunities")
|
|
async def get_research_opportunities():
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM opportunity_cards "
|
|
"WHERE scan_session = (SELECT scan_session FROM research_reports ORDER BY created_at DESC, id DESC LIMIT 1) "
|
|
"ORDER BY score DESC"
|
|
)
|
|
)
|
|
return [dict(row._mapping) for row in result.fetchall()]
|
|
|
|
|
|
@router.get("/risks")
|
|
async def get_research_risks():
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM risk_events "
|
|
"WHERE scan_session = (SELECT scan_session FROM research_reports ORDER BY created_at DESC, id DESC LIMIT 1) "
|
|
"ORDER BY reject DESC, severity DESC, id DESC"
|
|
)
|
|
)
|
|
return [dict(row._mapping) for row in result.fetchall()]
|
|
|
|
|
|
@router.get("/review")
|
|
async def get_research_review(days: int = 60):
|
|
from app.research.review_agent import build_research_review
|
|
|
|
return await build_research_review(days=days)
|
|
|
|
|
|
@router.get("/theme-knowledge")
|
|
async def get_theme_knowledge():
|
|
return await load_theme_chain_library()
|
|
|
|
|
|
@router.put("/theme-knowledge/{theme_name}")
|
|
async def update_theme_knowledge(
|
|
theme_name: str,
|
|
payload: ThemeKnowledgeUpdate,
|
|
_admin: dict = Depends(get_current_admin),
|
|
):
|
|
await ensure_theme_knowledge_seeded()
|
|
normalized_name = payload.theme_name.strip() or theme_name.strip()
|
|
if not normalized_name:
|
|
raise HTTPException(status_code=400, detail="主题名称不能为空")
|
|
|
|
chain_nodes = [node.strip() for node in payload.chain_nodes if node.strip()]
|
|
if payload.chain_items:
|
|
chain_nodes = [str(item.get("chain_node") or "").strip() for item in payload.chain_items if str(item.get("chain_node") or "").strip()]
|
|
if not chain_nodes:
|
|
raise HTTPException(status_code=400, detail="至少需要一个产业链环节")
|
|
|
|
async with get_db() as db:
|
|
existing = await db.execute(
|
|
text("SELECT id FROM theme_knowledge WHERE theme_name = :theme_name LIMIT 1"),
|
|
{"theme_name": theme_name},
|
|
)
|
|
row = existing.fetchone()
|
|
values = {
|
|
"theme_name": normalized_name,
|
|
"aliases_json": json.dumps([item.strip() for item in payload.aliases if item.strip()], ensure_ascii=False),
|
|
"logic_summary": payload.logic.strip(),
|
|
"lifecycle_status": payload.lifecycle_status.strip() or "观察期",
|
|
"stage": payload.stage.strip() or "mid",
|
|
"is_active": bool(payload.is_active),
|
|
"sort_order": int(payload.sort_order or 0),
|
|
}
|
|
if row:
|
|
await db.execute(
|
|
text(
|
|
"UPDATE theme_knowledge SET theme_name = :theme_name, aliases_json = :aliases_json, "
|
|
"logic_summary = :logic_summary, lifecycle_status = :lifecycle_status, stage = :stage, "
|
|
"is_active = :is_active, sort_order = :sort_order, updated_at = CURRENT_TIMESTAMP "
|
|
"WHERE id = :id"
|
|
),
|
|
{**values, "id": row._mapping["id"]},
|
|
)
|
|
else:
|
|
await db.execute(tables.theme_knowledge_table.insert().values(**values))
|
|
|
|
await db.execute(text("DELETE FROM theme_chain_knowledge WHERE theme_name = :theme_name"), {"theme_name": theme_name})
|
|
if normalized_name != theme_name:
|
|
await db.execute(text("DELETE FROM theme_chain_knowledge WHERE theme_name = :theme_name"), {"theme_name": normalized_name})
|
|
await db.execute(
|
|
tables.theme_chain_knowledge_table.insert(),
|
|
[
|
|
{
|
|
"theme_name": normalized_name,
|
|
"chain_node": item["chain_node"],
|
|
"related_stocks": json.dumps(item.get("related_stocks", []), ensure_ascii=False, default=str),
|
|
"leader_stocks": json.dumps(item.get("leader_stocks", []), ensure_ascii=False, default=str),
|
|
"node_role": item.get("node_role", ""),
|
|
"is_active": True,
|
|
"sort_order": index,
|
|
}
|
|
for index, item in enumerate(_normalize_chain_items(payload.chain_items, chain_nodes))
|
|
],
|
|
)
|
|
await db.commit()
|
|
|
|
library = await load_theme_chain_library()
|
|
for item in library:
|
|
if item.get("theme") == normalized_name:
|
|
return item
|
|
return {"status": "ok", "theme": normalized_name}
|
|
|
|
|
|
@router.post("/refresh")
|
|
async def refresh_research(_admin: dict = Depends(get_current_admin)):
|
|
from app.engine.recommender import get_latest_recommendations
|
|
from app.research.report_agent import build_research_report_async, save_research_report
|
|
|
|
result = await get_latest_recommendations()
|
|
latest_scan = result.get("latest_scan") or {}
|
|
scan_session = latest_scan.get("scan_session") or "manual_research"
|
|
report = await build_research_report_async(result, scan_session)
|
|
await save_research_report(report)
|
|
return {"status": "ok", "scan_session": scan_session, "opportunity_count": len(report.get("opportunity_cards", []))}
|
|
|
|
|
|
def _safe_json(value: str | None) -> dict:
|
|
if not value:
|
|
return {}
|
|
try:
|
|
parsed = json.loads(value)
|
|
return parsed if isinstance(parsed, dict) else {}
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
def _normalize_chain_items(chain_items: list[dict], chain_nodes: list[str]) -> list[dict]:
|
|
if not chain_items:
|
|
return [
|
|
{"chain_node": node, "related_stocks": [], "leader_stocks": [], "node_role": ""}
|
|
for node in chain_nodes
|
|
]
|
|
normalized = []
|
|
for item in chain_items:
|
|
node = str(item.get("chain_node") or "").strip()
|
|
if not node:
|
|
continue
|
|
normalized.append({
|
|
"chain_node": node,
|
|
"related_stocks": _stock_list(item.get("related_stocks")),
|
|
"leader_stocks": _stock_list(item.get("leader_stocks")),
|
|
"node_role": str(item.get("node_role") or "").strip(),
|
|
})
|
|
return normalized
|
|
|
|
|
|
def _stock_list(value) -> list:
|
|
if not isinstance(value, list):
|
|
return []
|
|
cleaned = []
|
|
for item in value:
|
|
if isinstance(item, dict):
|
|
name = str(item.get("name") or "").strip()
|
|
ts_code = str(item.get("ts_code") or item.get("code") or "").strip()
|
|
if name or ts_code:
|
|
cleaned.append({"name": name, "ts_code": ts_code})
|
|
else:
|
|
text_value = str(item).strip()
|
|
if text_value:
|
|
cleaned.append(text_value)
|
|
return cleaned
|