stock-ai-agent/backend/test_p0_fixes_simple.py
2026-03-28 22:56:16 +08:00

357 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
P0问题修复验证测试 - 简化版无需导入CryptoAgent
直接测试核心逻辑
"""
import asyncio
import json
from pathlib import Path
from datetime import datetime
def test_1_platform_rules():
"""测试1: 检查PLATFORM_RULES配置"""
print("\n" + "="*60)
print("测试1: PLATFORM_RULES 配置max_margin_pct应该是25%")
print("="*60)
# 读取crypto_agent.py文件
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/crypto_agent.py")
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 检查配置
tests_passed = 0
tests_total = 3
# 检查Bitget
if "'max_margin_pct': 0.25," in content and "'Bitget'" in content:
print("✅ PASS - Bitget max_margin_pct = 0.25")
tests_passed += 1
else:
print("❌ FAIL - Bitget max_margin_pct != 0.25")
# 检查PaperTrading
if "'PaperTrading'" in content:
# 找到PaperTrading部分
start = content.find("'PaperTrading'")
section = content[start:start+500]
if "'max_margin_pct': 0.25," in section:
print("✅ PASS - PaperTrading max_margin_pct = 0.25")
tests_passed += 1
else:
print("❌ FAIL - PaperTrading max_margin_pct != 0.25")
# 检查Hyperliquid
if "'Hyperliquid'" in content:
start = content.find("'Hyperliquid'")
section = content[start:start+500]
if "'max_margin_pct': 0.25," in section:
print("✅ PASS - Hyperliquid max_margin_pct = 0.25")
tests_passed += 1
else:
print("❌ FAIL - Hyperliquid max_margin_pct != 0.25")
return tests_passed, tests_total
def test_2_position_sizing_logic():
"""测试2: 仓位计算逻辑"""
print("\n" + "="*60)
print("测试2: A级信号仓位计算逻辑")
print("="*60)
tests_passed = 0
tests_total = 3
# 测试场景:$1000账户A级信号20%
account = {
'available': 1000.0,
'balance': 1000.0,
'current_leverage': 0,
'max_leverage': 10
}
# A级信号配置
base_margin_pct = 0.20 # 20%
max_margin_pct = 0.25 # 25% (修复后)
# 计算保证金
margin = account['available'] * base_margin_pct # $200
# 应用最大限制
max_margin = account['balance'] * max_margin_pct # $250
if margin <= max_margin:
final_margin = margin
print(f"✅ PASS - A级信号保证金: ${margin:.2f} (未被截断)")
tests_passed += 1
else:
final_margin = max_margin
print(f"❌ FAIL - A级信号保证金被截断: ${margin:.2f} → ${max_margin:.2f}")
# 验证仓位价值
leverage = 10
position_value = final_margin * leverage
if position_value == 2000: # $200 * 10x = $2000
print(f"✅ PASS - 持仓价值: ${position_value:.2f} (期望 $2000)")
tests_passed += 1
else:
print(f"❌ FAIL - 持仓价值: ${position_value:.2f} (期望 $2000)")
# 验证占账户比例
position_pct = (position_value / account['balance']) * 100
if position_pct == 200: # 200% (20%保证金 * 10x杠杆)
print(f"✅ PASS - 账户比例: {position_pct:.0f}% (期望 200%)")
tests_passed += 1
else:
print(f"❌ FAIL - 账户比例: {position_pct:.0f}% (期望 200%)")
return tests_passed, tests_total
def test_3_emergency_close_code():
"""测试3: 检查紧急平仓代码"""
print("\n" + "="*60)
print("测试3: 紧急平仓async/await处理代码")
print("="*60)
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/crypto_agent.py")
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
tests_passed = 0
tests_total = 2
# 检查是否有iscoroutinefunction
if "asyncio.iscoroutinefunction" in content:
print("✅ PASS - 包含 asyncio.iscoroutinefunction 检查")
tests_passed += 1
else:
print("❌ FAIL - 缺少 asyncio.iscoroutinefunction 检查")
# 检查是否有await调用
if "await close_method" in content:
print("✅ PASS - 包含 await close_method 调用")
tests_passed += 1
else:
print("❌ FAIL - 缺少 await close_method 调用")
return tests_passed, tests_total
def test_4_initial_balance_methods():
"""测试4: 检查初始余额持久化方法"""
print("\n" + "="*60)
print("测试4: 初始余额持久化方法")
print("="*60)
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/crypto_agent.py")
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
tests_passed = 0
tests_total = 4
# 检查方法定义
methods = [
'_load_initial_balances',
'_save_initial_balances',
'_get_initial_balance',
'_initial_balances'
]
for method in methods:
if method in content:
print(f"✅ PASS - 找到方法/属性: {method}")
tests_passed += 1
else:
print(f"❌ FAIL - 未找到方法/属性: {method}")
return tests_passed, tests_total
def test_5_drawdown_calculation():
"""测试5: 回撤计算逻辑"""
print("\n" + "="*60)
print("测试5: 回撤计算(各种场景)")
print("="*60)
tests_passed = 0
tests_total = 4
scenarios = [
{"initial": 10000, "current": 10000, "expected_dd": 0.0, "desc": "无回撤"},
{"initial": 10000, "current": 8500, "expected_dd": 0.15, "desc": "15%回撤(警告)"},
{"initial": 10000, "current": 7500, "expected_dd": 0.25, "desc": "25%回撤(止损)"},
{"initial": 10000, "current": 5000, "expected_dd": 0.50, "desc": "50%回撤(严重)"},
]
for scenario in scenarios:
initial = scenario['initial']
current = scenario['current']
expected = scenario['expected_dd']
# 计算回撤
drawdown = (initial - current) / initial
if abs(drawdown - expected) < 0.001:
print(f"✅ PASS - {scenario['desc']}: {drawdown*100:.1f}%")
tests_passed += 1
else:
print(f"❌ FAIL - {scenario['desc']}: 期望 {expected*100:.1f}%, 实际 {drawdown*100:.1f}%")
return tests_passed, tests_total
async def test_6_async_await():
"""测试6: async/await机制"""
print("\n" + "="*60)
print("测试6: async/await机制验证")
print("="*60)
tests_passed = 0
tests_total = 4
# Mock同步方法
def sync_close(symbol: str):
return {'success': True, 'symbol': symbol}
# Mock异步方法
async def async_close(symbol: str):
await asyncio.sleep(0.01)
return {'success': True, 'symbol': symbol}
# 测试检测
is_sync = asyncio.iscoroutinefunction(sync_close)
is_async = asyncio.iscoroutinefunction(async_close)
if not is_sync:
print(f"✅ PASS - sync_close 不是协程: {is_sync}")
tests_passed += 1
else:
print(f"❌ FAIL - sync_close 是协程: {is_sync}")
if is_async:
print(f"✅ PASS - async_close 是协程: {is_async}")
tests_passed += 1
else:
print(f"❌ FAIL - async_close 不是协程: {is_async}")
# 测试调用
result_sync = sync_close("BTC")
if result_sync['success']:
print(f"✅ PASS - 同步调用成功: {result_sync}")
tests_passed += 1
else:
print(f"❌ FAIL - 同步调用失败")
result_async = await async_close("ETH")
if result_async['success']:
print(f"✅ PASS - 异步调用成功: {result_async}")
tests_passed += 1
else:
print(f"❌ FAIL - 异步调用失败")
return tests_passed, tests_total
def test_7_base_executor_notifications():
"""测试7: 检查BaseExecutor通知功能"""
print("\n" + "="*60)
print("测试7: BaseExecutor飞书通知功能")
print("="*60)
file_path = Path("/Users/aaron/source_code/Stock_Agent/backend/app/crypto_agent/executor/base_executor.py")
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
tests_passed = 0
tests_total = 3
# 检查通知方法
methods = [
'send_execution_notification',
'_send_open_notification',
'_send_close_notification',
]
for method in methods:
if method in content:
print(f"✅ PASS - 找到通知方法: {method}")
tests_passed += 1
else:
print(f"❌ FAIL - 未找到通知方法: {method}")
return tests_passed, tests_total
async def main():
"""主测试函数"""
print("\n" + "🚀 " * 30)
print("P0问题修复验证测试 - 简化版")
print("测试时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print("🚀 " * 30)
total_passed = 0
total_tests = 0
# 运行所有测试
passed, total = test_1_platform_rules()
total_passed += passed
total_tests += total
passed, total = test_2_position_sizing_logic()
total_passed += passed
total_tests += total
passed, total = test_3_emergency_close_code()
total_passed += passed
total_tests += total
passed, total = test_4_initial_balance_methods()
total_passed += passed
total_tests += total
passed, total = test_5_drawdown_calculation()
total_passed += passed
total_tests += total
passed, total = await test_6_async_await()
total_passed += passed
total_tests += total
passed, total = test_7_base_executor_notifications()
total_passed += passed
total_tests += total
# 打印总结
print("\n" + "="*60)
print("📊 测试总结")
print("="*60)
pass_rate = (total_passed / total_tests * 100) if total_tests > 0 else 0
print(f"总测试数: {total_tests}")
print(f"通过: {total_passed}")
print(f"失败: {total_tests - total_passed}")
print(f"通过率: {pass_rate:.1f}%")
if total_passed == total_tests:
print("\n🎉 所有测试通过P0问题修复验证成功")
return True
else:
print(f"\n⚠️ 有 {total_tests - total_passed} 个测试失败,请检查!")
return False
if __name__ == "__main__":
success = asyncio.run(main())
exit(0 if success else 1)