This commit is contained in:
aaron 2025-03-09 16:56:00 +08:00
parent f194a99d5b
commit 0b9fd541cc
6 changed files with 34 additions and 55 deletions

View File

@ -4,19 +4,17 @@ from sqlalchemy.orm import Session
from app.models.database import get_db from app.models.database import get_db
from app.models.user import UserDB, UserRole from app.models.user import UserDB, UserRole
from app.core.security import verify_token from app.core.security import verify_token
from fastapi import Request
async def get_current_user( async def get_current_user(
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None),
access_token: Optional[str] = Cookie(None), db: Session = Depends(get_db),
db: Session = Depends(get_db) request: Request = None
) -> UserDB: ) -> UserDB:
# 优先使用Header中的token其次使用Cookie中的token
token = None
if authorization and authorization.startswith("Bearer "): if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1] token = authorization.split(" ")[1]
elif access_token: else:
token = access_token token = request.session.get("access_token")
if not token: if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息") raise HTTPException(status_code=401, detail="未提供有效的认证信息")

View File

@ -11,7 +11,7 @@ from app.core.config import settings
from tencentcloud.common import credential from tencentcloud.common import credential
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.sms.v20210111 import sms_client, models from tencentcloud.sms.v20210111 import sms_client, models
from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie, get_password_hash, verify_password from app.core.security import create_access_token, get_password_hash, verify_password
from app.core.response import success_response, error_response, ResponseModel from app.core.response import success_response, error_response, ResponseModel
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List from typing import List
@ -27,6 +27,7 @@ from app.models.user import UserUpdateRoles, UserUpdateDeliveryCommission
from app.models.order import ShippingOrderDB, OrderStatus from app.models.order import ShippingOrderDB, OrderStatus
from app.core.redis_client import redis_client from app.core.redis_client import redis_client
import logging import logging
from fastapi import Request
router = APIRouter() router = APIRouter()
@ -68,7 +69,8 @@ async def send_verify_code(request: VerifyCodeRequest):
async def login( async def login(
user_login: UserLogin, user_login: UserLogin,
db: Session = Depends(get_db), db: Session = Depends(get_db),
response: Response = None response: Response = None,
request: Request = None
): ):
"""用户登录""" """用户登录"""
phone = user_login.phone phone = user_login.phone
@ -105,10 +107,8 @@ async def login(
access_token = create_access_token( access_token = create_access_token(
data={"phone": user.phone} data={"phone": user.phone}
) )
# 设置JWT cookie request.session["access_token"] = access_token
if response:
set_jwt_cookie(response, access_token)
return success_response( return success_response(
message="登录成功", message="登录成功",
@ -173,25 +173,26 @@ async def get_user_info(
@router.post("/phone-login", response_model=ResponseModel) @router.post("/phone-login", response_model=ResponseModel)
async def phone_login( async def phone_login(
request: PhoneLoginRequest, login_data: PhoneLoginRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
response: Response = None response: Response = None,
request: Request = None
): ):
""" 手机号登录(测试环境) """ """ 手机号登录(测试环境) """
if not settings.DEBUG: if not settings.DEBUG:
return error_response(code=400, message="测试环境不支持手机号登录") return error_response(code=400, message="测试环境不支持手机号登录")
# 查找或创建用户 # 查找或创建用户
user = db.query(UserDB).filter(UserDB.phone == request.phone).first() user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
if not user: if not user:
# 生成用户编码 # 生成用户编码
user_code = generate_user_code(db) user_code = generate_user_code(db)
user = UserDB( user = UserDB(
nickname=f"蜜友{request.phone[-4:]}", nickname=f"蜜友{login_data.phone[-4:]}",
phone=request.phone, phone=login_data.phone,
user_code=user_code, user_code=user_code,
referral_code=request.referral_code, referral_code=login_data.referral_code,
password=get_password_hash("123456"), password=get_password_hash("123456"),
roles=[UserRole.USER] roles=[UserRole.USER]
) )
@ -205,10 +206,8 @@ async def phone_login(
access_token = create_access_token( access_token = create_access_token(
data={"phone": user.phone} data={"phone": user.phone}
) )
# 设置JWT cookie request.session["access_token"] = access_token
if response:
set_jwt_cookie(response, access_token)
return success_response( return success_response(
message="登录成功", message="登录成功",
@ -222,10 +221,11 @@ async def phone_login(
@router.post("/logout", response_model=ResponseModel) @router.post("/logout", response_model=ResponseModel)
async def logout( async def logout(
response: Response, response: Response,
request: Request,
current_user: UserDB = Depends(get_current_user) current_user: UserDB = Depends(get_current_user)
): ):
"""退出登录""" """退出登录"""
clear_jwt_cookie(response) request.session.clear()
return success_response(message="退出登录成功") return success_response(message="退出登录成功")
@router.put("/update", response_model=ResponseModel) @router.put("/update", response_model=ResponseModel)
@ -315,7 +315,8 @@ async def update_user_roles(
async def password_login( async def password_login(
login_data: UserPasswordLogin, login_data: UserPasswordLogin,
db: Session = Depends(get_db), db: Session = Depends(get_db),
response: Response = None response: Response = None,
request: Request = None
): ):
"""密码登录""" """密码登录"""
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first() user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
@ -356,9 +357,7 @@ async def password_login(
# 生成访问令牌 # 生成访问令牌
access_token = create_access_token(data={"phone": user.phone}) access_token = create_access_token(data={"phone": user.phone})
# 设置JWT cookie request.session["access_token"] = access_token
if response:
set_jwt_cookie(response, access_token)
return success_response( return success_response(
data={ data={

View File

@ -5,7 +5,7 @@ from app.models.user import UserInfo,UserDB, PhoneLoginRequest, generate_user_co
from app.models.order import ShippingOrderDB, OrderStatus from app.models.order import ShippingOrderDB, OrderStatus
from app.core.response import success_response, error_response, ResponseModel from app.core.response import success_response, error_response, ResponseModel
from app.core.wechat import WeChatClient,generate_random_string from app.core.wechat import WeChatClient,generate_random_string
from app.core.security import create_access_token, set_jwt_cookie from app.core.security import create_access_token
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import json import json
import time import time
@ -113,10 +113,7 @@ async def wechat_phone_login(
access_token = create_access_token( access_token = create_access_token(
data={"phone": user.phone, "userid": user.userid} data={"phone": user.phone, "userid": user.userid}
) )
# 设置JWT cookie
if response:
set_jwt_cookie(response, access_token)
return success_response( return success_response(
message="登录成功", message="登录成功",

View File

@ -22,27 +22,6 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256") encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
return encoded_jwt return encoded_jwt
def set_jwt_cookie(response: Response, token: str):
"""设置JWT cookie"""
response.set_cookie(
key="access_token",
value=token,
httponly=True, # 防止JavaScript访问
secure=not settings.DEBUG, # 生产环境使用HTTPS
samesite="Lax", # CSRF保护
expires=datetime.now(timezone.utc) + timedelta(days=180),
max_age=180*24*60*60 # 30天的秒数
)
def clear_jwt_cookie(response: Response):
"""清除JWT cookie"""
response.delete_cookie(
key="access_token",
httponly=True,
# secure=not settings.DEBUG,
samesite="lax"
)
def verify_token(token: str) -> Optional[str]: def verify_token(token: str) -> Optional[str]:
try: try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])

View File

@ -13,6 +13,8 @@ from app.core.config import settings
from app.core.wecombot import WecomBot from app.core.wecombot import WecomBot
from app.api.endpoints import wecom from app.api.endpoints import wecom
from app.api.endpoints import feedback from app.api.endpoints import feedback
from starlette.middleware.sessions import SessionMiddleware
# 创建数据库表 # 创建数据库表
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@ -24,6 +26,7 @@ app = FastAPI(
docs_url="/docs" if settings.DEBUG else None docs_url="/docs" if settings.DEBUG else None
) )
app.default_response_class = CustomJSONResponse app.default_response_class = CustomJSONResponse
# 配置 CORS # 配置 CORS
@ -37,6 +40,8 @@ app.add_middleware(
# 添加请求日志中间件 # 添加请求日志中间件
app.add_middleware(RequestLoggerMiddleware) app.add_middleware(RequestLoggerMiddleware)
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
# 添加用户路由 # 添加用户路由
app.include_router(ai.router, prefix="/api/ai", tags=["AI服务"]) app.include_router(ai.router, prefix="/api/ai", tags=["AI服务"])

View File

@ -17,4 +17,5 @@ cryptography==42.0.2
qrcode>=7.3.1 qrcode>=7.3.1
pillow>=9.0.0 pillow>=9.0.0
pytz==2024.1 pytz==2024.1
dashscope>=1.13.0 dashscope>=1.13.0
starlette-session==0.3.1