update
This commit is contained in:
parent
08c4efcfd3
commit
da278dc033
@ -12,6 +12,8 @@ from fastapi.responses import StreamingResponse
|
|||||||
from cryptoai.routes.user import get_current_user
|
from cryptoai.routes.user import get_current_user
|
||||||
import requests
|
import requests
|
||||||
from cryptoai.models.astock import AStockManager
|
from cryptoai.models.astock import AStockManager
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from cryptoai.utils.db_manager import get_db
|
||||||
# 创建路由
|
# 创建路由
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -19,16 +21,17 @@ logger = logging.getLogger(__name__)
|
|||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
@router.get("/stock/search")
|
@router.get("/stock/search")
|
||||||
async def search_stock(key: str, limit: int = 10):
|
async def search_stock(key: str, limit: int = 10, session: Session = Depends(get_db)):
|
||||||
manager = AStockManager()
|
manager = AStockManager(session)
|
||||||
result = manager.search_stock(key, limit)
|
result = manager.search_stock(key, limit)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stock/base", summary="获取股票基础信息")
|
@router.get("/stock/base", summary="获取股票基础信息")
|
||||||
async def get_stock_base(stock_code: str):
|
async def get_stock_base(stock_code: str, session: Session = Depends(get_db)):
|
||||||
api = AStockAPI()
|
api = AStockAPI()
|
||||||
|
manager = AStockManager(session)
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,8 @@ from datetime import date, timedelta
|
|||||||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||||||
from cryptoai.models.user_question import UserQuestionManager
|
from cryptoai.models.user_question import UserQuestionManager
|
||||||
from cryptoai.models.token import TokenManager
|
from cryptoai.models.token import TokenManager
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from cryptoai.utils.db_manager import get_db
|
||||||
class AnalysisHistoryRequest(BaseModel):
|
class AnalysisHistoryRequest(BaseModel):
|
||||||
symbol: str
|
symbol: str
|
||||||
content: str
|
content: str
|
||||||
@ -20,9 +22,10 @@ router = APIRouter()
|
|||||||
|
|
||||||
@router.post("/analysis_history")
|
@router.post("/analysis_history")
|
||||||
async def analysis_history(request: AnalysisHistoryRequest,
|
async def analysis_history(request: AnalysisHistoryRequest,
|
||||||
current_user: dict = Depends(get_current_user)):
|
current_user: dict = Depends(get_current_user),
|
||||||
|
session: Session = Depends(get_db)):
|
||||||
|
|
||||||
manager = AnalysisHistoryManager()
|
manager = AnalysisHistoryManager(session)
|
||||||
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
||||||
|
|
||||||
return {"message": "ok"}
|
return {"message": "ok"}
|
||||||
@ -30,8 +33,9 @@ async def analysis_history(request: AnalysisHistoryRequest,
|
|||||||
@router.get("/analysis_histories")
|
@router.get("/analysis_histories")
|
||||||
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
offset: int = 0):
|
offset: int = 0,
|
||||||
manager = AnalysisHistoryManager()
|
session: Session = Depends(get_db)):
|
||||||
|
manager = AnalysisHistoryManager(session)
|
||||||
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
@ -48,7 +52,8 @@ class ChatRequest(BaseModel):
|
|||||||
|
|
||||||
@router.post("/chat-messages")
|
@router.post("/chat-messages")
|
||||||
async def chat(request: ChatRequest,
|
async def chat(request: ChatRequest,
|
||||||
current_user: dict = Depends(get_current_user)):
|
current_user: dict = Depends(get_current_user),
|
||||||
|
session: Session = Depends(get_db)):
|
||||||
|
|
||||||
token = 'app-pPtva2AdJ8hJzkBKu12ThWjD'
|
token = 'app-pPtva2AdJ8hJzkBKu12ThWjD'
|
||||||
|
|
||||||
@ -68,7 +73,7 @@ async def chat(request: ChatRequest,
|
|||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
manager = UserQuestionManager()
|
manager = UserQuestionManager(session)
|
||||||
manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||||
@ -90,11 +95,12 @@ async def chat(request: ChatRequest,
|
|||||||
|
|
||||||
@router.post("/analysis")
|
@router.post("/analysis")
|
||||||
async def analysis(request: AnalysisRequest,
|
async def analysis(request: AnalysisRequest,
|
||||||
current_user: dict = Depends(get_current_user)):
|
current_user: dict = Depends(get_current_user),
|
||||||
|
session: Session = Depends(get_db)):
|
||||||
|
|
||||||
if request.type == 'crypto':
|
if request.type == 'crypto':
|
||||||
# 检查symbol是否存在
|
# 检查symbol是否存在
|
||||||
manager = TokenManager()
|
manager = TokenManager(session)
|
||||||
tokens = manager.search_token(request.symbol)
|
tokens = manager.search_token(request.symbol)
|
||||||
if not tokens or len(tokens) == 0:
|
if not tokens or len(tokens) == 0:
|
||||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||||
@ -111,7 +117,7 @@ async def analysis(request: AnalysisRequest,
|
|||||||
"user": current_user["mail"]
|
"user": current_user["mail"]
|
||||||
}
|
}
|
||||||
|
|
||||||
manager = UserQuestionManager()
|
manager = UserQuestionManager(session)
|
||||||
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||||
|
|
||||||
|
|
||||||
@ -127,7 +133,7 @@ async def analysis(request: AnalysisRequest,
|
|||||||
"user": current_user["mail"]
|
"user": current_user["mail"]
|
||||||
}
|
}
|
||||||
|
|
||||||
manager = UserQuestionManager()
|
manager = UserQuestionManager(session)
|
||||||
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
||||||
|
|
||||||
elif request.type == 'usstock':
|
elif request.type == 'usstock':
|
||||||
@ -143,7 +149,7 @@ async def analysis(request: AnalysisRequest,
|
|||||||
"response_mode": "streaming",
|
"response_mode": "streaming",
|
||||||
"user": current_user["mail"]
|
"user": current_user["mail"]
|
||||||
}
|
}
|
||||||
manager = UserQuestionManager()
|
manager = UserQuestionManager(session)
|
||||||
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -15,6 +15,8 @@ from cryptoai.api.binance_api import get_binance_api
|
|||||||
from cryptoai.models.data_processor import DataProcessor
|
from cryptoai.models.data_processor import DataProcessor
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from cryptoai.models.token import TokenManager
|
from cryptoai.models.token import TokenManager
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from cryptoai.utils.db_manager import get_db
|
||||||
# 创建路由
|
# 创建路由
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -22,8 +24,8 @@ logger = logging.getLogger(__name__)
|
|||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
@router.get("/search/{key}")
|
@router.get("/search/{key}")
|
||||||
async def search_crypto(key: str):
|
async def search_crypto(key: str, session: Session = Depends(get_db)):
|
||||||
manager = TokenManager()
|
manager = TokenManager(session)
|
||||||
result = manager.search_token(key)
|
result = manager.search_token(key)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -32,9 +34,9 @@ class CryptoAnalysisRequest(BaseModel):
|
|||||||
timeframe: Optional[str] = None
|
timeframe: Optional[str] = None
|
||||||
|
|
||||||
@router.get("/kline/{symbol}")
|
@router.get("/kline/{symbol}")
|
||||||
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100):
|
async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100, session: Session = Depends(get_db)):
|
||||||
# 检查symbol是否存在
|
# 检查symbol是否存在
|
||||||
token_manager = TokenManager()
|
token_manager = TokenManager(session)
|
||||||
tokens = token_manager.search_token(symbol)
|
tokens = token_manager.search_token(symbol)
|
||||||
if not tokens or len(tokens) == 0:
|
if not tokens or len(tokens) == 0:
|
||||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||||
|
|||||||
@ -29,7 +29,7 @@ services:
|
|||||||
cryptoai-api:
|
cryptoai-api:
|
||||||
build: .
|
build: .
|
||||||
container_name: cryptoai-api
|
container_name: cryptoai-api
|
||||||
image: cryptoai-api:0.1.33
|
image: cryptoai-api:0.1.34
|
||||||
restart: always
|
restart: always
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user