From da278dc033b65b94c3152aba71943b0906fa94ec Mon Sep 17 00:00:00 2001 From: aaron <> Date: Fri, 30 May 2025 22:14:59 +0800 Subject: [PATCH] update --- cryptoai/routes/adata.py | 9 ++++++--- cryptoai/routes/analysis.py | 28 +++++++++++++++++----------- cryptoai/routes/crypto.py | 10 ++++++---- docker-compose.yml | 2 +- 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/cryptoai/routes/adata.py b/cryptoai/routes/adata.py index 927c329..42d5745 100644 --- a/cryptoai/routes/adata.py +++ b/cryptoai/routes/adata.py @@ -12,6 +12,8 @@ from fastapi.responses import StreamingResponse from cryptoai.routes.user import get_current_user import requests from cryptoai.models.astock import AStockManager +from sqlalchemy.orm import Session +from cryptoai.utils.db_manager import get_db # 创建路由 router = APIRouter() @@ -19,16 +21,17 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @router.get("/stock/search") -async def search_stock(key: str, limit: int = 10): - manager = AStockManager() +async def search_stock(key: str, limit: int = 10, session: Session = Depends(get_db)): + manager = AStockManager(session) result = manager.search_stock(key, limit) return result @router.get("/stock/base", summary="获取股票基础信息") -async def get_stock_base(stock_code: str): +async def get_stock_base(stock_code: str, session: Session = Depends(get_db)): api = AStockAPI() + manager = AStockManager(session) result = {} diff --git a/cryptoai/routes/analysis.py b/cryptoai/routes/analysis.py index 93aacda..1e72e2d 100644 --- a/cryptoai/routes/analysis.py +++ b/cryptoai/routes/analysis.py @@ -10,6 +10,8 @@ from datetime import date, timedelta from cryptoai.models.analysis_history import AnalysisHistoryManager from cryptoai.models.user_question import UserQuestionManager from cryptoai.models.token import TokenManager +from sqlalchemy.orm import Session +from cryptoai.utils.db_manager import get_db class AnalysisHistoryRequest(BaseModel): symbol: str content: str @@ -20,9 +22,10 @@ router = APIRouter() @router.post("/analysis_history") async def analysis_history(request: AnalysisHistoryRequest, - current_user: dict = Depends(get_current_user)): + current_user: dict = Depends(get_current_user), + session: Session = Depends(get_db)): - manager = AnalysisHistoryManager() + manager = AnalysisHistoryManager(session) manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe) return {"message": "ok"} @@ -30,8 +33,9 @@ async def analysis_history(request: AnalysisHistoryRequest, @router.get("/analysis_histories") async def get_analysis_histories(current_user: dict = Depends(get_current_user), limit: int = 10, - offset: int = 0): - manager = AnalysisHistoryManager() + offset: int = 0, + session: Session = Depends(get_db)): + manager = AnalysisHistoryManager(session) history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset) return history @@ -48,7 +52,8 @@ class ChatRequest(BaseModel): @router.post("/chat-messages") async def chat(request: ChatRequest, - current_user: dict = Depends(get_current_user)): + current_user: dict = Depends(get_current_user), + session: Session = Depends(get_db)): token = 'app-pPtva2AdJ8hJzkBKu12ThWjD' @@ -68,7 +73,7 @@ async def chat(request: ChatRequest, 'Content-Type': 'application/json' } - manager = UserQuestionManager() + manager = UserQuestionManager(session) manager.save_user_question(current_user["id"],"chat-messages", request.message) response = requests.post(url, headers=headers, json=payload, stream=True) @@ -90,11 +95,12 @@ async def chat(request: ChatRequest, @router.post("/analysis") async def analysis(request: AnalysisRequest, - current_user: dict = Depends(get_current_user)): + current_user: dict = Depends(get_current_user), + session: Session = Depends(get_db)): if request.type == 'crypto': # 检查symbol是否存在 - manager = TokenManager() + manager = TokenManager(session) tokens = manager.search_token(request.symbol) if not tokens or len(tokens) == 0: raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") @@ -111,7 +117,7 @@ async def analysis(request: AnalysisRequest, "user": current_user["mail"] } - manager = UserQuestionManager() + manager = UserQuestionManager(session) manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。") @@ -127,7 +133,7 @@ async def analysis(request: AnalysisRequest, "user": current_user["mail"] } - manager = UserQuestionManager() + manager = UserQuestionManager(session) manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。") elif request.type == 'usstock': @@ -143,7 +149,7 @@ async def analysis(request: AnalysisRequest, "response_mode": "streaming", "user": current_user["mail"] } - manager = UserQuestionManager() + manager = UserQuestionManager(session) manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。") else: diff --git a/cryptoai/routes/crypto.py b/cryptoai/routes/crypto.py index e677654..6e09d65 100644 --- a/cryptoai/routes/crypto.py +++ b/cryptoai/routes/crypto.py @@ -15,6 +15,8 @@ from cryptoai.api.binance_api import get_binance_api from cryptoai.models.data_processor import DataProcessor from datetime import timedelta from cryptoai.models.token import TokenManager +from sqlalchemy.orm import Session +from cryptoai.utils.db_manager import get_db # 创建路由 router = APIRouter() @@ -22,8 +24,8 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @router.get("/search/{key}") -async def search_crypto(key: str): - manager = TokenManager() +async def search_crypto(key: str, session: Session = Depends(get_db)): + manager = TokenManager(session) result = manager.search_token(key) return result @@ -32,9 +34,9 @@ class CryptoAnalysisRequest(BaseModel): timeframe: Optional[str] = None @router.get("/kline/{symbol}") -async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100): +async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100, session: Session = Depends(get_db)): # 检查symbol是否存在 - token_manager = TokenManager() + token_manager = TokenManager(session) tokens = token_manager.search_token(symbol) if not tokens or len(tokens) == 0: raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") diff --git a/docker-compose.yml b/docker-compose.yml index 5c8daf3..8e36b26 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.1.33 + image: cryptoai-api:0.1.34 restart: always ports: - "8000:8000"