trading.ai/web/mysql_app.py
2025-11-02 10:41:17 +08:00

405 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
AI 智能选股大师 MySQL版本 Web 展示界面
使用Flask框架展示策略筛选结果支持MySQL数据库
"""
import sys
import os
from pathlib import Path
from flask import Flask, render_template, jsonify, request
from datetime import datetime, date, timedelta
import pandas as pd
# 添加项目根目录到路径
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))
from src.database.mysql_database_manager import MySQLDatabaseManager
from src.utils.config_loader import ConfigLoader
from loguru import logger
app = Flask(__name__)
app.secret_key = 'trading_ai_mysql_secret_key_2023'
# 操作验证密钥(放在导入后立即定义)
OPERATION_KEY = os.environ.get('OPERATION_KEY', '9257')
# 初始化组件
db_manager = MySQLDatabaseManager()
config_loader = ConfigLoader()
# 全局分析状态跟踪
analysis_status = {
'is_running': False,
'start_time': None,
'stock_count': 0,
'progress': 0,
'current_stock': '',
'estimated_completion': None
}
@app.route('/')
def index():
"""首页 - 重定向到信号页面"""
from flask import redirect, url_for
return redirect(url_for('signals'))
@app.route('/signals')
def signals():
"""信号页面 - 详细的信号列表"""
try:
# 获取查询参数
strategy_name = request.args.get('strategy', '')
timeframe = request.args.get('timeframe', '')
asset_type = request.args.get('asset_type', '') # 新增资产类型筛选
days = int(request.args.get('days', 7)) # 默认显示7天内的信号
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 20))
# 计算日期范围
end_date = date.today()
start_date = end_date - timedelta(days=days)
# 获取信号数据
signals_df = db_manager.get_signals_by_date_range(
start_date=start_date,
end_date=end_date,
strategy_name=strategy_name if strategy_name else None,
asset_type=asset_type if asset_type else None, # 传递资产类型筛选
timeframe=timeframe if timeframe else None
)
# 按扫描时间(精确到小时)分组,每组内按信号日期排序
signals_grouped = {}
if not signals_df.empty:
# 确保scan_time是datetime类型
signals_df['scan_time'] = pd.to_datetime(signals_df['scan_time'])
# Docker容器中已设置正确时区直接使用数据库时间
# 创建小时级别的分组键
signals_df['scan_hour'] = signals_df['scan_time'].dt.floor('h')
# 按扫描小时分组
for scan_hour, group in signals_df.groupby('scan_hour'):
# 每组内按信号日期排序(降序)
group_sorted = group.sort_values('signal_date', ascending=False)
signals_grouped[scan_hour] = group_sorted
# 按扫描小时排序(最新的在前)
signals_grouped = dict(sorted(signals_grouped.items(), key=lambda x: x[0], reverse=True))
# 分页处理
total_records = len(signals_df)
start_idx = (page - 1) * per_page
end_idx = start_idx + per_page
# 将分组数据展平用于分页
flattened_signals = []
for scan_hour, group in signals_grouped.items():
flattened_signals.extend(group.to_dict('records'))
paginated_signals = flattened_signals[start_idx:end_idx]
# 重新按扫描小时分组分页后的数据
paginated_grouped = {}
for signal in paginated_signals:
# Docker容器中已设置正确时区直接使用
scan_time = pd.to_datetime(signal['scan_time'])
scan_hour = scan_time.floor('h')
if scan_hour not in paginated_grouped:
paginated_grouped[scan_hour] = []
paginated_grouped[scan_hour].append(signal)
# 计算分页信息
total_pages = (total_records + per_page - 1) // per_page
has_prev = page > 1
has_next = page < total_pages
return render_template('signals.html',
signals_grouped=paginated_grouped,
current_page=page,
total_pages=total_pages,
has_prev=has_prev,
has_next=has_next,
total_records=total_records,
strategy_name=strategy_name,
asset_type=asset_type, # 传递资产类型到模板
timeframe=timeframe,
days=days,
per_page=per_page)
except Exception as e:
logger.error(f"信号页面数据加载失败: {e}")
return render_template('error.html', error=str(e))
@app.route('/api/signals')
def api_signals():
"""API接口 - 获取信号数据"""
try:
strategy_name = request.args.get('strategy', '')
limit = int(request.args.get('limit', 100))
signals_df = db_manager.get_latest_signals(
strategy_name=strategy_name if strategy_name else None,
limit=limit
)
return jsonify({
'success': True,
'data': signals_df.to_dict('records') if not signals_df.empty else [],
'total': len(signals_df)
})
except Exception as e:
logger.error(f"API获取信号失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/stats')
def api_stats():
"""API接口 - 获取策略统计"""
try:
strategy_stats = db_manager.get_strategy_stats()
return jsonify({
'success': True,
'data': strategy_stats.to_dict('records') if not strategy_stats.empty else []
})
except Exception as e:
logger.error(f"API获取统计失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/clear_signals', methods=['POST'])
def api_clear_signals():
"""API接口 - 清空信号数据"""
try:
# 获取密钥验证
operation_key = request.json.get('operation_key', '') if request.is_json else request.form.get('operation_key', '')
# 验证密钥
if operation_key != OPERATION_KEY:
logger.warning(f"清空信号操作密钥验证失败: {operation_key}")
return jsonify({
'success': False,
'error': '操作密钥验证失败,请输入正确的验证密钥'
})
# 获取清空范围参数
days = request.json.get('days', 7) if request.is_json else int(request.form.get('days', 7))
strategy_name = request.json.get('strategy_name', '') if request.is_json else request.form.get('strategy_name', '')
# 调用数据库管理器的清空方法
deleted_count = db_manager.clear_signals(
days=days,
strategy_name=strategy_name if strategy_name else None
)
logger.info(f"清空信号完成: 删除了 {deleted_count} 条信号记录 (范围: {days}天, 策略: {strategy_name or '全部'})")
return jsonify({
'success': True,
'message': f'成功清空 {deleted_count} 条信号记录',
'deleted_count': deleted_count
})
except Exception as e:
logger.error(f"清空信号失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/run_analysis', methods=['POST'])
def api_run_analysis():
"""API接口 - 立即运行市场分析"""
global analysis_status
try:
import subprocess
import threading
from datetime import datetime
# 获取密钥验证
operation_key = request.json.get('operation_key', '') if request.is_json else request.form.get('operation_key', '')
# 验证密钥
if operation_key != OPERATION_KEY:
logger.warning(f"立即分析操作密钥验证失败: {operation_key}")
return jsonify({
'success': False,
'error': '操作密钥验证失败,请输入正确的验证密钥'
})
# 检查是否已有分析在运行
if analysis_status['is_running']:
return jsonify({
'success': False,
'error': '已有分析任务正在运行,请等待完成后再试'
})
# 获取分析参数
stock_count = request.json.get('stock_count', 200) if request.is_json else int(request.form.get('stock_count', 200))
def run_analysis_background():
"""后台运行分析任务"""
global analysis_status
try:
# 更新分析状态
analysis_status.update({
'is_running': True,
'start_time': datetime.now(),
'stock_count': stock_count,
'progress': 0,
'current_stock': '准备中...',
'estimated_completion': None
})
logger.info(f"开始后台市场分析: 扫描 {stock_count} 只股票")
# 执行市场分析命令
result = subprocess.run([
sys.executable, 'market_scanner.py', str(stock_count)
],
cwd=str(project_root),
capture_output=True,
text=True,
timeout=1800 # 30分钟超时
)
if result.returncode == 0:
logger.info(f"后台市场分析完成: 扫描 {stock_count} 只股票")
analysis_status['progress'] = 100
analysis_status['current_stock'] = '分析完成'
else:
logger.error(f"后台市场分析失败: {result.stderr}")
except subprocess.TimeoutExpired:
logger.error("后台市场分析超时 (30分钟)")
except Exception as e:
logger.error(f"后台市场分析异常: {e}")
finally:
# 重置分析状态
analysis_status.update({
'is_running': False,
'progress': 0,
'current_stock': '',
'start_time': None,
'stock_count': 0,
'estimated_completion': None
})
# 启动后台线程执行分析
analysis_thread = threading.Thread(target=run_analysis_background)
analysis_thread.daemon = True
analysis_thread.start()
logger.info(f"市场分析任务已启动: 扫描 {stock_count} 只股票 (后台执行)")
return jsonify({
'success': True,
'message': f'市场分析任务已启动,正在后台扫描 {stock_count} 只股票',
'stock_count': stock_count,
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
})
except Exception as e:
logger.error(f"启动市场分析失败: {e}")
analysis_status['is_running'] = False
return jsonify({'success': False, 'error': str(e)})
@app.route('/api/analysis_status')
def api_analysis_status():
"""API接口 - 获取市场分析状态"""
global analysis_status
try:
# 计算运行时间
running_time = None
if analysis_status['is_running'] and analysis_status['start_time']:
running_time = (datetime.now() - analysis_status['start_time']).total_seconds()
# 估算进度(基于时间的粗略估算)
estimated_progress = 0
if analysis_status['is_running'] and running_time:
# 假设分析200只股票需要约5分钟根据时间估算进度
stock_count = analysis_status.get('stock_count', 200)
estimated_time = (stock_count / 200) * 300 # 每200只股票约300秒
estimated_progress = min(95, int((running_time / estimated_time) * 100))
analysis_status['progress'] = estimated_progress
return jsonify({
'success': True,
'data': {
'is_running': analysis_status['is_running'],
'start_time': analysis_status['start_time'].strftime('%Y-%m-%d %H:%M:%S') if analysis_status['start_time'] else None,
'stock_count': analysis_status['stock_count'],
'progress': analysis_status['progress'],
'current_stock': analysis_status['current_stock'],
'running_time': int(running_time) if running_time else 0,
'estimated_completion': analysis_status['estimated_completion']
}
})
except Exception as e:
logger.error(f"获取分析状态失败: {e}")
return jsonify({'success': False, 'error': str(e)})
@app.template_filter('datetime_format')
def datetime_format(value, format='%Y-%m-%d %H:%M'):
"""日期时间格式化过滤器 - Docker容器时区已设置"""
if value is None:
return ''
if isinstance(value, str):
try:
value = datetime.fromisoformat(value.replace('Z', '+00:00') if 'Z' in value else value)
except:
return value
# Docker容器中已设置正确时区直接格式化
if isinstance(value, datetime):
return value.strftime(format)
return str(value)
@app.template_filter('percentage')
def percentage_format(value, precision=2):
"""百分比格式化过滤器"""
if value is None:
return '0.00%'
return f"{float(value):.{precision}f}%"
@app.template_filter('currency')
def currency_format(value, precision=2):
"""货币格式化过滤器"""
if value is None:
return '0.00'
return f"{float(value):.{precision}f}"
if __name__ == '__main__':
# 设置日志
logger.remove()
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
print("=" * 60)
print("🌐 A股量化交易系统 Web 界面 (MySQL版)")
print("=" * 60)
print("🚀 启动 Flask 服务器...")
print("📊 访问地址: http://localhost:8080")
print("🗄️ 数据库: MySQL")
print(f"📡 主机: {db_manager.config.host}")
print(f"📋 数据库: {db_manager.config.database}")
print("=" * 60)
app.run(host='0.0.0.0', port=8081, debug=True)