This commit is contained in:
aaron 2025-05-30 22:14:59 +08:00
parent 08c4efcfd3
commit da278dc033
4 changed files with 30 additions and 19 deletions

View File

@ -12,6 +12,8 @@ from fastapi.responses import StreamingResponse
from cryptoai.routes.user import get_current_user from cryptoai.routes.user import get_current_user
import requests import requests
from cryptoai.models.astock import AStockManager from cryptoai.models.astock import AStockManager
from sqlalchemy.orm import Session
from cryptoai.utils.db_manager import get_db
# 创建路由 # 创建路由
router = APIRouter() router = APIRouter()
@ -19,16 +21,17 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@router.get("/stock/search") @router.get("/stock/search")
async def search_stock(key: str, limit: int = 10): async def search_stock(key: str, limit: int = 10, session: Session = Depends(get_db)):
manager = AStockManager() manager = AStockManager(session)
result = manager.search_stock(key, limit) result = manager.search_stock(key, limit)
return result return result
@router.get("/stock/base", summary="获取股票基础信息") @router.get("/stock/base", summary="获取股票基础信息")
async def get_stock_base(stock_code: str): async def get_stock_base(stock_code: str, session: Session = Depends(get_db)):
api = AStockAPI() api = AStockAPI()
manager = AStockManager(session)
result = {} result = {}

View File

@ -10,6 +10,8 @@ from datetime import date, timedelta
from cryptoai.models.analysis_history import AnalysisHistoryManager from cryptoai.models.analysis_history import AnalysisHistoryManager
from cryptoai.models.user_question import UserQuestionManager from cryptoai.models.user_question import UserQuestionManager
from cryptoai.models.token import TokenManager from cryptoai.models.token import TokenManager
from sqlalchemy.orm import Session
from cryptoai.utils.db_manager import get_db
class AnalysisHistoryRequest(BaseModel): class AnalysisHistoryRequest(BaseModel):
symbol: str symbol: str
content: str content: str
@ -20,9 +22,10 @@ router = APIRouter()
@router.post("/analysis_history") @router.post("/analysis_history")
async def analysis_history(request: AnalysisHistoryRequest, async def analysis_history(request: AnalysisHistoryRequest,
current_user: dict = Depends(get_current_user)): current_user: dict = Depends(get_current_user),
session: Session = Depends(get_db)):
manager = AnalysisHistoryManager() manager = AnalysisHistoryManager(session)
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe) manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
return {"message": "ok"} return {"message": "ok"}
@ -30,8 +33,9 @@ async def analysis_history(request: AnalysisHistoryRequest,
@router.get("/analysis_histories") @router.get("/analysis_histories")
async def get_analysis_histories(current_user: dict = Depends(get_current_user), async def get_analysis_histories(current_user: dict = Depends(get_current_user),
limit: int = 10, limit: int = 10,
offset: int = 0): offset: int = 0,
manager = AnalysisHistoryManager() session: Session = Depends(get_db)):
manager = AnalysisHistoryManager(session)
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset) history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
return history return history
@ -48,7 +52,8 @@ class ChatRequest(BaseModel):
@router.post("/chat-messages") @router.post("/chat-messages")
async def chat(request: ChatRequest, async def chat(request: ChatRequest,
current_user: dict = Depends(get_current_user)): current_user: dict = Depends(get_current_user),
session: Session = Depends(get_db)):
token = 'app-pPtva2AdJ8hJzkBKu12ThWjD' token = 'app-pPtva2AdJ8hJzkBKu12ThWjD'
@ -68,7 +73,7 @@ async def chat(request: ChatRequest,
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
manager = UserQuestionManager() manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"],"chat-messages", request.message) manager.save_user_question(current_user["id"],"chat-messages", request.message)
response = requests.post(url, headers=headers, json=payload, stream=True) response = requests.post(url, headers=headers, json=payload, stream=True)
@ -90,11 +95,12 @@ async def chat(request: ChatRequest,
@router.post("/analysis") @router.post("/analysis")
async def analysis(request: AnalysisRequest, async def analysis(request: AnalysisRequest,
current_user: dict = Depends(get_current_user)): current_user: dict = Depends(get_current_user),
session: Session = Depends(get_db)):
if request.type == 'crypto': if request.type == 'crypto':
# 检查symbol是否存在 # 检查symbol是否存在
manager = TokenManager() manager = TokenManager(session)
tokens = manager.search_token(request.symbol) tokens = manager.search_token(request.symbol)
if not tokens or len(tokens) == 0: if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
@ -111,7 +117,7 @@ async def analysis(request: AnalysisRequest,
"user": current_user["mail"] "user": current_user["mail"]
} }
manager = UserQuestionManager() manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。") manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
@ -127,7 +133,7 @@ async def analysis(request: AnalysisRequest,
"user": current_user["mail"] "user": current_user["mail"]
} }
manager = UserQuestionManager() manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票" + stock_code + ",并给出分析报告。") manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票" + stock_code + ",并给出分析报告。")
elif request.type == 'usstock': elif request.type == 'usstock':
@ -143,7 +149,7 @@ async def analysis(request: AnalysisRequest,
"response_mode": "streaming", "response_mode": "streaming",
"user": current_user["mail"] "user": current_user["mail"]
} }
manager = UserQuestionManager() manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。") manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
else: else:

View File

@ -15,6 +15,8 @@ from cryptoai.api.binance_api import get_binance_api
from cryptoai.models.data_processor import DataProcessor from cryptoai.models.data_processor import DataProcessor
from datetime import timedelta from datetime import timedelta
from cryptoai.models.token import TokenManager from cryptoai.models.token import TokenManager
from sqlalchemy.orm import Session
from cryptoai.utils.db_manager import get_db
# 创建路由 # 创建路由
router = APIRouter() router = APIRouter()
@ -22,8 +24,8 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@router.get("/search/{key}") @router.get("/search/{key}")
async def search_crypto(key: str): async def search_crypto(key: str, session: Session = Depends(get_db)):
manager = TokenManager() manager = TokenManager(session)
result = manager.search_token(key) result = manager.search_token(key)
return result return result
@ -32,9 +34,9 @@ class CryptoAnalysisRequest(BaseModel):
timeframe: Optional[str] = None timeframe: Optional[str] = None
@router.get("/kline/{symbol}") @router.get("/kline/{symbol}")
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100): async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100, session: Session = Depends(get_db)):
# 检查symbol是否存在 # 检查symbol是否存在
token_manager = TokenManager() token_manager = TokenManager(session)
tokens = token_manager.search_token(symbol) tokens = token_manager.search_token(symbol)
if not tokens or len(tokens) == 0: if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")

View File

@ -29,7 +29,7 @@ services:
cryptoai-api: cryptoai-api:
build: . build: .
container_name: cryptoai-api container_name: cryptoai-api
image: cryptoai-api:0.1.33 image: cryptoai-api:0.1.34
restart: always restart: always
ports: ports:
- "8000:8000" - "8000:8000"