138 lines
4.9 KiB
Python
138 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import httpx
|
|
from fastapi import HTTPException
|
|
from jose import JWTError, jwt
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.config import settings
|
|
from app.core.auth import create_access_token
|
|
from app.db.models import ClassMembership, User
|
|
from app.schemas.user import build_user_out
|
|
|
|
WECHAT_BIND_PURPOSE = "wechat_bind"
|
|
WECHAT_BIND_TOKEN_EXP_SECONDS = 15 * 60
|
|
|
|
|
|
def _require_wechat_config() -> None:
|
|
if not settings.wechat_mini_app_id or not settings.wechat_mini_app_secret:
|
|
raise HTTPException(status_code=500, detail="微信小程序配置未完成")
|
|
|
|
|
|
async def exchange_code_for_session(code: str) -> dict:
|
|
_require_wechat_config()
|
|
async with httpx.AsyncClient(timeout=settings.wechat_api_timeout_seconds) as client:
|
|
response = await client.get(
|
|
"https://api.weixin.qq.com/sns/jscode2session",
|
|
params={
|
|
"appid": settings.wechat_mini_app_id,
|
|
"secret": settings.wechat_mini_app_secret,
|
|
"js_code": code,
|
|
"grant_type": "authorization_code",
|
|
},
|
|
)
|
|
data = response.json()
|
|
if data.get("errcode"):
|
|
raise HTTPException(status_code=400, detail=data.get("errmsg") or "微信登录失败")
|
|
if not data.get("openid"):
|
|
raise HTTPException(status_code=400, detail="微信登录未返回 openid")
|
|
return data
|
|
|
|
|
|
async def _get_access_token() -> str:
|
|
_require_wechat_config()
|
|
async with httpx.AsyncClient(timeout=settings.wechat_api_timeout_seconds) as client:
|
|
response = await client.get(
|
|
"https://api.weixin.qq.com/cgi-bin/token",
|
|
params={
|
|
"grant_type": "client_credential",
|
|
"appid": settings.wechat_mini_app_id,
|
|
"secret": settings.wechat_mini_app_secret,
|
|
},
|
|
)
|
|
data = response.json()
|
|
if data.get("errcode"):
|
|
raise HTTPException(status_code=400, detail=data.get("errmsg") or "获取微信 access_token 失败")
|
|
access_token = data.get("access_token")
|
|
if not access_token:
|
|
raise HTTPException(status_code=400, detail="微信未返回 access_token")
|
|
return access_token
|
|
|
|
|
|
async def exchange_phone_code(phone_code: str) -> str:
|
|
access_token = await _get_access_token()
|
|
async with httpx.AsyncClient(timeout=settings.wechat_api_timeout_seconds) as client:
|
|
response = await client.post(
|
|
"https://api.weixin.qq.com/wxa/business/getuserphonenumber",
|
|
params={"access_token": access_token},
|
|
json={"code": phone_code},
|
|
)
|
|
data = response.json()
|
|
if data.get("errcode"):
|
|
raise HTTPException(status_code=400, detail=data.get("errmsg") or "获取微信手机号失败")
|
|
phone_info = data.get("phone_info") or {}
|
|
phone = phone_info.get("purePhoneNumber") or phone_info.get("phoneNumber")
|
|
if not phone:
|
|
raise HTTPException(status_code=400, detail="微信未返回手机号")
|
|
return str(phone)
|
|
|
|
|
|
def create_bind_token(openid: str, unionid: str | None = None) -> str:
|
|
payload = {
|
|
"purpose": WECHAT_BIND_PURPOSE,
|
|
"openid": openid,
|
|
"unionid": unionid,
|
|
"exp": datetime.now(timezone.utc) + timedelta(seconds=WECHAT_BIND_TOKEN_EXP_SECONDS),
|
|
}
|
|
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
|
|
|
|
|
def decode_bind_token(token: str) -> tuple[str, str | None]:
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.jwt_secret,
|
|
algorithms=[settings.jwt_algorithm],
|
|
)
|
|
except JWTError:
|
|
raise HTTPException(status_code=401, detail="微信绑定凭证无效或已过期")
|
|
if payload.get("purpose") != WECHAT_BIND_PURPOSE or not payload.get("openid"):
|
|
raise HTTPException(status_code=401, detail="微信绑定凭证无效")
|
|
return str(payload["openid"]), payload.get("unionid")
|
|
|
|
|
|
async def get_user_by_openid(db: AsyncSession, openid: str) -> User | None:
|
|
result = await db.execute(
|
|
select(User)
|
|
.options(
|
|
selectinload(User.memberships),
|
|
selectinload(User.memberships).selectinload(ClassMembership.class_),
|
|
)
|
|
.where(User.wechat_openid == openid)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
def build_token_payload(user: User) -> dict:
|
|
return {"sub": str(user.id), "role": user.role}
|
|
|
|
|
|
def mark_wechat_bound(user: User, openid: str, unionid: str | None, phone: str) -> None:
|
|
user.wechat_openid = openid
|
|
user.wechat_unionid = unionid
|
|
user.phone = phone
|
|
user.phone_verified_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
user.status = "approved"
|
|
|
|
|
|
def build_wechat_login_result(user: User, class_id: int | None = None) -> dict:
|
|
return {
|
|
"binding_required": False,
|
|
"token": create_access_token(build_token_payload(user)),
|
|
"user": build_user_out(user, class_id),
|
|
}
|