215 lines
5.7 KiB
Python
215 lines
5.7 KiB
Python
"""
|
||
数据库服务
|
||
提供数据库操作功能
|
||
"""
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
from typing import Optional, List
|
||
from datetime import datetime
|
||
import uuid
|
||
|
||
from app.config import get_settings
|
||
from app.models.database import Base, Conversation, Message, UserPreference
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class DatabaseService:
|
||
"""数据库服务类"""
|
||
|
||
def __init__(self):
|
||
"""初始化数据库连接"""
|
||
settings = get_settings()
|
||
self.engine = create_engine(
|
||
settings.database_url,
|
||
connect_args={"check_same_thread": False} if "sqlite" in settings.database_url else {}
|
||
)
|
||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
||
|
||
# 创建表
|
||
Base.metadata.create_all(bind=self.engine)
|
||
logger.info("数据库初始化成功")
|
||
|
||
def get_session(self) -> Session:
|
||
"""获取数据库会话"""
|
||
return self.SessionLocal()
|
||
|
||
def create_conversation(self, session_id: Optional[str] = None, user_id: Optional[str] = None) -> Conversation:
|
||
"""
|
||
创建新对话
|
||
|
||
Args:
|
||
session_id: 会话ID(可选,自动生成)
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
对话对象
|
||
"""
|
||
db = self.get_session()
|
||
try:
|
||
if not session_id:
|
||
session_id = str(uuid.uuid4())
|
||
|
||
conversation = Conversation(
|
||
session_id=session_id,
|
||
user_id=user_id
|
||
)
|
||
db.add(conversation)
|
||
db.commit()
|
||
db.refresh(conversation)
|
||
return conversation
|
||
finally:
|
||
db.close()
|
||
|
||
def get_conversation(self, session_id: str) -> Optional[Conversation]:
|
||
"""
|
||
获取对话
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
|
||
Returns:
|
||
对话对象或None
|
||
"""
|
||
db = self.get_session()
|
||
try:
|
||
return db.query(Conversation).filter(Conversation.session_id == session_id).first()
|
||
finally:
|
||
db.close()
|
||
|
||
def add_message(
|
||
self,
|
||
session_id: str,
|
||
role: str,
|
||
content: str,
|
||
metadata: Optional[dict] = None,
|
||
user_id: Optional[int] = None
|
||
) -> Message:
|
||
"""
|
||
添加消息
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
role: 角色(user/assistant)
|
||
content: 消息内容
|
||
metadata: 元数据
|
||
user_id: 用户ID(创建新对话时需要)
|
||
|
||
Returns:
|
||
消息对象
|
||
"""
|
||
db = self.get_session()
|
||
try:
|
||
# 获取或创建对话
|
||
conversation = db.query(Conversation).filter(
|
||
Conversation.session_id == session_id
|
||
).first()
|
||
|
||
if not conversation:
|
||
if not user_id:
|
||
raise ValueError("创建新对话时必须提供 user_id")
|
||
conversation = Conversation(session_id=session_id, user_id=user_id)
|
||
db.add(conversation)
|
||
db.commit()
|
||
db.refresh(conversation)
|
||
|
||
# 创建消息
|
||
message = Message(
|
||
conversation_id=conversation.id,
|
||
role=role,
|
||
content=content,
|
||
msg_metadata=metadata
|
||
)
|
||
db.add(message)
|
||
db.commit()
|
||
db.refresh(message)
|
||
return message
|
||
finally:
|
||
db.close()
|
||
|
||
def get_conversation_history(self, session_id: str, limit: int = 50) -> List[Message]:
|
||
"""
|
||
获取对话历史
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
limit: 最大消息数
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
db = self.get_session()
|
||
try:
|
||
conversation = db.query(Conversation).filter(
|
||
Conversation.session_id == session_id
|
||
).first()
|
||
|
||
if not conversation:
|
||
return []
|
||
|
||
messages = db.query(Message).filter(
|
||
Message.conversation_id == conversation.id
|
||
).order_by(Message.created_at.desc()).limit(limit).all()
|
||
|
||
return list(reversed(messages))
|
||
finally:
|
||
db.close()
|
||
|
||
def get_user_preference(self, user_id: str) -> Optional[dict]:
|
||
"""
|
||
获取用户偏好
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
偏好字典或None
|
||
"""
|
||
db = self.get_session()
|
||
try:
|
||
pref = db.query(UserPreference).filter(
|
||
UserPreference.user_id == user_id
|
||
).first()
|
||
return pref.preferences if pref else None
|
||
finally:
|
||
db.close()
|
||
|
||
def set_user_preference(self, user_id: str, preferences: dict) -> bool:
|
||
"""
|
||
设置用户偏好
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
preferences: 偏好字典
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
db = self.get_session()
|
||
try:
|
||
pref = db.query(UserPreference).filter(
|
||
UserPreference.user_id == user_id
|
||
).first()
|
||
|
||
if pref:
|
||
pref.preferences = preferences
|
||
pref.updated_at = datetime.utcnow()
|
||
else:
|
||
pref = UserPreference(
|
||
user_id=user_id,
|
||
preferences=preferences
|
||
)
|
||
db.add(pref)
|
||
|
||
db.commit()
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"设置用户偏好失败: {e}")
|
||
db.rollback()
|
||
return False
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
# 创建全局实例
|
||
db_service = DatabaseService()
|