226 lines
7.3 KiB
Python
226 lines
7.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
A股量化交易系统 Web 展示界面
|
|
使用Flask框架展示策略筛选结果
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
from flask import Flask, render_template, jsonify, request
|
|
from datetime import datetime, date, timedelta
|
|
import pandas as pd
|
|
|
|
# 添加项目根目录到路径
|
|
current_dir = Path(__file__).parent
|
|
project_root = current_dir.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from src.database.database_manager import DatabaseManager
|
|
from src.utils.config_loader import ConfigLoader
|
|
from loguru import logger
|
|
|
|
app = Flask(__name__)
|
|
app.secret_key = 'trading_ai_secret_key_2023'
|
|
|
|
# 初始化组件
|
|
db_manager = DatabaseManager()
|
|
config_loader = ConfigLoader()
|
|
|
|
|
|
@app.route('/')
|
|
def index():
|
|
"""首页 - 直接跳转到交易信号页面"""
|
|
from flask import redirect, url_for
|
|
return redirect(url_for('signals'))
|
|
|
|
|
|
@app.route('/signals')
|
|
def signals():
|
|
"""信号页面 - 详细的信号列表"""
|
|
try:
|
|
# 获取查询参数
|
|
strategy_name = request.args.get('strategy', '')
|
|
timeframe = request.args.get('timeframe', '')
|
|
days = int(request.args.get('days', 30))
|
|
page = int(request.args.get('page', 1))
|
|
per_page = int(request.args.get('per_page', 20))
|
|
|
|
# 计算日期范围
|
|
end_date = date.today()
|
|
start_date = end_date - timedelta(days=days)
|
|
|
|
# 获取信号数据
|
|
signals_df = db_manager.get_signals_by_date_range(
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
strategy_name=strategy_name if strategy_name else None,
|
|
timeframe=timeframe if timeframe else None
|
|
)
|
|
|
|
# 按扫描日期分组,每组内按信号日期排序
|
|
signals_grouped = {}
|
|
if not signals_df.empty:
|
|
# 确保scan_time是datetime类型
|
|
signals_df['scan_time'] = pd.to_datetime(signals_df['scan_time'])
|
|
signals_df['scan_date'] = signals_df['scan_time'].dt.date
|
|
|
|
# 按扫描日期分组
|
|
for scan_date, group in signals_df.groupby('scan_date'):
|
|
# 每组内按信号日期排序(降序)
|
|
group_sorted = group.sort_values('signal_date', ascending=False)
|
|
signals_grouped[scan_date] = group_sorted
|
|
|
|
# 按扫描日期排序(最新的在前)
|
|
signals_grouped = dict(sorted(signals_grouped.items(), key=lambda x: x[0], reverse=True))
|
|
|
|
# 分页处理
|
|
total_records = len(signals_df)
|
|
start_idx = (page - 1) * per_page
|
|
end_idx = start_idx + per_page
|
|
|
|
# 将分组数据展平用于分页
|
|
flattened_signals = []
|
|
for scan_date, group in signals_grouped.items():
|
|
flattened_signals.extend(group.to_dict('records'))
|
|
|
|
paginated_signals = flattened_signals[start_idx:end_idx]
|
|
|
|
# 重新按扫描日期分组分页后的数据
|
|
paginated_grouped = {}
|
|
for signal in paginated_signals:
|
|
scan_date = pd.to_datetime(signal['scan_time']).date()
|
|
if scan_date not in paginated_grouped:
|
|
paginated_grouped[scan_date] = []
|
|
paginated_grouped[scan_date].append(signal)
|
|
|
|
# 计算分页信息
|
|
total_pages = (total_records + per_page - 1) // per_page
|
|
has_prev = page > 1
|
|
has_next = page < total_pages
|
|
|
|
return render_template('signals.html',
|
|
signals_grouped=paginated_grouped,
|
|
current_page=page,
|
|
total_pages=total_pages,
|
|
has_prev=has_prev,
|
|
has_next=has_next,
|
|
total_records=total_records,
|
|
strategy_name=strategy_name,
|
|
timeframe=timeframe,
|
|
days=days,
|
|
per_page=per_page)
|
|
except Exception as e:
|
|
logger.error(f"信号页面数据加载失败: {e}")
|
|
return render_template('error.html', error=str(e))
|
|
|
|
|
|
@app.route('/pullbacks')
|
|
def pullbacks():
|
|
"""回踩监控页面"""
|
|
try:
|
|
days = int(request.args.get('days', 30))
|
|
pullback_alerts = db_manager.get_pullback_alerts(days=days)
|
|
|
|
return render_template('pullbacks.html',
|
|
pullback_alerts=pullback_alerts.to_dict('records') if not pullback_alerts.empty else [],
|
|
days=days)
|
|
except Exception as e:
|
|
logger.error(f"回踩监控页面数据加载失败: {e}")
|
|
return render_template('error.html', error=str(e))
|
|
|
|
|
|
@app.route('/api/signals')
|
|
def api_signals():
|
|
"""API接口 - 获取信号数据"""
|
|
try:
|
|
strategy_name = request.args.get('strategy', '')
|
|
limit = int(request.args.get('limit', 100))
|
|
|
|
signals_df = db_manager.get_latest_signals(
|
|
strategy_name=strategy_name if strategy_name else None,
|
|
limit=limit
|
|
)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'data': signals_df.to_dict('records') if not signals_df.empty else [],
|
|
'total': len(signals_df)
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"API获取信号失败: {e}")
|
|
return jsonify({'success': False, 'error': str(e)})
|
|
|
|
|
|
@app.route('/api/stats')
|
|
def api_stats():
|
|
"""API接口 - 获取策略统计"""
|
|
try:
|
|
strategy_stats = db_manager.get_strategy_stats()
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'data': strategy_stats.to_dict('records') if not strategy_stats.empty else []
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"API获取统计失败: {e}")
|
|
return jsonify({'success': False, 'error': str(e)})
|
|
|
|
|
|
@app.route('/api/pullbacks')
|
|
def api_pullbacks():
|
|
"""API接口 - 获取回踩提醒"""
|
|
try:
|
|
days = int(request.args.get('days', 7))
|
|
pullback_alerts = db_manager.get_pullback_alerts(days=days)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'data': pullback_alerts.to_dict('records') if not pullback_alerts.empty else []
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"API获取回踩提醒失败: {e}")
|
|
return jsonify({'success': False, 'error': str(e)})
|
|
|
|
|
|
@app.template_filter('datetime_format')
|
|
def datetime_format(value, format='%Y-%m-%d %H:%M'):
|
|
"""日期时间格式化过滤器"""
|
|
if value is None:
|
|
return ''
|
|
if isinstance(value, str):
|
|
try:
|
|
value = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
|
except:
|
|
return value
|
|
return value.strftime(format)
|
|
|
|
|
|
@app.template_filter('percentage')
|
|
def percentage_format(value, precision=2):
|
|
"""百分比格式化过滤器"""
|
|
if value is None:
|
|
return '0.00%'
|
|
return f"{float(value):.{precision}f}%"
|
|
|
|
|
|
@app.template_filter('currency')
|
|
def currency_format(value, precision=2):
|
|
"""货币格式化过滤器"""
|
|
if value is None:
|
|
return '0.00'
|
|
return f"{float(value):.{precision}f}"
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 设置日志
|
|
logger.remove()
|
|
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
|
|
|
|
print("=" * 60)
|
|
print("🌐 A股量化交易系统 Web 界面")
|
|
print("=" * 60)
|
|
print("🚀 启动 Flask 服务器...")
|
|
print("📊 访问地址: http://localhost:8080")
|
|
print("=" * 60)
|
|
|
|
app.run(host='0.0.0.0', port=8080, debug=True) |