diff --git a/app/api/deps.py b/app/api/deps.py index d0a17fa..033c250 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -4,19 +4,17 @@ from sqlalchemy.orm import Session from app.models.database import get_db from app.models.user import UserDB, UserRole from app.core.security import verify_token +from fastapi import Request async def get_current_user( authorization: Optional[str] = Header(None), - access_token: Optional[str] = Cookie(None), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + request: Request = None ) -> UserDB: - - # 优先使用Header中的token,其次使用Cookie中的token - token = None if authorization and authorization.startswith("Bearer "): token = authorization.split(" ")[1] - elif access_token: - token = access_token + else: + token = request.session.get("access_token") if not token: raise HTTPException(status_code=401, detail="未提供有效的认证信息") diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 4700bfa..850b909 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -11,7 +11,7 @@ from app.core.config import settings from tencentcloud.common import credential from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException from tencentcloud.sms.v20210111 import sms_client, models -from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie, get_password_hash, verify_password +from app.core.security import create_access_token, get_password_hash, verify_password from app.core.response import success_response, error_response, ResponseModel from pydantic import BaseModel, Field from typing import List @@ -27,6 +27,7 @@ from app.models.user import UserUpdateRoles, UserUpdateDeliveryCommission from app.models.order import ShippingOrderDB, OrderStatus from app.core.redis_client import redis_client import logging +from fastapi import Request router = APIRouter() @@ -68,7 +69,8 @@ async def send_verify_code(request: VerifyCodeRequest): async def login( user_login: UserLogin, db: Session = Depends(get_db), - response: Response = None + response: Response = None, + request: Request = None ): """用户登录""" phone = user_login.phone @@ -105,10 +107,8 @@ async def login( access_token = create_access_token( data={"phone": user.phone} ) - - # 设置JWT cookie - if response: - set_jwt_cookie(response, access_token) + + request.session["access_token"] = access_token return success_response( message="登录成功", @@ -173,25 +173,26 @@ async def get_user_info( @router.post("/phone-login", response_model=ResponseModel) async def phone_login( - request: PhoneLoginRequest, + login_data: PhoneLoginRequest, db: Session = Depends(get_db), - response: Response = None + response: Response = None, + request: Request = None ): """ 手机号登录(测试环境) """ if not settings.DEBUG: return error_response(code=400, message="测试环境不支持手机号登录") # 查找或创建用户 - user = db.query(UserDB).filter(UserDB.phone == request.phone).first() + user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first() if not user: # 生成用户编码 user_code = generate_user_code(db) user = UserDB( - nickname=f"蜜友{request.phone[-4:]}", - phone=request.phone, + nickname=f"蜜友{login_data.phone[-4:]}", + phone=login_data.phone, user_code=user_code, - referral_code=request.referral_code, + referral_code=login_data.referral_code, password=get_password_hash("123456"), roles=[UserRole.USER] ) @@ -205,10 +206,8 @@ async def phone_login( access_token = create_access_token( data={"phone": user.phone} ) - - # 设置JWT cookie - if response: - set_jwt_cookie(response, access_token) + + request.session["access_token"] = access_token return success_response( message="登录成功", @@ -222,10 +221,11 @@ async def phone_login( @router.post("/logout", response_model=ResponseModel) async def logout( response: Response, + request: Request, current_user: UserDB = Depends(get_current_user) ): """退出登录""" - clear_jwt_cookie(response) + request.session.clear() return success_response(message="退出登录成功") @router.put("/update", response_model=ResponseModel) @@ -315,7 +315,8 @@ async def update_user_roles( async def password_login( login_data: UserPasswordLogin, db: Session = Depends(get_db), - response: Response = None + response: Response = None, + request: Request = None ): """密码登录""" user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first() @@ -356,9 +357,7 @@ async def password_login( # 生成访问令牌 access_token = create_access_token(data={"phone": user.phone}) - # 设置JWT cookie - if response: - set_jwt_cookie(response, access_token) + request.session["access_token"] = access_token return success_response( data={ diff --git a/app/api/endpoints/wechat.py b/app/api/endpoints/wechat.py index fd75a4e..4013152 100644 --- a/app/api/endpoints/wechat.py +++ b/app/api/endpoints/wechat.py @@ -5,7 +5,7 @@ from app.models.user import UserInfo,UserDB, PhoneLoginRequest, generate_user_co from app.models.order import ShippingOrderDB, OrderStatus from app.core.response import success_response, error_response, ResponseModel from app.core.wechat import WeChatClient,generate_random_string -from app.core.security import create_access_token, set_jwt_cookie +from app.core.security import create_access_token from pydantic import BaseModel, Field import json import time @@ -113,10 +113,7 @@ async def wechat_phone_login( access_token = create_access_token( data={"phone": user.phone, "userid": user.userid} ) - - # 设置JWT cookie - if response: - set_jwt_cookie(response, access_token) + return success_response( message="登录成功", diff --git a/app/core/security.py b/app/core/security.py index 03010b7..e8802e6 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -22,27 +22,6 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256") return encoded_jwt -def set_jwt_cookie(response: Response, token: str): - """设置JWT cookie""" - response.set_cookie( - key="access_token", - value=token, - httponly=True, # 防止JavaScript访问 - secure=not settings.DEBUG, # 生产环境使用HTTPS - samesite="Lax", # CSRF保护 - expires=datetime.now(timezone.utc) + timedelta(days=180), - max_age=180*24*60*60 # 30天的秒数 - ) - -def clear_jwt_cookie(response: Response): - """清除JWT cookie""" - response.delete_cookie( - key="access_token", - httponly=True, - # secure=not settings.DEBUG, - samesite="lax" - ) - def verify_token(token: str) -> Optional[str]: try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) diff --git a/app/main.py b/app/main.py index 3d4b0ee..9ea4bd4 100644 --- a/app/main.py +++ b/app/main.py @@ -13,6 +13,8 @@ from app.core.config import settings from app.core.wecombot import WecomBot from app.api.endpoints import wecom from app.api.endpoints import feedback +from starlette.middleware.sessions import SessionMiddleware + # 创建数据库表 Base.metadata.create_all(bind=engine) @@ -24,6 +26,7 @@ app = FastAPI( docs_url="/docs" if settings.DEBUG else None ) + app.default_response_class = CustomJSONResponse # 配置 CORS @@ -37,6 +40,8 @@ app.add_middleware( # 添加请求日志中间件 app.add_middleware(RequestLoggerMiddleware) +app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY) + # 添加用户路由 app.include_router(ai.router, prefix="/api/ai", tags=["AI服务"]) diff --git a/requirements.txt b/requirements.txt index ea6f316..f82f265 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,5 @@ cryptography==42.0.2 qrcode>=7.3.1 pillow>=9.0.0 pytz==2024.1 -dashscope>=1.13.0 \ No newline at end of file +dashscope>=1.13.0 +starlette-session==0.3.1 \ No newline at end of file