"""认证 API 登录、密码修改、用户管理(管理员)、数据重置(管理员) """ import secrets import logging from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel from sqlalchemy import select, update, text, func from app.core.auth import hash_password, verify_password, create_access_token from app.core.deps import get_current_user, get_current_admin from app.db.database import get_db from app.db.tables import users_table logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/auth", tags=["auth"]) # ---------- Request/Response Models ---------- class LoginRequest(BaseModel): username: str password: str class ChangePasswordRequest(BaseModel): old_password: str new_password: str class CreateUserRequest(BaseModel): username: str role: str = "user" class DataResetRequest(BaseModel): mode: str # "all", "recommendations", "date_range", "low_score" before_date: str | None = None # for date_range mode, e.g. "2026-04-10" min_score: int | None = None # for low_score mode, default 60 # ---------- Public Endpoints ---------- @router.post("/login") async def login(req: LoginRequest): """用户登录,返回 JWT token""" async with get_db() as db: result = await db.execute( select(users_table).where(users_table.c.username == req.username) ) user = result.mappings().first() if user is None or not user["is_active"]: raise HTTPException(status_code=401, detail="用户名或密码错误") if not verify_password(req.password, user["password_hash"]): raise HTTPException(status_code=401, detail="用户名或密码错误") token = create_access_token({"sub": str(user["id"]), "role": user["role"]}) return { "token": token, "user": { "id": user["id"], "username": user["username"], "role": user["role"], }, } # ---------- Authenticated Endpoints ---------- @router.get("/me") async def get_me(current_user: dict = Depends(get_current_user)): """获取当前用户信息""" return { "id": current_user["id"], "username": current_user["username"], "role": current_user["role"], "is_active": current_user["is_active"], } @router.post("/change-password") async def change_password( req: ChangePasswordRequest, current_user: dict = Depends(get_current_user), ): """修改自己的密码""" if not verify_password(req.old_password, current_user["password_hash"]): raise HTTPException(status_code=400, detail="旧密码错误") new_hash = hash_password(req.new_password) async with get_db() as db: await db.execute( update(users_table) .where(users_table.c.id == current_user["id"]) .values(password_hash=new_hash) ) await db.commit() return {"message": "密码修改成功"} # ---------- Admin Endpoints ---------- @router.get("/users") async def list_users(admin: dict = Depends(get_current_admin)): """列出所有用户(管理员)""" async with get_db() as db: result = await db.execute( select(users_table).order_by(users_table.c.id) ) rows = result.mappings().all() return [ { "id": r["id"], "username": r["username"], "role": r["role"], "is_active": r["is_active"], "created_at": r["created_at"].isoformat() if r["created_at"] else None, } for r in rows ] @router.post("/users") async def create_user(req: CreateUserRequest, admin: dict = Depends(get_current_admin)): """创建新用户(管理员),自动生成随机密码""" # 检查用户名是否已存在 async with get_db() as db: result = await db.execute( select(users_table).where(users_table.c.username == req.username) ) if result.first(): raise HTTPException(status_code=400, detail="用户名已存在") if req.role not in ("admin", "user"): raise HTTPException(status_code=400, detail="角色必须是 admin 或 user") # 生成 12 位随机密码 raw_password = secrets.token_urlsafe(9) password_hash = hash_password(raw_password) await db.execute( users_table.insert().values( username=req.username, password_hash=password_hash, role=req.role, ) ) await db.commit() logger.info(f"管理员 {admin['username']} 创建了用户 {req.username} ({req.role})") return { "username": req.username, "password": raw_password, "role": req.role, "message": "请妥善保管密码,此密码仅显示一次", } @router.delete("/users/{user_id}") async def disable_user(user_id: int, admin: dict = Depends(get_current_admin)): """禁用用户(软删除)""" if user_id == admin["id"]: raise HTTPException(status_code=400, detail="不能禁用自己") async with get_db() as db: result = await db.execute( select(users_table).where(users_table.c.id == user_id) ) user = result.mappings().first() if not user: raise HTTPException(status_code=404, detail="用户不存在") await db.execute( update(users_table) .where(users_table.c.id == user_id) .values(is_active=False) ) await db.commit() return {"message": f"用户 {user['username']} 已禁用"} @router.post("/users/{user_id}/reset-password") async def reset_password(user_id: int, admin: dict = Depends(get_current_admin)): """重置用户密码(管理员),生成新的随机密码""" async with get_db() as db: result = await db.execute( select(users_table).where(users_table.c.id == user_id) ) user = result.mappings().first() if not user: raise HTTPException(status_code=404, detail="用户不存在") raw_password = secrets.token_urlsafe(9) password_hash = hash_password(raw_password) await db.execute( update(users_table) .where(users_table.c.id == user_id) .values(password_hash=password_hash) ) await db.commit() return { "username": user["username"], "password": raw_password, "message": "请妥善保管新密码,此密码仅显示一次", } # ---------- Data Reset Endpoints (Admin Only) ---------- @router.get("/data-stats") async def get_data_stats(admin: dict = Depends(get_current_admin)): """获取数据统计(管理员)""" async with get_db() as db: rec_count = (await db.execute(text("SELECT COUNT(*) FROM recommendations"))).scalar() or 0 track_count = (await db.execute(text("SELECT COUNT(*) FROM recommendation_tracking"))).scalar() or 0 sector_count = (await db.execute(text("SELECT COUNT(*) FROM sector_heat"))).scalar() or 0 temp_count = (await db.execute(text("SELECT COUNT(*) FROM market_temperature"))).scalar() or 0 review_count = (await db.execute(text("SELECT COUNT(*) FROM daily_reviews"))).scalar() or 0 low_score = (await db.execute(text("SELECT COUNT(*) FROM recommendations WHERE score < 60"))).scalar() or 0 # 最新日期 latest_rec = (await db.execute(text("SELECT MAX(date(created_at)) FROM recommendations"))).scalar() or "" earliest_rec = (await db.execute(text("SELECT MIN(date(created_at)) FROM recommendations"))).scalar() or "" return { "recommendations": rec_count, "tracking": track_count, "sector_heat": sector_count, "market_temperature": temp_count, "daily_reviews": review_count, "low_score_count": low_score, "latest_date": str(latest_rec), "earliest_date": str(earliest_rec), } @router.post("/data-reset") async def data_reset(req: DataResetRequest, admin: dict = Depends(get_current_admin)): """数据重置(管理员) mode: - "all": 清除所有业务数据(推荐、跟踪、板块热度、市场温度、复盘) - "recommendations": 清除推荐记录和跟踪数据,保留板块和市场温度 - "date_range": 清除指定日期之前的数据 - "low_score": 清除低分推荐(score < min_score)和过期跟踪数据 """ deleted = {} async with get_db() as db: if req.mode == "all": # 清除所有业务数据(保留用户) result = await db.execute(text("DELETE FROM recommendation_tracking")) deleted["tracking"] = result.rowcount or 0 result = await db.execute(text("DELETE FROM recommendations")) deleted["recommendations"] = result.rowcount or 0 result = await db.execute(text("DELETE FROM sector_heat")) deleted["sector_heat"] = result.rowcount or 0 result = await db.execute(text("DELETE FROM market_temperature")) deleted["market_temperature"] = result.rowcount or 0 result = await db.execute(text("DELETE FROM daily_reviews")) deleted["daily_reviews"] = result.rowcount or 0 result = await db.execute(text("DELETE FROM stock_diagnoses")) deleted["stock_diagnoses"] = result.rowcount or 0 elif req.mode == "recommendations": result = await db.execute(text("DELETE FROM recommendation_tracking")) deleted["tracking"] = result.rowcount or 0 result = await db.execute(text("DELETE FROM recommendations")) deleted["recommendations"] = result.rowcount or 0 elif req.mode == "date_range": if not req.before_date: raise HTTPException(status_code=400, detail="date_range 模式需要 before_date 参数") # 删除 before_date 之前的推荐和跟踪 result = await db.execute( text("DELETE FROM recommendation_tracking WHERE track_date < :bd"), {"bd": req.before_date} ) deleted["tracking"] = result.rowcount or 0 result = await db.execute( text("DELETE FROM recommendations WHERE date(created_at) < :bd"), {"bd": req.before_date} ) deleted["recommendations"] = result.rowcount or 0 result = await db.execute( text("DELETE FROM sector_heat WHERE trade_date < :bd"), {"bd": req.before_date} ) deleted["sector_heat"] = result.rowcount or 0 result = await db.execute( text("DELETE FROM market_temperature WHERE trade_date < :bd"), {"bd": req.before_date} ) deleted["market_temperature"] = result.rowcount or 0 elif req.mode == "low_score": threshold = req.min_score or 60 # 删除低分推荐及其跟踪数据 result = await db.execute( text("DELETE FROM recommendation_tracking WHERE recommendation_id IN " "(SELECT id FROM recommendations WHERE score < :ms)"), {"ms": threshold} ) deleted["tracking"] = result.rowcount or 0 result = await db.execute( text("DELETE FROM recommendations WHERE score < :ms"), {"ms": threshold} ) deleted["recommendations"] = result.rowcount or 0 else: raise HTTPException(status_code=400, detail=f"不支持的模式: {req.mode}") await db.commit() logger.info(f"管理员 {admin['username']} 执行数据重置: mode={req.mode}, deleted={deleted}") return { "status": "ok", "mode": req.mode, "deleted": deleted, }