tradusai/web/api.py
2025-12-09 12:57:31 +08:00

294 lines
8.9 KiB
Python

"""
FastAPI Web Service - 多周期交易状态展示 API
"""
import json
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# 状态文件路径
STATE_FILE = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
SIGNAL_FILE = Path(__file__).parent.parent / 'output' / 'latest_signal.json'
app = FastAPI(title="Trading Dashboard", version="2.0.0")
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict):
for connection in self.active_connections:
try:
await connection.send_json(message)
except:
pass
manager = ConnectionManager()
def load_trading_state() -> Dict[str, Any]:
"""加载交易状态"""
try:
if STATE_FILE.exists():
with open(STATE_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading state: {e}")
# 返回默认状态
return {
'accounts': {
'short': _default_account('short', 10000),
'medium': _default_account('medium', 10000),
'long': _default_account('long', 10000),
},
'last_updated': None,
}
def _default_account(timeframe: str, balance: float) -> Dict:
return {
'timeframe': timeframe,
'balance': balance,
'initial_balance': balance,
'leverage': 10, # 所有周期统一 10 倍杠杆
'position': None,
'trades': [],
'stats': {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'total_pnl': 0.0,
'max_drawdown': 0.0,
'peak_balance': balance,
'win_rate': 0.0,
},
'equity_curve': [],
}
def load_latest_signal() -> Dict[str, Any]:
"""加载最新信号"""
try:
if SIGNAL_FILE.exists():
with open(SIGNAL_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading signal: {e}")
return {}
@app.get("/")
async def root():
"""返回前端页面"""
html_file = Path(__file__).parent / 'static' / 'index.html'
if html_file.exists():
return FileResponse(html_file)
return {"error": "Static files not found"}
@app.get("/api/status")
async def get_status():
"""获取多周期交易状态"""
state = load_trading_state()
accounts = state.get('accounts', {})
# 计算总余额
total_balance = sum(acc.get('balance', 0) for acc in accounts.values())
total_initial = sum(acc.get('initial_balance', 0) for acc in accounts.values())
total_return = (total_balance - total_initial) / total_initial * 100 if total_initial > 0 else 0
# 构建各周期状态
timeframes = {}
for tf_key, acc in accounts.items():
initial = acc.get('initial_balance', 0)
balance = acc.get('balance', 0)
return_pct = (balance - initial) / initial * 100 if initial > 0 else 0
timeframes[tf_key] = {
'name': '短周期' if tf_key == 'short' else '中周期' if tf_key == 'medium' else '长周期',
'name_en': 'Short-term' if tf_key == 'short' else 'Medium-term' if tf_key == 'medium' else 'Long-term',
'balance': balance,
'initial_balance': initial,
'return_pct': return_pct,
'leverage': acc.get('leverage', 1),
'position': acc.get('position'),
'stats': acc.get('stats', {}),
}
return {
'timestamp': datetime.now().isoformat(),
'total_balance': total_balance,
'total_initial_balance': total_initial,
'total_return': total_return,
'timeframes': timeframes,
'last_updated': state.get('last_updated'),
}
@app.get("/api/trades")
async def get_trades(timeframe: str = None, limit: int = 50):
"""获取交易记录"""
state = load_trading_state()
accounts = state.get('accounts', {})
all_trades = []
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
trades = acc.get('trades', [])
all_trades.extend(trades)
# 按时间排序
all_trades.sort(key=lambda x: x.get('exit_time', ''), reverse=True)
return {
'total': len(all_trades),
'trades': all_trades[:limit] if limit > 0 else all_trades,
}
@app.get("/api/equity")
async def get_equity_curve(timeframe: str = None, limit: int = 500):
"""获取权益曲线"""
state = load_trading_state()
accounts = state.get('accounts', {})
result = {}
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
equity_curve = acc.get('equity_curve', [])
result[tf_key] = equity_curve[-limit:] if limit > 0 else equity_curve
return {
'data': result,
}
@app.get("/api/signal")
async def get_signal():
"""获取最新信号"""
signal = load_latest_signal()
agg = signal.get('aggregated_signal', {})
llm = agg.get('llm_signal', {})
market = signal.get('market_analysis', {})
# 提取各周期机会
opportunities = llm.get('opportunities', {})
return {
'timestamp': agg.get('timestamp'),
'final_signal': agg.get('final_signal'),
'final_confidence': agg.get('final_confidence'),
'current_price': agg.get('levels', {}).get('current_price') or market.get('price'),
'opportunities': {
'short': opportunities.get('short_term_5m_15m_1h') or opportunities.get('intraday'),
'medium': opportunities.get('medium_term_4h_1d') or opportunities.get('swing'),
'long': opportunities.get('long_term_1d_1w'),
},
'reasoning': llm.get('reasoning'),
'recommendations': llm.get('recommendations_by_timeframe', {}),
}
@app.get("/api/timeframe/{timeframe}")
async def get_timeframe_detail(timeframe: str):
"""获取单个周期详情"""
state = load_trading_state()
accounts = state.get('accounts', {})
if timeframe not in accounts:
return {"error": f"Timeframe '{timeframe}' not found"}
acc = accounts[timeframe]
initial = acc.get('initial_balance', 0)
balance = acc.get('balance', 0)
return {
'timeframe': timeframe,
'balance': balance,
'initial_balance': initial,
'return_pct': (balance - initial) / initial * 100 if initial > 0 else 0,
'leverage': acc.get('leverage', 1),
'position': acc.get('position'),
'stats': acc.get('stats', {}),
'recent_trades': acc.get('trades', [])[-20:],
'equity_curve': acc.get('equity_curve', [])[-200:],
}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket 实时推送"""
await manager.connect(websocket)
try:
# 发送初始状态
state = load_trading_state()
signal = load_latest_signal()
await websocket.send_json({
'type': 'init',
'state': state,
'signal': signal,
})
# 持续推送更新
last_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
last_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
while True:
await asyncio.sleep(1)
current_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
current_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
if current_state_mtime > last_state_mtime:
last_state_mtime = current_state_mtime
state = load_trading_state()
await websocket.send_json({
'type': 'state_update',
'state': state,
})
if current_signal_mtime > last_signal_mtime:
last_signal_mtime = current_signal_mtime
signal = load_latest_signal()
await websocket.send_json({
'type': 'signal_update',
'signal': signal,
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket)
# 静态文件
static_dir = Path(__file__).parent / 'static'
if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)