280 lines
9.0 KiB
Python
280 lines
9.0 KiB
Python
"""
|
||
多模型LLM服务 - 支持智谱AI和DeepSeek
|
||
"""
|
||
from typing import Optional, List, Dict, Any
|
||
from app.config import get_settings
|
||
from app.utils.logger import logger
|
||
|
||
# 智谱AI
|
||
try:
|
||
from zhipuai import ZhipuAI
|
||
ZHIPUAI_AVAILABLE = True
|
||
except ImportError:
|
||
ZHIPUAI_AVAILABLE = False
|
||
logger.warning("zhipuai包未安装")
|
||
|
||
# DeepSeek (使用OpenAI兼容接口)
|
||
try:
|
||
from openai import OpenAI
|
||
OPENAI_AVAILABLE = True
|
||
except ImportError:
|
||
OPENAI_AVAILABLE = False
|
||
logger.warning("openai包未安装")
|
||
|
||
|
||
class MultiLLMService:
|
||
"""多模型LLM服务类"""
|
||
|
||
def __init__(self):
|
||
"""初始化多模型LLM服务"""
|
||
settings = get_settings()
|
||
|
||
self.clients = {}
|
||
self.current_model = None
|
||
self.model_info = {}
|
||
|
||
# 初始化智谱AI
|
||
if ZHIPUAI_AVAILABLE and settings.zhipuai_api_key:
|
||
try:
|
||
api_key = settings.zhipuai_api_key.strip()
|
||
if '.' in api_key and len(api_key) > 10:
|
||
self.clients['zhipu'] = ZhipuAI(api_key=api_key)
|
||
self.model_info['zhipu'] = {
|
||
'name': '智谱AI GLM-4',
|
||
'model_id': 'glm-4',
|
||
'provider': 'zhipu',
|
||
'available': True
|
||
}
|
||
logger.info("智谱AI初始化成功")
|
||
except Exception as e:
|
||
logger.error(f"智谱AI初始化失败: {e}")
|
||
|
||
# 初始化DeepSeek
|
||
if OPENAI_AVAILABLE and settings.deepseek_api_key:
|
||
try:
|
||
self.clients['deepseek'] = OpenAI(
|
||
api_key=settings.deepseek_api_key,
|
||
base_url="https://api.deepseek.com"
|
||
)
|
||
self.model_info['deepseek'] = {
|
||
'name': 'DeepSeek Chat',
|
||
'model_id': 'deepseek-chat',
|
||
'provider': 'deepseek',
|
||
'available': True
|
||
}
|
||
logger.info("DeepSeek初始化成功")
|
||
except Exception as e:
|
||
logger.error(f"DeepSeek初始化失败: {e}")
|
||
|
||
# 设置默认模型(优先DeepSeek,因为更便宜)
|
||
if 'deepseek' in self.clients:
|
||
self.current_model = 'deepseek'
|
||
elif 'zhipu' in self.clients:
|
||
self.current_model = 'zhipu'
|
||
|
||
if self.current_model:
|
||
logger.info(f"当前使用模型: {self.model_info[self.current_model]['name']}")
|
||
else:
|
||
logger.warning("没有可用的LLM模型")
|
||
|
||
def get_available_models(self) -> List[Dict[str, Any]]:
|
||
"""获取所有可用的模型列表"""
|
||
return [info for info in self.model_info.values() if info['available']]
|
||
|
||
def get_current_model_info(self) -> Optional[Dict[str, Any]]:
|
||
"""获取当前使用的模型信息"""
|
||
if self.current_model:
|
||
return self.model_info[self.current_model]
|
||
return None
|
||
|
||
def switch_model(self, provider: str) -> bool:
|
||
"""
|
||
切换模型
|
||
|
||
Args:
|
||
provider: 模型提供商 ('zhipu' 或 'deepseek')
|
||
|
||
Returns:
|
||
是否切换成功
|
||
"""
|
||
if provider in self.clients:
|
||
self.current_model = provider
|
||
logger.info(f"切换到模型: {self.model_info[provider]['name']}")
|
||
return True
|
||
else:
|
||
logger.error(f"模型不可用: {provider}")
|
||
return False
|
||
|
||
def chat(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2000,
|
||
model_override: Optional[str] = None
|
||
) -> Optional[str]:
|
||
"""
|
||
调用LLM进行对话
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
model_override: 临时覆盖使用的模型
|
||
|
||
Returns:
|
||
LLM响应文本
|
||
"""
|
||
provider = model_override or self.current_model
|
||
|
||
if not provider or provider not in self.clients:
|
||
logger.error("没有可用的LLM客户端")
|
||
return None
|
||
|
||
try:
|
||
client = self.clients[provider]
|
||
model_id = self.model_info[provider]['model_id']
|
||
|
||
logger.info(f"调用LLM: provider={provider}, model={model_id}, messages={len(messages)}条")
|
||
|
||
if provider == 'zhipu':
|
||
# 智谱AI调用
|
||
response = client.chat.completions.create(
|
||
model=model_id,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
elif provider == 'deepseek':
|
||
# DeepSeek调用(OpenAI兼容)
|
||
response = client.chat.completions.create(
|
||
model=model_id,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
else:
|
||
logger.error(f"未知的模型提供商: {provider}")
|
||
return None
|
||
|
||
if response.choices:
|
||
content = response.choices[0].message.content
|
||
logger.info(f"LLM响应成功,长度: {len(content) if content else 0}")
|
||
return content
|
||
else:
|
||
logger.warning("LLM响应中没有choices")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM调用失败: {type(e).__name__}: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
return None
|
||
|
||
def chat_stream(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 2000,
|
||
model_override: Optional[str] = None
|
||
):
|
||
"""
|
||
流式调用LLM进行对话
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
model_override: 临时覆盖使用的模型
|
||
|
||
Yields:
|
||
LLM响应的文本片段
|
||
"""
|
||
provider = model_override or self.current_model
|
||
|
||
if not provider or provider not in self.clients:
|
||
logger.error("没有可用的LLM客户端")
|
||
return
|
||
|
||
try:
|
||
client = self.clients[provider]
|
||
model_id = self.model_info[provider]['model_id']
|
||
|
||
logger.info(f"流式调用LLM: provider={provider}, model={model_id}, messages={len(messages)}条")
|
||
|
||
if provider == 'zhipu':
|
||
# 智谱AI流式调用
|
||
response = client.chat.completions.create(
|
||
model=model_id,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
stream=True
|
||
)
|
||
for chunk in response:
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
yield chunk.choices[0].delta.content
|
||
|
||
elif provider == 'deepseek':
|
||
# DeepSeek流式调用(OpenAI兼容)
|
||
response = client.chat.completions.create(
|
||
model=model_id,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
stream=True
|
||
)
|
||
for chunk in response:
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
yield chunk.choices[0].delta.content
|
||
|
||
else:
|
||
logger.error(f"未知的模型提供商: {provider}")
|
||
return
|
||
|
||
logger.info("LLM流式响应完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM流式调用失败: {type(e).__name__}: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
return
|
||
|
||
def analyze_intent(self, user_message: str) -> Dict[str, Any]:
|
||
"""使用LLM分析用户意图"""
|
||
if not self.current_model:
|
||
return {"type": "unknown", "confidence": 0}
|
||
|
||
prompt = f"""你是一个股票分析助手的意图识别模块。请分析用户的查询意图。
|
||
|
||
用户消息:{user_message}
|
||
|
||
请识别以下意图类型之一:
|
||
1. market_data - 查询实时行情、价格
|
||
2. technical_analysis - 技术分析、技术指标
|
||
3. fundamental - 基本面信息、公司信息
|
||
4. visualization - K线图、图表
|
||
5. unknown - 无法识别
|
||
|
||
请以JSON格式返回:
|
||
{{
|
||
"type": "意图类型",
|
||
"confidence": 0.0-1.0,
|
||
"stock_name": "提取的股票名称(如果有)"
|
||
}}
|
||
"""
|
||
|
||
try:
|
||
response = self.chat([{"role": "user", "content": prompt}], temperature=0.3)
|
||
if response:
|
||
import json
|
||
result = json.loads(response)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"意图分析失败: {e}")
|
||
|
||
return {"type": "unknown", "confidence": 0}
|
||
|
||
|
||
# 创建全局实例
|
||
multi_llm_service = MultiLLMService()
|