stock-ai-agent/backend/app/agent/context.py
2026-02-03 10:08:15 +08:00

94 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
上下文管理器
管理对话历史和上下文
"""
from typing import List, Dict, Optional
from app.services.db_service import db_service
from app.utils.logger import logger
class ContextManager:
"""上下文管理器"""
def __init__(self, max_history: int = 10):
"""
初始化上下文管理器
Args:
max_history: 最大历史消息数
"""
self.max_history = max_history
def get_context(self, session_id: str) -> List[Dict[str, str]]:
"""
获取对话上下文
Args:
session_id: 会话ID
Returns:
消息列表
"""
messages = db_service.get_conversation_history(session_id, limit=self.max_history)
context = []
for msg in messages:
context.append({
"role": msg.role,
"content": msg.content
})
return context
def add_message(
self,
session_id: str,
role: str,
content: str,
metadata: Optional[dict] = None
):
"""
添加消息到上下文
Args:
session_id: 会话ID
role: 角色user/assistant
content: 消息内容
metadata: 元数据
"""
db_service.add_message(session_id, role, content, metadata)
logger.info(f"添加消息到上下文: {session_id}, {role}")
def clear_context(self, session_id: str):
"""
清除上下文(暂不实现删除,保留历史)
Args:
session_id: 会话ID
"""
logger.info(f"清除上下文请求: {session_id}")
# 实际不删除,只是标记
pass
def format_context_for_llm(self, session_id: str) -> str:
"""
格式化上下文供LLM使用
Args:
session_id: 会话ID
Returns:
格式化的上下文字符串
"""
context = self.get_context(session_id)
if not context:
return ""
formatted = []
for msg in context:
role = "用户" if msg["role"] == "user" else "助手"
formatted.append(f"{role}: {msg['content']}")
return "\n".join(formatted)