340 lines
11 KiB
Python
340 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
P0问题修复验证测试
|
||
测试3个严重问题的修复情况
|
||
"""
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
import asyncio
|
||
import json
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from typing import Dict, Any
|
||
|
||
|
||
class TestP0Fixes:
|
||
"""P0问题修复测试类"""
|
||
|
||
def __init__(self):
|
||
self.test_results = []
|
||
self.passed = 0
|
||
self.failed = 0
|
||
|
||
def log(self, test_name: str, passed: bool, message: str):
|
||
"""记录测试结果"""
|
||
status = "✅ PASS" if passed else "❌ FAIL"
|
||
result = f"{status} - {test_name}: {message}"
|
||
self.test_results.append(result)
|
||
|
||
if passed:
|
||
self.passed += 1
|
||
else:
|
||
self.failed += 1
|
||
|
||
print(result)
|
||
|
||
def test_1_max_margin_pct(self):
|
||
"""测试1: 配置不一致修复(max_margin_pct 应该是 25%)"""
|
||
print("\n" + "="*60)
|
||
print("测试1: max_margin_pct 配置(应该支持25%)")
|
||
print("="*60)
|
||
|
||
try:
|
||
from app.crypto_agent.crypto_agent import CryptoAgent
|
||
|
||
agent = CryptoAgent()
|
||
|
||
# 检查 Bitget
|
||
bitget_max = agent.PLATFORM_RULES['Bitget']['max_margin_pct']
|
||
self.log(
|
||
"Bitget max_margin_pct",
|
||
bitget_max == 0.25,
|
||
f"期望 0.25, 实际 {bitget_max}"
|
||
)
|
||
|
||
# 检查 PaperTrading
|
||
paper_max = agent.PLATFORM_RULES['PaperTrading']['max_margin_pct']
|
||
self.log(
|
||
"PaperTrading max_margin_pct",
|
||
paper_max == 0.25,
|
||
f"期望 0.25, 实际 {paper_max}"
|
||
)
|
||
|
||
# 检查 Hyperliquid
|
||
hyper_max = agent.PLATFORM_RULES['Hyperliquid']['max_margin_pct']
|
||
self.log(
|
||
"Hyperliquid max_margin_pct",
|
||
hyper_max == 0.25,
|
||
f"期望 0.25, 实际 {hyper_max}"
|
||
)
|
||
|
||
except Exception as e:
|
||
self.log("max_margin_pct 测试", False, f"异常: {e}")
|
||
|
||
def test_2_position_sizing(self):
|
||
"""测试2: A级信号应该能使用20%保证金(不被截断到10%)"""
|
||
print("\n" + "="*60)
|
||
print("测试2: A级信号仓位计算($1000账户应该能用$200保证金)")
|
||
print("="*60)
|
||
|
||
try:
|
||
from app.crypto_agent.crypto_agent import CryptoAgent
|
||
|
||
agent = CryptoAgent()
|
||
|
||
# 模拟A级信号
|
||
signal = {
|
||
'symbol': 'BTCUSDT',
|
||
'confidence': 92, # A级
|
||
}
|
||
|
||
# 模拟账户
|
||
account = {
|
||
'available': 1000.0,
|
||
'current_balance': 1000.0,
|
||
'current_total_leverage': 0,
|
||
'max_total_leverage': 10,
|
||
}
|
||
|
||
# 计算仓位
|
||
margin, reason = agent._calculate_position_size(signal, account, 'PaperTrading')
|
||
|
||
# A级信号期望 20% = $200
|
||
expected_min = 190 # 考虑一些边界情况,至少应该接近$200
|
||
expected_max = 210
|
||
|
||
is_correct = expected_min <= margin <= expected_max
|
||
|
||
self.log(
|
||
"A级信号保证金计算",
|
||
is_correct,
|
||
f"期望 ${expected_min}-${expected_max}, 实际 ${margin:.2f} ({reason})"
|
||
)
|
||
|
||
except Exception as e:
|
||
self.log("仓位计算测试", False, f"异常: {e}")
|
||
|
||
def test_3_initial_balance_persistence(self):
|
||
"""测试3: 初始余额持久化机制"""
|
||
print("\n" + "="*60)
|
||
print("测试3: 初始余额持久化(应该保存到data/initial_balances.json)")
|
||
print("="*60)
|
||
|
||
try:
|
||
from app.crypto_agent.crypto_agent import CryptoAgent
|
||
|
||
# 清理测试文件
|
||
test_file = Path("data/initial_balances.json")
|
||
if test_file.exists():
|
||
test_file.unlink()
|
||
|
||
# 创建新的agent实例
|
||
agent = CryptoAgent()
|
||
|
||
# 第一次获取(应该记录)
|
||
platform_name = "TestPlatform"
|
||
current_balance = 10000.0
|
||
|
||
initial_1 = agent._get_initial_balance(platform_name, current_balance)
|
||
|
||
self.log(
|
||
"首次获取初始余额",
|
||
initial_1 == current_balance,
|
||
f"期望 {current_balance}, 实际 {initial_1}"
|
||
)
|
||
|
||
# 检查文件是否创建
|
||
self.log(
|
||
"持久化文件创建",
|
||
test_file.exists(),
|
||
f"文件存在: {test_file.exists()}"
|
||
)
|
||
|
||
# 检查文件内容
|
||
with open(test_file, 'r') as f:
|
||
saved_data = json.load(f)
|
||
|
||
self.log(
|
||
"持久化内容正确",
|
||
saved_data.get(platform_name) == current_balance,
|
||
f"文件中 {platform_name}: {saved_data.get(platform_name)}"
|
||
)
|
||
|
||
# 模拟余额变化,再次获取(应该返回初始值)
|
||
current_balance_2 = 9000.0 # 亏损10%
|
||
initial_2 = agent._get_initial_balance(platform_name, current_balance_2)
|
||
|
||
self.log(
|
||
"再次获取初始余额(余额变化后)",
|
||
initial_2 == current_balance, # 仍然是初始值
|
||
f"期望 {current_balance}, 实际 {initial_2}"
|
||
)
|
||
|
||
# 计算回撤
|
||
drawdown = (initial_2 - current_balance_2) / initial_2
|
||
expected_drawdown = 0.10 # 10%
|
||
|
||
self.log(
|
||
"回撤计算正确",
|
||
abs(drawdown - expected_drawdown) < 0.001,
|
||
f"期望 {expected_drawdown*100:.1f}%, 实际 {drawdown*100:.1f}%"
|
||
)
|
||
|
||
except Exception as e:
|
||
self.log("初始余额持久化测试", False, f"异常: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
async def test_4_emergency_close_await(self):
|
||
"""测试4: 紧急平仓async/await处理"""
|
||
print("\n" + "="*60)
|
||
print("测试4: 紧急平仓async/await检测")
|
||
print("="*60)
|
||
|
||
try:
|
||
import asyncio
|
||
|
||
# Mock同步方法
|
||
def sync_close(symbol: str):
|
||
return {'success': True, 'symbol': symbol}
|
||
|
||
# Mock异步方法
|
||
async def async_close(symbol: str):
|
||
await asyncio.sleep(0.01)
|
||
return {'success': True, 'symbol': symbol}
|
||
|
||
# 测试asyncio.iscoroutinefunction
|
||
is_sync = asyncio.iscoroutinefunction(sync_close)
|
||
is_async = asyncio.iscoroutinefunction(async_close)
|
||
|
||
self.log(
|
||
"同步方法检测",
|
||
not is_sync,
|
||
f"sync_close 是协程: {is_sync} (应该是False)"
|
||
)
|
||
|
||
self.log(
|
||
"异步方法检测",
|
||
is_async,
|
||
f"async_close 是协程: {is_async} (应该是True)"
|
||
)
|
||
|
||
# 测试调用
|
||
try:
|
||
# 同步调用
|
||
result_sync = sync_close("BTC")
|
||
self.log(
|
||
"同步方法调用",
|
||
result_sync['success'],
|
||
f"结果: {result_sync}"
|
||
)
|
||
|
||
# 异步调用
|
||
result_async = await async_close("ETH")
|
||
self.log(
|
||
"异步方法调用",
|
||
result_async['success'],
|
||
f"结果: {result_async}"
|
||
)
|
||
|
||
except Exception as e:
|
||
self.log("方法调用测试", False, f"异常: {e}")
|
||
|
||
except Exception as e:
|
||
self.log("紧急平仓await测试", False, f"异常: {e}")
|
||
|
||
def test_5_account_drawdown_calculation(self):
|
||
"""测试5: 账户回撤计算"""
|
||
print("\n" + "="*60)
|
||
print("测试5: 账户回撤计算(模拟各种场景)")
|
||
print("="*60)
|
||
|
||
try:
|
||
from app.crypto_agent.crypto_agent import CryptoAgent
|
||
|
||
# 测试场景
|
||
scenarios = [
|
||
{"initial": 10000, "current": 10000, "expected_dd": 0.0, "desc": "无回撤"},
|
||
{"initial": 10000, "current": 8500, "expected_dd": 0.15, "desc": "15%回撤(警告)"},
|
||
{"initial": 10000, "current": 7500, "expected_dd": 0.25, "desc": "25%回撤(止损)"},
|
||
{"initial": 10000, "current": 5000, "expected_dd": 0.50, "desc": "50%回撤(严重)"},
|
||
]
|
||
|
||
agent = CryptoAgent()
|
||
|
||
for scenario in scenarios:
|
||
initial = scenario['initial']
|
||
current = scenario['current']
|
||
expected = scenario['expected_dd']
|
||
|
||
# 模拟计算
|
||
drawdown = (initial - current) / initial
|
||
|
||
is_correct = abs(drawdown - expected) < 0.001
|
||
|
||
self.log(
|
||
f"回撤计算: {scenario['desc']}",
|
||
is_correct,
|
||
f"期望 {expected*100:.1f}%, 实际 {drawdown*100:.1f}%"
|
||
)
|
||
|
||
except Exception as e:
|
||
self.log("账户回撤计算测试", False, f"异常: {e}")
|
||
|
||
def print_summary(self):
|
||
"""打印测试总结"""
|
||
print("\n" + "="*60)
|
||
print("📊 测试总结")
|
||
print("="*60)
|
||
|
||
total = self.passed + self.failed
|
||
pass_rate = (self.passed / total * 100) if total > 0 else 0
|
||
|
||
print(f"总测试数: {total}")
|
||
print(f"通过: {self.passed} ✅")
|
||
print(f"失败: {self.failed} ❌")
|
||
print(f"通过率: {pass_rate:.1f}%")
|
||
|
||
if self.failed == 0:
|
||
print("\n🎉 所有测试通过!P0问题修复验证成功!")
|
||
else:
|
||
print(f"\n⚠️ 有 {self.failed} 个测试失败,请检查!")
|
||
|
||
print("="*60 + "\n")
|
||
|
||
async def run_all_tests(self):
|
||
"""运行所有测试"""
|
||
print("\n" + "🚀 " * 30)
|
||
print("P0问题修复验证测试")
|
||
print("测试时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||
print("🚀 " * 30)
|
||
|
||
# 运行测试
|
||
self.test_1_max_margin_pct()
|
||
self.test_2_position_sizing()
|
||
self.test_3_initial_balance_persistence()
|
||
await self.test_4_emergency_close_await()
|
||
self.test_5_account_drawdown_calculation()
|
||
|
||
# 打印总结
|
||
self.print_summary()
|
||
|
||
return self.failed == 0
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
test = TestP0Fixes()
|
||
|
||
# 运行异步测试
|
||
success = asyncio.run(test.run_all_tests())
|
||
|
||
# 返回退出码
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|