from __future__ import annotations from datetime import datetime, timedelta, timezone import httpx from fastapi import HTTPException from jose import JWTError, jwt from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.config import settings from app.core.auth import create_access_token from app.db.models import ClassMembership, User from app.schemas.user import build_user_out WECHAT_BIND_PURPOSE = "wechat_bind" WECHAT_BIND_TOKEN_EXP_SECONDS = 15 * 60 def _require_wechat_config() -> None: if not settings.wechat_mini_app_id or not settings.wechat_mini_app_secret: raise HTTPException(status_code=500, detail="微信小程序配置未完成") async def exchange_code_for_session(code: str) -> dict: _require_wechat_config() async with httpx.AsyncClient(timeout=settings.wechat_api_timeout_seconds) as client: response = await client.get( "https://api.weixin.qq.com/sns/jscode2session", params={ "appid": settings.wechat_mini_app_id, "secret": settings.wechat_mini_app_secret, "js_code": code, "grant_type": "authorization_code", }, ) data = response.json() if data.get("errcode"): raise HTTPException(status_code=400, detail=data.get("errmsg") or "微信登录失败") if not data.get("openid"): raise HTTPException(status_code=400, detail="微信登录未返回 openid") return data async def _get_access_token() -> str: _require_wechat_config() async with httpx.AsyncClient(timeout=settings.wechat_api_timeout_seconds) as client: response = await client.get( "https://api.weixin.qq.com/cgi-bin/token", params={ "grant_type": "client_credential", "appid": settings.wechat_mini_app_id, "secret": settings.wechat_mini_app_secret, }, ) data = response.json() if data.get("errcode"): raise HTTPException(status_code=400, detail=data.get("errmsg") or "获取微信 access_token 失败") access_token = data.get("access_token") if not access_token: raise HTTPException(status_code=400, detail="微信未返回 access_token") return access_token async def exchange_phone_code(phone_code: str) -> str: access_token = await _get_access_token() async with httpx.AsyncClient(timeout=settings.wechat_api_timeout_seconds) as client: response = await client.post( "https://api.weixin.qq.com/wxa/business/getuserphonenumber", params={"access_token": access_token}, json={"code": phone_code}, ) data = response.json() if data.get("errcode"): raise HTTPException(status_code=400, detail=data.get("errmsg") or "获取微信手机号失败") phone_info = data.get("phone_info") or {} phone = phone_info.get("purePhoneNumber") or phone_info.get("phoneNumber") if not phone: raise HTTPException(status_code=400, detail="微信未返回手机号") return str(phone) def create_bind_token(openid: str, unionid: str | None = None) -> str: payload = { "purpose": WECHAT_BIND_PURPOSE, "openid": openid, "unionid": unionid, "exp": datetime.now(timezone.utc) + timedelta(seconds=WECHAT_BIND_TOKEN_EXP_SECONDS), } return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm) def decode_bind_token(token: str) -> tuple[str, str | None]: try: payload = jwt.decode( token, settings.jwt_secret, algorithms=[settings.jwt_algorithm], ) except JWTError: raise HTTPException(status_code=401, detail="微信绑定凭证无效或已过期") if payload.get("purpose") != WECHAT_BIND_PURPOSE or not payload.get("openid"): raise HTTPException(status_code=401, detail="微信绑定凭证无效") return str(payload["openid"]), payload.get("unionid") async def get_user_by_openid(db: AsyncSession, openid: str) -> User | None: result = await db.execute( select(User) .options( selectinload(User.memberships), selectinload(User.memberships).selectinload(ClassMembership.class_), ) .where(User.wechat_openid == openid) ) return result.scalar_one_or_none() def build_token_payload(user: User) -> dict: return {"sub": str(user.id), "role": user.role} def mark_wechat_bound(user: User, openid: str, unionid: str | None, phone: str) -> None: user.wechat_openid = openid user.wechat_unionid = unionid user.phone = phone user.phone_verified_at = datetime.now(timezone.utc).replace(tzinfo=None) user.status = "approved" def build_wechat_login_result(user: User, class_id: int | None = None) -> dict: return { "binding_required": False, "token": create_access_token(build_token_payload(user)), "user": build_user_out(user, class_id), }