From 73d1c3eef298d542239c553ebeaf9ab42f2d8f9c Mon Sep 17 00:00:00 2001 From: aaron <> Date: Wed, 11 Jun 2025 19:45:53 +0800 Subject: [PATCH] update --- cryptoai/examples/upay_example.py | 56 ++++ cryptoai/models/__init__.py | 4 +- .../__pycache__/__init__.cpython-313.pyc | Bin 818 -> 1039 bytes cryptoai/models/subscription_order.py | 281 ++++++++++++++++++ cryptoai/models/user_subscription.py | 281 ++++++++++++++++++ cryptoai/requirements.txt | 1 + cryptoai/routes/analysis.py | 18 +- cryptoai/routes/fastapi_app.py | 4 + cryptoai/routes/payment.py | 154 ++++++++++ cryptoai/routes/user.py | 58 +++- cryptoai/tasks/user.py | 4 +- cryptoai/utils/db_manager.py | 30 +- cryptoai/utils/upay.py | 107 +++++++ docker-compose.yml | 4 +- 14 files changed, 979 insertions(+), 23 deletions(-) create mode 100644 cryptoai/examples/upay_example.py create mode 100644 cryptoai/models/subscription_order.py create mode 100644 cryptoai/models/user_subscription.py create mode 100644 cryptoai/requirements.txt create mode 100644 cryptoai/routes/payment.py create mode 100644 cryptoai/utils/upay.py diff --git a/cryptoai/examples/upay_example.py b/cryptoai/examples/upay_example.py new file mode 100644 index 0000000..02c98e7 --- /dev/null +++ b/cryptoai/examples/upay_example.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import sys +import os +from datetime import datetime + +# 添加项目根目录到 Python 路径 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.upay import Upay + +def main(): + """ + UPay 创建订单示例 + """ + # 初始化 UPay 客户端 + upay = Upay() + + # 生成唯一的商户订单号 + merchant_order_no = f"Order{datetime.now().strftime('%Y%m%d%H%M%S')}" + + # 订单参数 + order_params = { + 'merchant_order_no': merchant_order_no, + 'chain_type': '1', # 1: 波场TRC20 + 'fiat_amount': '50.00', # 法币金额 + 'fiat_currency': 'USD', # 美元 + 'notify_url': 'https://your-domain.com/callback', # 回调地址 + 'attach': 'custom_data_123', # 自定义数据 + 'product_name': 'T-shirt', # 商品名称 + 'redirect_url': 'https://your-domain.com/success' # 成功跳转地址 + } + + print("正在创建订单...") + print(f"商户订单号: {merchant_order_no}") + + # 创建订单 + result = upay.create_order(**order_params) + + # 处理结果 + if 'code' in result and result['code'] == 200: + data = result['data'] + print("\n✅ 订单创建成功!") + print(f"UPay 订单号: {data['orderNo']}") + print(f"汇率: {data['exchangeRate']}") + print(f"加密货币金额: {data['crypto']} USDT") + print(f"订单状态: {data['status']}") + print(f"收银台链接: {data['payUrl']}") + print("\n用户可以通过收银台链接进行支付") + else: + print("\n❌ 订单创建失败!") + print(f"错误信息: {result}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/cryptoai/models/__init__.py b/cryptoai/models/__init__.py index 46c0258..c737361 100644 --- a/cryptoai/models/__init__.py +++ b/cryptoai/models/__init__.py @@ -9,4 +9,6 @@ from cryptoai.models.analysis_result import AnalysisResult, AnalysisResultManage from cryptoai.models.user import User, UserManager from cryptoai.models.user_question import UserQuestion, UserQuestionManager from cryptoai.models.astock import AStock, AStockManager -from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager \ No newline at end of file +from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager +from cryptoai.models.subscription_order import SubscriptionOrder, SubscriptionOrderManager +from cryptoai.models.user_subscription import UserSubscription, UserSubscriptionManager \ No newline at end of file diff --git a/cryptoai/models/__pycache__/__init__.cpython-313.pyc b/cryptoai/models/__pycache__/__init__.cpython-313.pyc index 4f56f4e09366b0a026fef27f75fe05e50ffcb652..89125b67824a5d121e200ded8b17c421e9907ae9 100644 GIT binary patch delta 291 zcmdnQ*3ZHFnU|M~0SNdbJTq=gBA{}C z(Bjl0Wcgd-=ps1TqA*cE-sq%z<3e XKy{2jT&y(tKC=?X1%{9!Ngx*hVcrj+ diff --git a/cryptoai/models/subscription_order.py b/cryptoai/models/subscription_order.py new file mode 100644 index 0000000..64fe854 --- /dev/null +++ b/cryptoai/models/subscription_order.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime +from fastapi import HTTPException, status +from sqlalchemy import Column, Integer, String, DateTime, Float, Index +from sqlalchemy.orm import Session +from cryptoai.models.base import Base, logger + +# 定义订阅订单数据模型 +class SubscriptionOrder(Base): + """订阅订单数据表模型""" + __tablename__ = 'subscription_orders' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, nullable=False, comment='用户ID') + order_id = Column(String(100), nullable=False, unique=True, comment='订单ID') + payment_order_id = Column(String(100), nullable=False, comment='支付订单ID') + amount = Column(Float, nullable=False, comment='金额') + member_type = Column(Integer, nullable=False, comment='会员类型(1=VIP,2=SVIP)') + currency = Column(String(10), nullable=False, comment='货币类型') + time_type = Column(Integer, nullable=False, comment='时间类型(1=包月,2=包年)') + status = Column(Integer, nullable=False, default=1, comment='订单状态(1=待支付,2=已完成)') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 索引和表属性 + __table_args__ = ( + Index('idx_user_id', 'user_id'), + Index('idx_order_id', 'order_id'), + Index('idx_payment_order_id', 'payment_order_id'), + Index('idx_member_type', 'member_type'), + Index('idx_status', 'status'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class SubscriptionOrderManager: + """订阅订单管理类""" + + def __init__(self, session: Session = None): + self.session = session + + def create_order(self, user_id: int, order_id: str, payment_order_id: str, + amount: float, member_type: int, currency: str, time_type: int, status: int = 1) -> bool: + """ + 创建订阅订单 + + Args: + user_id: 用户ID + order_id: 订单ID + payment_order_id: 支付订单ID + amount: 金额 + member_type: 会员类型(1=VIP,2=SVIP) + currency: 货币类型 + time_type: 时间类型(1=包月,2=包年) + status: 订单状态(1=待支付,2=已完成),默认为1 + + Returns: + 创建是否成功 + """ + try: + # 验证参数 + if member_type not in [1, 2]: + logger.warning(f"无效的会员类型: {member_type}") + return False + + if time_type not in [1, 2]: + logger.warning(f"无效的时间类型: {time_type}") + return False + + if status not in [1, 2]: + logger.warning(f"无效的订单状态: {status}") + return False + + if amount <= 0: + logger.warning(f"无效的金额: {amount}") + return False + + # 检查订单ID是否已存在 + existing_order = self.session.query(SubscriptionOrder).filter( + SubscriptionOrder.order_id == order_id + ).first() + if existing_order: + logger.warning(f"订单ID {order_id} 已存在") + return False + + # 创建新订单 + new_order = SubscriptionOrder( + user_id=user_id, + order_id=order_id, + payment_order_id=payment_order_id, + amount=amount, + member_type=member_type, + currency=currency, + time_type=time_type, + status=status, + create_time=datetime.now() + ) + + # 添加并提交 + self.session.add(new_order) + self.session.commit() + + logger.info(f"成功创建订阅订单: {order_id},用户: {user_id}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"创建订阅订单失败: {e}") + return False + + def get_order_by_id(self, order_id: str) -> Optional[Dict[str, Any]]: + """ + 通过订单ID获取订单信息 + + Args: + order_id: 订单ID + + Returns: + 订单信息,如果订单不存在则返回None + """ + try: + # 查询订单 + order = self.session.query(SubscriptionOrder).filter( + SubscriptionOrder.order_id == order_id + ).first() + + if order: + # 转换为字典 + return { + 'id': order.id, + 'user_id': order.user_id, + 'order_id': order.order_id, + 'payment_order_id': order.payment_order_id, + 'amount': order.amount, + 'member_type': order.member_type, + 'currency': order.currency, + 'time_type': order.time_type, + 'status': order.status, + 'create_time': order.create_time + } + else: + return None + + except Exception as e: + logger.error(f"获取订单信息失败: {e}") + return None + + def get_orders_by_user_id(self, user_id: int) -> List[Dict[str, Any]]: + """ + 通过用户ID获取用户的所有订单 + + Args: + user_id: 用户ID + + Returns: + 订单列表 + """ + try: + # 查询用户的所有订单 + orders = self.session.query(SubscriptionOrder).filter( + SubscriptionOrder.user_id == user_id + ).order_by(SubscriptionOrder.create_time.desc()).all() + + # 转换为字典列表 + order_list = [] + for order in orders: + order_list.append({ + 'id': order.id, + 'user_id': order.user_id, + 'order_id': order.order_id, + 'payment_order_id': order.payment_order_id, + 'amount': order.amount, + 'member_type': order.member_type, + 'currency': order.currency, + 'time_type': order.time_type, + 'status': order.status, + 'create_time': order.create_time + }) + + return order_list + + except Exception as e: + logger.error(f"获取用户订单列表失败: {e}") + return [] + + def get_order_count(self) -> int: + """ + 获取订单总数 + """ + try: + # 查询订单数量 + order_count = self.session.query(SubscriptionOrder).count() + + return order_count + except Exception as e: + logger.error(f"获取订单数量失败: {e}") + return 0 + + def update_order_status(self, order_id: str, status: int) -> bool: + """ + 更新订单状态 + + Args: + order_id: 订单ID + status: 新的订单状态(1=待支付,2=已完成) + + Returns: + 更新是否成功 + """ + try: + # 验证状态值 + if status not in [1, 2]: + logger.warning(f"无效的订单状态: {status}") + return False + + # 查询订单 + order = self.session.query(SubscriptionOrder).filter( + SubscriptionOrder.order_id == order_id + ).first() + + if not order: + logger.warning(f"订单ID {order_id} 不存在") + return False + + # 更新状态 + order.status = status + self.session.commit() + + status_name = "待支付" if status == 1 else "已完成" + logger.info(f"成功更新订单 {order_id} 状态为: {status_name}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"更新订单状态失败: {e}") + return False + + def get_orders_by_status(self, status: int) -> List[Dict[str, Any]]: + """ + 根据状态获取订单列表 + + Args: + status: 订单状态(1=待支付,2=已完成) + + Returns: + 订单列表 + """ + try: + # 验证状态值 + if status not in [1, 2]: + logger.warning(f"无效的订单状态: {status}") + return [] + + # 查询指定状态的订单 + orders = self.session.query(SubscriptionOrder).filter( + SubscriptionOrder.status == status + ).order_by(SubscriptionOrder.create_time.desc()).all() + + # 转换为字典列表 + order_list = [] + for order in orders: + order_list.append({ + 'id': order.id, + 'user_id': order.user_id, + 'order_id': order.order_id, + 'payment_order_id': order.payment_order_id, + 'amount': order.amount, + 'member_type': order.member_type, + 'currency': order.currency, + 'time_type': order.time_type, + 'status': order.status, + 'create_time': order.create_time + }) + + return order_list + + except Exception as e: + logger.error(f"获取订单列表失败: {e}") + return [] \ No newline at end of file diff --git a/cryptoai/models/user_subscription.py b/cryptoai/models/user_subscription.py new file mode 100644 index 0000000..a8581ba --- /dev/null +++ b/cryptoai/models/user_subscription.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime +from fastapi import HTTPException, status +from sqlalchemy import Column, Integer, String, DateTime, Index +from sqlalchemy.orm import Session +from cryptoai.models.base import Base, logger + +# 定义用户订阅数据模型 +class UserSubscription(Base): + """用户订阅数据表模型""" + __tablename__ = 'user_subscriptions' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, nullable=False, comment='用户ID') + member_type = Column(Integer, nullable=False, comment='会员类型(1=VIP,2=SVIP)') + time_type = Column(Integer, nullable=False, comment='时间类型(1=包月,2=包年)') + sub_order_id = Column(String(100), nullable=False, comment='订阅订单ID') + expire_time = Column(DateTime, nullable=False, comment='过期时间') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 索引和表属性 + __table_args__ = ( + Index('idx_user_id', 'user_id'), + Index('idx_member_type', 'member_type'), + Index('idx_time_type', 'time_type'), + Index('idx_sub_order_id', 'sub_order_id'), + Index('idx_expire_time', 'expire_time'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class UserSubscriptionManager: + """用户订阅管理类""" + + def __init__(self, session: Session = None): + self.session = session + + def create_subscription(self, user_id: int, member_type: int, time_type: int, + sub_order_id: str, expire_time: datetime) -> bool: + """ + 创建用户订阅 + + Args: + user_id: 用户ID + member_type: 会员类型(1=VIP,2=SVIP) + time_type: 时间类型(1=包月,2=包年) + sub_order_id: 订阅订单ID + expire_time: 过期时间 + + Returns: + 创建是否成功 + """ + try: + # 验证参数 + if member_type not in [1, 2]: + logger.warning(f"无效的会员类型: {member_type}") + return False + + if time_type not in [1, 2]: + logger.warning(f"无效的时间类型: {time_type}") + return False + + # 检查订单ID是否已存在 + existing_subscription = self.session.query(UserSubscription).filter( + UserSubscription.sub_order_id == sub_order_id + ).first() + if existing_subscription: + logger.warning(f"订阅订单ID {sub_order_id} 已存在") + return False + + # 创建新订阅 + new_subscription = UserSubscription( + user_id=user_id, + member_type=member_type, + time_type=time_type, + sub_order_id=sub_order_id, + expire_time=expire_time, + create_time=datetime.now() + ) + + # 添加并提交 + self.session.add(new_subscription) + self.session.commit() + + logger.info(f"成功创建用户订阅: 用户 {user_id},订单 {sub_order_id}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"创建用户订阅失败: {e}") + return False + + def get_subscription_by_user_id(self, user_id: int) -> Optional[Dict[str, Any]]: + """ + 通过用户ID获取用户当前有效订阅 + + Args: + user_id: 用户ID + + Returns: + 订阅信息,如果没有有效订阅则返回None + """ + try: + # 查询用户当前有效的订阅(未过期) + subscription = self.session.query(UserSubscription).filter( + UserSubscription.user_id == user_id, + UserSubscription.expire_time > datetime.now() + ).order_by(UserSubscription.expire_time.desc()).first() + + if subscription: + # 转换为字典 + return { + 'id': subscription.id, + 'user_id': subscription.user_id, + 'member_type': subscription.member_type, + 'time_type': subscription.time_type, + 'sub_order_id': subscription.sub_order_id, + 'expire_time': subscription.expire_time, + 'create_time': subscription.create_time + } + else: + return None + + except Exception as e: + logger.error(f"获取用户订阅失败: {e}") + return None + + def get_all_subscriptions_by_user_id(self, user_id: int) -> List[Dict[str, Any]]: + """ + 通过用户ID获取用户的所有订阅记录 + + Args: + user_id: 用户ID + + Returns: + 订阅列表 + """ + try: + # 查询用户的所有订阅记录 + subscriptions = self.session.query(UserSubscription).filter( + UserSubscription.user_id == user_id + ).order_by(UserSubscription.create_time.desc()).all() + + # 转换为字典列表 + subscription_list = [] + for subscription in subscriptions: + subscription_list.append({ + 'id': subscription.id, + 'user_id': subscription.user_id, + 'member_type': subscription.member_type, + 'time_type': subscription.time_type, + 'sub_order_id': subscription.sub_order_id, + 'expire_time': subscription.expire_time, + 'create_time': subscription.create_time + }) + + return subscription_list + + except Exception as e: + logger.error(f"获取用户订阅列表失败: {e}") + return [] + + def get_subscription_by_order_id(self, sub_order_id: str) -> Optional[Dict[str, Any]]: + """ + 通过订单ID获取订阅信息 + + Args: + sub_order_id: 订阅订单ID + + Returns: + 订阅信息,如果订阅不存在则返回None + """ + try: + # 查询订阅 + subscription = self.session.query(UserSubscription).filter( + UserSubscription.sub_order_id == sub_order_id + ).first() + + if subscription: + # 转换为字典 + return { + 'id': subscription.id, + 'user_id': subscription.user_id, + 'member_type': subscription.member_type, + 'time_type': subscription.time_type, + 'sub_order_id': subscription.sub_order_id, + 'expire_time': subscription.expire_time, + 'create_time': subscription.create_time + } + else: + return None + + except Exception as e: + logger.error(f"获取订阅信息失败: {e}") + return None + + def is_user_subscribed(self, user_id: int) -> bool: + """ + 检查用户是否有有效订阅 + + Args: + user_id: 用户ID + + Returns: + 是否有有效订阅 + """ + try: + # 查询用户是否有未过期的订阅 + subscription = self.session.query(UserSubscription).filter( + UserSubscription.user_id == user_id, + UserSubscription.expire_time > datetime.now() + ).first() + + return subscription is not None + + except Exception as e: + logger.error(f"检查用户订阅状态失败: {e}") + return False + + def get_expired_subscriptions(self) -> List[Dict[str, Any]]: + """ + 获取已过期的订阅列表 + + Returns: + 已过期的订阅列表 + """ + try: + # 查询已过期的订阅 + expired_subscriptions = self.session.query(UserSubscription).filter( + UserSubscription.expire_time <= datetime.now() + ).order_by(UserSubscription.expire_time.desc()).all() + + # 转换为字典列表 + subscription_list = [] + for subscription in expired_subscriptions: + subscription_list.append({ + 'id': subscription.id, + 'user_id': subscription.user_id, + 'member_type': subscription.member_type, + 'time_type': subscription.time_type, + 'sub_order_id': subscription.sub_order_id, + 'expire_time': subscription.expire_time, + 'create_time': subscription.create_time + }) + + return subscription_list + + except Exception as e: + logger.error(f"获取过期订阅列表失败: {e}") + return [] + + def get_subscription_count(self) -> int: + """ + 获取订阅总数 + """ + try: + # 查询订阅数量 + subscription_count = self.session.query(UserSubscription).count() + + return subscription_count + except Exception as e: + logger.error(f"获取订阅数量失败: {e}") + return 0 + + def get_active_subscription_count(self) -> int: + """ + 获取有效订阅数量 + """ + try: + # 查询有效订阅数量 + active_count = self.session.query(UserSubscription).filter( + UserSubscription.expire_time > datetime.now() + ).count() + + return active_count + except Exception as e: + logger.error(f"获取有效订阅数量失败: {e}") + return 0 \ No newline at end of file diff --git a/cryptoai/requirements.txt b/cryptoai/requirements.txt new file mode 100644 index 0000000..7bb88b7 --- /dev/null +++ b/cryptoai/requirements.txt @@ -0,0 +1 @@ +requests>=2.25.1 \ No newline at end of file diff --git a/cryptoai/routes/analysis.py b/cryptoai/routes/analysis.py index 498ac8d..cc6977e 100644 --- a/cryptoai/routes/analysis.py +++ b/cryptoai/routes/analysis.py @@ -13,6 +13,8 @@ from cryptoai.models.token import TokenManager from sqlalchemy.orm import Session from cryptoai.utils.db_manager import get_db from cryptoai.models.user import UserManager +from cryptoai.models.user_subscription import UserSubscriptionManager +from datetime import datetime class AnalysisHistoryRequest(BaseModel): symbol: str @@ -209,8 +211,14 @@ async def chat(request: ChatRequest, session: Session = Depends(get_db)): # 检查用户积分 - if current_user["points"] < 20: - raise HTTPException(status_code=400, detail="您的积分不足,请先充值。") + if current_user["points"] < 1: + raise HTTPException(status_code=400, detail="您的免费次数不足,你可以订阅会员。") + + # 检查用户是否订阅 + user_subscription_manager = UserSubscriptionManager(session) + user_subscription = user_subscription_manager.get_subscription_by_user_id(current_user["id"]) + if not user_subscription or user_subscription["expire_time"] < datetime.now(): + raise HTTPException(status_code=400, detail="您的会员已过期,请续订会员。") payload = { "inputs" : {}, @@ -241,10 +249,8 @@ async def chat(request: ChatRequest, ) # 扣除用户积分 - print(f"current_user: {current_user}") - if current_user['level'] < 2: - manager = UserManager(session) - manager.consume_user_points(current_user["id"], 20) + manager = UserManager(session) + manager.consume_user_points(current_user["id"], 1) # 获取response的stream def stream_response(): diff --git a/cryptoai/routes/fastapi_app.py b/cryptoai/routes/fastapi_app.py index 1538ccf..35b13ed 100644 --- a/cryptoai/routes/fastapi_app.py +++ b/cryptoai/routes/fastapi_app.py @@ -14,6 +14,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import time from typing import Dict, Any +from cryptoai.utils.db_manager import init_db from cryptoai.routes.user import router as user_router from cryptoai.routes.adata import router as adata_router @@ -21,6 +22,7 @@ from cryptoai.routes.crypto import router as crypto_router from cryptoai.routes.platform import router as platform_router from cryptoai.routes.analysis import router as analysis_router from cryptoai.routes.alltick import router as alltick_router +from cryptoai.routes.payment import router as payment_router # 配置日志 logging.basicConfig( level=logging.INFO, @@ -55,6 +57,7 @@ app.include_router(adata_router, prefix="/adata", tags=["A股数据"]) app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"]) app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"]) app.include_router(alltick_router, prefix="/alltick", tags=["AllTick数据"]) +app.include_router(payment_router, prefix="/payment", tags=["支付"]) # 请求计时中间件 @app.middleware("http") async def add_process_time_header(request: Request, call_next): @@ -115,4 +118,5 @@ def start(): ) if __name__ == "__main__": + init_db() start() \ No newline at end of file diff --git a/cryptoai/routes/payment.py b/cryptoai/routes/payment.py new file mode 100644 index 0000000..4d764e9 --- /dev/null +++ b/cryptoai/routes/payment.py @@ -0,0 +1,154 @@ +from fastapi import APIRouter +from pydantic import BaseModel +from cryptoai.utils.upay import Upay +from datetime import datetime +from cryptoai.models.subscription_order import SubscriptionOrderManager +import random +from cryptoai.routes.user import get_current_user +from cryptoai.models.user import User +from fastapi import Depends +from sqlalchemy.orm import Session +from cryptoai.utils.db_manager import get_db +from cryptoai.models.user_subscription import UserSubscriptionManager +from datetime import timedelta +import logging + +router = APIRouter() + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +class CreateOrderRequest(BaseModel): + subscribe_type: int + +@router.get('/pricing') +async def pricing(): + return { + "code": 200, + "data": { + "price_month": 29, + "price_year": 219 + } + } + +@router.post("/create_order") +async def create_order(request: CreateOrderRequest, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_db)): + + if current_user["id"] > 2: + return { + "code": 500, + "message": "暂时还没有开放订阅功能" + } + + upay = Upay() + + # 生成唯一的商户订单号 + merchant_order_no = f"D{datetime.now().strftime('%Y%m%d%H%M%S')}{random.randint(100, 999)}" + + ## 1=包月,2=包年 + if request.subscribe_type == 1: + fiat_amount = "29" + product_name = "会员订阅:1个月" + elif request.subscribe_type == 2: + fiat_amount = "219" + product_name = "会员订阅:1年" + else: + return { + "code": 500, + "message": "Invalid subscribe type" + } + + + result = upay.create_order( + merchant_order_no=merchant_order_no, + chain_type="1", ## TRC20 + fiat_amount=fiat_amount, + fiat_currency="USD", + notify_url="https://api.ibtc.work/payment/notify", + product_name=product_name, + redirect_url="https://tradus.vip/subscription-success" + ) + print(result) + + if result['code'] == "1": + payment_order_id = result['data']['orderNo'] + order_id = result['data']['merchantOrderNo'] + + ## 创建订阅记录 + subscription_order_manager = SubscriptionOrderManager(session) + subscription_order_manager.create_order( + order_id=order_id, + user_id=current_user["id"], + payment_order_id=payment_order_id, + amount=float(fiat_amount), + member_type=1, + currency="USD", + time_type=request.subscribe_type, + status=1 + ) + + return { + "code": 200, + "data": { + "order_no": result['data']['orderNo'], + "pay_url": result['data']['payUrl'], + "status": result['data']['status'] + } + } + else: + return { + "code": 500, + "message": result['message'] + } + + +class NotifyRequest(BaseModel): + order_no: str + merchant_order_no: str + status: str + amount: str + currency: str + chain_type: str + +@router.post("/notify") +async def notify(request: NotifyRequest, session: Session = Depends(get_db)): + + try: + # 更新订单状态 + subscription_order_manager = SubscriptionOrderManager(session) + subscription_order_manager.update_order_status(request.merchant_order_no, 2) + + order = subscription_order_manager.get_order_by_id(request.merchant_order_no) + + if order is None: + return { + "code": 500, + "message": "Order not found" + } + + user_id = order['user_id'] + member_type = order['member_type'] + time_type = order['time_type'] + + if time_type == 1: + expire_time = datetime.now() + timedelta(days=30) + elif time_type == 2: + expire_time = datetime.now() + timedelta(days=365) + + #增加用户订阅记录 + user_subscription_manager = UserSubscriptionManager(session) + user_subscription_manager.create_subscription(user_id, + member_type, + time_type, + request.merchant_order_no, + expire_time) + + return { + "code": 200, + "message": "success" + } + except Exception as e: + logger.error(f"创建用户订阅失败: {e}") + raise e \ No newline at end of file diff --git a/cryptoai/routes/user.py b/cryptoai/routes/user.py index 276f61f..1523964 100644 --- a/cryptoai/routes/user.py +++ b/cryptoai/routes/user.py @@ -20,6 +20,7 @@ from sqlalchemy.orm import Session from cryptoai.utils.db_manager import get_db from cryptoai.utils.email_service import get_email_service from cryptoai.models.user import UserManager +from cryptoai.models.user_subscription import UserSubscriptionManager # 配置日志 logger = logging.getLogger("user_router") @@ -64,6 +65,9 @@ class UserResponse(BaseModel): level: int points: int create_time: datetime + is_subscribed: bool + member_name: str + expire_time: datetime = None class TokenResponse(BaseModel): """令牌响应模型""" @@ -292,6 +296,24 @@ async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> Tok # 创建访问令牌,不过期 access_token = create_access_token(data={"sub": user["mail"]}) + user_subscription_manager = UserSubscriptionManager(session) + user_subscription = user_subscription_manager.get_subscription_by_user_id(user["id"]) + + is_subscribed = False + expire_time = None + + if user_subscription: + member_name = "VIP" + expire_time = user_subscription["expire_time"] + if expire_time > datetime.now(): + member_name = "SVIP" + is_subscribed = True + else: + member_name = "VIP" + is_subscribed = True + else: + member_name = "普通会员" + return TokenResponse( access_token=access_token, token_type="bearer", @@ -302,7 +324,10 @@ async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> Tok nickname=user["nickname"], level=user["level"], points=user["points"], - create_time=user["create_time"] + create_time=user["create_time"], + is_subscribed=is_subscribed, + member_name=member_name, + expire_time=expire_time ) ) @@ -326,15 +351,42 @@ async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user) Returns: 用户信息 """ - return UserResponse( + + user_subscription_manager = UserSubscriptionManager(session) + user_subscription = user_subscription_manager.get_subscription_by_user_id(current_user["id"]) + + is_subscribed = False + expire_time = None + + if user_subscription: + member_name = "VIP" + expire_time = user_subscription["expire_time"] + if expire_time > datetime.now(): + member_name = "SVIP" + is_subscribed = True + else: + member_name = "VIP" + is_subscribed = True + else: + member_name = "普通会员" + + + user = UserResponse( id=current_user["id"], mail=current_user["mail"], nickname=current_user["nickname"], level=current_user["level"], points=current_user["points"], - create_time=current_user["create_time"] + create_time=current_user["create_time"], + is_subscribed=is_subscribed, + member_name=member_name ) + if is_subscribed and expire_time: + user.expire_time = expire_time + + return user + @router.put("/level/{user_id}", response_model=Dict[str, Any]) async def update_user_level( user_id: int, diff --git a/cryptoai/tasks/user.py b/cryptoai/tasks/user.py index bf4e3ed..3fc4672 100644 --- a/cryptoai/tasks/user.py +++ b/cryptoai/tasks/user.py @@ -9,9 +9,9 @@ logger.setLevel(logging.DEBUG) def task_run(): try: session = SessionLocal() - users = session.query(User).filter(User.points < 100).all() + users = session.query(User).filter(User.points < 1).all() for user in users: - user.points = 100 + user.points = 1 session.commit() logger.info(f"用户 {user.mail} 积分复位成功") except Exception as e: diff --git a/cryptoai/utils/db_manager.py b/cryptoai/utils/db_manager.py index 3e982e2..abade56 100644 --- a/cryptoai/utils/db_manager.py +++ b/cryptoai/utils/db_manager.py @@ -11,14 +11,10 @@ from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.pool import QueuePool from cryptoai.utils.config_loader import ConfigLoader -from sqlalchemy.ext.declarative import declarative_base # 配置日志 -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) logger = logging.getLogger('db_manager') +logger.setLevel(logging.DEBUG) config_loader = ConfigLoader() @@ -32,15 +28,31 @@ engine = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['passwor pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效 connect_args={'charset': 'utf8mb4'}) -# 声明基类 -Base = declarative_base() -Base.metadata.create_all(bind=engine) - # 创建线程安全的会话工厂 SessionLocal = scoped_session( sessionmaker(autocommit=False, autoflush=False, bind=engine) ) +def init_db(): + try: + # 导入 Base 和所有模型(避免循环导入) + from cryptoai.models.base import Base + from cryptoai.models import ( + User, Token, AnalysisResult, UserQuestion, + AStock, AnalysisHistory, SubscriptionOrder, UserSubscription + ) + + Base.metadata.create_all(bind=engine, checkfirst=True) + logger.info("数据库初始化成功") + + # 输出已创建的表列表 + tables = list(Base.metadata.tables.keys()) + logger.info(f"已创建的数据表: {tables}") + + except Exception as e: + logger.error(f"初始化数据库失败: {e}") + raise e + def get_db(): db = SessionLocal() try: diff --git a/cryptoai/utils/upay.py b/cryptoai/utils/upay.py new file mode 100644 index 0000000..6f80edf --- /dev/null +++ b/cryptoai/utils/upay.py @@ -0,0 +1,107 @@ +import hashlib +import requests +from typing import Dict, Any, Optional + +class Upay: + def __init__(self): + self.app_id="4NgfCm1e" + self.app_secret="4gDTZDXfpKQBboT6" + self.base_url = "https://api-test.upay.ink" + # self.base_url = "https://api.upay.ink" + + def _generate_signature(self, params: Dict[str, Any]) -> str: + """ + 生成签名 + 根据 UPay API 签名算法生成签名 + + 步骤: + 1. 将需要加签的参数按照参数名 ASCII 码从小到大排序(字典序) + 2. 使用 URL 键值对的格式拼接成字符串 stringA + 3. 在 stringA 最后拼接上 &appSecret=密钥 得到 stringSignTemp + 4. 对 stringSignTemp 进行 MD5 运算,再转换为大写 + """ + # 过滤空值和签名字段(signature 参数不参与签名) + filtered_params = {k: v for k, v in params.items() if v is not None and k != 'signature'} + + # 按键名升序排序(字典序) + sorted_params = sorted(filtered_params.items()) + + # 拼接字符串 stringA + string_a = '&'.join([f"{k}={v}" for k, v in sorted_params]) + + # 拼接 appSecret 得到 stringSignTemp + string_sign_temp = f"{string_a}&appSecret={self.app_secret}" + + # MD5 运算并转为大写 + signature = hashlib.md5(string_sign_temp.encode('utf-8')).hexdigest().upper() + + return signature + + def create_order(self, + merchant_order_no: str, + chain_type: str, + fiat_amount: str, + fiat_currency: str, + notify_url: str, + attach: Optional[str] = None, + product_name: Optional[str] = None, + redirect_url: Optional[str] = None) -> Dict[str, Any]: + """ + 创建订单 + + Args: + merchant_order_no: 商户端自主生成的订单号,在商户端要保证唯一性 + chain_type: 链路类型(1: 波场TRC20, 2: 以太坊ERC20, 3: PayPal PYUSD) + fiat_amount: 法币金额,精确到小数点后4位 + fiat_currency: 法币类型(USD, CNY, INR, JPY, KRW, PHP, EUR, GBP, CHF, TWD, HKD, MOP, SGD, NZD, THB, CAD, ZAR, BRL) + notify_url: 接收异步通知的回调地址 + attach: 用户自定义数据(可选) + product_name: 商品名称(可选) + redirect_url: 支付成功后的前端重定向地址(可选) + + Returns: + Dict: API 响应数据 + """ + # 构建请求参数 + sign_params = { + 'appId': self.app_id, + 'merchantOrderNo': merchant_order_no, + 'chainType': chain_type, + 'fiatAmount': fiat_amount, + 'fiatCurrency': fiat_currency, + 'notifyUrl': notify_url, + } + + params = sign_params.copy() + + # 添加可选参数 + if attach is not None: + params['attach'] = attach + if product_name is not None: + params['productName'] = product_name + if redirect_url is not None: + params['redirectUrl'] = redirect_url + + print("请求参数:", params) + + # 生成签名 + signature = self._generate_signature(sign_params) + params['signature'] = signature + + # 发送请求 + url = f"{self.base_url}/v1/api/open/order/apply" + + try: + response = requests.post(url, json=params) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + return { + 'success': False, + 'error': f'请求失败: {str(e)}' + } + except Exception as e: + return { + 'success': False, + 'error': f'处理异常: {str(e)}' + } \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 1a9d79c..d3e2a03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ services: cryptoai-task: build: . container_name: cryptoai-task - image: cryptoai:0.0.19 + image: cryptoai:0.0.20 restart: always volumes: - ./cryptoai/data:/app/cryptoai/data @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.1.42 + image: cryptoai-api:0.2.0 restart: always ports: - "8000:8000"