stock-ai-agent/backend/app/services/db_service.py
2026-02-04 14:56:03 +08:00

215 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库服务
提供数据库操作功能
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from typing import Optional, List
from datetime import datetime
import uuid
from app.config import get_settings
from app.models.database import Base, Conversation, Message, UserPreference
from app.utils.logger import logger
class DatabaseService:
"""数据库服务类"""
def __init__(self):
"""初始化数据库连接"""
settings = get_settings()
self.engine = create_engine(
settings.database_url,
connect_args={"check_same_thread": False} if "sqlite" in settings.database_url else {}
)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
# 创建表
Base.metadata.create_all(bind=self.engine)
logger.info("数据库初始化成功")
def get_session(self) -> Session:
"""获取数据库会话"""
return self.SessionLocal()
def create_conversation(self, session_id: Optional[str] = None, user_id: Optional[str] = None) -> Conversation:
"""
创建新对话
Args:
session_id: 会话ID可选自动生成
user_id: 用户ID
Returns:
对话对象
"""
db = self.get_session()
try:
if not session_id:
session_id = str(uuid.uuid4())
conversation = Conversation(
session_id=session_id,
user_id=user_id
)
db.add(conversation)
db.commit()
db.refresh(conversation)
return conversation
finally:
db.close()
def get_conversation(self, session_id: str) -> Optional[Conversation]:
"""
获取对话
Args:
session_id: 会话ID
Returns:
对话对象或None
"""
db = self.get_session()
try:
return db.query(Conversation).filter(Conversation.session_id == session_id).first()
finally:
db.close()
def add_message(
self,
session_id: str,
role: str,
content: str,
metadata: Optional[dict] = None,
user_id: Optional[int] = None
) -> Message:
"""
添加消息
Args:
session_id: 会话ID
role: 角色user/assistant
content: 消息内容
metadata: 元数据
user_id: 用户ID创建新对话时需要
Returns:
消息对象
"""
db = self.get_session()
try:
# 获取或创建对话
conversation = db.query(Conversation).filter(
Conversation.session_id == session_id
).first()
if not conversation:
if not user_id:
raise ValueError("创建新对话时必须提供 user_id")
conversation = Conversation(session_id=session_id, user_id=user_id)
db.add(conversation)
db.commit()
db.refresh(conversation)
# 创建消息
message = Message(
conversation_id=conversation.id,
role=role,
content=content,
msg_metadata=metadata
)
db.add(message)
db.commit()
db.refresh(message)
return message
finally:
db.close()
def get_conversation_history(self, session_id: str, limit: int = 50) -> List[Message]:
"""
获取对话历史
Args:
session_id: 会话ID
limit: 最大消息数
Returns:
消息列表
"""
db = self.get_session()
try:
conversation = db.query(Conversation).filter(
Conversation.session_id == session_id
).first()
if not conversation:
return []
messages = db.query(Message).filter(
Message.conversation_id == conversation.id
).order_by(Message.created_at.desc()).limit(limit).all()
return list(reversed(messages))
finally:
db.close()
def get_user_preference(self, user_id: str) -> Optional[dict]:
"""
获取用户偏好
Args:
user_id: 用户ID
Returns:
偏好字典或None
"""
db = self.get_session()
try:
pref = db.query(UserPreference).filter(
UserPreference.user_id == user_id
).first()
return pref.preferences if pref else None
finally:
db.close()
def set_user_preference(self, user_id: str, preferences: dict) -> bool:
"""
设置用户偏好
Args:
user_id: 用户ID
preferences: 偏好字典
Returns:
是否成功
"""
db = self.get_session()
try:
pref = db.query(UserPreference).filter(
UserPreference.user_id == user_id
).first()
if pref:
pref.preferences = preferences
pref.updated_at = datetime.utcnow()
else:
pref = UserPreference(
user_id=user_id,
preferences=preferences
)
db.add(pref)
db.commit()
return True
except Exception as e:
logger.error(f"设置用户偏好失败: {e}")
db.rollback()
return False
finally:
db.close()
# 创建全局实例
db_service = DatabaseService()