update
This commit is contained in:
parent
08c4efcfd3
commit
da278dc033
@ -12,6 +12,8 @@ from fastapi.responses import StreamingResponse
|
||||
from cryptoai.routes.user import get_current_user
|
||||
import requests
|
||||
from cryptoai.models.astock import AStockManager
|
||||
from sqlalchemy.orm import Session
|
||||
from cryptoai.utils.db_manager import get_db
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
@ -19,16 +21,17 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@router.get("/stock/search")
|
||||
async def search_stock(key: str, limit: int = 10):
|
||||
manager = AStockManager()
|
||||
async def search_stock(key: str, limit: int = 10, session: Session = Depends(get_db)):
|
||||
manager = AStockManager(session)
|
||||
result = manager.search_stock(key, limit)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/stock/base", summary="获取股票基础信息")
|
||||
async def get_stock_base(stock_code: str):
|
||||
async def get_stock_base(stock_code: str, session: Session = Depends(get_db)):
|
||||
api = AStockAPI()
|
||||
manager = AStockManager(session)
|
||||
|
||||
result = {}
|
||||
|
||||
|
||||
@ -10,6 +10,8 @@ from datetime import date, timedelta
|
||||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||||
from cryptoai.models.user_question import UserQuestionManager
|
||||
from cryptoai.models.token import TokenManager
|
||||
from sqlalchemy.orm import Session
|
||||
from cryptoai.utils.db_manager import get_db
|
||||
class AnalysisHistoryRequest(BaseModel):
|
||||
symbol: str
|
||||
content: str
|
||||
@ -20,9 +22,10 @@ router = APIRouter()
|
||||
|
||||
@router.post("/analysis_history")
|
||||
async def analysis_history(request: AnalysisHistoryRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
current_user: dict = Depends(get_current_user),
|
||||
session: Session = Depends(get_db)):
|
||||
|
||||
manager = AnalysisHistoryManager()
|
||||
manager = AnalysisHistoryManager(session)
|
||||
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
||||
|
||||
return {"message": "ok"}
|
||||
@ -30,8 +33,9 @@ async def analysis_history(request: AnalysisHistoryRequest,
|
||||
@router.get("/analysis_histories")
|
||||
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
||||
limit: int = 10,
|
||||
offset: int = 0):
|
||||
manager = AnalysisHistoryManager()
|
||||
offset: int = 0,
|
||||
session: Session = Depends(get_db)):
|
||||
manager = AnalysisHistoryManager(session)
|
||||
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||
return history
|
||||
|
||||
@ -48,7 +52,8 @@ class ChatRequest(BaseModel):
|
||||
|
||||
@router.post("/chat-messages")
|
||||
async def chat(request: ChatRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
current_user: dict = Depends(get_current_user),
|
||||
session: Session = Depends(get_db)):
|
||||
|
||||
token = 'app-pPtva2AdJ8hJzkBKu12ThWjD'
|
||||
|
||||
@ -68,7 +73,7 @@ async def chat(request: ChatRequest,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
manager = UserQuestionManager()
|
||||
manager = UserQuestionManager(session)
|
||||
manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||
@ -90,11 +95,12 @@ async def chat(request: ChatRequest,
|
||||
|
||||
@router.post("/analysis")
|
||||
async def analysis(request: AnalysisRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
current_user: dict = Depends(get_current_user),
|
||||
session: Session = Depends(get_db)):
|
||||
|
||||
if request.type == 'crypto':
|
||||
# 检查symbol是否存在
|
||||
manager = TokenManager()
|
||||
manager = TokenManager(session)
|
||||
tokens = manager.search_token(request.symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
@ -111,7 +117,7 @@ async def analysis(request: AnalysisRequest,
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
manager = UserQuestionManager()
|
||||
manager = UserQuestionManager(session)
|
||||
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||
|
||||
|
||||
@ -127,7 +133,7 @@ async def analysis(request: AnalysisRequest,
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
manager = UserQuestionManager()
|
||||
manager = UserQuestionManager(session)
|
||||
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
||||
|
||||
elif request.type == 'usstock':
|
||||
@ -143,7 +149,7 @@ async def analysis(request: AnalysisRequest,
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
manager = UserQuestionManager()
|
||||
manager = UserQuestionManager(session)
|
||||
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
||||
|
||||
else:
|
||||
|
||||
@ -15,6 +15,8 @@ from cryptoai.api.binance_api import get_binance_api
|
||||
from cryptoai.models.data_processor import DataProcessor
|
||||
from datetime import timedelta
|
||||
from cryptoai.models.token import TokenManager
|
||||
from sqlalchemy.orm import Session
|
||||
from cryptoai.utils.db_manager import get_db
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
@ -22,8 +24,8 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@router.get("/search/{key}")
|
||||
async def search_crypto(key: str):
|
||||
manager = TokenManager()
|
||||
async def search_crypto(key: str, session: Session = Depends(get_db)):
|
||||
manager = TokenManager(session)
|
||||
result = manager.search_token(key)
|
||||
return result
|
||||
|
||||
@ -32,9 +34,9 @@ class CryptoAnalysisRequest(BaseModel):
|
||||
timeframe: Optional[str] = None
|
||||
|
||||
@router.get("/kline/{symbol}")
|
||||
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
|
||||
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100, session: Session = Depends(get_db)):
|
||||
# 检查symbol是否存在
|
||||
token_manager = TokenManager()
|
||||
token_manager = TokenManager(session)
|
||||
tokens = token_manager.search_token(symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
|
||||
@ -29,7 +29,7 @@ services:
|
||||
cryptoai-api:
|
||||
build: .
|
||||
container_name: cryptoai-api
|
||||
image: cryptoai-api:0.1.33
|
||||
image: cryptoai-api:0.1.34
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user