105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
||
|
||
from cryptoai.models.base import Base, logger
|
||
|
||
# 定义AI Agent信息流模型
|
||
class AgentFeed(Base):
|
||
"""AI Agent信息流表模型"""
|
||
__tablename__ = 'agent_feeds'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
|
||
avatar_url = Column(String(255), nullable=True, comment='头像URL')
|
||
content = Column(Text, nullable=False, comment='内容')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_agent_name', 'agent_name'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
class AgentFeedManager:
|
||
"""Agent信息流管理类"""
|
||
|
||
def __init__(self, db_session):
|
||
self.session = db_session
|
||
|
||
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
|
||
"""
|
||
保存AI Agent信息流到数据库
|
||
|
||
Args:
|
||
agent_name: AI Agent名称
|
||
content: 内容
|
||
avatar_url: 头像URL,可选
|
||
|
||
Returns:
|
||
保存是否成功
|
||
"""
|
||
try:
|
||
# 创建新记录
|
||
new_feed = AgentFeed(
|
||
agent_name=agent_name,
|
||
content=content,
|
||
avatar_url=avatar_url
|
||
)
|
||
|
||
# 添加并提交
|
||
self.session.add(new_feed)
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功保存 {agent_name} 的信息流")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"保存信息流失败: {e}")
|
||
return False
|
||
|
||
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取AI Agent信息流
|
||
|
||
Args:
|
||
agent_name: 可选,指定获取特定Agent的信息流
|
||
limit: 返回的最大记录数,默认20条
|
||
skip: 跳过的记录数,默认0条
|
||
|
||
Returns:
|
||
信息流列表,如果查询失败则返回空列表
|
||
"""
|
||
try:
|
||
# 构建查询
|
||
query = self.session.query(AgentFeed)
|
||
|
||
# 如果指定了agent_name,则筛选
|
||
if agent_name:
|
||
query = query.filter(AgentFeed.agent_name == agent_name)
|
||
|
||
# 按创建时间降序排序并限制数量
|
||
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
|
||
|
||
# 转换为字典列表
|
||
feeds = []
|
||
for result in results:
|
||
feeds.append({
|
||
'id': result.id,
|
||
'agent_name': result.agent_name,
|
||
'avatar_url': result.avatar_url,
|
||
'content': result.content,
|
||
'create_time': result.create_time
|
||
})
|
||
|
||
return feeds
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取信息流失败: {e}")
|
||
return [] |