512 lines
20 KiB
Python
512 lines
20 KiB
Python
"""认证 API。
|
|
|
|
邮箱+密码登录。
|
|
邀请码 + 邮箱验证码 + 密码注册。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import random
|
|
import string
|
|
import re
|
|
from datetime import datetime, timedelta
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import select, update, text, func, delete
|
|
|
|
from app.config import settings
|
|
from app.core.auth import hash_password, verify_password, create_access_token
|
|
from app.core.deps import get_current_user, get_current_admin
|
|
from app.core.email import send_email, build_register_code_email
|
|
from app.db.database import get_db
|
|
from app.db.tables import users_table, email_verification_codes_table, invite_codes_table
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
|
|
|
|
class SendRegisterCodeRequest(BaseModel):
|
|
email: str
|
|
invite_code: str = Field(min_length=4, max_length=64)
|
|
|
|
|
|
class RegisterRequest(BaseModel):
|
|
email: str
|
|
invite_code: str = Field(min_length=4, max_length=64)
|
|
email_code: str = Field(min_length=6, max_length=6)
|
|
password: str = Field(min_length=6)
|
|
|
|
|
|
class ChangePasswordRequest(BaseModel):
|
|
old_password: str
|
|
new_password: str = Field(min_length=6)
|
|
|
|
|
|
class CreateInviteCodeRequest(BaseModel):
|
|
code: str = Field(min_length=4, max_length=64)
|
|
description: str = ""
|
|
max_uses: int = 1
|
|
|
|
|
|
class DataResetRequest(BaseModel):
|
|
mode: str
|
|
before_date: str | None = None
|
|
min_score: int | None = None
|
|
|
|
|
|
def _normalize_email(email: str) -> str:
|
|
return str(email or "").strip().lower()
|
|
|
|
|
|
def _validate_email(email: str) -> str:
|
|
value = _normalize_email(email)
|
|
if not re.fullmatch(r"[^@\s]+@[^@\s]+\.[^@\s]+", value):
|
|
raise HTTPException(status_code=400, detail="邮箱格式错误")
|
|
return value
|
|
|
|
|
|
def _validate_password(password: str) -> None:
|
|
if len(password or "") < settings.auth_min_password_length:
|
|
raise HTTPException(status_code=400, detail=f"密码至少 {settings.auth_min_password_length} 位")
|
|
|
|
|
|
def _build_username_from_email(email: str) -> str:
|
|
return _normalize_email(email)
|
|
|
|
|
|
def _generate_email_code() -> str:
|
|
return "".join(random.choices(string.digits, k=6))
|
|
|
|
|
|
def _coerce_naive_datetime(value) -> datetime | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, datetime):
|
|
return value.replace(tzinfo=None)
|
|
if isinstance(value, str):
|
|
return datetime.fromisoformat(value.replace("Z", "+00:00")).replace(tzinfo=None)
|
|
return None
|
|
|
|
|
|
async def _get_user_by_email(email: str) -> dict | None:
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
select(users_table).where(users_table.c.email == _normalize_email(email))
|
|
)
|
|
user = result.mappings().first()
|
|
return dict(user) if user else None
|
|
|
|
|
|
async def _get_invite_code(code: str) -> dict | None:
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
select(invite_codes_table).where(invite_codes_table.c.code == code.strip())
|
|
)
|
|
row = result.mappings().first()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def _assert_invite_code_valid(invite_row: dict | None) -> None:
|
|
if not settings.invite_code_required:
|
|
return
|
|
if not invite_row:
|
|
raise HTTPException(status_code=400, detail="邀请码无效")
|
|
if not invite_row["is_active"]:
|
|
raise HTTPException(status_code=400, detail="邀请码已停用")
|
|
if invite_row["max_uses"] is not None and invite_row["used_count"] >= invite_row["max_uses"]:
|
|
raise HTTPException(status_code=400, detail="邀请码已用完")
|
|
expires_at = _coerce_naive_datetime(invite_row.get("expires_at"))
|
|
if expires_at and expires_at < datetime.utcnow():
|
|
raise HTTPException(status_code=400, detail="邀请码已过期")
|
|
|
|
|
|
async def _consume_invite_code(code: str) -> None:
|
|
if not settings.invite_code_required:
|
|
return
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
update(invite_codes_table)
|
|
.where(invite_codes_table.c.code == code.strip())
|
|
.values(
|
|
used_count=invite_codes_table.c.used_count + 1,
|
|
updated_at=func.now(),
|
|
)
|
|
)
|
|
await db.commit()
|
|
|
|
|
|
async def _save_email_code(email: str, code: str, purpose: str) -> None:
|
|
expires_at = datetime.utcnow() + timedelta(minutes=settings.email_code_expiry_minutes)
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
delete(email_verification_codes_table).where(
|
|
email_verification_codes_table.c.email == email,
|
|
email_verification_codes_table.c.purpose == purpose,
|
|
)
|
|
)
|
|
await db.execute(
|
|
email_verification_codes_table.insert().values(
|
|
email=email,
|
|
code=code,
|
|
purpose=purpose,
|
|
expires_at=expires_at,
|
|
used=False,
|
|
)
|
|
)
|
|
await db.commit()
|
|
|
|
|
|
async def _assert_email_code_valid(email: str, code: str, purpose: str) -> None:
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
select(email_verification_codes_table)
|
|
.where(
|
|
email_verification_codes_table.c.email == email,
|
|
email_verification_codes_table.c.code == code,
|
|
email_verification_codes_table.c.purpose == purpose,
|
|
email_verification_codes_table.c.used == False, # noqa: E712
|
|
)
|
|
.order_by(email_verification_codes_table.c.id.desc())
|
|
)
|
|
row = result.mappings().first()
|
|
if not row:
|
|
raise HTTPException(status_code=400, detail="邮箱验证码错误")
|
|
if row["expires_at"] < datetime.utcnow():
|
|
raise HTTPException(status_code=400, detail="邮箱验证码已过期")
|
|
|
|
|
|
async def _mark_email_code_used(email: str, code: str, purpose: str) -> None:
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
update(email_verification_codes_table)
|
|
.where(
|
|
email_verification_codes_table.c.email == email,
|
|
email_verification_codes_table.c.code == code,
|
|
email_verification_codes_table.c.purpose == purpose,
|
|
)
|
|
.values(used=True)
|
|
)
|
|
await db.commit()
|
|
|
|
|
|
@router.post("/login")
|
|
async def login(req: LoginRequest):
|
|
email = _validate_email(req.email)
|
|
user = await _get_user_by_email(email)
|
|
if user is None or not user["is_active"]:
|
|
raise HTTPException(status_code=401, detail="邮箱或密码错误")
|
|
if not verify_password(req.password, user["password_hash"]):
|
|
raise HTTPException(status_code=401, detail="邮箱或密码错误")
|
|
|
|
token = create_access_token({"sub": str(user["id"]), "role": user["role"]})
|
|
return {
|
|
"token": token,
|
|
"user": {
|
|
"id": user["id"],
|
|
"username": user["username"],
|
|
"email": user["email"],
|
|
"role": user["role"],
|
|
},
|
|
}
|
|
|
|
|
|
@router.post("/send-register-code")
|
|
async def send_register_code(req: SendRegisterCodeRequest):
|
|
email = _validate_email(req.email)
|
|
if await _get_user_by_email(email):
|
|
raise HTTPException(status_code=400, detail="邮箱已注册")
|
|
|
|
invite_row = await _get_invite_code(req.invite_code)
|
|
_assert_invite_code_valid(invite_row)
|
|
|
|
async with get_db() as db:
|
|
result = await db.execute(
|
|
select(email_verification_codes_table)
|
|
.where(
|
|
email_verification_codes_table.c.email == email,
|
|
email_verification_codes_table.c.purpose == "register",
|
|
)
|
|
.order_by(email_verification_codes_table.c.id.desc())
|
|
)
|
|
last_code = result.mappings().first()
|
|
if last_code and last_code["created_at"]:
|
|
created_at = _coerce_naive_datetime(last_code["created_at"])
|
|
delta = datetime.utcnow() - created_at if created_at else timedelta.max
|
|
if delta.total_seconds() < settings.email_code_cooldown_seconds:
|
|
raise HTTPException(status_code=429, detail=f"发送过于频繁,请 {settings.email_code_cooldown_seconds} 秒后再试")
|
|
|
|
code = _generate_email_code()
|
|
subject, html, text = build_register_code_email(code)
|
|
try:
|
|
send_email(subject=subject, to_email=email, html=html, text=text)
|
|
except Exception as e:
|
|
logger.error("发送注册验证码失败: %s", e)
|
|
raise HTTPException(status_code=500, detail="验证码发送失败")
|
|
|
|
await _save_email_code(email, code, "register")
|
|
return {"message": "验证码已发送,请查收邮箱"}
|
|
|
|
|
|
@router.post("/register")
|
|
async def register(req: RegisterRequest):
|
|
email = _validate_email(req.email)
|
|
_validate_password(req.password)
|
|
if await _get_user_by_email(email):
|
|
raise HTTPException(status_code=400, detail="邮箱已注册")
|
|
|
|
invite_row = await _get_invite_code(req.invite_code)
|
|
_assert_invite_code_valid(invite_row)
|
|
await _assert_email_code_valid(email, req.email_code.strip(), "register")
|
|
|
|
username = _build_username_from_email(email)
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
users_table.insert().values(
|
|
username=username,
|
|
email=email,
|
|
password_hash=hash_password(req.password),
|
|
role="user",
|
|
is_active=True,
|
|
invite_code_used=req.invite_code.strip(),
|
|
)
|
|
)
|
|
await db.commit()
|
|
|
|
await _mark_email_code_used(email, req.email_code.strip(), "register")
|
|
await _consume_invite_code(req.invite_code)
|
|
return {"message": "注册成功,请使用邮箱和密码登录"}
|
|
|
|
|
|
@router.get("/me")
|
|
async def get_me(current_user: dict = Depends(get_current_user)):
|
|
return {
|
|
"id": current_user["id"],
|
|
"username": current_user["username"],
|
|
"email": current_user["email"],
|
|
"role": current_user["role"],
|
|
"is_active": current_user["is_active"],
|
|
}
|
|
|
|
|
|
@router.post("/change-password")
|
|
async def change_password(req: ChangePasswordRequest, current_user: dict = Depends(get_current_user)):
|
|
_validate_password(req.new_password)
|
|
if not verify_password(req.old_password, current_user["password_hash"]):
|
|
raise HTTPException(status_code=400, detail="旧密码错误")
|
|
|
|
async with get_db() as db:
|
|
await db.execute(
|
|
update(users_table)
|
|
.where(users_table.c.id == current_user["id"])
|
|
.values(password_hash=hash_password(req.new_password), updated_at=func.now())
|
|
)
|
|
await db.commit()
|
|
return {"message": "密码修改成功"}
|
|
|
|
|
|
@router.get("/users")
|
|
async def list_users(admin: dict = Depends(get_current_admin)):
|
|
async with get_db() as db:
|
|
result = await db.execute(select(users_table).order_by(users_table.c.id))
|
|
rows = result.mappings().all()
|
|
return [
|
|
{
|
|
"id": r["id"],
|
|
"username": r["username"],
|
|
"email": r["email"],
|
|
"role": r["role"],
|
|
"is_active": r["is_active"],
|
|
"invite_code_used": r.get("invite_code_used") or "",
|
|
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
@router.delete("/users/{user_id}")
|
|
async def disable_user(user_id: int, admin: dict = Depends(get_current_admin)):
|
|
if user_id == admin["id"]:
|
|
raise HTTPException(status_code=400, detail="不能禁用自己")
|
|
async with get_db() as db:
|
|
result = await db.execute(select(users_table).where(users_table.c.id == user_id))
|
|
user = result.mappings().first()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
await db.execute(
|
|
update(users_table).where(users_table.c.id == user_id).values(is_active=False, updated_at=func.now())
|
|
)
|
|
await db.commit()
|
|
return {"message": f"用户 {user['email']} 已禁用"}
|
|
|
|
|
|
@router.post("/users/{user_id}/reset-password")
|
|
async def reset_password(user_id: int, admin: dict = Depends(get_current_admin)):
|
|
new_password = "".join(random.choices(string.ascii_letters + string.digits, k=10))
|
|
async with get_db() as db:
|
|
result = await db.execute(select(users_table).where(users_table.c.id == user_id))
|
|
user = result.mappings().first()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
await db.execute(
|
|
update(users_table)
|
|
.where(users_table.c.id == user_id)
|
|
.values(password_hash=hash_password(new_password), updated_at=func.now())
|
|
)
|
|
await db.commit()
|
|
return {
|
|
"email": user["email"],
|
|
"password": new_password,
|
|
"message": "请妥善保管新密码,此密码仅显示一次",
|
|
}
|
|
|
|
|
|
@router.get("/invite-codes")
|
|
async def list_invite_codes(admin: dict = Depends(get_current_admin)):
|
|
async with get_db() as db:
|
|
result = await db.execute(select(invite_codes_table).order_by(invite_codes_table.c.id.desc()))
|
|
rows = result.mappings().all()
|
|
return [
|
|
{
|
|
"id": r["id"],
|
|
"code": r["code"],
|
|
"description": r["description"] or "",
|
|
"is_active": r["is_active"],
|
|
"max_uses": r["max_uses"],
|
|
"used_count": r["used_count"],
|
|
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
@router.post("/invite-codes")
|
|
async def create_invite_code(req: CreateInviteCodeRequest, admin: dict = Depends(get_current_admin)):
|
|
async with get_db() as db:
|
|
result = await db.execute(select(invite_codes_table).where(invite_codes_table.c.code == req.code.strip()))
|
|
if result.first():
|
|
raise HTTPException(status_code=400, detail="邀请码已存在")
|
|
await db.execute(
|
|
invite_codes_table.insert().values(
|
|
code=req.code.strip(),
|
|
description=req.description.strip(),
|
|
max_uses=max(1, req.max_uses),
|
|
used_count=0,
|
|
is_active=True,
|
|
created_by=admin["id"],
|
|
)
|
|
)
|
|
await db.commit()
|
|
return {"message": "邀请码创建成功", "code": req.code.strip()}
|
|
|
|
|
|
@router.post("/invite-codes/{invite_id}/toggle")
|
|
async def toggle_invite_code(invite_id: int, admin: dict = Depends(get_current_admin)):
|
|
async with get_db() as db:
|
|
result = await db.execute(select(invite_codes_table).where(invite_codes_table.c.id == invite_id))
|
|
row = result.mappings().first()
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="邀请码不存在")
|
|
await db.execute(
|
|
update(invite_codes_table)
|
|
.where(invite_codes_table.c.id == invite_id)
|
|
.values(is_active=not row["is_active"], updated_at=func.now())
|
|
)
|
|
await db.commit()
|
|
return {"message": "邀请码状态已更新"}
|
|
|
|
|
|
@router.get("/data-stats")
|
|
async def get_data_stats(admin: dict = Depends(get_current_admin)):
|
|
async with get_db() as db:
|
|
rec_count = (await db.execute(text("SELECT COUNT(*) FROM recommendations"))).scalar() or 0
|
|
track_count = (await db.execute(text("SELECT COUNT(*) FROM recommendation_tracking"))).scalar() or 0
|
|
sector_count = (await db.execute(text("SELECT COUNT(*) FROM sector_heat"))).scalar() or 0
|
|
temp_count = (await db.execute(text("SELECT COUNT(*) FROM market_temperature"))).scalar() or 0
|
|
diagnosis_count = (await db.execute(text("SELECT COUNT(*) FROM stock_diagnoses"))).scalar() or 0
|
|
watchlist_analysis_count = (await db.execute(text("SELECT COUNT(*) FROM watchlist_analyses"))).scalar() or 0
|
|
user_count = (await db.execute(text("SELECT COUNT(*) FROM users"))).scalar() or 0
|
|
invite_count = (await db.execute(text("SELECT COUNT(*) FROM invite_codes"))).scalar() or 0
|
|
error_log_count = (await db.execute(text("SELECT COUNT(*) FROM error_logs"))).scalar() or 0
|
|
scan_log_count = (await db.execute(text("SELECT COUNT(*) FROM scan_process_logs"))).scalar() or 0
|
|
low_score = (await db.execute(text("SELECT COUNT(*) FROM recommendations WHERE score < 60"))).scalar() or 0
|
|
latest_rec = (await db.execute(text("SELECT MAX(date(created_at)) FROM recommendations"))).scalar() or ""
|
|
earliest_rec = (await db.execute(text("SELECT MIN(date(created_at)) FROM recommendations"))).scalar() or ""
|
|
return {
|
|
"recommendations": rec_count,
|
|
"tracking": track_count,
|
|
"sector_heat": sector_count,
|
|
"market_temperature": temp_count,
|
|
"stock_diagnoses": diagnosis_count,
|
|
"watchlist_analyses": watchlist_analysis_count,
|
|
"users": user_count,
|
|
"invite_codes": invite_count,
|
|
"error_logs": error_log_count,
|
|
"scan_logs": scan_log_count,
|
|
"low_score_count": low_score,
|
|
"latest_date": str(latest_rec),
|
|
"earliest_date": str(earliest_rec),
|
|
}
|
|
|
|
|
|
@router.post("/data-reset")
|
|
async def data_reset(req: DataResetRequest, admin: dict = Depends(get_current_admin)):
|
|
deleted: dict[str, int] = {}
|
|
async with get_db() as db:
|
|
if req.mode == "all":
|
|
for table in ["recommendation_tracking", "recommendations", "sector_heat", "market_temperature", "stock_diagnoses", "watchlist_analyses"]:
|
|
result = await db.execute(text(f"DELETE FROM {table}"))
|
|
deleted[table] = result.rowcount or 0
|
|
elif req.mode == "recommendations":
|
|
for table in ["recommendation_tracking", "recommendations"]:
|
|
result = await db.execute(text(f"DELETE FROM {table}"))
|
|
deleted[table] = result.rowcount or 0
|
|
elif req.mode == "market_cache":
|
|
for table in ["sector_heat", "market_temperature"]:
|
|
result = await db.execute(text(f"DELETE FROM {table}"))
|
|
deleted[table] = result.rowcount or 0
|
|
elif req.mode == "diagnostics":
|
|
for table in ["stock_diagnoses", "watchlist_analyses"]:
|
|
result = await db.execute(text(f"DELETE FROM {table}"))
|
|
deleted[table] = result.rowcount or 0
|
|
elif req.mode == "logs":
|
|
for table in ["error_logs", "scan_process_logs", "research_observations"]:
|
|
result = await db.execute(text(f"DELETE FROM {table}"))
|
|
deleted[table] = result.rowcount or 0
|
|
elif req.mode == "date_range":
|
|
if not req.before_date:
|
|
raise HTTPException(status_code=400, detail="date_range 模式需要 before_date 参数")
|
|
result = await db.execute(text("DELETE FROM recommendation_tracking WHERE track_date < :bd"), {"bd": req.before_date})
|
|
deleted["tracking"] = result.rowcount or 0
|
|
result = await db.execute(text("DELETE FROM recommendations WHERE date(created_at) < :bd"), {"bd": req.before_date})
|
|
deleted["recommendations"] = result.rowcount or 0
|
|
result = await db.execute(text("DELETE FROM sector_heat WHERE trade_date < :bd"), {"bd": req.before_date})
|
|
deleted["sector_heat"] = result.rowcount or 0
|
|
result = await db.execute(text("DELETE FROM market_temperature WHERE trade_date < :bd"), {"bd": req.before_date})
|
|
deleted["market_temperature"] = result.rowcount or 0
|
|
result = await db.execute(text("DELETE FROM stock_diagnoses WHERE date(created_at) < :bd"), {"bd": req.before_date})
|
|
deleted["stock_diagnoses"] = result.rowcount or 0
|
|
result = await db.execute(text("DELETE FROM watchlist_analyses WHERE date(created_at) < :bd"), {"bd": req.before_date})
|
|
deleted["watchlist_analyses"] = result.rowcount or 0
|
|
elif req.mode == "low_score":
|
|
threshold = req.min_score or 60
|
|
result = await db.execute(
|
|
text("DELETE FROM recommendation_tracking WHERE recommendation_id IN (SELECT id FROM recommendations WHERE score < :ms)"),
|
|
{"ms": threshold},
|
|
)
|
|
deleted["tracking"] = result.rowcount or 0
|
|
result = await db.execute(text("DELETE FROM recommendations WHERE score < :ms"), {"ms": threshold})
|
|
deleted["recommendations"] = result.rowcount or 0
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"不支持的模式: {req.mode}")
|
|
await db.commit()
|
|
|
|
logger.info("管理员 %s 执行数据重置: mode=%s deleted=%s", admin["email"], req.mode, deleted)
|
|
return {"status": "ok", "mode": req.mode, "deleted": deleted}
|