trading.ai/web/app.py
2025-09-19 20:20:18 +08:00

267 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
AI 智能选股大师 Web 展示界面
使用Flask框架展示策略筛选结果
"""
import sys
from pathlib import Path
from flask import Flask, render_template, jsonify, request
from datetime import datetime, date, timedelta, timezone
import pandas as pd
# 添加项目根目录到路径
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))
from src.database.database_manager import DatabaseManager
from src.utils.config_loader import ConfigLoader
from loguru import logger
app = Flask(__name__)
app.secret_key = 'trading_ai_secret_key_2023'
# 初始化组件
db_manager = DatabaseManager()
config_loader = ConfigLoader()
@app.route('/')
def index():
"""首页 - 直接跳转到交易信号页面"""
from flask import redirect, url_for
return redirect(url_for('signals'))
@app.route('/signals')
def signals():
"""信号页面 - 详细的信号列表"""
try:
# 获取查询参数
strategy_name = request.args.get('strategy', '')
timeframe = request.args.get('timeframe', '')
days = int(request.args.get('days', 30))
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 20))
# 计算日期范围
end_date = date.today()
start_date = end_date - timedelta(days=days)
# 获取信号数据
signals_df = db_manager.get_signals_by_date_range(
start_date=start_date,
end_date=end_date,
strategy_name=strategy_name if strategy_name else None,
timeframe=timeframe if timeframe else None
)
# 按扫描时间(精确到小时)分组,每组内按信号日期排序
signals_grouped = {}
if not signals_df.empty:
# 确保scan_time是datetime类型
signals_df['scan_time'] = pd.to_datetime(signals_df['scan_time'])
# 转换为东八区时间
china_tz = timezone(timedelta(hours=8))
# 如果是naive datetime假设是UTC时间
if signals_df['scan_time'].dt.tz is None:
signals_df['scan_time'] = signals_df['scan_time'].dt.tz_localize('UTC')
# 转换为东八区时间
signals_df['scan_time_china'] = signals_df['scan_time'].dt.tz_convert(china_tz)
# 创建小时级别的分组键(基于东八区时间)
signals_df['scan_hour'] = signals_df['scan_time_china'].dt.floor('h')
# 按扫描小时分组
for scan_hour, group in signals_df.groupby('scan_hour'):
# 每组内按信号日期排序(降序)
group_sorted = group.sort_values('signal_date', ascending=False)
signals_grouped[scan_hour] = group_sorted
# 按扫描小时排序(最新的在前)
signals_grouped = dict(sorted(signals_grouped.items(), key=lambda x: x[0], reverse=True))
# 分页处理
total_records = len(signals_df)
start_idx = (page - 1) * per_page
end_idx = start_idx + per_page
# 将分组数据展平用于分页
flattened_signals = []
for scan_hour, group in signals_grouped.items():
flattened_signals.extend(group.to_dict('records'))
paginated_signals = flattened_signals[start_idx:end_idx]
# 重新按扫描小时分组分页后的数据
paginated_grouped = {}
for signal in paginated_signals:
# 转换为东八区时间进行分组
china_tz = timezone(timedelta(hours=8))
scan_time = pd.to_datetime(signal['scan_time'])
# 如果是naive datetime假设是UTC时间
if scan_time.tz is None:
scan_time = scan_time.tz_localize('UTC')
# 转换为东八区时间并按小时分组
scan_hour = scan_time.tz_convert(china_tz).floor('h')
if scan_hour not in paginated_grouped:
paginated_grouped[scan_hour] = []
paginated_grouped[scan_hour].append(signal)
# 计算分页信息
total_pages = (total_records + per_page - 1) // per_page
has_prev = page > 1
has_next = page < total_pages
return render_template('signals.html',
signals_grouped=paginated_grouped,
current_page=page,
total_pages=total_pages,
has_prev=has_prev,
has_next=has_next,
total_records=total_records,
strategy_name=strategy_name,
timeframe=timeframe,
days=days,
per_page=per_page)
except Exception as e:
logger.error(f"信号页面数据加载失败: {e}")
return render_template('error.html', error=str(e))
@app.route('/pullbacks')
def pullbacks():
"""回踩监控页面"""
try:
days = int(request.args.get('days', 30))
pullback_alerts = db_manager.get_pullback_alerts(days=days)
return render_template('pullbacks.html',
pullback_alerts=pullback_alerts.to_dict('records') if not pullback_alerts.empty else [],
days=days)
except Exception as e:
logger.error(f"回踩监控页面数据加载失败: {e}")
return render_template('error.html', error=str(e))
@app.route('/api/signals')
def api_signals():
"""API接口 - 获取信号数据"""
try:
strategy_name = request.args.get('strategy', '')
limit = int(request.args.get('limit', 100))
signals_df = db_manager.get_latest_signals(
strategy_name=strategy_name if strategy_name else None,
limit=limit
)
return jsonify({
'success': True,
'data': signals_df.to_dict('records') if not signals_df.empty else [],
'total': len(signals_df)
})
except Exception as e:
logger.error(f"API获取信号失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/stats')
def api_stats():
"""API接口 - 获取策略统计"""
try:
strategy_stats = db_manager.get_strategy_stats()
return jsonify({
'success': True,
'data': strategy_stats.to_dict('records') if not strategy_stats.empty else []
})
except Exception as e:
logger.error(f"API获取统计失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/pullbacks')
def api_pullbacks():
"""API接口 - 获取回踩提醒"""
try:
days = int(request.args.get('days', 7))
pullback_alerts = db_manager.get_pullback_alerts(days=days)
return jsonify({
'success': True,
'data': pullback_alerts.to_dict('records') if not pullback_alerts.empty else []
})
except Exception as e:
logger.error(f"API获取回踩提醒失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.template_filter('datetime_format')
def datetime_format(value, format='%Y-%m-%d %H:%M'):
"""日期时间格式化过滤器 - 转换为东八区时间"""
if value is None:
return ''
# 使用已导入的时区支持
if isinstance(value, str):
try:
# 解析ISO格式时间字符串
if 'Z' in value:
value = datetime.fromisoformat(value.replace('Z', '+00:00'))
elif '+' not in value and 'T' in value:
# 假设是UTC时间
value = datetime.fromisoformat(value).replace(tzinfo=timezone.utc)
else:
value = datetime.fromisoformat(value)
except:
return value
# 如果是naive datetime假设是UTC时间
if isinstance(value, datetime) and value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
# 转换为东八区时间 (UTC+8)
if isinstance(value, datetime) and value.tzinfo is not None:
china_tz = timezone(timedelta(hours=8))
value = value.astimezone(china_tz)
return value.strftime(format)
@app.template_filter('percentage')
def percentage_format(value, precision=2):
"""百分比格式化过滤器"""
if value is None:
return '0.00%'
return f"{float(value):.{precision}f}%"
@app.template_filter('currency')
def currency_format(value, precision=2):
"""货币格式化过滤器"""
if value is None:
return '0.00'
return f"{float(value):.{precision}f}"
if __name__ == '__main__':
# 设置日志
logger.remove()
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
print("=" * 60)
print("🌐 A股量化交易系统 Web 界面")
print("=" * 60)
print("🚀 启动 Flask 服务器...")
print("📊 访问地址: http://localhost:8080")
print("=" * 60)
app.run(host='0.0.0.0', port=8080, debug=True)