243 lines
7.1 KiB
Python
243 lines
7.1 KiB
Python
"""
|
|
FastAPI Web Service - 模拟盘状态展示 API
|
|
"""
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, List
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import HTMLResponse, FileResponse
|
|
from pydantic import BaseModel
|
|
|
|
# 状态文件路径
|
|
STATE_FILE = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
|
|
SIGNAL_FILE = Path(__file__).parent.parent / 'output' / 'latest_signal.json'
|
|
|
|
app = FastAPI(title="Paper Trading Dashboard", version="1.0.0")
|
|
|
|
# WebSocket 连接管理
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: List[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
self.active_connections.remove(websocket)
|
|
|
|
async def broadcast(self, message: dict):
|
|
for connection in self.active_connections:
|
|
try:
|
|
await connection.send_json(message)
|
|
except:
|
|
pass
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
def load_trading_state() -> Dict[str, Any]:
|
|
"""加载交易状态"""
|
|
try:
|
|
if STATE_FILE.exists():
|
|
with open(STATE_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
print(f"Error loading state: {e}")
|
|
|
|
return {
|
|
'balance': 10000.0,
|
|
'position': None,
|
|
'trades': [],
|
|
'stats': {
|
|
'total_trades': 0,
|
|
'winning_trades': 0,
|
|
'losing_trades': 0,
|
|
'total_pnl': 0.0,
|
|
'max_drawdown': 0.0,
|
|
'peak_balance': 10000.0,
|
|
'win_rate': 0.0,
|
|
},
|
|
'equity_curve': [],
|
|
}
|
|
|
|
|
|
def load_latest_signal() -> Dict[str, Any]:
|
|
"""加载最新信号"""
|
|
try:
|
|
if SIGNAL_FILE.exists():
|
|
with open(SIGNAL_FILE, 'r') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
print(f"Error loading signal: {e}")
|
|
return {}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""返回前端页面"""
|
|
html_file = Path(__file__).parent / 'static' / 'index.html'
|
|
if html_file.exists():
|
|
return FileResponse(html_file)
|
|
return HTMLResponse("<h1>Paper Trading Dashboard</h1><p>Static files not found</p>")
|
|
|
|
|
|
@app.get("/api/status")
|
|
async def get_status():
|
|
"""获取模拟盘状态"""
|
|
state = load_trading_state()
|
|
signal = load_latest_signal()
|
|
|
|
# 计算总收益率
|
|
initial_balance = 10000.0
|
|
total_return = (state.get('balance', initial_balance) - initial_balance) / initial_balance * 100
|
|
|
|
return {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'balance': state.get('balance', initial_balance),
|
|
'initial_balance': initial_balance,
|
|
'total_return': total_return,
|
|
'position': state.get('position'),
|
|
'stats': state.get('stats', {}),
|
|
'last_updated': state.get('last_updated'),
|
|
}
|
|
|
|
|
|
@app.get("/api/trades")
|
|
async def get_trades(limit: int = 50):
|
|
"""获取交易记录"""
|
|
state = load_trading_state()
|
|
trades = state.get('trades', [])
|
|
return {
|
|
'total': len(trades),
|
|
'trades': trades[-limit:] if limit > 0 else trades,
|
|
}
|
|
|
|
|
|
@app.get("/api/equity")
|
|
async def get_equity_curve(limit: int = 500):
|
|
"""获取权益曲线"""
|
|
state = load_trading_state()
|
|
equity_curve = state.get('equity_curve', [])
|
|
return {
|
|
'total': len(equity_curve),
|
|
'data': equity_curve[-limit:] if limit > 0 else equity_curve,
|
|
}
|
|
|
|
|
|
@app.get("/api/signal")
|
|
async def get_signal():
|
|
"""获取最新信号"""
|
|
signal = load_latest_signal()
|
|
|
|
# 提取关键信息
|
|
agg = signal.get('aggregated_signal', {})
|
|
llm = agg.get('llm_signal', {})
|
|
quant = agg.get('quantitative_signal', {})
|
|
market = signal.get('market_analysis', {})
|
|
|
|
return {
|
|
'timestamp': agg.get('timestamp'),
|
|
'final_signal': agg.get('final_signal'),
|
|
'final_confidence': agg.get('final_confidence'),
|
|
'consensus': agg.get('consensus'),
|
|
'current_price': agg.get('levels', {}).get('current_price'),
|
|
'llm': {
|
|
'signal': llm.get('signal_type'),
|
|
'confidence': llm.get('confidence'),
|
|
'reasoning': llm.get('reasoning'),
|
|
'opportunities': llm.get('opportunities', {}),
|
|
'recommendations': llm.get('recommendations_by_timeframe', {}),
|
|
},
|
|
'quantitative': {
|
|
'signal': quant.get('signal_type'),
|
|
'confidence': quant.get('confidence'),
|
|
'composite_score': quant.get('composite_score'),
|
|
'scores': quant.get('scores', {}),
|
|
},
|
|
'market': {
|
|
'price': market.get('price'),
|
|
'trend': market.get('trend', {}),
|
|
'momentum': market.get('momentum', {}),
|
|
},
|
|
}
|
|
|
|
|
|
@app.get("/api/position")
|
|
async def get_position():
|
|
"""获取当前持仓详情"""
|
|
state = load_trading_state()
|
|
position = state.get('position')
|
|
|
|
if not position:
|
|
return {'has_position': False, 'position': None}
|
|
|
|
return {
|
|
'has_position': position.get('side') != 'FLAT' and position.get('total_size', 0) > 0,
|
|
'position': position,
|
|
}
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
"""WebSocket 实时推送"""
|
|
await manager.connect(websocket)
|
|
|
|
try:
|
|
# 发送初始状态
|
|
state = load_trading_state()
|
|
signal = load_latest_signal()
|
|
await websocket.send_json({
|
|
'type': 'init',
|
|
'state': state,
|
|
'signal': signal,
|
|
})
|
|
|
|
# 持续推送更新
|
|
last_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
|
|
last_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
|
|
|
|
while True:
|
|
await asyncio.sleep(1) # 每秒检查一次
|
|
|
|
# 检查状态文件更新
|
|
current_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
|
|
current_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
|
|
|
|
if current_state_mtime > last_state_mtime:
|
|
last_state_mtime = current_state_mtime
|
|
state = load_trading_state()
|
|
await websocket.send_json({
|
|
'type': 'state_update',
|
|
'state': state,
|
|
})
|
|
|
|
if current_signal_mtime > last_signal_mtime:
|
|
last_signal_mtime = current_signal_mtime
|
|
signal = load_latest_signal()
|
|
await websocket.send_json({
|
|
'type': 'signal_update',
|
|
'signal': signal,
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
except Exception as e:
|
|
print(f"WebSocket error: {e}")
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
# 静态文件
|
|
static_dir = Path(__file__).parent / 'static'
|
|
if static_dir.exists():
|
|
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8080)
|