tradusai/web/api.py
2025-12-09 12:27:47 +08:00

243 lines
7.1 KiB
Python

"""
FastAPI Web Service - 模拟盘状态展示 API
"""
import json
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional, List
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, FileResponse
from pydantic import BaseModel
# 状态文件路径
STATE_FILE = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
SIGNAL_FILE = Path(__file__).parent.parent / 'output' / 'latest_signal.json'
app = FastAPI(title="Paper Trading Dashboard", version="1.0.0")
# WebSocket 连接管理
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: dict):
for connection in self.active_connections:
try:
await connection.send_json(message)
except:
pass
manager = ConnectionManager()
def load_trading_state() -> Dict[str, Any]:
"""加载交易状态"""
try:
if STATE_FILE.exists():
with open(STATE_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading state: {e}")
return {
'balance': 10000.0,
'position': None,
'trades': [],
'stats': {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'total_pnl': 0.0,
'max_drawdown': 0.0,
'peak_balance': 10000.0,
'win_rate': 0.0,
},
'equity_curve': [],
}
def load_latest_signal() -> Dict[str, Any]:
"""加载最新信号"""
try:
if SIGNAL_FILE.exists():
with open(SIGNAL_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading signal: {e}")
return {}
@app.get("/")
async def root():
"""返回前端页面"""
html_file = Path(__file__).parent / 'static' / 'index.html'
if html_file.exists():
return FileResponse(html_file)
return HTMLResponse("<h1>Paper Trading Dashboard</h1><p>Static files not found</p>")
@app.get("/api/status")
async def get_status():
"""获取模拟盘状态"""
state = load_trading_state()
signal = load_latest_signal()
# 计算总收益率
initial_balance = 10000.0
total_return = (state.get('balance', initial_balance) - initial_balance) / initial_balance * 100
return {
'timestamp': datetime.now().isoformat(),
'balance': state.get('balance', initial_balance),
'initial_balance': initial_balance,
'total_return': total_return,
'position': state.get('position'),
'stats': state.get('stats', {}),
'last_updated': state.get('last_updated'),
}
@app.get("/api/trades")
async def get_trades(limit: int = 50):
"""获取交易记录"""
state = load_trading_state()
trades = state.get('trades', [])
return {
'total': len(trades),
'trades': trades[-limit:] if limit > 0 else trades,
}
@app.get("/api/equity")
async def get_equity_curve(limit: int = 500):
"""获取权益曲线"""
state = load_trading_state()
equity_curve = state.get('equity_curve', [])
return {
'total': len(equity_curve),
'data': equity_curve[-limit:] if limit > 0 else equity_curve,
}
@app.get("/api/signal")
async def get_signal():
"""获取最新信号"""
signal = load_latest_signal()
# 提取关键信息
agg = signal.get('aggregated_signal', {})
llm = agg.get('llm_signal', {})
quant = agg.get('quantitative_signal', {})
market = signal.get('market_analysis', {})
return {
'timestamp': agg.get('timestamp'),
'final_signal': agg.get('final_signal'),
'final_confidence': agg.get('final_confidence'),
'consensus': agg.get('consensus'),
'current_price': agg.get('levels', {}).get('current_price'),
'llm': {
'signal': llm.get('signal_type'),
'confidence': llm.get('confidence'),
'reasoning': llm.get('reasoning'),
'opportunities': llm.get('opportunities', {}),
'recommendations': llm.get('recommendations_by_timeframe', {}),
},
'quantitative': {
'signal': quant.get('signal_type'),
'confidence': quant.get('confidence'),
'composite_score': quant.get('composite_score'),
'scores': quant.get('scores', {}),
},
'market': {
'price': market.get('price'),
'trend': market.get('trend', {}),
'momentum': market.get('momentum', {}),
},
}
@app.get("/api/position")
async def get_position():
"""获取当前持仓详情"""
state = load_trading_state()
position = state.get('position')
if not position:
return {'has_position': False, 'position': None}
return {
'has_position': position.get('side') != 'FLAT' and position.get('total_size', 0) > 0,
'position': position,
}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket 实时推送"""
await manager.connect(websocket)
try:
# 发送初始状态
state = load_trading_state()
signal = load_latest_signal()
await websocket.send_json({
'type': 'init',
'state': state,
'signal': signal,
})
# 持续推送更新
last_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
last_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
while True:
await asyncio.sleep(1) # 每秒检查一次
# 检查状态文件更新
current_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
current_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
if current_state_mtime > last_state_mtime:
last_state_mtime = current_state_mtime
state = load_trading_state()
await websocket.send_json({
'type': 'state_update',
'state': state,
})
if current_signal_mtime > last_signal_mtime:
last_signal_mtime = current_signal_mtime
signal = load_latest_signal()
await websocket.send_json({
'type': 'signal_update',
'signal': signal,
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket)
# 静态文件
static_dir = Path(__file__).parent / 'static'
if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)