alphax/tools/backtest.py
2026-05-13 22:49:47 +08:00

206 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
山寨币策略回测脚本
对 DB 中所有有完整入场方案stop_loss/tp1/tp2的推荐做模拟跟踪。
"""
import json
import os
import sqlite3
import sys
from datetime import datetime
from pathlib import Path
import ccxt
import pandas as pd
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
exchange = ccxt.binance({'enableRateLimit': True})
DB = os.getenv("ALPHAX_DB_PATH", str(REPO_ROOT / "data" / "altcoin_monitor.db"))
OUTPUT_PATH = REPO_ROOT / "reports" / "backtest_result.json"
def fetch_klines_since(symbol, timeframe, since_ms, limit=500):
"""Fetch K-lines from a specific timestamp."""
try:
ohlcv = exchange.fetch_ohlcv(symbol, timeframe, since=since_ms, limit=limit)
if not ohlcv:
return None
df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
return df
except Exception as e:
print(f" fetch_klines error for {symbol}: {e}")
return None
def simulate_trade(rec, klines_df):
"""Simulate one trade: walk through K-lines, check TP/stop."""
entry_price = float(rec['entry_price'])
stop_loss = float(rec['stop_loss'] or 0)
tp1 = float(rec['tp1'] or 0)
tp2 = float(rec['tp2'] or 0)
if stop_loss <= 0 or tp1 <= 0:
return {'result': 'no_entry_plan', 'exit_price': 0, 'exit_time': '', 'pnl_pct': 0, 'hours': 0}
result = 'expired'
exit_price = entry_price
exit_time = ''
max_profit_pct = 0
max_loss_pct = 0
for _, row in klines_df.iterrows():
high = float(row['high'])
low = float(row['low'])
close = float(row['close'])
ts = row['timestamp']
max_profit_pct = max(max_profit_pct, (high / entry_price - 1) * 100)
max_loss_pct = min(max_loss_pct, (low / entry_price - 1) * 100)
# Check TP2 first (higher target)
if tp2 > 0 and high >= tp2:
result = 'hit_tp2'
exit_price = tp2
exit_time = str(ts)
break
# Check TP1
if tp1 > 0 and high >= tp1:
result = 'hit_tp1'
exit_price = tp1
exit_time = str(ts)
break
# Check stop loss
if stop_loss > 0 and low <= stop_loss:
result = 'stopped_out'
exit_price = stop_loss
exit_time = str(ts)
break
pnl_pct = round((exit_price / entry_price - 1) * 100, 2)
# Calculate holding hours
if exit_time:
try:
et = datetime.fromisoformat(exit_time)
rt = datetime.fromisoformat(rec['rec_time'])
hours = round((et - rt).total_seconds() / 3600, 1)
except:
hours = 0
else:
hours = 0
return {
'result': result,
'exit_price': round(exit_price, 6),
'exit_time': exit_time,
'pnl_pct': pnl_pct,
'max_profit_pct': round(max_profit_pct, 2),
'max_loss_pct': round(max_loss_pct, 2),
'hours': hours,
}
def main():
conn = sqlite3.connect(DB)
conn.row_factory = sqlite3.Row
# Only backtest "爆发" with full entry plans
rows = conn.execute("""
SELECT id, symbol, rec_time, rec_state, rec_score, entry_price,
stop_loss, tp1, tp2, status, entry_plan_json, signals, sector
FROM recommendation
WHERE stop_loss > 0 AND tp1 > 0
ORDER BY id
""").fetchall()
conn.close()
print(f"回测样本: {len(rows)} 条 (有完整入场方案)\n")
results = []
wins = losses = expired_count = 0
total_pnl = 0
max_win = -999
max_loss = 999
for i, rec in enumerate(rows, 1):
symbol = rec['symbol']
rec_time = datetime.fromisoformat(rec['rec_time'])
since_ms = int(rec_time.timestamp() * 1000)
print(f"[{i}/{len(rows)}] {symbol} rec_time={rec['rec_time'][:19]} "
f"entry={rec['entry_price']} stop={rec['stop_loss']} tp1={rec['tp1']} tp2={rec['tp2']}", end='')
klines = fetch_klines_since(symbol, '15m', since_ms, limit=2000)
if klines is None or len(klines) < 2:
print(" → 数据不足,跳过")
continue
sim = simulate_trade(rec, klines)
results.append({**sim, 'symbol': symbol, 'rec_time': rec['rec_time'],
'entry_price': rec['entry_price'], 'rec_state': rec['rec_state']})
tag = {'hit_tp1': '🟢', 'hit_tp2': '🟢🟢', 'stopped_out': '🔴', 'expired': '', 'no_entry_plan': ''}
print(f"{tag.get(sim['result'], '?')} {sim['result']} pnl={sim['pnl_pct']}% "
f"max_profit={sim['max_profit_pct']}% max_loss={sim['max_loss_pct']}% {sim['hours']}h")
if sim['result'] in ('hit_tp1', 'hit_tp2'):
wins += 1
total_pnl += sim['pnl_pct']
max_win = max(max_win, sim['pnl_pct'])
elif sim['result'] == 'stopped_out':
losses += 1
total_pnl += sim['pnl_pct']
max_loss = min(max_loss, sim['pnl_pct'])
else:
expired_count += 1
print(f"\n{'='*60}")
print(f"回测汇总 (n={len(results)})")
print(f"{'='*60}")
print(f"止盈(TP): {wins}")
print(f"止损: {losses}")
print(f"未触达: {expired_count}")
closed = wins + losses
if closed > 0:
print(f"胜率: {wins}/{closed} = {round(wins/closed*100,1)}%")
print(f"平均盈亏: {round(total_pnl/closed, 2)}%")
print(f"最大盈利: {max_win}%")
print(f"最大亏损: {max_loss}%")
print(f"盈亏比: {round(abs(max_win/max_loss) if max_loss != 0 else 99, 1)}")
print(f"{'='*60}")
# Detail table
print(f"\n{'symbol':<14} {'time':<17} {'result':<14} {'pnl':>7} {'max+':>7} {'max-':>7} {'h':>5}")
print("-" * 70)
for r in results:
print(f"{r['symbol']:<14} {r['rec_time'][:16]:<17} {r['result']:<14} "
f"{r['pnl_pct']:>6.1f}% {r['max_profit_pct']:>6.1f}% {r['max_loss_pct']:>6.1f}% {r['hours']:>5.1f}")
# Save to JSON for HTML report
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_PATH, 'w', encoding='utf-8') as f:
json.dump({
'generated_at': datetime.now().isoformat(),
'total': len(results),
'wins': wins,
'losses': losses,
'expired': expired_count,
'win_rate': round(wins/closed*100, 1) if closed > 0 else 0,
'avg_pnl': round(total_pnl/closed, 2) if closed > 0 else 0,
'max_win': max_win,
'max_loss': max_loss,
'details': [{k: str(v) if isinstance(v, (datetime, pd.Timestamp)) else v
for k, v in r.items()} for r in results],
}, f, ensure_ascii=False, indent=2, default=str)
print(f"\n结果已保存: {OUTPUT_PATH}")
if __name__ == '__main__':
main()