This commit is contained in:
aaron 2025-05-30 22:14:59 +08:00
parent 08c4efcfd3
commit da278dc033
4 changed files with 30 additions and 19 deletions

View File

@ -12,6 +12,8 @@ from fastapi.responses import StreamingResponse
from cryptoai.routes.user import get_current_user
import requests
from cryptoai.models.astock import AStockManager
from sqlalchemy.orm import Session
from cryptoai.utils.db_manager import get_db
# 创建路由
router = APIRouter()
@ -19,16 +21,17 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@router.get("/stock/search")
async def search_stock(key: str, limit: int = 10):
manager = AStockManager()
async def search_stock(key: str, limit: int = 10, session: Session = Depends(get_db)):
manager = AStockManager(session)
result = manager.search_stock(key, limit)
return result
@router.get("/stock/base", summary="获取股票基础信息")
async def get_stock_base(stock_code: str):
async def get_stock_base(stock_code: str, session: Session = Depends(get_db)):
api = AStockAPI()
manager = AStockManager(session)
result = {}

View File

@ -10,6 +10,8 @@ from datetime import date, timedelta
from cryptoai.models.analysis_history import AnalysisHistoryManager
from cryptoai.models.user_question import UserQuestionManager
from cryptoai.models.token import TokenManager
from sqlalchemy.orm import Session
from cryptoai.utils.db_manager import get_db
class AnalysisHistoryRequest(BaseModel):
symbol: str
content: str
@ -20,9 +22,10 @@ router = APIRouter()
@router.post("/analysis_history")
async def analysis_history(request: AnalysisHistoryRequest,
current_user: dict = Depends(get_current_user)):
current_user: dict = Depends(get_current_user),
session: Session = Depends(get_db)):
manager = AnalysisHistoryManager()
manager = AnalysisHistoryManager(session)
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
return {"message": "ok"}
@ -30,8 +33,9 @@ async def analysis_history(request: AnalysisHistoryRequest,
@router.get("/analysis_histories")
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
limit: int = 10,
offset: int = 0):
manager = AnalysisHistoryManager()
offset: int = 0,
session: Session = Depends(get_db)):
manager = AnalysisHistoryManager(session)
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
return history
@ -48,7 +52,8 @@ class ChatRequest(BaseModel):
@router.post("/chat-messages")
async def chat(request: ChatRequest,
current_user: dict = Depends(get_current_user)):
current_user: dict = Depends(get_current_user),
session: Session = Depends(get_db)):
token = 'app-pPtva2AdJ8hJzkBKu12ThWjD'
@ -68,7 +73,7 @@ async def chat(request: ChatRequest,
'Content-Type': 'application/json'
}
manager = UserQuestionManager()
manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"],"chat-messages", request.message)
response = requests.post(url, headers=headers, json=payload, stream=True)
@ -90,11 +95,12 @@ async def chat(request: ChatRequest,
@router.post("/analysis")
async def analysis(request: AnalysisRequest,
current_user: dict = Depends(get_current_user)):
current_user: dict = Depends(get_current_user),
session: Session = Depends(get_db)):
if request.type == 'crypto':
# 检查symbol是否存在
manager = TokenManager()
manager = TokenManager(session)
tokens = manager.search_token(request.symbol)
if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
@ -111,7 +117,7 @@ async def analysis(request: AnalysisRequest,
"user": current_user["mail"]
}
manager = UserQuestionManager()
manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
@ -127,7 +133,7 @@ async def analysis(request: AnalysisRequest,
"user": current_user["mail"]
}
manager = UserQuestionManager()
manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票" + stock_code + ",并给出分析报告。")
elif request.type == 'usstock':
@ -143,7 +149,7 @@ async def analysis(request: AnalysisRequest,
"response_mode": "streaming",
"user": current_user["mail"]
}
manager = UserQuestionManager()
manager = UserQuestionManager(session)
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
else:

View File

@ -15,6 +15,8 @@ from cryptoai.api.binance_api import get_binance_api
from cryptoai.models.data_processor import DataProcessor
from datetime import timedelta
from cryptoai.models.token import TokenManager
from sqlalchemy.orm import Session
from cryptoai.utils.db_manager import get_db
# 创建路由
router = APIRouter()
@ -22,8 +24,8 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@router.get("/search/{key}")
async def search_crypto(key: str):
manager = TokenManager()
async def search_crypto(key: str, session: Session = Depends(get_db)):
manager = TokenManager(session)
result = manager.search_token(key)
return result
@ -32,9 +34,9 @@ class CryptoAnalysisRequest(BaseModel):
timeframe: Optional[str] = None
@router.get("/kline/{symbol}")
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100, session: Session = Depends(get_db)):
# 检查symbol是否存在
token_manager = TokenManager()
token_manager = TokenManager(session)
tokens = token_manager.search_token(symbol)
if not tokens or len(tokens) == 0:
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")

View File

@ -29,7 +29,7 @@ services:
cryptoai-api:
build: .
container_name: cryptoai-api
image: cryptoai-api:0.1.33
image: cryptoai-api:0.1.34
restart: always
ports:
- "8000:8000"