This commit is contained in:
aaron 2025-05-24 12:08:57 +08:00
parent 64119d5391
commit 1debbc7dce
25 changed files with 1764 additions and 1798 deletions

View File

@ -285,7 +285,7 @@ class CryptoAgent:
# 保存分析结果到数据库
try:
# 保存到数据库
saved = self.db_manager.save_analysis_result(
saved = self.db_manager.analysis_result_manager.save_analysis_result(
agent='crypto',
symbol=symbol,
time_interval=self.time_interval,
@ -363,7 +363,7 @@ class CryptoAgent:
# 保存交易建议到数据库
try:
saved = self.db_manager.save_agent_feed(
saved = self.db_manager.agent_feed_manager.save_agent_feed(
agent_name="Crypto Agent",
content=message
)

View File

@ -2,3 +2,13 @@
# -*- coding: utf-8 -*-
"""模型模块,提供数据处理和模型定义。"""
from cryptoai.models.base import Base
from cryptoai.models.token import Token, TokenManager
from cryptoai.models.analysis_result import AnalysisResult, AnalysisResultManager
from cryptoai.models.agent_feed import AgentFeed, AgentFeedManager
from cryptoai.models.user import User, UserManager
from cryptoai.models.user_question import UserQuestion, UserQuestionManager
from cryptoai.models.agent import Agent, AgentManager
from cryptoai.models.astock import AStock, AStockManager
from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager

268
cryptoai/models/agent.py Normal file
View File

@ -0,0 +1,268 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
from sqlalchemy.dialects.mysql import JSON
from cryptoai.models.base import Base, logger
# 定义Agent数据模型
class Agent(Base):
"""Agent数据表模型"""
__tablename__ = 'agents'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
description = Column(Text, nullable=True, comment='Agent描述')
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引和表属性
__table_args__ = (
Index('idx_name', 'name'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class AgentManager:
"""Agent管理类"""
def __init__(self, db_session):
self.session = db_session
def create_agent(self, name: str, description: Optional[str] = None,
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
"""
创建新的Agent
Args:
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
新Agent的ID如果创建失败则返回None
"""
try:
# 检查名称是否已存在
existing_agent = self.session.query(Agent).filter(Agent.name == name).first()
if existing_agent:
logger.warning(f"Agent名称 {name} 已存在")
return None
# 创建新Agent
new_agent = Agent(
name=name,
description=description,
hello_prompt=hello_prompt,
dify_token=dify_token,
inputs=inputs,
create_time=datetime.now(),
update_time=datetime.now()
)
# 添加并提交
self.session.add(new_agent)
self.session.commit()
logger.info(f"成功创建Agent: {name}")
return new_agent.id
except Exception as e:
self.session.rollback()
logger.error(f"创建Agent失败: {e}")
return None
def update_agent(self, agent_id: int, name: Optional[str] = None,
description: Optional[str] = None, hello_prompt: Optional[str] = None,
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
"""
更新Agent信息
Args:
agent_id: Agent ID
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
更新是否成功
"""
try:
# 查询Agent
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 更新字段
if name is not None:
# 检查名称是否与其他Agent冲突
if name != agent.name:
existing = self.session.query(Agent).filter(Agent.name == name).first()
if existing:
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
return False
agent.name = name
if description is not None:
agent.description = description
if hello_prompt is not None:
agent.hello_prompt = hello_prompt
if dify_token is not None:
agent.dify_token = dify_token
if inputs is not None:
agent.inputs = inputs
agent.update_time = datetime.now()
# 提交更改
self.session.commit()
logger.info(f"成功更新Agent ID: {agent_id}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"更新Agent失败: {e}")
return False
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取Agent信息
Args:
agent_id: Agent ID
Returns:
Agent信息如果不存在则返回None
"""
try:
# 查询Agent
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
"""
通过名称获取Agent信息
Args:
name: Agent名称
Returns:
Agent信息如果不存在则返回None
"""
try:
# 查询Agent
agent = self.session.query(Agent).filter(Agent.name == name).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取Agent列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
Agent列表
"""
try:
# 查询Agent列表
agents = self.session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
# 转换为字典列表
return [{
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
} for agent in agents]
except Exception as e:
logger.error(f"获取Agent列表失败: {e}")
return []
def delete_agent(self, agent_id: int) -> bool:
"""
删除Agent
Args:
agent_id: Agent ID
Returns:
删除是否成功
"""
try:
# 查询Agent
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 删除Agent
self.session.delete(agent)
self.session.commit()
logger.info(f"成功删除Agent ID: {agent_id}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"删除Agent失败: {e}")
return False

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
from cryptoai.models.base import Base, logger
# 定义AI Agent信息流模型
class AgentFeed(Base):
"""AI Agent信息流表模型"""
__tablename__ = 'agent_feeds'
id = Column(Integer, primary_key=True, autoincrement=True)
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
avatar_url = Column(String(255), nullable=True, comment='头像URL')
content = Column(Text, nullable=False, comment='内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引和表属性
__table_args__ = (
Index('idx_agent_name', 'agent_name'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class AgentFeedManager:
"""Agent信息流管理类"""
def __init__(self, db_session):
self.session = db_session
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
"""
保存AI Agent信息流到数据库
Args:
agent_name: AI Agent名称
content: 内容
avatar_url: 头像URL可选
Returns:
保存是否成功
"""
try:
# 创建新记录
new_feed = AgentFeed(
agent_name=agent_name,
content=content,
avatar_url=avatar_url
)
# 添加并提交
self.session.add(new_feed)
self.session.commit()
logger.info(f"成功保存 {agent_name} 的信息流")
return True
except Exception as e:
self.session.rollback()
logger.error(f"保存信息流失败: {e}")
return False
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取AI Agent信息流
Args:
agent_name: 可选指定获取特定Agent的信息流
limit: 返回的最大记录数默认20条
skip: 跳过的记录数默认0条
Returns:
信息流列表如果查询失败则返回空列表
"""
try:
# 构建查询
query = self.session.query(AgentFeed)
# 如果指定了agent_name则筛选
if agent_name:
query = query.filter(AgentFeed.agent_name == agent_name)
# 按创建时间降序排序并限制数量
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
# 转换为字典列表
feeds = []
for result in results:
feeds.append({
'id': result.id,
'agent_name': result.agent_name,
'avatar_url': result.avatar_url,
'content': result.content,
'create_time': result.create_time
})
return feeds
except Exception as e:
logger.error(f"获取信息流失败: {e}")
return []

View File

@ -0,0 +1,189 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
from sqlalchemy.orm import relationship
from cryptoai.models.base import Base, logger
# 定义分析历史模型
class AnalysisHistory(Base):
"""分析历史表模型"""
__tablename__ = 'analysis_history'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
type = Column(String(20), nullable=False, comment='分析类型(crypto, astock)')
symbol = Column(String(50), nullable=False, comment='交易符号')
timeframe = Column(String(20), nullable=True, comment='时间框架')
content = Column(Text, nullable=False, comment='分析内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
user = relationship("User", back_populates="analysis_histories")
# 索引和表属性
__table_args__ = (
Index('idx_user_id', 'user_id'),
Index('idx_type', 'type'),
Index('idx_symbol', 'symbol'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class AnalysisHistoryManager:
"""分析历史管理类"""
def __init__(self, db_session):
self.session = db_session
def add_analysis_history(self, user_id: int, type: str, symbol: str,
content: str, timeframe: str = None) -> bool:
"""
添加分析历史记录
Args:
user_id: 用户ID
type: 分析类型(crypto, astock)
symbol: 交易符号
content: 分析内容
timeframe: 时间框架可选
Returns:
添加是否成功
"""
try:
# 创建新记录
new_history = AnalysisHistory(
user_id=user_id,
type=type,
symbol=symbol,
timeframe=timeframe,
content=content,
create_time=datetime.now()
)
# 添加并提交
self.session.add(new_history)
self.session.commit()
logger.info(f"成功添加分析历史记录用户ID: {user_id}, 类型: {type}, 交易符号: {symbol}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"添加分析历史记录失败: {e}")
return False
def delete_analysis_history(self, history_id: int) -> bool:
"""
删除分析历史记录
Args:
history_id: 历史记录ID
Returns:
删除是否成功
"""
try:
# 查询记录
history = self.session.query(AnalysisHistory).filter(AnalysisHistory.id == history_id).first()
if not history:
logger.warning(f"分析历史记录ID {history_id} 不存在")
return False
# 删除记录
self.session.delete(history)
self.session.commit()
logger.info(f"成功删除分析历史记录ID: {history_id}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"删除分析历史记录失败: {e}")
return False
def get_analysis_history_by_id(self, history_id: int) -> Optional[Dict[str, Any]]:
"""
根据ID获取分析历史记录
Args:
history_id: 历史记录ID
Returns:
分析历史记录如果不存在则返回None
"""
try:
# 查询记录
history = self.session.query(AnalysisHistory).filter(AnalysisHistory.id == history_id).first()
if history:
# 转换为字典
return {
'id': history.id,
'user_id': history.user_id,
'type': history.type,
'symbol': history.symbol,
'timeframe': history.timeframe,
'content': history.content,
'create_time': history.create_time
}
else:
return None
except Exception as e:
logger.error(f"获取分析历史记录失败: {e}")
return None
def get_user_analysis_history(self, user_id: int, type: str = None,
symbol: str = None, limit: int = 10,
offset: int = 0) -> List[Dict[str, Any]]:
"""
获取用户的分析历史记录
Args:
user_id: 用户ID
type: 分析类型筛选可选
symbol: 交易符号筛选可选
limit: 返回记录数量限制
offset: 分页偏移量
Returns:
分析历史记录列表
"""
try:
# 构建查询
query = self.session.query(AnalysisHistory).filter(AnalysisHistory.user_id == user_id)
# 添加可选过滤条件
if type:
query = query.filter(AnalysisHistory.type == type)
if symbol:
query = query.filter(AnalysisHistory.symbol == symbol)
# 按创建时间降序排序,并应用分页
results = query.order_by(AnalysisHistory.create_time.desc()).limit(limit).offset(offset).all()
# 转换为字典列表
history_list = []
for history in results:
history_list.append({
'id': history.id,
'user_id': history.user_id,
'type': history.type,
'symbol': history.symbol,
'timeframe': history.timeframe,
'content': history.content,
'create_time': history.create_time
})
return history_list
except Exception as e:
logger.error(f"获取用户分析历史记录失败: {e}")
return []

View File

@ -0,0 +1,111 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Index
from sqlalchemy.dialects.mysql import JSON
from cryptoai.models.base import Base, logger
# 定义分析结果模型
class AnalysisResult(Base):
"""分析结果表模型"""
__tablename__ = 'analysis_results'
id = Column(Integer, primary_key=True, autoincrement=True)
agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)')
symbol = Column(String(50), nullable=False, comment='交易对符号')
time_interval = Column(String(20), nullable=False, comment='时间间隔')
completion_result = Column(JSON, nullable=False, comment='分析结果JSON')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引
__table_args__ = (
Index('idx_agent', 'agent'),
Index('idx_symbol', 'symbol'),
Index('idx_time_interval', 'time_interval'),
Index('idx_created_at', 'created_at'),
)
class AnalysisResultManager:
"""分析结果管理类"""
def __init__(self, db_session):
self.session = db_session
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
analysis_result: Dict[str, Any]) -> bool:
"""
保存分析结果到数据库
Args:
agent: 智能体类型例如 'crypto' 'gold'
symbol: 交易对符号例如 'BTCUSDT'
time_interval: 时间间隔例如 '1h', '4h', '1d'
analysis_result: 分析结果数据
Returns:
保存是否成功
"""
try:
# 创建新记录
new_result = AnalysisResult(
agent=agent,
symbol=symbol,
time_interval=time_interval,
completion_result=analysis_result,
created_at=datetime.now(),
updated_at=datetime.now()
)
# 添加并提交
self.session.add(new_result)
self.session.commit()
logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"保存分析结果失败: {e}")
return False
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
"""
获取最新的分析结果
Args:
agent: 智能体类型例如 'crypto' 'gold'
symbol: 交易对符号例如 'BTCUSDT'
time_interval: 时间间隔例如 '1h', '4h', '1d'
Returns:
最新分析结果如果查询失败则返回None
"""
try:
# 查询最新的结果
result = self.session.query(AnalysisResult).filter(
AnalysisResult.agent == agent,
AnalysisResult.symbol == symbol,
AnalysisResult.time_interval == time_interval
).order_by(AnalysisResult.created_at.desc()).first()
if result:
# 转换为字典
return {
'id': result.id,
'agent': result.agent,
'symbol': result.symbol,
'time_interval': result.time_interval,
'completion_result': result.completion_result,
'created_at': result.created_at
}
else:
return None
except Exception as e:
logger.error(f"获取最新分析结果失败: {e}")
return None

314
cryptoai/models/astock.py Normal file
View File

@ -0,0 +1,314 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Index
from cryptoai.models.base import Base, logger
from cryptoai.utils.db_utils import convert_timestamp_to_datetime
# 定义 A 股数据模型
class AStock(Base):
"""A股股票基本信息表模型"""
__tablename__ = 'astock'
stock_code = Column(String(10), primary_key=True, comment='股票代码')
short_name = Column(String(50), nullable=False, comment='股票简称')
exchange = Column(String(20), nullable=True, comment='交易所')
list_date = Column(DateTime, nullable=True, comment='上市日期')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引
__table_args__ = (
Index('idx_stock_code', 'stock_code', unique=True),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class AStockManager:
"""A股管理类"""
def __init__(self, db_session):
self.session = db_session
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
"""
批量创建股票信息
Args:
stocks: 股票信息列表每个元素为包含以下键的字典
- stock_code: 股票代码
- short_name: 股票简称
- exchange: 交易所可选
- list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
try:
for stock_data in stocks:
# 检查股票代码是否已存在
existing_stock = self.session.query(AStock).filter(
AStock.stock_code == stock_data['stock_code']
).first()
if existing_stock:
logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过")
continue
# 创建新股票记录
new_stock = AStock(
stock_code=stock_data['stock_code'],
short_name=stock_data['short_name'],
exchange=stock_data.get('exchange')
)
# 处理上市日期
if 'list_date' in stock_data and stock_data['list_date']:
list_date = convert_timestamp_to_datetime(stock_data['list_date'])
if list_date:
new_stock.list_date = list_date
self.session.add(new_stock)
# 批量提交
self.session.commit()
logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录")
return True
except Exception as e:
self.session.rollback()
logger.error(f"批量创建股票信息失败: {e}")
return False
def create_stock(self, stock_code: str, short_name: str,
exchange: Optional[str] = None,
list_date: Optional[Union[int, float, str]] = None) -> bool:
"""
创建新的股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所可选
list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
try:
# 检查股票代码是否已存在
existing_stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if existing_stock:
logger.warning(f"股票代码 {stock_code} 已存在")
return False
# 创建新股票记录
new_stock = AStock(
stock_code=stock_code,
short_name=short_name,
exchange=exchange
)
# 处理上市日期
if list_date:
converted_date = convert_timestamp_to_datetime(list_date)
if converted_date:
new_stock.list_date = converted_date
self.session.add(new_stock)
self.session.commit()
logger.info(f"成功创建股票信息: {stock_code} - {short_name}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"创建股票信息失败: {e}")
return False
def update_stock(self, stock_code: str, short_name: Optional[str] = None,
exchange: Optional[str] = None,
list_date: Optional[Union[int, float, str, datetime]] = None) -> bool:
"""
更新股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所
list_date: 上市日期可以是datetime对象或时间戳
Returns:
更新是否成功
"""
try:
# 查询股票
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 更新字段
if short_name is not None:
stock.short_name = short_name
if exchange is not None:
stock.exchange = exchange
if list_date is not None:
if isinstance(list_date, datetime):
stock.list_date = list_date
else:
converted_date = convert_timestamp_to_datetime(list_date)
if converted_date:
stock.list_date = converted_date
self.session.commit()
logger.info(f"成功更新股票信息: {stock_code}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"更新股票信息失败: {e}")
return False
def delete_stock(self, stock_code: str) -> bool:
"""
删除股票信息
Args:
stock_code: 股票代码
Returns:
删除是否成功
"""
try:
# 查询股票
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 删除股票
self.session.delete(stock)
self.session.commit()
logger.info(f"成功删除股票信息: {stock_code}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"删除股票信息失败: {e}")
return False
def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]:
"""
通过股票代码获取股票信息
Args:
stock_code: 股票代码
Returns:
股票信息如果不存在则返回None
"""
try:
# 查询股票
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if stock:
return {
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
}
else:
return None
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return None
def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
搜索股票
Args:
key: 搜索关键词
limit: 最大返回数量
Returns:
股票信息列表
"""
try:
# 查询股票
stocks = self.session.query(AStock).filter(
AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")
).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]:
"""
通过股票简称获取股票信息可能有多个结果
Args:
short_name: 股票简称
Returns:
股票信息列表
"""
try:
# 查询股票
stocks = self.session.query(AStock).filter(AStock.short_name == short_name).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取股票列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
股票信息列表
"""
try:
# 查询股票列表
stocks = self.session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
except Exception as e:
logger.error(f"获取股票列表失败: {e}")
return []

20
cryptoai/models/base.py Normal file
View File

@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.dialects.mysql import JSON
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('db_models')
# 创建模型基类
Base = declarative_base()

127
cryptoai/models/token.py Normal file
View File

@ -0,0 +1,127 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Index
from cryptoai.models.base import Base, logger
# 定义Token模型
class Token(Base):
"""Token信息表模型"""
__tablename__ = 'tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
symbol = Column(String(50), nullable=False, unique=True, comment='交易对符号')
base_asset = Column(String(20), nullable=False, comment='基础资产')
quote_asset = Column(String(20), nullable=False, comment='计价资产')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引
__table_args__ = (
Index('idx_symbol', 'symbol'),
Index('idx_base_asset', 'base_asset'),
Index('idx_quote_asset', 'quote_asset'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class TokenManager:
"""Token管理类"""
def __init__(self, db_session):
self.session = db_session
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
"""
创建新的Token信息
Args:
symbol: 交易对符号
base_asset: 基础资产
quote_asset: 计价资产
Returns:
创建是否成功
"""
try:
# 检查交易对是否已存在
existing_token = self.session.query(Token).filter(Token.symbol == symbol).first()
if existing_token:
logger.warning(f"交易对 {symbol} 已存在")
return False
# 创建新Token记录
new_token = Token(
symbol=symbol,
base_asset=base_asset,
quote_asset=quote_asset
)
self.session.add(new_token)
self.session.commit()
logger.info(f"成功创建Token信息: {symbol}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"创建Token信息失败: {e}")
return False
def delete_token(self, symbol: str) -> bool:
"""
删除Token信息
Args:
symbol: 交易对符号
Returns:
删除是否成功
"""
try:
# 查询Token
token = self.session.query(Token).filter(Token.symbol == symbol).first()
if not token:
logger.warning(f"交易对 {symbol} 不存在")
return False
# 删除Token
self.session.delete(token)
self.session.commit()
logger.info(f"成功删除Token信息: {symbol}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"删除Token信息失败: {e}")
return False
def search_token(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
搜索Token
Args:
key: 搜索关键词
limit: 最大返回数量
Returns:
Token信息列表
"""
try:
# 使用 SQLAlchemy 的 ORM 查询
tokens = self.session.query(Token).filter(
Token.symbol.like(f"{key}%") | Token.base_asset.like(f"{key}%")
).limit(limit).all()
return [{
'symbol': token.symbol,
'base_asset': token.base_asset,
'quote_asset': token.quote_asset
} for token in tokens]
except Exception as e:
logger.error(f"搜索Token信息失败: {e}")
return []

280
cryptoai/models/user.py Normal file
View File

@ -0,0 +1,280 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Index
from sqlalchemy.orm import relationship
from cryptoai.models.base import Base, logger
# 定义用户数据模型
class User(Base):
"""用户数据表模型"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
mail = Column(String(100), nullable=False, unique=True, comment='邮箱')
nickname = Column(String(50), nullable=False, comment='昵称')
password = Column(String(100), nullable=False, comment='密码')
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
points = Column(Integer, nullable=False, default=0, comment='用户积分')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
questions = relationship("UserQuestion", back_populates="user")
analysis_histories = relationship("AnalysisHistory", back_populates="user")
# 索引和表属性
__table_args__ = (
Index('idx_mail', 'mail'),
Index('idx_level', 'level'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class UserManager:
"""用户管理类"""
def __init__(self, db_session):
self.session = db_session
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
"""
注册新用户
Args:
mail: 邮箱
nickname: 昵称
password: 密码
level: 用户级别默认为0普通用户
points: 初始积分默认为0
Returns:
注册是否成功
"""
try:
# 检查邮箱是否已存在
existing_user = self.session.query(User).filter(User.mail == mail).first()
if existing_user:
logger.warning(f"邮箱 {mail} 已被注册")
return False
# 创建新用户
new_user = User(
mail=mail,
nickname=nickname,
password=password, # 实际应用中应该对密码进行哈希处理
level=level,
points=points,
create_time=datetime.now()
)
# 添加并提交
self.session.add(new_user)
self.session.commit()
logger.info(f"成功注册用户: {mail},初始积分: {points}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"注册用户失败: {e}")
return False
def get_user_count(self) -> int:
"""
获取用户数量
"""
try:
# 查询用户数量
user_count = self.session.query(User).count()
return user_count
except Exception as e:
logger.error(f"获取用户数量失败: {e}")
return 0
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取用户信息
Args:
user_id: 用户ID
Returns:
用户信息如果用户不存在则返回None
"""
try:
# 查询用户
user = self.session.query(User).filter(User.id == user_id).first()
if user:
# 转换为字典
return {
'id': user.id,
'mail': user.mail,
'nickname': user.nickname,
'level': user.level,
'points': user.points,
'create_time': user.create_time
}
else:
return None
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
def login(self, mail: str, password: str) -> Optional[Dict[str, Any]]:
"""
登录
"""
user = self.session.query(User).filter(User.mail == mail).first()
if not user:
return None
if user.password != password:
return None
return {'id': user.id, 'mail': user.mail, 'nickname': user.nickname, 'level': user.level, 'points': user.points, 'create_time': user.create_time}
def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]:
"""
通过邮箱获取用户信息
Args:
mail: 邮箱
Returns:
用户信息如果用户不存在则返回None
"""
try:
# 查询用户
user = self.session.query(User).filter(User.mail == mail).first()
if user:
# 转换为字典
return {
'id': user.id,
'mail': user.mail,
'nickname': user.nickname,
'level': user.level,
'points': user.points,
'create_time': user.create_time
}
else:
return None
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
def update_user_level(self, user_id: int, level: int) -> bool:
"""
更新用户级别
Args:
user_id: 用户ID
level: 新的用户级别
Returns:
更新是否成功
"""
try:
# 查询用户
user = self.session.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"用户ID {user_id} 不存在")
return False
# 更新级别
user.level = level
self.session.commit()
logger.info(f"成功更新用户 {user.mail} 的级别为 {level}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"更新用户级别失败: {e}")
return False
def add_user_points(self, user_id: int, points: int) -> bool:
"""
为用户增加积分
Args:
user_id: 用户ID
points: 增加的积分数量正数
Returns:
操作是否成功
"""
if points <= 0:
logger.warning(f"增加的积分必须是正数: {points}")
return False
try:
# 查询用户
user = self.session.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"用户ID {user_id} 不存在")
return False
# 增加积分
user.points += points
self.session.commit()
logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"增加用户积分失败: {e}")
return False
def consume_user_points(self, user_id: int, points: int) -> bool:
"""
用户消费积分
Args:
user_id: 用户ID
points: 消费的积分数量正数
Returns:
操作是否成功
"""
if points <= 0:
logger.warning(f"消费的积分必须是正数: {points}")
return False
try:
# 查询用户
user = self.session.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"用户ID {user_id} 不存在")
return False
# 检查积分是否足够
if user.points < points:
logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}")
return False
# 消费积分
user.points -= points
self.session.commit()
logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"消费用户积分失败: {e}")
return False

View File

@ -0,0 +1,164 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
from sqlalchemy.orm import relationship
from cryptoai.models.base import Base, logger
from cryptoai.models.user import User
# 定义用户提问数据模型
class UserQuestion(Base):
"""用户提问数据表模型"""
__tablename__ = 'user_questions'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
question = Column(Text, nullable=False, comment='提问内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
user = relationship("User", back_populates="questions")
# 索引和表属性
__table_args__ = (
Index('idx_user_id', 'user_id'),
Index('idx_agent_id', 'agent_id'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class UserQuestionManager:
"""用户提问管理类"""
def __init__(self, db_session):
self.session = db_session
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
"""
保存用户提问数据
Args:
user_id: 用户ID
agent_id: AI Agent ID
question: 提问内容
Returns:
保存是否成功
"""
try:
# 创建新记录
new_question = UserQuestion(
user_id=user_id,
agent_id=agent_id,
question=question,
create_time=datetime.now()
)
# 添加并提交
self.session.add(new_question)
self.session.commit()
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
return True
except Exception as e:
self.session.rollback()
logger.error(f"保存用户提问失败: {e}")
return False
def get_user_question_count(self) -> int:
"""
获取用户提问数量
Returns:
提问总数
"""
try:
# 查询用户提问数量
question_count = self.session.query(UserQuestion).count()
return question_count
except Exception as e:
logger.error(f"获取用户提问数量失败: {e}")
return 0
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取用户提问数据
Args:
user_id: 可选指定获取特定用户的提问
agent_id: 可选指定获取特定Agent的提问
limit: 返回的最大记录数默认20条
skip: 跳过的记录数默认0条
Returns:
提问数据列表如果查询失败则返回空列表
"""
try:
# 构建查询
query = self.session.query(UserQuestion)
# 如果指定了user_id则筛选
if user_id:
query = query.filter(UserQuestion.user_id == user_id)
# 如果指定了agent_id则筛选
if agent_id:
query = query.filter(UserQuestion.agent_id == agent_id)
# 按创建时间降序排序并限制数量
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
# 转换为字典列表
questions = []
for result in results:
questions.append({
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
})
return questions
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return []
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取用户提问数据
Args:
question_id: 提问ID
Returns:
提问数据如果不存在则返回None
"""
try:
# 查询提问
result = self.session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
if result:
# 转换为字典
return {
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
}
else:
return None
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return None

View File

@ -22,7 +22,7 @@ logger.setLevel(logging.DEBUG)
@router.get("/stock/search")
async def search_stock(key: str, limit: int = 10):
manager = get_db_manager()
result = manager.search_stock(key, limit)
result = manager.astock_manager.search_stock(key, limit)
return result
@ -123,16 +123,9 @@ async def get_stock_data_all(stock_code: str):
return result
@router.post('/{stock_code}/analysis', summary="获取股票分析数据")
async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Depends(get_current_user)):
# 检查stock_code是否存在
# codes = get_db_manager().search_stock(stock_code)
# if not codes or len(codes) == 0:
# raise HTTPException(status_code=400, detail="您输入的股票代码不存在,请检查后重新输入。")
# stock_code = codes[0]["stock_code"]
url = 'https://mate.aimateplus.com/v1/workflows/run'
token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV'
@ -150,7 +143,7 @@ async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Dep
}
# 保存用户提问
get_db_manager().save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
response = requests.post(url, headers=headers, json=data, stream=True)

View File

@ -49,7 +49,7 @@ async def create_agent(agent: AgentCreate):
创建新的AI Agent
"""
agent = get_db_manager().create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
agent = get_db_manager().agent_manager.create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
return agent
@ -64,7 +64,7 @@ async def get_agents(
"""
# 从数据库获取Agent列表
agents = get_db_manager().list_agents(limit=limit, skip=skip)
agents = get_db_manager().agent_manager.list_agents(limit=limit, skip=skip)
return agents
@ -76,7 +76,7 @@ async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_
# 尝试从数据库获取Agent
try:
agent_id = int(request.agent_id)
agent = get_db_manager().get_agent_by_id(agent_id)
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
if not agent:
raise HTTPException(status_code=400, detail="Invalid agent ID")
@ -106,7 +106,7 @@ async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_
logging.info(f"Chat request data: {data}")
# 保存用户提问
get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt)
get_db_manager().user_question_manager.save_user_question(current_user["id"], request.agent_id, request.user_prompt)
response = requests.post(url, headers=headers, json=data, stream=True)

View File

@ -0,0 +1,30 @@
from fastapi import APIRouter
from typing import Optional
from cryptoai.utils.db_manager import get_db_manager
from fastapi import Depends
from pydantic import BaseModel
from cryptoai.routes.user import get_current_user
class AnalysisHistoryRequest(BaseModel):
symbol: str
content: str
timeframe: Optional[str] = None
type: str
router = APIRouter()
@router.post("/analysis_history")
async def analysis_history(request: AnalysisHistoryRequest,
current_user: dict = Depends(get_current_user)):
get_db_manager().analysis_history_manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
return {"message": "ok"}
@router.get("/analysis_histories")
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
limit: int = 10,
offset: int = 0):
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
return history

View File

@ -34,7 +34,7 @@ class CryptoAnalysisRequest(BaseModel):
@router.get("/kline/{symbol}")
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
# 检查symbol是否存在
tokens = get_db_manager().search_token(symbol)
tokens = get_db_manager().token_manager.search_token(symbol)
if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
@ -57,7 +57,7 @@ async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit:
async def analysis_crypto_v2(request: CryptoAnalysisRequest,
current_user: dict = Depends(get_current_user)):
# 检查symbol是否存在
tokens = get_db_manager().search_token(request.symbol)
tokens = get_db_manager().token_manager.search_token(request.symbol)
if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
@ -80,7 +80,7 @@ async def analysis_crypto_v2(request: CryptoAnalysisRequest,
}
# 保存用户提问
get_db_manager().save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
response = requests.post(url, headers=headers, json=data, stream=True)
@ -111,7 +111,7 @@ async def analysis_crypto(request: CryptoAnalysisRequest,
else:
user_prompt = f"请分析以下加密货币:{request.symbol},并给出分析报告。"
agent = get_db_manager().get_agent_by_id(agent_id)
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
if not agent:
raise HTTPException(status_code=400, detail="Invalid agent ID")
@ -138,7 +138,7 @@ async def analysis_crypto(request: CryptoAnalysisRequest,
logging.info(f"Chat request data: {data}")
# 保存用户提问
get_db_manager().save_user_question(current_user["id"], agent_id, user_prompt)
get_db_manager().user_question_manager.save_user_question(current_user["id"], agent_id, user_prompt)
response = requests.post(url, headers=headers, json=data, stream=True)

View File

@ -22,6 +22,7 @@ from cryptoai.routes.adata import router as adata_router
from cryptoai.routes.question import router as question_router
from cryptoai.routes.crypto import router as crypto_router
from cryptoai.routes.platform import router as platform_router
from cryptoai.routes.analysis import router as analysis_router
# 配置日志
logging.basicConfig(
@ -58,6 +59,8 @@ app.include_router(user_router, prefix="/user", tags=["用户管理"])
app.include_router(question_router, prefix="/question", tags=["用户提问"])
app.include_router(adata_router, prefix="/adata", tags=["A股数据"])
app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"])
app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"])
# 请求计时中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):

View File

@ -52,7 +52,7 @@ async def create_feed(feed: AgentFeedCreate) -> Dict[str, Any]:
db_manager = get_db_manager()
# 保存信息流
success = db_manager.save_agent_feed(
success = db_manager.agent_feed_manager.save_agent_feed(
agent_name=feed.agent_name,
content=feed.content,
avatar_url=feed.avatar_url
@ -98,7 +98,7 @@ async def get_feeds(
db_manager = get_db_manager()
# 获取信息流
feeds = db_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip)
feeds = db_manager.agent_feed_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip)
return feeds

View File

@ -14,8 +14,8 @@ async def get_platform_info():
result = {}
try:
result["user_count"] = db_manager.get_user_count()
result["question_count"] = db_manager.get_user_question_count()
result["user_count"] = db_manager.user_manager.get_user_count()
result["question_count"] = db_manager.user_question_manager.get_user_question_count()
return result
except Exception as e:

View File

@ -55,7 +55,7 @@ async def create_question(
db_manager = get_db_manager()
# 保存提问
success = db_manager.save_user_question(
success = db_manager.user_question_manager.save_user_question(
user_id=current_user["id"],
agent_id=question.agent_id,
question=question.question
@ -103,7 +103,7 @@ async def get_questions(
db_manager = get_db_manager()
# 获取提问记录
questions = db_manager.get_user_questions(
questions = db_manager.user_question_manager.get_user_questions(
user_id=current_user["id"],
agent_id=agent_id,
limit=limit,
@ -139,7 +139,7 @@ async def get_question(
db_manager = get_db_manager()
# 获取提问记录
question = db_manager.get_user_question_by_id(question_id)
question = db_manager.user_question_manager.get_user_question_by_id(question_id)
if not question:
raise HTTPException(

View File

@ -104,7 +104,7 @@ async def get_current_user(request: Request) -> Dict[str, Any]:
raise credentials_exception
db_manager = get_db_manager()
user = db_manager.get_user_by_mail(mail)
user = db_manager.user_manager.get_user_by_mail(mail)
if user is None:
raise credentials_exception
return user
@ -125,7 +125,7 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s
db_manager = get_db_manager()
# 检查邮箱是否已被注册
user = db_manager.get_user_by_mail(request.mail)
user = db_manager.user_manager.get_user_by_mail(request.mail)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -231,7 +231,7 @@ async def login(loginData: UserLogin) -> TokenResponse:
db_manager = get_db_manager()
# 获取用户信息
user = db_manager.get_user_by_mail(loginData.mail)
user = db_manager.user_manager.get_user_by_mail(loginData.mail)
if not user:
raise HTTPException(
@ -246,9 +246,8 @@ async def login(loginData: UserLogin) -> TokenResponse:
# 查询用户的密码哈希
session = db_manager.Session()
try:
from cryptoai.utils.db_manager import User
db_user = session.query(User).filter(User.mail == loginData.mail).first()
if not db_user or db_user.password != hashed_password:
user = db_manager.user_manager.login(loginData.mail, hashed_password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="邮箱或密码错误",
@ -334,7 +333,7 @@ async def update_user_level(
db_manager = get_db_manager()
# 更新用户级别
success = db_manager.update_user_level(user_id, level)
success = db_manager.user_manager.update_user_level(user_id, level)
if not success:
raise HTTPException(
@ -373,7 +372,7 @@ async def get_user_points(
db_manager = get_db_manager()
# 获取用户信息
user = db_manager.get_user_by_id(user_id)
user = db_manager.user_manager.get_user_by_id(user_id)
if not user:
raise HTTPException(
@ -416,7 +415,7 @@ async def add_user_points(
db_manager = get_db_manager()
# 添加积分
success = db_manager.add_user_points(user_id, points)
success = db_manager.user_manager.add_user_points(user_id, points)
if not success:
raise HTTPException(
@ -425,7 +424,7 @@ async def add_user_points(
)
# 获取更新后的用户信息
user = db_manager.get_user_by_id(user_id)
user = db_manager.user_manager.get_user_by_id(user_id)
return {
"status": "success",
@ -461,7 +460,7 @@ async def consume_user_points(
db_manager = get_db_manager()
# 消费积分
success = db_manager.consume_user_points(user_id, points)
success = db_manager.user_manager.consume_user_points(user_id, points)
if not success:
raise HTTPException(
@ -470,7 +469,7 @@ async def consume_user_points(
)
# 获取更新后的用户信息
user = db_manager.get_user_by_id(user_id)
user = db_manager.user_manager.get_user_by_id(user_id)
return {
"status": "success",

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from typing import Optional, Union
from datetime import datetime
# 配置日志
logger = logging.getLogger('db_utils')
def convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]:
"""
将时间戳转换为datetime对象
Args:
timestamp: Unix时间戳毫秒级
Returns:
datetime对象或None
"""
if timestamp is None:
return None
try:
# 转换为整数
if isinstance(timestamp, str):
timestamp = int(timestamp)
# 如果是毫秒级时间戳,转换为秒级
if timestamp > 1e11: # 判断是否为毫秒级时间戳
timestamp = timestamp / 1000
return datetime.fromtimestamp(timestamp)
except (ValueError, TypeError) as e:
logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}")
return None

View File

@ -29,7 +29,7 @@ services:
cryptoai-api:
build: .
container_name: cryptoai-api
image: cryptoai-api:0.1.25
image: cryptoai-api:0.1.26
restart: always
ports:
- "8000:8000"

View File

@ -17,4 +17,4 @@ if __name__ == "__main__":
base_asset = symbol.split('USDT')[0]
quote_asset = 'USDT'
manager.create_token(symbol,base_asset, quote_asset)
manager.token_manager.create_token(symbol,base_asset, quote_asset)