crypto.ai/cryptoai/models/user_question.py
2025-05-30 22:09:45 +08:00

164 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
from sqlalchemy.orm import relationship
from cryptoai.models.base import Base, logger
from cryptoai.models.user import User
# 定义用户提问数据模型
class UserQuestion(Base):
"""用户提问数据表模型"""
__tablename__ = 'user_questions'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
question = Column(Text, nullable=False, comment='提问内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
user = relationship("User", back_populates="questions")
# 索引和表属性
__table_args__ = (
Index('idx_user_id', 'user_id'),
Index('idx_agent_id', 'agent_id'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class UserQuestionManager:
"""用户提问管理类"""
def __init__(self, session: Session = None):
self.session = session
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
"""
保存用户提问数据
Args:
user_id: 用户ID
agent_id: AI Agent ID
question: 提问内容
Returns:
保存是否成功
"""
try:
# 创建新记录
new_question = UserQuestion(
user_id=user_id,
agent_id=agent_id,
question=question,
create_time=datetime.now()
)
# 添加并提交
self.session.add(new_question)
self.session.commit()
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
return True
except Exception as e:
self.session.rollback()
logger.error(f"保存用户提问失败: {e}")
return False
def get_user_question_count(self) -> int:
"""
获取用户提问数量
Returns:
提问总数
"""
try:
# 查询用户提问数量
question_count = self.session.query(UserQuestion).count()
return question_count
except Exception as e:
logger.error(f"获取用户提问数量失败: {e}")
return 0
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取用户提问数据
Args:
user_id: 可选,指定获取特定用户的提问
agent_id: 可选指定获取特定Agent的提问
limit: 返回的最大记录数默认20条
skip: 跳过的记录数默认0条
Returns:
提问数据列表,如果查询失败则返回空列表
"""
try:
# 构建查询
query = self.session.query(UserQuestion)
# 如果指定了user_id则筛选
if user_id:
query = query.filter(UserQuestion.user_id == user_id)
# 如果指定了agent_id则筛选
if agent_id:
query = query.filter(UserQuestion.agent_id == agent_id)
# 按创建时间降序排序并限制数量
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
# 转换为字典列表
questions = []
for result in results:
questions.append({
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
})
return questions
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return []
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取用户提问数据
Args:
question_id: 提问ID
Returns:
提问数据如果不存在则返回None
"""
try:
# 查询提问
result = self.session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
if result:
# 转换为字典
return {
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
}
else:
return None
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return None