astock-agent/backend/app/llm/batch_screener.py
2026-04-30 20:28:19 +08:00

379 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM 候选预筛 + 逐股深度分析
先做轻量预筛,控制深度裁决成本;
再对重点股票单独调用 LLM 做深度分析。
"""
import asyncio
import json
import logging
import re
from app.config import settings
from app.db.error_logger import log_error
logger = logging.getLogger(__name__)
async def prefilter_single_stock(candidate: dict, market_summary: str) -> dict:
"""对单只候选股票做轻量 LLM 预筛。"""
from app.llm.prompts import STOCK_PREFILTER_PROMPT
from app.llm.strategy_config import get_prompt_content
from app.llm.client import get_client
stock_text = f"""\
股票: {candidate['name']}({candidate['ts_code']})
主题: {candidate.get('sector', '未知')}
主线主题匹配: {"是,匹配 " + candidate.get("hot_theme_name", "") if candidate.get("hot_theme_matched") else ""}
主题别名: {", ".join(candidate.get("hot_theme_aliases", []) or [""])}
召回来源: {', '.join(candidate.get('recall_tags', []) or ['未标注'])}
规则参考分: {candidate.get('quant_score', 0)}/100
资金顺势分: {candidate.get('flow_momentum_score', 0)}/100
位置安全: {candidate.get('position_score', 50)}/100
当前价: {candidate.get('current_price', '未知')}
主题阶段: {candidate.get('sector_stage', '未知')}
个股角色线索: {candidate.get('stock_role_hint', '待判断')}"""
if candidate.get("kline_summary"):
stock_text += f"\n\n## 技术结构摘要\n{candidate['kline_summary']}"
if candidate.get("capital_flow_summary"):
stock_text += f"\n\n## 资金与活跃度摘要\n{candidate['capital_flow_summary']}"
if candidate.get("intraday_volume"):
stock_text += f"\n\n## 分时量能摘要\n{candidate['intraday_volume']}"
prompt = await get_prompt_content("stock_prefilter", STOCK_PREFILTER_PROMPT)
user_msg = f"{prompt}\n\n## 市场环境\n{market_summary}\n\n{stock_text}\n\n请输出 JSON。"
try:
client = get_client()
response = await client.chat.completions.create(
model=settings.deepseek_model,
messages=[
{
"role": "system",
"content": (
"你是A股候选池预审官。"
"你只负责决定资源分配优先级,不直接下最终交易结论。"
"必须返回合法JSON。"
),
},
{"role": "user", "content": user_msg},
],
max_tokens=400,
temperature=0.2,
)
content = response.choices[0].message.content.strip()
return _parse_prefilter_response(content)
except Exception as e:
logger.error(f"LLM 预筛 {candidate.get('ts_code')} 失败: {e}")
await log_error(
"llm_prefilter",
f"LLM 预筛 {candidate.get('ts_code')} 失败: {e}",
detail=f"candidate={candidate.get('ts_code')}|{candidate.get('name', '')}",
)
return {
"decision": "watch",
"confidence": 5,
"reason": "AI 预筛暂不可用,保留观察",
"focus_points": [],
}
async def analyze_single_stock(candidate: dict, market_summary: str) -> dict:
"""对单只股票做 LLM 深度分析
candidate: 包含 ts_code, name, sector, quant_score, position_score,
kline_summary, capital_flow_summary
market_summary: 市场环境摘要
返回: {
"verdict": "execute"/"watch"/"skip",
"action_plan": "可操作"/"重点关注"/"观察",
"conviction": int,
"timing": str,
"entry_price": float or None,
"target_price": float or None,
"stop_loss": float or None,
"trigger_condition": str,
"invalidation_condition": str,
"position_pct": int,
"risk_flag": str,
"analysis": str,
}
"""
from app.llm.prompts import SINGLE_STOCK_ANALYSIS_PROMPT
from app.llm.strategy_config import get_prompt_content
from app.llm.client import get_client
# 构建 prompt — 不传 signal_type让 LLM 独立判断
stock_text = f"""\
股票: {candidate['name']}({candidate['ts_code']})
主题: {candidate.get('sector', '未知')}
主线主题匹配: {"是,匹配 " + candidate.get("hot_theme_name", "") if candidate.get("hot_theme_matched") else ""}
规则参考分: {candidate.get('quant_score', 0)}/100
资金顺势分: {candidate.get('flow_momentum_score', 0)}/100
位置安全: {candidate.get('position_score', 50)}/100
当前价: {candidate.get('current_price', '未知')}"""
if candidate.get("kline_summary"):
stock_text += f"\n\n## 技术分析结论\n{candidate['kline_summary']}"
if candidate.get("capital_flow_summary"):
stock_text += f"\n\n## 资金流向\n{candidate['capital_flow_summary']}"
prompt = await get_prompt_content("single_stock_analysis", SINGLE_STOCK_ANALYSIS_PROMPT)
user_msg = f"{prompt}\n\n## 市场环境\n{market_summary}\n\n{stock_text}\n\n请给出你的分析。"
try:
client = get_client()
response = await client.chat.completions.create(
model=settings.deepseek_model,
messages=[
{
"role": "system",
"content": (
"你是一位A股短线交易裁决员。"
"你的任务是决定这只股票今天是否该进入推荐前列,以及应该归入可操作、重点关注还是观察。"
"若候选股没有匹配今日主线主题,除非它是全市场极强异动且触发条件清晰,否则不要给可操作或重点关注。"
"不要复述数据,不要写成长文,不要被规则参考分绑架。"
"必须返回合法JSON。"
),
},
{"role": "user", "content": user_msg},
],
max_tokens=800,
temperature=0.3,
)
content = response.choices[0].message.content.strip()
return _parse_single_response(content)
except Exception as e:
logger.error(f"LLM 分析 {candidate.get('ts_code')} 失败: {e}")
await log_error(
"llm_final",
f"LLM 分析 {candidate.get('ts_code')} 失败: {e}",
detail=f"candidate={candidate.get('ts_code')}|{candidate.get('name', '')}",
)
return {
"verdict": "watch",
"action_plan": "重点关注",
"conviction": 4,
"timing": "",
"entry_price": None,
"target_price": None,
"stop_loss": None,
"trigger_condition": "",
"invalidation_condition": "",
"position_pct": 0,
"risk_flag": "AI 裁决暂不可用",
"analysis": "AI分析暂不可用",
}
def _parse_prefilter_response(text: str) -> dict:
data = _extract_json_object(text)
if not data:
return {
"decision": "watch",
"confidence": 5,
"reason": "预筛输出不可解析,默认保留观察",
"focus_points": [],
}
decision = str(data.get("decision", "watch")).strip().lower()
if decision not in {"priority", "watch", "ignore"}:
decision = "watch"
focus_points = data.get("focus_points") or []
if not isinstance(focus_points, list):
focus_points = []
return {
"decision": decision,
"confidence": _clamp_int(data.get("confidence"), minimum=1, maximum=10, default=5),
"reason": str(data.get("reason", "")).strip() or "暂无说明",
"focus_points": [str(item).strip() for item in focus_points[:3] if str(item).strip()],
}
def _parse_single_response(text: str) -> dict:
"""解析单只股票的 LLM 返回"""
data = _extract_json_object(text)
if data:
verdict = str(data.get("verdict", "watch")).strip().lower()
if verdict not in {"execute", "watch", "skip"}:
verdict = "watch"
action_plan = str(data.get("action_plan", "")).strip()
if action_plan not in {"可操作", "重点关注", "观察"}:
action_plan = {"execute": "可操作", "watch": "重点关注", "skip": "观察"}[verdict]
conviction = _clamp_int(data.get("conviction"), minimum=1, maximum=10, default=6)
position_pct = _clamp_int(data.get("position_pct"), minimum=0, maximum=35, default=0)
return {
"verdict": verdict,
"action_plan": action_plan,
"conviction": conviction,
"timing": str(data.get("timing", "")).strip(),
"entry_price": _parse_float(data.get("entry_price")),
"target_price": _parse_float(data.get("target_price")),
"stop_loss": _parse_float(data.get("stop_loss")),
"trigger_condition": str(data.get("trigger_condition", "")).strip(),
"invalidation_condition": str(data.get("invalidation_condition", "")).strip(),
"position_pct": position_pct,
"analysis": str(data.get("analysis", "")).strip() or "暂无分析",
"risk_flag": str(data.get("risk_flag", "")).strip(),
}
# 兼容旧格式
signal = "HOLD"
signal_match = re.search(r"信号[:\s]*(BUY|HOLD|SKIP)", text)
if signal_match:
signal = signal_match.group(1)
verdict = "execute" if signal == "BUY" else "skip" if signal == "SKIP" else "watch"
strength = ""
strength_match = re.search(r"信号强度[:\s]*(强|中|弱)", text)
if strength_match:
strength = strength_match.group(1)
conviction = {"": 8, "": 6, "": 4}.get(strength, 6)
entry_price = None
entry_match = re.search(r"买入价[:\s]*(\d+(?:\.\d+)?)", text)
if entry_match:
entry_price = float(entry_match.group(1))
target_price = None
target_match = re.search(r"止盈价[:\s]*(\d+(?:\.\d+)?)", text)
if target_match:
target_price = float(target_match.group(1))
stop_loss = None
stop_match = re.search(r"止损价[:\s]*(\d+(?:\.\d+)?)", text)
if stop_match:
stop_loss = float(stop_match.group(1))
analysis = ""
analysis_match = re.search(r"分析[:\s]*(.+)", text, re.DOTALL)
if analysis_match:
analysis = analysis_match.group(1).strip()
return {
"verdict": verdict,
"action_plan": {"execute": "可操作", "watch": "重点关注", "skip": "观察"}[verdict],
"conviction": conviction,
"timing": "",
"entry_price": entry_price,
"target_price": target_price,
"stop_loss": stop_loss,
"trigger_condition": "",
"invalidation_condition": "",
"position_pct": 20 if verdict == "execute" else 0,
"analysis": analysis or "暂无分析",
"risk_flag": "",
}
def _extract_json_object(text: str) -> dict | None:
match = re.search(r"\{[\s\S]*\}", text)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except Exception:
return None
def _parse_float(value) -> float | None:
try:
if value in (None, "", 0, "0"):
return None
return float(value)
except Exception:
return None
def _clamp_int(value, minimum: int, maximum: int, default: int) -> int:
try:
parsed = int(round(float(value)))
except Exception:
return default
return max(minimum, min(maximum, parsed))
async def analyze_candidates_individually(
candidates: list[dict], market_summary: str, max_concurrent: int = 3
) -> dict[str, dict]:
"""对候选股票逐个做 LLM 分析(控制并发数)
返回: {ts_code: {"verdict", "action_plan", "conviction", "entry_price", ...}}
"""
if not settings.deepseek_api_key or not candidates:
return {}
results = {}
semaphore = asyncio.Semaphore(max_concurrent)
async def _analyze_with_semaphore(c: dict):
async with semaphore:
ts_code = c["ts_code"]
logger.info(f"LLM 分析: {c.get('name', ts_code)}")
result = await analyze_single_stock(c, market_summary)
logger.info(
f"LLM 结果: {c.get('name', ts_code)}"
f"裁决={result['verdict']} 计划={result['action_plan']} 置信度={result['conviction']} "
f"买入={result.get('entry_price')} 止盈={result.get('target_price')} "
f"止损={result.get('stop_loss')}"
)
return ts_code, result
tasks = [_analyze_with_semaphore(c) for c in candidates]
completed = await asyncio.gather(*tasks, return_exceptions=True)
for item in completed:
if isinstance(item, Exception):
logger.error(f"LLM 分析任务异常: {item}")
continue
ts_code, result = item
results[ts_code] = result
logger.info(f"LLM 逐股分析完成: {len(results)}/{len(candidates)}")
return results
async def prefilter_candidates_individually(
candidates: list[dict], market_summary: str, max_concurrent: int = 6
) -> dict[str, dict]:
"""对候选股票逐个做 LLM 预筛。"""
if not settings.deepseek_api_key or not candidates:
return {}
results = {}
semaphore = asyncio.Semaphore(max_concurrent)
async def _prefilter_with_semaphore(c: dict):
async with semaphore:
ts_code = c["ts_code"]
logger.info(f"LLM 预筛: {c.get('name', ts_code)}")
result = await prefilter_single_stock(c, market_summary)
logger.info(
f"LLM 预筛结果: {c.get('name', ts_code)}"
f"decision={result['decision']} confidence={result['confidence']}"
)
return ts_code, result
tasks = [_prefilter_with_semaphore(c) for c in candidates]
completed = await asyncio.gather(*tasks, return_exceptions=True)
for item in completed:
if isinstance(item, Exception):
logger.error(f"LLM 预筛任务异常: {item}")
continue
ts_code, result = item
results[ts_code] = result
logger.info(f"LLM 预筛完成: {len(results)}/{len(candidates)}")
return results