astock-agent/backend/app/api/auth.py
2026-05-28 23:32:52 +08:00

512 lines
20 KiB
Python

"""认证 API。
邮箱+密码登录。
邀请码 + 邮箱验证码 + 密码注册。
"""
from __future__ import annotations
import logging
import random
import string
import re
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select, update, text, func, delete
from app.config import settings
from app.core.auth import hash_password, verify_password, create_access_token
from app.core.deps import get_current_user, get_current_admin
from app.core.email import send_email, build_register_code_email
from app.db.database import get_db
from app.db.tables import users_table, email_verification_codes_table, invite_codes_table
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["auth"])
class LoginRequest(BaseModel):
email: str
password: str
class SendRegisterCodeRequest(BaseModel):
email: str
invite_code: str = Field(min_length=4, max_length=64)
class RegisterRequest(BaseModel):
email: str
invite_code: str = Field(min_length=4, max_length=64)
email_code: str = Field(min_length=6, max_length=6)
password: str = Field(min_length=6)
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str = Field(min_length=6)
class CreateInviteCodeRequest(BaseModel):
code: str = Field(min_length=4, max_length=64)
description: str = ""
max_uses: int = 1
class DataResetRequest(BaseModel):
mode: str
before_date: str | None = None
min_score: int | None = None
def _normalize_email(email: str) -> str:
return str(email or "").strip().lower()
def _validate_email(email: str) -> str:
value = _normalize_email(email)
if not re.fullmatch(r"[^@\s]+@[^@\s]+\.[^@\s]+", value):
raise HTTPException(status_code=400, detail="邮箱格式错误")
return value
def _validate_password(password: str) -> None:
if len(password or "") < settings.auth_min_password_length:
raise HTTPException(status_code=400, detail=f"密码至少 {settings.auth_min_password_length}")
def _build_username_from_email(email: str) -> str:
return _normalize_email(email)
def _generate_email_code() -> str:
return "".join(random.choices(string.digits, k=6))
def _coerce_naive_datetime(value) -> datetime | None:
if value is None:
return None
if isinstance(value, datetime):
return value.replace(tzinfo=None)
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00")).replace(tzinfo=None)
return None
async def _get_user_by_email(email: str) -> dict | None:
async with get_db() as db:
result = await db.execute(
select(users_table).where(users_table.c.email == _normalize_email(email))
)
user = result.mappings().first()
return dict(user) if user else None
async def _get_invite_code(code: str) -> dict | None:
async with get_db() as db:
result = await db.execute(
select(invite_codes_table).where(invite_codes_table.c.code == code.strip())
)
row = result.mappings().first()
return dict(row) if row else None
def _assert_invite_code_valid(invite_row: dict | None) -> None:
if not settings.invite_code_required:
return
if not invite_row:
raise HTTPException(status_code=400, detail="邀请码无效")
if not invite_row["is_active"]:
raise HTTPException(status_code=400, detail="邀请码已停用")
if invite_row["max_uses"] is not None and invite_row["used_count"] >= invite_row["max_uses"]:
raise HTTPException(status_code=400, detail="邀请码已用完")
expires_at = _coerce_naive_datetime(invite_row.get("expires_at"))
if expires_at and expires_at < datetime.utcnow():
raise HTTPException(status_code=400, detail="邀请码已过期")
async def _consume_invite_code(code: str) -> None:
if not settings.invite_code_required:
return
async with get_db() as db:
await db.execute(
update(invite_codes_table)
.where(invite_codes_table.c.code == code.strip())
.values(
used_count=invite_codes_table.c.used_count + 1,
updated_at=func.now(),
)
)
await db.commit()
async def _save_email_code(email: str, code: str, purpose: str) -> None:
expires_at = datetime.utcnow() + timedelta(minutes=settings.email_code_expiry_minutes)
async with get_db() as db:
await db.execute(
delete(email_verification_codes_table).where(
email_verification_codes_table.c.email == email,
email_verification_codes_table.c.purpose == purpose,
)
)
await db.execute(
email_verification_codes_table.insert().values(
email=email,
code=code,
purpose=purpose,
expires_at=expires_at,
used=False,
)
)
await db.commit()
async def _assert_email_code_valid(email: str, code: str, purpose: str) -> None:
async with get_db() as db:
result = await db.execute(
select(email_verification_codes_table)
.where(
email_verification_codes_table.c.email == email,
email_verification_codes_table.c.code == code,
email_verification_codes_table.c.purpose == purpose,
email_verification_codes_table.c.used == False, # noqa: E712
)
.order_by(email_verification_codes_table.c.id.desc())
)
row = result.mappings().first()
if not row:
raise HTTPException(status_code=400, detail="邮箱验证码错误")
if row["expires_at"] < datetime.utcnow():
raise HTTPException(status_code=400, detail="邮箱验证码已过期")
async def _mark_email_code_used(email: str, code: str, purpose: str) -> None:
async with get_db() as db:
await db.execute(
update(email_verification_codes_table)
.where(
email_verification_codes_table.c.email == email,
email_verification_codes_table.c.code == code,
email_verification_codes_table.c.purpose == purpose,
)
.values(used=True)
)
await db.commit()
@router.post("/login")
async def login(req: LoginRequest):
email = _validate_email(req.email)
user = await _get_user_by_email(email)
if user is None or not user["is_active"]:
raise HTTPException(status_code=401, detail="邮箱或密码错误")
if not verify_password(req.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="邮箱或密码错误")
token = create_access_token({"sub": str(user["id"]), "role": user["role"]})
return {
"token": token,
"user": {
"id": user["id"],
"username": user["username"],
"email": user["email"],
"role": user["role"],
},
}
@router.post("/send-register-code")
async def send_register_code(req: SendRegisterCodeRequest):
email = _validate_email(req.email)
if await _get_user_by_email(email):
raise HTTPException(status_code=400, detail="邮箱已注册")
invite_row = await _get_invite_code(req.invite_code)
_assert_invite_code_valid(invite_row)
async with get_db() as db:
result = await db.execute(
select(email_verification_codes_table)
.where(
email_verification_codes_table.c.email == email,
email_verification_codes_table.c.purpose == "register",
)
.order_by(email_verification_codes_table.c.id.desc())
)
last_code = result.mappings().first()
if last_code and last_code["created_at"]:
created_at = _coerce_naive_datetime(last_code["created_at"])
delta = datetime.utcnow() - created_at if created_at else timedelta.max
if delta.total_seconds() < settings.email_code_cooldown_seconds:
raise HTTPException(status_code=429, detail=f"发送过于频繁,请 {settings.email_code_cooldown_seconds} 秒后再试")
code = _generate_email_code()
subject, html, text = build_register_code_email(code)
try:
send_email(subject=subject, to_email=email, html=html, text=text)
except Exception as e:
logger.error("发送注册验证码失败: %s", e)
raise HTTPException(status_code=500, detail="验证码发送失败")
await _save_email_code(email, code, "register")
return {"message": "验证码已发送,请查收邮箱"}
@router.post("/register")
async def register(req: RegisterRequest):
email = _validate_email(req.email)
_validate_password(req.password)
if await _get_user_by_email(email):
raise HTTPException(status_code=400, detail="邮箱已注册")
invite_row = await _get_invite_code(req.invite_code)
_assert_invite_code_valid(invite_row)
await _assert_email_code_valid(email, req.email_code.strip(), "register")
username = _build_username_from_email(email)
async with get_db() as db:
await db.execute(
users_table.insert().values(
username=username,
email=email,
password_hash=hash_password(req.password),
role="user",
is_active=True,
invite_code_used=req.invite_code.strip(),
)
)
await db.commit()
await _mark_email_code_used(email, req.email_code.strip(), "register")
await _consume_invite_code(req.invite_code)
return {"message": "注册成功,请使用邮箱和密码登录"}
@router.get("/me")
async def get_me(current_user: dict = Depends(get_current_user)):
return {
"id": current_user["id"],
"username": current_user["username"],
"email": current_user["email"],
"role": current_user["role"],
"is_active": current_user["is_active"],
}
@router.post("/change-password")
async def change_password(req: ChangePasswordRequest, current_user: dict = Depends(get_current_user)):
_validate_password(req.new_password)
if not verify_password(req.old_password, current_user["password_hash"]):
raise HTTPException(status_code=400, detail="旧密码错误")
async with get_db() as db:
await db.execute(
update(users_table)
.where(users_table.c.id == current_user["id"])
.values(password_hash=hash_password(req.new_password), updated_at=func.now())
)
await db.commit()
return {"message": "密码修改成功"}
@router.get("/users")
async def list_users(admin: dict = Depends(get_current_admin)):
async with get_db() as db:
result = await db.execute(select(users_table).order_by(users_table.c.id))
rows = result.mappings().all()
return [
{
"id": r["id"],
"username": r["username"],
"email": r["email"],
"role": r["role"],
"is_active": r["is_active"],
"invite_code_used": r.get("invite_code_used") or "",
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
}
for r in rows
]
@router.delete("/users/{user_id}")
async def disable_user(user_id: int, admin: dict = Depends(get_current_admin)):
if user_id == admin["id"]:
raise HTTPException(status_code=400, detail="不能禁用自己")
async with get_db() as db:
result = await db.execute(select(users_table).where(users_table.c.id == user_id))
user = result.mappings().first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
await db.execute(
update(users_table).where(users_table.c.id == user_id).values(is_active=False, updated_at=func.now())
)
await db.commit()
return {"message": f"用户 {user['email']} 已禁用"}
@router.post("/users/{user_id}/reset-password")
async def reset_password(user_id: int, admin: dict = Depends(get_current_admin)):
new_password = "".join(random.choices(string.ascii_letters + string.digits, k=10))
async with get_db() as db:
result = await db.execute(select(users_table).where(users_table.c.id == user_id))
user = result.mappings().first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
await db.execute(
update(users_table)
.where(users_table.c.id == user_id)
.values(password_hash=hash_password(new_password), updated_at=func.now())
)
await db.commit()
return {
"email": user["email"],
"password": new_password,
"message": "请妥善保管新密码,此密码仅显示一次",
}
@router.get("/invite-codes")
async def list_invite_codes(admin: dict = Depends(get_current_admin)):
async with get_db() as db:
result = await db.execute(select(invite_codes_table).order_by(invite_codes_table.c.id.desc()))
rows = result.mappings().all()
return [
{
"id": r["id"],
"code": r["code"],
"description": r["description"] or "",
"is_active": r["is_active"],
"max_uses": r["max_uses"],
"used_count": r["used_count"],
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
}
for r in rows
]
@router.post("/invite-codes")
async def create_invite_code(req: CreateInviteCodeRequest, admin: dict = Depends(get_current_admin)):
async with get_db() as db:
result = await db.execute(select(invite_codes_table).where(invite_codes_table.c.code == req.code.strip()))
if result.first():
raise HTTPException(status_code=400, detail="邀请码已存在")
await db.execute(
invite_codes_table.insert().values(
code=req.code.strip(),
description=req.description.strip(),
max_uses=max(1, req.max_uses),
used_count=0,
is_active=True,
created_by=admin["id"],
)
)
await db.commit()
return {"message": "邀请码创建成功", "code": req.code.strip()}
@router.post("/invite-codes/{invite_id}/toggle")
async def toggle_invite_code(invite_id: int, admin: dict = Depends(get_current_admin)):
async with get_db() as db:
result = await db.execute(select(invite_codes_table).where(invite_codes_table.c.id == invite_id))
row = result.mappings().first()
if not row:
raise HTTPException(status_code=404, detail="邀请码不存在")
await db.execute(
update(invite_codes_table)
.where(invite_codes_table.c.id == invite_id)
.values(is_active=not row["is_active"], updated_at=func.now())
)
await db.commit()
return {"message": "邀请码状态已更新"}
@router.get("/data-stats")
async def get_data_stats(admin: dict = Depends(get_current_admin)):
async with get_db() as db:
rec_count = (await db.execute(text("SELECT COUNT(*) FROM recommendations"))).scalar() or 0
track_count = (await db.execute(text("SELECT COUNT(*) FROM recommendation_tracking"))).scalar() or 0
sector_count = (await db.execute(text("SELECT COUNT(*) FROM sector_heat"))).scalar() or 0
temp_count = (await db.execute(text("SELECT COUNT(*) FROM market_temperature"))).scalar() or 0
diagnosis_count = (await db.execute(text("SELECT COUNT(*) FROM stock_diagnoses"))).scalar() or 0
watchlist_analysis_count = (await db.execute(text("SELECT COUNT(*) FROM watchlist_analyses"))).scalar() or 0
user_count = (await db.execute(text("SELECT COUNT(*) FROM users"))).scalar() or 0
invite_count = (await db.execute(text("SELECT COUNT(*) FROM invite_codes"))).scalar() or 0
error_log_count = (await db.execute(text("SELECT COUNT(*) FROM error_logs"))).scalar() or 0
scan_log_count = (await db.execute(text("SELECT COUNT(*) FROM scan_process_logs"))).scalar() or 0
low_score = (await db.execute(text("SELECT COUNT(*) FROM recommendations WHERE score < 60"))).scalar() or 0
latest_rec = (await db.execute(text("SELECT MAX(date(created_at)) FROM recommendations"))).scalar() or ""
earliest_rec = (await db.execute(text("SELECT MIN(date(created_at)) FROM recommendations"))).scalar() or ""
return {
"recommendations": rec_count,
"tracking": track_count,
"sector_heat": sector_count,
"market_temperature": temp_count,
"stock_diagnoses": diagnosis_count,
"watchlist_analyses": watchlist_analysis_count,
"users": user_count,
"invite_codes": invite_count,
"error_logs": error_log_count,
"scan_logs": scan_log_count,
"low_score_count": low_score,
"latest_date": str(latest_rec),
"earliest_date": str(earliest_rec),
}
@router.post("/data-reset")
async def data_reset(req: DataResetRequest, admin: dict = Depends(get_current_admin)):
deleted: dict[str, int] = {}
async with get_db() as db:
if req.mode == "all":
for table in ["recommendation_tracking", "recommendations", "sector_heat", "market_temperature", "stock_diagnoses", "watchlist_analyses"]:
result = await db.execute(text(f"DELETE FROM {table}"))
deleted[table] = result.rowcount or 0
elif req.mode == "recommendations":
for table in ["recommendation_tracking", "recommendations"]:
result = await db.execute(text(f"DELETE FROM {table}"))
deleted[table] = result.rowcount or 0
elif req.mode == "market_cache":
for table in ["sector_heat", "market_temperature"]:
result = await db.execute(text(f"DELETE FROM {table}"))
deleted[table] = result.rowcount or 0
elif req.mode == "diagnostics":
for table in ["stock_diagnoses", "watchlist_analyses"]:
result = await db.execute(text(f"DELETE FROM {table}"))
deleted[table] = result.rowcount or 0
elif req.mode == "logs":
for table in ["error_logs", "scan_process_logs", "research_observations"]:
result = await db.execute(text(f"DELETE FROM {table}"))
deleted[table] = result.rowcount or 0
elif req.mode == "date_range":
if not req.before_date:
raise HTTPException(status_code=400, detail="date_range 模式需要 before_date 参数")
result = await db.execute(text("DELETE FROM recommendation_tracking WHERE track_date < :bd"), {"bd": req.before_date})
deleted["tracking"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM recommendations WHERE date(created_at) < :bd"), {"bd": req.before_date})
deleted["recommendations"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM sector_heat WHERE trade_date < :bd"), {"bd": req.before_date})
deleted["sector_heat"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM market_temperature WHERE trade_date < :bd"), {"bd": req.before_date})
deleted["market_temperature"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM stock_diagnoses WHERE date(created_at) < :bd"), {"bd": req.before_date})
deleted["stock_diagnoses"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM watchlist_analyses WHERE date(created_at) < :bd"), {"bd": req.before_date})
deleted["watchlist_analyses"] = result.rowcount or 0
elif req.mode == "low_score":
threshold = req.min_score or 60
result = await db.execute(
text("DELETE FROM recommendation_tracking WHERE recommendation_id IN (SELECT id FROM recommendations WHERE score < :ms)"),
{"ms": threshold},
)
deleted["tracking"] = result.rowcount or 0
result = await db.execute(text("DELETE FROM recommendations WHERE score < :ms"), {"ms": threshold})
deleted["recommendations"] = result.rowcount or 0
else:
raise HTTPException(status_code=400, detail=f"不支持的模式: {req.mode}")
await db.commit()
logger.info("管理员 %s 执行数据重置: mode=%s deleted=%s", admin["email"], req.mode, deleted)
return {"status": "ok", "mode": req.mode, "deleted": deleted}