From 08c4efcfd36a248d5d802f826cd1c3367afe66ef Mon Sep 17 00:00:00 2001 From: aaron <> Date: Fri, 30 May 2025 22:09:45 +0800 Subject: [PATCH] update --- cryptoai/agents/crypto_agent.py | 4 - cryptoai/api/adata_api.py | 9 +- cryptoai/models/analysis_history.py | 8 +- cryptoai/models/analysis_result.py | 8 +- cryptoai/models/astock.py | 7 +- cryptoai/models/token.py | 8 +- cryptoai/models/user.py | 6 +- cryptoai/models/user_question.py | 6 +- cryptoai/routes/adata.py | 7 +- cryptoai/routes/analysis.py | 26 ++- cryptoai/routes/crypto.py | 7 +- cryptoai/routes/platform.py | 16 +- cryptoai/routes/user.py | 83 ++++----- cryptoai/utils/db_manager.py | 256 +++++----------------------- cryptoai/utils/update_db_charset.py | 120 ------------- docker-compose.yml | 2 +- test.py | 8 +- 17 files changed, 158 insertions(+), 423 deletions(-) delete mode 100644 cryptoai/utils/update_db_charset.py diff --git a/cryptoai/agents/crypto_agent.py b/cryptoai/agents/crypto_agent.py index c73a994..e8d215f 100644 --- a/cryptoai/agents/crypto_agent.py +++ b/cryptoai/agents/crypto_agent.py @@ -17,7 +17,6 @@ from api.deepseek_api import DeepSeekAPI from models.data_processor import DataProcessor from utils.config_loader import ConfigLoader from utils.dingtalk_bot import DingTalkBot -from utils.db_manager import get_db_manager from utils.discord_bot import DiscordBot class CryptoAgent: @@ -79,9 +78,6 @@ class CryptoAgent: ) print("Discord机器人已启用") - # 初始化数据库管理器 - self.db_manager = get_db_manager() - # 设置支持的加密货币 self.base_currencies = self.crypto_config['base_currencies'] self.quote_currency = self.crypto_config['quote_currency'] diff --git a/cryptoai/api/adata_api.py b/cryptoai/api/adata_api.py index 9726076..20f0c5a 100644 --- a/cryptoai/api/adata_api.py +++ b/cryptoai/api/adata_api.py @@ -7,7 +7,7 @@ import adata from typing import Dict, List, Optional import pandas as pd from datetime import datetime, timedelta - +from cryptoai.utils.db_manager import get_db_context class AStockAPI: @staticmethod def get_all_stock_codes() -> List[str]: @@ -156,8 +156,11 @@ if __name__ == "__main__": print(list) # 保存到数据库 - import cryptoai.utils.db_manager as db_manager - db_manager.get_db_manager().create_stocks(list) + from cryptoai.models.astock import AStockManager + + session = get_db_context() + manager = AStockManager(session) + manager.create_stocks(list) diff --git a/cryptoai/models/analysis_history.py b/cryptoai/models/analysis_history.py index 57cc1d0..d3bafc4 100644 --- a/cryptoai/models/analysis_history.py +++ b/cryptoai/models/analysis_history.py @@ -3,12 +3,12 @@ from typing import Dict, Any, List, Optional from datetime import datetime - +from sqlalchemy.orm import Session from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey from sqlalchemy.orm import relationship from cryptoai.models.base import Base, logger - +from cryptoai.utils.db_manager import get_db_context # 定义分析历史模型 class AnalysisHistory(Base): """分析历史表模型""" @@ -37,8 +37,8 @@ class AnalysisHistory(Base): class AnalysisHistoryManager: """分析历史管理类""" - def __init__(self, db_session): - self.session = db_session + def __init__(self, session: Session = None): + self.session = session def add_analysis_history(self, user_id: int, type: str, symbol: str, content: str, timeframe: str = None) -> bool: diff --git a/cryptoai/models/analysis_result.py b/cryptoai/models/analysis_result.py index b522e85..3ff956d 100644 --- a/cryptoai/models/analysis_result.py +++ b/cryptoai/models/analysis_result.py @@ -3,12 +3,12 @@ from typing import Dict, Any, List, Optional from datetime import datetime - +from sqlalchemy.orm import Session from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy.dialects.mysql import JSON from cryptoai.models.base import Base, logger - +from cryptoai.utils.db_manager import get_db_context # 定义分析结果模型 class AnalysisResult(Base): """分析结果表模型""" @@ -33,8 +33,8 @@ class AnalysisResult(Base): class AnalysisResultManager: """分析结果管理类""" - def __init__(self, db_session): - self.session = db_session + def __init__(self, session: Session = None): + self.session = session def save_analysis_result(self, agent: str, symbol: str, time_interval: str, analysis_result: Dict[str, Any]) -> bool: diff --git a/cryptoai/models/astock.py b/cryptoai/models/astock.py index 90e179e..631e89e 100644 --- a/cryptoai/models/astock.py +++ b/cryptoai/models/astock.py @@ -5,9 +5,10 @@ from typing import Dict, Any, List, Optional, Union from datetime import datetime from sqlalchemy import Column, Integer, String, DateTime, Index - +from sqlalchemy.orm import Session from cryptoai.models.base import Base, logger from cryptoai.utils.db_utils import convert_timestamp_to_datetime +from cryptoai.utils.db_manager import get_db_context # 定义 A 股数据模型 class AStock(Base): @@ -29,8 +30,8 @@ class AStock(Base): class AStockManager: """A股管理类""" - def __init__(self, db_session): - self.session = db_session + def __init__(self, session: Session = None): + self.session = session def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: """ diff --git a/cryptoai/models/token.py b/cryptoai/models/token.py index 45df755..66df959 100644 --- a/cryptoai/models/token.py +++ b/cryptoai/models/token.py @@ -3,11 +3,11 @@ from typing import Dict, Any, List, Optional from datetime import datetime - +from sqlalchemy.orm import Session from sqlalchemy import Column, Integer, String, DateTime, Index from cryptoai.models.base import Base, logger - +from cryptoai.utils.db_manager import get_db_context # 定义Token模型 class Token(Base): """Token信息表模型""" @@ -30,8 +30,8 @@ class Token(Base): class TokenManager: """Token管理类""" - def __init__(self, db_session): - self.session = db_session + def __init__(self, session: Session = None): + self.session = session def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool: """ diff --git a/cryptoai/models/user.py b/cryptoai/models/user.py index 1087113..7229c26 100644 --- a/cryptoai/models/user.py +++ b/cryptoai/models/user.py @@ -6,7 +6,7 @@ from datetime import datetime from fastapi import HTTPException, status from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy.orm import relationship - +from sqlalchemy.orm import Session from cryptoai.models.base import Base, logger # 定义用户数据模型 @@ -37,8 +37,8 @@ class User(Base): class UserManager: """用户管理类""" - def __init__(self, db_session): - self.session = db_session + def __init__(self, session: Session = None): + self.session = session def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool: """ diff --git a/cryptoai/models/user_question.py b/cryptoai/models/user_question.py index 02c6b12..6793c2f 100644 --- a/cryptoai/models/user_question.py +++ b/cryptoai/models/user_question.py @@ -3,7 +3,7 @@ from typing import Dict, Any, List, Optional from datetime import datetime - +from sqlalchemy.orm import Session from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey from sqlalchemy.orm import relationship @@ -35,8 +35,8 @@ class UserQuestion(Base): class UserQuestionManager: """用户提问管理类""" - def __init__(self, db_session): - self.session = db_session + def __init__(self, session: Session = None): + self.session = session def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool: """ diff --git a/cryptoai/routes/adata.py b/cryptoai/routes/adata.py index cbd6f6f..927c329 100644 --- a/cryptoai/routes/adata.py +++ b/cryptoai/routes/adata.py @@ -2,7 +2,6 @@ import json import logging from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path from cryptoai.api.adata_api import AStockAPI -from cryptoai.utils.db_manager import get_db_manager from datetime import datetime from typing import Dict, Any, List, Optional from pydantic import BaseModel @@ -12,7 +11,7 @@ from cryptoai.utils.config_loader import ConfigLoader from fastapi.responses import StreamingResponse from cryptoai.routes.user import get_current_user import requests - +from cryptoai.models.astock import AStockManager # 创建路由 router = APIRouter() @@ -21,8 +20,8 @@ logger.setLevel(logging.DEBUG) @router.get("/stock/search") async def search_stock(key: str, limit: int = 10): - manager = get_db_manager() - result = manager.astock_manager.search_stock(key, limit) + manager = AStockManager() + result = manager.search_stock(key, limit) return result diff --git a/cryptoai/routes/analysis.py b/cryptoai/routes/analysis.py index f9240fe..93aacda 100644 --- a/cryptoai/routes/analysis.py +++ b/cryptoai/routes/analysis.py @@ -1,6 +1,5 @@ from fastapi import APIRouter from typing import Optional -from cryptoai.utils.db_manager import get_db_manager from fastapi import Depends from pydantic import BaseModel from cryptoai.routes.user import get_current_user @@ -8,7 +7,9 @@ from fastapi import HTTPException from fastapi.responses import StreamingResponse import requests from datetime import date, timedelta - +from cryptoai.models.analysis_history import AnalysisHistoryManager +from cryptoai.models.user_question import UserQuestionManager +from cryptoai.models.token import TokenManager class AnalysisHistoryRequest(BaseModel): symbol: str content: str @@ -21,7 +22,8 @@ router = APIRouter() async def analysis_history(request: AnalysisHistoryRequest, current_user: dict = Depends(get_current_user)): - get_db_manager().analysis_history_manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe) + manager = AnalysisHistoryManager() + manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe) return {"message": "ok"} @@ -29,7 +31,8 @@ async def analysis_history(request: AnalysisHistoryRequest, async def get_analysis_histories(current_user: dict = Depends(get_current_user), limit: int = 10, offset: int = 0): - history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset) + manager = AnalysisHistoryManager() + history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset) return history @@ -65,7 +68,8 @@ async def chat(request: ChatRequest, 'Content-Type': 'application/json' } - get_db_manager().user_question_manager.save_user_question(current_user["id"],"chat-messages", request.message) + manager = UserQuestionManager() + manager.save_user_question(current_user["id"],"chat-messages", request.message) response = requests.post(url, headers=headers, json=payload, stream=True) @@ -90,7 +94,8 @@ async def analysis(request: AnalysisRequest, if request.type == 'crypto': # 检查symbol是否存在 - tokens = get_db_manager().token_manager.search_token(request.symbol) + manager = TokenManager() + tokens = manager.search_token(request.symbol) if not tokens or len(tokens) == 0: raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") @@ -106,7 +111,8 @@ async def analysis(request: AnalysisRequest, "user": current_user["mail"] } - get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。") + manager = UserQuestionManager() + manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。") elif request.type == 'astock': @@ -121,7 +127,8 @@ async def analysis(request: AnalysisRequest, "user": current_user["mail"] } - get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。") + manager = UserQuestionManager() + manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。") elif request.type == 'usstock': stock_code = request.stock_code @@ -136,7 +143,8 @@ async def analysis(request: AnalysisRequest, "response_mode": "streaming", "user": current_user["mail"] } - get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。") + manager = UserQuestionManager() + manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。") else: raise HTTPException(status_code=400, detail="不支持的类型") diff --git a/cryptoai/routes/crypto.py b/cryptoai/routes/crypto.py index 28191c5..e677654 100644 --- a/cryptoai/routes/crypto.py +++ b/cryptoai/routes/crypto.py @@ -2,7 +2,6 @@ import json import logging from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path from cryptoai.api.adata_api import AStockAPI -from cryptoai.utils.db_manager import get_db_manager from datetime import datetime from typing import Dict, Any, List, Optional from pydantic import BaseModel @@ -15,6 +14,7 @@ import requests from cryptoai.api.binance_api import get_binance_api from cryptoai.models.data_processor import DataProcessor from datetime import timedelta +from cryptoai.models.token import TokenManager # 创建路由 router = APIRouter() @@ -23,7 +23,7 @@ logger.setLevel(logging.DEBUG) @router.get("/search/{key}") async def search_crypto(key: str): - manager = get_db_manager() + manager = TokenManager() result = manager.search_token(key) return result @@ -34,7 +34,8 @@ class CryptoAnalysisRequest(BaseModel): @router.get("/kline/{symbol}") async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100): # 检查symbol是否存在 - tokens = get_db_manager().token_manager.search_token(symbol) + token_manager = TokenManager() + tokens = token_manager.search_token(symbol) if not tokens or len(tokens) == 0: raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") diff --git a/cryptoai/routes/platform.py b/cryptoai/routes/platform.py index 516648b..ef0ddac 100644 --- a/cryptoai/routes/platform.py +++ b/cryptoai/routes/platform.py @@ -1,6 +1,10 @@ from fastapi import APIRouter -from cryptoai.utils.db_manager import get_db_manager import logging +from cryptoai.models.user import UserManager +from cryptoai.models.user_question import UserQuestionManager +from sqlalchemy.orm import Session +from fastapi import Depends +from cryptoai.utils.db_manager import get_db router = APIRouter() @@ -8,14 +12,16 @@ logger = logging.getLogger("platform_router") logger.setLevel(logging.INFO) @router.get("/info") -async def get_platform_info(): - db_manager = get_db_manager() +async def get_platform_info(session: Session = Depends(get_db)): + + user_manager = UserManager(session) + question_manager = UserQuestionManager(session) result = {} try: - result["user_count"] = db_manager.user_manager.get_user_count() - result["question_count"] = db_manager.user_question_manager.get_user_question_count() + result["user_count"] = user_manager.get_user_count() + result["question_count"] = question_manager.get_user_question_count() return result except Exception as e: diff --git a/cryptoai/routes/user.py b/cryptoai/routes/user.py index 743b873..276f61f 100644 --- a/cryptoai/routes/user.py +++ b/cryptoai/routes/user.py @@ -16,9 +16,10 @@ from datetime import datetime, timedelta import jwt from jwt.exceptions import PyJWTError from fastapi import Request - -from cryptoai.utils.db_manager import get_db_manager +from sqlalchemy.orm import Session +from cryptoai.utils.db_manager import get_db from cryptoai.utils.email_service import get_email_service +from cryptoai.models.user import UserManager # 配置日志 logger = logging.getLogger("user_router") @@ -87,7 +88,7 @@ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) - encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) return encoded_jwt -async def get_current_user(request: Request) -> Dict[str, Any]: +async def get_current_user(request: Request, session: Session = Depends(get_db)) -> Dict[str, Any]: """获取当前用户""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -109,8 +110,8 @@ async def get_current_user(request: Request) -> Dict[str, Any]: print(f"PyJWTError: {e}") raise credentials_exception - db_manager = get_db_manager() - user = db_manager.user_manager.get_user_by_mail(mail) + manager = UserManager(session) + user = manager.get_user_by_mail(mail) if user is None: raise credentials_exception return user @@ -154,13 +155,13 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s ) @router.put("/reset_password", response_model=Dict[str, Any]) -async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]: +async def reset_password(request: ResetPasswordRequest, session: Session = Depends(get_db)) -> Dict[str, Any]: """ 修改密码 """ try: # 获取数据库管理器 - db_manager = get_db_manager() + # 验证验证码 email_service = get_email_service() @@ -173,7 +174,8 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]: # 更新密码 hashed_password = hash_password(request.new_password) - success = db_manager.user_manager.update_password(request.mail, hashed_password) + manager = UserManager(session) + success = manager.update_password(request.mail, hashed_password) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -193,7 +195,7 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]: @router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED) -async def register_user(user: UserRegister) -> Dict[str, Any]: +async def register_user(user: UserRegister, session: Session = Depends(get_db)) -> Dict[str, Any]: """ 注册新用户 @@ -215,13 +217,13 @@ async def register_user(user: UserRegister) -> Dict[str, Any]: ) # 获取数据库管理器 - db_manager = get_db_manager() + manager = UserManager(session) # 对密码进行哈希处理 hashed_password = hash_password(user.password) # 注册用户 - success = db_manager.user_manager.register_user( + success = manager.register_user( mail=user.mail, nickname=user.nickname, password=hashed_password, @@ -250,7 +252,7 @@ async def register_user(user: UserRegister) -> Dict[str, Any]: ) @router.post("/login", response_model=TokenResponse) -async def login(loginData: UserLogin) -> TokenResponse: +async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> TokenResponse: """ 用户登录 @@ -262,10 +264,10 @@ async def login(loginData: UserLogin) -> TokenResponse: """ try: # 获取数据库管理器 - db_manager = get_db_manager() + manager = UserManager(session) # 获取用户信息 - user = db_manager.user_manager.get_user_by_mail(loginData.mail) + user = manager.get_user_by_mail(loginData.mail) if not user: raise HTTPException( @@ -278,17 +280,14 @@ async def login(loginData: UserLogin) -> TokenResponse: hashed_password = hash_password(loginData.password) # 查询用户的密码哈希 - session = db_manager.Session() - try: - user = db_manager.user_manager.get_user_by_mail_and_password(loginData.mail, hashed_password) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="邮箱或密码错误", - headers={"WWW-Authenticate": "Bearer"}, - ) - finally: - session.close() + user = manager.get_user_by_mail_and_password(loginData.mail, hashed_password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="邮箱或密码错误", + headers={"WWW-Authenticate": "Bearer"}, + ) + # 创建访问令牌,不过期 access_token = create_access_token(data={"sub": user["mail"]}) @@ -317,7 +316,7 @@ async def login(loginData: UserLogin) -> TokenResponse: ) @router.get("/me", response_model=UserResponse) -async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse: +async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db)) -> UserResponse: """ 获取当前登录用户信息 @@ -340,7 +339,8 @@ async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user) async def update_user_level( user_id: int, level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), + session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 更新用户级别(需要管理员权限) @@ -361,10 +361,10 @@ async def update_user_level( ) # 获取数据库管理器 - db_manager = get_db_manager() + manager = UserManager(session) # 更新用户级别 - success = db_manager.user_manager.update_user_level(user_id, level) + success = manager.update_user_level(user_id, level) if not success: raise HTTPException( @@ -380,7 +380,8 @@ async def update_user_level( @router.get("/points/{user_id}", response_model=Dict[str, Any]) async def get_user_points( user_id: int, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), + session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 获取用户积分 @@ -400,10 +401,10 @@ async def get_user_points( ) # 获取数据库管理器 - db_manager = get_db_manager() + manager = UserManager(session) # 获取用户信息 - user = db_manager.user_manager.get_user_by_id(user_id) + user = manager.get_user_by_id(user_id) if not user: raise HTTPException( @@ -422,7 +423,8 @@ async def get_user_points( async def add_user_points( user_id: int, points: int = Query(..., gt=0, description="增加的积分数量(必须大于0)"), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), + session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 为用户增加积分(需要管理员权限) @@ -443,10 +445,10 @@ async def add_user_points( ) # 获取数据库管理器 - db_manager = get_db_manager() + manager = UserManager(session) # 添加积分 - success = db_manager.user_manager.add_user_points(user_id, points) + success = manager.add_user_points(user_id, points) if not success: raise HTTPException( @@ -455,7 +457,7 @@ async def add_user_points( ) # 获取更新后的用户信息 - user = db_manager.user_manager.get_user_by_id(user_id) + user = manager.get_user_by_id(user_id) return { "status": "success", @@ -467,7 +469,8 @@ async def add_user_points( async def consume_user_points( user_id: int, points: int = Query(..., gt=0, description="消费的积分数量(必须大于0)"), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), + session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 用户消费积分 @@ -488,10 +491,10 @@ async def consume_user_points( ) # 获取数据库管理器 - db_manager = get_db_manager() + manager = UserManager(session) # 消费积分 - success = db_manager.user_manager.consume_user_points(user_id, points) + success = manager.consume_user_points(user_id, points) if not success: raise HTTPException( @@ -500,7 +503,7 @@ async def consume_user_points( ) # 获取更新后的用户信息 - user = db_manager.user_manager.get_user_by_id(user_id) + user = manager.get_user_by_id(user_id) return { "status": "success", diff --git a/cryptoai/utils/db_manager.py b/cryptoai/utils/db_manager.py index a79c071..3e982e2 100644 --- a/cryptoai/utils/db_manager.py +++ b/cryptoai/utils/db_manager.py @@ -7,17 +7,11 @@ from typing import Dict, Any, List, Optional, Union from datetime import datetime from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.pool import QueuePool from cryptoai.utils.config_loader import ConfigLoader -from cryptoai.models.base import Base -from cryptoai.models.token import TokenManager -from cryptoai.models.analysis_result import AnalysisResultManager -from cryptoai.models.user import UserManager -from cryptoai.models.user_question import UserQuestionManager -from cryptoai.models.astock import AStockManager -from cryptoai.models.analysis_history import AnalysisHistoryManager +from sqlalchemy.ext.declarative import declarative_base # 配置日志 logging.basicConfig( @@ -26,207 +20,49 @@ logging.basicConfig( ) logger = logging.getLogger('db_manager') -class DBManager: - """ - 数据库管理工具,用于连接MySQL数据库并提供各个模型的管理器 - - 使用方法: - - 调用 get_db_manager() 获取数据库管理器实例 - - 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法 - - 使用 db_manager.get_session() 可以获取一个新的数据库会话 - """ - - def __init__(self, host: str, port: int, user: str, password: str, db_name: str): - """ - 初始化数据库管理器 - - Args: - host: 数据库主机地址 - port: 数据库端口 - user: 用户名 - password: 密码 - db_name: 数据库名 - """ - self.host = host - self.port = port - self.user = user - self.password = password - self.db_name = db_name - self.engine = None - self.Session = None - - # 初始化数据库连接 - self._init_db() - - # 初始化各个管理器 - self._init_managers() - - def _init_db(self) -> None: - """初始化数据库连接和表""" - try: - # 创建数据库连接 - connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4" - - # 创建引擎,设置连接池 - self.engine = create_engine( - connection_string, - echo=False, # 设置为True可以输出SQL语句(调试用) - pool_size=5, # 连接池大小 - max_overflow=10, # 最大溢出连接数 - pool_timeout=30, # 连接超时时间 - pool_recycle=1800, # 连接回收时间(秒) - pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效 - connect_args={'charset': 'utf8mb4'} - ) - - # 创建会话工厂 - self.Session = sessionmaker(bind=self.engine) - - # 创建表(如果不存在) - Base.metadata.create_all(self.engine) - - logger.info(f"成功连接到数据库 {self.db_name}") - - except Exception as e: - logger.error(f"数据库初始化失败: {e}") - self.engine = None - - def _init_managers(self) -> None: - """初始化各个模型的管理器""" - if not self.engine: - logger.error("引擎未初始化,无法创建管理器") - return - - try: - session = self.Session() - - # 初始化各个模型的管理器 - self.token_manager = TokenManager(session) - self.analysis_result_manager = AnalysisResultManager(session) - self.user_manager = UserManager(session) - self.user_question_manager = UserQuestionManager(session) - self.astock_manager = AStockManager(session) - self.analysis_history_manager = AnalysisHistoryManager(session) - - logger.info("成功初始化所有模型管理器") - - except Exception as e: - logger.error(f"管理器初始化失败: {e}") - if session: - session.close() - - def get_session(self): - """ - 获取新的数据库会话 - - Returns: - SQLAlchemy session对象,如果初始化失败则返回None - """ - if not self.Session: - try: - self._init_db() - except Exception as e: - logger.error(f"重新初始化数据库失败: {e}") - return None - - return self.Session() - - def refresh_managers(self) -> bool: - """ - 刷新所有管理器,重新建立会话 - - Returns: - 刷新是否成功 - """ - try: - # 关闭旧会话(如果有) - self.close() - - # 重新初始化数据库连接 - self._init_db() - - # 重新初始化管理器 - self._init_managers() - - return self.engine is not None - except Exception as e: - logger.error(f"刷新管理器失败: {e}") - return False - - def close(self) -> None: - """关闭数据库连接(如果存在)""" - if self.engine: - self.engine.dispose() - logger.info("数据库连接已关闭") - self.engine = None - self.Session = None -# 单例模式 -_db_instance = None +config_loader = ConfigLoader() +db_config = config_loader.get_database_config() +engine = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['db_name']}?charset=utf8mb4", + echo=False, # 设置为True可以输出SQL语句(调试用) + pool_size=5, # 连接池大小 + max_overflow=10, # 最大溢出连接数 + pool_timeout=30, # 连接超时时间 + pool_recycle=1800, # 连接回收时间(秒) + pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效 + connect_args={'charset': 'utf8mb4'}) + +# 声明基类 +Base = declarative_base() +Base.metadata.create_all(bind=engine) + +# 创建线程安全的会话工厂 +SessionLocal = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + if db: + db.close() + +def get_db_context(): + try: + db = SessionLocal() + yield db + db.commit() + except Exception as e: + if db: + db.rollback() + raise e + finally: + if db: + db.close() + + + + -def get_db_manager(host: Optional[str] = None, - port: Optional[int] = None, - user: Optional[str] = None, - password: Optional[str] = None, - db_name: Optional[str] = None) -> DBManager: - """ - 获取数据库管理器实例(单例模式) - - Args: - host: 数据库主机地址 - port: 数据库端口 - user: 用户名 - password: 密码 - db_name: 数据库名 - - Returns: - 数据库管理器实例 - - 使用示例: - ```python - # 获取数据库管理器 - db_manager = get_db_manager() - - # 使用Token管理器 - tokens = db_manager.token_manager.search_token("BTC") - - # 使用用户管理器 - user = db_manager.user_manager.get_user_by_mail("example@test.com") - ``` - """ - global _db_instance - - # 如果已经初始化过,直接返回 - if _db_instance is not None: - return _db_instance - - # 如果未指定参数,从配置加载器获取数据库配置 - if host is None or port is None or user is None or password is None or db_name is None: - config_loader = ConfigLoader() - db_config = config_loader.get_database_config() - - # 使用配置中的值或默认值 - db_host = host or db_config.get('host') - db_port = port or db_config.get('port') - db_user = user or db_config.get('user') - db_password = password or db_config.get('password') - db_name = db_name or db_config.get('db_name') - - logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}") - else: - db_host = host - db_port = port - db_user = user - db_password = password - db_name = db_name - - # 创建实例 - _db_instance = DBManager( - host=db_host, - port=db_port, - user=db_user, - password=db_password, - db_name=db_name - ) - - return _db_instance \ No newline at end of file diff --git a/cryptoai/utils/update_db_charset.py b/cryptoai/utils/update_db_charset.py deleted file mode 100644 index 2397686..0000000 --- a/cryptoai/utils/update_db_charset.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -更新数据库表字符集为utf8mb4,以支持emoji和其他特殊字符 -""" - -import logging -from cryptoai.utils.db_manager import get_db_manager -from sqlalchemy import text - -# 配置日志 -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger('update_db_charset') - -def update_table_charset(): - """更新数据库表字符集为utf8mb4""" - try: - # 获取数据库管理器 - db_manager = get_db_manager() - - if not db_manager.engine: - logger.error("数据库连接失败") - return False - - # 创建会话 - session = db_manager.Session() - - try: - # 更新数据库字符集 - session.execute(text("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")) - - # 更新agent_feeds表字符集 - session.execute(text(""" - ALTER TABLE agent_feeds - CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; - """)) - - # 特别更新content列的字符集 - session.execute(text(""" - ALTER TABLE agent_feeds - MODIFY content TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; - """)) - - # 检查users表是否存在 - result = session.execute(text(""" - SELECT COUNT(*) - FROM information_schema.tables - WHERE table_schema = DATABASE() - AND table_name = 'users'; - """)) - - table_exists = result.scalar() > 0 - - # 如果users表存在,更新其字符集 - if table_exists: - session.execute(text(""" - ALTER TABLE users - CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; - """)) - - # 特别更新nickname和mail列的字符集 - session.execute(text(""" - ALTER TABLE users - MODIFY nickname VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; - """)) - - session.execute(text(""" - ALTER TABLE users - MODIFY mail VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; - """)) - - logger.info("成功更新users表字符集为utf8mb4") - - # 检查user_questions表是否存在 - result = session.execute(text(""" - SELECT COUNT(*) - FROM information_schema.tables - WHERE table_schema = DATABASE() - AND table_name = 'user_questions'; - """)) - - table_exists = result.scalar() > 0 - - # 如果user_questions表存在,更新其字符集 - if table_exists: - session.execute(text(""" - ALTER TABLE user_questions - CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; - """)) - - # 特别更新question列的字符集 - session.execute(text(""" - ALTER TABLE user_questions - MODIFY question TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; - """)) - - logger.info("成功更新user_questions表字符集为utf8mb4") - - session.commit() - logger.info("成功更新数据库表字符集为utf8mb4") - return True - - except Exception as e: - session.rollback() - logger.error(f"更新数据库表字符集失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"更新数据库表字符集失败: {e}") - return False - -if __name__ == "__main__": - update_table_charset() \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 4a02abf..5c8daf3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.1.32 + image: cryptoai-api:0.1.33 restart: always ports: - "8000:8000" diff --git a/test.py b/test.py index 05ed715..ffb9a7f 100644 --- a/test.py +++ b/test.py @@ -1,8 +1,9 @@ -from cryptoai.utils.db_manager import get_db_manager from cryptoai.api.adata_api import AStockAPI from cryptoai.api.binance_api import get_binance_api import json from time import sleep +from cryptoai.models.token import TokenManager +from cryptoai.utils.db_manager import get_db_context if __name__ == "__main__": symbols = get_binance_api().get_all_symbols() @@ -11,10 +12,11 @@ if __name__ == "__main__": # symbol = symbol.replace('USDT', '') # print(symbol) - manager = get_db_manager() + session = get_db_context() + manager = TokenManager(session) for symbol in symbols: base_asset = symbol.split('USDT')[0] quote_asset = 'USDT' - manager.token_manager.create_token(symbol,base_asset, quote_asset) + manager.create_token(symbol,base_asset, quote_asset)