update
This commit is contained in:
parent
4027f68dca
commit
08c4efcfd3
@ -17,7 +17,6 @@ from api.deepseek_api import DeepSeekAPI
|
|||||||
from models.data_processor import DataProcessor
|
from models.data_processor import DataProcessor
|
||||||
from utils.config_loader import ConfigLoader
|
from utils.config_loader import ConfigLoader
|
||||||
from utils.dingtalk_bot import DingTalkBot
|
from utils.dingtalk_bot import DingTalkBot
|
||||||
from utils.db_manager import get_db_manager
|
|
||||||
from utils.discord_bot import DiscordBot
|
from utils.discord_bot import DiscordBot
|
||||||
|
|
||||||
class CryptoAgent:
|
class CryptoAgent:
|
||||||
@ -79,9 +78,6 @@ class CryptoAgent:
|
|||||||
)
|
)
|
||||||
print("Discord机器人已启用")
|
print("Discord机器人已启用")
|
||||||
|
|
||||||
# 初始化数据库管理器
|
|
||||||
self.db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 设置支持的加密货币
|
# 设置支持的加密货币
|
||||||
self.base_currencies = self.crypto_config['base_currencies']
|
self.base_currencies = self.crypto_config['base_currencies']
|
||||||
self.quote_currency = self.crypto_config['quote_currency']
|
self.quote_currency = self.crypto_config['quote_currency']
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import adata
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from cryptoai.utils.db_manager import get_db_context
|
||||||
class AStockAPI:
|
class AStockAPI:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_stock_codes() -> List[str]:
|
def get_all_stock_codes() -> List[str]:
|
||||||
@ -156,8 +156,11 @@ if __name__ == "__main__":
|
|||||||
print(list)
|
print(list)
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
import cryptoai.utils.db_manager as db_manager
|
from cryptoai.models.astock import AStockManager
|
||||||
db_manager.get_db_manager().create_stocks(list)
|
|
||||||
|
session = get_db_context()
|
||||||
|
manager = AStockManager(session)
|
||||||
|
manager.create_stocks(list)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,12 +3,12 @@
|
|||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from cryptoai.models.base import Base, logger
|
from cryptoai.models.base import Base, logger
|
||||||
|
from cryptoai.utils.db_manager import get_db_context
|
||||||
# 定义分析历史模型
|
# 定义分析历史模型
|
||||||
class AnalysisHistory(Base):
|
class AnalysisHistory(Base):
|
||||||
"""分析历史表模型"""
|
"""分析历史表模型"""
|
||||||
@ -37,8 +37,8 @@ class AnalysisHistory(Base):
|
|||||||
class AnalysisHistoryManager:
|
class AnalysisHistoryManager:
|
||||||
"""分析历史管理类"""
|
"""分析历史管理类"""
|
||||||
|
|
||||||
def __init__(self, db_session):
|
def __init__(self, session: Session = None):
|
||||||
self.session = db_session
|
self.session = session
|
||||||
|
|
||||||
def add_analysis_history(self, user_id: int, type: str, symbol: str,
|
def add_analysis_history(self, user_id: int, type: str, symbol: str,
|
||||||
content: str, timeframe: str = None) -> bool:
|
content: str, timeframe: str = None) -> bool:
|
||||||
|
|||||||
@ -3,12 +3,12 @@
|
|||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||||
from sqlalchemy.dialects.mysql import JSON
|
from sqlalchemy.dialects.mysql import JSON
|
||||||
|
|
||||||
from cryptoai.models.base import Base, logger
|
from cryptoai.models.base import Base, logger
|
||||||
|
from cryptoai.utils.db_manager import get_db_context
|
||||||
# 定义分析结果模型
|
# 定义分析结果模型
|
||||||
class AnalysisResult(Base):
|
class AnalysisResult(Base):
|
||||||
"""分析结果表模型"""
|
"""分析结果表模型"""
|
||||||
@ -33,8 +33,8 @@ class AnalysisResult(Base):
|
|||||||
class AnalysisResultManager:
|
class AnalysisResultManager:
|
||||||
"""分析结果管理类"""
|
"""分析结果管理类"""
|
||||||
|
|
||||||
def __init__(self, db_session):
|
def __init__(self, session: Session = None):
|
||||||
self.session = db_session
|
self.session = session
|
||||||
|
|
||||||
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
|
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
|
||||||
analysis_result: Dict[str, Any]) -> bool:
|
analysis_result: Dict[str, Any]) -> bool:
|
||||||
|
|||||||
@ -5,9 +5,10 @@ from typing import Dict, Any, List, Optional, Union
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from cryptoai.models.base import Base, logger
|
from cryptoai.models.base import Base, logger
|
||||||
from cryptoai.utils.db_utils import convert_timestamp_to_datetime
|
from cryptoai.utils.db_utils import convert_timestamp_to_datetime
|
||||||
|
from cryptoai.utils.db_manager import get_db_context
|
||||||
|
|
||||||
# 定义 A 股数据模型
|
# 定义 A 股数据模型
|
||||||
class AStock(Base):
|
class AStock(Base):
|
||||||
@ -29,8 +30,8 @@ class AStock(Base):
|
|||||||
class AStockManager:
|
class AStockManager:
|
||||||
"""A股管理类"""
|
"""A股管理类"""
|
||||||
|
|
||||||
def __init__(self, db_session):
|
def __init__(self, session: Session = None):
|
||||||
self.session = db_session
|
self.session = session
|
||||||
|
|
||||||
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
|
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -3,11 +3,11 @@
|
|||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||||
|
|
||||||
from cryptoai.models.base import Base, logger
|
from cryptoai.models.base import Base, logger
|
||||||
|
from cryptoai.utils.db_manager import get_db_context
|
||||||
# 定义Token模型
|
# 定义Token模型
|
||||||
class Token(Base):
|
class Token(Base):
|
||||||
"""Token信息表模型"""
|
"""Token信息表模型"""
|
||||||
@ -30,8 +30,8 @@ class Token(Base):
|
|||||||
class TokenManager:
|
class TokenManager:
|
||||||
"""Token管理类"""
|
"""Token管理类"""
|
||||||
|
|
||||||
def __init__(self, db_session):
|
def __init__(self, session: Session = None):
|
||||||
self.session = db_session
|
self.session = session
|
||||||
|
|
||||||
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
|
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from datetime import datetime
|
|||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from cryptoai.models.base import Base, logger
|
from cryptoai.models.base import Base, logger
|
||||||
|
|
||||||
# 定义用户数据模型
|
# 定义用户数据模型
|
||||||
@ -37,8 +37,8 @@ class User(Base):
|
|||||||
class UserManager:
|
class UserManager:
|
||||||
"""用户管理类"""
|
"""用户管理类"""
|
||||||
|
|
||||||
def __init__(self, db_session):
|
def __init__(self, session: Session = None):
|
||||||
self.session = db_session
|
self.session = session
|
||||||
|
|
||||||
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
|
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@ -35,8 +35,8 @@ class UserQuestion(Base):
|
|||||||
class UserQuestionManager:
|
class UserQuestionManager:
|
||||||
"""用户提问管理类"""
|
"""用户提问管理类"""
|
||||||
|
|
||||||
def __init__(self, db_session):
|
def __init__(self, session: Session = None):
|
||||||
self.session = db_session
|
self.session = session
|
||||||
|
|
||||||
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
|
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
||||||
from cryptoai.api.adata_api import AStockAPI
|
from cryptoai.api.adata_api import AStockAPI
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -12,7 +11,7 @@ from cryptoai.utils.config_loader import ConfigLoader
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from cryptoai.routes.user import get_current_user
|
from cryptoai.routes.user import get_current_user
|
||||||
import requests
|
import requests
|
||||||
|
from cryptoai.models.astock import AStockManager
|
||||||
# 创建路由
|
# 创建路由
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -21,8 +20,8 @@ logger.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
@router.get("/stock/search")
|
@router.get("/stock/search")
|
||||||
async def search_stock(key: str, limit: int = 10):
|
async def search_stock(key: str, limit: int = 10):
|
||||||
manager = get_db_manager()
|
manager = AStockManager()
|
||||||
result = manager.astock_manager.search_stock(key, limit)
|
result = manager.search_stock(key, limit)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cryptoai.routes.user import get_current_user
|
from cryptoai.routes.user import get_current_user
|
||||||
@ -8,7 +7,9 @@ from fastapi import HTTPException
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import requests
|
import requests
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
|
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||||||
|
from cryptoai.models.user_question import UserQuestionManager
|
||||||
|
from cryptoai.models.token import TokenManager
|
||||||
class AnalysisHistoryRequest(BaseModel):
|
class AnalysisHistoryRequest(BaseModel):
|
||||||
symbol: str
|
symbol: str
|
||||||
content: str
|
content: str
|
||||||
@ -21,7 +22,8 @@ router = APIRouter()
|
|||||||
async def analysis_history(request: AnalysisHistoryRequest,
|
async def analysis_history(request: AnalysisHistoryRequest,
|
||||||
current_user: dict = Depends(get_current_user)):
|
current_user: dict = Depends(get_current_user)):
|
||||||
|
|
||||||
get_db_manager().analysis_history_manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
manager = AnalysisHistoryManager()
|
||||||
|
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
||||||
|
|
||||||
return {"message": "ok"}
|
return {"message": "ok"}
|
||||||
|
|
||||||
@ -29,7 +31,8 @@ async def analysis_history(request: AnalysisHistoryRequest,
|
|||||||
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
offset: int = 0):
|
offset: int = 0):
|
||||||
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
manager = AnalysisHistoryManager()
|
||||||
|
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +68,8 @@ async def chat(request: ChatRequest,
|
|||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
manager = UserQuestionManager()
|
||||||
|
manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||||
|
|
||||||
@ -90,7 +94,8 @@ async def analysis(request: AnalysisRequest,
|
|||||||
|
|
||||||
if request.type == 'crypto':
|
if request.type == 'crypto':
|
||||||
# 检查symbol是否存在
|
# 检查symbol是否存在
|
||||||
tokens = get_db_manager().token_manager.search_token(request.symbol)
|
manager = TokenManager()
|
||||||
|
tokens = manager.search_token(request.symbol)
|
||||||
if not tokens or len(tokens) == 0:
|
if not tokens or len(tokens) == 0:
|
||||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||||
|
|
||||||
@ -106,7 +111,8 @@ async def analysis(request: AnalysisRequest,
|
|||||||
"user": current_user["mail"]
|
"user": current_user["mail"]
|
||||||
}
|
}
|
||||||
|
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
manager = UserQuestionManager()
|
||||||
|
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||||
|
|
||||||
|
|
||||||
elif request.type == 'astock':
|
elif request.type == 'astock':
|
||||||
@ -121,7 +127,8 @@ async def analysis(request: AnalysisRequest,
|
|||||||
"user": current_user["mail"]
|
"user": current_user["mail"]
|
||||||
}
|
}
|
||||||
|
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
manager = UserQuestionManager()
|
||||||
|
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
||||||
|
|
||||||
elif request.type == 'usstock':
|
elif request.type == 'usstock':
|
||||||
stock_code = request.stock_code
|
stock_code = request.stock_code
|
||||||
@ -136,7 +143,8 @@ async def analysis(request: AnalysisRequest,
|
|||||||
"response_mode": "streaming",
|
"response_mode": "streaming",
|
||||||
"user": current_user["mail"]
|
"user": current_user["mail"]
|
||||||
}
|
}
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
manager = UserQuestionManager()
|
||||||
|
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail="不支持的类型")
|
raise HTTPException(status_code=400, detail="不支持的类型")
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
||||||
from cryptoai.api.adata_api import AStockAPI
|
from cryptoai.api.adata_api import AStockAPI
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -15,6 +14,7 @@ import requests
|
|||||||
from cryptoai.api.binance_api import get_binance_api
|
from cryptoai.api.binance_api import get_binance_api
|
||||||
from cryptoai.models.data_processor import DataProcessor
|
from cryptoai.models.data_processor import DataProcessor
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from cryptoai.models.token import TokenManager
|
||||||
# 创建路由
|
# 创建路由
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ logger.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
@router.get("/search/{key}")
|
@router.get("/search/{key}")
|
||||||
async def search_crypto(key: str):
|
async def search_crypto(key: str):
|
||||||
manager = get_db_manager()
|
manager = TokenManager()
|
||||||
result = manager.search_token(key)
|
result = manager.search_token(key)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -34,7 +34,8 @@ class CryptoAnalysisRequest(BaseModel):
|
|||||||
@router.get("/kline/{symbol}")
|
@router.get("/kline/{symbol}")
|
||||||
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
|
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
|
||||||
# 检查symbol是否存在
|
# 检查symbol是否存在
|
||||||
tokens = get_db_manager().token_manager.search_token(symbol)
|
token_manager = TokenManager()
|
||||||
|
tokens = token_manager.search_token(symbol)
|
||||||
if not tokens or len(tokens) == 0:
|
if not tokens or len(tokens) == 0:
|
||||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
import logging
|
import logging
|
||||||
|
from cryptoai.models.user import UserManager
|
||||||
|
from cryptoai.models.user_question import UserQuestionManager
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from fastapi import Depends
|
||||||
|
from cryptoai.utils.db_manager import get_db
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -8,14 +12,16 @@ logger = logging.getLogger("platform_router")
|
|||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
@router.get("/info")
|
@router.get("/info")
|
||||||
async def get_platform_info():
|
async def get_platform_info(session: Session = Depends(get_db)):
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
user_manager = UserManager(session)
|
||||||
|
question_manager = UserQuestionManager(session)
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result["user_count"] = db_manager.user_manager.get_user_count()
|
result["user_count"] = user_manager.get_user_count()
|
||||||
result["question_count"] = db_manager.user_question_manager.get_user_question_count()
|
result["question_count"] = question_manager.get_user_question_count()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -16,9 +16,10 @@ from datetime import datetime, timedelta
|
|||||||
import jwt
|
import jwt
|
||||||
from jwt.exceptions import PyJWTError
|
from jwt.exceptions import PyJWTError
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
from cryptoai.utils.db_manager import get_db
|
||||||
from cryptoai.utils.email_service import get_email_service
|
from cryptoai.utils.email_service import get_email_service
|
||||||
|
from cryptoai.models.user import UserManager
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = logging.getLogger("user_router")
|
logger = logging.getLogger("user_router")
|
||||||
@ -87,7 +88,7 @@ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -
|
|||||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> Dict[str, Any]:
|
async def get_current_user(request: Request, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||||||
"""获取当前用户"""
|
"""获取当前用户"""
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -109,8 +110,8 @@ async def get_current_user(request: Request) -> Dict[str, Any]:
|
|||||||
print(f"PyJWTError: {e}")
|
print(f"PyJWTError: {e}")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
db_manager = get_db_manager()
|
manager = UserManager(session)
|
||||||
user = db_manager.user_manager.get_user_by_mail(mail)
|
user = manager.get_user_by_mail(mail)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
return user
|
return user
|
||||||
@ -154,13 +155,13 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s
|
|||||||
)
|
)
|
||||||
|
|
||||||
@router.put("/reset_password", response_model=Dict[str, Any])
|
@router.put("/reset_password", response_model=Dict[str, Any])
|
||||||
async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
|
async def reset_password(request: ResetPasswordRequest, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
修改密码
|
修改密码
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取数据库管理器
|
# 获取数据库管理器
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 验证验证码
|
# 验证验证码
|
||||||
email_service = get_email_service()
|
email_service = get_email_service()
|
||||||
@ -173,7 +174,8 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
|
|||||||
# 更新密码
|
# 更新密码
|
||||||
hashed_password = hash_password(request.new_password)
|
hashed_password = hash_password(request.new_password)
|
||||||
|
|
||||||
success = db_manager.user_manager.update_password(request.mail, hashed_password)
|
manager = UserManager(session)
|
||||||
|
success = manager.update_password(request.mail, hashed_password)
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@ -193,7 +195,7 @@ async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||||||
async def register_user(user: UserRegister) -> Dict[str, Any]:
|
async def register_user(user: UserRegister, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
注册新用户
|
注册新用户
|
||||||
|
|
||||||
@ -215,13 +217,13 @@ async def register_user(user: UserRegister) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取数据库管理器
|
# 获取数据库管理器
|
||||||
db_manager = get_db_manager()
|
manager = UserManager(session)
|
||||||
|
|
||||||
# 对密码进行哈希处理
|
# 对密码进行哈希处理
|
||||||
hashed_password = hash_password(user.password)
|
hashed_password = hash_password(user.password)
|
||||||
|
|
||||||
# 注册用户
|
# 注册用户
|
||||||
success = db_manager.user_manager.register_user(
|
success = manager.register_user(
|
||||||
mail=user.mail,
|
mail=user.mail,
|
||||||
nickname=user.nickname,
|
nickname=user.nickname,
|
||||||
password=hashed_password,
|
password=hashed_password,
|
||||||
@ -250,7 +252,7 @@ async def register_user(user: UserRegister) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
async def login(loginData: UserLogin) -> TokenResponse:
|
async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> TokenResponse:
|
||||||
"""
|
"""
|
||||||
用户登录
|
用户登录
|
||||||
|
|
||||||
@ -262,10 +264,10 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取数据库管理器
|
# 获取数据库管理器
|
||||||
db_manager = get_db_manager()
|
manager = UserManager(session)
|
||||||
|
|
||||||
# 获取用户信息
|
# 获取用户信息
|
||||||
user = db_manager.user_manager.get_user_by_mail(loginData.mail)
|
user = manager.get_user_by_mail(loginData.mail)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -278,17 +280,14 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
|||||||
hashed_password = hash_password(loginData.password)
|
hashed_password = hash_password(loginData.password)
|
||||||
|
|
||||||
# 查询用户的密码哈希
|
# 查询用户的密码哈希
|
||||||
session = db_manager.Session()
|
user = manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
|
||||||
try:
|
if not user:
|
||||||
user = db_manager.user_manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
|
raise HTTPException(
|
||||||
if not user:
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
raise HTTPException(
|
detail="邮箱或密码错误",
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
detail="邮箱或密码错误",
|
)
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
# 创建访问令牌,不过期
|
# 创建访问令牌,不过期
|
||||||
access_token = create_access_token(data={"sub": user["mail"]})
|
access_token = create_access_token(data={"sub": user["mail"]})
|
||||||
@ -317,7 +316,7 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse:
|
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db)) -> UserResponse:
|
||||||
"""
|
"""
|
||||||
获取当前登录用户信息
|
获取当前登录用户信息
|
||||||
|
|
||||||
@ -340,7 +339,8 @@ async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)
|
|||||||
async def update_user_level(
|
async def update_user_level(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
|
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||||
|
session: Session = Depends(get_db)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
更新用户级别(需要管理员权限)
|
更新用户级别(需要管理员权限)
|
||||||
@ -361,10 +361,10 @@ async def update_user_level(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取数据库管理器
|
# 获取数据库管理器
|
||||||
db_manager = get_db_manager()
|
manager = UserManager(session)
|
||||||
|
|
||||||
# 更新用户级别
|
# 更新用户级别
|
||||||
success = db_manager.user_manager.update_user_level(user_id, level)
|
success = manager.update_user_level(user_id, level)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -380,7 +380,8 @@ async def update_user_level(
|
|||||||
@router.get("/points/{user_id}", response_model=Dict[str, Any])
|
@router.get("/points/{user_id}", response_model=Dict[str, Any])
|
||||||
async def get_user_points(
|
async def get_user_points(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||||
|
session: Session = Depends(get_db)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取用户积分
|
获取用户积分
|
||||||
@ -400,10 +401,10 @@ async def get_user_points(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取数据库管理器
|
# 获取数据库管理器
|
||||||
db_manager = get_db_manager()
|
manager = UserManager(session)
|
||||||
|
|
||||||
# 获取用户信息
|
# 获取用户信息
|
||||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
user = manager.get_user_by_id(user_id)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -422,7 +423,8 @@ async def get_user_points(
|
|||||||
async def add_user_points(
|
async def add_user_points(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
points: int = Query(..., gt=0, description="增加的积分数量(必须大于0)"),
|
points: int = Query(..., gt=0, description="增加的积分数量(必须大于0)"),
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||||
|
session: Session = Depends(get_db)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
为用户增加积分(需要管理员权限)
|
为用户增加积分(需要管理员权限)
|
||||||
@ -443,10 +445,10 @@ async def add_user_points(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取数据库管理器
|
# 获取数据库管理器
|
||||||
db_manager = get_db_manager()
|
manager = UserManager(session)
|
||||||
|
|
||||||
# 添加积分
|
# 添加积分
|
||||||
success = db_manager.user_manager.add_user_points(user_id, points)
|
success = manager.add_user_points(user_id, points)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -455,7 +457,7 @@ async def add_user_points(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取更新后的用户信息
|
# 获取更新后的用户信息
|
||||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
user = manager.get_user_by_id(user_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@ -467,7 +469,8 @@ async def add_user_points(
|
|||||||
async def consume_user_points(
|
async def consume_user_points(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
points: int = Query(..., gt=0, description="消费的积分数量(必须大于0)"),
|
points: int = Query(..., gt=0, description="消费的积分数量(必须大于0)"),
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||||
|
session: Session = Depends(get_db)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
用户消费积分
|
用户消费积分
|
||||||
@ -488,10 +491,10 @@ async def consume_user_points(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取数据库管理器
|
# 获取数据库管理器
|
||||||
db_manager = get_db_manager()
|
manager = UserManager(session)
|
||||||
|
|
||||||
# 消费积分
|
# 消费积分
|
||||||
success = db_manager.user_manager.consume_user_points(user_id, points)
|
success = manager.consume_user_points(user_id, points)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -500,7 +503,7 @@ async def consume_user_points(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取更新后的用户信息
|
# 获取更新后的用户信息
|
||||||
user = db_manager.user_manager.get_user_by_id(user_id)
|
user = manager.get_user_by_id(user_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
|
|||||||
@ -7,17 +7,11 @@ from typing import Dict, Any, List, Optional, Union
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||||
from sqlalchemy.pool import QueuePool
|
from sqlalchemy.pool import QueuePool
|
||||||
|
|
||||||
from cryptoai.utils.config_loader import ConfigLoader
|
from cryptoai.utils.config_loader import ConfigLoader
|
||||||
from cryptoai.models.base import Base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from cryptoai.models.token import TokenManager
|
|
||||||
from cryptoai.models.analysis_result import AnalysisResultManager
|
|
||||||
from cryptoai.models.user import UserManager
|
|
||||||
from cryptoai.models.user_question import UserQuestionManager
|
|
||||||
from cryptoai.models.astock import AStockManager
|
|
||||||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -26,207 +20,49 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger('db_manager')
|
logger = logging.getLogger('db_manager')
|
||||||
|
|
||||||
class DBManager:
|
|
||||||
"""
|
|
||||||
数据库管理工具,用于连接MySQL数据库并提供各个模型的管理器
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
- 调用 get_db_manager() 获取数据库管理器实例
|
|
||||||
- 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法
|
|
||||||
- 使用 db_manager.get_session() 可以获取一个新的数据库会话
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
|
|
||||||
"""
|
|
||||||
初始化数据库管理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host: 数据库主机地址
|
|
||||||
port: 数据库端口
|
|
||||||
user: 用户名
|
|
||||||
password: 密码
|
|
||||||
db_name: 数据库名
|
|
||||||
"""
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.user = user
|
|
||||||
self.password = password
|
|
||||||
self.db_name = db_name
|
|
||||||
self.engine = None
|
|
||||||
self.Session = None
|
|
||||||
|
|
||||||
# 初始化数据库连接
|
|
||||||
self._init_db()
|
|
||||||
|
|
||||||
# 初始化各个管理器
|
|
||||||
self._init_managers()
|
|
||||||
|
|
||||||
def _init_db(self) -> None:
|
|
||||||
"""初始化数据库连接和表"""
|
|
||||||
try:
|
|
||||||
# 创建数据库连接
|
|
||||||
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
|
|
||||||
|
|
||||||
# 创建引擎,设置连接池
|
|
||||||
self.engine = create_engine(
|
|
||||||
connection_string,
|
|
||||||
echo=False, # 设置为True可以输出SQL语句(调试用)
|
|
||||||
pool_size=5, # 连接池大小
|
|
||||||
max_overflow=10, # 最大溢出连接数
|
|
||||||
pool_timeout=30, # 连接超时时间
|
|
||||||
pool_recycle=1800, # 连接回收时间(秒)
|
|
||||||
pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效
|
|
||||||
connect_args={'charset': 'utf8mb4'}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建会话工厂
|
|
||||||
self.Session = sessionmaker(bind=self.engine)
|
|
||||||
|
|
||||||
# 创建表(如果不存在)
|
|
||||||
Base.metadata.create_all(self.engine)
|
|
||||||
|
|
||||||
logger.info(f"成功连接到数据库 {self.db_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库初始化失败: {e}")
|
|
||||||
self.engine = None
|
|
||||||
|
|
||||||
def _init_managers(self) -> None:
|
|
||||||
"""初始化各个模型的管理器"""
|
|
||||||
if not self.engine:
|
|
||||||
logger.error("引擎未初始化,无法创建管理器")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
session = self.Session()
|
|
||||||
|
|
||||||
# 初始化各个模型的管理器
|
|
||||||
self.token_manager = TokenManager(session)
|
|
||||||
self.analysis_result_manager = AnalysisResultManager(session)
|
|
||||||
self.user_manager = UserManager(session)
|
|
||||||
self.user_question_manager = UserQuestionManager(session)
|
|
||||||
self.astock_manager = AStockManager(session)
|
|
||||||
self.analysis_history_manager = AnalysisHistoryManager(session)
|
|
||||||
|
|
||||||
logger.info("成功初始化所有模型管理器")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"管理器初始化失败: {e}")
|
|
||||||
if session:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
def get_session(self):
|
|
||||||
"""
|
|
||||||
获取新的数据库会话
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SQLAlchemy session对象,如果初始化失败则返回None
|
|
||||||
"""
|
|
||||||
if not self.Session:
|
|
||||||
try:
|
|
||||||
self._init_db()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"重新初始化数据库失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return self.Session()
|
|
||||||
|
|
||||||
def refresh_managers(self) -> bool:
|
|
||||||
"""
|
|
||||||
刷新所有管理器,重新建立会话
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
刷新是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 关闭旧会话(如果有)
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
# 重新初始化数据库连接
|
|
||||||
self._init_db()
|
|
||||||
|
|
||||||
# 重新初始化管理器
|
|
||||||
self._init_managers()
|
|
||||||
|
|
||||||
return self.engine is not None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"刷新管理器失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""关闭数据库连接(如果存在)"""
|
|
||||||
if self.engine:
|
|
||||||
self.engine.dispose()
|
|
||||||
logger.info("数据库连接已关闭")
|
|
||||||
self.engine = None
|
|
||||||
self.Session = None
|
|
||||||
|
|
||||||
# 单例模式
|
config_loader = ConfigLoader()
|
||||||
_db_instance = None
|
db_config = config_loader.get_database_config()
|
||||||
|
engine = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['db_name']}?charset=utf8mb4",
|
||||||
|
echo=False, # 设置为True可以输出SQL语句(调试用)
|
||||||
|
pool_size=5, # 连接池大小
|
||||||
|
max_overflow=10, # 最大溢出连接数
|
||||||
|
pool_timeout=30, # 连接超时时间
|
||||||
|
pool_recycle=1800, # 连接回收时间(秒)
|
||||||
|
pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效
|
||||||
|
connect_args={'charset': 'utf8mb4'})
|
||||||
|
|
||||||
|
# 声明基类
|
||||||
|
Base = declarative_base()
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
# 创建线程安全的会话工厂
|
||||||
|
SessionLocal = scoped_session(
|
||||||
|
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
if db:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
def get_db_context():
|
||||||
|
try:
|
||||||
|
db = SessionLocal()
|
||||||
|
yield db
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
if db:
|
||||||
|
db.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if db:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_manager(host: Optional[str] = None,
|
|
||||||
port: Optional[int] = None,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
db_name: Optional[str] = None) -> DBManager:
|
|
||||||
"""
|
|
||||||
获取数据库管理器实例(单例模式)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host: 数据库主机地址
|
|
||||||
port: 数据库端口
|
|
||||||
user: 用户名
|
|
||||||
password: 密码
|
|
||||||
db_name: 数据库名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
数据库管理器实例
|
|
||||||
|
|
||||||
使用示例:
|
|
||||||
```python
|
|
||||||
# 获取数据库管理器
|
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 使用Token管理器
|
|
||||||
tokens = db_manager.token_manager.search_token("BTC")
|
|
||||||
|
|
||||||
# 使用用户管理器
|
|
||||||
user = db_manager.user_manager.get_user_by_mail("example@test.com")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
global _db_instance
|
|
||||||
|
|
||||||
# 如果已经初始化过,直接返回
|
|
||||||
if _db_instance is not None:
|
|
||||||
return _db_instance
|
|
||||||
|
|
||||||
# 如果未指定参数,从配置加载器获取数据库配置
|
|
||||||
if host is None or port is None or user is None or password is None or db_name is None:
|
|
||||||
config_loader = ConfigLoader()
|
|
||||||
db_config = config_loader.get_database_config()
|
|
||||||
|
|
||||||
# 使用配置中的值或默认值
|
|
||||||
db_host = host or db_config.get('host')
|
|
||||||
db_port = port or db_config.get('port')
|
|
||||||
db_user = user or db_config.get('user')
|
|
||||||
db_password = password or db_config.get('password')
|
|
||||||
db_name = db_name or db_config.get('db_name')
|
|
||||||
|
|
||||||
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
|
|
||||||
else:
|
|
||||||
db_host = host
|
|
||||||
db_port = port
|
|
||||||
db_user = user
|
|
||||||
db_password = password
|
|
||||||
db_name = db_name
|
|
||||||
|
|
||||||
# 创建实例
|
|
||||||
_db_instance = DBManager(
|
|
||||||
host=db_host,
|
|
||||||
port=db_port,
|
|
||||||
user=db_user,
|
|
||||||
password=db_password,
|
|
||||||
db_name=db_name
|
|
||||||
)
|
|
||||||
|
|
||||||
return _db_instance
|
|
||||||
@ -1,120 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
"""
|
|
||||||
更新数据库表字符集为utf8mb4,以支持emoji和其他特殊字符
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger('update_db_charset')
|
|
||||||
|
|
||||||
def update_table_charset():
|
|
||||||
"""更新数据库表字符集为utf8mb4"""
|
|
||||||
try:
|
|
||||||
# 获取数据库管理器
|
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
if not db_manager.engine:
|
|
||||||
logger.error("数据库连接失败")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 创建会话
|
|
||||||
session = db_manager.Session()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 更新数据库字符集
|
|
||||||
session.execute(text("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;"))
|
|
||||||
|
|
||||||
# 更新agent_feeds表字符集
|
|
||||||
session.execute(text("""
|
|
||||||
ALTER TABLE agent_feeds
|
|
||||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
|
||||||
"""))
|
|
||||||
|
|
||||||
# 特别更新content列的字符集
|
|
||||||
session.execute(text("""
|
|
||||||
ALTER TABLE agent_feeds
|
|
||||||
MODIFY content TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
|
||||||
"""))
|
|
||||||
|
|
||||||
# 检查users表是否存在
|
|
||||||
result = session.execute(text("""
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM information_schema.tables
|
|
||||||
WHERE table_schema = DATABASE()
|
|
||||||
AND table_name = 'users';
|
|
||||||
"""))
|
|
||||||
|
|
||||||
table_exists = result.scalar() > 0
|
|
||||||
|
|
||||||
# 如果users表存在,更新其字符集
|
|
||||||
if table_exists:
|
|
||||||
session.execute(text("""
|
|
||||||
ALTER TABLE users
|
|
||||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
|
||||||
"""))
|
|
||||||
|
|
||||||
# 特别更新nickname和mail列的字符集
|
|
||||||
session.execute(text("""
|
|
||||||
ALTER TABLE users
|
|
||||||
MODIFY nickname VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
|
||||||
"""))
|
|
||||||
|
|
||||||
session.execute(text("""
|
|
||||||
ALTER TABLE users
|
|
||||||
MODIFY mail VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
|
||||||
"""))
|
|
||||||
|
|
||||||
logger.info("成功更新users表字符集为utf8mb4")
|
|
||||||
|
|
||||||
# 检查user_questions表是否存在
|
|
||||||
result = session.execute(text("""
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM information_schema.tables
|
|
||||||
WHERE table_schema = DATABASE()
|
|
||||||
AND table_name = 'user_questions';
|
|
||||||
"""))
|
|
||||||
|
|
||||||
table_exists = result.scalar() > 0
|
|
||||||
|
|
||||||
# 如果user_questions表存在,更新其字符集
|
|
||||||
if table_exists:
|
|
||||||
session.execute(text("""
|
|
||||||
ALTER TABLE user_questions
|
|
||||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
|
||||||
"""))
|
|
||||||
|
|
||||||
# 特别更新question列的字符集
|
|
||||||
session.execute(text("""
|
|
||||||
ALTER TABLE user_questions
|
|
||||||
MODIFY question TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
|
||||||
"""))
|
|
||||||
|
|
||||||
logger.info("成功更新user_questions表字符集为utf8mb4")
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
logger.info("成功更新数据库表字符集为utf8mb4")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
session.rollback()
|
|
||||||
logger.error(f"更新数据库表字符集失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"更新数据库表字符集失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
update_table_charset()
|
|
||||||
@ -29,7 +29,7 @@ services:
|
|||||||
cryptoai-api:
|
cryptoai-api:
|
||||||
build: .
|
build: .
|
||||||
container_name: cryptoai-api
|
container_name: cryptoai-api
|
||||||
image: cryptoai-api:0.1.32
|
image: cryptoai-api:0.1.33
|
||||||
restart: always
|
restart: always
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|||||||
8
test.py
8
test.py
@ -1,8 +1,9 @@
|
|||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
from cryptoai.api.adata_api import AStockAPI
|
from cryptoai.api.adata_api import AStockAPI
|
||||||
from cryptoai.api.binance_api import get_binance_api
|
from cryptoai.api.binance_api import get_binance_api
|
||||||
import json
|
import json
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
from cryptoai.models.token import TokenManager
|
||||||
|
from cryptoai.utils.db_manager import get_db_context
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
symbols = get_binance_api().get_all_symbols()
|
symbols = get_binance_api().get_all_symbols()
|
||||||
@ -11,10 +12,11 @@ if __name__ == "__main__":
|
|||||||
# symbol = symbol.replace('USDT', '')
|
# symbol = symbol.replace('USDT', '')
|
||||||
# print(symbol)
|
# print(symbol)
|
||||||
|
|
||||||
manager = get_db_manager()
|
session = get_db_context()
|
||||||
|
manager = TokenManager(session)
|
||||||
|
|
||||||
for symbol in symbols:
|
for symbol in symbols:
|
||||||
base_asset = symbol.split('USDT')[0]
|
base_asset = symbol.split('USDT')[0]
|
||||||
quote_asset = 'USDT'
|
quote_asset = 'USDT'
|
||||||
|
|
||||||
manager.token_manager.create_token(symbol,base_asset, quote_asset)
|
manager.create_token(symbol,base_asset, quote_asset)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user