""" 数据库服务 提供数据库操作功能 """ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from typing import Optional, List from datetime import datetime import uuid from app.config import get_settings from app.models.database import Base, Conversation, Message, UserPreference from app.utils.logger import logger class DatabaseService: """数据库服务类""" def __init__(self): """初始化数据库连接""" settings = get_settings() self.engine = create_engine( settings.database_url, connect_args={"check_same_thread": False} if "sqlite" in settings.database_url else {} ) self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) # 创建表 Base.metadata.create_all(bind=self.engine) logger.info("数据库初始化成功") def get_session(self) -> Session: """获取数据库会话""" return self.SessionLocal() def create_conversation(self, session_id: Optional[str] = None, user_id: Optional[str] = None) -> Conversation: """ 创建新对话 Args: session_id: 会话ID(可选,自动生成) user_id: 用户ID Returns: 对话对象 """ db = self.get_session() try: if not session_id: session_id = str(uuid.uuid4()) conversation = Conversation( session_id=session_id, user_id=user_id ) db.add(conversation) db.commit() db.refresh(conversation) return conversation finally: db.close() def get_conversation(self, session_id: str) -> Optional[Conversation]: """ 获取对话 Args: session_id: 会话ID Returns: 对话对象或None """ db = self.get_session() try: return db.query(Conversation).filter(Conversation.session_id == session_id).first() finally: db.close() def add_message( self, session_id: str, role: str, content: str, metadata: Optional[dict] = None, user_id: Optional[int] = None ) -> Message: """ 添加消息 Args: session_id: 会话ID role: 角色(user/assistant) content: 消息内容 metadata: 元数据 user_id: 用户ID(创建新对话时需要) Returns: 消息对象 """ db = self.get_session() try: # 获取或创建对话 conversation = db.query(Conversation).filter( Conversation.session_id == session_id ).first() if not conversation: if not user_id: raise ValueError("创建新对话时必须提供 user_id") conversation = Conversation(session_id=session_id, user_id=user_id) db.add(conversation) db.commit() db.refresh(conversation) # 创建消息 message = Message( conversation_id=conversation.id, role=role, content=content, msg_metadata=metadata ) db.add(message) db.commit() db.refresh(message) return message finally: db.close() def get_conversation_history(self, session_id: str, limit: int = 50) -> List[Message]: """ 获取对话历史 Args: session_id: 会话ID limit: 最大消息数 Returns: 消息列表 """ db = self.get_session() try: conversation = db.query(Conversation).filter( Conversation.session_id == session_id ).first() if not conversation: return [] messages = db.query(Message).filter( Message.conversation_id == conversation.id ).order_by(Message.created_at.desc()).limit(limit).all() return list(reversed(messages)) finally: db.close() def get_user_preference(self, user_id: str) -> Optional[dict]: """ 获取用户偏好 Args: user_id: 用户ID Returns: 偏好字典或None """ db = self.get_session() try: pref = db.query(UserPreference).filter( UserPreference.user_id == user_id ).first() return pref.preferences if pref else None finally: db.close() def set_user_preference(self, user_id: str, preferences: dict) -> bool: """ 设置用户偏好 Args: user_id: 用户ID preferences: 偏好字典 Returns: 是否成功 """ db = self.get_session() try: pref = db.query(UserPreference).filter( UserPreference.user_id == user_id ).first() if pref: pref.preferences = preferences pref.updated_at = datetime.utcnow() else: pref = UserPreference( user_id=user_id, preferences=preferences ) db.add(pref) db.commit() return True except Exception as e: logger.error(f"设置用户偏好失败: {e}") db.rollback() return False finally: db.close() # 创建全局实例 db_service = DatabaseService()