stock-ai-agent/backend/app/services/auth_service.py
2026-02-04 14:56:03 +08:00

195 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
认证服务
"""
from datetime import datetime, timedelta
from typing import Dict, Optional
from sqlalchemy.orm import Session
from app.models.database import User, VerificationCode
from app.services.sms_service import sms_service
from app.services.jwt_service import jwt_service
from app.config import get_settings
from app.utils.logger import logger
class AuthService:
"""认证服务类"""
def __init__(self):
"""初始化认证服务"""
settings = get_settings()
self.code_expire_minutes = settings.code_expire_minutes
self.code_resend_seconds = settings.code_resend_seconds
self.code_max_per_hour = settings.code_max_per_hour
# 白名单手机号列表
self.whitelist_phones = [p.strip() for p in settings.whitelist_phones.split(",") if p.strip()]
logger.info(f"白名单手机号: {self.whitelist_phones}")
async def send_verification_code(
self,
db: Session,
phone: str,
ip_address: str
) -> Dict:
"""
发送验证码
Args:
db: 数据库会话
phone: 手机号
ip_address: IP地址
Returns:
{"success": bool, "message": str, "expires_in": int}
"""
# 1. 检查发送频率限制60秒
last_code = db.query(VerificationCode).filter(
VerificationCode.phone == phone,
VerificationCode.created_at > datetime.utcnow() - timedelta(seconds=self.code_resend_seconds)
).first()
if last_code:
remaining = self.code_resend_seconds - int((datetime.utcnow() - last_code.created_at).total_seconds())
return {
"success": False,
"message": f"{remaining}秒后再试"
}
# 2. 检查IP限制每小时10次
ip_count = db.query(VerificationCode).filter(
VerificationCode.ip_address == ip_address,
VerificationCode.created_at > datetime.utcnow() - timedelta(hours=1)
).count()
if ip_count >= self.code_max_per_hour:
return {
"success": False,
"message": "发送次数过多,请稍后再试"
}
# 3. 生成验证码
code = sms_service.generate_code()
# 4. 发送短信
success = await sms_service.send_code(phone, code)
if not success:
return {
"success": False,
"message": "发送失败,请稍后重试"
}
# 5. 保存验证码记录
verification = VerificationCode(
phone=phone,
code=code,
expires_at=datetime.utcnow() + timedelta(minutes=self.code_expire_minutes),
ip_address=ip_address
)
db.add(verification)
db.commit()
logger.info(f"验证码已发送: {phone}")
return {
"success": True,
"message": "验证码已发送",
"expires_in": self.code_expire_minutes * 60
}
async def login_with_code(
self,
db: Session,
phone: str,
code: str,
ip_address: str
) -> Dict:
"""
验证码登录
Args:
db: 数据库会话
phone: 手机号
code: 验证码
ip_address: IP地址
Returns:
{"success": bool, "token": str, "user": dict, "message": str}
"""
# 检查是否为白名单手机号
is_whitelist = phone in self.whitelist_phones
if is_whitelist:
logger.info(f"白名单手机号登录: {phone}")
else:
# 1. 查找验证码(非白名单需要验证)
verification = db.query(VerificationCode).filter(
VerificationCode.phone == phone,
VerificationCode.code == code,
VerificationCode.is_used == False,
VerificationCode.expires_at > datetime.utcnow()
).order_by(VerificationCode.created_at.desc()).first()
if not verification:
return {
"success": False,
"message": "验证码错误或已过期"
}
# 2. 标记验证码已使用
verification.is_used = True
verification.used_at = datetime.utcnow()
# 3. 查找或创建用户
user = db.query(User).filter(User.phone == phone).first()
if not user:
# 自动注册
user = User(phone=phone)
db.add(user)
db.flush()
logger.info(f"新用户注册: {phone}")
# 4. 更新最后登录时间
user.last_login_at = datetime.utcnow()
# 5. 关联验证码到用户(如果不是白名单)
if not is_whitelist:
verification.user_id = user.id
db.commit()
db.refresh(user)
# 6. 生成JWT token
token = jwt_service.create_access_token(user.id, user.phone)
logger.info(f"用户登录成功: {phone}")
return {
"success": True,
"token": token,
"user": {
"id": user.id,
"phone": self._mask_phone(user.phone),
"created_at": user.created_at.isoformat(),
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None
}
}
def _mask_phone(self, phone: str) -> str:
"""
手机号脱敏
Args:
phone: 手机号
Returns:
脱敏后的手机号
"""
if len(phone) == 11:
return f"{phone[:3]}****{phone[-4:]}"
return phone
# 创建全局实例
auth_service = AuthService()