trading.ai/web/app.py
2025-09-18 20:45:01 +08:00

226 lines
7.3 KiB
Python

#!/usr/bin/env python3
"""
A股量化交易系统 Web 展示界面
使用Flask框架展示策略筛选结果
"""
import sys
from pathlib import Path
from flask import Flask, render_template, jsonify, request
from datetime import datetime, date, timedelta
import pandas as pd
# 添加项目根目录到路径
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))
from src.database.database_manager import DatabaseManager
from src.utils.config_loader import ConfigLoader
from loguru import logger
app = Flask(__name__)
app.secret_key = 'trading_ai_secret_key_2023'
# 初始化组件
db_manager = DatabaseManager()
config_loader = ConfigLoader()
@app.route('/')
def index():
"""首页 - 直接跳转到交易信号页面"""
from flask import redirect, url_for
return redirect(url_for('signals'))
@app.route('/signals')
def signals():
"""信号页面 - 详细的信号列表"""
try:
# 获取查询参数
strategy_name = request.args.get('strategy', '')
timeframe = request.args.get('timeframe', '')
days = int(request.args.get('days', 30))
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 20))
# 计算日期范围
end_date = date.today()
start_date = end_date - timedelta(days=days)
# 获取信号数据
signals_df = db_manager.get_signals_by_date_range(
start_date=start_date,
end_date=end_date,
strategy_name=strategy_name if strategy_name else None,
timeframe=timeframe if timeframe else None
)
# 按扫描日期分组,每组内按信号日期排序
signals_grouped = {}
if not signals_df.empty:
# 确保scan_time是datetime类型
signals_df['scan_time'] = pd.to_datetime(signals_df['scan_time'])
signals_df['scan_date'] = signals_df['scan_time'].dt.date
# 按扫描日期分组
for scan_date, group in signals_df.groupby('scan_date'):
# 每组内按信号日期排序(降序)
group_sorted = group.sort_values('signal_date', ascending=False)
signals_grouped[scan_date] = group_sorted
# 按扫描日期排序(最新的在前)
signals_grouped = dict(sorted(signals_grouped.items(), key=lambda x: x[0], reverse=True))
# 分页处理
total_records = len(signals_df)
start_idx = (page - 1) * per_page
end_idx = start_idx + per_page
# 将分组数据展平用于分页
flattened_signals = []
for scan_date, group in signals_grouped.items():
flattened_signals.extend(group.to_dict('records'))
paginated_signals = flattened_signals[start_idx:end_idx]
# 重新按扫描日期分组分页后的数据
paginated_grouped = {}
for signal in paginated_signals:
scan_date = pd.to_datetime(signal['scan_time']).date()
if scan_date not in paginated_grouped:
paginated_grouped[scan_date] = []
paginated_grouped[scan_date].append(signal)
# 计算分页信息
total_pages = (total_records + per_page - 1) // per_page
has_prev = page > 1
has_next = page < total_pages
return render_template('signals.html',
signals_grouped=paginated_grouped,
current_page=page,
total_pages=total_pages,
has_prev=has_prev,
has_next=has_next,
total_records=total_records,
strategy_name=strategy_name,
timeframe=timeframe,
days=days,
per_page=per_page)
except Exception as e:
logger.error(f"信号页面数据加载失败: {e}")
return render_template('error.html', error=str(e))
@app.route('/pullbacks')
def pullbacks():
"""回踩监控页面"""
try:
days = int(request.args.get('days', 30))
pullback_alerts = db_manager.get_pullback_alerts(days=days)
return render_template('pullbacks.html',
pullback_alerts=pullback_alerts.to_dict('records') if not pullback_alerts.empty else [],
days=days)
except Exception as e:
logger.error(f"回踩监控页面数据加载失败: {e}")
return render_template('error.html', error=str(e))
@app.route('/api/signals')
def api_signals():
"""API接口 - 获取信号数据"""
try:
strategy_name = request.args.get('strategy', '')
limit = int(request.args.get('limit', 100))
signals_df = db_manager.get_latest_signals(
strategy_name=strategy_name if strategy_name else None,
limit=limit
)
return jsonify({
'success': True,
'data': signals_df.to_dict('records') if not signals_df.empty else [],
'total': len(signals_df)
})
except Exception as e:
logger.error(f"API获取信号失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/stats')
def api_stats():
"""API接口 - 获取策略统计"""
try:
strategy_stats = db_manager.get_strategy_stats()
return jsonify({
'success': True,
'data': strategy_stats.to_dict('records') if not strategy_stats.empty else []
})
except Exception as e:
logger.error(f"API获取统计失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/pullbacks')
def api_pullbacks():
"""API接口 - 获取回踩提醒"""
try:
days = int(request.args.get('days', 7))
pullback_alerts = db_manager.get_pullback_alerts(days=days)
return jsonify({
'success': True,
'data': pullback_alerts.to_dict('records') if not pullback_alerts.empty else []
})
except Exception as e:
logger.error(f"API获取回踩提醒失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.template_filter('datetime_format')
def datetime_format(value, format='%Y-%m-%d %H:%M'):
"""日期时间格式化过滤器"""
if value is None:
return ''
if isinstance(value, str):
try:
value = datetime.fromisoformat(value.replace('Z', '+00:00'))
except:
return value
return value.strftime(format)
@app.template_filter('percentage')
def percentage_format(value, precision=2):
"""百分比格式化过滤器"""
if value is None:
return '0.00%'
return f"{float(value):.{precision}f}%"
@app.template_filter('currency')
def currency_format(value, precision=2):
"""货币格式化过滤器"""
if value is None:
return '0.00'
return f"{float(value):.{precision}f}"
if __name__ == '__main__':
# 设置日志
logger.remove()
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
print("=" * 60)
print("🌐 A股量化交易系统 Web 界面")
print("=" * 60)
print("🚀 启动 Flask 服务器...")
print("📊 访问地址: http://localhost:8080")
print("=" * 60)
app.run(host='0.0.0.0', port=8080, debug=True)