This commit is contained in:
aaron 2025-05-30 22:09:45 +08:00
parent 4027f68dca
commit 08c4efcfd3
17 changed files with 158 additions and 423 deletions

View File

@ -17,7 +17,6 @@ from api.deepseek_api import DeepSeekAPI
from models.data_processor import DataProcessor from models.data_processor import DataProcessor
from utils.config_loader import ConfigLoader from utils.config_loader import ConfigLoader
from utils.dingtalk_bot import DingTalkBot from utils.dingtalk_bot import DingTalkBot
from utils.db_manager import get_db_manager
from utils.discord_bot import DiscordBot from utils.discord_bot import DiscordBot
class CryptoAgent: class CryptoAgent:
@ -79,9 +78,6 @@ class CryptoAgent:
) )
print("Discord机器人已启用") print("Discord机器人已启用")
# 初始化数据库管理器
self.db_manager = get_db_manager()
# 设置支持的加密货币 # 设置支持的加密货币
self.base_currencies = self.crypto_config['base_currencies'] self.base_currencies = self.crypto_config['base_currencies']
self.quote_currency = self.crypto_config['quote_currency'] self.quote_currency = self.crypto_config['quote_currency']

View File

@ -7,7 +7,7 @@ import adata
from typing import Dict, List, Optional from typing import Dict, List, Optional
import pandas as pd import pandas as pd
from datetime import datetime, timedelta from datetime import datetime, timedelta
from cryptoai.utils.db_manager import get_db_context
class AStockAPI: class AStockAPI:
@staticmethod @staticmethod
def get_all_stock_codes() -> List[str]: def get_all_stock_codes() -> List[str]:
@ -156,8 +156,11 @@ if __name__ == "__main__":
print(list) print(list)
# 保存到数据库 # 保存到数据库
import cryptoai.utils.db_manager as db_manager from cryptoai.models.astock import AStockManager
db_manager.get_db_manager().create_stocks(list)
session = get_db_context()
manager = AStockManager(session)
manager.create_stocks(list)

View File

@ -3,12 +3,12 @@
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from cryptoai.models.base import Base, logger from cryptoai.models.base import Base, logger
from cryptoai.utils.db_manager import get_db_context
# 定义分析历史模型 # 定义分析历史模型
class AnalysisHistory(Base): class AnalysisHistory(Base):
"""分析历史表模型""" """分析历史表模型"""
@ -37,8 +37,8 @@ class AnalysisHistory(Base):
class AnalysisHistoryManager: class AnalysisHistoryManager:
"""分析历史管理类""" """分析历史管理类"""
def __init__(self, db_session): def __init__(self, session: Session = None):
self.session = db_session self.session = session
def add_analysis_history(self, user_id: int, type: str, symbol: str, def add_analysis_history(self, user_id: int, type: str, symbol: str,
content: str, timeframe: str = None) -> bool: content: str, timeframe: str = None) -> bool:

View File

@ -3,12 +3,12 @@
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy import Column, Integer, String, DateTime, Index
from sqlalchemy.dialects.mysql import JSON from sqlalchemy.dialects.mysql import JSON
from cryptoai.models.base import Base, logger from cryptoai.models.base import Base, logger
from cryptoai.utils.db_manager import get_db_context
# 定义分析结果模型 # 定义分析结果模型
class AnalysisResult(Base): class AnalysisResult(Base):
"""分析结果表模型""" """分析结果表模型"""
@ -33,8 +33,8 @@ class AnalysisResult(Base):
class AnalysisResultManager: class AnalysisResultManager:
"""分析结果管理类""" """分析结果管理类"""
def __init__(self, db_session): def __init__(self, session: Session = None):
self.session = db_session self.session = session
def save_analysis_result(self, agent: str, symbol: str, time_interval: str, def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
analysis_result: Dict[str, Any]) -> bool: analysis_result: Dict[str, Any]) -> bool:

View File

@ -5,9 +5,10 @@ from typing import Dict, Any, List, Optional, Union
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy import Column, Integer, String, DateTime, Index
from sqlalchemy.orm import Session
from cryptoai.models.base import Base, logger from cryptoai.models.base import Base, logger
from cryptoai.utils.db_utils import convert_timestamp_to_datetime from cryptoai.utils.db_utils import convert_timestamp_to_datetime
from cryptoai.utils.db_manager import get_db_context
# 定义 A 股数据模型 # 定义 A 股数据模型
class AStock(Base): class AStock(Base):
@ -29,8 +30,8 @@ class AStock(Base):
class AStockManager: class AStockManager:
"""A股管理类""" """A股管理类"""
def __init__(self, db_session): def __init__(self, session: Session = None):
self.session = db_session self.session = session
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
""" """

View File

@ -3,11 +3,11 @@
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy import Column, Integer, String, DateTime, Index
from cryptoai.models.base import Base, logger from cryptoai.models.base import Base, logger
from cryptoai.utils.db_manager import get_db_context
# 定义Token模型 # 定义Token模型
class Token(Base): class Token(Base):
"""Token信息表模型""" """Token信息表模型"""
@ -30,8 +30,8 @@ class Token(Base):
class TokenManager: class TokenManager:
"""Token管理类""" """Token管理类"""
def __init__(self, db_session): def __init__(self, session: Session = None):
self.session = db_session self.session = session
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool: def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
""" """

View File

@ -6,7 +6,7 @@ from datetime import datetime
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy import Column, Integer, String, DateTime, Index
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from cryptoai.models.base import Base, logger from cryptoai.models.base import Base, logger
# 定义用户数据模型 # 定义用户数据模型
@ -37,8 +37,8 @@ class User(Base):
class UserManager: class UserManager:
"""用户管理类""" """用户管理类"""
def __init__(self, db_session): def __init__(self, session: Session = None):
self.session = db_session self.session = session
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool: def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
""" """

View File

@ -3,7 +3,7 @@
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@ -35,8 +35,8 @@ class UserQuestion(Base):
class UserQuestionManager: class UserQuestionManager:
"""用户提问管理类""" """用户提问管理类"""
def __init__(self, db_session): def __init__(self, session: Session = None):
self.session = db_session self.session = session
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool: def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
""" """

View File

@ -2,7 +2,6 @@ import json
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
from cryptoai.api.adata_api import AStockAPI from cryptoai.api.adata_api import AStockAPI
from cryptoai.utils.db_manager import get_db_manager
from datetime import datetime from datetime import datetime
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -12,7 +11,7 @@ from cryptoai.utils.config_loader import ConfigLoader
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from cryptoai.routes.user import get_current_user from cryptoai.routes.user import get_current_user
import requests import requests
from cryptoai.models.astock import AStockManager
# 创建路由 # 创建路由
router = APIRouter() router = APIRouter()
@ -21,8 +20,8 @@ logger.setLevel(logging.DEBUG)
@router.get("/stock/search") @router.get("/stock/search")
async def search_stock(key: str, limit: int = 10): async def search_stock(key: str, limit: int = 10):
manager = get_db_manager() manager = AStockManager()
result = manager.astock_manager.search_stock(key, limit) result = manager.search_stock(key, limit)
return result return result

View File

@ -1,6 +1,5 @@
from fastapi import APIRouter from fastapi import APIRouter
from typing import Optional from typing import Optional
from cryptoai.utils.db_manager import get_db_manager
from fastapi import Depends from fastapi import Depends
from pydantic import BaseModel from pydantic import BaseModel
from cryptoai.routes.user import get_current_user from cryptoai.routes.user import get_current_user
@ -8,7 +7,9 @@ from fastapi import HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import requests import requests
from datetime import date, timedelta from datetime import date, timedelta
from cryptoai.models.analysis_history import AnalysisHistoryManager
from cryptoai.models.user_question import UserQuestionManager
from cryptoai.models.token import TokenManager
class AnalysisHistoryRequest(BaseModel): class AnalysisHistoryRequest(BaseModel):
symbol: str symbol: str
content: str content: str
@ -21,7 +22,8 @@ router = APIRouter()
async def analysis_history(request: AnalysisHistoryRequest, async def analysis_history(request: AnalysisHistoryRequest,
current_user: dict = Depends(get_current_user)): current_user: dict = Depends(get_current_user)):
get_db_manager().analysis_history_manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe) manager = AnalysisHistoryManager()
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
return {"message": "ok"} return {"message": "ok"}
@ -29,7 +31,8 @@ async def analysis_history(request: AnalysisHistoryRequest,
async def get_analysis_histories(current_user: dict = Depends(get_current_user), async def get_analysis_histories(current_user: dict = Depends(get_current_user),
limit: int = 10, limit: int = 10,
offset: int = 0): offset: int = 0):
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset) manager = AnalysisHistoryManager()
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
return history return history
@ -65,7 +68,8 @@ async def chat(request: ChatRequest,
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
get_db_manager().user_question_manager.save_user_question(current_user["id"],"chat-messages", request.message) manager = UserQuestionManager()
manager.save_user_question(current_user["id"],"chat-messages", request.message)
response = requests.post(url, headers=headers, json=payload, stream=True) response = requests.post(url, headers=headers, json=payload, stream=True)
@ -90,7 +94,8 @@ async def analysis(request: AnalysisRequest,
if request.type == 'crypto': if request.type == 'crypto':
# 检查symbol是否存在 # 检查symbol是否存在
tokens = get_db_manager().token_manager.search_token(request.symbol) manager = TokenManager()
tokens = manager.search_token(request.symbol)
if not tokens or len(tokens) == 0: if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
@ -106,7 +111,8 @@ async def analysis(request: AnalysisRequest,
"user": current_user["mail"] "user": current_user["mail"]
} }
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。") manager = UserQuestionManager()
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
elif request.type == 'astock': elif request.type == 'astock':
@ -121,7 +127,8 @@ async def analysis(request: AnalysisRequest,
"user": current_user["mail"] "user": current_user["mail"]
} }
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票" + stock_code + ",并给出分析报告。") manager = UserQuestionManager()
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票" + stock_code + ",并给出分析报告。")
elif request.type == 'usstock': elif request.type == 'usstock':
stock_code = request.stock_code stock_code = request.stock_code
@ -136,7 +143,8 @@ async def analysis(request: AnalysisRequest,
"response_mode": "streaming", "response_mode": "streaming",
"user": current_user["mail"] "user": current_user["mail"]
} }
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。") manager = UserQuestionManager()
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
else: else:
raise HTTPException(status_code=400, detail="不支持的类型") raise HTTPException(status_code=400, detail="不支持的类型")

View File

@ -2,7 +2,6 @@ import json
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
from cryptoai.api.adata_api import AStockAPI from cryptoai.api.adata_api import AStockAPI
from cryptoai.utils.db_manager import get_db_manager
from datetime import datetime from datetime import datetime
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -15,6 +14,7 @@ import requests
from cryptoai.api.binance_api import get_binance_api from cryptoai.api.binance_api import get_binance_api
from cryptoai.models.data_processor import DataProcessor from cryptoai.models.data_processor import DataProcessor
from datetime import timedelta from datetime import timedelta
from cryptoai.models.token import TokenManager
# 创建路由 # 创建路由
router = APIRouter() router = APIRouter()
@ -23,7 +23,7 @@ logger.setLevel(logging.DEBUG)
@router.get("/search/{key}") @router.get("/search/{key}")
async def search_crypto(key: str): async def search_crypto(key: str):
manager = get_db_manager() manager = TokenManager()
result = manager.search_token(key) result = manager.search_token(key)
return result return result
@ -34,7 +34,8 @@ class CryptoAnalysisRequest(BaseModel):
@router.get("/kline/{symbol}") @router.get("/kline/{symbol}")
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100): async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
# 检查symbol是否存在 # 检查symbol是否存在
tokens = get_db_manager().token_manager.search_token(symbol) token_manager = TokenManager()
tokens = token_manager.search_token(symbol)
if not tokens or len(tokens) == 0: if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")

View File

@ -1,6 +1,10 @@
from fastapi import APIRouter from fastapi import APIRouter
from cryptoai.utils.db_manager import get_db_manager
import logging import logging
from cryptoai.models.user import UserManager
from cryptoai.models.user_question import UserQuestionManager
from sqlalchemy.orm import Session
from fastapi import Depends
from cryptoai.utils.db_manager import get_db
router = APIRouter() router = APIRouter()
@ -8,14 +12,16 @@ logger = logging.getLogger("platform_router")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@router.get("/info") @router.get("/info")
async def get_platform_info(): async def get_platform_info(session: Session = Depends(get_db)):
db_manager = get_db_manager()
user_manager = UserManager(session)
question_manager = UserQuestionManager(session)
result = {} result = {}
try: try:
result["user_count"] = db_manager.user_manager.get_user_count() result["user_count"] = user_manager.get_user_count()
result["question_count"] = db_manager.user_question_manager.get_user_question_count() result["question_count"] = question_manager.get_user_question_count()
return result return result
except Exception as e: except Exception as e:

View File

@ -16,9 +16,10 @@ from datetime import datetime, timedelta
import jwt import jwt
from jwt.exceptions import PyJWTError from jwt.exceptions import PyJWTError
from fastapi import Request from fastapi import Request
from sqlalchemy.orm import Session
from cryptoai.utils.db_manager import get_db_manager from cryptoai.utils.db_manager import get_db
from cryptoai.utils.email_service import get_email_service from cryptoai.utils.email_service import get_email_service
from cryptoai.models.user import UserManager
# 配置日志 # 配置日志
logger = logging.getLogger("user_router") logger = logging.getLogger("user_router")
@ -87,7 +88,7 @@ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
return encoded_jwt return encoded_jwt
async def get_current_user(request: Request) -> Dict[str, Any]: async def get_current_user(request: Request, session: Session = Depends(get_db)) -> Dict[str, Any]:
"""获取当前用户""" """获取当前用户"""
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -109,8 +110,8 @@ async def get_current_user(request: Request) -> Dict[str, Any]:
print(f"PyJWTError: {e}") print(f"PyJWTError: {e}")
raise credentials_exception raise credentials_exception
db_manager = get_db_manager() manager = UserManager(session)
user = db_manager.user_manager.get_user_by_mail(mail) user = manager.get_user_by_mail(mail)
if user is None: if user is None:
raise credentials_exception raise credentials_exception
return user return user
@ -154,13 +155,13 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s
) )
@router.put("/reset_password", response_model=Dict[str, Any]) @router.put("/reset_password", response_model=Dict[str, Any])
async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]: async def reset_password(request: ResetPasswordRequest, session: Session = Depends(get_db)) -> Dict[str, Any]:
""" """
修改密码 修改密码
""" """
try: try:
# 获取数据库管理器 # 获取数据库管理器
db_manager = get_db_manager()
# 验证验证码 # 验证验证码
email_service = get_email_service() email_service = get_email_service()
@ -173,7 +174,8 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
# 更新密码 # 更新密码
hashed_password = hash_password(request.new_password) hashed_password = hash_password(request.new_password)
success = db_manager.user_manager.update_password(request.mail, hashed_password) manager = UserManager(session)
success = manager.update_password(request.mail, hashed_password)
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -193,7 +195,7 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED) @router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def register_user(user: UserRegister) -> Dict[str, Any]: async def register_user(user: UserRegister, session: Session = Depends(get_db)) -> Dict[str, Any]:
""" """
注册新用户 注册新用户
@ -215,13 +217,13 @@ async def register_user(user: UserRegister) -> Dict[str, Any]:
) )
# 获取数据库管理器 # 获取数据库管理器
db_manager = get_db_manager() manager = UserManager(session)
# 对密码进行哈希处理 # 对密码进行哈希处理
hashed_password = hash_password(user.password) hashed_password = hash_password(user.password)
# 注册用户 # 注册用户
success = db_manager.user_manager.register_user( success = manager.register_user(
mail=user.mail, mail=user.mail,
nickname=user.nickname, nickname=user.nickname,
password=hashed_password, password=hashed_password,
@ -250,7 +252,7 @@ async def register_user(user: UserRegister) -> Dict[str, Any]:
) )
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse)
async def login(loginData: UserLogin) -> TokenResponse: async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> TokenResponse:
""" """
用户登录 用户登录
@ -262,10 +264,10 @@ async def login(loginData: UserLogin) -> TokenResponse:
""" """
try: try:
# 获取数据库管理器 # 获取数据库管理器
db_manager = get_db_manager() manager = UserManager(session)
# 获取用户信息 # 获取用户信息
user = db_manager.user_manager.get_user_by_mail(loginData.mail) user = manager.get_user_by_mail(loginData.mail)
if not user: if not user:
raise HTTPException( raise HTTPException(
@ -278,17 +280,14 @@ async def login(loginData: UserLogin) -> TokenResponse:
hashed_password = hash_password(loginData.password) hashed_password = hash_password(loginData.password)
# 查询用户的密码哈希 # 查询用户的密码哈希
session = db_manager.Session() user = manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
try: if not user:
user = db_manager.user_manager.get_user_by_mail_and_password(loginData.mail, hashed_password) raise HTTPException(
if not user: status_code=status.HTTP_401_UNAUTHORIZED,
raise HTTPException( detail="邮箱或密码错误",
status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"},
detail="邮箱或密码错误", )
headers={"WWW-Authenticate": "Bearer"},
)
finally:
session.close()
# 创建访问令牌,不过期 # 创建访问令牌,不过期
access_token = create_access_token(data={"sub": user["mail"]}) access_token = create_access_token(data={"sub": user["mail"]})
@ -317,7 +316,7 @@ async def login(loginData: UserLogin) -> TokenResponse:
) )
@router.get("/me", response_model=UserResponse) @router.get("/me", response_model=UserResponse)
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse: async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db)) -> UserResponse:
""" """
获取当前登录用户信息 获取当前登录用户信息
@ -340,7 +339,8 @@ async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)
async def update_user_level( async def update_user_level(
user_id: int, user_id: int,
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"), level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
current_user: Dict[str, Any] = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user),
session: Session = Depends(get_db)
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
更新用户级别需要管理员权限 更新用户级别需要管理员权限
@ -361,10 +361,10 @@ async def update_user_level(
) )
# 获取数据库管理器 # 获取数据库管理器
db_manager = get_db_manager() manager = UserManager(session)
# 更新用户级别 # 更新用户级别
success = db_manager.user_manager.update_user_level(user_id, level) success = manager.update_user_level(user_id, level)
if not success: if not success:
raise HTTPException( raise HTTPException(
@ -380,7 +380,8 @@ async def update_user_level(
@router.get("/points/{user_id}", response_model=Dict[str, Any]) @router.get("/points/{user_id}", response_model=Dict[str, Any])
async def get_user_points( async def get_user_points(
user_id: int, user_id: int,
current_user: Dict[str, Any] = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user),
session: Session = Depends(get_db)
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取用户积分 获取用户积分
@ -400,10 +401,10 @@ async def get_user_points(
) )
# 获取数据库管理器 # 获取数据库管理器
db_manager = get_db_manager() manager = UserManager(session)
# 获取用户信息 # 获取用户信息
user = db_manager.user_manager.get_user_by_id(user_id) user = manager.get_user_by_id(user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(
@ -422,7 +423,8 @@ async def get_user_points(
async def add_user_points( async def add_user_points(
user_id: int, user_id: int,
points: int = Query(..., gt=0, description="增加的积分数量必须大于0"), points: int = Query(..., gt=0, description="增加的积分数量必须大于0"),
current_user: Dict[str, Any] = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user),
session: Session = Depends(get_db)
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
为用户增加积分需要管理员权限 为用户增加积分需要管理员权限
@ -443,10 +445,10 @@ async def add_user_points(
) )
# 获取数据库管理器 # 获取数据库管理器
db_manager = get_db_manager() manager = UserManager(session)
# 添加积分 # 添加积分
success = db_manager.user_manager.add_user_points(user_id, points) success = manager.add_user_points(user_id, points)
if not success: if not success:
raise HTTPException( raise HTTPException(
@ -455,7 +457,7 @@ async def add_user_points(
) )
# 获取更新后的用户信息 # 获取更新后的用户信息
user = db_manager.user_manager.get_user_by_id(user_id) user = manager.get_user_by_id(user_id)
return { return {
"status": "success", "status": "success",
@ -467,7 +469,8 @@ async def add_user_points(
async def consume_user_points( async def consume_user_points(
user_id: int, user_id: int,
points: int = Query(..., gt=0, description="消费的积分数量必须大于0"), points: int = Query(..., gt=0, description="消费的积分数量必须大于0"),
current_user: Dict[str, Any] = Depends(get_current_user) current_user: Dict[str, Any] = Depends(get_current_user),
session: Session = Depends(get_db)
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
用户消费积分 用户消费积分
@ -488,10 +491,10 @@ async def consume_user_points(
) )
# 获取数据库管理器 # 获取数据库管理器
db_manager = get_db_manager() manager = UserManager(session)
# 消费积分 # 消费积分
success = db_manager.user_manager.consume_user_points(user_id, points) success = manager.consume_user_points(user_id, points)
if not success: if not success:
raise HTTPException( raise HTTPException(
@ -500,7 +503,7 @@ async def consume_user_points(
) )
# 获取更新后的用户信息 # 获取更新后的用户信息
user = db_manager.user_manager.get_user_by_id(user_id) user = manager.get_user_by_id(user_id)
return { return {
"status": "success", "status": "success",

View File

@ -7,17 +7,11 @@ from typing import Dict, Any, List, Optional, Union
from datetime import datetime from datetime import datetime
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool
from cryptoai.utils.config_loader import ConfigLoader from cryptoai.utils.config_loader import ConfigLoader
from cryptoai.models.base import Base from sqlalchemy.ext.declarative import declarative_base
from cryptoai.models.token import TokenManager
from cryptoai.models.analysis_result import AnalysisResultManager
from cryptoai.models.user import UserManager
from cryptoai.models.user_question import UserQuestionManager
from cryptoai.models.astock import AStockManager
from cryptoai.models.analysis_history import AnalysisHistoryManager
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
@ -26,207 +20,49 @@ logging.basicConfig(
) )
logger = logging.getLogger('db_manager') logger = logging.getLogger('db_manager')
class DBManager:
"""
数据库管理工具用于连接MySQL数据库并提供各个模型的管理器
使用方法: config_loader = ConfigLoader()
- 调用 get_db_manager() 获取数据库管理器实例 db_config = config_loader.get_database_config()
- 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法 engine = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['db_name']}?charset=utf8mb4",
- 使用 db_manager.get_session() 可以获取一个新的数据库会话 echo=False, # 设置为True可以输出SQL语句调试用
""" pool_size=5, # 连接池大小
max_overflow=10, # 最大溢出连接数
pool_timeout=30, # 连接超时时间
pool_recycle=1800, # 连接回收时间(秒)
pool_pre_ping=True, # 在使用连接前先ping一下确保连接有效
connect_args={'charset': 'utf8mb4'})
def __init__(self, host: str, port: int, user: str, password: str, db_name: str): # 声明基类
""" Base = declarative_base()
初始化数据库管理器 Base.metadata.create_all(bind=engine)
Args: # 创建线程安全的会话工厂
host: 数据库主机地址 SessionLocal = scoped_session(
port: 数据库端口 sessionmaker(autocommit=False, autoflush=False, bind=engine)
user: 用户名 )
password: 密码
db_name: 数据库名
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.db_name = db_name
self.engine = None
self.Session = None
# 初始化数据库连接 def get_db():
self._init_db() db = SessionLocal()
try:
yield db
finally:
if db:
db.close()
# 初始化各个管理器 def get_db_context():
self._init_managers() try:
db = SessionLocal()
yield db
db.commit()
except Exception as e:
if db:
db.rollback()
raise e
finally:
if db:
db.close()
def _init_db(self) -> None:
"""初始化数据库连接和表"""
try:
# 创建数据库连接
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
# 创建引擎,设置连接池
self.engine = create_engine(
connection_string,
echo=False, # 设置为True可以输出SQL语句调试用
pool_size=5, # 连接池大小
max_overflow=10, # 最大溢出连接数
pool_timeout=30, # 连接超时时间
pool_recycle=1800, # 连接回收时间(秒)
pool_pre_ping=True, # 在使用连接前先ping一下确保连接有效
connect_args={'charset': 'utf8mb4'}
)
# 创建会话工厂
self.Session = sessionmaker(bind=self.engine)
# 创建表(如果不存在)
Base.metadata.create_all(self.engine)
logger.info(f"成功连接到数据库 {self.db_name}")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
self.engine = None
def _init_managers(self) -> None:
"""初始化各个模型的管理器"""
if not self.engine:
logger.error("引擎未初始化,无法创建管理器")
return
try:
session = self.Session()
# 初始化各个模型的管理器
self.token_manager = TokenManager(session)
self.analysis_result_manager = AnalysisResultManager(session)
self.user_manager = UserManager(session)
self.user_question_manager = UserQuestionManager(session)
self.astock_manager = AStockManager(session)
self.analysis_history_manager = AnalysisHistoryManager(session)
logger.info("成功初始化所有模型管理器")
except Exception as e:
logger.error(f"管理器初始化失败: {e}")
if session:
session.close()
def get_session(self):
"""
获取新的数据库会话
Returns:
SQLAlchemy session对象如果初始化失败则返回None
"""
if not self.Session:
try:
self._init_db()
except Exception as e:
logger.error(f"重新初始化数据库失败: {e}")
return None
return self.Session()
def refresh_managers(self) -> bool:
"""
刷新所有管理器重新建立会话
Returns:
刷新是否成功
"""
try:
# 关闭旧会话(如果有)
self.close()
# 重新初始化数据库连接
self._init_db()
# 重新初始化管理器
self._init_managers()
return self.engine is not None
except Exception as e:
logger.error(f"刷新管理器失败: {e}")
return False
def close(self) -> None:
"""关闭数据库连接(如果存在)"""
if self.engine:
self.engine.dispose()
logger.info("数据库连接已关闭")
self.engine = None
self.Session = None
# 单例模式
_db_instance = None
def get_db_manager(host: Optional[str] = None,
port: Optional[int] = None,
user: Optional[str] = None,
password: Optional[str] = None,
db_name: Optional[str] = None) -> DBManager:
"""
获取数据库管理器实例单例模式
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
Returns:
数据库管理器实例
使用示例:
```python
# 获取数据库管理器
db_manager = get_db_manager()
# 使用Token管理器
tokens = db_manager.token_manager.search_token("BTC")
# 使用用户管理器
user = db_manager.user_manager.get_user_by_mail("example@test.com")
```
"""
global _db_instance
# 如果已经初始化过,直接返回
if _db_instance is not None:
return _db_instance
# 如果未指定参数,从配置加载器获取数据库配置
if host is None or port is None or user is None or password is None or db_name is None:
config_loader = ConfigLoader()
db_config = config_loader.get_database_config()
# 使用配置中的值或默认值
db_host = host or db_config.get('host')
db_port = port or db_config.get('port')
db_user = user or db_config.get('user')
db_password = password or db_config.get('password')
db_name = db_name or db_config.get('db_name')
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
else:
db_host = host
db_port = port
db_user = user
db_password = password
db_name = db_name
# 创建实例
_db_instance = DBManager(
host=db_host,
port=db_port,
user=db_user,
password=db_password,
db_name=db_name
)
return _db_instance

View File

@ -1,120 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
更新数据库表字符集为utf8mb4以支持emoji和其他特殊字符
"""
import logging
from cryptoai.utils.db_manager import get_db_manager
from sqlalchemy import text
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('update_db_charset')
def update_table_charset():
"""更新数据库表字符集为utf8mb4"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
if not db_manager.engine:
logger.error("数据库连接失败")
return False
# 创建会话
session = db_manager.Session()
try:
# 更新数据库字符集
session.execute(text("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;"))
# 更新agent_feeds表字符集
session.execute(text("""
ALTER TABLE agent_feeds
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
"""))
# 特别更新content列的字符集
session.execute(text("""
ALTER TABLE agent_feeds
MODIFY content TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
"""))
# 检查users表是否存在
result = session.execute(text("""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = 'users';
"""))
table_exists = result.scalar() > 0
# 如果users表存在更新其字符集
if table_exists:
session.execute(text("""
ALTER TABLE users
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
"""))
# 特别更新nickname和mail列的字符集
session.execute(text("""
ALTER TABLE users
MODIFY nickname VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
"""))
session.execute(text("""
ALTER TABLE users
MODIFY mail VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
"""))
logger.info("成功更新users表字符集为utf8mb4")
# 检查user_questions表是否存在
result = session.execute(text("""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = 'user_questions';
"""))
table_exists = result.scalar() > 0
# 如果user_questions表存在更新其字符集
if table_exists:
session.execute(text("""
ALTER TABLE user_questions
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
"""))
# 特别更新question列的字符集
session.execute(text("""
ALTER TABLE user_questions
MODIFY question TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
"""))
logger.info("成功更新user_questions表字符集为utf8mb4")
session.commit()
logger.info("成功更新数据库表字符集为utf8mb4")
return True
except Exception as e:
session.rollback()
logger.error(f"更新数据库表字符集失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"更新数据库表字符集失败: {e}")
return False
if __name__ == "__main__":
update_table_charset()

View File

@ -29,7 +29,7 @@ services:
cryptoai-api: cryptoai-api:
build: . build: .
container_name: cryptoai-api container_name: cryptoai-api
image: cryptoai-api:0.1.32 image: cryptoai-api:0.1.33
restart: always restart: always
ports: ports:
- "8000:8000" - "8000:8000"

View File

@ -1,8 +1,9 @@
from cryptoai.utils.db_manager import get_db_manager
from cryptoai.api.adata_api import AStockAPI from cryptoai.api.adata_api import AStockAPI
from cryptoai.api.binance_api import get_binance_api from cryptoai.api.binance_api import get_binance_api
import json import json
from time import sleep from time import sleep
from cryptoai.models.token import TokenManager
from cryptoai.utils.db_manager import get_db_context
if __name__ == "__main__": if __name__ == "__main__":
symbols = get_binance_api().get_all_symbols() symbols = get_binance_api().get_all_symbols()
@ -11,10 +12,11 @@ if __name__ == "__main__":
# symbol = symbol.replace('USDT', '') # symbol = symbol.replace('USDT', '')
# print(symbol) # print(symbol)
manager = get_db_manager() session = get_db_context()
manager = TokenManager(session)
for symbol in symbols: for symbol in symbols:
base_asset = symbol.split('USDT')[0] base_asset = symbol.split('USDT')[0]
quote_asset = 'USDT' quote_asset = 'USDT'
manager.token_manager.create_token(symbol,base_asset, quote_asset) manager.create_token(symbol,base_asset, quote_asset)