astock-agent/backend/app/api/auth.py
2026-04-17 00:32:21 +08:00

335 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""认证 API
登录、密码修改、用户管理(管理员)、数据重置(管理员)
"""
import secrets
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select, update, text, func
from app.core.auth import hash_password, verify_password, create_access_token
from app.core.deps import get_current_user, get_current_admin
from app.db.database import get_db
from app.db.tables import users_table
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["auth"])
# ---------- Request/Response Models ----------
class LoginRequest(BaseModel):
username: str
password: str
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
class CreateUserRequest(BaseModel):
username: str
role: str = "user"
class DataResetRequest(BaseModel):
mode: str # "all", "recommendations", "date_range", "low_score"
before_date: str | None = None # for date_range mode, e.g. "2026-04-10"
min_score: int | None = None # for low_score mode, default 60
# ---------- Public Endpoints ----------
@router.post("/login")
async def login(req: LoginRequest):
"""用户登录,返回 JWT token"""
async with get_db() as db:
result = await db.execute(
select(users_table).where(users_table.c.username == req.username)
)
user = result.mappings().first()
if user is None or not user["is_active"]:
raise HTTPException(status_code=401, detail="用户名或密码错误")
if not verify_password(req.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="用户名或密码错误")
token = create_access_token({"sub": str(user["id"]), "role": user["role"]})
return {
"token": token,
"user": {
"id": user["id"],
"username": user["username"],
"role": user["role"],
},
}
# ---------- Authenticated Endpoints ----------
@router.get("/me")
async def get_me(current_user: dict = Depends(get_current_user)):
"""获取当前用户信息"""
return {
"id": current_user["id"],
"username": current_user["username"],
"role": current_user["role"],
"is_active": current_user["is_active"],
}
@router.post("/change-password")
async def change_password(
req: ChangePasswordRequest,
current_user: dict = Depends(get_current_user),
):
"""修改自己的密码"""
if not verify_password(req.old_password, current_user["password_hash"]):
raise HTTPException(status_code=400, detail="旧密码错误")
new_hash = hash_password(req.new_password)
async with get_db() as db:
await db.execute(
update(users_table)
.where(users_table.c.id == current_user["id"])
.values(password_hash=new_hash)
)
await db.commit()
return {"message": "密码修改成功"}
# ---------- Admin Endpoints ----------
@router.get("/users")
async def list_users(admin: dict = Depends(get_current_admin)):
"""列出所有用户(管理员)"""
async with get_db() as db:
result = await db.execute(
select(users_table).order_by(users_table.c.id)
)
rows = result.mappings().all()
return [
{
"id": r["id"],
"username": r["username"],
"role": r["role"],
"is_active": r["is_active"],
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
}
for r in rows
]
@router.post("/users")
async def create_user(req: CreateUserRequest, admin: dict = Depends(get_current_admin)):
"""创建新用户(管理员),自动生成随机密码"""
# 检查用户名是否已存在
async with get_db() as db:
result = await db.execute(
select(users_table).where(users_table.c.username == req.username)
)
if result.first():
raise HTTPException(status_code=400, detail="用户名已存在")
if req.role not in ("admin", "user"):
raise HTTPException(status_code=400, detail="角色必须是 admin 或 user")
# 生成 12 位随机密码
raw_password = secrets.token_urlsafe(9)
password_hash = hash_password(raw_password)
await db.execute(
users_table.insert().values(
username=req.username,
password_hash=password_hash,
role=req.role,
)
)
await db.commit()
logger.info(f"管理员 {admin['username']} 创建了用户 {req.username} ({req.role})")
return {
"username": req.username,
"password": raw_password,
"role": req.role,
"message": "请妥善保管密码,此密码仅显示一次",
}
@router.delete("/users/{user_id}")
async def disable_user(user_id: int, admin: dict = Depends(get_current_admin)):
"""禁用用户(软删除)"""
if user_id == admin["id"]:
raise HTTPException(status_code=400, detail="不能禁用自己")
async with get_db() as db:
result = await db.execute(
select(users_table).where(users_table.c.id == user_id)
)
user = result.mappings().first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
await db.execute(
update(users_table)
.where(users_table.c.id == user_id)
.values(is_active=False)
)
await db.commit()
return {"message": f"用户 {user['username']} 已禁用"}
@router.post("/users/{user_id}/reset-password")
async def reset_password(user_id: int, admin: dict = Depends(get_current_admin)):
"""重置用户密码(管理员),生成新的随机密码"""
async with get_db() as db:
result = await db.execute(
select(users_table).where(users_table.c.id == user_id)
)
user = result.mappings().first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
raw_password = secrets.token_urlsafe(9)
password_hash = hash_password(raw_password)
await db.execute(
update(users_table)
.where(users_table.c.id == user_id)
.values(password_hash=password_hash)
)
await db.commit()
return {
"username": user["username"],
"password": raw_password,
"message": "请妥善保管新密码,此密码仅显示一次",
}
# ---------- Data Reset Endpoints (Admin Only) ----------
@router.get("/data-stats")
async def get_data_stats(admin: dict = Depends(get_current_admin)):
"""获取数据统计(管理员)"""
async with get_db() as db:
rec_count = (await db.execute(text("SELECT COUNT(*) FROM recommendations"))).scalar() or 0
track_count = (await db.execute(text("SELECT COUNT(*) FROM recommendation_tracking"))).scalar() or 0
sector_count = (await db.execute(text("SELECT COUNT(*) FROM sector_heat"))).scalar() or 0
temp_count = (await db.execute(text("SELECT COUNT(*) FROM market_temperature"))).scalar() or 0
review_count = (await db.execute(text("SELECT COUNT(*) FROM daily_reviews"))).scalar() or 0
low_score = (await db.execute(text("SELECT COUNT(*) FROM recommendations WHERE score < 60"))).scalar() or 0
# 最新日期
latest_rec = (await db.execute(text("SELECT MAX(date(created_at)) FROM recommendations"))).scalar() or ""
earliest_rec = (await db.execute(text("SELECT MIN(date(created_at)) FROM recommendations"))).scalar() or ""
return {
"recommendations": rec_count,
"tracking": track_count,
"sector_heat": sector_count,
"market_temperature": temp_count,
"daily_reviews": review_count,
"low_score_count": low_score,
"latest_date": str(latest_rec),
"earliest_date": str(earliest_rec),
}
@router.post("/data-reset")
async def data_reset(req: DataResetRequest, admin: dict = Depends(get_current_admin)):
"""数据重置(管理员)
mode:
- "all": 清除所有业务数据(推荐、跟踪、板块热度、市场温度、复盘)
- "recommendations": 清除推荐记录和跟踪数据,保留板块和市场温度
- "date_range": 清除指定日期之前的数据
- "low_score": 清除低分推荐score < min_score和过期跟踪数据
"""
deleted = {}
async with get_db() as db:
if req.mode == "all":
# 清除所有业务数据(保留用户)
result = await db.execute(text("DELETE FROM recommendation_tracking"))
deleted["tracking"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM recommendations"))
deleted["recommendations"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM sector_heat"))
deleted["sector_heat"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM market_temperature"))
deleted["market_temperature"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM daily_reviews"))
deleted["daily_reviews"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM stock_diagnoses"))
deleted["stock_diagnoses"] = result.rowcount or 0
elif req.mode == "recommendations":
result = await db.execute(text("DELETE FROM recommendation_tracking"))
deleted["tracking"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM recommendations"))
deleted["recommendations"] = result.rowcount or 0
elif req.mode == "date_range":
if not req.before_date:
raise HTTPException(status_code=400, detail="date_range 模式需要 before_date 参数")
# 删除 before_date 之前的推荐和跟踪
result = await db.execute(
text("DELETE FROM recommendation_tracking WHERE track_date < :bd"),
{"bd": req.before_date}
)
deleted["tracking"] = result.rowcount or 0
result = await db.execute(
text("DELETE FROM recommendations WHERE date(created_at) < :bd"),
{"bd": req.before_date}
)
deleted["recommendations"] = result.rowcount or 0
result = await db.execute(
text("DELETE FROM sector_heat WHERE trade_date < :bd"),
{"bd": req.before_date}
)
deleted["sector_heat"] = result.rowcount or 0
result = await db.execute(
text("DELETE FROM market_temperature WHERE trade_date < :bd"),
{"bd": req.before_date}
)
deleted["market_temperature"] = result.rowcount or 0
elif req.mode == "low_score":
threshold = req.min_score or 60
# 删除低分推荐及其跟踪数据
result = await db.execute(
text("DELETE FROM recommendation_tracking WHERE recommendation_id IN "
"(SELECT id FROM recommendations WHERE score < :ms)"),
{"ms": threshold}
)
deleted["tracking"] = result.rowcount or 0
result = await db.execute(
text("DELETE FROM recommendations WHERE score < :ms"),
{"ms": threshold}
)
deleted["recommendations"] = result.rowcount or 0
else:
raise HTTPException(status_code=400, detail=f"不支持的模式: {req.mode}")
await db.commit()
logger.info(f"管理员 {admin['username']} 执行数据重置: mode={req.mode}, deleted={deleted}")
return {
"status": "ok",
"mode": req.mode,
"deleted": deleted,
}