update.
This commit is contained in:
parent
36f8ef7da9
commit
772fc25a9d
@ -1,13 +1,30 @@
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, HTTPException, Header
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.database import get_db
|
||||
from app.models.user import UserDB
|
||||
from app.core.security import verify_token
|
||||
|
||||
async def get_current_user(
|
||||
phone: str,
|
||||
authorization: Optional[str] = Header(None),
|
||||
db: Session = Depends(get_db)
|
||||
) -> UserDB:
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.split(" ")[1]
|
||||
phone = verify_token(token)
|
||||
if not phone:
|
||||
raise HTTPException(status_code=401, detail="Token已过期或无效")
|
||||
|
||||
user = db.query(UserDB).filter(UserDB.phone == phone).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="用户未登录")
|
||||
return user
|
||||
return user
|
||||
|
||||
async def get_admin_user(
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
) -> UserDB:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
return current_user
|
||||
@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import List
|
||||
@ -6,17 +6,17 @@ from app.models.address import AddressDB, AddressCreate, AddressUpdate, AddressI
|
||||
from app.models.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import UserDB
|
||||
from app.core.response import success_response, error_response, ResponseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=AddressInfo)
|
||||
@router.post("/", response_model=ResponseModel)
|
||||
async def create_address(
|
||||
address: AddressCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""创建配送地址"""
|
||||
# 如果设置为默认地址,先将其他地址的默认状态取消
|
||||
if address.is_default:
|
||||
db.query(AddressDB).filter(
|
||||
and_(
|
||||
@ -32,9 +32,9 @@ async def create_address(
|
||||
db.add(db_address)
|
||||
db.commit()
|
||||
db.refresh(db_address)
|
||||
return db_address
|
||||
return success_response(data=AddressInfo.model_validate(db_address))
|
||||
|
||||
@router.get("/", response_model=List[AddressInfo])
|
||||
@router.get("/", response_model=ResponseModel)
|
||||
async def get_addresses(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
@ -43,9 +43,9 @@ async def get_addresses(
|
||||
addresses = db.query(AddressDB).filter(
|
||||
AddressDB.user_id == current_user.userid
|
||||
).all()
|
||||
return addresses
|
||||
return success_response(data=[AddressInfo.model_validate(a) for a in addresses])
|
||||
|
||||
@router.put("/{address_id}", response_model=AddressInfo)
|
||||
@router.put("/{address_id}", response_model=ResponseModel)
|
||||
async def update_address(
|
||||
address_id: int,
|
||||
address: AddressUpdate,
|
||||
@ -61,9 +61,8 @@ async def update_address(
|
||||
).first()
|
||||
|
||||
if not db_address:
|
||||
raise HTTPException(status_code=404, detail="地址不存在")
|
||||
return error_response(code=404, message="地址不存在")
|
||||
|
||||
# 如果设置为默认地址,先将其他地址的默认状态取消
|
||||
update_data = address.model_dump(exclude_unset=True)
|
||||
if update_data.get("is_default"):
|
||||
db.query(AddressDB).filter(
|
||||
@ -78,9 +77,9 @@ async def update_address(
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_address)
|
||||
return db_address
|
||||
return success_response(data=AddressInfo.model_validate(db_address))
|
||||
|
||||
@router.delete("/{address_id}")
|
||||
@router.delete("/{address_id}", response_model=ResponseModel)
|
||||
async def delete_address(
|
||||
address_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
@ -95,19 +94,18 @@ async def delete_address(
|
||||
).delete()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="地址不存在")
|
||||
return error_response(code=404, message="地址不存在")
|
||||
|
||||
db.commit()
|
||||
return {"message": "地址已删除"}
|
||||
return success_response(message="地址已删除")
|
||||
|
||||
@router.post("/{address_id}/set-default", response_model=AddressInfo)
|
||||
@router.post("/{address_id}/set-default", response_model=ResponseModel)
|
||||
async def set_default_address(
|
||||
address_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""设置默认地址"""
|
||||
# 取消其他默认地址
|
||||
db.query(AddressDB).filter(
|
||||
and_(
|
||||
AddressDB.user_id == current_user.userid,
|
||||
@ -115,7 +113,6 @@ async def set_default_address(
|
||||
)
|
||||
).update({"is_default": False})
|
||||
|
||||
# 设置新的默认地址
|
||||
db_address = db.query(AddressDB).filter(
|
||||
and_(
|
||||
AddressDB.id == address_id,
|
||||
@ -124,9 +121,9 @@ async def set_default_address(
|
||||
).first()
|
||||
|
||||
if not db_address:
|
||||
raise HTTPException(status_code=404, detail="地址不存在")
|
||||
return error_response(code=404, message="地址不存在")
|
||||
|
||||
db_address.is_default = True
|
||||
db.commit()
|
||||
db.refresh(db_address)
|
||||
return db_address
|
||||
return success_response(data=AddressInfo.model_validate(db_address))
|
||||
@ -1,24 +1,28 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from app.models.community import CommunityDB, CommunityCreate, CommunityUpdate, CommunityInfo
|
||||
from app.models.database import get_db
|
||||
from app.api.deps import get_admin_user
|
||||
from app.models.user import UserDB
|
||||
from app.core.response import success_response, error_response, ResponseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=CommunityInfo)
|
||||
@router.post("/", response_model=ResponseModel)
|
||||
async def create_community(
|
||||
community: CommunityCreate,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
admin: UserDB = Depends(get_admin_user)
|
||||
):
|
||||
"""创建社区"""
|
||||
db_community = CommunityDB(**community.model_dump())
|
||||
db.add(db_community)
|
||||
db.commit()
|
||||
db.refresh(db_community)
|
||||
return db_community
|
||||
return success_response(data=CommunityInfo.model_validate(db_community))
|
||||
|
||||
@router.get("/", response_model=List[CommunityInfo])
|
||||
@router.get("/", response_model=ResponseModel)
|
||||
async def get_communities(
|
||||
skip: int = 0,
|
||||
limit: int = 10,
|
||||
@ -26,9 +30,9 @@ async def get_communities(
|
||||
):
|
||||
"""获取社区列表"""
|
||||
communities = db.query(CommunityDB).offset(skip).limit(limit).all()
|
||||
return communities
|
||||
return success_response(data=[CommunityInfo.model_validate(c) for c in communities])
|
||||
|
||||
@router.get("/{community_id}", response_model=CommunityInfo)
|
||||
@router.get("/{community_id}", response_model=ResponseModel)
|
||||
async def get_community(
|
||||
community_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
@ -36,19 +40,20 @@ async def get_community(
|
||||
"""获取社区详情"""
|
||||
community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first()
|
||||
if not community:
|
||||
raise HTTPException(status_code=404, detail="社区不存在")
|
||||
return community
|
||||
return error_response(code=404, message="社区不存在")
|
||||
return success_response(data=CommunityInfo.model_validate(community))
|
||||
|
||||
@router.put("/{community_id}", response_model=CommunityInfo)
|
||||
@router.put("/{community_id}", response_model=ResponseModel)
|
||||
async def update_community(
|
||||
community_id: int,
|
||||
community: CommunityUpdate,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
admin: UserDB = Depends(get_admin_user)
|
||||
):
|
||||
"""更新社区信息"""
|
||||
db_community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first()
|
||||
if not db_community:
|
||||
raise HTTPException(status_code=404, detail="社区不存在")
|
||||
return error_response(code=404, message="社区不存在")
|
||||
|
||||
update_data = community.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
@ -56,17 +61,18 @@ async def update_community(
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_community)
|
||||
return db_community
|
||||
return success_response(data=CommunityInfo.model_validate(db_community))
|
||||
|
||||
@router.delete("/{community_id}")
|
||||
@router.delete("/{community_id}", response_model=ResponseModel)
|
||||
async def delete_community(
|
||||
community_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
admin: UserDB = Depends(get_admin_user)
|
||||
):
|
||||
"""删除社区"""
|
||||
result = db.query(CommunityDB).filter(CommunityDB.id == community_id).delete()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="社区不存在")
|
||||
return error_response(code=404, message="社区不存在")
|
||||
|
||||
db.commit()
|
||||
return {"message": "社区已删除"}
|
||||
return success_response(message="社区已删除")
|
||||
@ -8,6 +8,9 @@ import redis
|
||||
from app.core.config import settings
|
||||
from unisdk.sms import UniSMS
|
||||
from unisdk.exception import UniException
|
||||
from datetime import timedelta
|
||||
from app.core.security import create_access_token
|
||||
from app.core.response import success_response, error_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -40,8 +43,8 @@ async def send_verify_code(request: VerifyCodeRequest):
|
||||
}
|
||||
})
|
||||
|
||||
if res.code != "0": # 0 表示发送成功
|
||||
raise HTTPException(status_code=500, detail=f"短信发送失败: {res.message}")
|
||||
if res.code != "0":
|
||||
return error_response(message=f"短信发送失败: {res.message}")
|
||||
|
||||
# 存储验证码到 Redis
|
||||
redis_client.setex(
|
||||
@ -50,10 +53,10 @@ async def send_verify_code(request: VerifyCodeRequest):
|
||||
code
|
||||
)
|
||||
|
||||
return {"message": "验证码已发送"}
|
||||
return success_response(message="验证码已发送")
|
||||
|
||||
except UniException as e:
|
||||
raise HTTPException(status_code=500, detail=f"发送验证码失败: {str(e)}")
|
||||
return error_response(message=f"发送验证码失败: {str(e)}")
|
||||
|
||||
@router.post("/login")
|
||||
async def login(user_login: UserLogin, db: Session = Depends(get_db)):
|
||||
@ -64,14 +67,13 @@ async def login(user_login: UserLogin, db: Session = Depends(get_db)):
|
||||
# 验证验证码
|
||||
stored_code = redis_client.get(f"verify_code:{phone}")
|
||||
if not stored_code or stored_code != verify_code:
|
||||
raise HTTPException(status_code=400, detail="验证码错误或已过期")
|
||||
return error_response(message="验证码错误或已过期")
|
||||
|
||||
redis_client.delete(f"verify_code:{phone}")
|
||||
|
||||
# 查找或创建用户
|
||||
user = db.query(UserDB).filter(UserDB.phone == phone).first()
|
||||
if not user:
|
||||
# 创建新用户
|
||||
user = UserDB(
|
||||
username=f"user_{phone[-4:]}",
|
||||
phone=phone
|
||||
@ -80,13 +82,26 @@ async def login(user_login: UserLogin, db: Session = Depends(get_db)):
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return {"message": "登录成功", "user": UserInfo.model_validate(user)}
|
||||
# 创建访问令牌
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.phone},
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
)
|
||||
|
||||
return success_response(
|
||||
message="登录成功",
|
||||
data={
|
||||
"user": UserInfo.model_validate(user),
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer"
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/info")
|
||||
async def get_user_info(phone: str, db: Session = Depends(get_db)):
|
||||
"""获取用户信息"""
|
||||
user = db.query(UserDB).filter(UserDB.phone == phone).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
return error_response(code=404, message="用户不存在")
|
||||
|
||||
return UserInfo.model_validate(user)
|
||||
return success_response(data=UserInfo.model_validate(user))
|
||||
@ -9,7 +9,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# JWT 配置
|
||||
SECRET_KEY: str = "your-secret-key-here"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int | None = None # None 表示永不过期
|
||||
|
||||
REDIS_HOST: str = "101.36.120.145"
|
||||
REDIS_PORT: int = 6379
|
||||
|
||||
21
app/core/response.py
Normal file
21
app/core/response.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import Any, Optional
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
code: int = 200
|
||||
message: str = "success"
|
||||
data: Optional[Any] = None
|
||||
|
||||
def success_response(*, data: Any = None, message: str = "success") -> dict:
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
message=message,
|
||||
data=data
|
||||
).model_dump()
|
||||
|
||||
def error_response(*, code: int = 400, message: str) -> dict:
|
||||
return ResponseModel(
|
||||
code=code,
|
||||
message=message
|
||||
).model_dump()
|
||||
23
app/core/security.py
Normal file
23
app/core/security.py
Normal file
@ -0,0 +1,23 @@
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
from app.core.config import settings
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
if expires_delta is not None:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(token: str) -> Optional[str]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
phone: str = payload.get("sub")
|
||||
if phone is None:
|
||||
return None
|
||||
return phone
|
||||
except JWTError:
|
||||
return None
|
||||
26
app/main.py
26
app/main.py
@ -2,6 +2,10 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api.endpoints import user, address, community
|
||||
from app.models.database import Base, engine
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.core.response import error_response
|
||||
from fastapi import HTTPException
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@ -32,4 +36,24 @@ async def root():
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=error_response(
|
||||
code=400,
|
||||
message=str(exc)
|
||||
)
|
||||
)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response(
|
||||
code=exc.status_code,
|
||||
message=exc.detail
|
||||
)
|
||||
)
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from pydantic import BaseModel, Field
|
||||
@ -27,11 +26,11 @@ class AddressCreate(BaseModel):
|
||||
is_default: bool = False
|
||||
|
||||
class AddressUpdate(BaseModel):
|
||||
community_id: Optional[int] = None
|
||||
address_detail: Optional[str] = Field(None, max_length=200)
|
||||
name: Optional[str] = Field(None, max_length=50)
|
||||
phone: Optional[str] = Field(None, pattern="^1[3-9]\d{9}$")
|
||||
is_default: Optional[bool] = None
|
||||
community_id: int | None = None
|
||||
address_detail: str | None = Field(None, max_length=200)
|
||||
name: str | None = Field(None, max_length=50)
|
||||
phone: str | None = Field(None, pattern="^1[3-9]\d{9}$")
|
||||
is_default: bool | None = None
|
||||
|
||||
class AddressInfo(BaseModel):
|
||||
id: int
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, DateTime,Integer
|
||||
from sqlalchemy import Column, String, DateTime,Integer, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from pydantic import BaseModel, Field
|
||||
from .database import Base
|
||||
@ -10,6 +10,7 @@ class UserDB(Base):
|
||||
userid = Column(Integer, primary_key=True,autoincrement=True, index=True)
|
||||
username = Column(String(50))
|
||||
phone = Column(String(11), unique=True, index=True)
|
||||
is_admin = Column(Boolean, default=False)
|
||||
create_time = Column(DateTime(timezone=True), server_default=func.now())
|
||||
update_time = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
@ -22,6 +23,7 @@ class UserInfo(BaseModel):
|
||||
userid: int
|
||||
username: str
|
||||
phone: str
|
||||
is_admin: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user