update
This commit is contained in:
parent
64119d5391
commit
1debbc7dce
@ -285,7 +285,7 @@ class CryptoAgent:
|
||||
# 保存分析结果到数据库
|
||||
try:
|
||||
# 保存到数据库
|
||||
saved = self.db_manager.save_analysis_result(
|
||||
saved = self.db_manager.analysis_result_manager.save_analysis_result(
|
||||
agent='crypto',
|
||||
symbol=symbol,
|
||||
time_interval=self.time_interval,
|
||||
@ -363,7 +363,7 @@ class CryptoAgent:
|
||||
|
||||
# 保存交易建议到数据库
|
||||
try:
|
||||
saved = self.db_manager.save_agent_feed(
|
||||
saved = self.db_manager.agent_feed_manager.save_agent_feed(
|
||||
agent_name="Crypto Agent",
|
||||
content=message
|
||||
)
|
||||
|
||||
@ -1,4 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""模型模块,提供数据处理和模型定义。"""
|
||||
"""模型模块,提供数据处理和模型定义。"""
|
||||
|
||||
from cryptoai.models.base import Base
|
||||
from cryptoai.models.token import Token, TokenManager
|
||||
from cryptoai.models.analysis_result import AnalysisResult, AnalysisResultManager
|
||||
from cryptoai.models.agent_feed import AgentFeed, AgentFeedManager
|
||||
from cryptoai.models.user import User, UserManager
|
||||
from cryptoai.models.user_question import UserQuestion, UserQuestionManager
|
||||
from cryptoai.models.agent import Agent, AgentManager
|
||||
from cryptoai.models.astock import AStock, AStockManager
|
||||
from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager
|
||||
Binary file not shown.
268
cryptoai/models/agent.py
Normal file
268
cryptoai/models/agent.py
Normal file
@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
||||
from sqlalchemy.dialects.mysql import JSON
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义Agent数据模型
|
||||
class Agent(Base):
|
||||
"""Agent数据表模型"""
|
||||
__tablename__ = 'agents'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
|
||||
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
|
||||
description = Column(Text, nullable=True, comment='Agent描述')
|
||||
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
|
||||
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_name', 'name'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class AgentManager:
|
||||
"""Agent管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def create_agent(self, name: str, description: Optional[str] = None,
|
||||
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
|
||||
"""
|
||||
创建新的Agent
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
description: Agent描述
|
||||
hello_prompt: 欢迎提示语
|
||||
dify_token: Dify API令牌
|
||||
inputs: 输入参数JSON
|
||||
|
||||
Returns:
|
||||
新Agent的ID,如果创建失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 检查名称是否已存在
|
||||
existing_agent = self.session.query(Agent).filter(Agent.name == name).first()
|
||||
if existing_agent:
|
||||
logger.warning(f"Agent名称 {name} 已存在")
|
||||
return None
|
||||
|
||||
# 创建新Agent
|
||||
new_agent = Agent(
|
||||
name=name,
|
||||
description=description,
|
||||
hello_prompt=hello_prompt,
|
||||
dify_token=dify_token,
|
||||
inputs=inputs,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_agent)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功创建Agent: {name}")
|
||||
return new_agent.id
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"创建Agent失败: {e}")
|
||||
return None
|
||||
|
||||
def update_agent(self, agent_id: int, name: Optional[str] = None,
|
||||
description: Optional[str] = None, hello_prompt: Optional[str] = None,
|
||||
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
更新Agent信息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
name: Agent名称
|
||||
description: Agent描述
|
||||
hello_prompt: 欢迎提示语
|
||||
dify_token: Dify API令牌
|
||||
inputs: 输入参数JSON
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||||
return False
|
||||
|
||||
# 更新字段
|
||||
if name is not None:
|
||||
# 检查名称是否与其他Agent冲突
|
||||
if name != agent.name:
|
||||
existing = self.session.query(Agent).filter(Agent.name == name).first()
|
||||
if existing:
|
||||
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
|
||||
return False
|
||||
agent.name = name
|
||||
|
||||
if description is not None:
|
||||
agent.description = description
|
||||
|
||||
if hello_prompt is not None:
|
||||
agent.hello_prompt = hello_prompt
|
||||
|
||||
if dify_token is not None:
|
||||
agent.dify_token = dify_token
|
||||
|
||||
if inputs is not None:
|
||||
agent.inputs = inputs
|
||||
|
||||
agent.update_time = datetime.now()
|
||||
|
||||
# 提交更改
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功更新Agent ID: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"更新Agent失败: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过ID获取Agent信息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent信息,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
|
||||
if agent:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent信息失败: {e}")
|
||||
return None
|
||||
|
||||
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过名称获取Agent信息
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
|
||||
Returns:
|
||||
Agent信息,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.name == name).first()
|
||||
|
||||
if agent:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent信息失败: {e}")
|
||||
return None
|
||||
|
||||
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Agent列表
|
||||
|
||||
Args:
|
||||
limit: 返回的最大数量
|
||||
skip: 跳过的数量
|
||||
|
||||
Returns:
|
||||
Agent列表
|
||||
"""
|
||||
try:
|
||||
# 查询Agent列表
|
||||
agents = self.session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
|
||||
|
||||
# 转换为字典列表
|
||||
return [{
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
} for agent in agents]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent列表失败: {e}")
|
||||
return []
|
||||
|
||||
def delete_agent(self, agent_id: int) -> bool:
|
||||
"""
|
||||
删除Agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||||
return False
|
||||
|
||||
# 删除Agent
|
||||
self.session.delete(agent)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功删除Agent ID: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"删除Agent失败: {e}")
|
||||
return False
|
||||
105
cryptoai/models/agent_feed.py
Normal file
105
cryptoai/models/agent_feed.py
Normal file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义AI Agent信息流模型
|
||||
class AgentFeed(Base):
|
||||
"""AI Agent信息流表模型"""
|
||||
__tablename__ = 'agent_feeds'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
|
||||
avatar_url = Column(String(255), nullable=True, comment='头像URL')
|
||||
content = Column(Text, nullable=False, comment='内容')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_agent_name', 'agent_name'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class AgentFeedManager:
|
||||
"""Agent信息流管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
|
||||
"""
|
||||
保存AI Agent信息流到数据库
|
||||
|
||||
Args:
|
||||
agent_name: AI Agent名称
|
||||
content: 内容
|
||||
avatar_url: 头像URL,可选
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
# 创建新记录
|
||||
new_feed = AgentFeed(
|
||||
agent_name=agent_name,
|
||||
content=content,
|
||||
avatar_url=avatar_url
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_feed)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功保存 {agent_name} 的信息流")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"保存信息流失败: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取AI Agent信息流
|
||||
|
||||
Args:
|
||||
agent_name: 可选,指定获取特定Agent的信息流
|
||||
limit: 返回的最大记录数,默认20条
|
||||
skip: 跳过的记录数,默认0条
|
||||
|
||||
Returns:
|
||||
信息流列表,如果查询失败则返回空列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = self.session.query(AgentFeed)
|
||||
|
||||
# 如果指定了agent_name,则筛选
|
||||
if agent_name:
|
||||
query = query.filter(AgentFeed.agent_name == agent_name)
|
||||
|
||||
# 按创建时间降序排序并限制数量
|
||||
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
# 转换为字典列表
|
||||
feeds = []
|
||||
for result in results:
|
||||
feeds.append({
|
||||
'id': result.id,
|
||||
'agent_name': result.agent_name,
|
||||
'avatar_url': result.avatar_url,
|
||||
'content': result.content,
|
||||
'create_time': result.create_time
|
||||
})
|
||||
|
||||
return feeds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取信息流失败: {e}")
|
||||
return []
|
||||
189
cryptoai/models/analysis_history.py
Normal file
189
cryptoai/models/analysis_history.py
Normal file
@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义分析历史模型
|
||||
class AnalysisHistory(Base):
|
||||
"""分析历史表模型"""
|
||||
__tablename__ = 'analysis_history'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
|
||||
type = Column(String(20), nullable=False, comment='分析类型(crypto, astock)')
|
||||
symbol = Column(String(50), nullable=False, comment='交易符号')
|
||||
timeframe = Column(String(20), nullable=True, comment='时间框架')
|
||||
content = Column(Text, nullable=False, comment='分析内容')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 关系
|
||||
user = relationship("User", back_populates="analysis_histories")
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_user_id', 'user_id'),
|
||||
Index('idx_type', 'type'),
|
||||
Index('idx_symbol', 'symbol'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class AnalysisHistoryManager:
|
||||
"""分析历史管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def add_analysis_history(self, user_id: int, type: str, symbol: str,
|
||||
content: str, timeframe: str = None) -> bool:
|
||||
"""
|
||||
添加分析历史记录
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
type: 分析类型(crypto, astock)
|
||||
symbol: 交易符号
|
||||
content: 分析内容
|
||||
timeframe: 时间框架(可选)
|
||||
|
||||
Returns:
|
||||
添加是否成功
|
||||
"""
|
||||
try:
|
||||
# 创建新记录
|
||||
new_history = AnalysisHistory(
|
||||
user_id=user_id,
|
||||
type=type,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
content=content,
|
||||
create_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_history)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功添加分析历史记录,用户ID: {user_id}, 类型: {type}, 交易符号: {symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"添加分析历史记录失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_analysis_history(self, history_id: int) -> bool:
|
||||
"""
|
||||
删除分析历史记录
|
||||
|
||||
Args:
|
||||
history_id: 历史记录ID
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询记录
|
||||
history = self.session.query(AnalysisHistory).filter(AnalysisHistory.id == history_id).first()
|
||||
|
||||
if not history:
|
||||
logger.warning(f"分析历史记录ID {history_id} 不存在")
|
||||
return False
|
||||
|
||||
# 删除记录
|
||||
self.session.delete(history)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功删除分析历史记录,ID: {history_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"删除分析历史记录失败: {e}")
|
||||
return False
|
||||
|
||||
def get_analysis_history_by_id(self, history_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据ID获取分析历史记录
|
||||
|
||||
Args:
|
||||
history_id: 历史记录ID
|
||||
|
||||
Returns:
|
||||
分析历史记录,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询记录
|
||||
history = self.session.query(AnalysisHistory).filter(AnalysisHistory.id == history_id).first()
|
||||
|
||||
if history:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': history.id,
|
||||
'user_id': history.user_id,
|
||||
'type': history.type,
|
||||
'symbol': history.symbol,
|
||||
'timeframe': history.timeframe,
|
||||
'content': history.content,
|
||||
'create_time': history.create_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取分析历史记录失败: {e}")
|
||||
return None
|
||||
|
||||
def get_user_analysis_history(self, user_id: int, type: str = None,
|
||||
symbol: str = None, limit: int = 10,
|
||||
offset: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户的分析历史记录
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
type: 分析类型筛选(可选)
|
||||
symbol: 交易符号筛选(可选)
|
||||
limit: 返回记录数量限制
|
||||
offset: 分页偏移量
|
||||
|
||||
Returns:
|
||||
分析历史记录列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = self.session.query(AnalysisHistory).filter(AnalysisHistory.user_id == user_id)
|
||||
|
||||
# 添加可选过滤条件
|
||||
if type:
|
||||
query = query.filter(AnalysisHistory.type == type)
|
||||
if symbol:
|
||||
query = query.filter(AnalysisHistory.symbol == symbol)
|
||||
|
||||
# 按创建时间降序排序,并应用分页
|
||||
results = query.order_by(AnalysisHistory.create_time.desc()).limit(limit).offset(offset).all()
|
||||
|
||||
# 转换为字典列表
|
||||
history_list = []
|
||||
for history in results:
|
||||
history_list.append({
|
||||
'id': history.id,
|
||||
'user_id': history.user_id,
|
||||
'type': history.type,
|
||||
'symbol': history.symbol,
|
||||
'timeframe': history.timeframe,
|
||||
'content': history.content,
|
||||
'create_time': history.create_time
|
||||
})
|
||||
|
||||
return history_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户分析历史记录失败: {e}")
|
||||
return []
|
||||
111
cryptoai/models/analysis_result.py
Normal file
111
cryptoai/models/analysis_result.py
Normal file
@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
from sqlalchemy.dialects.mysql import JSON
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义分析结果模型
|
||||
class AnalysisResult(Base):
|
||||
"""分析结果表模型"""
|
||||
__tablename__ = 'analysis_results'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)')
|
||||
symbol = Column(String(50), nullable=False, comment='交易对符号')
|
||||
time_interval = Column(String(20), nullable=False, comment='时间间隔')
|
||||
completion_result = Column(JSON, nullable=False, comment='分析结果JSON')
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index('idx_agent', 'agent'),
|
||||
Index('idx_symbol', 'symbol'),
|
||||
Index('idx_time_interval', 'time_interval'),
|
||||
Index('idx_created_at', 'created_at'),
|
||||
)
|
||||
|
||||
class AnalysisResultManager:
|
||||
"""分析结果管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
|
||||
analysis_result: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
保存分析结果到数据库
|
||||
|
||||
Args:
|
||||
agent: 智能体类型,例如 'crypto' 或 'gold'
|
||||
symbol: 交易对符号,例如 'BTCUSDT'
|
||||
time_interval: 时间间隔,例如 '1h', '4h', '1d'
|
||||
analysis_result: 分析结果数据
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
# 创建新记录
|
||||
new_result = AnalysisResult(
|
||||
agent=agent,
|
||||
symbol=symbol,
|
||||
time_interval=time_interval,
|
||||
completion_result=analysis_result,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_result)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"保存分析结果失败: {e}")
|
||||
return False
|
||||
|
||||
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取最新的分析结果
|
||||
|
||||
Args:
|
||||
agent: 智能体类型,例如 'crypto' 或 'gold'
|
||||
symbol: 交易对符号,例如 'BTCUSDT'
|
||||
time_interval: 时间间隔,例如 '1h', '4h', '1d'
|
||||
|
||||
Returns:
|
||||
最新分析结果,如果查询失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询最新的结果
|
||||
result = self.session.query(AnalysisResult).filter(
|
||||
AnalysisResult.agent == agent,
|
||||
AnalysisResult.symbol == symbol,
|
||||
AnalysisResult.time_interval == time_interval
|
||||
).order_by(AnalysisResult.created_at.desc()).first()
|
||||
|
||||
if result:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': result.id,
|
||||
'agent': result.agent,
|
||||
'symbol': result.symbol,
|
||||
'time_interval': result.time_interval,
|
||||
'completion_result': result.completion_result,
|
||||
'created_at': result.created_at
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最新分析结果失败: {e}")
|
||||
return None
|
||||
314
cryptoai/models/astock.py
Normal file
314
cryptoai/models/astock.py
Normal file
@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
from cryptoai.utils.db_utils import convert_timestamp_to_datetime
|
||||
|
||||
# 定义 A 股数据模型
|
||||
class AStock(Base):
|
||||
"""A股股票基本信息表模型"""
|
||||
__tablename__ = 'astock'
|
||||
|
||||
stock_code = Column(String(10), primary_key=True, comment='股票代码')
|
||||
short_name = Column(String(50), nullable=False, comment='股票简称')
|
||||
exchange = Column(String(20), nullable=True, comment='交易所')
|
||||
list_date = Column(DateTime, nullable=True, comment='上市日期')
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index('idx_stock_code', 'stock_code', unique=True),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class AStockManager:
|
||||
"""A股管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
批量创建股票信息
|
||||
|
||||
Args:
|
||||
stocks: 股票信息列表,每个元素为包含以下键的字典:
|
||||
- stock_code: 股票代码
|
||||
- short_name: 股票简称
|
||||
- exchange: 交易所(可选)
|
||||
- list_date: 上市日期(可选,Unix时间戳,毫秒级)
|
||||
|
||||
Returns:
|
||||
创建是否成功
|
||||
"""
|
||||
try:
|
||||
for stock_data in stocks:
|
||||
# 检查股票代码是否已存在
|
||||
existing_stock = self.session.query(AStock).filter(
|
||||
AStock.stock_code == stock_data['stock_code']
|
||||
).first()
|
||||
|
||||
if existing_stock:
|
||||
logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过")
|
||||
continue
|
||||
|
||||
# 创建新股票记录
|
||||
new_stock = AStock(
|
||||
stock_code=stock_data['stock_code'],
|
||||
short_name=stock_data['short_name'],
|
||||
exchange=stock_data.get('exchange')
|
||||
)
|
||||
|
||||
# 处理上市日期
|
||||
if 'list_date' in stock_data and stock_data['list_date']:
|
||||
list_date = convert_timestamp_to_datetime(stock_data['list_date'])
|
||||
if list_date:
|
||||
new_stock.list_date = list_date
|
||||
|
||||
self.session.add(new_stock)
|
||||
|
||||
# 批量提交
|
||||
self.session.commit()
|
||||
logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"批量创建股票信息失败: {e}")
|
||||
return False
|
||||
|
||||
def create_stock(self, stock_code: str, short_name: str,
|
||||
exchange: Optional[str] = None,
|
||||
list_date: Optional[Union[int, float, str]] = None) -> bool:
|
||||
"""
|
||||
创建新的股票信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
short_name: 股票简称
|
||||
exchange: 交易所(可选)
|
||||
list_date: 上市日期(可选,Unix时间戳,毫秒级)
|
||||
|
||||
Returns:
|
||||
创建是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查股票代码是否已存在
|
||||
existing_stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||||
if existing_stock:
|
||||
logger.warning(f"股票代码 {stock_code} 已存在")
|
||||
return False
|
||||
|
||||
# 创建新股票记录
|
||||
new_stock = AStock(
|
||||
stock_code=stock_code,
|
||||
short_name=short_name,
|
||||
exchange=exchange
|
||||
)
|
||||
|
||||
# 处理上市日期
|
||||
if list_date:
|
||||
converted_date = convert_timestamp_to_datetime(list_date)
|
||||
if converted_date:
|
||||
new_stock.list_date = converted_date
|
||||
|
||||
self.session.add(new_stock)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功创建股票信息: {stock_code} - {short_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"创建股票信息失败: {e}")
|
||||
return False
|
||||
|
||||
def update_stock(self, stock_code: str, short_name: Optional[str] = None,
|
||||
exchange: Optional[str] = None,
|
||||
list_date: Optional[Union[int, float, str, datetime]] = None) -> bool:
|
||||
"""
|
||||
更新股票信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
short_name: 股票简称
|
||||
exchange: 交易所
|
||||
list_date: 上市日期(可以是datetime对象或时间戳)
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询股票
|
||||
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||||
if not stock:
|
||||
logger.warning(f"股票代码 {stock_code} 不存在")
|
||||
return False
|
||||
|
||||
# 更新字段
|
||||
if short_name is not None:
|
||||
stock.short_name = short_name
|
||||
if exchange is not None:
|
||||
stock.exchange = exchange
|
||||
if list_date is not None:
|
||||
if isinstance(list_date, datetime):
|
||||
stock.list_date = list_date
|
||||
else:
|
||||
converted_date = convert_timestamp_to_datetime(list_date)
|
||||
if converted_date:
|
||||
stock.list_date = converted_date
|
||||
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功更新股票信息: {stock_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"更新股票信息失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_stock(self, stock_code: str) -> bool:
|
||||
"""
|
||||
删除股票信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询股票
|
||||
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||||
if not stock:
|
||||
logger.warning(f"股票代码 {stock_code} 不存在")
|
||||
return False
|
||||
|
||||
# 删除股票
|
||||
self.session.delete(stock)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功删除股票信息: {stock_code}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"删除股票信息失败: {e}")
|
||||
return False
|
||||
|
||||
def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过股票代码获取股票信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
股票信息,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询股票
|
||||
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||||
|
||||
if stock:
|
||||
return {
|
||||
'stock_code': stock.stock_code,
|
||||
'short_name': stock.short_name,
|
||||
'exchange': stock.exchange,
|
||||
'list_date': stock.list_date,
|
||||
'created_at': stock.created_at
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票信息失败: {e}")
|
||||
return None
|
||||
|
||||
def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索股票
|
||||
|
||||
Args:
|
||||
key: 搜索关键词
|
||||
limit: 最大返回数量
|
||||
|
||||
Returns:
|
||||
股票信息列表
|
||||
"""
|
||||
try:
|
||||
# 查询股票
|
||||
stocks = self.session.query(AStock).filter(
|
||||
AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")
|
||||
).limit(limit).all()
|
||||
|
||||
return [{
|
||||
'stock_code': stock.stock_code,
|
||||
'short_name': stock.short_name,
|
||||
'exchange': stock.exchange,
|
||||
'list_date': stock.list_date,
|
||||
'created_at': stock.created_at
|
||||
} for stock in stocks]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票信息失败: {e}")
|
||||
return []
|
||||
|
||||
def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
通过股票简称获取股票信息(可能有多个结果)
|
||||
|
||||
Args:
|
||||
short_name: 股票简称
|
||||
|
||||
Returns:
|
||||
股票信息列表
|
||||
"""
|
||||
try:
|
||||
# 查询股票
|
||||
stocks = self.session.query(AStock).filter(AStock.short_name == short_name).all()
|
||||
|
||||
return [{
|
||||
'stock_code': stock.stock_code,
|
||||
'short_name': stock.short_name,
|
||||
'exchange': stock.exchange,
|
||||
'list_date': stock.list_date,
|
||||
'created_at': stock.created_at
|
||||
} for stock in stocks]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票信息失败: {e}")
|
||||
return []
|
||||
|
||||
def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取股票列表
|
||||
|
||||
Args:
|
||||
limit: 返回的最大数量
|
||||
skip: 跳过的数量
|
||||
|
||||
Returns:
|
||||
股票信息列表
|
||||
"""
|
||||
try:
|
||||
# 查询股票列表
|
||||
stocks = self.session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all()
|
||||
|
||||
return [{
|
||||
'stock_code': stock.stock_code,
|
||||
'short_name': stock.short_name,
|
||||
'exchange': stock.exchange,
|
||||
'list_date': stock.list_date,
|
||||
'created_at': stock.created_at
|
||||
} for stock in stocks]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票列表失败: {e}")
|
||||
return []
|
||||
20
cryptoai/models/base.py
Normal file
20
cryptoai/models/base.py
Normal file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, relationship
|
||||
from sqlalchemy.dialects.mysql import JSON
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('db_models')
|
||||
|
||||
# 创建模型基类
|
||||
Base = declarative_base()
|
||||
127
cryptoai/models/token.py
Normal file
127
cryptoai/models/token.py
Normal file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义Token模型
|
||||
class Token(Base):
|
||||
"""Token信息表模型"""
|
||||
__tablename__ = 'tokens'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
symbol = Column(String(50), nullable=False, unique=True, comment='交易对符号')
|
||||
base_asset = Column(String(20), nullable=False, comment='基础资产')
|
||||
quote_asset = Column(String(20), nullable=False, comment='计价资产')
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index('idx_symbol', 'symbol'),
|
||||
Index('idx_base_asset', 'base_asset'),
|
||||
Index('idx_quote_asset', 'quote_asset'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class TokenManager:
|
||||
"""Token管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
|
||||
"""
|
||||
创建新的Token信息
|
||||
|
||||
Args:
|
||||
symbol: 交易对符号
|
||||
base_asset: 基础资产
|
||||
quote_asset: 计价资产
|
||||
|
||||
Returns:
|
||||
创建是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查交易对是否已存在
|
||||
existing_token = self.session.query(Token).filter(Token.symbol == symbol).first()
|
||||
if existing_token:
|
||||
logger.warning(f"交易对 {symbol} 已存在")
|
||||
return False
|
||||
|
||||
# 创建新Token记录
|
||||
new_token = Token(
|
||||
symbol=symbol,
|
||||
base_asset=base_asset,
|
||||
quote_asset=quote_asset
|
||||
)
|
||||
|
||||
self.session.add(new_token)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功创建Token信息: {symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"创建Token信息失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_token(self, symbol: str) -> bool:
|
||||
"""
|
||||
删除Token信息
|
||||
|
||||
Args:
|
||||
symbol: 交易对符号
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询Token
|
||||
token = self.session.query(Token).filter(Token.symbol == symbol).first()
|
||||
if not token:
|
||||
logger.warning(f"交易对 {symbol} 不存在")
|
||||
return False
|
||||
|
||||
# 删除Token
|
||||
self.session.delete(token)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功删除Token信息: {symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"删除Token信息失败: {e}")
|
||||
return False
|
||||
|
||||
def search_token(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索Token
|
||||
|
||||
Args:
|
||||
key: 搜索关键词
|
||||
limit: 最大返回数量
|
||||
|
||||
Returns:
|
||||
Token信息列表
|
||||
"""
|
||||
try:
|
||||
# 使用 SQLAlchemy 的 ORM 查询
|
||||
tokens = self.session.query(Token).filter(
|
||||
Token.symbol.like(f"{key}%") | Token.base_asset.like(f"{key}%")
|
||||
).limit(limit).all()
|
||||
|
||||
return [{
|
||||
'symbol': token.symbol,
|
||||
'base_asset': token.base_asset,
|
||||
'quote_asset': token.quote_asset
|
||||
} for token in tokens]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索Token信息失败: {e}")
|
||||
return []
|
||||
280
cryptoai/models/user.py
Normal file
280
cryptoai/models/user.py
Normal file
@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义用户数据模型
|
||||
class User(Base):
|
||||
"""用户数据表模型"""
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
mail = Column(String(100), nullable=False, unique=True, comment='邮箱')
|
||||
nickname = Column(String(50), nullable=False, comment='昵称')
|
||||
password = Column(String(100), nullable=False, comment='密码')
|
||||
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
|
||||
points = Column(Integer, nullable=False, default=0, comment='用户积分')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 关系
|
||||
questions = relationship("UserQuestion", back_populates="user")
|
||||
analysis_histories = relationship("AnalysisHistory", back_populates="user")
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_mail', 'mail'),
|
||||
Index('idx_level', 'level'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class UserManager:
|
||||
"""用户管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
|
||||
"""
|
||||
注册新用户
|
||||
|
||||
Args:
|
||||
mail: 邮箱
|
||||
nickname: 昵称
|
||||
password: 密码
|
||||
level: 用户级别,默认为0(普通用户)
|
||||
points: 初始积分,默认为0
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查邮箱是否已存在
|
||||
existing_user = self.session.query(User).filter(User.mail == mail).first()
|
||||
if existing_user:
|
||||
logger.warning(f"邮箱 {mail} 已被注册")
|
||||
return False
|
||||
|
||||
# 创建新用户
|
||||
new_user = User(
|
||||
mail=mail,
|
||||
nickname=nickname,
|
||||
password=password, # 实际应用中应该对密码进行哈希处理
|
||||
level=level,
|
||||
points=points,
|
||||
create_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_user)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功注册用户: {mail},初始积分: {points}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"注册用户失败: {e}")
|
||||
return False
|
||||
|
||||
def get_user_count(self) -> int:
|
||||
"""
|
||||
获取用户数量
|
||||
"""
|
||||
try:
|
||||
# 查询用户数量
|
||||
user_count = self.session.query(User).count()
|
||||
|
||||
return user_count
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户数量失败: {e}")
|
||||
return 0
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过ID获取用户信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户信息,如果用户不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询用户
|
||||
user = self.session.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if user:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': user.id,
|
||||
'mail': user.mail,
|
||||
'nickname': user.nickname,
|
||||
'level': user.level,
|
||||
'points': user.points,
|
||||
'create_time': user.create_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户信息失败: {e}")
|
||||
return None
|
||||
|
||||
def login(self, mail: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
登录
|
||||
"""
|
||||
|
||||
user = self.session.query(User).filter(User.mail == mail).first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if user.password != password:
|
||||
return None
|
||||
|
||||
return {'id': user.id, 'mail': user.mail, 'nickname': user.nickname, 'level': user.level, 'points': user.points, 'create_time': user.create_time}
|
||||
|
||||
def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过邮箱获取用户信息
|
||||
|
||||
Args:
|
||||
mail: 邮箱
|
||||
|
||||
Returns:
|
||||
用户信息,如果用户不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询用户
|
||||
user = self.session.query(User).filter(User.mail == mail).first()
|
||||
|
||||
if user:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': user.id,
|
||||
'mail': user.mail,
|
||||
'nickname': user.nickname,
|
||||
'level': user.level,
|
||||
'points': user.points,
|
||||
'create_time': user.create_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户信息失败: {e}")
|
||||
return None
|
||||
|
||||
def update_user_level(self, user_id: int, level: int) -> bool:
|
||||
"""
|
||||
更新用户级别
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
level: 新的用户级别
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询用户
|
||||
user = self.session.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"用户ID {user_id} 不存在")
|
||||
return False
|
||||
|
||||
# 更新级别
|
||||
user.level = level
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功更新用户 {user.mail} 的级别为 {level}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"更新用户级别失败: {e}")
|
||||
return False
|
||||
|
||||
def add_user_points(self, user_id: int, points: int) -> bool:
|
||||
"""
|
||||
为用户增加积分
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
points: 增加的积分数量(正数)
|
||||
|
||||
Returns:
|
||||
操作是否成功
|
||||
"""
|
||||
if points <= 0:
|
||||
logger.warning(f"增加的积分必须是正数: {points}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 查询用户
|
||||
user = self.session.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"用户ID {user_id} 不存在")
|
||||
return False
|
||||
|
||||
# 增加积分
|
||||
user.points += points
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"增加用户积分失败: {e}")
|
||||
return False
|
||||
|
||||
def consume_user_points(self, user_id: int, points: int) -> bool:
|
||||
"""
|
||||
用户消费积分
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
points: 消费的积分数量(正数)
|
||||
|
||||
Returns:
|
||||
操作是否成功
|
||||
"""
|
||||
if points <= 0:
|
||||
logger.warning(f"消费的积分必须是正数: {points}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 查询用户
|
||||
user = self.session.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"用户ID {user_id} 不存在")
|
||||
return False
|
||||
|
||||
# 检查积分是否足够
|
||||
if user.points < points:
|
||||
logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}")
|
||||
return False
|
||||
|
||||
# 消费积分
|
||||
user.points -= points
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"消费用户积分失败: {e}")
|
||||
return False
|
||||
164
cryptoai/models/user_question.py
Normal file
164
cryptoai/models/user_question.py
Normal file
@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
from cryptoai.models.user import User
|
||||
|
||||
# 定义用户提问数据模型
|
||||
class UserQuestion(Base):
|
||||
"""用户提问数据表模型"""
|
||||
__tablename__ = 'user_questions'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
|
||||
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
|
||||
question = Column(Text, nullable=False, comment='提问内容')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 关系
|
||||
user = relationship("User", back_populates="questions")
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_user_id', 'user_id'),
|
||||
Index('idx_agent_id', 'agent_id'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class UserQuestionManager:
|
||||
"""用户提问管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
|
||||
"""
|
||||
保存用户提问数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
agent_id: AI Agent ID
|
||||
question: 提问内容
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
# 创建新记录
|
||||
new_question = UserQuestion(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
question=question,
|
||||
create_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_question)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"保存用户提问失败: {e}")
|
||||
return False
|
||||
|
||||
def get_user_question_count(self) -> int:
|
||||
"""
|
||||
获取用户提问数量
|
||||
|
||||
Returns:
|
||||
提问总数
|
||||
"""
|
||||
try:
|
||||
# 查询用户提问数量
|
||||
question_count = self.session.query(UserQuestion).count()
|
||||
|
||||
return question_count
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户提问数量失败: {e}")
|
||||
return 0
|
||||
|
||||
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
|
||||
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户提问数据
|
||||
|
||||
Args:
|
||||
user_id: 可选,指定获取特定用户的提问
|
||||
agent_id: 可选,指定获取特定Agent的提问
|
||||
limit: 返回的最大记录数,默认20条
|
||||
skip: 跳过的记录数,默认0条
|
||||
|
||||
Returns:
|
||||
提问数据列表,如果查询失败则返回空列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = self.session.query(UserQuestion)
|
||||
|
||||
# 如果指定了user_id,则筛选
|
||||
if user_id:
|
||||
query = query.filter(UserQuestion.user_id == user_id)
|
||||
|
||||
# 如果指定了agent_id,则筛选
|
||||
if agent_id:
|
||||
query = query.filter(UserQuestion.agent_id == agent_id)
|
||||
|
||||
# 按创建时间降序排序并限制数量
|
||||
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
# 转换为字典列表
|
||||
questions = []
|
||||
for result in results:
|
||||
questions.append({
|
||||
'id': result.id,
|
||||
'user_id': result.user_id,
|
||||
'agent_id': result.agent_id,
|
||||
'question': result.question,
|
||||
'create_time': result.create_time
|
||||
})
|
||||
|
||||
return questions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户提问失败: {e}")
|
||||
return []
|
||||
|
||||
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过ID获取用户提问数据
|
||||
|
||||
Args:
|
||||
question_id: 提问ID
|
||||
|
||||
Returns:
|
||||
提问数据,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询提问
|
||||
result = self.session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
|
||||
|
||||
if result:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': result.id,
|
||||
'user_id': result.user_id,
|
||||
'agent_id': result.agent_id,
|
||||
'question': result.question,
|
||||
'create_time': result.create_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户提问失败: {e}")
|
||||
return None
|
||||
@ -22,7 +22,7 @@ logger.setLevel(logging.DEBUG)
|
||||
@router.get("/stock/search")
|
||||
async def search_stock(key: str, limit: int = 10):
|
||||
manager = get_db_manager()
|
||||
result = manager.search_stock(key, limit)
|
||||
result = manager.astock_manager.search_stock(key, limit)
|
||||
|
||||
return result
|
||||
|
||||
@ -123,16 +123,9 @@ async def get_stock_data_all(stock_code: str):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post('/{stock_code}/analysis', summary="获取股票分析数据")
|
||||
async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
|
||||
# 检查stock_code是否存在
|
||||
# codes = get_db_manager().search_stock(stock_code)
|
||||
# if not codes or len(codes) == 0:
|
||||
# raise HTTPException(status_code=400, detail="您输入的股票代码不存在,请检查后重新输入。")
|
||||
|
||||
# stock_code = codes[0]["stock_code"]
|
||||
|
||||
url = 'https://mate.aimateplus.com/v1/workflows/run'
|
||||
token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV'
|
||||
@ -150,7 +143,7 @@ async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Dep
|
||||
}
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ async def create_agent(agent: AgentCreate):
|
||||
创建新的AI Agent
|
||||
"""
|
||||
|
||||
agent = get_db_manager().create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
|
||||
agent = get_db_manager().agent_manager.create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
|
||||
|
||||
return agent
|
||||
|
||||
@ -64,7 +64,7 @@ async def get_agents(
|
||||
"""
|
||||
|
||||
# 从数据库获取Agent列表
|
||||
agents = get_db_manager().list_agents(limit=limit, skip=skip)
|
||||
agents = get_db_manager().agent_manager.list_agents(limit=limit, skip=skip)
|
||||
return agents
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_
|
||||
# 尝试从数据库获取Agent
|
||||
try:
|
||||
agent_id = int(request.agent_id)
|
||||
agent = get_db_manager().get_agent_by_id(agent_id)
|
||||
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
|
||||
|
||||
if not agent:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
||||
@ -106,7 +106,7 @@ async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_
|
||||
logging.info(f"Chat request data: {data}")
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
|
||||
30
cryptoai/routes/analysis.py
Normal file
30
cryptoai/routes/analysis.py
Normal file
@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter
|
||||
from typing import Optional
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from cryptoai.routes.user import get_current_user
|
||||
|
||||
|
||||
class AnalysisHistoryRequest(BaseModel):
|
||||
symbol: str
|
||||
content: str
|
||||
timeframe: Optional[str] = None
|
||||
type: str
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/analysis_history")
|
||||
async def analysis_history(request: AnalysisHistoryRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
|
||||
get_db_manager().analysis_history_manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
@router.get("/analysis_histories")
|
||||
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
||||
limit: int = 10,
|
||||
offset: int = 0):
|
||||
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||
return history
|
||||
@ -34,7 +34,7 @@ class CryptoAnalysisRequest(BaseModel):
|
||||
@router.get("/kline/{symbol}")
|
||||
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
|
||||
# 检查symbol是否存在
|
||||
tokens = get_db_manager().search_token(symbol)
|
||||
tokens = get_db_manager().token_manager.search_token(symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
|
||||
@ -57,7 +57,7 @@ async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit:
|
||||
async def analysis_crypto_v2(request: CryptoAnalysisRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
# 检查symbol是否存在
|
||||
tokens = get_db_manager().search_token(request.symbol)
|
||||
tokens = get_db_manager().token_manager.search_token(request.symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
|
||||
@ -80,7 +80,7 @@ async def analysis_crypto_v2(request: CryptoAnalysisRequest,
|
||||
}
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
@ -111,7 +111,7 @@ async def analysis_crypto(request: CryptoAnalysisRequest,
|
||||
else:
|
||||
user_prompt = f"请分析以下加密货币:{request.symbol},并给出分析报告。"
|
||||
|
||||
agent = get_db_manager().get_agent_by_id(agent_id)
|
||||
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
|
||||
|
||||
if not agent:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
||||
@ -138,7 +138,7 @@ async def analysis_crypto(request: CryptoAnalysisRequest,
|
||||
logging.info(f"Chat request data: {data}")
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().save_user_question(current_user["id"], agent_id, user_prompt)
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], agent_id, user_prompt)
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from cryptoai.routes.adata import router as adata_router
|
||||
from cryptoai.routes.question import router as question_router
|
||||
from cryptoai.routes.crypto import router as crypto_router
|
||||
from cryptoai.routes.platform import router as platform_router
|
||||
from cryptoai.routes.analysis import router as analysis_router
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
@ -58,6 +59,8 @@ app.include_router(user_router, prefix="/user", tags=["用户管理"])
|
||||
app.include_router(question_router, prefix="/question", tags=["用户提问"])
|
||||
app.include_router(adata_router, prefix="/adata", tags=["A股数据"])
|
||||
app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"])
|
||||
app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"])
|
||||
|
||||
# 请求计时中间件
|
||||
@app.middleware("http")
|
||||
async def add_process_time_header(request: Request, call_next):
|
||||
|
||||
@ -52,7 +52,7 @@ async def create_feed(feed: AgentFeedCreate) -> Dict[str, Any]:
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 保存信息流
|
||||
success = db_manager.save_agent_feed(
|
||||
success = db_manager.agent_feed_manager.save_agent_feed(
|
||||
agent_name=feed.agent_name,
|
||||
content=feed.content,
|
||||
avatar_url=feed.avatar_url
|
||||
@ -98,7 +98,7 @@ async def get_feeds(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取信息流
|
||||
feeds = db_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip)
|
||||
feeds = db_manager.agent_feed_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip)
|
||||
|
||||
return feeds
|
||||
|
||||
|
||||
@ -14,8 +14,8 @@ async def get_platform_info():
|
||||
result = {}
|
||||
|
||||
try:
|
||||
result["user_count"] = db_manager.get_user_count()
|
||||
result["question_count"] = db_manager.get_user_question_count()
|
||||
result["user_count"] = db_manager.user_manager.get_user_count()
|
||||
result["question_count"] = db_manager.user_question_manager.get_user_question_count()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@ -55,7 +55,7 @@ async def create_question(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 保存提问
|
||||
success = db_manager.save_user_question(
|
||||
success = db_manager.user_question_manager.save_user_question(
|
||||
user_id=current_user["id"],
|
||||
agent_id=question.agent_id,
|
||||
question=question.question
|
||||
@ -103,7 +103,7 @@ async def get_questions(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取提问记录
|
||||
questions = db_manager.get_user_questions(
|
||||
questions = db_manager.user_question_manager.get_user_questions(
|
||||
user_id=current_user["id"],
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
@ -139,7 +139,7 @@ async def get_question(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取提问记录
|
||||
question = db_manager.get_user_question_by_id(question_id)
|
||||
question = db_manager.user_question_manager.get_user_question_by_id(question_id)
|
||||
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
|
||||
@ -104,7 +104,7 @@ async def get_current_user(request: Request) -> Dict[str, Any]:
|
||||
raise credentials_exception
|
||||
|
||||
db_manager = get_db_manager()
|
||||
user = db_manager.get_user_by_mail(mail)
|
||||
user = db_manager.user_manager.get_user_by_mail(mail)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
@ -125,7 +125,7 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 检查邮箱是否已被注册
|
||||
user = db_manager.get_user_by_mail(request.mail)
|
||||
user = db_manager.user_manager.get_user_by_mail(request.mail)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@ -231,7 +231,7 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取用户信息
|
||||
user = db_manager.get_user_by_mail(loginData.mail)
|
||||
user = db_manager.user_manager.get_user_by_mail(loginData.mail)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@ -246,9 +246,8 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
||||
# 查询用户的密码哈希
|
||||
session = db_manager.Session()
|
||||
try:
|
||||
from cryptoai.utils.db_manager import User
|
||||
db_user = session.query(User).filter(User.mail == loginData.mail).first()
|
||||
if not db_user or db_user.password != hashed_password:
|
||||
user = db_manager.user_manager.login(loginData.mail, hashed_password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="邮箱或密码错误",
|
||||
@ -334,7 +333,7 @@ async def update_user_level(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 更新用户级别
|
||||
success = db_manager.update_user_level(user_id, level)
|
||||
success = db_manager.user_manager.update_user_level(user_id, level)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@ -373,7 +372,7 @@ async def get_user_points(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取用户信息
|
||||
user = db_manager.get_user_by_id(user_id)
|
||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@ -416,7 +415,7 @@ async def add_user_points(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 添加积分
|
||||
success = db_manager.add_user_points(user_id, points)
|
||||
success = db_manager.user_manager.add_user_points(user_id, points)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@ -425,7 +424,7 @@ async def add_user_points(
|
||||
)
|
||||
|
||||
# 获取更新后的用户信息
|
||||
user = db_manager.get_user_by_id(user_id)
|
||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
@ -461,7 +460,7 @@ async def consume_user_points(
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 消费积分
|
||||
success = db_manager.consume_user_points(user_id, points)
|
||||
success = db_manager.user_manager.consume_user_points(user_id, points)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@ -470,7 +469,7 @@ async def consume_user_points(
|
||||
)
|
||||
|
||||
# 获取更新后的用户信息
|
||||
user = db_manager.get_user_by_id(user_id)
|
||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
36
cryptoai/utils/db_utils.py
Normal file
36
cryptoai/utils/db_utils.py
Normal file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger('db_utils')
|
||||
|
||||
def convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]:
|
||||
"""
|
||||
将时间戳转换为datetime对象
|
||||
|
||||
Args:
|
||||
timestamp: Unix时间戳(毫秒级)
|
||||
|
||||
Returns:
|
||||
datetime对象或None
|
||||
"""
|
||||
if timestamp is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 转换为整数
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = int(timestamp)
|
||||
|
||||
# 如果是毫秒级时间戳,转换为秒级
|
||||
if timestamp > 1e11: # 判断是否为毫秒级时间戳
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
return datetime.fromtimestamp(timestamp)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}")
|
||||
return None
|
||||
@ -29,7 +29,7 @@ services:
|
||||
cryptoai-api:
|
||||
build: .
|
||||
container_name: cryptoai-api
|
||||
image: cryptoai-api:0.1.25
|
||||
image: cryptoai-api:0.1.26
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user