astock-agent/backend/app/llm/batch_screener.py
2026-04-22 11:56:23 +08:00

241 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM 逐股深度分析
量化筛选完成后,对每只候选股票单独调用 LLM 做深度分析,
让 AI 独立判断入场时机并给出具体买卖价格。
"""
import asyncio
import json
import logging
import re
from app.config import settings
logger = logging.getLogger(__name__)
async def analyze_single_stock(candidate: dict, market_summary: str) -> dict:
"""对单只股票做 LLM 深度分析
candidate: 包含 ts_code, name, sector, quant_score, position_score,
kline_summary, capital_flow_summary
market_summary: 市场环境摘要
返回: {
"verdict": "execute"/"watch"/"skip",
"action_plan": "可操作"/"重点关注"/"观察",
"conviction": int,
"timing": str,
"entry_price": float or None,
"target_price": float or None,
"stop_loss": float or None,
"trigger_condition": str,
"invalidation_condition": str,
"position_pct": int,
"risk_flag": str,
"analysis": str,
}
"""
from app.llm.prompts import SINGLE_STOCK_ANALYSIS_PROMPT
from app.llm.client import get_client
# 构建 prompt — 不传 signal_type让 LLM 独立判断
stock_text = f"""\
股票: {candidate['name']}({candidate['ts_code']})
板块: {candidate.get('sector', '未知')}
规则参考分: {candidate.get('quant_score', 0)}/100
位置安全: {candidate.get('position_score', 50)}/100
当前价: {candidate.get('current_price', '未知')}"""
if candidate.get("kline_summary"):
stock_text += f"\n\n## 技术分析结论\n{candidate['kline_summary']}"
if candidate.get("capital_flow_summary"):
stock_text += f"\n\n## 资金流向\n{candidate['capital_flow_summary']}"
user_msg = f"{SINGLE_STOCK_ANALYSIS_PROMPT}\n\n## 市场环境\n{market_summary}\n\n{stock_text}\n\n请给出你的分析。"
try:
client = get_client()
response = await client.chat.completions.create(
model=settings.deepseek_model,
messages=[
{
"role": "system",
"content": (
"你是一位A股短线交易裁决员。"
"你的任务是决定这只股票今天是否该进入推荐前列,以及应该归入可操作、重点关注还是观察。"
"不要复述数据,不要写成长文,不要被规则参考分绑架。"
"必须返回合法JSON。"
),
},
{"role": "user", "content": user_msg},
],
max_tokens=800,
temperature=0.3,
)
content = response.choices[0].message.content.strip()
return _parse_single_response(content)
except Exception as e:
logger.error(f"LLM 分析 {candidate.get('ts_code')} 失败: {e}")
return {
"verdict": "watch",
"action_plan": "重点关注",
"conviction": 4,
"timing": "",
"entry_price": None,
"target_price": None,
"stop_loss": None,
"trigger_condition": "",
"invalidation_condition": "",
"position_pct": 0,
"risk_flag": "AI 裁决暂不可用",
"analysis": "AI分析暂不可用",
}
def _parse_single_response(text: str) -> dict:
"""解析单只股票的 LLM 返回"""
data = _extract_json_object(text)
if data:
verdict = str(data.get("verdict", "watch")).strip().lower()
if verdict not in {"execute", "watch", "skip"}:
verdict = "watch"
action_plan = str(data.get("action_plan", "")).strip()
if action_plan not in {"可操作", "重点关注", "观察"}:
action_plan = {"execute": "可操作", "watch": "重点关注", "skip": "观察"}[verdict]
conviction = _clamp_int(data.get("conviction"), minimum=1, maximum=10, default=6)
position_pct = _clamp_int(data.get("position_pct"), minimum=0, maximum=35, default=0)
return {
"verdict": verdict,
"action_plan": action_plan,
"conviction": conviction,
"timing": str(data.get("timing", "")).strip(),
"entry_price": _parse_float(data.get("entry_price")),
"target_price": _parse_float(data.get("target_price")),
"stop_loss": _parse_float(data.get("stop_loss")),
"trigger_condition": str(data.get("trigger_condition", "")).strip(),
"invalidation_condition": str(data.get("invalidation_condition", "")).strip(),
"position_pct": position_pct,
"analysis": str(data.get("analysis", "")).strip() or "暂无分析",
"risk_flag": str(data.get("risk_flag", "")).strip(),
}
# 兼容旧格式
signal = "HOLD"
signal_match = re.search(r"信号[:\s]*(BUY|HOLD|SKIP)", text)
if signal_match:
signal = signal_match.group(1)
verdict = "execute" if signal == "BUY" else "skip" if signal == "SKIP" else "watch"
strength = ""
strength_match = re.search(r"信号强度[:\s]*(强|中|弱)", text)
if strength_match:
strength = strength_match.group(1)
conviction = {"": 8, "": 6, "": 4}.get(strength, 6)
entry_price = None
entry_match = re.search(r"买入价[:\s]*(\d+(?:\.\d+)?)", text)
if entry_match:
entry_price = float(entry_match.group(1))
target_price = None
target_match = re.search(r"止盈价[:\s]*(\d+(?:\.\d+)?)", text)
if target_match:
target_price = float(target_match.group(1))
stop_loss = None
stop_match = re.search(r"止损价[:\s]*(\d+(?:\.\d+)?)", text)
if stop_match:
stop_loss = float(stop_match.group(1))
analysis = ""
analysis_match = re.search(r"分析[:\s]*(.+)", text, re.DOTALL)
if analysis_match:
analysis = analysis_match.group(1).strip()
return {
"verdict": verdict,
"action_plan": {"execute": "可操作", "watch": "重点关注", "skip": "观察"}[verdict],
"conviction": conviction,
"timing": "",
"entry_price": entry_price,
"target_price": target_price,
"stop_loss": stop_loss,
"trigger_condition": "",
"invalidation_condition": "",
"position_pct": 20 if verdict == "execute" else 0,
"analysis": analysis or "暂无分析",
"risk_flag": "",
}
def _extract_json_object(text: str) -> dict | None:
match = re.search(r"\{[\s\S]*\}", text)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except Exception:
return None
def _parse_float(value) -> float | None:
try:
if value in (None, "", 0, "0"):
return None
return float(value)
except Exception:
return None
def _clamp_int(value, minimum: int, maximum: int, default: int) -> int:
try:
parsed = int(round(float(value)))
except Exception:
return default
return max(minimum, min(maximum, parsed))
async def analyze_candidates_individually(
candidates: list[dict], market_summary: str, max_concurrent: int = 3
) -> dict[str, dict]:
"""对候选股票逐个做 LLM 分析(控制并发数)
返回: {ts_code: {"verdict", "action_plan", "conviction", "entry_price", ...}}
"""
if not settings.deepseek_api_key or not candidates:
return {}
results = {}
semaphore = asyncio.Semaphore(max_concurrent)
async def _analyze_with_semaphore(c: dict):
async with semaphore:
ts_code = c["ts_code"]
logger.info(f"LLM 分析: {c.get('name', ts_code)}")
result = await analyze_single_stock(c, market_summary)
logger.info(
f"LLM 结果: {c.get('name', ts_code)}"
f"裁决={result['verdict']} 计划={result['action_plan']} 置信度={result['conviction']} "
f"买入={result.get('entry_price')} 止盈={result.get('target_price')} "
f"止损={result.get('stop_loss')}"
)
return ts_code, result
tasks = [_analyze_with_semaphore(c) for c in candidates]
completed = await asyncio.gather(*tasks, return_exceptions=True)
for item in completed:
if isinstance(item, Exception):
logger.error(f"LLM 分析任务异常: {item}")
continue
ts_code, result = item
results[ts_code] = result
logger.info(f"LLM 逐股分析完成: {len(results)}/{len(candidates)}")
return results