198 lines
5.6 KiB
Python
198 lines
5.6 KiB
Python
"""
|
||
上下文管理器
|
||
管理对话历史和上下文
|
||
"""
|
||
from typing import List, Dict, Optional
|
||
from app.services.db_service import db_service
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class ContextManager:
|
||
"""上下文管理器"""
|
||
|
||
def __init__(self, max_history: int = 10):
|
||
"""
|
||
初始化上下文管理器
|
||
|
||
Args:
|
||
max_history: 最大历史消息数
|
||
"""
|
||
self.max_history = max_history
|
||
|
||
def get_context(self, session_id: str) -> List[Dict[str, str]]:
|
||
"""
|
||
获取对话上下文
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
messages = db_service.get_conversation_history(session_id, limit=self.max_history)
|
||
|
||
context = []
|
||
for msg in messages:
|
||
context.append({
|
||
"role": msg.role,
|
||
"content": msg.content,
|
||
"metadata": msg.metadata if hasattr(msg, 'metadata') else {}
|
||
})
|
||
|
||
return context
|
||
|
||
def add_message(
|
||
self,
|
||
session_id: str,
|
||
role: str,
|
||
content: str,
|
||
metadata: Optional[dict] = None,
|
||
user_id: Optional[int] = None
|
||
):
|
||
"""
|
||
添加消息到上下文
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
role: 角色(user/assistant)
|
||
content: 消息内容
|
||
metadata: 元数据
|
||
user_id: 用户ID(创建新对话时需要)
|
||
"""
|
||
db_service.add_message(session_id, role, content, metadata, user_id)
|
||
logger.info(f"添加消息到上下文: {session_id}, {role}")
|
||
|
||
def clear_context(self, session_id: str):
|
||
"""
|
||
清除上下文(暂不实现删除,保留历史)
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
"""
|
||
logger.info(f"清除上下文请求: {session_id}")
|
||
# 实际不删除,只是标记
|
||
pass
|
||
|
||
def format_context_for_llm(self, session_id: str) -> str:
|
||
"""
|
||
格式化上下文供LLM使用
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
|
||
Returns:
|
||
格式化的上下文字符串
|
||
"""
|
||
context = self.get_context(session_id)
|
||
|
||
if not context:
|
||
return ""
|
||
|
||
formatted = []
|
||
for msg in context:
|
||
role = "用户" if msg["role"] == "user" else "助手"
|
||
formatted.append(f"{role}: {msg['content']}")
|
||
|
||
return "\n".join(formatted)
|
||
|
||
def extract_context_info(self, session_id: str) -> Dict:
|
||
"""
|
||
提取上下文信息
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
|
||
Returns:
|
||
ContextInfo: {
|
||
'last_stock': str | None, # 上次讨论的股票
|
||
'last_topic': str | None, # 上次的话题
|
||
'user_preferences': dict # 用户偏好
|
||
}
|
||
"""
|
||
history = self.get_context(session_id)
|
||
|
||
return {
|
||
'last_stock': self._extract_last_stock(history),
|
||
'last_topic': self._extract_last_topic(history),
|
||
'user_preferences': self._analyze_user_preferences(history)
|
||
}
|
||
|
||
def _extract_last_stock(self, history: List[Dict]) -> Optional[str]:
|
||
"""
|
||
从历史对话中提取最后讨论的股票
|
||
|
||
Args:
|
||
history: 对话历史
|
||
|
||
Returns:
|
||
股票代码或None
|
||
"""
|
||
# 从后往前查找
|
||
for msg in reversed(history):
|
||
if msg['role'] == 'assistant':
|
||
metadata = msg.get('metadata', {})
|
||
if isinstance(metadata, dict):
|
||
# 尝试从不同位置提取股票代码
|
||
if 'data' in metadata:
|
||
data = metadata['data']
|
||
if isinstance(data, dict):
|
||
if 'stock_code' in data:
|
||
return data['stock_code']
|
||
if 'ts_code' in data:
|
||
return data['ts_code']
|
||
|
||
# 尝试从intent中提取
|
||
if 'intent' in metadata:
|
||
intent = metadata['intent']
|
||
if isinstance(intent, dict) and 'target' in intent:
|
||
target = intent['target']
|
||
if isinstance(target, dict) and 'stock_code' in target:
|
||
return target['stock_code']
|
||
|
||
return None
|
||
|
||
def _extract_last_topic(self, history: List[Dict]) -> Optional[str]:
|
||
"""
|
||
从历史对话中提取最后的话题
|
||
|
||
Args:
|
||
history: 对话历史
|
||
|
||
Returns:
|
||
话题或None
|
||
"""
|
||
if not history:
|
||
return None
|
||
|
||
# 获取最后一条用户消息
|
||
for msg in reversed(history):
|
||
if msg['role'] == 'user':
|
||
content = msg['content']
|
||
# 简单提取话题(前50个字符)
|
||
return content[:50] if len(content) > 50 else content
|
||
|
||
return None
|
||
|
||
def _analyze_user_preferences(self, history: List[Dict]) -> Dict:
|
||
"""
|
||
分析用户偏好
|
||
|
||
Args:
|
||
history: 对话历史
|
||
|
||
Returns:
|
||
用户偏好字典
|
||
"""
|
||
preferences = {
|
||
'preferred_style': 'casual',
|
||
'typical_time_scope': 'short_term',
|
||
'frequent_dimensions': []
|
||
}
|
||
|
||
# 简单的偏好分析(可以后续扩展)
|
||
if len(history) > 5:
|
||
# 如果对话较多,可能是专业用户
|
||
preferences['preferred_style'] = 'professional'
|
||
|
||
return preferences
|