357 lines
10 KiB
Python
357 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
P0问题修复验证测试 - 简化版(无需导入CryptoAgent)
|
||
直接测试核心逻辑
|
||
"""
|
||
import asyncio
|
||
import json
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
|
||
def test_1_platform_rules():
|
||
"""测试1: 检查PLATFORM_RULES配置"""
|
||
print("\n" + "="*60)
|
||
print("测试1: PLATFORM_RULES 配置(max_margin_pct应该是25%)")
|
||
print("="*60)
|
||
|
||
# 读取crypto_agent.py文件
|
||
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/crypto_agent.py")
|
||
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 检查配置
|
||
tests_passed = 0
|
||
tests_total = 3
|
||
|
||
# 检查Bitget
|
||
if "'max_margin_pct': 0.25," in content and "'Bitget'" in content:
|
||
print("✅ PASS - Bitget max_margin_pct = 0.25")
|
||
tests_passed += 1
|
||
else:
|
||
print("❌ FAIL - Bitget max_margin_pct != 0.25")
|
||
|
||
# 检查PaperTrading
|
||
if "'PaperTrading'" in content:
|
||
# 找到PaperTrading部分
|
||
start = content.find("'PaperTrading'")
|
||
section = content[start:start+500]
|
||
if "'max_margin_pct': 0.25," in section:
|
||
print("✅ PASS - PaperTrading max_margin_pct = 0.25")
|
||
tests_passed += 1
|
||
else:
|
||
print("❌ FAIL - PaperTrading max_margin_pct != 0.25")
|
||
|
||
# 检查Hyperliquid
|
||
if "'Hyperliquid'" in content:
|
||
start = content.find("'Hyperliquid'")
|
||
section = content[start:start+500]
|
||
if "'max_margin_pct': 0.25," in section:
|
||
print("✅ PASS - Hyperliquid max_margin_pct = 0.25")
|
||
tests_passed += 1
|
||
else:
|
||
print("❌ FAIL - Hyperliquid max_margin_pct != 0.25")
|
||
|
||
return tests_passed, tests_total
|
||
|
||
|
||
def test_2_position_sizing_logic():
|
||
"""测试2: 仓位计算逻辑"""
|
||
print("\n" + "="*60)
|
||
print("测试2: A级信号仓位计算逻辑")
|
||
print("="*60)
|
||
|
||
tests_passed = 0
|
||
tests_total = 3
|
||
|
||
# 测试场景:$1000账户,A级信号(20%)
|
||
account = {
|
||
'available': 1000.0,
|
||
'balance': 1000.0,
|
||
'current_leverage': 0,
|
||
'max_leverage': 10
|
||
}
|
||
|
||
# A级信号配置
|
||
base_margin_pct = 0.20 # 20%
|
||
max_margin_pct = 0.25 # 25% (修复后)
|
||
|
||
# 计算保证金
|
||
margin = account['available'] * base_margin_pct # $200
|
||
|
||
# 应用最大限制
|
||
max_margin = account['balance'] * max_margin_pct # $250
|
||
|
||
if margin <= max_margin:
|
||
final_margin = margin
|
||
print(f"✅ PASS - A级信号保证金: ${margin:.2f} (未被截断)")
|
||
tests_passed += 1
|
||
else:
|
||
final_margin = max_margin
|
||
print(f"❌ FAIL - A级信号保证金被截断: ${margin:.2f} → ${max_margin:.2f}")
|
||
|
||
# 验证仓位价值
|
||
leverage = 10
|
||
position_value = final_margin * leverage
|
||
|
||
if position_value == 2000: # $200 * 10x = $2000
|
||
print(f"✅ PASS - 持仓价值: ${position_value:.2f} (期望 $2000)")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - 持仓价值: ${position_value:.2f} (期望 $2000)")
|
||
|
||
# 验证占账户比例
|
||
position_pct = (position_value / account['balance']) * 100
|
||
|
||
if position_pct == 200: # 200% (20%保证金 * 10x杠杆)
|
||
print(f"✅ PASS - 账户比例: {position_pct:.0f}% (期望 200%)")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - 账户比例: {position_pct:.0f}% (期望 200%)")
|
||
|
||
return tests_passed, tests_total
|
||
|
||
|
||
def test_3_emergency_close_code():
|
||
"""测试3: 检查紧急平仓代码"""
|
||
print("\n" + "="*60)
|
||
print("测试3: 紧急平仓async/await处理代码")
|
||
print("="*60)
|
||
|
||
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/crypto_agent.py")
|
||
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
tests_passed = 0
|
||
tests_total = 2
|
||
|
||
# 检查是否有iscoroutinefunction
|
||
if "asyncio.iscoroutinefunction" in content:
|
||
print("✅ PASS - 包含 asyncio.iscoroutinefunction 检查")
|
||
tests_passed += 1
|
||
else:
|
||
print("❌ FAIL - 缺少 asyncio.iscoroutinefunction 检查")
|
||
|
||
# 检查是否有await调用
|
||
if "await close_method" in content:
|
||
print("✅ PASS - 包含 await close_method 调用")
|
||
tests_passed += 1
|
||
else:
|
||
print("❌ FAIL - 缺少 await close_method 调用")
|
||
|
||
return tests_passed, tests_total
|
||
|
||
|
||
def test_4_initial_balance_methods():
|
||
"""测试4: 检查初始余额持久化方法"""
|
||
print("\n" + "="*60)
|
||
print("测试4: 初始余额持久化方法")
|
||
print("="*60)
|
||
|
||
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/crypto_agent.py")
|
||
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
tests_passed = 0
|
||
tests_total = 4
|
||
|
||
# 检查方法定义
|
||
methods = [
|
||
'_load_initial_balances',
|
||
'_save_initial_balances',
|
||
'_get_initial_balance',
|
||
'_initial_balances'
|
||
]
|
||
|
||
for method in methods:
|
||
if method in content:
|
||
print(f"✅ PASS - 找到方法/属性: {method}")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - 未找到方法/属性: {method}")
|
||
|
||
return tests_passed, tests_total
|
||
|
||
|
||
def test_5_drawdown_calculation():
|
||
"""测试5: 回撤计算逻辑"""
|
||
print("\n" + "="*60)
|
||
print("测试5: 回撤计算(各种场景)")
|
||
print("="*60)
|
||
|
||
tests_passed = 0
|
||
tests_total = 4
|
||
|
||
scenarios = [
|
||
{"initial": 10000, "current": 10000, "expected_dd": 0.0, "desc": "无回撤"},
|
||
{"initial": 10000, "current": 8500, "expected_dd": 0.15, "desc": "15%回撤(警告)"},
|
||
{"initial": 10000, "current": 7500, "expected_dd": 0.25, "desc": "25%回撤(止损)"},
|
||
{"initial": 10000, "current": 5000, "expected_dd": 0.50, "desc": "50%回撤(严重)"},
|
||
]
|
||
|
||
for scenario in scenarios:
|
||
initial = scenario['initial']
|
||
current = scenario['current']
|
||
expected = scenario['expected_dd']
|
||
|
||
# 计算回撤
|
||
drawdown = (initial - current) / initial
|
||
|
||
if abs(drawdown - expected) < 0.001:
|
||
print(f"✅ PASS - {scenario['desc']}: {drawdown*100:.1f}%")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - {scenario['desc']}: 期望 {expected*100:.1f}%, 实际 {drawdown*100:.1f}%")
|
||
|
||
return tests_passed, tests_total
|
||
|
||
|
||
async def test_6_async_await():
|
||
"""测试6: async/await机制"""
|
||
print("\n" + "="*60)
|
||
print("测试6: async/await机制验证")
|
||
print("="*60)
|
||
|
||
tests_passed = 0
|
||
tests_total = 4
|
||
|
||
# Mock同步方法
|
||
def sync_close(symbol: str):
|
||
return {'success': True, 'symbol': symbol}
|
||
|
||
# Mock异步方法
|
||
async def async_close(symbol: str):
|
||
await asyncio.sleep(0.01)
|
||
return {'success': True, 'symbol': symbol}
|
||
|
||
# 测试检测
|
||
is_sync = asyncio.iscoroutinefunction(sync_close)
|
||
is_async = asyncio.iscoroutinefunction(async_close)
|
||
|
||
if not is_sync:
|
||
print(f"✅ PASS - sync_close 不是协程: {is_sync}")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - sync_close 是协程: {is_sync}")
|
||
|
||
if is_async:
|
||
print(f"✅ PASS - async_close 是协程: {is_async}")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - async_close 不是协程: {is_async}")
|
||
|
||
# 测试调用
|
||
result_sync = sync_close("BTC")
|
||
if result_sync['success']:
|
||
print(f"✅ PASS - 同步调用成功: {result_sync}")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - 同步调用失败")
|
||
|
||
result_async = await async_close("ETH")
|
||
if result_async['success']:
|
||
print(f"✅ PASS - 异步调用成功: {result_async}")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - 异步调用失败")
|
||
|
||
return tests_passed, tests_total
|
||
|
||
|
||
def test_7_base_executor_notifications():
|
||
"""测试7: 检查BaseExecutor通知功能"""
|
||
print("\n" + "="*60)
|
||
print("测试7: BaseExecutor飞书通知功能")
|
||
print("="*60)
|
||
|
||
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/executor/base_executor.py")
|
||
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
tests_passed = 0
|
||
tests_total = 3
|
||
|
||
# 检查通知方法
|
||
methods = [
|
||
'send_execution_notification',
|
||
'_send_open_notification',
|
||
'_send_close_notification',
|
||
]
|
||
|
||
for method in methods:
|
||
if method in content:
|
||
print(f"✅ PASS - 找到通知方法: {method}")
|
||
tests_passed += 1
|
||
else:
|
||
print(f"❌ FAIL - 未找到通知方法: {method}")
|
||
|
||
return tests_passed, tests_total
|
||
|
||
|
||
async def main():
|
||
"""主测试函数"""
|
||
print("\n" + "🚀 " * 30)
|
||
print("P0问题修复验证测试 - 简化版")
|
||
print("测试时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||
print("🚀 " * 30)
|
||
|
||
total_passed = 0
|
||
total_tests = 0
|
||
|
||
# 运行所有测试
|
||
passed, total = test_1_platform_rules()
|
||
total_passed += passed
|
||
total_tests += total
|
||
|
||
passed, total = test_2_position_sizing_logic()
|
||
total_passed += passed
|
||
total_tests += total
|
||
|
||
passed, total = test_3_emergency_close_code()
|
||
total_passed += passed
|
||
total_tests += total
|
||
|
||
passed, total = test_4_initial_balance_methods()
|
||
total_passed += passed
|
||
total_tests += total
|
||
|
||
passed, total = test_5_drawdown_calculation()
|
||
total_passed += passed
|
||
total_tests += total
|
||
|
||
passed, total = await test_6_async_await()
|
||
total_passed += passed
|
||
total_tests += total
|
||
|
||
passed, total = test_7_base_executor_notifications()
|
||
total_passed += passed
|
||
total_tests += total
|
||
|
||
# 打印总结
|
||
print("\n" + "="*60)
|
||
print("📊 测试总结")
|
||
print("="*60)
|
||
|
||
pass_rate = (total_passed / total_tests * 100) if total_tests > 0 else 0
|
||
|
||
print(f"总测试数: {total_tests}")
|
||
print(f"通过: {total_passed} ✅")
|
||
print(f"失败: {total_tests - total_passed} ❌")
|
||
print(f"通过率: {pass_rate:.1f}%")
|
||
|
||
if total_passed == total_tests:
|
||
print("\n🎉 所有测试通过!P0问题修复验证成功!")
|
||
return True
|
||
else:
|
||
print(f"\n⚠️ 有 {total_tests - total_passed} 个测试失败,请检查!")
|
||
return False
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = asyncio.run(main())
|
||
exit(0 if success else 1)
|