stock-ai-agent/backend/app/agent/context.py
2026-02-04 14:56:03 +08:00

198 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
上下文管理器
管理对话历史和上下文
"""
from typing import List, Dict, Optional
from app.services.db_service import db_service
from app.utils.logger import logger
class ContextManager:
"""上下文管理器"""
def __init__(self, max_history: int = 10):
"""
初始化上下文管理器
Args:
max_history: 最大历史消息数
"""
self.max_history = max_history
def get_context(self, session_id: str) -> List[Dict[str, str]]:
"""
获取对话上下文
Args:
session_id: 会话ID
Returns:
消息列表
"""
messages = db_service.get_conversation_history(session_id, limit=self.max_history)
context = []
for msg in messages:
context.append({
"role": msg.role,
"content": msg.content,
"metadata": msg.metadata if hasattr(msg, 'metadata') else {}
})
return context
def add_message(
self,
session_id: str,
role: str,
content: str,
metadata: Optional[dict] = None,
user_id: Optional[int] = None
):
"""
添加消息到上下文
Args:
session_id: 会话ID
role: 角色user/assistant
content: 消息内容
metadata: 元数据
user_id: 用户ID创建新对话时需要
"""
db_service.add_message(session_id, role, content, metadata, user_id)
logger.info(f"添加消息到上下文: {session_id}, {role}")
def clear_context(self, session_id: str):
"""
清除上下文(暂不实现删除,保留历史)
Args:
session_id: 会话ID
"""
logger.info(f"清除上下文请求: {session_id}")
# 实际不删除,只是标记
pass
def format_context_for_llm(self, session_id: str) -> str:
"""
格式化上下文供LLM使用
Args:
session_id: 会话ID
Returns:
格式化的上下文字符串
"""
context = self.get_context(session_id)
if not context:
return ""
formatted = []
for msg in context:
role = "用户" if msg["role"] == "user" else "助手"
formatted.append(f"{role}: {msg['content']}")
return "\n".join(formatted)
def extract_context_info(self, session_id: str) -> Dict:
"""
提取上下文信息
Args:
session_id: 会话ID
Returns:
ContextInfo: {
'last_stock': str | None, # 上次讨论的股票
'last_topic': str | None, # 上次的话题
'user_preferences': dict # 用户偏好
}
"""
history = self.get_context(session_id)
return {
'last_stock': self._extract_last_stock(history),
'last_topic': self._extract_last_topic(history),
'user_preferences': self._analyze_user_preferences(history)
}
def _extract_last_stock(self, history: List[Dict]) -> Optional[str]:
"""
从历史对话中提取最后讨论的股票
Args:
history: 对话历史
Returns:
股票代码或None
"""
# 从后往前查找
for msg in reversed(history):
if msg['role'] == 'assistant':
metadata = msg.get('metadata', {})
if isinstance(metadata, dict):
# 尝试从不同位置提取股票代码
if 'data' in metadata:
data = metadata['data']
if isinstance(data, dict):
if 'stock_code' in data:
return data['stock_code']
if 'ts_code' in data:
return data['ts_code']
# 尝试从intent中提取
if 'intent' in metadata:
intent = metadata['intent']
if isinstance(intent, dict) and 'target' in intent:
target = intent['target']
if isinstance(target, dict) and 'stock_code' in target:
return target['stock_code']
return None
def _extract_last_topic(self, history: List[Dict]) -> Optional[str]:
"""
从历史对话中提取最后的话题
Args:
history: 对话历史
Returns:
话题或None
"""
if not history:
return None
# 获取最后一条用户消息
for msg in reversed(history):
if msg['role'] == 'user':
content = msg['content']
# 简单提取话题前50个字符
return content[:50] if len(content) > 50 else content
return None
def _analyze_user_preferences(self, history: List[Dict]) -> Dict:
"""
分析用户偏好
Args:
history: 对话历史
Returns:
用户偏好字典
"""
preferences = {
'preferred_style': 'casual',
'typical_time_scope': 'short_term',
'frequent_dimensions': []
}
# 简单的偏好分析(可以后续扩展)
if len(history) > 5:
# 如果对话较多,可能是专业用户
preferences['preferred_style'] = 'professional'
return preferences