update
This commit is contained in:
parent
4027f68dca
commit
08c4efcfd3
@ -17,7 +17,6 @@ from api.deepseek_api import DeepSeekAPI
|
||||
from models.data_processor import DataProcessor
|
||||
from utils.config_loader import ConfigLoader
|
||||
from utils.dingtalk_bot import DingTalkBot
|
||||
from utils.db_manager import get_db_manager
|
||||
from utils.discord_bot import DiscordBot
|
||||
|
||||
class CryptoAgent:
|
||||
@ -79,9 +78,6 @@ class CryptoAgent:
|
||||
)
|
||||
print("Discord机器人已启用")
|
||||
|
||||
# 初始化数据库管理器
|
||||
self.db_manager = get_db_manager()
|
||||
|
||||
# 设置支持的加密货币
|
||||
self.base_currencies = self.crypto_config['base_currencies']
|
||||
self.quote_currency = self.crypto_config['quote_currency']
|
||||
|
||||
@ -7,7 +7,7 @@ import adata
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_context
|
||||
class AStockAPI:
|
||||
@staticmethod
|
||||
def get_all_stock_codes() -> List[str]:
|
||||
@ -156,8 +156,11 @@ if __name__ == "__main__":
|
||||
print(list)
|
||||
|
||||
# 保存到数据库
|
||||
import cryptoai.utils.db_manager as db_manager
|
||||
db_manager.get_db_manager().create_stocks(list)
|
||||
from cryptoai.models.astock import AStockManager
|
||||
|
||||
session = get_db_context()
|
||||
manager = AStockManager(session)
|
||||
manager.create_stocks(list)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -3,12 +3,12 @@
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_context
|
||||
# 定义分析历史模型
|
||||
class AnalysisHistory(Base):
|
||||
"""分析历史表模型"""
|
||||
@ -37,8 +37,8 @@ class AnalysisHistory(Base):
|
||||
class AnalysisHistoryManager:
|
||||
"""分析历史管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
def __init__(self, session: Session = None):
|
||||
self.session = session
|
||||
|
||||
def add_analysis_history(self, user_id: int, type: str, symbol: str,
|
||||
content: str, timeframe: str = None) -> bool:
|
||||
|
||||
@ -3,12 +3,12 @@
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
from sqlalchemy.dialects.mysql import JSON
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_context
|
||||
# 定义分析结果模型
|
||||
class AnalysisResult(Base):
|
||||
"""分析结果表模型"""
|
||||
@ -33,8 +33,8 @@ class AnalysisResult(Base):
|
||||
class AnalysisResultManager:
|
||||
"""分析结果管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
def __init__(self, session: Session = None):
|
||||
self.session = session
|
||||
|
||||
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
|
||||
analysis_result: Dict[str, Any]) -> bool:
|
||||
|
||||
@ -5,9 +5,10 @@ from typing import Dict, Any, List, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from cryptoai.models.base import Base, logger
|
||||
from cryptoai.utils.db_utils import convert_timestamp_to_datetime
|
||||
from cryptoai.utils.db_manager import get_db_context
|
||||
|
||||
# 定义 A 股数据模型
|
||||
class AStock(Base):
|
||||
@ -29,8 +30,8 @@ class AStock(Base):
|
||||
class AStockManager:
|
||||
"""A股管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
def __init__(self, session: Session = None):
|
||||
self.session = session
|
||||
|
||||
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
|
||||
@ -3,11 +3,11 @@
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_context
|
||||
# 定义Token模型
|
||||
class Token(Base):
|
||||
"""Token信息表模型"""
|
||||
@ -30,8 +30,8 @@ class Token(Base):
|
||||
class TokenManager:
|
||||
"""Token管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
def __init__(self, session: Session = None):
|
||||
self.session = session
|
||||
|
||||
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
|
||||
"""
|
||||
|
||||
@ -6,7 +6,7 @@ from datetime import datetime
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义用户数据模型
|
||||
@ -37,8 +37,8 @@ class User(Base):
|
||||
class UserManager:
|
||||
"""用户管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
def __init__(self, session: Session = None):
|
||||
self.session = session
|
||||
|
||||
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@ -35,8 +35,8 @@ class UserQuestion(Base):
|
||||
class UserQuestionManager:
|
||||
"""用户提问管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
def __init__(self, session: Session = None):
|
||||
self.session = session
|
||||
|
||||
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
|
||||
"""
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
||||
from cryptoai.api.adata_api import AStockAPI
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
@ -12,7 +11,7 @@ from cryptoai.utils.config_loader import ConfigLoader
|
||||
from fastapi.responses import StreamingResponse
|
||||
from cryptoai.routes.user import get_current_user
|
||||
import requests
|
||||
|
||||
from cryptoai.models.astock import AStockManager
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
@ -21,8 +20,8 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
@router.get("/stock/search")
|
||||
async def search_stock(key: str, limit: int = 10):
|
||||
manager = get_db_manager()
|
||||
result = manager.astock_manager.search_stock(key, limit)
|
||||
manager = AStockManager()
|
||||
result = manager.search_stock(key, limit)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from typing import Optional
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from cryptoai.routes.user import get_current_user
|
||||
@ -8,7 +7,9 @@ from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import requests
|
||||
from datetime import date, timedelta
|
||||
|
||||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||||
from cryptoai.models.user_question import UserQuestionManager
|
||||
from cryptoai.models.token import TokenManager
|
||||
class AnalysisHistoryRequest(BaseModel):
|
||||
symbol: str
|
||||
content: str
|
||||
@ -21,7 +22,8 @@ router = APIRouter()
|
||||
async def analysis_history(request: AnalysisHistoryRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
|
||||
get_db_manager().analysis_history_manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
||||
manager = AnalysisHistoryManager()
|
||||
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
@ -29,7 +31,8 @@ async def analysis_history(request: AnalysisHistoryRequest,
|
||||
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
||||
limit: int = 10,
|
||||
offset: int = 0):
|
||||
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||
manager = AnalysisHistoryManager()
|
||||
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||
return history
|
||||
|
||||
|
||||
@ -65,7 +68,8 @@ async def chat(request: ChatRequest,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
||||
manager = UserQuestionManager()
|
||||
manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||
|
||||
@ -90,7 +94,8 @@ async def analysis(request: AnalysisRequest,
|
||||
|
||||
if request.type == 'crypto':
|
||||
# 检查symbol是否存在
|
||||
tokens = get_db_manager().token_manager.search_token(request.symbol)
|
||||
manager = TokenManager()
|
||||
tokens = manager.search_token(request.symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
|
||||
@ -106,7 +111,8 @@ async def analysis(request: AnalysisRequest,
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||
manager = UserQuestionManager()
|
||||
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||
|
||||
|
||||
elif request.type == 'astock':
|
||||
@ -121,7 +127,8 @@ async def analysis(request: AnalysisRequest,
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
||||
manager = UserQuestionManager()
|
||||
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
||||
|
||||
elif request.type == 'usstock':
|
||||
stock_code = request.stock_code
|
||||
@ -136,7 +143,8 @@ async def analysis(request: AnalysisRequest,
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
||||
manager = UserQuestionManager()
|
||||
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="不支持的类型")
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
||||
from cryptoai.api.adata_api import AStockAPI
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
@ -15,6 +14,7 @@ import requests
|
||||
from cryptoai.api.binance_api import get_binance_api
|
||||
from cryptoai.models.data_processor import DataProcessor
|
||||
from datetime import timedelta
|
||||
from cryptoai.models.token import TokenManager
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
@ -23,7 +23,7 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
@router.get("/search/{key}")
|
||||
async def search_crypto(key: str):
|
||||
manager = get_db_manager()
|
||||
manager = TokenManager()
|
||||
result = manager.search_token(key)
|
||||
return result
|
||||
|
||||
@ -34,7 +34,8 @@ class CryptoAnalysisRequest(BaseModel):
|
||||
@router.get("/kline/{symbol}")
|
||||
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
|
||||
# 检查symbol是否存在
|
||||
tokens = get_db_manager().token_manager.search_token(symbol)
|
||||
token_manager = TokenManager()
|
||||
tokens = token_manager.search_token(symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
import logging
|
||||
from cryptoai.models.user import UserManager
|
||||
from cryptoai.models.user_question import UserQuestionManager
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Depends
|
||||
from cryptoai.utils.db_manager import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -8,14 +12,16 @@ logger = logging.getLogger("platform_router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
@router.get("/info")
|
||||
async def get_platform_info():
|
||||
db_manager = get_db_manager()
|
||||
async def get_platform_info(session: Session = Depends(get_db)):
|
||||
|
||||
user_manager = UserManager(session)
|
||||
question_manager = UserQuestionManager(session)
|
||||
|
||||
result = {}
|
||||
|
||||
try:
|
||||
result["user_count"] = db_manager.user_manager.get_user_count()
|
||||
result["question_count"] = db_manager.user_question_manager.get_user_question_count()
|
||||
result["user_count"] = user_manager.get_user_count()
|
||||
result["question_count"] = question_manager.get_user_question_count()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@ -16,9 +16,10 @@ from datetime import datetime, timedelta
|
||||
import jwt
|
||||
from jwt.exceptions import PyJWTError
|
||||
from fastapi import Request
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from sqlalchemy.orm import Session
|
||||
from cryptoai.utils.db_manager import get_db
|
||||
from cryptoai.utils.email_service import get_email_service
|
||||
from cryptoai.models.user import UserManager
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger("user_router")
|
||||
@ -87,7 +88,7 @@ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
async def get_current_user(request: Request) -> Dict[str, Any]:
|
||||
async def get_current_user(request: Request, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||||
"""获取当前用户"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -109,8 +110,8 @@ async def get_current_user(request: Request) -> Dict[str, Any]:
|
||||
print(f"PyJWTError: {e}")
|
||||
raise credentials_exception
|
||||
|
||||
db_manager = get_db_manager()
|
||||
user = db_manager.user_manager.get_user_by_mail(mail)
|
||||
manager = UserManager(session)
|
||||
user = manager.get_user_by_mail(mail)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
@ -154,13 +155,13 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s
|
||||
)
|
||||
|
||||
@router.put("/reset_password", response_model=Dict[str, Any])
|
||||
async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
|
||||
async def reset_password(request: ResetPasswordRequest, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
修改密码
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
|
||||
# 验证验证码
|
||||
email_service = get_email_service()
|
||||
@ -173,7 +174,8 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
|
||||
# 更新密码
|
||||
hashed_password = hash_password(request.new_password)
|
||||
|
||||
success = db_manager.user_manager.update_password(request.mail, hashed_password)
|
||||
manager = UserManager(session)
|
||||
success = manager.update_password(request.mail, hashed_password)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@ -193,7 +195,7 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(user: UserRegister) -> Dict[str, Any]:
|
||||
async def register_user(user: UserRegister, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||||
"""
|
||||
注册新用户
|
||||
|
||||
@ -215,13 +217,13 @@ async def register_user(user: UserRegister) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
manager = UserManager(session)
|
||||
|
||||
# 对密码进行哈希处理
|
||||
hashed_password = hash_password(user.password)
|
||||
|
||||
# 注册用户
|
||||
success = db_manager.user_manager.register_user(
|
||||
success = manager.register_user(
|
||||
mail=user.mail,
|
||||
nickname=user.nickname,
|
||||
password=hashed_password,
|
||||
@ -250,7 +252,7 @@ async def register_user(user: UserRegister) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(loginData: UserLogin) -> TokenResponse:
|
||||
async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> TokenResponse:
|
||||
"""
|
||||
用户登录
|
||||
|
||||
@ -262,10 +264,10 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
manager = UserManager(session)
|
||||
|
||||
# 获取用户信息
|
||||
user = db_manager.user_manager.get_user_by_mail(loginData.mail)
|
||||
user = manager.get_user_by_mail(loginData.mail)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@ -278,17 +280,14 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
||||
hashed_password = hash_password(loginData.password)
|
||||
|
||||
# 查询用户的密码哈希
|
||||
session = db_manager.Session()
|
||||
try:
|
||||
user = db_manager.user_manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
|
||||
user = manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="邮箱或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# 创建访问令牌,不过期
|
||||
access_token = create_access_token(data={"sub": user["mail"]})
|
||||
@ -317,7 +316,7 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
||||
)
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse:
|
||||
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db)) -> UserResponse:
|
||||
"""
|
||||
获取当前登录用户信息
|
||||
|
||||
@ -340,7 +339,8 @@ async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
async def update_user_level(
|
||||
user_id: int,
|
||||
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
session: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新用户级别(需要管理员权限)
|
||||
@ -361,10 +361,10 @@ async def update_user_level(
|
||||
)
|
||||
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
manager = UserManager(session)
|
||||
|
||||
# 更新用户级别
|
||||
success = db_manager.user_manager.update_user_level(user_id, level)
|
||||
success = manager.update_user_level(user_id, level)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@ -380,7 +380,8 @@ async def update_user_level(
|
||||
@router.get("/points/{user_id}", response_model=Dict[str, Any])
|
||||
async def get_user_points(
|
||||
user_id: int,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
session: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户积分
|
||||
@ -400,10 +401,10 @@ async def get_user_points(
|
||||
)
|
||||
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
manager = UserManager(session)
|
||||
|
||||
# 获取用户信息
|
||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
||||
user = manager.get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@ -422,7 +423,8 @@ async def get_user_points(
|
||||
async def add_user_points(
|
||||
user_id: int,
|
||||
points: int = Query(..., gt=0, description="增加的积分数量(必须大于0)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
session: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为用户增加积分(需要管理员权限)
|
||||
@ -443,10 +445,10 @@ async def add_user_points(
|
||||
)
|
||||
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
manager = UserManager(session)
|
||||
|
||||
# 添加积分
|
||||
success = db_manager.user_manager.add_user_points(user_id, points)
|
||||
success = manager.add_user_points(user_id, points)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@ -455,7 +457,7 @@ async def add_user_points(
|
||||
)
|
||||
|
||||
# 获取更新后的用户信息
|
||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
||||
user = manager.get_user_by_id(user_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
@ -467,7 +469,8 @@ async def add_user_points(
|
||||
async def consume_user_points(
|
||||
user_id: int,
|
||||
points: int = Query(..., gt=0, description="消费的积分数量(必须大于0)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
session: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
用户消费积分
|
||||
@ -488,10 +491,10 @@ async def consume_user_points(
|
||||
)
|
||||
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
manager = UserManager(session)
|
||||
|
||||
# 消费积分
|
||||
success = db_manager.user_manager.consume_user_points(user_id, points)
|
||||
success = manager.consume_user_points(user_id, points)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@ -500,7 +503,7 @@ async def consume_user_points(
|
||||
)
|
||||
|
||||
# 获取更新后的用户信息
|
||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
||||
user = manager.get_user_by_id(user_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
|
||||
@ -7,17 +7,11 @@ from typing import Dict, Any, List, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
from cryptoai.utils.config_loader import ConfigLoader
|
||||
from cryptoai.models.base import Base
|
||||
from cryptoai.models.token import TokenManager
|
||||
from cryptoai.models.analysis_result import AnalysisResultManager
|
||||
from cryptoai.models.user import UserManager
|
||||
from cryptoai.models.user_question import UserQuestionManager
|
||||
from cryptoai.models.astock import AStockManager
|
||||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
@ -26,207 +20,49 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger('db_manager')
|
||||
|
||||
class DBManager:
|
||||
"""
|
||||
数据库管理工具,用于连接MySQL数据库并提供各个模型的管理器
|
||||
|
||||
使用方法:
|
||||
- 调用 get_db_manager() 获取数据库管理器实例
|
||||
- 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法
|
||||
- 使用 db_manager.get_session() 可以获取一个新的数据库会话
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
|
||||
"""
|
||||
初始化数据库管理器
|
||||
|
||||
Args:
|
||||
host: 数据库主机地址
|
||||
port: 数据库端口
|
||||
user: 用户名
|
||||
password: 密码
|
||||
db_name: 数据库名
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.db_name = db_name
|
||||
self.engine = None
|
||||
self.Session = None
|
||||
|
||||
# 初始化数据库连接
|
||||
self._init_db()
|
||||
|
||||
# 初始化各个管理器
|
||||
self._init_managers()
|
||||
|
||||
def _init_db(self) -> None:
|
||||
"""初始化数据库连接和表"""
|
||||
try:
|
||||
# 创建数据库连接
|
||||
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
|
||||
|
||||
# 创建引擎,设置连接池
|
||||
self.engine = create_engine(
|
||||
connection_string,
|
||||
config_loader = ConfigLoader()
|
||||
db_config = config_loader.get_database_config()
|
||||
engine = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['db_name']}?charset=utf8mb4",
|
||||
echo=False, # 设置为True可以输出SQL语句(调试用)
|
||||
pool_size=5, # 连接池大小
|
||||
max_overflow=10, # 最大溢出连接数
|
||||
pool_timeout=30, # 连接超时时间
|
||||
pool_recycle=1800, # 连接回收时间(秒)
|
||||
pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效
|
||||
connect_args={'charset': 'utf8mb4'}
|
||||
connect_args={'charset': 'utf8mb4'})
|
||||
|
||||
# 声明基类
|
||||
Base = declarative_base()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# 创建线程安全的会话工厂
|
||||
SessionLocal = scoped_session(
|
||||
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
# 创建表(如果不存在)
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
logger.info(f"成功连接到数据库 {self.db_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化失败: {e}")
|
||||
self.engine = None
|
||||
|
||||
def _init_managers(self) -> None:
|
||||
"""初始化各个模型的管理器"""
|
||||
if not self.engine:
|
||||
logger.error("引擎未初始化,无法创建管理器")
|
||||
return
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
session = self.Session()
|
||||
yield db
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
# 初始化各个模型的管理器
|
||||
self.token_manager = TokenManager(session)
|
||||
self.analysis_result_manager = AnalysisResultManager(session)
|
||||
self.user_manager = UserManager(session)
|
||||
self.user_question_manager = UserQuestionManager(session)
|
||||
self.astock_manager = AStockManager(session)
|
||||
self.analysis_history_manager = AnalysisHistoryManager(session)
|
||||
|
||||
logger.info("成功初始化所有模型管理器")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"管理器初始化失败: {e}")
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
def get_session(self):
|
||||
"""
|
||||
获取新的数据库会话
|
||||
|
||||
Returns:
|
||||
SQLAlchemy session对象,如果初始化失败则返回None
|
||||
"""
|
||||
if not self.Session:
|
||||
def get_db_context():
|
||||
try:
|
||||
self._init_db()
|
||||
db = SessionLocal()
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"重新初始化数据库失败: {e}")
|
||||
return None
|
||||
if db:
|
||||
db.rollback()
|
||||
raise e
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
return self.Session()
|
||||
|
||||
def refresh_managers(self) -> bool:
|
||||
"""
|
||||
刷新所有管理器,重新建立会话
|
||||
|
||||
Returns:
|
||||
刷新是否成功
|
||||
"""
|
||||
try:
|
||||
# 关闭旧会话(如果有)
|
||||
self.close()
|
||||
|
||||
# 重新初始化数据库连接
|
||||
self._init_db()
|
||||
|
||||
# 重新初始化管理器
|
||||
self._init_managers()
|
||||
|
||||
return self.engine is not None
|
||||
except Exception as e:
|
||||
logger.error(f"刷新管理器失败: {e}")
|
||||
return False
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭数据库连接(如果存在)"""
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
logger.info("数据库连接已关闭")
|
||||
self.engine = None
|
||||
self.Session = None
|
||||
|
||||
# 单例模式
|
||||
_db_instance = None
|
||||
|
||||
def get_db_manager(host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
db_name: Optional[str] = None) -> DBManager:
|
||||
"""
|
||||
获取数据库管理器实例(单例模式)
|
||||
|
||||
Args:
|
||||
host: 数据库主机地址
|
||||
port: 数据库端口
|
||||
user: 用户名
|
||||
password: 密码
|
||||
db_name: 数据库名
|
||||
|
||||
Returns:
|
||||
数据库管理器实例
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 使用Token管理器
|
||||
tokens = db_manager.token_manager.search_token("BTC")
|
||||
|
||||
# 使用用户管理器
|
||||
user = db_manager.user_manager.get_user_by_mail("example@test.com")
|
||||
```
|
||||
"""
|
||||
global _db_instance
|
||||
|
||||
# 如果已经初始化过,直接返回
|
||||
if _db_instance is not None:
|
||||
return _db_instance
|
||||
|
||||
# 如果未指定参数,从配置加载器获取数据库配置
|
||||
if host is None or port is None or user is None or password is None or db_name is None:
|
||||
config_loader = ConfigLoader()
|
||||
db_config = config_loader.get_database_config()
|
||||
|
||||
# 使用配置中的值或默认值
|
||||
db_host = host or db_config.get('host')
|
||||
db_port = port or db_config.get('port')
|
||||
db_user = user or db_config.get('user')
|
||||
db_password = password or db_config.get('password')
|
||||
db_name = db_name or db_config.get('db_name')
|
||||
|
||||
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
|
||||
else:
|
||||
db_host = host
|
||||
db_port = port
|
||||
db_user = user
|
||||
db_password = password
|
||||
db_name = db_name
|
||||
|
||||
# 创建实例
|
||||
_db_instance = DBManager(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=db_user,
|
||||
password=db_password,
|
||||
db_name=db_name
|
||||
)
|
||||
|
||||
return _db_instance
|
||||
@ -1,120 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
更新数据库表字符集为utf8mb4,以支持emoji和其他特殊字符
|
||||
"""
|
||||
|
||||
import logging
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from sqlalchemy import text
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('update_db_charset')
|
||||
|
||||
def update_table_charset():
|
||||
"""更新数据库表字符集为utf8mb4"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
if not db_manager.engine:
|
||||
logger.error("数据库连接失败")
|
||||
return False
|
||||
|
||||
# 创建会话
|
||||
session = db_manager.Session()
|
||||
|
||||
try:
|
||||
# 更新数据库字符集
|
||||
session.execute(text("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;"))
|
||||
|
||||
# 更新agent_feeds表字符集
|
||||
session.execute(text("""
|
||||
ALTER TABLE agent_feeds
|
||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
# 特别更新content列的字符集
|
||||
session.execute(text("""
|
||||
ALTER TABLE agent_feeds
|
||||
MODIFY content TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
# 检查users表是否存在
|
||||
result = session.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = 'users';
|
||||
"""))
|
||||
|
||||
table_exists = result.scalar() > 0
|
||||
|
||||
# 如果users表存在,更新其字符集
|
||||
if table_exists:
|
||||
session.execute(text("""
|
||||
ALTER TABLE users
|
||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
# 特别更新nickname和mail列的字符集
|
||||
session.execute(text("""
|
||||
ALTER TABLE users
|
||||
MODIFY nickname VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
session.execute(text("""
|
||||
ALTER TABLE users
|
||||
MODIFY mail VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
logger.info("成功更新users表字符集为utf8mb4")
|
||||
|
||||
# 检查user_questions表是否存在
|
||||
result = session.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = 'user_questions';
|
||||
"""))
|
||||
|
||||
table_exists = result.scalar() > 0
|
||||
|
||||
# 如果user_questions表存在,更新其字符集
|
||||
if table_exists:
|
||||
session.execute(text("""
|
||||
ALTER TABLE user_questions
|
||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
# 特别更新question列的字符集
|
||||
session.execute(text("""
|
||||
ALTER TABLE user_questions
|
||||
MODIFY question TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
logger.info("成功更新user_questions表字符集为utf8mb4")
|
||||
|
||||
session.commit()
|
||||
logger.info("成功更新数据库表字符集为utf8mb4")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"更新数据库表字符集失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库表字符集失败: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_table_charset()
|
||||
@ -29,7 +29,7 @@ services:
|
||||
cryptoai-api:
|
||||
build: .
|
||||
container_name: cryptoai-api
|
||||
image: cryptoai-api:0.1.32
|
||||
image: cryptoai-api:0.1.33
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
8
test.py
8
test.py
@ -1,8 +1,9 @@
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from cryptoai.api.adata_api import AStockAPI
|
||||
from cryptoai.api.binance_api import get_binance_api
|
||||
import json
|
||||
from time import sleep
|
||||
from cryptoai.models.token import TokenManager
|
||||
from cryptoai.utils.db_manager import get_db_context
|
||||
if __name__ == "__main__":
|
||||
|
||||
symbols = get_binance_api().get_all_symbols()
|
||||
@ -11,10 +12,11 @@ if __name__ == "__main__":
|
||||
# symbol = symbol.replace('USDT', '')
|
||||
# print(symbol)
|
||||
|
||||
manager = get_db_manager()
|
||||
session = get_db_context()
|
||||
manager = TokenManager(session)
|
||||
|
||||
for symbol in symbols:
|
||||
base_asset = symbol.split('USDT')[0]
|
||||
quote_asset = 'USDT'
|
||||
|
||||
manager.token_manager.create_token(symbol,base_asset, quote_asset)
|
||||
manager.create_token(symbol,base_asset, quote_asset)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user