267 lines
8.9 KiB
Python
267 lines
8.9 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
AI 智能选股大师 Web 展示界面
|
||
使用Flask框架展示策略筛选结果
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
from flask import Flask, render_template, jsonify, request
|
||
from datetime import datetime, date, timedelta, timezone
|
||
import pandas as pd
|
||
|
||
# 添加项目根目录到路径
|
||
current_dir = Path(__file__).parent
|
||
project_root = current_dir.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from src.database.database_manager import DatabaseManager
|
||
from src.utils.config_loader import ConfigLoader
|
||
from loguru import logger
|
||
|
||
app = Flask(__name__)
|
||
app.secret_key = 'trading_ai_secret_key_2023'
|
||
|
||
# 初始化组件
|
||
db_manager = DatabaseManager()
|
||
config_loader = ConfigLoader()
|
||
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""首页 - 直接跳转到交易信号页面"""
|
||
from flask import redirect, url_for
|
||
return redirect(url_for('signals'))
|
||
|
||
|
||
@app.route('/signals')
|
||
def signals():
|
||
"""信号页面 - 详细的信号列表"""
|
||
try:
|
||
# 获取查询参数
|
||
strategy_name = request.args.get('strategy', '')
|
||
timeframe = request.args.get('timeframe', '')
|
||
days = int(request.args.get('days', 30))
|
||
page = int(request.args.get('page', 1))
|
||
per_page = int(request.args.get('per_page', 20))
|
||
|
||
# 计算日期范围
|
||
end_date = date.today()
|
||
start_date = end_date - timedelta(days=days)
|
||
|
||
# 获取信号数据
|
||
signals_df = db_manager.get_signals_by_date_range(
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
strategy_name=strategy_name if strategy_name else None,
|
||
timeframe=timeframe if timeframe else None
|
||
)
|
||
|
||
# 按扫描时间(精确到小时)分组,每组内按信号日期排序
|
||
signals_grouped = {}
|
||
if not signals_df.empty:
|
||
# 确保scan_time是datetime类型
|
||
signals_df['scan_time'] = pd.to_datetime(signals_df['scan_time'])
|
||
|
||
# 转换为东八区时间
|
||
china_tz = timezone(timedelta(hours=8))
|
||
|
||
# 如果是naive datetime,假设是UTC时间
|
||
if signals_df['scan_time'].dt.tz is None:
|
||
signals_df['scan_time'] = signals_df['scan_time'].dt.tz_localize('UTC')
|
||
|
||
# 转换为东八区时间
|
||
signals_df['scan_time_china'] = signals_df['scan_time'].dt.tz_convert(china_tz)
|
||
|
||
# 创建小时级别的分组键(基于东八区时间)
|
||
signals_df['scan_hour'] = signals_df['scan_time_china'].dt.floor('h')
|
||
|
||
# 按扫描小时分组
|
||
for scan_hour, group in signals_df.groupby('scan_hour'):
|
||
# 每组内按信号日期排序(降序)
|
||
group_sorted = group.sort_values('signal_date', ascending=False)
|
||
signals_grouped[scan_hour] = group_sorted
|
||
|
||
# 按扫描小时排序(最新的在前)
|
||
signals_grouped = dict(sorted(signals_grouped.items(), key=lambda x: x[0], reverse=True))
|
||
|
||
# 分页处理
|
||
total_records = len(signals_df)
|
||
start_idx = (page - 1) * per_page
|
||
end_idx = start_idx + per_page
|
||
|
||
# 将分组数据展平用于分页
|
||
flattened_signals = []
|
||
for scan_hour, group in signals_grouped.items():
|
||
flattened_signals.extend(group.to_dict('records'))
|
||
|
||
paginated_signals = flattened_signals[start_idx:end_idx]
|
||
|
||
# 重新按扫描小时分组分页后的数据
|
||
paginated_grouped = {}
|
||
for signal in paginated_signals:
|
||
# 转换为东八区时间进行分组
|
||
china_tz = timezone(timedelta(hours=8))
|
||
|
||
scan_time = pd.to_datetime(signal['scan_time'])
|
||
# 如果是naive datetime,假设是UTC时间
|
||
if scan_time.tz is None:
|
||
scan_time = scan_time.tz_localize('UTC')
|
||
# 转换为东八区时间并按小时分组
|
||
scan_hour = scan_time.tz_convert(china_tz).floor('h')
|
||
|
||
if scan_hour not in paginated_grouped:
|
||
paginated_grouped[scan_hour] = []
|
||
paginated_grouped[scan_hour].append(signal)
|
||
|
||
# 计算分页信息
|
||
total_pages = (total_records + per_page - 1) // per_page
|
||
has_prev = page > 1
|
||
has_next = page < total_pages
|
||
|
||
return render_template('signals.html',
|
||
signals_grouped=paginated_grouped,
|
||
current_page=page,
|
||
total_pages=total_pages,
|
||
has_prev=has_prev,
|
||
has_next=has_next,
|
||
total_records=total_records,
|
||
strategy_name=strategy_name,
|
||
timeframe=timeframe,
|
||
days=days,
|
||
per_page=per_page)
|
||
except Exception as e:
|
||
logger.error(f"信号页面数据加载失败: {e}")
|
||
return render_template('error.html', error=str(e))
|
||
|
||
|
||
@app.route('/pullbacks')
|
||
def pullbacks():
|
||
"""回踩监控页面"""
|
||
try:
|
||
days = int(request.args.get('days', 30))
|
||
pullback_alerts = db_manager.get_pullback_alerts(days=days)
|
||
|
||
return render_template('pullbacks.html',
|
||
pullback_alerts=pullback_alerts.to_dict('records') if not pullback_alerts.empty else [],
|
||
days=days)
|
||
except Exception as e:
|
||
logger.error(f"回踩监控页面数据加载失败: {e}")
|
||
return render_template('error.html', error=str(e))
|
||
|
||
|
||
@app.route('/api/signals')
|
||
def api_signals():
|
||
"""API接口 - 获取信号数据"""
|
||
try:
|
||
strategy_name = request.args.get('strategy', '')
|
||
limit = int(request.args.get('limit', 100))
|
||
|
||
signals_df = db_manager.get_latest_signals(
|
||
strategy_name=strategy_name if strategy_name else None,
|
||
limit=limit
|
||
)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': signals_df.to_dict('records') if not signals_df.empty else [],
|
||
'total': len(signals_df)
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"API获取信号失败: {e}")
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
@app.route('/api/stats')
|
||
def api_stats():
|
||
"""API接口 - 获取策略统计"""
|
||
try:
|
||
strategy_stats = db_manager.get_strategy_stats()
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': strategy_stats.to_dict('records') if not strategy_stats.empty else []
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"API获取统计失败: {e}")
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
@app.route('/api/pullbacks')
|
||
def api_pullbacks():
|
||
"""API接口 - 获取回踩提醒"""
|
||
try:
|
||
days = int(request.args.get('days', 7))
|
||
pullback_alerts = db_manager.get_pullback_alerts(days=days)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': pullback_alerts.to_dict('records') if not pullback_alerts.empty else []
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"API获取回踩提醒失败: {e}")
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
@app.template_filter('datetime_format')
|
||
def datetime_format(value, format='%Y-%m-%d %H:%M'):
|
||
"""日期时间格式化过滤器 - 转换为东八区时间"""
|
||
if value is None:
|
||
return ''
|
||
|
||
# 使用已导入的时区支持
|
||
|
||
if isinstance(value, str):
|
||
try:
|
||
# 解析ISO格式时间字符串
|
||
if 'Z' in value:
|
||
value = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||
elif '+' not in value and 'T' in value:
|
||
# 假设是UTC时间
|
||
value = datetime.fromisoformat(value).replace(tzinfo=timezone.utc)
|
||
else:
|
||
value = datetime.fromisoformat(value)
|
||
except:
|
||
return value
|
||
|
||
# 如果是naive datetime,假设是UTC时间
|
||
if isinstance(value, datetime) and value.tzinfo is None:
|
||
value = value.replace(tzinfo=timezone.utc)
|
||
|
||
# 转换为东八区时间 (UTC+8)
|
||
if isinstance(value, datetime) and value.tzinfo is not None:
|
||
china_tz = timezone(timedelta(hours=8))
|
||
value = value.astimezone(china_tz)
|
||
|
||
return value.strftime(format)
|
||
|
||
|
||
@app.template_filter('percentage')
|
||
def percentage_format(value, precision=2):
|
||
"""百分比格式化过滤器"""
|
||
if value is None:
|
||
return '0.00%'
|
||
return f"{float(value):.{precision}f}%"
|
||
|
||
|
||
@app.template_filter('currency')
|
||
def currency_format(value, precision=2):
|
||
"""货币格式化过滤器"""
|
||
if value is None:
|
||
return '0.00'
|
||
return f"{float(value):.{precision}f}"
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 设置日志
|
||
logger.remove()
|
||
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
|
||
|
||
print("=" * 60)
|
||
print("🌐 A股量化交易系统 Web 界面")
|
||
print("=" * 60)
|
||
print("🚀 启动 Flask 服务器...")
|
||
print("📊 访问地址: http://localhost:8080")
|
||
print("=" * 60)
|
||
|
||
app.run(host='0.0.0.0', port=8081, debug=True) |