This commit is contained in:
aaron 2025-01-02 18:11:01 +08:00
parent 36f8ef7da9
commit 772fc25a9d
10 changed files with 160 additions and 56 deletions

View File

@ -1,13 +1,30 @@
from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Header
from typing import Optional
from sqlalchemy.orm import Session
from app.models.database import get_db
from app.models.user import UserDB
from app.core.security import verify_token
async def get_current_user(
phone: str,
authorization: Optional[str] = Header(None),
db: Session = Depends(get_db)
) -> UserDB:
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.split(" ")[1]
phone = verify_token(token)
if not phone:
raise HTTPException(status_code=401, detail="Token已过期或无效")
user = db.query(UserDB).filter(UserDB.phone == phone).first()
if not user:
raise HTTPException(status_code=401, detail="用户未登录")
return user
return user
async def get_admin_user(
current_user: UserDB = Depends(get_current_user)
) -> UserDB:
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="需要管理员权限")
return current_user

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import List
@ -6,17 +6,17 @@ from app.models.address import AddressDB, AddressCreate, AddressUpdate, AddressI
from app.models.database import get_db
from app.api.deps import get_current_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("/", response_model=AddressInfo)
@router.post("/", response_model=ResponseModel)
async def create_address(
address: AddressCreate,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""创建配送地址"""
# 如果设置为默认地址,先将其他地址的默认状态取消
if address.is_default:
db.query(AddressDB).filter(
and_(
@ -32,9 +32,9 @@ async def create_address(
db.add(db_address)
db.commit()
db.refresh(db_address)
return db_address
return success_response(data=AddressInfo.model_validate(db_address))
@router.get("/", response_model=List[AddressInfo])
@router.get("/", response_model=ResponseModel)
async def get_addresses(
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
@ -43,9 +43,9 @@ async def get_addresses(
addresses = db.query(AddressDB).filter(
AddressDB.user_id == current_user.userid
).all()
return addresses
return success_response(data=[AddressInfo.model_validate(a) for a in addresses])
@router.put("/{address_id}", response_model=AddressInfo)
@router.put("/{address_id}", response_model=ResponseModel)
async def update_address(
address_id: int,
address: AddressUpdate,
@ -61,9 +61,8 @@ async def update_address(
).first()
if not db_address:
raise HTTPException(status_code=404, detail="地址不存在")
return error_response(code=404, message="地址不存在")
# 如果设置为默认地址,先将其他地址的默认状态取消
update_data = address.model_dump(exclude_unset=True)
if update_data.get("is_default"):
db.query(AddressDB).filter(
@ -78,9 +77,9 @@ async def update_address(
db.commit()
db.refresh(db_address)
return db_address
return success_response(data=AddressInfo.model_validate(db_address))
@router.delete("/{address_id}")
@router.delete("/{address_id}", response_model=ResponseModel)
async def delete_address(
address_id: int,
db: Session = Depends(get_db),
@ -95,19 +94,18 @@ async def delete_address(
).delete()
if not result:
raise HTTPException(status_code=404, detail="地址不存在")
return error_response(code=404, message="地址不存在")
db.commit()
return {"message": "地址已删除"}
return success_response(message="地址已删除")
@router.post("/{address_id}/set-default", response_model=AddressInfo)
@router.post("/{address_id}/set-default", response_model=ResponseModel)
async def set_default_address(
address_id: int,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""设置默认地址"""
# 取消其他默认地址
db.query(AddressDB).filter(
and_(
AddressDB.user_id == current_user.userid,
@ -115,7 +113,6 @@ async def set_default_address(
)
).update({"is_default": False})
# 设置新的默认地址
db_address = db.query(AddressDB).filter(
and_(
AddressDB.id == address_id,
@ -124,9 +121,9 @@ async def set_default_address(
).first()
if not db_address:
raise HTTPException(status_code=404, detail="地址不存在")
return error_response(code=404, message="地址不存在")
db_address.is_default = True
db.commit()
db.refresh(db_address)
return db_address
return success_response(data=AddressInfo.model_validate(db_address))

View File

@ -1,24 +1,28 @@
from fastapi import APIRouter, HTTPException, Depends
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from app.models.community import CommunityDB, CommunityCreate, CommunityUpdate, CommunityInfo
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("/", response_model=CommunityInfo)
@router.post("/", response_model=ResponseModel)
async def create_community(
community: CommunityCreate,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建社区"""
db_community = CommunityDB(**community.model_dump())
db.add(db_community)
db.commit()
db.refresh(db_community)
return db_community
return success_response(data=CommunityInfo.model_validate(db_community))
@router.get("/", response_model=List[CommunityInfo])
@router.get("/", response_model=ResponseModel)
async def get_communities(
skip: int = 0,
limit: int = 10,
@ -26,9 +30,9 @@ async def get_communities(
):
"""获取社区列表"""
communities = db.query(CommunityDB).offset(skip).limit(limit).all()
return communities
return success_response(data=[CommunityInfo.model_validate(c) for c in communities])
@router.get("/{community_id}", response_model=CommunityInfo)
@router.get("/{community_id}", response_model=ResponseModel)
async def get_community(
community_id: int,
db: Session = Depends(get_db)
@ -36,19 +40,20 @@ async def get_community(
"""获取社区详情"""
community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first()
if not community:
raise HTTPException(status_code=404, detail="社区不存在")
return community
return error_response(code=404, message="社区不存在")
return success_response(data=CommunityInfo.model_validate(community))
@router.put("/{community_id}", response_model=CommunityInfo)
@router.put("/{community_id}", response_model=ResponseModel)
async def update_community(
community_id: int,
community: CommunityUpdate,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新社区信息"""
db_community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first()
if not db_community:
raise HTTPException(status_code=404, detail="社区不存在")
return error_response(code=404, message="社区不存在")
update_data = community.model_dump(exclude_unset=True)
for key, value in update_data.items():
@ -56,17 +61,18 @@ async def update_community(
db.commit()
db.refresh(db_community)
return db_community
return success_response(data=CommunityInfo.model_validate(db_community))
@router.delete("/{community_id}")
@router.delete("/{community_id}", response_model=ResponseModel)
async def delete_community(
community_id: int,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""删除社区"""
result = db.query(CommunityDB).filter(CommunityDB.id == community_id).delete()
if not result:
raise HTTPException(status_code=404, detail="社区不存在")
return error_response(code=404, message="社区不存在")
db.commit()
return {"message": "社区已删除"}
return success_response(message="社区已删除")

View File

@ -8,6 +8,9 @@ import redis
from app.core.config import settings
from unisdk.sms import UniSMS
from unisdk.exception import UniException
from datetime import timedelta
from app.core.security import create_access_token
from app.core.response import success_response, error_response
router = APIRouter()
@ -40,8 +43,8 @@ async def send_verify_code(request: VerifyCodeRequest):
}
})
if res.code != "0": # 0 表示发送成功
raise HTTPException(status_code=500, detail=f"短信发送失败: {res.message}")
if res.code != "0":
return error_response(message=f"短信发送失败: {res.message}")
# 存储验证码到 Redis
redis_client.setex(
@ -50,10 +53,10 @@ async def send_verify_code(request: VerifyCodeRequest):
code
)
return {"message": "验证码已发送"}
return success_response(message="验证码已发送")
except UniException as e:
raise HTTPException(status_code=500, detail=f"发送验证码失败: {str(e)}")
return error_response(message=f"发送验证码失败: {str(e)}")
@router.post("/login")
async def login(user_login: UserLogin, db: Session = Depends(get_db)):
@ -64,14 +67,13 @@ async def login(user_login: UserLogin, db: Session = Depends(get_db)):
# 验证验证码
stored_code = redis_client.get(f"verify_code:{phone}")
if not stored_code or stored_code != verify_code:
raise HTTPException(status_code=400, detail="验证码错误或已过期")
return error_response(message="验证码错误或已过期")
redis_client.delete(f"verify_code:{phone}")
# 查找或创建用户
user = db.query(UserDB).filter(UserDB.phone == phone).first()
if not user:
# 创建新用户
user = UserDB(
username=f"user_{phone[-4:]}",
phone=phone
@ -80,13 +82,26 @@ async def login(user_login: UserLogin, db: Session = Depends(get_db)):
db.commit()
db.refresh(user)
return {"message": "登录成功", "user": UserInfo.model_validate(user)}
# 创建访问令牌
access_token = create_access_token(
data={"sub": user.phone},
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
)
return success_response(
message="登录成功",
data={
"user": UserInfo.model_validate(user),
"access_token": access_token,
"token_type": "bearer"
}
)
@router.get("/info")
async def get_user_info(phone: str, db: Session = Depends(get_db)):
"""获取用户信息"""
user = db.query(UserDB).filter(UserDB.phone == phone).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return error_response(code=404, message="用户不存在")
return UserInfo.model_validate(user)
return success_response(data=UserInfo.model_validate(user))

View File

@ -9,7 +9,7 @@ class Settings(BaseSettings):
# JWT 配置
SECRET_KEY: str = "your-secret-key-here"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
ACCESS_TOKEN_EXPIRE_MINUTES: int | None = None # None 表示永不过期
REDIS_HOST: str = "101.36.120.145"
REDIS_PORT: int = 6379

21
app/core/response.py Normal file
View File

@ -0,0 +1,21 @@
from typing import Any, Optional
from fastapi.responses import JSONResponse
from pydantic import BaseModel
class ResponseModel(BaseModel):
code: int = 200
message: str = "success"
data: Optional[Any] = None
def success_response(*, data: Any = None, message: str = "success") -> dict:
return ResponseModel(
code=200,
message=message,
data=data
).model_dump()
def error_response(*, code: int = 400, message: str) -> dict:
return ResponseModel(
code=code,
message=message
).model_dump()

23
app/core/security.py Normal file
View File

@ -0,0 +1,23 @@
from datetime import datetime, timedelta, UTC
from typing import Optional
from jose import JWTError, jwt
from app.core.config import settings
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta is not None:
expire = datetime.now(UTC) + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
return encoded_jwt
def verify_token(token: str) -> Optional[str]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
phone: str = payload.get("sub")
if phone is None:
return None
return phone
except JWTError:
return None

View File

@ -2,6 +2,10 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import user, address, community
from app.models.database import Base, engine
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.core.response import error_response
from fastapi import HTTPException
# 创建数据库表
Base.metadata.create_all(bind=engine)
@ -32,4 +36,24 @@ async def root():
@app.get("/health")
async def health_check():
return {"status": "healthy"}
return {"status": "healthy"}
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return JSONResponse(
status_code=400,
content=error_response(
code=400,
message=str(exc)
)
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=error_response(
code=exc.status_code,
message=exc.detail
)
)

View File

@ -1,4 +1,3 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, Boolean
from sqlalchemy.sql import func
from pydantic import BaseModel, Field
@ -27,11 +26,11 @@ class AddressCreate(BaseModel):
is_default: bool = False
class AddressUpdate(BaseModel):
community_id: Optional[int] = None
address_detail: Optional[str] = Field(None, max_length=200)
name: Optional[str] = Field(None, max_length=50)
phone: Optional[str] = Field(None, pattern="^1[3-9]\d{9}$")
is_default: Optional[bool] = None
community_id: int | None = None
address_detail: str | None = Field(None, max_length=200)
name: str | None = Field(None, max_length=50)
phone: str | None = Field(None, pattern="^1[3-9]\d{9}$")
is_default: bool | None = None
class AddressInfo(BaseModel):
id: int

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, String, DateTime,Integer
from sqlalchemy import Column, String, DateTime,Integer, Boolean
from sqlalchemy.sql import func
from pydantic import BaseModel, Field
from .database import Base
@ -10,6 +10,7 @@ class UserDB(Base):
userid = Column(Integer, primary_key=True,autoincrement=True, index=True)
username = Column(String(50))
phone = Column(String(11), unique=True, index=True)
is_admin = Column(Boolean, default=False)
create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now())
@ -22,6 +23,7 @@ class UserInfo(BaseModel):
userid: int
username: str
phone: str
is_admin: bool
class Config:
from_attributes = True