This commit is contained in:
aaron 2025-03-09 16:56:00 +08:00
parent f194a99d5b
commit 0b9fd541cc
6 changed files with 34 additions and 55 deletions

View File

@ -4,19 +4,17 @@ from sqlalchemy.orm import Session
from app.models.database import get_db
from app.models.user import UserDB, UserRole
from app.core.security import verify_token
from fastapi import Request
async def get_current_user(
authorization: Optional[str] = Header(None),
access_token: Optional[str] = Cookie(None),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
request: Request = None
) -> UserDB:
# 优先使用Header中的token其次使用Cookie中的token
token = None
if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1]
elif access_token:
token = access_token
else:
token = request.session.get("access_token")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")

View File

@ -11,7 +11,7 @@ from app.core.config import settings
from tencentcloud.common import credential
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.sms.v20210111 import sms_client, models
from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie, get_password_hash, verify_password
from app.core.security import create_access_token, get_password_hash, verify_password
from app.core.response import success_response, error_response, ResponseModel
from pydantic import BaseModel, Field
from typing import List
@ -27,6 +27,7 @@ from app.models.user import UserUpdateRoles, UserUpdateDeliveryCommission
from app.models.order import ShippingOrderDB, OrderStatus
from app.core.redis_client import redis_client
import logging
from fastapi import Request
router = APIRouter()
@ -68,7 +69,8 @@ async def send_verify_code(request: VerifyCodeRequest):
async def login(
user_login: UserLogin,
db: Session = Depends(get_db),
response: Response = None
response: Response = None,
request: Request = None
):
"""用户登录"""
phone = user_login.phone
@ -105,10 +107,8 @@ async def login(
access_token = create_access_token(
data={"phone": user.phone}
)
# 设置JWT cookie
if response:
set_jwt_cookie(response, access_token)
request.session["access_token"] = access_token
return success_response(
message="登录成功",
@ -173,25 +173,26 @@ async def get_user_info(
@router.post("/phone-login", response_model=ResponseModel)
async def phone_login(
request: PhoneLoginRequest,
login_data: PhoneLoginRequest,
db: Session = Depends(get_db),
response: Response = None
response: Response = None,
request: Request = None
):
""" 手机号登录(测试环境) """
if not settings.DEBUG:
return error_response(code=400, message="测试环境不支持手机号登录")
# 查找或创建用户
user = db.query(UserDB).filter(UserDB.phone == request.phone).first()
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
if not user:
# 生成用户编码
user_code = generate_user_code(db)
user = UserDB(
nickname=f"蜜友{request.phone[-4:]}",
phone=request.phone,
nickname=f"蜜友{login_data.phone[-4:]}",
phone=login_data.phone,
user_code=user_code,
referral_code=request.referral_code,
referral_code=login_data.referral_code,
password=get_password_hash("123456"),
roles=[UserRole.USER]
)
@ -205,10 +206,8 @@ async def phone_login(
access_token = create_access_token(
data={"phone": user.phone}
)
# 设置JWT cookie
if response:
set_jwt_cookie(response, access_token)
request.session["access_token"] = access_token
return success_response(
message="登录成功",
@ -222,10 +221,11 @@ async def phone_login(
@router.post("/logout", response_model=ResponseModel)
async def logout(
response: Response,
request: Request,
current_user: UserDB = Depends(get_current_user)
):
"""退出登录"""
clear_jwt_cookie(response)
request.session.clear()
return success_response(message="退出登录成功")
@router.put("/update", response_model=ResponseModel)
@ -315,7 +315,8 @@ async def update_user_roles(
async def password_login(
login_data: UserPasswordLogin,
db: Session = Depends(get_db),
response: Response = None
response: Response = None,
request: Request = None
):
"""密码登录"""
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
@ -356,9 +357,7 @@ async def password_login(
# 生成访问令牌
access_token = create_access_token(data={"phone": user.phone})
# 设置JWT cookie
if response:
set_jwt_cookie(response, access_token)
request.session["access_token"] = access_token
return success_response(
data={

View File

@ -5,7 +5,7 @@ from app.models.user import UserInfo,UserDB, PhoneLoginRequest, generate_user_co
from app.models.order import ShippingOrderDB, OrderStatus
from app.core.response import success_response, error_response, ResponseModel
from app.core.wechat import WeChatClient,generate_random_string
from app.core.security import create_access_token, set_jwt_cookie
from app.core.security import create_access_token
from pydantic import BaseModel, Field
import json
import time
@ -113,10 +113,7 @@ async def wechat_phone_login(
access_token = create_access_token(
data={"phone": user.phone, "userid": user.userid}
)
# 设置JWT cookie
if response:
set_jwt_cookie(response, access_token)
return success_response(
message="登录成功",

View File

@ -22,27 +22,6 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
return encoded_jwt
def set_jwt_cookie(response: Response, token: str):
"""设置JWT cookie"""
response.set_cookie(
key="access_token",
value=token,
httponly=True, # 防止JavaScript访问
secure=not settings.DEBUG, # 生产环境使用HTTPS
samesite="Lax", # CSRF保护
expires=datetime.now(timezone.utc) + timedelta(days=180),
max_age=180*24*60*60 # 30天的秒数
)
def clear_jwt_cookie(response: Response):
"""清除JWT cookie"""
response.delete_cookie(
key="access_token",
httponly=True,
# secure=not settings.DEBUG,
samesite="lax"
)
def verify_token(token: str) -> Optional[str]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])

View File

@ -13,6 +13,8 @@ from app.core.config import settings
from app.core.wecombot import WecomBot
from app.api.endpoints import wecom
from app.api.endpoints import feedback
from starlette.middleware.sessions import SessionMiddleware
# 创建数据库表
Base.metadata.create_all(bind=engine)
@ -24,6 +26,7 @@ app = FastAPI(
docs_url="/docs" if settings.DEBUG else None
)
app.default_response_class = CustomJSONResponse
# 配置 CORS
@ -37,6 +40,8 @@ app.add_middleware(
# 添加请求日志中间件
app.add_middleware(RequestLoggerMiddleware)
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY)
# 添加用户路由
app.include_router(ai.router, prefix="/api/ai", tags=["AI服务"])

View File

@ -17,4 +17,5 @@ cryptography==42.0.2
qrcode>=7.3.1
pillow>=9.0.0
pytz==2024.1
dashscope>=1.13.0
dashscope>=1.13.0
starlette-session==0.3.1