""" 上下文管理器 管理对话历史和上下文 """ from typing import List, Dict, Optional from app.services.db_service import db_service from app.utils.logger import logger class ContextManager: """上下文管理器""" def __init__(self, max_history: int = 10): """ 初始化上下文管理器 Args: max_history: 最大历史消息数 """ self.max_history = max_history def get_context(self, session_id: str) -> List[Dict[str, str]]: """ 获取对话上下文 Args: session_id: 会话ID Returns: 消息列表 """ messages = db_service.get_conversation_history(session_id, limit=self.max_history) context = [] for msg in messages: context.append({ "role": msg.role, "content": msg.content, "metadata": msg.metadata if hasattr(msg, 'metadata') else {} }) return context def add_message( self, session_id: str, role: str, content: str, metadata: Optional[dict] = None, user_id: Optional[int] = None ): """ 添加消息到上下文 Args: session_id: 会话ID role: 角色(user/assistant) content: 消息内容 metadata: 元数据 user_id: 用户ID(创建新对话时需要) """ db_service.add_message(session_id, role, content, metadata, user_id) logger.info(f"添加消息到上下文: {session_id}, {role}") def clear_context(self, session_id: str): """ 清除上下文(暂不实现删除,保留历史) Args: session_id: 会话ID """ logger.info(f"清除上下文请求: {session_id}") # 实际不删除,只是标记 pass def format_context_for_llm(self, session_id: str) -> str: """ 格式化上下文供LLM使用 Args: session_id: 会话ID Returns: 格式化的上下文字符串 """ context = self.get_context(session_id) if not context: return "" formatted = [] for msg in context: role = "用户" if msg["role"] == "user" else "助手" formatted.append(f"{role}: {msg['content']}") return "\n".join(formatted) def extract_context_info(self, session_id: str) -> Dict: """ 提取上下文信息 Args: session_id: 会话ID Returns: ContextInfo: { 'last_stock': str | None, # 上次讨论的股票 'last_topic': str | None, # 上次的话题 'user_preferences': dict # 用户偏好 } """ history = self.get_context(session_id) return { 'last_stock': self._extract_last_stock(history), 'last_topic': self._extract_last_topic(history), 'user_preferences': self._analyze_user_preferences(history) } def _extract_last_stock(self, history: List[Dict]) -> Optional[str]: """ 从历史对话中提取最后讨论的股票 Args: history: 对话历史 Returns: 股票代码或None """ # 从后往前查找 for msg in reversed(history): if msg['role'] == 'assistant': metadata = msg.get('metadata', {}) if isinstance(metadata, dict): # 尝试从不同位置提取股票代码 if 'data' in metadata: data = metadata['data'] if isinstance(data, dict): if 'stock_code' in data: return data['stock_code'] if 'ts_code' in data: return data['ts_code'] # 尝试从intent中提取 if 'intent' in metadata: intent = metadata['intent'] if isinstance(intent, dict) and 'target' in intent: target = intent['target'] if isinstance(target, dict) and 'stock_code' in target: return target['stock_code'] return None def _extract_last_topic(self, history: List[Dict]) -> Optional[str]: """ 从历史对话中提取最后的话题 Args: history: 对话历史 Returns: 话题或None """ if not history: return None # 获取最后一条用户消息 for msg in reversed(history): if msg['role'] == 'user': content = msg['content'] # 简单提取话题(前50个字符) return content[:50] if len(content) > 50 else content return None def _analyze_user_preferences(self, history: List[Dict]) -> Dict: """ 分析用户偏好 Args: history: 对话历史 Returns: 用户偏好字典 """ preferences = { 'preferred_style': 'casual', 'typical_time_scope': 'short_term', 'frequent_dimensions': [] } # 简单的偏好分析(可以后续扩展) if len(history) > 5: # 如果对话较多,可能是专业用户 preferences['preferred_style'] = 'professional' return preferences