diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 4c5e67c..4a1bf0d 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, Depends, Response from sqlalchemy.orm import Session -from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB +from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate from app.api.deps import get_current_user from app.models.database import get_db import random @@ -110,14 +110,12 @@ async def login( } ) -@router.get("/info") -async def get_user_info(phone: str, db: Session = Depends(get_db)): +@router.get("/info", response_model=ResponseModel) +async def get_user_info( + current_user: UserDB = Depends(get_current_user) +): """获取用户信息""" - user = db.query(UserDB).filter(UserDB.phone == phone).first() - if not user: - return error_response(code=404, message="用户不存在") - - return success_response(data=UserInfo.model_validate(user)) + return success_response(data=UserInfo.model_validate(current_user)) @router.post("/mock-login", response_model=ResponseModel) async def mock_login( @@ -165,4 +163,32 @@ async def logout( ): """退出登录""" clear_jwt_cookie(response) - return success_response(message="退出登录成功") \ No newline at end of file + return success_response(message="退出登录成功") + +@router.put("/update", response_model=ResponseModel) +async def update_user_info( + update_data: UserUpdate, + db: Session = Depends(get_db), + current_user: UserDB = Depends(get_current_user) +): + """更新用户信息""" + # 获取非空的更新字段 + update_fields = update_data.model_dump(exclude_unset=True) + + if not update_fields: + return error_response(code=400, message="没有提供要更新的字段") + + # 更新字段 + for field, value in update_fields.items(): + setattr(current_user, field, value) + + try: + db.commit() + db.refresh(current_user) + return success_response( + message="用户信息更新成功", + data=UserInfo.model_validate(current_user) + ) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"更新失败: {str(e)}") \ No newline at end of file diff --git a/app/models/community.py b/app/models/community.py index d375a9f..e97dc5e 100644 --- a/app/models/community.py +++ b/app/models/community.py @@ -1,5 +1,5 @@ from typing import Optional -from sqlalchemy import Column, Integer, String, Float, DateTime +from sqlalchemy import Column, Integer, String, DECIMAL, DateTime from sqlalchemy.sql import func from pydantic import BaseModel, Field from .database import Base @@ -11,8 +11,8 @@ class CommunityDB(Base): id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(100), nullable=False) address = Column(String(200), nullable=False) - longitude = Column(Float, nullable=False) # 经度 - latitude = Column(Float, nullable=False) # 纬度 + longitude = Column(DECIMAL(9,6), nullable=False) # 经度,精确到小数点后6位 + latitude = Column(DECIMAL(9,6), nullable=False) # 纬度,精确到小数点后6位 create_time = Column(DateTime(timezone=True), server_default=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now()) @@ -21,13 +21,13 @@ class CommunityCreate(BaseModel): name: str = Field(..., max_length=100) address: str = Field(..., max_length=200) longitude: float = Field(..., ge=-180, le=180) - latitude: float = Field(..., ge=-90, le=90) + latitude: float = Field(..., ge=-180, le=180) class CommunityUpdate(BaseModel): name: Optional[str] = Field(None, max_length=100) address: Optional[str] = Field(None, max_length=200) longitude: Optional[float] = Field(None, ge=-180, le=180) - latitude: Optional[float] = Field(None, ge=-90, le=90) + latitude: Optional[float] = Field(None, ge=-180, le=180) class CommunityInfo(BaseModel): id: int diff --git a/app/models/user.py b/app/models/user.py index 834951f..08d7899 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, String, DateTime,Integer, Boolean from sqlalchemy.sql import func from pydantic import BaseModel, Field from .database import Base +from typing import Optional # 数据库模型 class UserDB(Base): @@ -29,4 +30,11 @@ class UserInfo(BaseModel): from_attributes = True class VerifyCodeRequest(BaseModel): - phone: str = Field(..., pattern="^1[3-9]\d{9}$") \ No newline at end of file + phone: str = Field(..., pattern="^1[3-9]\d{9}$") + +class UserUpdate(BaseModel): + username: Optional[str] = Field(None, min_length=2, max_length=50) + avatar: Optional[str] = Field(None, max_length=200) + + class Config: + extra = "forbid" # 禁止额外字段 \ No newline at end of file