335 lines
12 KiB
Python
335 lines
12 KiB
Python
"""认证 API
|
||
|
||
登录、密码修改、用户管理(管理员)、数据重置(管理员)
|
||
"""
|
||
|
||
import secrets
|
||
import logging
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import select, update, text, func
|
||
|
||
from app.core.auth import hash_password, verify_password, create_access_token
|
||
from app.core.deps import get_current_user, get_current_admin
|
||
from app.db.database import get_db
|
||
from app.db.tables import users_table
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||
|
||
|
||
# ---------- Request/Response Models ----------
|
||
|
||
class LoginRequest(BaseModel):
|
||
username: str
|
||
password: str
|
||
|
||
|
||
class ChangePasswordRequest(BaseModel):
|
||
old_password: str
|
||
new_password: str
|
||
|
||
|
||
class CreateUserRequest(BaseModel):
|
||
username: str
|
||
role: str = "user"
|
||
|
||
|
||
class DataResetRequest(BaseModel):
|
||
mode: str # "all", "recommendations", "date_range", "low_score"
|
||
before_date: str | None = None # for date_range mode, e.g. "2026-04-10"
|
||
min_score: int | None = None # for low_score mode, default 60
|
||
|
||
|
||
# ---------- Public Endpoints ----------
|
||
|
||
@router.post("/login")
|
||
async def login(req: LoginRequest):
|
||
"""用户登录,返回 JWT token"""
|
||
async with get_db() as db:
|
||
result = await db.execute(
|
||
select(users_table).where(users_table.c.username == req.username)
|
||
)
|
||
user = result.mappings().first()
|
||
|
||
if user is None or not user["is_active"]:
|
||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||
|
||
if not verify_password(req.password, user["password_hash"]):
|
||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||
|
||
token = create_access_token({"sub": str(user["id"]), "role": user["role"]})
|
||
|
||
return {
|
||
"token": token,
|
||
"user": {
|
||
"id": user["id"],
|
||
"username": user["username"],
|
||
"role": user["role"],
|
||
},
|
||
}
|
||
|
||
|
||
# ---------- Authenticated Endpoints ----------
|
||
|
||
@router.get("/me")
|
||
async def get_me(current_user: dict = Depends(get_current_user)):
|
||
"""获取当前用户信息"""
|
||
return {
|
||
"id": current_user["id"],
|
||
"username": current_user["username"],
|
||
"role": current_user["role"],
|
||
"is_active": current_user["is_active"],
|
||
}
|
||
|
||
|
||
@router.post("/change-password")
|
||
async def change_password(
|
||
req: ChangePasswordRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""修改自己的密码"""
|
||
if not verify_password(req.old_password, current_user["password_hash"]):
|
||
raise HTTPException(status_code=400, detail="旧密码错误")
|
||
|
||
new_hash = hash_password(req.new_password)
|
||
async with get_db() as db:
|
||
await db.execute(
|
||
update(users_table)
|
||
.where(users_table.c.id == current_user["id"])
|
||
.values(password_hash=new_hash)
|
||
)
|
||
await db.commit()
|
||
|
||
return {"message": "密码修改成功"}
|
||
|
||
|
||
# ---------- Admin Endpoints ----------
|
||
|
||
@router.get("/users")
|
||
async def list_users(admin: dict = Depends(get_current_admin)):
|
||
"""列出所有用户(管理员)"""
|
||
async with get_db() as db:
|
||
result = await db.execute(
|
||
select(users_table).order_by(users_table.c.id)
|
||
)
|
||
rows = result.mappings().all()
|
||
|
||
return [
|
||
{
|
||
"id": r["id"],
|
||
"username": r["username"],
|
||
"role": r["role"],
|
||
"is_active": r["is_active"],
|
||
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
|
||
}
|
||
for r in rows
|
||
]
|
||
|
||
|
||
@router.post("/users")
|
||
async def create_user(req: CreateUserRequest, admin: dict = Depends(get_current_admin)):
|
||
"""创建新用户(管理员),自动生成随机密码"""
|
||
# 检查用户名是否已存在
|
||
async with get_db() as db:
|
||
result = await db.execute(
|
||
select(users_table).where(users_table.c.username == req.username)
|
||
)
|
||
if result.first():
|
||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||
|
||
if req.role not in ("admin", "user"):
|
||
raise HTTPException(status_code=400, detail="角色必须是 admin 或 user")
|
||
|
||
# 生成 12 位随机密码
|
||
raw_password = secrets.token_urlsafe(9)
|
||
password_hash = hash_password(raw_password)
|
||
|
||
await db.execute(
|
||
users_table.insert().values(
|
||
username=req.username,
|
||
password_hash=password_hash,
|
||
role=req.role,
|
||
)
|
||
)
|
||
await db.commit()
|
||
|
||
logger.info(f"管理员 {admin['username']} 创建了用户 {req.username} ({req.role})")
|
||
|
||
return {
|
||
"username": req.username,
|
||
"password": raw_password,
|
||
"role": req.role,
|
||
"message": "请妥善保管密码,此密码仅显示一次",
|
||
}
|
||
|
||
|
||
@router.delete("/users/{user_id}")
|
||
async def disable_user(user_id: int, admin: dict = Depends(get_current_admin)):
|
||
"""禁用用户(软删除)"""
|
||
if user_id == admin["id"]:
|
||
raise HTTPException(status_code=400, detail="不能禁用自己")
|
||
|
||
async with get_db() as db:
|
||
result = await db.execute(
|
||
select(users_table).where(users_table.c.id == user_id)
|
||
)
|
||
user = result.mappings().first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
await db.execute(
|
||
update(users_table)
|
||
.where(users_table.c.id == user_id)
|
||
.values(is_active=False)
|
||
)
|
||
await db.commit()
|
||
|
||
return {"message": f"用户 {user['username']} 已禁用"}
|
||
|
||
|
||
@router.post("/users/{user_id}/reset-password")
|
||
async def reset_password(user_id: int, admin: dict = Depends(get_current_admin)):
|
||
"""重置用户密码(管理员),生成新的随机密码"""
|
||
async with get_db() as db:
|
||
result = await db.execute(
|
||
select(users_table).where(users_table.c.id == user_id)
|
||
)
|
||
user = result.mappings().first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
raw_password = secrets.token_urlsafe(9)
|
||
password_hash = hash_password(raw_password)
|
||
|
||
await db.execute(
|
||
update(users_table)
|
||
.where(users_table.c.id == user_id)
|
||
.values(password_hash=password_hash)
|
||
)
|
||
await db.commit()
|
||
|
||
return {
|
||
"username": user["username"],
|
||
"password": raw_password,
|
||
"message": "请妥善保管新密码,此密码仅显示一次",
|
||
}
|
||
|
||
|
||
# ---------- Data Reset Endpoints (Admin Only) ----------
|
||
|
||
@router.get("/data-stats")
|
||
async def get_data_stats(admin: dict = Depends(get_current_admin)):
|
||
"""获取数据统计(管理员)"""
|
||
async with get_db() as db:
|
||
rec_count = (await db.execute(text("SELECT COUNT(*) FROM recommendations"))).scalar() or 0
|
||
track_count = (await db.execute(text("SELECT COUNT(*) FROM recommendation_tracking"))).scalar() or 0
|
||
sector_count = (await db.execute(text("SELECT COUNT(*) FROM sector_heat"))).scalar() or 0
|
||
temp_count = (await db.execute(text("SELECT COUNT(*) FROM market_temperature"))).scalar() or 0
|
||
review_count = (await db.execute(text("SELECT COUNT(*) FROM daily_reviews"))).scalar() or 0
|
||
low_score = (await db.execute(text("SELECT COUNT(*) FROM recommendations WHERE score < 60"))).scalar() or 0
|
||
|
||
# 最新日期
|
||
latest_rec = (await db.execute(text("SELECT MAX(date(created_at)) FROM recommendations"))).scalar() or ""
|
||
earliest_rec = (await db.execute(text("SELECT MIN(date(created_at)) FROM recommendations"))).scalar() or ""
|
||
|
||
return {
|
||
"recommendations": rec_count,
|
||
"tracking": track_count,
|
||
"sector_heat": sector_count,
|
||
"market_temperature": temp_count,
|
||
"daily_reviews": review_count,
|
||
"low_score_count": low_score,
|
||
"latest_date": str(latest_rec),
|
||
"earliest_date": str(earliest_rec),
|
||
}
|
||
|
||
|
||
@router.post("/data-reset")
|
||
async def data_reset(req: DataResetRequest, admin: dict = Depends(get_current_admin)):
|
||
"""数据重置(管理员)
|
||
|
||
mode:
|
||
- "all": 清除所有业务数据(推荐、跟踪、板块热度、市场温度、复盘)
|
||
- "recommendations": 清除推荐记录和跟踪数据,保留板块和市场温度
|
||
- "date_range": 清除指定日期之前的数据
|
||
- "low_score": 清除低分推荐(score < min_score)和过期跟踪数据
|
||
"""
|
||
deleted = {}
|
||
|
||
async with get_db() as db:
|
||
if req.mode == "all":
|
||
# 清除所有业务数据(保留用户)
|
||
result = await db.execute(text("DELETE FROM recommendation_tracking"))
|
||
deleted["tracking"] = result.rowcount or 0
|
||
result = await db.execute(text("DELETE FROM recommendations"))
|
||
deleted["recommendations"] = result.rowcount or 0
|
||
result = await db.execute(text("DELETE FROM sector_heat"))
|
||
deleted["sector_heat"] = result.rowcount or 0
|
||
result = await db.execute(text("DELETE FROM market_temperature"))
|
||
deleted["market_temperature"] = result.rowcount or 0
|
||
result = await db.execute(text("DELETE FROM daily_reviews"))
|
||
deleted["daily_reviews"] = result.rowcount or 0
|
||
result = await db.execute(text("DELETE FROM stock_diagnoses"))
|
||
deleted["stock_diagnoses"] = result.rowcount or 0
|
||
|
||
elif req.mode == "recommendations":
|
||
result = await db.execute(text("DELETE FROM recommendation_tracking"))
|
||
deleted["tracking"] = result.rowcount or 0
|
||
result = await db.execute(text("DELETE FROM recommendations"))
|
||
deleted["recommendations"] = result.rowcount or 0
|
||
|
||
elif req.mode == "date_range":
|
||
if not req.before_date:
|
||
raise HTTPException(status_code=400, detail="date_range 模式需要 before_date 参数")
|
||
# 删除 before_date 之前的推荐和跟踪
|
||
result = await db.execute(
|
||
text("DELETE FROM recommendation_tracking WHERE track_date < :bd"),
|
||
{"bd": req.before_date}
|
||
)
|
||
deleted["tracking"] = result.rowcount or 0
|
||
result = await db.execute(
|
||
text("DELETE FROM recommendations WHERE date(created_at) < :bd"),
|
||
{"bd": req.before_date}
|
||
)
|
||
deleted["recommendations"] = result.rowcount or 0
|
||
result = await db.execute(
|
||
text("DELETE FROM sector_heat WHERE trade_date < :bd"),
|
||
{"bd": req.before_date}
|
||
)
|
||
deleted["sector_heat"] = result.rowcount or 0
|
||
result = await db.execute(
|
||
text("DELETE FROM market_temperature WHERE trade_date < :bd"),
|
||
{"bd": req.before_date}
|
||
)
|
||
deleted["market_temperature"] = result.rowcount or 0
|
||
|
||
elif req.mode == "low_score":
|
||
threshold = req.min_score or 60
|
||
# 删除低分推荐及其跟踪数据
|
||
result = await db.execute(
|
||
text("DELETE FROM recommendation_tracking WHERE recommendation_id IN "
|
||
"(SELECT id FROM recommendations WHERE score < :ms)"),
|
||
{"ms": threshold}
|
||
)
|
||
deleted["tracking"] = result.rowcount or 0
|
||
result = await db.execute(
|
||
text("DELETE FROM recommendations WHERE score < :ms"),
|
||
{"ms": threshold}
|
||
)
|
||
deleted["recommendations"] = result.rowcount or 0
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"不支持的模式: {req.mode}")
|
||
|
||
await db.commit()
|
||
|
||
logger.info(f"管理员 {admin['username']} 执行数据重置: mode={req.mode}, deleted={deleted}")
|
||
|
||
return {
|
||
"status": "ok",
|
||
"mode": req.mode,
|
||
"deleted": deleted,
|
||
}
|