update
This commit is contained in:
parent
3eac517d9c
commit
c5c88bd73e
@ -3,6 +3,7 @@
|
||||
"""
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.config import get_settings
|
||||
from app.agent.context import ContextManager
|
||||
@ -37,6 +38,14 @@ class SmartStockAgent:
|
||||
else:
|
||||
logger.warning("Smart Agent初始化完成(规则模式,建议配置LLM)")
|
||||
|
||||
async def _call_llm_async(self, messages: List[Dict[str, str]], temperature: float = 0.7, max_tokens: int = 2000) -> Optional[str]:
|
||||
"""异步调用LLM,避免阻塞事件循环"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: llm_service.chat(messages, temperature, max_tokens)
|
||||
)
|
||||
|
||||
def _register_skills(self):
|
||||
"""注册所有技能"""
|
||||
skill_manager.register(MarketDataSkill())
|
||||
@ -411,7 +420,7 @@ DIF和DEA的位置关系,MACD柱状图变化,判断动能强弱和买卖信
|
||||
"""
|
||||
|
||||
try:
|
||||
analysis = llm_service.chat(
|
||||
analysis = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=3000 # 增加到3000,因为分析更详细了
|
||||
@ -564,7 +573,7 @@ DIF和DEA的位置关系,MACD柱状图变化,判断动能强弱和买卖信
|
||||
"""
|
||||
|
||||
try:
|
||||
analysis = llm_service.chat(
|
||||
analysis = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=1500
|
||||
@ -751,7 +760,7 @@ MA60:{f"{ma['ma60']:.2f}" if ma['ma60'] else '计算中'}
|
||||
"""
|
||||
|
||||
try:
|
||||
analysis = llm_service.chat(
|
||||
analysis = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=2500
|
||||
@ -1032,7 +1041,7 @@ MA60:{f"{ma['ma60']:.2f}" if ma['ma60'] else '计算中'}
|
||||
只返回JSON,不要有任何其他内容。"""
|
||||
|
||||
try:
|
||||
result = llm_service.chat(
|
||||
result = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
max_tokens=300
|
||||
@ -1094,50 +1103,18 @@ MA60:{f"{ma['ma60']:.2f}" if ma['ma60'] else '计算中'}
|
||||
# 处理美股
|
||||
return await self._handle_us_stock(stock_keyword, message)
|
||||
|
||||
# 处理A股和指数
|
||||
# 指数映射表
|
||||
index_mapping = {
|
||||
"上证指数": "000001.SH",
|
||||
"上证": "000001.SH",
|
||||
"大盘": "000001.SH",
|
||||
"沪指": "000001.SH",
|
||||
"深证成指": "399001.SZ",
|
||||
"深证": "399001.SZ",
|
||||
"深指": "399001.SZ",
|
||||
"创业板指": "399006.SZ",
|
||||
"创业板": "399006.SZ",
|
||||
"科创50": "000688.SH",
|
||||
"沪深300": "000300.SH",
|
||||
"中证500": "000905.SH",
|
||||
"A股": "000001.SH" # 默认用上证指数代表A股
|
||||
}
|
||||
# 处理A股和指数 - 使用LLM进行智能匹配
|
||||
stock_info = await self._match_stock_with_llm(stock_keyword)
|
||||
|
||||
# 检查是否是指数查询
|
||||
stock_code = None
|
||||
stock_name = None
|
||||
is_index = False
|
||||
if not stock_info:
|
||||
return {
|
||||
"message": f"抱歉,未找到股票或指数\"{stock_keyword}\"。请确认名称或代码是否正确。",
|
||||
"metadata": {"type": "error"}
|
||||
}
|
||||
|
||||
for key, code in index_mapping.items():
|
||||
if key in stock_keyword or stock_keyword in key:
|
||||
stock_code = code
|
||||
stock_name = key if key in stock_keyword else stock_keyword
|
||||
is_index = True
|
||||
logger.info(f"识别为指数查询: {stock_name} -> {stock_code}")
|
||||
break
|
||||
|
||||
# 如果不是指数,使用Tushare搜索股票
|
||||
if not is_index:
|
||||
search_results = tushare_service.search_stock(stock_keyword)
|
||||
|
||||
if not search_results:
|
||||
return {
|
||||
"message": f"抱歉,未找到股票\"{stock_keyword}\"。请确认股票名称或代码是否正确。",
|
||||
"metadata": {"type": "error"}
|
||||
}
|
||||
|
||||
stock = search_results[0]
|
||||
stock_code = stock['symbol']
|
||||
stock_name = stock['name']
|
||||
stock_code = stock_info['code']
|
||||
stock_name = stock_info['name']
|
||||
is_index = stock_info['is_index']
|
||||
|
||||
logger.info(f"处理{'指数' if is_index else '股票'}问题: {stock_name}({stock_code})")
|
||||
|
||||
@ -1152,6 +1129,126 @@ MA60:{f"{ma['ma60']:.2f}" if ma['ma60'] else '计算中'}
|
||||
else:
|
||||
return await self._single_query(stock_code, stock_name, message)
|
||||
|
||||
async def _match_stock_with_llm(self, keyword: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
使用LLM智能匹配股票或指数
|
||||
|
||||
Args:
|
||||
keyword: 用户输入的关键词
|
||||
|
||||
Returns:
|
||||
匹配结果: {'code': '股票代码', 'name': '股票名称', 'is_index': bool}
|
||||
"""
|
||||
if not self.use_llm:
|
||||
# 降级方案:使用Tushare搜索
|
||||
search_results = tushare_service.search_stock(keyword)
|
||||
if search_results:
|
||||
return {
|
||||
'code': search_results[0]['symbol'],
|
||||
'name': search_results[0]['name'],
|
||||
'is_index': False
|
||||
}
|
||||
return None
|
||||
|
||||
prompt = f"""你是一个专业的A股市场专家。请根据用户输入的关键词,识别对应的股票代码或指数代码。
|
||||
|
||||
用户输入:{keyword}
|
||||
|
||||
常见指数代码:
|
||||
- 上证指数/大盘/沪指/A股 → 000001.SH
|
||||
- 深证成指/深证/深指 → 399001.SZ
|
||||
- 创业板指/创业板 → 399006.SZ
|
||||
- 科创50 → 000688.SH
|
||||
- 沪深300 → 000300.SH
|
||||
- 中证500 → 000905.SH
|
||||
|
||||
如果是指数,请直接返回对应的指数代码。
|
||||
如果是股票名称或代码,请使用Tushare数据库进行搜索匹配。
|
||||
|
||||
请以JSON格式返回:
|
||||
{{
|
||||
"is_index": true/false,
|
||||
"code": "股票或指数代码(如000001.SH)",
|
||||
"name": "股票或指数名称",
|
||||
"confidence": 0.0-1.0
|
||||
}}
|
||||
|
||||
如果无法匹配,返回:
|
||||
{{
|
||||
"is_index": false,
|
||||
"code": null,
|
||||
"name": null,
|
||||
"confidence": 0.0
|
||||
}}
|
||||
|
||||
只返回JSON,不要有任何其他内容。"""
|
||||
|
||||
try:
|
||||
result = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if not result:
|
||||
logger.warning("LLM匹配返回空结果")
|
||||
return None
|
||||
|
||||
# 清理结果
|
||||
result = result.strip()
|
||||
if result.startswith("```json"):
|
||||
result = result[7:]
|
||||
if result.startswith("```"):
|
||||
result = result[3:]
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
result = result.strip()
|
||||
|
||||
# 解析JSON
|
||||
match_result = json.loads(result)
|
||||
|
||||
# 如果LLM无法匹配或置信度太低,使用Tushare搜索
|
||||
if not match_result.get('code') or match_result.get('confidence', 0) < 0.5:
|
||||
logger.info(f"LLM匹配置信度低,使用Tushare搜索: {keyword}")
|
||||
search_results = tushare_service.search_stock(keyword)
|
||||
if search_results:
|
||||
return {
|
||||
'code': search_results[0]['symbol'],
|
||||
'name': search_results[0]['name'],
|
||||
'is_index': False
|
||||
}
|
||||
return None
|
||||
|
||||
logger.info(f"LLM匹配成功: {keyword} -> {match_result['name']}({match_result['code']})")
|
||||
return {
|
||||
'code': match_result['code'],
|
||||
'name': match_result['name'],
|
||||
'is_index': match_result['is_index']
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"LLM匹配JSON解析失败: {e}, 原始响应: {result[:200] if result else 'None'}")
|
||||
# 降级方案
|
||||
search_results = tushare_service.search_stock(keyword)
|
||||
if search_results:
|
||||
return {
|
||||
'code': search_results[0]['symbol'],
|
||||
'name': search_results[0]['name'],
|
||||
'is_index': False
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LLM匹配失败: {e}")
|
||||
# 降级方案
|
||||
search_results = tushare_service.search_stock(keyword)
|
||||
if search_results:
|
||||
return {
|
||||
'code': search_results[0]['symbol'],
|
||||
'name': search_results[0]['name'],
|
||||
'is_index': False
|
||||
}
|
||||
return None
|
||||
|
||||
async def _handle_macro_question(
|
||||
self,
|
||||
intent_analysis: Dict[str, Any],
|
||||
@ -1193,7 +1290,7 @@ MA60:{f"{ma['ma60']:.2f}" if ma['ma60'] else '计算中'}
|
||||
5. 最后声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。"
|
||||
"""
|
||||
|
||||
analysis = llm_service.chat(
|
||||
analysis = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=1500
|
||||
@ -1262,7 +1359,7 @@ MA60:{f"{ma['ma60']:.2f}" if ma['ma60'] else '计算中'}
|
||||
"""
|
||||
|
||||
try:
|
||||
answer = llm_service.chat(
|
||||
answer = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=1200
|
||||
@ -1318,7 +1415,7 @@ MA60:{f"{ma['ma60']:.2f}" if ma['ma60'] else '计算中'}
|
||||
直接返回回复内容,不要有其他格式。"""
|
||||
|
||||
try:
|
||||
reply = llm_service.chat(
|
||||
reply = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.8,
|
||||
max_tokens=400
|
||||
@ -1584,7 +1681,7 @@ MACD:{f"{technical.get('macd'):.4f}" if technical.get('macd') else '计算中'
|
||||
"""
|
||||
|
||||
try:
|
||||
analysis = llm_service.chat(
|
||||
analysis = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
@ -1746,34 +1843,16 @@ RSI:{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
|
||||
|
||||
async def _handle_a_stock_stream(self, stock_keyword: str, message: str):
|
||||
"""流式处理A股分析"""
|
||||
# 指数映射
|
||||
index_mapping = {
|
||||
"上证指数": "000001.SH", "上证": "000001.SH", "大盘": "000001.SH", "沪指": "000001.SH",
|
||||
"深证成指": "399001.SZ", "深证": "399001.SZ", "深指": "399001.SZ",
|
||||
"创业板指": "399006.SZ", "创业板": "399006.SZ",
|
||||
"科创50": "000688.SH", "沪深300": "000300.SH", "中证500": "000905.SH",
|
||||
"A股": "000001.SH"
|
||||
}
|
||||
# 使用LLM进行智能匹配
|
||||
stock_info = await self._match_stock_with_llm(stock_keyword)
|
||||
|
||||
stock_code = None
|
||||
stock_name = None
|
||||
is_index = False
|
||||
if not stock_info:
|
||||
yield f"抱歉,未找到股票或指数\"{stock_keyword}\"。请确认名称或代码是否正确。"
|
||||
return
|
||||
|
||||
for key, code in index_mapping.items():
|
||||
if key in stock_keyword or stock_keyword in key:
|
||||
stock_code = code
|
||||
stock_name = key if key in stock_keyword else stock_keyword
|
||||
is_index = True
|
||||
break
|
||||
|
||||
if not is_index:
|
||||
search_results = tushare_service.search_stock(stock_keyword)
|
||||
if not search_results:
|
||||
yield f"抱歉,未找到股票\"{stock_keyword}\"。请确认股票名称或代码是否正确。"
|
||||
return
|
||||
stock = search_results[0]
|
||||
stock_code = stock['symbol']
|
||||
stock_name = stock['name']
|
||||
stock_code = stock_info['code']
|
||||
stock_name = stock_info['name']
|
||||
is_index = stock_info['is_index']
|
||||
|
||||
# 获取数据(非流式)
|
||||
try:
|
||||
@ -1926,13 +2005,19 @@ RSI:{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
|
||||
5. 最后声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。"
|
||||
"""
|
||||
|
||||
# 流式调用LLM(同步生成器)
|
||||
# 流式调用LLM(同步生成器,使用线程避免阻塞)
|
||||
import asyncio
|
||||
stream = llm_service.chat_stream(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
# 在线程中迭代同步生成器,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
for chunk in stream:
|
||||
# 每次yield后让出控制权
|
||||
await asyncio.sleep(0)
|
||||
yield chunk
|
||||
|
||||
async def _llm_us_stock_analysis_stream(self, data: Dict[str, Any], user_message: str):
|
||||
@ -2022,13 +2107,18 @@ MACD:{f"{technical.get('macd'):.4f}" if technical.get('macd') else '计算中'
|
||||
6. 最后声明:"以上分析仅供参考,不构成投资建议。美股投资有风险,请谨慎决策。"
|
||||
"""
|
||||
|
||||
# 流式调用LLM(同步生成器)
|
||||
# 流式调用LLM(同步生成器,使用线程避免阻塞)
|
||||
import asyncio
|
||||
stream = llm_service.chat_stream(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
# 在线程中迭代同步生成器,避免阻塞事件循环
|
||||
for chunk in stream:
|
||||
# 每次yield后让出控制权
|
||||
await asyncio.sleep(0)
|
||||
yield chunk
|
||||
|
||||
|
||||
|
||||
2126
backend/app/agent/smart_agent.py.bak
Normal file
2126
backend/app/agent/smart_agent.py.bak
Normal file
File diff suppressed because it is too large
Load Diff
@ -146,10 +146,12 @@ class MultiLLMService:
|
||||
)
|
||||
elif provider == 'deepseek':
|
||||
# DeepSeek调用(OpenAI兼容)
|
||||
# DeepSeek对参数更严格,确保temperature在有效范围内
|
||||
safe_temperature = max(0.0, min(2.0, temperature))
|
||||
response = client.chat.completions.create(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
temperature=safe_temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
else:
|
||||
@ -216,10 +218,12 @@ class MultiLLMService:
|
||||
|
||||
elif provider == 'deepseek':
|
||||
# DeepSeek流式调用(OpenAI兼容)
|
||||
# DeepSeek对参数更严格,确保temperature在有效范围内
|
||||
safe_temperature = max(0.0, min(2.0, temperature))
|
||||
response = client.chat.completions.create(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
temperature=safe_temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
|
||||
@ -64,10 +64,9 @@
|
||||
<div class="example-queries">
|
||||
<button class="example-btn" @click="sendExample('分析贵州茅台')">分析贵州茅台</button>
|
||||
<button class="example-btn" @click="sendExample('比亚迪怎么样')">比亚迪怎么样</button>
|
||||
<button class="example-btn" @click="sendExample('上证指数走势')">上证指数走势</button>
|
||||
<button class="example-btn" @click="sendExample('分析特斯拉')">分析特斯拉</button>
|
||||
<button class="example-btn" @click="sendExample('苹果股票')">苹果股票</button>
|
||||
<button class="example-btn" @click="sendExample('NVDA基本面')">NVDA基本面</button>
|
||||
<button class="example-btn" @click="sendExample('英伟达股票怎么样')">英伟达股票怎么样</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user