stock-ai-agent/backend/test_p0_fixes.py
2026-03-28 22:56:16 +08:00

340 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
P0问题修复验证测试
测试3个严重问题的修复情况
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import json
from pathlib import Path
from datetime import datetime
from typing import Dict, Any
class TestP0Fixes:
"""P0问题修复测试类"""
def __init__(self):
self.test_results = []
self.passed = 0
self.failed = 0
def log(self, test_name: str, passed: bool, message: str):
"""记录测试结果"""
status = "✅ PASS" if passed else "❌ FAIL"
result = f"{status} - {test_name}: {message}"
self.test_results.append(result)
if passed:
self.passed += 1
else:
self.failed += 1
print(result)
def test_1_max_margin_pct(self):
"""测试1: 配置不一致修复max_margin_pct 应该是 25%"""
print("\n" + "="*60)
print("测试1: max_margin_pct 配置应该支持25%")
print("="*60)
try:
from app.crypto_agent.crypto_agent import CryptoAgent
agent = CryptoAgent()
# 检查 Bitget
bitget_max = agent.PLATFORM_RULES['Bitget']['max_margin_pct']
self.log(
"Bitget max_margin_pct",
bitget_max == 0.25,
f"期望 0.25, 实际 {bitget_max}"
)
# 检查 PaperTrading
paper_max = agent.PLATFORM_RULES['PaperTrading']['max_margin_pct']
self.log(
"PaperTrading max_margin_pct",
paper_max == 0.25,
f"期望 0.25, 实际 {paper_max}"
)
# 检查 Hyperliquid
hyper_max = agent.PLATFORM_RULES['Hyperliquid']['max_margin_pct']
self.log(
"Hyperliquid max_margin_pct",
hyper_max == 0.25,
f"期望 0.25, 实际 {hyper_max}"
)
except Exception as e:
self.log("max_margin_pct 测试", False, f"异常: {e}")
def test_2_position_sizing(self):
"""测试2: A级信号应该能使用20%保证金不被截断到10%"""
print("\n" + "="*60)
print("测试2: A级信号仓位计算$1000账户应该能用$200保证金")
print("="*60)
try:
from app.crypto_agent.crypto_agent import CryptoAgent
agent = CryptoAgent()
# 模拟A级信号
signal = {
'symbol': 'BTCUSDT',
'confidence': 92, # A级
}
# 模拟账户
account = {
'available': 1000.0,
'current_balance': 1000.0,
'current_total_leverage': 0,
'max_total_leverage': 10,
}
# 计算仓位
margin, reason = agent._calculate_position_size(signal, account, 'PaperTrading')
# A级信号期望 20% = $200
expected_min = 190 # 考虑一些边界情况,至少应该接近$200
expected_max = 210
is_correct = expected_min <= margin <= expected_max
self.log(
"A级信号保证金计算",
is_correct,
f"期望 ${expected_min}-${expected_max}, 实际 ${margin:.2f} ({reason})"
)
except Exception as e:
self.log("仓位计算测试", False, f"异常: {e}")
def test_3_initial_balance_persistence(self):
"""测试3: 初始余额持久化机制"""
print("\n" + "="*60)
print("测试3: 初始余额持久化应该保存到data/initial_balances.json")
print("="*60)
try:
from app.crypto_agent.crypto_agent import CryptoAgent
# 清理测试文件
test_file = Path("data/initial_balances.json")
if test_file.exists():
test_file.unlink()
# 创建新的agent实例
agent = CryptoAgent()
# 第一次获取(应该记录)
platform_name = "TestPlatform"
current_balance = 10000.0
initial_1 = agent._get_initial_balance(platform_name, current_balance)
self.log(
"首次获取初始余额",
initial_1 == current_balance,
f"期望 {current_balance}, 实际 {initial_1}"
)
# 检查文件是否创建
self.log(
"持久化文件创建",
test_file.exists(),
f"文件存在: {test_file.exists()}"
)
# 检查文件内容
with open(test_file, 'r') as f:
saved_data = json.load(f)
self.log(
"持久化内容正确",
saved_data.get(platform_name) == current_balance,
f"文件中 {platform_name}: {saved_data.get(platform_name)}"
)
# 模拟余额变化,再次获取(应该返回初始值)
current_balance_2 = 9000.0 # 亏损10%
initial_2 = agent._get_initial_balance(platform_name, current_balance_2)
self.log(
"再次获取初始余额(余额变化后)",
initial_2 == current_balance, # 仍然是初始值
f"期望 {current_balance}, 实际 {initial_2}"
)
# 计算回撤
drawdown = (initial_2 - current_balance_2) / initial_2
expected_drawdown = 0.10 # 10%
self.log(
"回撤计算正确",
abs(drawdown - expected_drawdown) < 0.001,
f"期望 {expected_drawdown*100:.1f}%, 实际 {drawdown*100:.1f}%"
)
except Exception as e:
self.log("初始余额持久化测试", False, f"异常: {e}")
import traceback
traceback.print_exc()
async def test_4_emergency_close_await(self):
"""测试4: 紧急平仓async/await处理"""
print("\n" + "="*60)
print("测试4: 紧急平仓async/await检测")
print("="*60)
try:
import asyncio
# Mock同步方法
def sync_close(symbol: str):
return {'success': True, 'symbol': symbol}
# Mock异步方法
async def async_close(symbol: str):
await asyncio.sleep(0.01)
return {'success': True, 'symbol': symbol}
# 测试asyncio.iscoroutinefunction
is_sync = asyncio.iscoroutinefunction(sync_close)
is_async = asyncio.iscoroutinefunction(async_close)
self.log(
"同步方法检测",
not is_sync,
f"sync_close 是协程: {is_sync} (应该是False)"
)
self.log(
"异步方法检测",
is_async,
f"async_close 是协程: {is_async} (应该是True)"
)
# 测试调用
try:
# 同步调用
result_sync = sync_close("BTC")
self.log(
"同步方法调用",
result_sync['success'],
f"结果: {result_sync}"
)
# 异步调用
result_async = await async_close("ETH")
self.log(
"异步方法调用",
result_async['success'],
f"结果: {result_async}"
)
except Exception as e:
self.log("方法调用测试", False, f"异常: {e}")
except Exception as e:
self.log("紧急平仓await测试", False, f"异常: {e}")
def test_5_account_drawdown_calculation(self):
"""测试5: 账户回撤计算"""
print("\n" + "="*60)
print("测试5: 账户回撤计算(模拟各种场景)")
print("="*60)
try:
from app.crypto_agent.crypto_agent import CryptoAgent
# 测试场景
scenarios = [
{"initial": 10000, "current": 10000, "expected_dd": 0.0, "desc": "无回撤"},
{"initial": 10000, "current": 8500, "expected_dd": 0.15, "desc": "15%回撤(警告)"},
{"initial": 10000, "current": 7500, "expected_dd": 0.25, "desc": "25%回撤(止损)"},
{"initial": 10000, "current": 5000, "expected_dd": 0.50, "desc": "50%回撤(严重)"},
]
agent = CryptoAgent()
for scenario in scenarios:
initial = scenario['initial']
current = scenario['current']
expected = scenario['expected_dd']
# 模拟计算
drawdown = (initial - current) / initial
is_correct = abs(drawdown - expected) < 0.001
self.log(
f"回撤计算: {scenario['desc']}",
is_correct,
f"期望 {expected*100:.1f}%, 实际 {drawdown*100:.1f}%"
)
except Exception as e:
self.log("账户回撤计算测试", False, f"异常: {e}")
def print_summary(self):
"""打印测试总结"""
print("\n" + "="*60)
print("📊 测试总结")
print("="*60)
total = self.passed + self.failed
pass_rate = (self.passed / total * 100) if total > 0 else 0
print(f"总测试数: {total}")
print(f"通过: {self.passed}")
print(f"失败: {self.failed}")
print(f"通过率: {pass_rate:.1f}%")
if self.failed == 0:
print("\n🎉 所有测试通过P0问题修复验证成功")
else:
print(f"\n⚠️ 有 {self.failed} 个测试失败,请检查!")
print("="*60 + "\n")
async def run_all_tests(self):
"""运行所有测试"""
print("\n" + "🚀 " * 30)
print("P0问题修复验证测试")
print("测试时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print("🚀 " * 30)
# 运行测试
self.test_1_max_margin_pct()
self.test_2_position_sizing()
self.test_3_initial_balance_persistence()
await self.test_4_emergency_close_await()
self.test_5_account_drawdown_calculation()
# 打印总结
self.print_summary()
return self.failed == 0
def main():
"""主函数"""
test = TestP0Fixes()
# 运行异步测试
success = asyncio.run(test.run_all_tests())
# 返回退出码
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()