294 lines
8.9 KiB
Python
294 lines
8.9 KiB
Python
"""
|
|
FastAPI Web Service - 多周期交易状态展示 API
|
|
"""
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Any, List
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import FileResponse
|
|
|
|
# 状态文件路径
|
|
STATE_FILE = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
|
|
SIGNAL_FILE = Path(__file__).parent.parent / 'output' / 'latest_signal.json'
|
|
|
|
app = FastAPI(title="Trading Dashboard", version="2.0.0")
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: List[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
if websocket in self.active_connections:
|
|
self.active_connections.remove(websocket)
|
|
|
|
async def broadcast(self, message: dict):
|
|
for connection in self.active_connections:
|
|
try:
|
|
await connection.send_json(message)
|
|
except:
|
|
pass
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
def load_trading_state() -> Dict[str, Any]:
|
|
"""加载交易状态"""
|
|
try:
|
|
if STATE_FILE.exists():
|
|
with open(STATE_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
print(f"Error loading state: {e}")
|
|
|
|
# 返回默认状态
|
|
return {
|
|
'accounts': {
|
|
'short': _default_account('short', 10000),
|
|
'medium': _default_account('medium', 10000),
|
|
'long': _default_account('long', 10000),
|
|
},
|
|
'last_updated': None,
|
|
}
|
|
|
|
|
|
def _default_account(timeframe: str, balance: float) -> Dict:
|
|
return {
|
|
'timeframe': timeframe,
|
|
'balance': balance,
|
|
'initial_balance': balance,
|
|
'leverage': 10, # 所有周期统一 10 倍杠杆
|
|
'position': None,
|
|
'trades': [],
|
|
'stats': {
|
|
'total_trades': 0,
|
|
'winning_trades': 0,
|
|
'losing_trades': 0,
|
|
'total_pnl': 0.0,
|
|
'max_drawdown': 0.0,
|
|
'peak_balance': balance,
|
|
'win_rate': 0.0,
|
|
},
|
|
'equity_curve': [],
|
|
}
|
|
|
|
|
|
def load_latest_signal() -> Dict[str, Any]:
|
|
"""加载最新信号"""
|
|
try:
|
|
if SIGNAL_FILE.exists():
|
|
with open(SIGNAL_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
print(f"Error loading signal: {e}")
|
|
return {}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""返回前端页面"""
|
|
html_file = Path(__file__).parent / 'static' / 'index.html'
|
|
if html_file.exists():
|
|
return FileResponse(html_file)
|
|
return {"error": "Static files not found"}
|
|
|
|
|
|
@app.get("/api/status")
|
|
async def get_status():
|
|
"""获取多周期交易状态"""
|
|
state = load_trading_state()
|
|
accounts = state.get('accounts', {})
|
|
|
|
# 计算总余额
|
|
total_balance = sum(acc.get('balance', 0) for acc in accounts.values())
|
|
total_initial = sum(acc.get('initial_balance', 0) for acc in accounts.values())
|
|
total_return = (total_balance - total_initial) / total_initial * 100 if total_initial > 0 else 0
|
|
|
|
# 构建各周期状态
|
|
timeframes = {}
|
|
for tf_key, acc in accounts.items():
|
|
initial = acc.get('initial_balance', 0)
|
|
balance = acc.get('balance', 0)
|
|
return_pct = (balance - initial) / initial * 100 if initial > 0 else 0
|
|
|
|
timeframes[tf_key] = {
|
|
'name': '短周期' if tf_key == 'short' else '中周期' if tf_key == 'medium' else '长周期',
|
|
'name_en': 'Short-term' if tf_key == 'short' else 'Medium-term' if tf_key == 'medium' else 'Long-term',
|
|
'balance': balance,
|
|
'initial_balance': initial,
|
|
'return_pct': return_pct,
|
|
'leverage': acc.get('leverage', 1),
|
|
'position': acc.get('position'),
|
|
'stats': acc.get('stats', {}),
|
|
}
|
|
|
|
return {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'total_balance': total_balance,
|
|
'total_initial_balance': total_initial,
|
|
'total_return': total_return,
|
|
'timeframes': timeframes,
|
|
'last_updated': state.get('last_updated'),
|
|
}
|
|
|
|
|
|
@app.get("/api/trades")
|
|
async def get_trades(timeframe: str = None, limit: int = 50):
|
|
"""获取交易记录"""
|
|
state = load_trading_state()
|
|
accounts = state.get('accounts', {})
|
|
|
|
all_trades = []
|
|
for tf_key, acc in accounts.items():
|
|
if timeframe and tf_key != timeframe:
|
|
continue
|
|
trades = acc.get('trades', [])
|
|
all_trades.extend(trades)
|
|
|
|
# 按时间排序
|
|
all_trades.sort(key=lambda x: x.get('exit_time', ''), reverse=True)
|
|
|
|
return {
|
|
'total': len(all_trades),
|
|
'trades': all_trades[:limit] if limit > 0 else all_trades,
|
|
}
|
|
|
|
|
|
@app.get("/api/equity")
|
|
async def get_equity_curve(timeframe: str = None, limit: int = 500):
|
|
"""获取权益曲线"""
|
|
state = load_trading_state()
|
|
accounts = state.get('accounts', {})
|
|
|
|
result = {}
|
|
for tf_key, acc in accounts.items():
|
|
if timeframe and tf_key != timeframe:
|
|
continue
|
|
equity_curve = acc.get('equity_curve', [])
|
|
result[tf_key] = equity_curve[-limit:] if limit > 0 else equity_curve
|
|
|
|
return {
|
|
'data': result,
|
|
}
|
|
|
|
|
|
@app.get("/api/signal")
|
|
async def get_signal():
|
|
"""获取最新信号"""
|
|
signal = load_latest_signal()
|
|
|
|
agg = signal.get('aggregated_signal', {})
|
|
llm = agg.get('llm_signal', {})
|
|
market = signal.get('market_analysis', {})
|
|
|
|
# 提取各周期机会
|
|
opportunities = llm.get('opportunities', {})
|
|
|
|
return {
|
|
'timestamp': agg.get('timestamp'),
|
|
'final_signal': agg.get('final_signal'),
|
|
'final_confidence': agg.get('final_confidence'),
|
|
'current_price': agg.get('levels', {}).get('current_price') or market.get('price'),
|
|
'opportunities': {
|
|
'short': opportunities.get('short_term_5m_15m_1h') or opportunities.get('intraday'),
|
|
'medium': opportunities.get('medium_term_4h_1d') or opportunities.get('swing'),
|
|
'long': opportunities.get('long_term_1d_1w'),
|
|
},
|
|
'reasoning': llm.get('reasoning'),
|
|
'recommendations': llm.get('recommendations_by_timeframe', {}),
|
|
}
|
|
|
|
|
|
@app.get("/api/timeframe/{timeframe}")
|
|
async def get_timeframe_detail(timeframe: str):
|
|
"""获取单个周期详情"""
|
|
state = load_trading_state()
|
|
accounts = state.get('accounts', {})
|
|
|
|
if timeframe not in accounts:
|
|
return {"error": f"Timeframe '{timeframe}' not found"}
|
|
|
|
acc = accounts[timeframe]
|
|
initial = acc.get('initial_balance', 0)
|
|
balance = acc.get('balance', 0)
|
|
|
|
return {
|
|
'timeframe': timeframe,
|
|
'balance': balance,
|
|
'initial_balance': initial,
|
|
'return_pct': (balance - initial) / initial * 100 if initial > 0 else 0,
|
|
'leverage': acc.get('leverage', 1),
|
|
'position': acc.get('position'),
|
|
'stats': acc.get('stats', {}),
|
|
'recent_trades': acc.get('trades', [])[-20:],
|
|
'equity_curve': acc.get('equity_curve', [])[-200:],
|
|
}
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
"""WebSocket 实时推送"""
|
|
await manager.connect(websocket)
|
|
|
|
try:
|
|
# 发送初始状态
|
|
state = load_trading_state()
|
|
signal = load_latest_signal()
|
|
await websocket.send_json({
|
|
'type': 'init',
|
|
'state': state,
|
|
'signal': signal,
|
|
})
|
|
|
|
# 持续推送更新
|
|
last_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
|
|
last_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
|
|
|
|
while True:
|
|
await asyncio.sleep(1)
|
|
|
|
current_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
|
|
current_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
|
|
|
|
if current_state_mtime > last_state_mtime:
|
|
last_state_mtime = current_state_mtime
|
|
state = load_trading_state()
|
|
await websocket.send_json({
|
|
'type': 'state_update',
|
|
'state': state,
|
|
})
|
|
|
|
if current_signal_mtime > last_signal_mtime:
|
|
last_signal_mtime = current_signal_mtime
|
|
signal = load_latest_signal()
|
|
await websocket.send_json({
|
|
'type': 'signal_update',
|
|
'signal': signal,
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
except Exception as e:
|
|
print(f"WebSocket error: {e}")
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
# 静态文件
|
|
static_dir = Path(__file__).parent / 'static'
|
|
if static_dir.exists():
|
|
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|