164 lines
5.4 KiB
Python
164 lines
5.4 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
||
from sqlalchemy.orm import relationship
|
||
|
||
from cryptoai.models.base import Base, logger
|
||
from cryptoai.models.user import User
|
||
|
||
# 定义用户提问数据模型
|
||
class UserQuestion(Base):
|
||
"""用户提问数据表模型"""
|
||
__tablename__ = 'user_questions'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
|
||
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
|
||
question = Column(Text, nullable=False, comment='提问内容')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 关系
|
||
user = relationship("User", back_populates="questions")
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_user_id', 'user_id'),
|
||
Index('idx_agent_id', 'agent_id'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
class UserQuestionManager:
|
||
"""用户提问管理类"""
|
||
|
||
def __init__(self, session: Session = None):
|
||
self.session = session
|
||
|
||
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
|
||
"""
|
||
保存用户提问数据
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
agent_id: AI Agent ID
|
||
question: 提问内容
|
||
|
||
Returns:
|
||
保存是否成功
|
||
"""
|
||
try:
|
||
# 创建新记录
|
||
new_question = UserQuestion(
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
question=question,
|
||
create_time=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
self.session.add(new_question)
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"保存用户提问失败: {e}")
|
||
return False
|
||
|
||
def get_user_question_count(self) -> int:
|
||
"""
|
||
获取用户提问数量
|
||
|
||
Returns:
|
||
提问总数
|
||
"""
|
||
try:
|
||
# 查询用户提问数量
|
||
question_count = self.session.query(UserQuestion).count()
|
||
|
||
return question_count
|
||
except Exception as e:
|
||
logger.error(f"获取用户提问数量失败: {e}")
|
||
return 0
|
||
|
||
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
|
||
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取用户提问数据
|
||
|
||
Args:
|
||
user_id: 可选,指定获取特定用户的提问
|
||
agent_id: 可选,指定获取特定Agent的提问
|
||
limit: 返回的最大记录数,默认20条
|
||
skip: 跳过的记录数,默认0条
|
||
|
||
Returns:
|
||
提问数据列表,如果查询失败则返回空列表
|
||
"""
|
||
try:
|
||
# 构建查询
|
||
query = self.session.query(UserQuestion)
|
||
|
||
# 如果指定了user_id,则筛选
|
||
if user_id:
|
||
query = query.filter(UserQuestion.user_id == user_id)
|
||
|
||
# 如果指定了agent_id,则筛选
|
||
if agent_id:
|
||
query = query.filter(UserQuestion.agent_id == agent_id)
|
||
|
||
# 按创建时间降序排序并限制数量
|
||
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
|
||
|
||
# 转换为字典列表
|
||
questions = []
|
||
for result in results:
|
||
questions.append({
|
||
'id': result.id,
|
||
'user_id': result.user_id,
|
||
'agent_id': result.agent_id,
|
||
'question': result.question,
|
||
'create_time': result.create_time
|
||
})
|
||
|
||
return questions
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户提问失败: {e}")
|
||
return []
|
||
|
||
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过ID获取用户提问数据
|
||
|
||
Args:
|
||
question_id: 提问ID
|
||
|
||
Returns:
|
||
提问数据,如果不存在则返回None
|
||
"""
|
||
try:
|
||
# 查询提问
|
||
result = self.session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
|
||
|
||
if result:
|
||
# 转换为字典
|
||
return {
|
||
'id': result.id,
|
||
'user_id': result.user_id,
|
||
'agent_id': result.agent_id,
|
||
'question': result.question,
|
||
'create_time': result.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户提问失败: {e}")
|
||
return None |