tradusai/web/api.py
2025-12-09 21:51:09 +08:00

376 lines
12 KiB
Python

"""
FastAPI Web Service - 多周期交易状态展示 API
"""
import json
import asyncio
import urllib.request
import ssl
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# 状态文件路径
STATE_FILE = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
SIGNAL_FILE = Path(__file__).parent.parent / 'output' / 'latest_signal.json'
# Binance API
BINANCE_PRICE_URL = "https://fapi.binance.com/fapi/v1/ticker/price?symbol=BTCUSDT"
app = FastAPI(title="Trading Dashboard", version="2.0.0")
# 全局价格缓存
_current_price: float = 0.0
_price_update_time: datetime = None
async def fetch_binance_price() -> Optional[float]:
"""从 Binance 获取实时价格(使用标准库)"""
global _current_price, _price_update_time
try:
# 使用线程池执行同步请求,避免阻塞事件循环
loop = asyncio.get_event_loop()
price = await loop.run_in_executor(None, _fetch_price_sync)
if price:
_current_price = price
_price_update_time = datetime.now()
return _current_price
except Exception as e:
print(f"Error fetching Binance price: {type(e).__name__}: {e}")
return _current_price if _current_price > 0 else None
def _fetch_price_sync() -> Optional[float]:
"""同步获取价格"""
try:
# 创建 SSL 上下文
ctx = ssl.create_default_context()
req = urllib.request.Request(
BINANCE_PRICE_URL,
headers={'User-Agent': 'Mozilla/5.0'}
)
with urllib.request.urlopen(req, timeout=5, context=ctx) as response:
data = json.loads(response.read().decode('utf-8'))
return float(data['price'])
except Exception as e:
print(f"Sync fetch error: {type(e).__name__}: {e}")
return None
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict):
for connection in self.active_connections:
try:
await connection.send_json(message)
except:
pass
manager = ConnectionManager()
def load_trading_state() -> Dict[str, Any]:
"""加载交易状态"""
try:
if STATE_FILE.exists():
with open(STATE_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading state: {e}")
# 返回默认状态
return {
'accounts': {
'short': _default_account('short', 10000),
'medium': _default_account('medium', 10000),
'long': _default_account('long', 10000),
},
'last_updated': None,
}
def _default_account(timeframe: str, initial_balance: float) -> Dict:
return {
'timeframe': timeframe,
'initial_balance': initial_balance,
'realized_pnl': 0.0,
'leverage': 10, # 所有周期统一 10 倍杠杆
'position': None,
'trades': [],
'stats': {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'total_pnl': 0.0,
'max_drawdown': 0.0,
'peak_balance': initial_balance,
'win_rate': 0.0,
},
'equity_curve': [],
}
def load_latest_signal() -> Dict[str, Any]:
"""加载最新信号"""
try:
if SIGNAL_FILE.exists():
with open(SIGNAL_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading signal: {e}")
return {}
@app.get("/")
async def root():
"""返回前端页面"""
html_file = Path(__file__).parent / 'static' / 'index.html'
if html_file.exists():
return FileResponse(html_file)
return {"error": "Static files not found"}
@app.get("/api/status")
async def get_status():
"""获取多周期交易状态"""
state = load_trading_state()
accounts = state.get('accounts', {})
total_initial = 0
total_realized_pnl = 0
total_equity = 0
# 构建各周期状态
timeframes = {}
for tf_key, acc in accounts.items():
initial = acc.get('initial_balance', 0)
realized_pnl = acc.get('realized_pnl', 0)
# 兼容旧数据格式
if 'realized_pnl' not in acc and 'balance' in acc:
realized_pnl = acc['balance'] - initial
# 计算权益(不含未实现盈亏,因为 API 没有实时价格)
equity = initial + realized_pnl
# 检查持仓的保证金
position = acc.get('position')
used_margin = position.get('margin', 0) if position else 0
available_balance = equity - used_margin
total_initial += initial
total_realized_pnl += realized_pnl
total_equity += equity
return_pct = (equity - initial) / initial * 100 if initial > 0 else 0
timeframes[tf_key] = {
'name': '短周期' if tf_key == 'short' else '中周期' if tf_key == 'medium' else '长周期',
'name_en': 'Short-term' if tf_key == 'short' else 'Medium-term' if tf_key == 'medium' else 'Long-term',
'initial_balance': initial,
'realized_pnl': realized_pnl,
'equity': equity,
'available_balance': available_balance,
'used_margin': used_margin,
'return_pct': return_pct,
'leverage': acc.get('leverage', 10),
'position': position,
'stats': acc.get('stats', {}),
}
total_return = (total_equity - total_initial) / total_initial * 100 if total_initial > 0 else 0
return {
'timestamp': datetime.now().isoformat(),
'total_initial_balance': total_initial,
'total_realized_pnl': total_realized_pnl,
'total_equity': total_equity,
'total_return': total_return,
'timeframes': timeframes,
'last_updated': state.get('last_updated'),
}
@app.get("/api/trades")
async def get_trades(timeframe: str = None, limit: int = 50):
"""获取交易记录"""
state = load_trading_state()
accounts = state.get('accounts', {})
all_trades = []
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
trades = acc.get('trades', [])
all_trades.extend(trades)
# 按时间排序
all_trades.sort(key=lambda x: x.get('exit_time', ''), reverse=True)
return {
'total': len(all_trades),
'trades': all_trades[:limit] if limit > 0 else all_trades,
}
@app.get("/api/equity")
async def get_equity_curve(timeframe: str = None, limit: int = 500):
"""获取权益曲线"""
state = load_trading_state()
accounts = state.get('accounts', {})
result = {}
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
equity_curve = acc.get('equity_curve', [])
result[tf_key] = equity_curve[-limit:] if limit > 0 else equity_curve
return {
'data': result,
}
@app.get("/api/signal")
async def get_signal():
"""获取最新信号"""
signal = load_latest_signal()
agg = signal.get('aggregated_signal', {})
llm = agg.get('llm_signal', {})
market = signal.get('market_analysis', {})
# 提取各周期机会
opportunities = llm.get('opportunities', {})
return {
'timestamp': agg.get('timestamp'),
'final_signal': agg.get('final_signal'),
'final_confidence': agg.get('final_confidence'),
'current_price': agg.get('levels', {}).get('current_price') or market.get('price'),
'opportunities': {
'short': opportunities.get('short_term_5m_15m_1h') or opportunities.get('intraday'),
'medium': opportunities.get('medium_term_4h_1d') or opportunities.get('swing'),
'long': opportunities.get('long_term_1d_1w'),
},
'reasoning': llm.get('reasoning'),
'recommendations': llm.get('recommendations_by_timeframe', {}),
}
@app.get("/api/timeframe/{timeframe}")
async def get_timeframe_detail(timeframe: str):
"""获取单个周期详情"""
state = load_trading_state()
accounts = state.get('accounts', {})
if timeframe not in accounts:
return {"error": f"Timeframe '{timeframe}' not found"}
acc = accounts[timeframe]
initial = acc.get('initial_balance', 0)
balance = acc.get('balance', 0)
return {
'timeframe': timeframe,
'balance': balance,
'initial_balance': initial,
'return_pct': (balance - initial) / initial * 100 if initial > 0 else 0,
'leverage': acc.get('leverage', 1),
'position': acc.get('position'),
'stats': acc.get('stats', {}),
'recent_trades': acc.get('trades', [])[-20:],
'equity_curve': acc.get('equity_curve', [])[-200:],
}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket 实时推送"""
await manager.connect(websocket)
try:
# 获取初始实时价格
current_price = await fetch_binance_price()
# 发送初始状态
state = load_trading_state()
signal = load_latest_signal()
await websocket.send_json({
'type': 'init',
'state': state,
'signal': signal,
'current_price': current_price,
})
# 持续推送更新
last_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
last_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
last_price = current_price
price_update_counter = 0
while True:
await asyncio.sleep(1)
price_update_counter += 1
# 每秒获取实时价格并推送
current_price = await fetch_binance_price()
if current_price and current_price != last_price:
last_price = current_price
await websocket.send_json({
'type': 'price_update',
'current_price': current_price,
'timestamp': datetime.now().isoformat(),
})
# 检查状态文件更新
current_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
current_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
if current_state_mtime > last_state_mtime:
last_state_mtime = current_state_mtime
state = load_trading_state()
await websocket.send_json({
'type': 'state_update',
'state': state,
})
if current_signal_mtime > last_signal_mtime:
last_signal_mtime = current_signal_mtime
signal = load_latest_signal()
await websocket.send_json({
'type': 'signal_update',
'signal': signal,
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket)
# 静态文件
static_dir = Path(__file__).parent / 'static'
if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)