diff --git a/app/api/deps.py b/app/api/deps.py index d972835..614868c 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,13 +1,30 @@ -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Header +from typing import Optional from sqlalchemy.orm import Session from app.models.database import get_db from app.models.user import UserDB +from app.core.security import verify_token async def get_current_user( - phone: str, + authorization: Optional[str] = Header(None), db: Session = Depends(get_db) ) -> UserDB: + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="未提供有效的认证信息") + + token = authorization.split(" ")[1] + phone = verify_token(token) + if not phone: + raise HTTPException(status_code=401, detail="Token已过期或无效") + user = db.query(UserDB).filter(UserDB.phone == phone).first() if not user: raise HTTPException(status_code=401, detail="用户未登录") - return user \ No newline at end of file + return user + +async def get_admin_user( + current_user: UserDB = Depends(get_current_user) +) -> UserDB: + if not current_user.is_admin: + raise HTTPException(status_code=403, detail="需要管理员权限") + return current_user \ No newline at end of file diff --git a/app/api/endpoints/address.py b/app/api/endpoints/address.py index 9db51db..6d2545c 100644 --- a/app/api/endpoints/address.py +++ b/app/api/endpoints/address.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from sqlalchemy import and_ from typing import List @@ -6,17 +6,17 @@ from app.models.address import AddressDB, AddressCreate, AddressUpdate, AddressI from app.models.database import get_db from app.api.deps import get_current_user from app.models.user import UserDB +from app.core.response import success_response, error_response, ResponseModel router = APIRouter() -@router.post("/", response_model=AddressInfo) +@router.post("/", response_model=ResponseModel) async def create_address( address: AddressCreate, db: Session = Depends(get_db), current_user: UserDB = Depends(get_current_user) ): """创建配送地址""" - # 如果设置为默认地址,先将其他地址的默认状态取消 if address.is_default: db.query(AddressDB).filter( and_( @@ -32,9 +32,9 @@ async def create_address( db.add(db_address) db.commit() db.refresh(db_address) - return db_address + return success_response(data=AddressInfo.model_validate(db_address)) -@router.get("/", response_model=List[AddressInfo]) +@router.get("/", response_model=ResponseModel) async def get_addresses( db: Session = Depends(get_db), current_user: UserDB = Depends(get_current_user) @@ -43,9 +43,9 @@ async def get_addresses( addresses = db.query(AddressDB).filter( AddressDB.user_id == current_user.userid ).all() - return addresses + return success_response(data=[AddressInfo.model_validate(a) for a in addresses]) -@router.put("/{address_id}", response_model=AddressInfo) +@router.put("/{address_id}", response_model=ResponseModel) async def update_address( address_id: int, address: AddressUpdate, @@ -61,9 +61,8 @@ async def update_address( ).first() if not db_address: - raise HTTPException(status_code=404, detail="地址不存在") + return error_response(code=404, message="地址不存在") - # 如果设置为默认地址,先将其他地址的默认状态取消 update_data = address.model_dump(exclude_unset=True) if update_data.get("is_default"): db.query(AddressDB).filter( @@ -78,9 +77,9 @@ async def update_address( db.commit() db.refresh(db_address) - return db_address + return success_response(data=AddressInfo.model_validate(db_address)) -@router.delete("/{address_id}") +@router.delete("/{address_id}", response_model=ResponseModel) async def delete_address( address_id: int, db: Session = Depends(get_db), @@ -95,19 +94,18 @@ async def delete_address( ).delete() if not result: - raise HTTPException(status_code=404, detail="地址不存在") + return error_response(code=404, message="地址不存在") db.commit() - return {"message": "地址已删除"} + return success_response(message="地址已删除") -@router.post("/{address_id}/set-default", response_model=AddressInfo) +@router.post("/{address_id}/set-default", response_model=ResponseModel) async def set_default_address( address_id: int, db: Session = Depends(get_db), current_user: UserDB = Depends(get_current_user) ): """设置默认地址""" - # 取消其他默认地址 db.query(AddressDB).filter( and_( AddressDB.user_id == current_user.userid, @@ -115,7 +113,6 @@ async def set_default_address( ) ).update({"is_default": False}) - # 设置新的默认地址 db_address = db.query(AddressDB).filter( and_( AddressDB.id == address_id, @@ -124,9 +121,9 @@ async def set_default_address( ).first() if not db_address: - raise HTTPException(status_code=404, detail="地址不存在") + return error_response(code=404, message="地址不存在") db_address.is_default = True db.commit() db.refresh(db_address) - return db_address \ No newline at end of file + return success_response(data=AddressInfo.model_validate(db_address)) \ No newline at end of file diff --git a/app/api/endpoints/community.py b/app/api/endpoints/community.py index 268d9ea..1a39d2c 100644 --- a/app/api/endpoints/community.py +++ b/app/api/endpoints/community.py @@ -1,24 +1,28 @@ -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from typing import List from app.models.community import CommunityDB, CommunityCreate, CommunityUpdate, CommunityInfo from app.models.database import get_db +from app.api.deps import get_admin_user +from app.models.user import UserDB +from app.core.response import success_response, error_response, ResponseModel router = APIRouter() -@router.post("/", response_model=CommunityInfo) +@router.post("/", response_model=ResponseModel) async def create_community( community: CommunityCreate, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) ): """创建社区""" db_community = CommunityDB(**community.model_dump()) db.add(db_community) db.commit() db.refresh(db_community) - return db_community + return success_response(data=CommunityInfo.model_validate(db_community)) -@router.get("/", response_model=List[CommunityInfo]) +@router.get("/", response_model=ResponseModel) async def get_communities( skip: int = 0, limit: int = 10, @@ -26,9 +30,9 @@ async def get_communities( ): """获取社区列表""" communities = db.query(CommunityDB).offset(skip).limit(limit).all() - return communities + return success_response(data=[CommunityInfo.model_validate(c) for c in communities]) -@router.get("/{community_id}", response_model=CommunityInfo) +@router.get("/{community_id}", response_model=ResponseModel) async def get_community( community_id: int, db: Session = Depends(get_db) @@ -36,19 +40,20 @@ async def get_community( """获取社区详情""" community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first() if not community: - raise HTTPException(status_code=404, detail="社区不存在") - return community + return error_response(code=404, message="社区不存在") + return success_response(data=CommunityInfo.model_validate(community)) -@router.put("/{community_id}", response_model=CommunityInfo) +@router.put("/{community_id}", response_model=ResponseModel) async def update_community( community_id: int, community: CommunityUpdate, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) ): """更新社区信息""" db_community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first() if not db_community: - raise HTTPException(status_code=404, detail="社区不存在") + return error_response(code=404, message="社区不存在") update_data = community.model_dump(exclude_unset=True) for key, value in update_data.items(): @@ -56,17 +61,18 @@ async def update_community( db.commit() db.refresh(db_community) - return db_community + return success_response(data=CommunityInfo.model_validate(db_community)) -@router.delete("/{community_id}") +@router.delete("/{community_id}", response_model=ResponseModel) async def delete_community( community_id: int, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) ): """删除社区""" result = db.query(CommunityDB).filter(CommunityDB.id == community_id).delete() if not result: - raise HTTPException(status_code=404, detail="社区不存在") + return error_response(code=404, message="社区不存在") db.commit() - return {"message": "社区已删除"} \ No newline at end of file + return success_response(message="社区已删除") \ No newline at end of file diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index c54b518..862f136 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -8,6 +8,9 @@ import redis from app.core.config import settings from unisdk.sms import UniSMS from unisdk.exception import UniException +from datetime import timedelta +from app.core.security import create_access_token +from app.core.response import success_response, error_response router = APIRouter() @@ -40,8 +43,8 @@ async def send_verify_code(request: VerifyCodeRequest): } }) - if res.code != "0": # 0 表示发送成功 - raise HTTPException(status_code=500, detail=f"短信发送失败: {res.message}") + if res.code != "0": + return error_response(message=f"短信发送失败: {res.message}") # 存储验证码到 Redis redis_client.setex( @@ -50,10 +53,10 @@ async def send_verify_code(request: VerifyCodeRequest): code ) - return {"message": "验证码已发送"} + return success_response(message="验证码已发送") except UniException as e: - raise HTTPException(status_code=500, detail=f"发送验证码失败: {str(e)}") + return error_response(message=f"发送验证码失败: {str(e)}") @router.post("/login") async def login(user_login: UserLogin, db: Session = Depends(get_db)): @@ -64,14 +67,13 @@ async def login(user_login: UserLogin, db: Session = Depends(get_db)): # 验证验证码 stored_code = redis_client.get(f"verify_code:{phone}") if not stored_code or stored_code != verify_code: - raise HTTPException(status_code=400, detail="验证码错误或已过期") + return error_response(message="验证码错误或已过期") redis_client.delete(f"verify_code:{phone}") # 查找或创建用户 user = db.query(UserDB).filter(UserDB.phone == phone).first() if not user: - # 创建新用户 user = UserDB( username=f"user_{phone[-4:]}", phone=phone @@ -80,13 +82,26 @@ async def login(user_login: UserLogin, db: Session = Depends(get_db)): db.commit() db.refresh(user) - return {"message": "登录成功", "user": UserInfo.model_validate(user)} + # 创建访问令牌 + access_token = create_access_token( + data={"sub": user.phone}, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + ) + + return success_response( + message="登录成功", + data={ + "user": UserInfo.model_validate(user), + "access_token": access_token, + "token_type": "bearer" + } + ) @router.get("/info") async def get_user_info(phone: str, db: Session = Depends(get_db)): """获取用户信息""" user = db.query(UserDB).filter(UserDB.phone == phone).first() if not user: - raise HTTPException(status_code=404, detail="用户不存在") + return error_response(code=404, message="用户不存在") - return UserInfo.model_validate(user) \ No newline at end of file + return success_response(data=UserInfo.model_validate(user)) \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py index 1b67e16..a271a26 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -9,7 +9,7 @@ class Settings(BaseSettings): # JWT 配置 SECRET_KEY: str = "your-secret-key-here" - ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + ACCESS_TOKEN_EXPIRE_MINUTES: int | None = None # None 表示永不过期 REDIS_HOST: str = "101.36.120.145" REDIS_PORT: int = 6379 diff --git a/app/core/response.py b/app/core/response.py new file mode 100644 index 0000000..21165fe --- /dev/null +++ b/app/core/response.py @@ -0,0 +1,21 @@ +from typing import Any, Optional +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +class ResponseModel(BaseModel): + code: int = 200 + message: str = "success" + data: Optional[Any] = None + +def success_response(*, data: Any = None, message: str = "success") -> dict: + return ResponseModel( + code=200, + message=message, + data=data + ).model_dump() + +def error_response(*, code: int = 400, message: str) -> dict: + return ResponseModel( + code=code, + message=message + ).model_dump() \ No newline at end of file diff --git a/app/core/security.py b/app/core/security.py new file mode 100644 index 0000000..c689a55 --- /dev/null +++ b/app/core/security.py @@ -0,0 +1,23 @@ +from datetime import datetime, timedelta, UTC +from typing import Optional +from jose import JWTError, jwt +from app.core.config import settings + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + to_encode = data.copy() + if expires_delta is not None: + expire = datetime.now(UTC) + expires_delta + to_encode.update({"exp": expire}) + + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256") + return encoded_jwt + +def verify_token(token: str) -> Optional[str]: + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) + phone: str = payload.get("sub") + if phone is None: + return None + return phone + except JWTError: + return None \ No newline at end of file diff --git a/app/main.py b/app/main.py index bc3aa89..f69945b 100644 --- a/app/main.py +++ b/app/main.py @@ -2,6 +2,10 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app.api.endpoints import user, address, community from app.models.database import Base, engine +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from app.core.response import error_response +from fastapi import HTTPException # 创建数据库表 Base.metadata.create_all(bind=engine) @@ -32,4 +36,24 @@ async def root(): @app.get("/health") async def health_check(): - return {"status": "healthy"} \ No newline at end of file + return {"status": "healthy"} + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return JSONResponse( + status_code=400, + content=error_response( + code=400, + message=str(exc) + ) + ) + +@app.exception_handler(HTTPException) +async def http_exception_handler(request, exc): + return JSONResponse( + status_code=exc.status_code, + content=error_response( + code=exc.status_code, + message=exc.detail + ) + ) \ No newline at end of file diff --git a/app/models/address.py b/app/models/address.py index 3f67322..9b2a890 100644 --- a/app/models/address.py +++ b/app/models/address.py @@ -1,4 +1,3 @@ -from typing import Optional from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, Boolean from sqlalchemy.sql import func from pydantic import BaseModel, Field @@ -27,11 +26,11 @@ class AddressCreate(BaseModel): is_default: bool = False class AddressUpdate(BaseModel): - community_id: Optional[int] = None - address_detail: Optional[str] = Field(None, max_length=200) - name: Optional[str] = Field(None, max_length=50) - phone: Optional[str] = Field(None, pattern="^1[3-9]\d{9}$") - is_default: Optional[bool] = None + community_id: int | None = None + address_detail: str | None = Field(None, max_length=200) + name: str | None = Field(None, max_length=50) + phone: str | None = Field(None, pattern="^1[3-9]\d{9}$") + is_default: bool | None = None class AddressInfo(BaseModel): id: int diff --git a/app/models/user.py b/app/models/user.py index 4035252..834951f 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String, DateTime,Integer +from sqlalchemy import Column, String, DateTime,Integer, Boolean from sqlalchemy.sql import func from pydantic import BaseModel, Field from .database import Base @@ -10,6 +10,7 @@ class UserDB(Base): userid = Column(Integer, primary_key=True,autoincrement=True, index=True) username = Column(String(50)) phone = Column(String(11), unique=True, index=True) + is_admin = Column(Boolean, default=False) create_time = Column(DateTime(timezone=True), server_default=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now()) @@ -22,6 +23,7 @@ class UserInfo(BaseModel): userid: int username: str phone: str + is_admin: bool class Config: from_attributes = True