tradusai/web/api.py
2025-12-09 22:46:04 +08:00

663 lines
22 KiB
Python

"""
FastAPI Web Service - 多币种多周期交易状态展示 API
"""
import json
import asyncio
import urllib.request
import ssl
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.settings import settings
# 状态文件路径
STATE_FILE = Path(__file__).parent.parent / 'output' / 'paper_trading_state.json'
SIGNAL_FILE = Path(__file__).parent.parent / 'output' / 'latest_signal.json'
SIGNALS_FILE = Path(__file__).parent.parent / 'output' / 'latest_signals.json'
# 支持的币种列表
SYMBOLS = settings.symbols_list
# Binance API - 多币种价格
BINANCE_PRICE_BASE_URL = "https://fapi.binance.com/fapi/v1/ticker/price"
app = FastAPI(title="Trading Dashboard", version="2.0.0")
# 全局价格缓存 - 多币种
_current_prices: Dict[str, float] = {}
_price_update_time: datetime = None
async def fetch_binance_prices() -> Dict[str, float]:
"""从 Binance 获取所有币种实时价格"""
global _current_prices, _price_update_time
try:
loop = asyncio.get_event_loop()
prices = await loop.run_in_executor(None, _fetch_prices_sync)
if prices:
_current_prices.update(prices)
_price_update_time = datetime.now()
return _current_prices
except Exception as e:
print(f"Error fetching Binance prices: {type(e).__name__}: {e}")
return _current_prices
async def fetch_binance_price(symbol: str = 'BTCUSDT') -> Optional[float]:
"""从 Binance 获取单个币种实时价格(向后兼容)"""
prices = await fetch_binance_prices()
return prices.get(symbol)
def _fetch_prices_sync() -> Dict[str, float]:
"""同步获取所有币种价格"""
prices = {}
try:
ctx = ssl.create_default_context()
for symbol in SYMBOLS:
try:
url = f"{BINANCE_PRICE_BASE_URL}?symbol={symbol}"
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
with urllib.request.urlopen(req, timeout=5, context=ctx) as response:
data = json.loads(response.read().decode('utf-8'))
prices[symbol] = float(data['price'])
except Exception as e:
print(f"Fetch {symbol} price error: {type(e).__name__}: {e}")
except Exception as e:
print(f"Sync fetch error: {type(e).__name__}: {e}")
return prices
def _fetch_price_sync(symbol: str = 'BTCUSDT') -> Optional[float]:
"""同步获取单个币种价格(向后兼容)"""
try:
ctx = ssl.create_default_context()
url = f"{BINANCE_PRICE_BASE_URL}?symbol={symbol}"
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
with urllib.request.urlopen(req, timeout=5, context=ctx) as response:
data = json.loads(response.read().decode('utf-8'))
return float(data['price'])
except Exception as e:
print(f"Sync fetch error: {type(e).__name__}: {e}")
return None
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict):
for connection in self.active_connections:
try:
await connection.send_json(message)
except:
pass
manager = ConnectionManager()
def load_trading_state() -> Dict[str, Any]:
"""加载交易状态"""
try:
if STATE_FILE.exists():
with open(STATE_FILE, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading state: {e}")
# 返回默认状态 - 多币种格式
default_symbols = {}
for symbol in SYMBOLS:
default_symbols[symbol] = {
'short': _default_account('short', 10000, symbol),
'medium': _default_account('medium', 10000, symbol),
'long': _default_account('long', 10000, symbol),
}
return {
'symbols': default_symbols,
'accounts': default_symbols.get(SYMBOLS[0], {}) if SYMBOLS else {}, # 向后兼容
'last_updated': None,
}
def _default_account(timeframe: str, initial_balance: float, symbol: str = 'BTCUSDT') -> Dict:
return {
'timeframe': timeframe,
'symbol': symbol,
'initial_balance': initial_balance,
'realized_pnl': 0.0,
'leverage': 10, # 所有周期统一 10 倍杠杆
'position': None,
'trades': [],
'stats': {
'total_trades': 0,
'winning_trades': 0,
'losing_trades': 0,
'total_pnl': 0.0,
'max_drawdown': 0.0,
'peak_balance': initial_balance,
'win_rate': 0.0,
},
'equity_curve': [],
}
def load_latest_signal(symbol: str = None) -> Dict[str, Any]:
"""加载最新信号
Args:
symbol: 指定币种,若为空则加载所有
"""
try:
if symbol:
# 加载单个币种信号
symbol_file = Path(__file__).parent.parent / 'output' / f'signal_{symbol.lower()}.json'
if symbol_file.exists():
with open(symbol_file, 'r') as f:
return json.load(f)
# 降级到旧文件
elif SIGNAL_FILE.exists():
with open(SIGNAL_FILE, 'r') as f:
return json.load(f)
else:
# 加载所有币种信号
if SIGNALS_FILE.exists():
with open(SIGNALS_FILE, 'r') as f:
return json.load(f)
elif SIGNAL_FILE.exists():
with open(SIGNAL_FILE, 'r') as f:
data = json.load(f)
# 转换为新格式
return {
'timestamp': data.get('timestamp'),
'symbols': {SYMBOLS[0] if SYMBOLS else 'BTCUSDT': data}
}
except Exception as e:
print(f"Error loading signal: {e}")
return {}
@app.get("/")
async def root():
"""返回前端页面"""
html_file = Path(__file__).parent / 'static' / 'index.html'
if html_file.exists():
return FileResponse(html_file)
return {"error": "Static files not found"}
@app.get("/api/status")
async def get_status(symbol: str = None):
"""获取多币种多周期交易状态
Args:
symbol: 指定币种(可选),若为空则返回所有币种汇总
"""
state = load_trading_state()
# 检查是否是新的多币种格式
if 'symbols' in state:
return _get_multi_symbol_status(state, symbol)
else:
# 旧格式兼容
return _get_legacy_status(state)
def _get_multi_symbol_status(state: Dict, symbol: str = None) -> Dict:
"""处理多币种状态"""
symbols_data = state.get('symbols', {})
if symbol:
# 返回单个币种状态
if symbol not in symbols_data:
return {"error": f"Symbol '{symbol}' not found"}
return _build_symbol_status(symbol, symbols_data[symbol], state.get('last_updated'))
# 返回所有币种汇总
grand_total_initial = 0
grand_total_realized_pnl = 0
grand_total_equity = 0
all_symbols_status = {}
for sym, accounts in symbols_data.items():
sym_status = _build_symbol_status(sym, accounts, None)
all_symbols_status[sym] = sym_status
grand_total_initial += sym_status.get('total_initial_balance', 0)
grand_total_realized_pnl += sym_status.get('total_realized_pnl', 0)
grand_total_equity += sym_status.get('total_equity', 0)
grand_total_return = (grand_total_equity - grand_total_initial) / grand_total_initial * 100 if grand_total_initial > 0 else 0
# 向后兼容:保留 timeframes 字段
first_symbol = SYMBOLS[0] if SYMBOLS else None
legacy_timeframes = all_symbols_status.get(first_symbol, {}).get('timeframes', {}) if first_symbol else {}
return {
'timestamp': datetime.now().isoformat(),
'symbols': all_symbols_status,
'supported_symbols': SYMBOLS,
'timeframes': legacy_timeframes, # 向后兼容
'grand_total_initial_balance': grand_total_initial,
'grand_total_realized_pnl': grand_total_realized_pnl,
'grand_total_equity': grand_total_equity,
'grand_total_return': grand_total_return,
# 向后兼容字段
'total_initial_balance': grand_total_initial,
'total_realized_pnl': grand_total_realized_pnl,
'total_equity': grand_total_equity,
'total_return': grand_total_return,
'last_updated': state.get('last_updated'),
}
def _build_symbol_status(symbol: str, accounts: Dict, last_updated: str = None) -> Dict:
"""构建单个币种的状态"""
total_initial = 0
total_realized_pnl = 0
total_equity = 0
timeframes = {}
for tf_key, acc in accounts.items():
initial = acc.get('initial_balance', 0)
realized_pnl = acc.get('realized_pnl', 0)
if 'realized_pnl' not in acc and 'balance' in acc:
realized_pnl = acc['balance'] - initial
equity = initial + realized_pnl
position = acc.get('position')
used_margin = position.get('margin', 0) if position else 0
available_balance = equity - used_margin
total_initial += initial
total_realized_pnl += realized_pnl
total_equity += equity
return_pct = (equity - initial) / initial * 100 if initial > 0 else 0
timeframes[tf_key] = {
'name': '短周期' if tf_key == 'short' else '中周期' if tf_key == 'medium' else '长周期',
'name_en': 'Short-term' if tf_key == 'short' else 'Medium-term' if tf_key == 'medium' else 'Long-term',
'symbol': symbol,
'initial_balance': initial,
'realized_pnl': realized_pnl,
'equity': equity,
'available_balance': available_balance,
'used_margin': used_margin,
'return_pct': return_pct,
'leverage': acc.get('leverage', 10),
'position': position,
'stats': acc.get('stats', {}),
}
total_return = (total_equity - total_initial) / total_initial * 100 if total_initial > 0 else 0
return {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'total_initial_balance': total_initial,
'total_realized_pnl': total_realized_pnl,
'total_equity': total_equity,
'total_return': total_return,
'timeframes': timeframes,
'last_updated': last_updated,
}
def _get_legacy_status(state: Dict) -> Dict:
"""处理旧格式状态(向后兼容)"""
accounts = state.get('accounts', {})
total_initial = 0
total_realized_pnl = 0
total_equity = 0
timeframes = {}
for tf_key, acc in accounts.items():
initial = acc.get('initial_balance', 0)
realized_pnl = acc.get('realized_pnl', 0)
if 'realized_pnl' not in acc and 'balance' in acc:
realized_pnl = acc['balance'] - initial
equity = initial + realized_pnl
position = acc.get('position')
used_margin = position.get('margin', 0) if position else 0
available_balance = equity - used_margin
total_initial += initial
total_realized_pnl += realized_pnl
total_equity += equity
return_pct = (equity - initial) / initial * 100 if initial > 0 else 0
timeframes[tf_key] = {
'name': '短周期' if tf_key == 'short' else '中周期' if tf_key == 'medium' else '长周期',
'name_en': 'Short-term' if tf_key == 'short' else 'Medium-term' if tf_key == 'medium' else 'Long-term',
'initial_balance': initial,
'realized_pnl': realized_pnl,
'equity': equity,
'available_balance': available_balance,
'used_margin': used_margin,
'return_pct': return_pct,
'leverage': acc.get('leverage', 10),
'position': position,
'stats': acc.get('stats', {}),
}
total_return = (total_equity - total_initial) / total_initial * 100 if total_initial > 0 else 0
return {
'timestamp': datetime.now().isoformat(),
'total_initial_balance': total_initial,
'total_realized_pnl': total_realized_pnl,
'total_equity': total_equity,
'total_return': total_return,
'timeframes': timeframes,
'last_updated': state.get('last_updated'),
}
@app.get("/api/trades")
async def get_trades(symbol: str = None, timeframe: str = None, limit: int = 50):
"""获取交易记录
Args:
symbol: 指定币种(可选)
timeframe: 指定周期(可选)
limit: 返回数量限制
"""
state = load_trading_state()
all_trades = []
# 检查是否是新的多币种格式
if 'symbols' in state:
symbols_data = state.get('symbols', {})
for sym, accounts in symbols_data.items():
if symbol and sym != symbol:
continue
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
trades = acc.get('trades', [])
# 确保每个交易都有 symbol 字段
for trade in trades:
if 'symbol' not in trade:
trade['symbol'] = sym
all_trades.extend(trades)
else:
# 旧格式
accounts = state.get('accounts', {})
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
trades = acc.get('trades', [])
all_trades.extend(trades)
# 按时间排序
all_trades.sort(key=lambda x: x.get('exit_time', ''), reverse=True)
return {
'total': len(all_trades),
'trades': all_trades[:limit] if limit > 0 else all_trades,
}
@app.get("/api/equity")
async def get_equity_curve(symbol: str = None, timeframe: str = None, limit: int = 500):
"""获取权益曲线
Args:
symbol: 指定币种(可选)
timeframe: 指定周期(可选)
limit: 返回数量限制
"""
state = load_trading_state()
result = {}
if 'symbols' in state:
symbols_data = state.get('symbols', {})
for sym, accounts in symbols_data.items():
if symbol and sym != symbol:
continue
sym_result = {}
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
equity_curve = acc.get('equity_curve', [])
sym_result[tf_key] = equity_curve[-limit:] if limit > 0 else equity_curve
if sym_result:
result[sym] = sym_result
else:
# 旧格式
accounts = state.get('accounts', {})
for tf_key, acc in accounts.items():
if timeframe and tf_key != timeframe:
continue
equity_curve = acc.get('equity_curve', [])
result[tf_key] = equity_curve[-limit:] if limit > 0 else equity_curve
return {
'data': result,
}
@app.get("/api/signal")
async def get_signal(symbol: str = None):
"""获取最新信号
Args:
symbol: 指定币种(可选),若为空则返回所有
"""
if symbol:
# 加载单个币种信号
signal = load_latest_signal(symbol)
return _format_signal_response(signal, symbol)
else:
# 加载所有币种信号
all_signals = load_latest_signal()
if 'symbols' in all_signals:
# 新的多币种格式
result = {
'timestamp': all_signals.get('timestamp'),
'symbols': {},
'supported_symbols': SYMBOLS,
}
for sym, sig_data in all_signals.get('symbols', {}).items():
result['symbols'][sym] = _format_signal_response(sig_data, sym)
return result
else:
# 旧格式
return _format_signal_response(all_signals, SYMBOLS[0] if SYMBOLS else 'BTCUSDT')
def _format_signal_response(signal: Dict, symbol: str) -> Dict:
"""格式化信号响应"""
agg = signal.get('aggregated_signal', {})
llm = agg.get('llm_signal', {})
market = signal.get('market_analysis', {})
opportunities = llm.get('opportunities', {})
return {
'symbol': symbol,
'timestamp': signal.get('timestamp') or agg.get('timestamp'),
'final_signal': agg.get('final_signal'),
'final_confidence': agg.get('final_confidence'),
'current_price': agg.get('levels', {}).get('current_price') or market.get('price'),
'opportunities': {
'short': opportunities.get('short_term_5m_15m_1h') or opportunities.get('intraday'),
'medium': opportunities.get('medium_term_4h_1d') or opportunities.get('swing'),
'long': opportunities.get('long_term_1d_1w'),
},
'reasoning': llm.get('reasoning'),
'recommendations': llm.get('recommendations_by_timeframe', {}),
}
@app.get("/api/timeframe/{timeframe}")
async def get_timeframe_detail(timeframe: str, symbol: str = None):
"""获取单个周期详情
Args:
timeframe: 周期 (short, medium, long)
symbol: 指定币种(可选),默认第一个
"""
state = load_trading_state()
symbol = symbol or (SYMBOLS[0] if SYMBOLS else 'BTCUSDT')
if 'symbols' in state:
symbols_data = state.get('symbols', {})
if symbol not in symbols_data:
return {"error": f"Symbol '{symbol}' not found"}
accounts = symbols_data[symbol]
else:
accounts = state.get('accounts', {})
if timeframe not in accounts:
return {"error": f"Timeframe '{timeframe}' not found"}
acc = accounts[timeframe]
initial = acc.get('initial_balance', 0)
realized_pnl = acc.get('realized_pnl', 0)
equity = initial + realized_pnl
return {
'symbol': symbol,
'timeframe': timeframe,
'equity': equity,
'initial_balance': initial,
'realized_pnl': realized_pnl,
'return_pct': (equity - initial) / initial * 100 if initial > 0 else 0,
'leverage': acc.get('leverage', 10),
'position': acc.get('position'),
'stats': acc.get('stats', {}),
'recent_trades': acc.get('trades', [])[-20:],
'equity_curve': acc.get('equity_curve', [])[-200:],
}
@app.get("/api/prices")
async def get_prices():
"""获取所有币种实时价格"""
prices = await fetch_binance_prices()
return {
'timestamp': datetime.now().isoformat(),
'prices': prices,
'supported_symbols': SYMBOLS,
}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket 实时推送 - 支持多币种"""
await manager.connect(websocket)
try:
# 获取所有币种初始实时价格
current_prices = await fetch_binance_prices()
# 发送初始状态
state = load_trading_state()
signal = load_latest_signal()
await websocket.send_json({
'type': 'init',
'state': state,
'signal': signal,
'prices': current_prices,
'current_price': current_prices.get(SYMBOLS[0]) if SYMBOLS else None, # 向后兼容
'supported_symbols': SYMBOLS,
})
# 持续推送更新
last_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
last_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
last_signals_mtime = SIGNALS_FILE.stat().st_mtime if SIGNALS_FILE.exists() else 0
last_prices = current_prices.copy()
while True:
await asyncio.sleep(1)
# 每秒获取所有币种实时价格并推送
current_prices = await fetch_binance_prices()
price_changed = False
for sym, price in current_prices.items():
if price and price != last_prices.get(sym):
price_changed = True
break
if price_changed:
last_prices = current_prices.copy()
await websocket.send_json({
'type': 'price_update',
'prices': current_prices,
'current_price': current_prices.get(SYMBOLS[0]) if SYMBOLS else None, # 向后兼容
'timestamp': datetime.now().isoformat(),
})
# 检查状态文件更新
current_state_mtime = STATE_FILE.stat().st_mtime if STATE_FILE.exists() else 0
current_signal_mtime = SIGNAL_FILE.stat().st_mtime if SIGNAL_FILE.exists() else 0
current_signals_mtime = SIGNALS_FILE.stat().st_mtime if SIGNALS_FILE.exists() else 0
if current_state_mtime > last_state_mtime:
last_state_mtime = current_state_mtime
state = load_trading_state()
await websocket.send_json({
'type': 'state_update',
'state': state,
})
# 检查信号文件更新(新格式或旧格式)
signal_updated = (current_signal_mtime > last_signal_mtime or
current_signals_mtime > last_signals_mtime)
if signal_updated:
last_signal_mtime = current_signal_mtime
last_signals_mtime = current_signals_mtime
signal = load_latest_signal()
await websocket.send_json({
'type': 'signal_update',
'signal': signal,
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket)
# 静态文件
static_dir = Path(__file__).parent / 'static'
if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)