From 648eda8b80b8b87f95af9fa56dfa1828e814538d Mon Sep 17 00:00:00 2001 From: aaron <> Date: Fri, 17 Jan 2025 17:09:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=BE=AE=E4=BF=A1=E6=94=AF?= =?UTF-8?q?=E4=BB=98=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/wechat.py | 17 ++++++-- app/core/config.py | 6 ++- app/core/wechat.py | 84 ++++++++++++++++++++++++++++++++++++- app/models/user.py | 2 + 4 files changed, 103 insertions(+), 6 deletions(-) diff --git a/app/api/endpoints/wechat.py b/app/api/endpoints/wechat.py index d31f7d3..47e7966 100644 --- a/app/api/endpoints/wechat.py +++ b/app/api/endpoints/wechat.py @@ -4,7 +4,7 @@ from app.models.database import get_db from app.models.user import UserInfo,UserDB, PhoneLoginRequest, generate_user_code from app.models.order import ShippingOrderDB, OrderStatus from app.core.response import success_response, error_response, ResponseModel -from app.core.wechat import WeChatClient, get_wechat_pay_client +from app.core.wechat import WeChatClient from app.core.security import create_access_token, set_jwt_cookie from pydantic import BaseModel import json @@ -18,7 +18,7 @@ import string router = APIRouter() class PhoneNumberRequest(BaseModel): - code: str # 手机号获取凭证 + code: str # 登录凭证 referral_code: str = None # 推荐码(可选) @router.post("/phone-login", response_model=ResponseModel) @@ -32,6 +32,10 @@ async def wechat_phone_login( # 初始化微信客户端 wechat = WeChatClient() + # 获取用户 openid + session_info = await wechat.code2session(request.code) + openid = session_info["openid"] + # 获取用户手机号 phone_info = await wechat.get_phone_number(request.code) @@ -50,10 +54,11 @@ async def wechat_phone_login( username=f"user_{phone[-4:]}", phone=phone, user_code=user_code, - referral_code=request.referral_code + referral_code=request.referral_code, + openid=openid # 保存 openid ) db.add(user) - db.flush() # 获取用户ID + db.flush() # 发放优惠券 from app.api.endpoints.user import issue_register_coupons @@ -61,6 +66,10 @@ async def wechat_phone_login( db.commit() db.refresh(user) + else: + # 更新现有用户的 openid + user.openid = openid + db.commit() # 创建访问令牌 access_token = create_access_token( diff --git a/app/core/config.py b/app/core/config.py index 244af76..5d7eb4b 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -6,6 +6,8 @@ class Settings(BaseSettings): API_V1_STR: str = "/api/v1" PROJECT_NAME: str = "FastAPI 项目" + API_BASE_URL: str = "https://api.beefast.co" + # 订单价格配置 ORDER_BASE_PRICE: float = 3.0 # 基础费用 ORDER_EXTRA_PACKAGE_PRICE: float = 0.5 # 额外包裹费用 @@ -61,7 +63,9 @@ class Settings(BaseSettings): WECHAT_MCH_ID: str = "1688852888" WECHAT_PRIVATE_KEY_PATH: str = "app/core/wechat_private_key.pem" WECHAT_CERT_SERIAL_NO: str = "1688852888" - + WECHAT_API_V3_KEY: str = "your-api-v3-key" # API v3密钥 + WECHAT_PLATFORM_CERT_PATH: str = "app/core/wechat_platform_cert.pem" # 平台证书路径 + class Config: case_sensitive = True env_file = ".env" diff --git a/app/core/wechat.py b/app/core/wechat.py index 8f85ceb..10f15cc 100644 --- a/app/core/wechat.py +++ b/app/core/wechat.py @@ -9,6 +9,8 @@ import base64 import time import random import string +from cryptography.x509 import load_pem_x509_certificate +from cryptography.hazmat.primitives.ciphers.aead import AESGCM def generate_random_string(length=32): """生成指定长度的随机字符串""" @@ -25,6 +27,8 @@ class WeChatClient: self.private_key = self._load_private_key() self.cert_serial_no = settings.WECHAT_CERT_SERIAL_NO self.access_token = None + self.api_v3_key = settings.WECHAT_API_V3_KEY + self.platform_cert = self._load_platform_cert() def _load_private_key(self): """加载商户私钥""" @@ -34,6 +38,12 @@ class WeChatClient: password=None ) + def _load_platform_cert(self): + """加载微信支付平台证书""" + with open(settings.WECHAT_PLATFORM_CERT_PATH, 'rb') as f: + cert_data = f.read() + return load_pem_x509_certificate(cert_data) + def sign(self, data: bytes) -> str: """签名数据""" signature = self.private_key.sign( @@ -132,4 +142,76 @@ class WeChatClient: if resp.status_code != 200: raise Exception(f"微信支付下单失败: {resp.json().get('message')}") - return resp.json() \ No newline at end of file + return resp.json() + + async def code2session(self, code: str) -> dict: + """通过 code 获取用户 openid + + Args: + code: 登录凭证 + Returns: + dict: 包含 openid 等信息 + """ + async with aiohttp.ClientSession() as session: + url = "https://api.weixin.qq.com/sns/jscode2session" + params = { + "appid": self.appid, + "secret": self.secret, + "js_code": code, + "grant_type": "authorization_code" + } + + async with session.get(url, params=params) as response: + result = await response.json() + if "openid" in result: + return result + raise Exception(result.get("errmsg", "获取openid失败")) + + def verify_signature(self, message: bytes, signature: str, serial_no: str) -> bool: + """验证微信支付回调签名 + + Args: + message: 待验证的消息 + signature: 签名字符串 + serial_no: 证书序列号 + """ + if serial_no != self.cert_serial_no: + return False + + try: + # 解码签名 + signature_bytes = base64.b64decode(signature) + + # 使用公钥验证签名 + self.platform_cert.public_key().verify( + signature_bytes, + message, + padding.PKCS1v15(), + hashes.SHA256() + ) + return True + except Exception: + return False + + def decrypt_callback(self, associated_data: str, nonce: str, ciphertext: str) -> str: + """解密回调数据 + + Args: + associated_data: 附加数据 + nonce: 随机串 + ciphertext: 密文 + Returns: + 解密后的明文 + """ + # 解码密文 + encrypted_data = base64.b64decode(ciphertext) + + # 使用 AEAD_AES_256_GCM 算法解密 + aesgcm = AESGCM(base64.b64decode(self.api_v3_key)) + data = aesgcm.decrypt( + nonce.encode(), + encrypted_data, + associated_data.encode() + ) + + return data.decode() \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py index 7a7643a..7428b28 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -25,6 +25,7 @@ class UserDB(Base): __tablename__ = "users" userid = Column(Integer, primary_key=True,autoincrement=True, index=True) + openid = Column(String(64), unique=True, nullable=True) username = Column(String(50)) phone = Column(String(11), unique=True, index=True) user_code = Column(String(6), unique=True, nullable=False) @@ -46,6 +47,7 @@ class UserLogin(BaseModel): class UserInfo(BaseModel): userid: int + openid: str username: str phone: str user_code: str