376 lines
13 KiB
Python
376 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
AI 智能选股大师 MySQL版本 Web 展示界面
|
||
使用Flask框架展示策略筛选结果,支持MySQL数据库
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
from flask import Flask, render_template, jsonify, request
|
||
from datetime import datetime, date, timedelta
|
||
import pandas as pd
|
||
|
||
# 添加项目根目录到路径
|
||
current_dir = Path(__file__).parent
|
||
project_root = current_dir.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from src.database.mysql_database_manager import MySQLDatabaseManager
|
||
from src.utils.config_loader import ConfigLoader
|
||
from loguru import logger
|
||
|
||
app = Flask(__name__)
|
||
app.secret_key = 'trading_ai_mysql_secret_key_2023'
|
||
|
||
# 初始化组件
|
||
db_manager = MySQLDatabaseManager()
|
||
config_loader = ConfigLoader()
|
||
|
||
# 全局分析状态跟踪
|
||
analysis_status = {
|
||
'is_running': False,
|
||
'start_time': None,
|
||
'stock_count': 0,
|
||
'progress': 0,
|
||
'current_stock': '',
|
||
'estimated_completion': None
|
||
}
|
||
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""首页 - 重定向到信号页面"""
|
||
from flask import redirect, url_for
|
||
return redirect(url_for('signals'))
|
||
|
||
|
||
@app.route('/signals')
|
||
def signals():
|
||
"""信号页面 - 详细的信号列表"""
|
||
try:
|
||
# 获取查询参数
|
||
strategy_name = request.args.get('strategy', '')
|
||
timeframe = request.args.get('timeframe', '')
|
||
days = int(request.args.get('days', 7)) # 默认显示7天内的信号
|
||
page = int(request.args.get('page', 1))
|
||
per_page = int(request.args.get('per_page', 20))
|
||
|
||
# 计算日期范围
|
||
end_date = date.today()
|
||
start_date = end_date - timedelta(days=days)
|
||
|
||
# 获取信号数据
|
||
signals_df = db_manager.get_signals_by_date_range(
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
strategy_name=strategy_name if strategy_name else None,
|
||
timeframe=timeframe if timeframe else None
|
||
)
|
||
|
||
# 按扫描时间(精确到小时)分组,每组内按信号日期排序
|
||
signals_grouped = {}
|
||
if not signals_df.empty:
|
||
# 确保scan_time是datetime类型
|
||
signals_df['scan_time'] = pd.to_datetime(signals_df['scan_time'])
|
||
|
||
# Docker容器中已设置正确时区,直接使用数据库时间
|
||
# 创建小时级别的分组键
|
||
signals_df['scan_hour'] = signals_df['scan_time'].dt.floor('h')
|
||
|
||
# 按扫描小时分组
|
||
for scan_hour, group in signals_df.groupby('scan_hour'):
|
||
# 每组内按信号日期排序(降序)
|
||
group_sorted = group.sort_values('signal_date', ascending=False)
|
||
signals_grouped[scan_hour] = group_sorted
|
||
|
||
# 按扫描小时排序(最新的在前)
|
||
signals_grouped = dict(sorted(signals_grouped.items(), key=lambda x: x[0], reverse=True))
|
||
|
||
# 分页处理
|
||
total_records = len(signals_df)
|
||
start_idx = (page - 1) * per_page
|
||
end_idx = start_idx + per_page
|
||
|
||
# 将分组数据展平用于分页
|
||
flattened_signals = []
|
||
for scan_hour, group in signals_grouped.items():
|
||
flattened_signals.extend(group.to_dict('records'))
|
||
|
||
paginated_signals = flattened_signals[start_idx:end_idx]
|
||
|
||
# 重新按扫描小时分组分页后的数据
|
||
paginated_grouped = {}
|
||
for signal in paginated_signals:
|
||
# Docker容器中已设置正确时区,直接使用
|
||
scan_time = pd.to_datetime(signal['scan_time'])
|
||
scan_hour = scan_time.floor('h')
|
||
|
||
if scan_hour not in paginated_grouped:
|
||
paginated_grouped[scan_hour] = []
|
||
paginated_grouped[scan_hour].append(signal)
|
||
|
||
# 计算分页信息
|
||
total_pages = (total_records + per_page - 1) // per_page
|
||
has_prev = page > 1
|
||
has_next = page < total_pages
|
||
|
||
return render_template('signals.html',
|
||
signals_grouped=paginated_grouped,
|
||
current_page=page,
|
||
total_pages=total_pages,
|
||
has_prev=has_prev,
|
||
has_next=has_next,
|
||
total_records=total_records,
|
||
strategy_name=strategy_name,
|
||
timeframe=timeframe,
|
||
days=days,
|
||
per_page=per_page)
|
||
except Exception as e:
|
||
logger.error(f"信号页面数据加载失败: {e}")
|
||
return render_template('error.html', error=str(e))
|
||
|
||
|
||
|
||
@app.route('/api/signals')
|
||
def api_signals():
|
||
"""API接口 - 获取信号数据"""
|
||
try:
|
||
strategy_name = request.args.get('strategy', '')
|
||
limit = int(request.args.get('limit', 100))
|
||
|
||
signals_df = db_manager.get_latest_signals(
|
||
strategy_name=strategy_name if strategy_name else None,
|
||
limit=limit
|
||
)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': signals_df.to_dict('records') if not signals_df.empty else [],
|
||
'total': len(signals_df)
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"API获取信号失败: {e}")
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
@app.route('/api/stats')
|
||
def api_stats():
|
||
"""API接口 - 获取策略统计"""
|
||
try:
|
||
strategy_stats = db_manager.get_strategy_stats()
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': strategy_stats.to_dict('records') if not strategy_stats.empty else []
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"API获取统计失败: {e}")
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
@app.route('/api/clear_signals', methods=['POST'])
|
||
def api_clear_signals():
|
||
"""API接口 - 清空信号数据"""
|
||
try:
|
||
# 获取清空范围参数
|
||
days = request.json.get('days', 7) if request.is_json else int(request.form.get('days', 7))
|
||
strategy_name = request.json.get('strategy_name', '') if request.is_json else request.form.get('strategy_name', '')
|
||
|
||
# 调用数据库管理器的清空方法
|
||
deleted_count = db_manager.clear_signals(
|
||
days=days,
|
||
strategy_name=strategy_name if strategy_name else None
|
||
)
|
||
|
||
logger.info(f"清空信号完成: 删除了 {deleted_count} 条信号记录 (范围: {days}天, 策略: {strategy_name or '全部'})")
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': f'成功清空 {deleted_count} 条信号记录',
|
||
'deleted_count': deleted_count
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"清空信号失败: {e}")
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
@app.route('/api/run_analysis', methods=['POST'])
|
||
def api_run_analysis():
|
||
"""API接口 - 立即运行市场分析"""
|
||
global analysis_status
|
||
|
||
try:
|
||
import subprocess
|
||
import threading
|
||
from datetime import datetime
|
||
|
||
# 检查是否已有分析在运行
|
||
if analysis_status['is_running']:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': '已有分析任务正在运行,请等待完成后再试'
|
||
})
|
||
|
||
# 获取分析参数
|
||
stock_count = request.json.get('stock_count', 200) if request.is_json else int(request.form.get('stock_count', 200))
|
||
|
||
def run_analysis_background():
|
||
"""后台运行分析任务"""
|
||
global analysis_status
|
||
|
||
try:
|
||
# 更新分析状态
|
||
analysis_status.update({
|
||
'is_running': True,
|
||
'start_time': datetime.now(),
|
||
'stock_count': stock_count,
|
||
'progress': 0,
|
||
'current_stock': '准备中...',
|
||
'estimated_completion': None
|
||
})
|
||
|
||
logger.info(f"开始后台市场分析: 扫描 {stock_count} 只股票")
|
||
|
||
# 执行市场分析命令
|
||
result = subprocess.run([
|
||
sys.executable, 'market_scanner.py', str(stock_count)
|
||
],
|
||
cwd=str(project_root),
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=1800 # 30分钟超时
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
logger.info(f"后台市场分析完成: 扫描 {stock_count} 只股票")
|
||
analysis_status['progress'] = 100
|
||
analysis_status['current_stock'] = '分析完成'
|
||
else:
|
||
logger.error(f"后台市场分析失败: {result.stderr}")
|
||
|
||
except subprocess.TimeoutExpired:
|
||
logger.error("后台市场分析超时 (30分钟)")
|
||
except Exception as e:
|
||
logger.error(f"后台市场分析异常: {e}")
|
||
finally:
|
||
# 重置分析状态
|
||
analysis_status.update({
|
||
'is_running': False,
|
||
'progress': 0,
|
||
'current_stock': '',
|
||
'start_time': None,
|
||
'stock_count': 0,
|
||
'estimated_completion': None
|
||
})
|
||
|
||
# 启动后台线程执行分析
|
||
analysis_thread = threading.Thread(target=run_analysis_background)
|
||
analysis_thread.daemon = True
|
||
analysis_thread.start()
|
||
|
||
logger.info(f"市场分析任务已启动: 扫描 {stock_count} 只股票 (后台执行)")
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': f'市场分析任务已启动,正在后台扫描 {stock_count} 只股票',
|
||
'stock_count': stock_count,
|
||
'start_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"启动市场分析失败: {e}")
|
||
analysis_status['is_running'] = False
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
@app.route('/api/analysis_status')
|
||
def api_analysis_status():
|
||
"""API接口 - 获取市场分析状态"""
|
||
global analysis_status
|
||
|
||
try:
|
||
# 计算运行时间
|
||
running_time = None
|
||
if analysis_status['is_running'] and analysis_status['start_time']:
|
||
running_time = (datetime.now() - analysis_status['start_time']).total_seconds()
|
||
|
||
# 估算进度(基于时间的粗略估算)
|
||
estimated_progress = 0
|
||
if analysis_status['is_running'] and running_time:
|
||
# 假设分析200只股票需要约5分钟,根据时间估算进度
|
||
stock_count = analysis_status.get('stock_count', 200)
|
||
estimated_time = (stock_count / 200) * 300 # 每200只股票约300秒
|
||
estimated_progress = min(95, int((running_time / estimated_time) * 100))
|
||
analysis_status['progress'] = estimated_progress
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'is_running': analysis_status['is_running'],
|
||
'start_time': analysis_status['start_time'].strftime('%Y-%m-%d %H:%M:%S') if analysis_status['start_time'] else None,
|
||
'stock_count': analysis_status['stock_count'],
|
||
'progress': analysis_status['progress'],
|
||
'current_stock': analysis_status['current_stock'],
|
||
'running_time': int(running_time) if running_time else 0,
|
||
'estimated_completion': analysis_status['estimated_completion']
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取分析状态失败: {e}")
|
||
return jsonify({'success': False, 'error': str(e)})
|
||
|
||
|
||
|
||
|
||
@app.template_filter('datetime_format')
|
||
def datetime_format(value, format='%Y-%m-%d %H:%M'):
|
||
"""日期时间格式化过滤器 - Docker容器时区已设置"""
|
||
if value is None:
|
||
return ''
|
||
|
||
if isinstance(value, str):
|
||
try:
|
||
value = datetime.fromisoformat(value.replace('Z', '+00:00') if 'Z' in value else value)
|
||
except:
|
||
return value
|
||
|
||
# Docker容器中已设置正确时区,直接格式化
|
||
if isinstance(value, datetime):
|
||
return value.strftime(format)
|
||
|
||
return str(value)
|
||
|
||
|
||
@app.template_filter('percentage')
|
||
def percentage_format(value, precision=2):
|
||
"""百分比格式化过滤器"""
|
||
if value is None:
|
||
return '0.00%'
|
||
return f"{float(value):.{precision}f}%"
|
||
|
||
|
||
@app.template_filter('currency')
|
||
def currency_format(value, precision=2):
|
||
"""货币格式化过滤器"""
|
||
if value is None:
|
||
return '0.00'
|
||
return f"{float(value):.{precision}f}"
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 设置日志
|
||
logger.remove()
|
||
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
|
||
|
||
print("=" * 60)
|
||
print("🌐 A股量化交易系统 Web 界面 (MySQL版)")
|
||
print("=" * 60)
|
||
print("🚀 启动 Flask 服务器...")
|
||
print("📊 访问地址: http://localhost:8080")
|
||
print("🗄️ 数据库: MySQL")
|
||
print(f"📡 主机: {db_manager.config.host}")
|
||
print(f"📋 数据库: {db_manager.config.database}")
|
||
print("=" * 60)
|
||
|
||
app.run(host='0.0.0.0', port=8081, debug=True) |