update
This commit is contained in:
parent
f194a99d5b
commit
0b9fd541cc
@ -4,19 +4,17 @@ from sqlalchemy.orm import Session
|
||||
from app.models.database import get_db
|
||||
from app.models.user import UserDB, UserRole
|
||||
from app.core.security import verify_token
|
||||
from fastapi import Request
|
||||
|
||||
async def get_current_user(
|
||||
authorization: Optional[str] = Header(None),
|
||||
access_token: Optional[str] = Cookie(None),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
request: Request = None
|
||||
) -> UserDB:
|
||||
|
||||
# 优先使用Header中的token,其次使用Cookie中的token
|
||||
token = None
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.split(" ")[1]
|
||||
elif access_token:
|
||||
token = access_token
|
||||
else:
|
||||
token = request.session.get("access_token")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
@ -11,7 +11,7 @@ from app.core.config import settings
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
||||
from tencentcloud.sms.v20210111 import sms_client, models
|
||||
from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie, get_password_hash, verify_password
|
||||
from app.core.security import create_access_token, get_password_hash, verify_password
|
||||
from app.core.response import success_response, error_response, ResponseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
@ -27,6 +27,7 @@ from app.models.user import UserUpdateRoles, UserUpdateDeliveryCommission
|
||||
from app.models.order import ShippingOrderDB, OrderStatus
|
||||
from app.core.redis_client import redis_client
|
||||
import logging
|
||||
from fastapi import Request
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@ -68,7 +69,8 @@ async def send_verify_code(request: VerifyCodeRequest):
|
||||
async def login(
|
||||
user_login: UserLogin,
|
||||
db: Session = Depends(get_db),
|
||||
response: Response = None
|
||||
response: Response = None,
|
||||
request: Request = None
|
||||
):
|
||||
"""用户登录"""
|
||||
phone = user_login.phone
|
||||
@ -106,9 +108,7 @@ async def login(
|
||||
data={"phone": user.phone}
|
||||
)
|
||||
|
||||
# 设置JWT cookie
|
||||
if response:
|
||||
set_jwt_cookie(response, access_token)
|
||||
request.session["access_token"] = access_token
|
||||
|
||||
return success_response(
|
||||
message="登录成功",
|
||||
@ -173,25 +173,26 @@ async def get_user_info(
|
||||
|
||||
@router.post("/phone-login", response_model=ResponseModel)
|
||||
async def phone_login(
|
||||
request: PhoneLoginRequest,
|
||||
login_data: PhoneLoginRequest,
|
||||
db: Session = Depends(get_db),
|
||||
response: Response = None
|
||||
response: Response = None,
|
||||
request: Request = None
|
||||
):
|
||||
""" 手机号登录(测试环境) """
|
||||
if not settings.DEBUG:
|
||||
return error_response(code=400, message="测试环境不支持手机号登录")
|
||||
|
||||
# 查找或创建用户
|
||||
user = db.query(UserDB).filter(UserDB.phone == request.phone).first()
|
||||
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
|
||||
if not user:
|
||||
# 生成用户编码
|
||||
user_code = generate_user_code(db)
|
||||
|
||||
user = UserDB(
|
||||
nickname=f"蜜友{request.phone[-4:]}",
|
||||
phone=request.phone,
|
||||
nickname=f"蜜友{login_data.phone[-4:]}",
|
||||
phone=login_data.phone,
|
||||
user_code=user_code,
|
||||
referral_code=request.referral_code,
|
||||
referral_code=login_data.referral_code,
|
||||
password=get_password_hash("123456"),
|
||||
roles=[UserRole.USER]
|
||||
)
|
||||
@ -206,9 +207,7 @@ async def phone_login(
|
||||
data={"phone": user.phone}
|
||||
)
|
||||
|
||||
# 设置JWT cookie
|
||||
if response:
|
||||
set_jwt_cookie(response, access_token)
|
||||
request.session["access_token"] = access_token
|
||||
|
||||
return success_response(
|
||||
message="登录成功",
|
||||
@ -222,10 +221,11 @@ async def phone_login(
|
||||
@router.post("/logout", response_model=ResponseModel)
|
||||
async def logout(
|
||||
response: Response,
|
||||
request: Request,
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""退出登录"""
|
||||
clear_jwt_cookie(response)
|
||||
request.session.clear()
|
||||
return success_response(message="退出登录成功")
|
||||
|
||||
@router.put("/update", response_model=ResponseModel)
|
||||
@ -315,7 +315,8 @@ async def update_user_roles(
|
||||
async def password_login(
|
||||
login_data: UserPasswordLogin,
|
||||
db: Session = Depends(get_db),
|
||||
response: Response = None
|
||||
response: Response = None,
|
||||
request: Request = None
|
||||
):
|
||||
"""密码登录"""
|
||||
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
|
||||
@ -356,9 +357,7 @@ async def password_login(
|
||||
# 生成访问令牌
|
||||
access_token = create_access_token(data={"phone": user.phone})
|
||||
|
||||
# 设置JWT cookie
|
||||
if response:
|
||||
set_jwt_cookie(response, access_token)
|
||||
request.session["access_token"] = access_token
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
|
||||
@ -5,7 +5,7 @@ from app.models.user import UserInfo,UserDB, PhoneLoginRequest, generate_user_co
|
||||
from app.models.order import ShippingOrderDB, OrderStatus
|
||||
from app.core.response import success_response, error_response, ResponseModel
|
||||
from app.core.wechat import WeChatClient,generate_random_string
|
||||
from app.core.security import create_access_token, set_jwt_cookie
|
||||
from app.core.security import create_access_token
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
import time
|
||||
@ -114,9 +114,6 @@ async def wechat_phone_login(
|
||||
data={"phone": user.phone, "userid": user.userid}
|
||||
)
|
||||
|
||||
# 设置JWT cookie
|
||||
if response:
|
||||
set_jwt_cookie(response, access_token)
|
||||
|
||||
return success_response(
|
||||
message="登录成功",
|
||||
|
||||
@ -22,27 +22,6 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
||||
return encoded_jwt
|
||||
|
||||
def set_jwt_cookie(response: Response, token: str):
|
||||
"""设置JWT cookie"""
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True, # 防止JavaScript访问
|
||||
secure=not settings.DEBUG, # 生产环境使用HTTPS
|
||||
samesite="Lax", # CSRF保护
|
||||
expires=datetime.now(timezone.utc) + timedelta(days=180),
|
||||
max_age=180*24*60*60 # 30天的秒数
|
||||
)
|
||||
|
||||
def clear_jwt_cookie(response: Response):
|
||||
"""清除JWT cookie"""
|
||||
response.delete_cookie(
|
||||
key="access_token",
|
||||
httponly=True,
|
||||
# secure=not settings.DEBUG,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
def verify_token(token: str) -> Optional[str]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
|
||||
@ -13,6 +13,8 @@ from app.core.config import settings
|
||||
from app.core.wecombot import WecomBot
|
||||
from app.api.endpoints import wecom
|
||||
from app.api.endpoints import feedback
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@ -24,6 +26,7 @@ app = FastAPI(
|
||||
docs_url="/docs" if settings.DEBUG else None
|
||||
)
|
||||
|
||||
|
||||
app.default_response_class = CustomJSONResponse
|
||||
|
||||
# 配置 CORS
|
||||
@ -37,6 +40,8 @@ app.add_middleware(
|
||||
|
||||
# 添加请求日志中间件
|
||||
app.add_middleware(RequestLoggerMiddleware)
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
|
||||
|
||||
|
||||
# 添加用户路由
|
||||
app.include_router(ai.router, prefix="/api/ai", tags=["AI服务"])
|
||||
|
||||
@ -18,3 +18,4 @@ qrcode>=7.3.1
|
||||
pillow>=9.0.0
|
||||
pytz==2024.1
|
||||
dashscope>=1.13.0
|
||||
starlette-session==0.3.1
|
||||
Loading…
Reference in New Issue
Block a user