108 lines
3.0 KiB
Python
108 lines
3.0 KiB
Python
import hashlib
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import EmailVerificationCode
|
|
|
|
|
|
ACTIVATION_EMAIL_PURPOSE = "account_activation"
|
|
EMAIL_CODE_TTL_MINUTES = 10
|
|
EMAIL_CODE_RESEND_SECONDS = 60
|
|
|
|
|
|
def _hash_code(code: str) -> str:
|
|
return hashlib.sha256(code.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def generate_email_code() -> str:
|
|
return f"{secrets.randbelow(1000000):06d}"
|
|
|
|
|
|
async def issue_email_verification_code(
|
|
db: AsyncSession,
|
|
*,
|
|
email: str,
|
|
purpose: str = ACTIVATION_EMAIL_PURPOSE,
|
|
) -> str:
|
|
normalized_email = email.strip().lower()
|
|
now = datetime.now(timezone.utc)
|
|
|
|
latest_result = await db.execute(
|
|
select(EmailVerificationCode)
|
|
.where(
|
|
EmailVerificationCode.email == normalized_email,
|
|
EmailVerificationCode.purpose == purpose,
|
|
EmailVerificationCode.consumed_at.is_(None),
|
|
)
|
|
.order_by(EmailVerificationCode.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
latest = latest_result.scalar_one_or_none()
|
|
if latest and latest.created_at and (now - latest.created_at.replace(tzinfo=timezone.utc)).total_seconds() < EMAIL_CODE_RESEND_SECONDS:
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"验证码发送过于频繁,请在 {EMAIL_CODE_RESEND_SECONDS} 秒后重试",
|
|
)
|
|
|
|
await db.execute(
|
|
update(EmailVerificationCode)
|
|
.where(
|
|
EmailVerificationCode.email == normalized_email,
|
|
EmailVerificationCode.purpose == purpose,
|
|
EmailVerificationCode.consumed_at.is_(None),
|
|
)
|
|
.values(consumed_at=now)
|
|
)
|
|
|
|
code = generate_email_code()
|
|
db.add(
|
|
EmailVerificationCode(
|
|
email=normalized_email,
|
|
purpose=purpose,
|
|
code_hash=_hash_code(code),
|
|
expires_at=now + timedelta(minutes=EMAIL_CODE_TTL_MINUTES),
|
|
)
|
|
)
|
|
await db.commit()
|
|
return code
|
|
|
|
|
|
async def verify_email_code(
|
|
db: AsyncSession,
|
|
*,
|
|
email: str,
|
|
code: str,
|
|
purpose: str = ACTIVATION_EMAIL_PURPOSE,
|
|
) -> bool:
|
|
normalized_email = email.strip().lower()
|
|
normalized_code = code.strip()
|
|
if not normalized_code:
|
|
return False
|
|
|
|
now = datetime.now(timezone.utc)
|
|
result = await db.execute(
|
|
select(EmailVerificationCode)
|
|
.where(
|
|
EmailVerificationCode.email == normalized_email,
|
|
EmailVerificationCode.purpose == purpose,
|
|
EmailVerificationCode.consumed_at.is_(None),
|
|
)
|
|
.order_by(EmailVerificationCode.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
record = result.scalar_one_or_none()
|
|
if record is None:
|
|
return False
|
|
|
|
expires_at = record.expires_at.replace(tzinfo=timezone.utc)
|
|
if expires_at < now or record.code_hash != _hash_code(normalized_code):
|
|
return False
|
|
|
|
record.consumed_at = now
|
|
await db.commit()
|
|
return True
|