diff --git a/.env.example b/.env.example index b5cb41c..49b1c9f 100644 --- a/.env.example +++ b/.env.example @@ -7,4 +7,12 @@ MYSQL_DATABASE=tradingai # Flask应用配置 FLASK_ENV=production -PYTHONPATH=/app \ No newline at end of file +PYTHONPATH=/app + +# 市场扫描服务配置 +MARKET_SCAN_STOCKS=200 +LOG_LEVEL=INFO + +# 钉钉通知配置(已启用) +DINGTALK_WEBHOOK_URL=https://oapi.dingtalk.com/robot/send?access_token=50ad2c14e3c8bf7e262ba837dc2a35cb420228ee4165abd69a9e678c901e120e +DINGTALK_SECRET=SEC6e9dbd71d4addd2c4e673fb72d686293b342da5ae48da2f8ec788a68de99f981 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 77c854e..7848212 100644 --- a/.gitignore +++ b/.gitignore @@ -9,8 +9,8 @@ __pycache__/ .env .venv -# Runtime data and databases -data/ +# Runtime data and databases (exclude source code) +/data/ logs/ *.log *.db diff --git a/Dockerfile b/Dockerfile index 0af7e2b..aac556e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,12 +11,13 @@ ENV FLASK_APP=web/mysql_app.py ENV FLASK_ENV=production ENV TZ=Asia/Shanghai -# 安装系统依赖,包括MySQL客户端库 +# 安装系统依赖,包括MySQL客户端库和cron RUN apt-get update && apt-get install -y \ gcc \ g++ \ curl \ tzdata \ + cron \ default-libmysqlclient-dev \ pkg-config \ && rm -rf /var/lib/apt/lists/* diff --git a/README.md b/README.md index 8421740..9f711ff 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # A股量化交易系统 -一个基于Python的A股市场监控和选股量化程序,使用adata数据源。 +一个基于Python的A股市场监控和选股量化程序,使用TuShare数据源。 ## 功能特性 -- 📈 **实时数据获取**: 使用adata获取A股实时行情数据 +- 📈 **实时数据获取**: 使用TuShare获取A股实时行情数据 - 🔍 **舆情分析**: 北向资金、融资融券、热点股票、龙虎榜数据分析 - 📊 **股票筛选**: 基于技术指标和基本面的智能选股 - 💰 **市场监控**: 实时监控价格变动、成交量异常、资金流向 @@ -22,7 +22,7 @@ TradingAI/ ├── src/ # 源代码 │ ├── data/ # 数据获取模块 │ │ ├── __init__.py -│ │ ├── data_fetcher.py # 行情数据获取 +│ │ ├── tushare_fetcher.py # 行情数据获取 │ │ └── sentiment_fetcher.py # 舆情数据获取 │ ├── strategy/ # 策略模块 │ ├── monitor/ # 监控模块 @@ -79,9 +79,9 @@ python main.py ### 获取实时行情 ```python -from src.data.data_fetcher import ADataFetcher +from src.data.tushare_fetcher import TushareFetcher -fetcher = ADataFetcher() +fetcher = TushareFetcher() # 获取单只股票实时数据 data = fetcher.get_realtime_data("000001.SZ") @@ -167,7 +167,7 @@ hist_data = fetcher.get_historical_data( ## 注意事项 -1. 首次使用需要确保网络连接正常,adata需要从网络获取数据 +1. 首次使用需要确保网络连接正常,TuShare需要从网络获取数据 2. 请合理使用数据接口,避免频繁请求 3. 舆情数据仅供参考,投资需谨慎 4. 本系统仅供学习和研究使用,不构成投资建议 diff --git a/STRATEGY_USAGE.md b/STRATEGY_USAGE.md index a713373..36fc425 100644 --- a/STRATEGY_USAGE.md +++ b/STRATEGY_USAGE.md @@ -91,12 +91,12 @@ python main.py ### 3. 程序化使用 ```python -from src.data.data_fetcher import ADataFetcher +from src.data.tushare_fetcher import TushareFetcher from src.utils.notification import NotificationManager from src.strategy.kline_pattern_strategy import KLinePatternStrategy # 初始化组件 -data_fetcher = ADataFetcher() +data_fetcher = TushareFetcher() notification_manager = NotificationManager(notification_config) strategy_config = { @@ -153,7 +153,7 @@ market_results = strategy.scan_market(max_stocks=20) ### 1. 数据源 -- 优先使用adata真实数据 +- 优先使用TuShare真实数据 - 如无法获取真实数据,会生成模拟数据进行测试 - 模拟数据中会人为插入形态信号用于验证策略逻辑 @@ -228,7 +228,7 @@ market_results = strategy.scan_market(max_stocks=20) ### 常见问题 -1. **无法获取数据**: 检查网络连接和adata配置 +1. **无法获取数据**: 检查网络连接和TuShare配置 2. **钉钉通知失败**: 验证webhook地址和安全设置 3. **策略未启用**: 检查配置文件中的 `enabled` 设置 4. **内存占用过高**: 减少 `scan_stocks_count` 和 `analysis_days` diff --git a/USAGE.md b/USAGE.md new file mode 100644 index 0000000..1021fb7 --- /dev/null +++ b/USAGE.md @@ -0,0 +1,174 @@ +# A股量化交易策略系统 + +## 项目简介 + +本项目是一个模块化的A股量化交易策略执行系统,支持定时执行多种策略任务,主要包括: + +- 📊 **数据获取**: 基于TuShare的股票数据获取 +- 🎯 **股票池管理**: 多种股票筛选规则(热榜、龙头股等) +- 📈 **策略分析**: K线形态识别等技术分析策略 +- ⏰ **任务调度**: 灵活的定时任务执行 +- 📱 **结果通知**: 多种通知方式支持 + +## 快速开始 + +### 启动策略测试系统 + +```bash +python main.py +``` + +### 基本命令 + +```bash +# 分析单只股票 +scan 000001.SZ + +# 扫描热门股票(默认20只) +market 30 + +# 查看可用股票池规则 +pools + +# 执行策略任务 +task tushare_hot 15 + +# 查看定时任务示例 +schedule + +# 显示帮助 +help + +# 退出程序 +quit +``` + +## 系统架构 + +``` +数据层: TushareFetcher + StockPoolManager +策略层: BaseStrategy + KLinePatternStrategy +执行层: TaskScheduler + StrategyExecutor +通知层: NotificationManager +``` + +### 核心组件 + +- **TushareFetcher**: TuShare API数据获取器 +- **StockPoolManager**: 股票池管理器,支持多种筛选规则 +- **KLinePatternStrategy**: K线形态识别策略 +- **StrategyExecutor**: 策略执行协调器 +- **TaskScheduler**: 定时任务调度器 + +## 股票池规则 + +- `tushare_hot`: 同花顺热榜 +- `combined_hot`: 合并热门(同花顺+东财) +- `leading_stocks`: 龙头牛股 +- 支持自定义股票列表 + +## 策略功能 + +### K线形态策略 +- 识别"两阳线+阴线+阳线"突破形态 +- 创新高回踩确认机制 +- 多时间周期分析支持 +- 回踩监控功能 + +## 配置文件 + +主要配置文件位于 `config/config.yaml`: + +```yaml +# TuShare配置 +tushare: + token: "your_token_here" + +# 策略配置 +strategy: + kline_pattern: + min_entity_ratio: 0.55 + final_yang_min_ratio: 0.40 + max_turnover_ratio: 40.0 + timeframes: ["daily"] + +# 通知配置 +notification: + dingtalk: + enabled: true + webhook_url: "your_webhook_url" +``` + +## 定时任务示例 + +查看 `examples/` 目录下的配置示例: + +- `examples/new_architecture_example.py`: 完整架构演示 +- `examples/task_config_examples.py`: 常见任务配置 + +## 使用示例 + +### 1. 分析单只股票 + +```bash +> scan 000001.SZ +🔍 分析股票: 000001.SZ +---------------------------------------- +📊 DAILY: 发现 1 个信号 + 1. 2024-01-15 | 两阳+阴+阳突破 | 价格: 12.50元 +📈 总计: 1 个信号 +``` + +### 2. 扫描市场热门股票 + +```bash +> market 20 +🌍 扫描市场热门股票 (前20只) +-------------------------------------------------- +📊 扫描结果: + 股票池: 同花顺热榜 + 总扫描: 20 只 + 有信号: 3 只 + 信号数: 5 个 + 耗时: 15.23 秒 +``` + +### 3. 执行策略任务 + +```bash +> task leading_stocks 10 +⚡ 执行策略任务 +股票池规则: leading_stocks +最大股票数: 10 +---------------------------------------- +✅ 任务完成: + 任务ID: manual_leading_stocks_10 + 成功: 是 + 耗时: 8.45 秒 + 信号数: 2 个 +``` + +## 开发指南 + +### 添加新策略 + +1. 继承 `BaseStrategy` 基类 +2. 实现 `analyze_stock()` 方法 +3. 注册到 `StrategyExecutor` + +### 添加新股票池规则 + +1. 继承 `StockPoolRule` 基类 +2. 实现 `get_stocks()` 方法 +3. 注册到 `StockPoolManager` + +## 注意事项 + +- 确保配置文件中的TuShare token有效 +- 策略分析需要足够的历史数据 +- 定时任务需要稳定的网络连接 +- 建议在交易时间外进行大量数据扫描 + +## 许可证 + +MIT License \ No newline at end of file diff --git a/crontab/market-scanner b/crontab/market-scanner new file mode 100644 index 0000000..c285f94 --- /dev/null +++ b/crontab/market-scanner @@ -0,0 +1,20 @@ +# 市场扫描定时任务配置 +# 格式: 分钟 小时 日 月 星期 命令 +# 时区: Asia/Shanghai + +# 每个工作日开盘前扫描 (09:00) +#0 9 * * 1-5 cd /app && python market_scanner.py 200 >> /app/logs/cron.log 2>&1 + +# 每个工作日午休时间扫描 (12:30) +#30 12 * * 1-5 cd /app && python market_scanner.py 100 >> /app/logs/cron.log 2>&1 + +# 每个工作日收盘后扫描 (15:30) +30 15 * * 1-5 cd /app && python market_scanner.py 300 >> /app/logs/cron.log 2>&1 + +# 每周末进行一次深度扫描 (周六 10:00) +#0 10 * * 6 cd /app && python market_scanner.py 500 >> /app/logs/cron.log 2>&1 + +# 高频监控 - 每30分钟扫描一次热门股票 (交易时间内: 9:30-15:00) +# 注释掉避免过于频繁,需要时可以开启 +# 30 9-14 * * 1-5 cd /app && python market_scanner.py 50 >> /app/logs/cron.log 2>&1 +# 0 10-14 * * 1-5 cd /app && python market_scanner.py 50 >> /app/logs/cron.log 2>&1 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 12d510c..437d9f0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -27,23 +27,26 @@ services: retries: 3 start_period: 40s - # 数据采集服务(可选) - trading-collector: + # 市场扫描定时任务服务 + trading-market-scanner: build: . - container_name: trading-ai-collector-mysql + container_name: trading-ai-market-scanner volumes: - ./config:/app/config - ./logs:/app/logs + - ./crontab:/app/crontab environment: - PYTHONPATH=/app + - LOG_LEVEL=${LOG_LEVEL:-INFO} + - MARKET_SCAN_STOCKS=${MARKET_SCAN_STOCKS:-200} # MySQL连接配置 - MYSQL_HOST=${MYSQL_HOST:-cd-cynosdbmysql-grp-7kdd8qe4.sql.tencentcdb.com} - MYSQL_PORT=${MYSQL_PORT:-26558} - MYSQL_USER=${MYSQL_USER:-root} - MYSQL_PASSWORD=${MYSQL_PASSWORD:-gUjjmQpu6c7V0hMF} - MYSQL_DATABASE=${MYSQL_DATABASE:-tradingai} - # 运行数据采集脚本 - command: ["python", "main.py", "scanmarket"] + # 运行定时市场扫描 + command: ["bash", "/app/start_market_scanner.sh"] restart: unless-stopped depends_on: - trading-web diff --git a/docker-start.sh b/docker-start.sh index e93a3ee..ebcad6a 100755 --- a/docker-start.sh +++ b/docker-start.sh @@ -58,7 +58,7 @@ database: path: "data/trading.db" data_source: - provider: "adata" + provider: "tushare" strategy: kline_pattern: diff --git a/docs/MARKET_SCANNER_DOCKER.md b/docs/MARKET_SCANNER_DOCKER.md new file mode 100644 index 0000000..eaf7c65 --- /dev/null +++ b/docs/MARKET_SCANNER_DOCKER.md @@ -0,0 +1,206 @@ +# 市场扫描Docker定时任务配置 + +## 概述 + +本文档描述如何使用Docker容器定时执行`main.py market 200`命令进行股票市场扫描分析。 + +## 服务组件 + +### 1. 市场扫描服务 (trading-market-scanner) + +- **容器名称**: `trading-ai-market-scanner` +- **功能**: 定时执行股票市场扫描和K线形态分析 +- **扫描对象**: 同花顺热榜前N只股票 +- **通知**: 支持钉钉webhook通知 + +### 2. 主要文件 + +- `market_scanner.py` - 市场扫描脚本 +- `start_market_scanner.sh` - 容器启动脚本 +- `crontab/market-scanner` - 定时任务配置 +- `docker-compose.yml` - Docker编排配置 + +## 定时任务配置 + +### 默认调度计划 + +```bash +# 每个工作日开盘前扫描 (09:00) - 200只股票 +0 9 * * 1-5 cd /app && python market_scanner.py 200 + +# 每个工作日午休时间扫描 (12:30) - 100只股票 +30 12 * * 1-5 cd /app && python market_scanner.py 100 + +# 每个工作日收盘后扫描 (15:30) - 200只股票 +30 15 * * 1-5 cd /app && python market_scanner.py 200 + +# 每周末进行一次深度扫描 (周六 10:00) - 500只股票 +0 10 * * 6 cd /app && python market_scanner.py 500 +``` + +### 自定义调度 + +可通过修改 `crontab/market-scanner` 文件来自定义定时任务: + +```bash +# 格式: 分钟 小时 日 月 星期 命令 +# 每15分钟扫描一次 (交易时间内) +*/15 9-15 * * 1-5 cd /app && python market_scanner.py 50 + +# 每小时扫描一次 +0 * * * * cd /app && python market_scanner.py 100 +``` + +## 环境变量配置 + +### 必要配置 + +| 变量名 | 默认值 | 说明 | +|--------|--------|------| +| `MARKET_SCAN_STOCKS` | 200 | 扫描的股票数量 | +| `LOG_LEVEL` | INFO | 日志级别 | +| `MYSQL_HOST` | - | MySQL主机地址 | +| `MYSQL_PORT` | 26558 | MySQL端口 | +| `MYSQL_USER` | root | MySQL用户名 | +| `MYSQL_PASSWORD` | - | MySQL密码 | +| `MYSQL_DATABASE` | tradingai | 数据库名 | + +### 可选配置 + +| 变量名 | 说明 | +|--------|------| +| `DINGTALK_WEBHOOK_URL` | 钉钉通知webhook地址 | + +## 部署步骤 + +### 1. 环境配置 + +复制并编辑环境变量文件: + +```bash +cp .env.example .env +# 编辑 .env 文件,配置数据库连接和扫描参数 +``` + +### 2. 启动服务 + +```bash +# 启动市场扫描服务 +docker-compose up -d trading-market-scanner + +# 查看服务状态 +docker-compose ps + +# 查看日志 +docker-compose logs -f trading-market-scanner +``` + +### 3. 验证部署 + +```bash +# 查看容器状态 +docker ps | grep trading-ai-market-scanner + +# 查看定时任务配置 +docker exec trading-ai-market-scanner cat /etc/cron.d/market-scanner + +# 查看扫描日志 +docker exec trading-ai-market-scanner tail -f /app/logs/market_scanner.log +``` + +## 日志管理 + +### 日志文件位置 + +- `/app/logs/market_scanner.log` - 市场扫描日志 +- `/app/logs/cron.log` - Cron执行日志 +- `/app/logs/scanner_startup.log` - 服务启动日志 + +### 日志轮转 + +- 日志文件按日轮转 +- 保留30天历史日志 +- 自动压缩旧日志文件 + +## 监控和运维 + +### 健康检查 + +服务包含自动重启机制,如果进程异常退出会自动重启。 + +### 手动执行扫描 + +```bash +# 进入容器手动执行扫描 +docker exec -it trading-ai-market-scanner python market_scanner.py 100 + +# 查看扫描结果 +docker exec trading-ai-market-scanner tail -20 /app/logs/market_scanner.log +``` + +### 修改定时任务 + +1. 编辑 `crontab/market-scanner` 文件 +2. 重启容器使配置生效: + +```bash +docker-compose restart trading-market-scanner +``` + +### 调试模式 + +设置更详细的日志级别: + +```bash +# 在 .env 文件中设置 +LOG_LEVEL=DEBUG + +# 重启服务 +docker-compose restart trading-market-scanner +``` + +## 注意事项 + +1. **时区设置**: 容器使用 Asia/Shanghai 时区 +2. **资源消耗**: 扫描大量股票会消耗较多CPU和内存 +3. **网络依赖**: 需要稳定的网络连接访问股票数据API +4. **数据库连接**: 确保MySQL数据库可正常连接 +5. **存储空间**: 定期清理日志文件避免磁盘空间不足 + +## 故障排除 + +### 常见问题 + +1. **容器启动失败** + ```bash + # 查看启动日志 + docker-compose logs trading-market-scanner + ``` + +2. **定时任务不执行** + ```bash + # 检查cron服务状态 + docker exec trading-ai-market-scanner service cron status + + # 查看cron日志 + docker exec trading-ai-market-scanner tail -f /app/logs/cron.log + ``` + +3. **数据库连接失败** + ```bash + # 检查环境变量配置 + docker exec trading-ai-market-scanner env | grep MYSQL + ``` + +4. **内存不足** + ```bash + # 减少扫描股票数量 + MARKET_SCAN_STOCKS=50 + ``` + +### 性能优化 + +1. 根据服务器性能调整扫描频率和股票数量 +2. 设置合理的日志级别避免过多日志输出 +3. 定期清理历史数据和日志文件 +4. 监控服务器资源使用情况 \ No newline at end of file diff --git a/docs/NEW_ARCHITECTURE.md b/docs/NEW_ARCHITECTURE.md new file mode 100644 index 0000000..6cf991e --- /dev/null +++ b/docs/NEW_ARCHITECTURE.md @@ -0,0 +1,215 @@ +# 新架构说明文档 + +## 概述 + +本项目已重构为模块化的策略执行系统,支持定时执行多个策略任务。新架构分为四个层次: + +1. **数据层**: TushareFetcher + StockPoolManager +2. **策略层**: BaseStrategy + 具体策略实现 +3. **执行层**: TaskScheduler + StrategyExecutor +4. **通知层**: NotificationManager + +## 核心工作流程 + +每个策略任务都遵循以下标准流程: + +``` +1. 根据规则获取股票池 → 2. 传递给策略进行分析 → 3. 以多种方式呈现或通知结果 +``` + +## 模块详解 + +### 1. 数据层 + +#### TushareFetcher +- 负责从TuShare API获取股票数据 +- 支持历史K线、热榜数据、基本信息等 +- 已集成token管理和缓存机制 + +#### StockPoolManager +- 管理不同的股票池获取规则 +- 内置规则: + - `tushare_hot`: 同花顺热榜 + - `combined_hot`: 合并热门(同花顺+东财) + - `leading_stocks`: 龙头牛股 + - 支持自定义股票列表 + +### 2. 策略层 + +#### BaseStrategy (抽象基类) +- 定义统一的策略接口 +- 标准化输入输出格式 +- 提供股票池批量分析能力 + +#### KLinePatternStrategy +- 实现K线形态识别策略 +- 继承自BaseStrategy +- 支持多时间周期分析 +- 集成回踩监控功能 + +### 3. 执行层 + +#### StrategyExecutor +- 协调股票池获取和策略分析 +- 管理策略注册和执行 +- 统一的结果处理和通知 + +#### TaskScheduler +- 支持多种调度规则 +- 任务执行历史和统计 +- 灵活的任务管理 + +### 4. 通知层 + +#### NotificationManager +- 支持多种通知方式 +- 策略结果汇总推送 +- 特殊事件提醒 + +## 快速开始 + +### 基础使用 + +```python +# 1. 初始化组件 +from src.data.tushare_fetcher import TushareFetcher +from src.data.stock_pool_manager import StockPoolManager +from src.strategy.kline_pattern_strategy import KLinePatternStrategy +from src.execution.strategy_executor import StrategyExecutor +from src.utils.notification import NotificationManager + +# 创建实例 +fetcher = TushareFetcher() +pool_manager = StockPoolManager(fetcher) +notification_manager = NotificationManager(config) +executor = StrategyExecutor(pool_manager, notification_manager) + +# 2. 注册策略 +strategy = KLinePatternStrategy(fetcher, notification_manager, config) +executor.register_strategy("kline_pattern", strategy) + +# 3. 执行任务 +result = executor.execute_task( + task_id="test_task", + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + max_stocks=20 +) +``` + +### 定时任务 + +```python +from src.execution.task_scheduler import TaskScheduler + +# 创建调度器 +scheduler = TaskScheduler() + +# 添加定时任务 +task_func = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + max_stocks=30 +) + +scheduler.add_task( + task_id="morning_scan", + name="晨间扫描", + func=task_func, + schedule_rule="weekdays at 09:00" +) + +# 启动调度器 +scheduler.start() +``` + +## 支持的调度规则 + +- `"every 30 minutes"` - 每30分钟 +- `"every 1 hour"` - 每1小时 +- `"daily at 09:30"` - 每日09:30 +- `"weekdays at 14:00"` - 工作日14:00 +- `"monday at 10:00"` - 每周一10:00 + +## 配置示例 + +### 策略配置 +```python +strategy_config = { + 'min_entity_ratio': 0.55, + 'final_yang_min_ratio': 0.40, + 'max_turnover_ratio': 40.0, + 'timeframes': ['daily'], + 'pullback_tolerance': 0.02, + 'monitor_days': 30 +} +``` + +### 股票池参数 +```python +# 同花顺热榜 +{"limit": 50} + +# 合并热门 +{"limit_per_source": 30, "final_limit": 50} + +# 龙头股 +{"top_boards": 8, "stocks_per_board": 3, "min_score": 70.0} +``` + +## 扩展指南 + +### 添加新的股票池规则 + +1. 继承 `StockPoolRule` 基类 +2. 实现 `get_stocks()` 和 `get_rule_name()` 方法 +3. 注册到 `StockPoolManager` + +```python +class MyCustomRule(StockPoolRule): + def get_stocks(self, fetcher, **kwargs): + # 自定义获取逻辑 + return stock_list + + def get_rule_name(self): + return "我的自定义规则" + +# 注册 +pool_manager.register_rule("my_rule", MyCustomRule()) +``` + +### 添加新策略 + +1. 继承 `BaseStrategy` 基类 +2. 实现 `analyze_stock()` 和 `get_strategy_description()` 方法 +3. 注册到 `StrategyExecutor` + +```python +class MyStrategy(BaseStrategy): + def analyze_stock(self, stock_code, timeframes=None): + # 实现分析逻辑 + return {timeframe: StrategyResult(...)} + + def get_strategy_description(self): + return "我的策略描述" + +# 注册 +executor.register_strategy("my_strategy", MyStrategy(...)) +``` + +## 示例文件 + +- `examples/new_architecture_example.py` - 完整架构演示 +- `examples/task_config_examples.py` - 常见任务配置示例 + +## 优势总结 + +✅ **模块化设计**: 职责清晰,易于维护扩展 +✅ **灵活配置**: 支持多种股票池和调度规则 +✅ **标准化接口**: 统一的策略和结果格式 +✅ **任务调度**: 强大的定时执行能力 +✅ **统一通知**: 完善的结果推送机制 +✅ **缓存优化**: 避免重复API调用 +✅ **错误处理**: 完整的异常处理和日志记录 + +新架构完全满足了原始需求:通过简单配置实现复杂的多策略定时执行和结果通知系统。 \ No newline at end of file diff --git a/examples/new_architecture_example.py b/examples/new_architecture_example.py new file mode 100644 index 0000000..753cd88 --- /dev/null +++ b/examples/new_architecture_example.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +""" +新架构集成示例 +展示如何使用重构后的模块化架构进行策略执行和任务调度 +""" + +import sys +import time +from pathlib import Path +from datetime import datetime + +# 添加src目录到路径 +current_dir = Path(__file__).parent +src_dir = current_dir.parent / "src" +sys.path.insert(0, str(src_dir)) + +from loguru import logger +from src.data.tushare_fetcher import TushareFetcher +from src.data.stock_pool_manager import StockPoolManager +from src.strategy.kline_pattern_strategy import KLinePatternStrategy +from src.execution.strategy_executor import StrategyExecutor +from src.execution.task_scheduler import TaskScheduler +from src.utils.notification import NotificationManager +from src.utils.config_loader import config_loader + + +def demo_new_architecture(): + """演示新架构的完整工作流程""" + + print("=" * 80) + print("🚀 新架构演示 - 模块化策略执行系统") + print("=" * 80) + print("📋 系统架构:") + print(" 数据层: TushareFetcher + StockPoolManager") + print(" 策略层: BaseStrategy + KLinePatternStrategy") + print(" 执行层: TaskScheduler + StrategyExecutor") + print(" 通知层: NotificationManager") + print() + + # 1. 初始化所有组件 + print("📦 第1步: 初始化所有组件") + print("-" * 60) + + # 数据层 + fetcher = TushareFetcher() + pool_manager = StockPoolManager(fetcher) + + # 通知层 + notification_config = config_loader.get('notification', {}) + notification_manager = NotificationManager(notification_config) + + # 策略层 + strategy_config = { + 'min_entity_ratio': 0.55, + 'final_yang_min_ratio': 0.40, + 'max_turnover_ratio': 40.0, + 'timeframes': ['daily'], + 'pullback_tolerance': 0.02, + 'monitor_days': 30, + 'pullback_confirmation_days': 7 + } + kline_strategy = KLinePatternStrategy( + data_fetcher=fetcher, + notification_manager=notification_manager, + config=strategy_config + ) + + # 执行层 + executor = StrategyExecutor(pool_manager, notification_manager) + scheduler = TaskScheduler() + + print("✅ 所有组件初始化完成") + print() + + # 2. 注册策略到执行器 + print("📋 第2步: 注册策略") + print("-" * 60) + + executor.register_strategy("kline_pattern", kline_strategy) + + print("已注册策略:") + for strategy_id, strategy_name in executor.get_registered_strategies().items(): + print(f" {strategy_id}: {strategy_name}") + print() + + # 3. 展示股票池规则 + print("🎯 第3步: 可用股票池规则") + print("-" * 60) + + available_rules = pool_manager.get_available_rules() + print("可用规则:") + for rule_id, rule_name in available_rules.items(): + print(f" {rule_id}: {rule_name}") + print() + + # 4. 手动执行单个任务 + print("⚡ 第4步: 手动执行策略任务") + print("-" * 60) + + task_id = f"manual_task_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + print(f"执行任务: {task_id}") + print("参数:") + print(f" 策略: kline_pattern") + print(f" 股票池规则: tushare_hot") + print(f" 最大股票数: 5") + print() + + result = executor.execute_task( + task_id=task_id, + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + stock_pool_params={"limit": 10}, + max_stocks=5, + send_notification=False # 演示时不发送通知 + ) + + # 显示执行结果 + print("📊 执行结果摘要:") + summary = result.get_summary() + for key, value in summary.items(): + print(f" {key}: {value}") + print() + + # 5. 设置定时任务 + print("⏰ 第5步: 设置定时任务") + print("-" * 60) + + # 创建任务函数 + task_function_1 = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + stock_pool_params={"limit": 20}, + max_stocks=10, + send_notification=False + ) + + task_function_2 = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="combined_hot", + stock_pool_params={"limit_per_source": 15, "final_limit": 25}, + max_stocks=15, + send_notification=False + ) + + # 添加定时任务 + scheduler.add_task( + task_id="hot_stocks_scan", + name="同花顺热榜K线形态扫描", + func=task_function_1, + schedule_rule="every 30 minutes", + enabled=False # 演示时不启用 + ) + + scheduler.add_task( + task_id="combined_hot_scan", + name="合并热门股票K线形态扫描", + func=task_function_2, + schedule_rule="daily at 09:30", + enabled=False # 演示时不启用 + ) + + scheduler.add_task( + task_id="leading_stocks_scan", + name="龙头股K线形态扫描", + func=executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="leading_stocks", + stock_pool_params={"top_boards": 5, "stocks_per_board": 2}, + max_stocks=10, + send_notification=False + ), + schedule_rule="weekdays at 14:30", + enabled=False # 演示时不启用 + ) + + print("已添加定时任务:") + task_status = scheduler.get_task_status() + for task_id, status in task_status.items(): + print(f" {task_id}:") + print(f" 名称: {status['name']}") + print(f" 规则: {status['schedule_rule']}") + print(f" 状态: {status['status']}") + print(f" 启用: {status['enabled']}") + print() + + # 6. 演示立即执行定时任务 + print("🎬 第6步: 演示立即执行定时任务") + print("-" * 60) + + print("立即执行任务: hot_stocks_scan") + success = scheduler.execute_task_now("hot_stocks_scan") + print(f"执行结果: {'成功' if success else '失败'}") + print() + + # 7. 显示任务执行历史 + print("📈 第7步: 任务执行历史") + print("-" * 60) + + updated_status = scheduler.get_task_status() + for task_id, status in updated_status.items(): + if status['total_executions'] > 0: + print(f"任务: {status['name']}") + print(f" 总执行次数: {status['total_executions']}") + print(f" 成功率: {status['success_rate']:.1f}%") + print(f" 最后执行: {status['last_execution_time']}") + print() + + # 8. 系统功能总结 + print("🎯 第8步: 新架构优势总结") + print("-" * 60) + + advantages = """ +✅ 模块化设计: + • 数据获取、股票池管理、策略执行、任务调度完全分离 + • 每个模块职责单一,易于维护和扩展 + • 支持插件化添加新的股票池规则和策略 + +✅ 灵活的股票池管理: + • 支持多种数据源:同花顺热榜、合并热门、龙头股 + • 可配置参数:股票数量、筛选条件等 + • 易于添加新的股票池规则 + +✅ 标准化的策略接口: + • BaseStrategy抽象基类统一策略接口 + • StrategyResult标准化输出格式 + • 支持多时间周期分析 + +✅ 强大的任务调度: + • 支持多种调度规则:间隔时间、每日定时、工作日定时 + • 任务执行历史和成功率统计 + • 支持立即执行和定时执行 + +✅ 统一的执行协调: + • StrategyExecutor协调股票池获取→策略分析→结果通知 + • 完整的执行结果记录和统计 + • 支持并发执行多个策略 + +✅ 完善的通知系统: + • 支持多种通知方式 + • 策略结果汇总推送 + • 回踩提醒等特殊通知 + """ + + print(advantages) + + print("🚀 架构使用示例:") + example_usage = """ +# 基础用法 - 立即执行 +result = executor.execute_task( + task_id="my_task", + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + max_stocks=20 +) + +# 定时任务 - 每天9:30执行 +scheduler.add_task( + task_id="morning_scan", + name="晨间形态扫描", + func=executor.create_task_function(...), + schedule_rule="daily at 09:30" +) + +# 启动调度器 +scheduler.start() + """ + print(example_usage) + + print("=" * 80) + print("🎉 新架构演示完成!") + print("💡 现在可以通过简单配置实现复杂的策略执行和调度") + print("🔧 支持灵活的股票池规则和多策略并行执行") + print("📱 统一的结果通知和监控体系") + print("=" * 80) + + +if __name__ == "__main__": + # 设置日志 + logger.remove() + logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") + + try: + demo_new_architecture() + except Exception as e: + logger.error(f"演示过程中发生错误: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/examples/task_config_examples.py b/examples/task_config_examples.py new file mode 100644 index 0000000..93f74a5 --- /dev/null +++ b/examples/task_config_examples.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +""" +任务配置示例 +展示如何配置不同类型的定时策略任务 +""" + +import sys +from pathlib import Path + +# 添加src目录到路径 +current_dir = Path(__file__).parent +src_dir = current_dir.parent / "src" +sys.path.insert(0, str(src_dir)) + +from loguru import logger +from src.data.tushare_fetcher import TushareFetcher +from src.data.stock_pool_manager import StockPoolManager +from src.strategy.kline_pattern_strategy import KLinePatternStrategy +from src.execution.strategy_executor import StrategyExecutor +from src.execution.task_scheduler import TaskScheduler +from src.utils.notification import NotificationManager +from src.utils.config_loader import config_loader + + +def setup_common_tasks(): + """设置常见的策略任务配置""" + + print("=" * 80) + print("📋 常见策略任务配置示例") + print("=" * 80) + + # 初始化组件 + fetcher = TushareFetcher() + pool_manager = StockPoolManager(fetcher) + notification_manager = NotificationManager(config_loader.get('notification', {})) + executor = StrategyExecutor(pool_manager, notification_manager) + scheduler = TaskScheduler() + + # 策略配置 + strategy_config = { + 'min_entity_ratio': 0.55, + 'final_yang_min_ratio': 0.40, + 'max_turnover_ratio': 40.0, + 'timeframes': ['daily'], + 'pullback_tolerance': 0.02, + 'monitor_days': 30, + 'pullback_confirmation_days': 7 + } + + # 注册策略 + kline_strategy = KLinePatternStrategy( + data_fetcher=fetcher, + notification_manager=notification_manager, + config=strategy_config + ) + executor.register_strategy("kline_pattern", kline_strategy) + + print("🎯 任务配置场景:") + print() + + # 场景1: 开盘前扫描热门股票 + print("📊 场景1: 开盘前热门股票扫描") + print("-" * 60) + + task_1 = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + stock_pool_params={"limit": 50}, + max_stocks=30, + send_notification=True + ) + + scheduler.add_task( + task_id="pre_market_hot_scan", + name="开盘前热门股票K线扫描", + func=task_1, + schedule_rule="weekdays at 09:00", + enabled=False + ) + + print("✅ 配置完成:") + print(" 任务: 开盘前热门股票K线扫描") + print(" 时间: 每个工作日 09:00") + print(" 股票池: 同花顺热榜前50只") + print(" 分析: 最多30只股票") + print(" 通知: 启用") + print() + + # 场景2: 午间龙头股扫描 + print("🐲 场景2: 午间龙头股扫描") + print("-" * 60) + + task_2 = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="leading_stocks", + stock_pool_params={ + "top_boards": 8, + "stocks_per_board": 3, + "min_score": 70.0 + }, + max_stocks=20, + send_notification=True + ) + + scheduler.add_task( + task_id="midday_leading_scan", + name="午间龙头股K线扫描", + func=task_2, + schedule_rule="weekdays at 12:30", + enabled=False + ) + + print("✅ 配置完成:") + print(" 任务: 午间龙头股K线扫描") + print(" 时间: 每个工作日 12:30") + print(" 股票池: 热门板块前8个,每板块前3只,评分>70") + print(" 分析: 最多20只龙头股") + print(" 通知: 启用") + print() + + # 场景3: 收盘后综合扫描 + print("🌅 场景3: 收盘后综合扫描") + print("-" * 60) + + task_3 = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="combined_hot", + stock_pool_params={ + "limit_per_source": 30, + "final_limit": 50 + }, + max_stocks=40, + send_notification=True + ) + + scheduler.add_task( + task_id="after_market_comprehensive_scan", + name="收盘后综合热门股扫描", + func=task_3, + schedule_rule="weekdays at 15:30", + enabled=False + ) + + print("✅ 配置完成:") + print(" 任务: 收盘后综合热门股扫描") + print(" 时间: 每个工作日 15:30") + print(" 股票池: 合并热门(同花顺+东财),各取30只,合并后50只") + print(" 分析: 最多40只股票") + print(" 通知: 启用") + print() + + # 场景4: 高频监控 + print("⚡ 场景4: 高频监控扫描") + print("-" * 60) + + task_4 = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + stock_pool_params={"limit": 20}, + max_stocks=15, + send_notification=False # 高频不通知,避免打扰 + ) + + scheduler.add_task( + task_id="high_freq_monitor", + name="高频热门股监控", + func=task_4, + schedule_rule="every 15 minutes", + enabled=False + ) + + print("✅ 配置完成:") + print(" 任务: 高频热门股监控") + print(" 时间: 每15分钟执行一次") + print(" 股票池: 同花顺热榜前20只") + print(" 分析: 最多15只股票") + print(" 通知: 关闭(避免频繁打扰)") + print() + + # 场景5: 自定义股票池 + print("🎯 场景5: 自定义股票池扫描") + print("-" * 60) + + # 创建自定义股票池 + custom_stocks = [ + "000001.SZ", # 平安银行 + "000002.SZ", # 万科A + "600000.SH", # 浦发银行 + "600036.SH", # 招商银行 + "000858.SZ", # 五粮液 + "600519.SH", # 贵州茅台 + "000725.SZ", # 京东方A + "002415.SZ" # 海康威视 + ] + + pool_manager.create_custom_rule("my_watchlist", custom_stocks) + + task_5 = executor.create_task_function( + strategy_id="kline_pattern", + stock_pool_rule="my_watchlist", + stock_pool_params={}, + max_stocks=len(custom_stocks), + send_notification=True + ) + + scheduler.add_task( + task_id="custom_watchlist_scan", + name="自选股K线形态扫描", + func=task_5, + schedule_rule="daily at 21:00", + enabled=False + ) + + print("✅ 配置完成:") + print(" 任务: 自选股K线形态扫描") + print(" 时间: 每日 21:00") + print(f" 股票池: 自定义股票池({len(custom_stocks)}只)") + print(" 分析: 全部自选股") + print(" 通知: 启用") + print() + + # 显示所有任务状态 + print("📋 所有配置任务总览:") + print("-" * 60) + + task_status = scheduler.get_task_status() + for task_id, status in task_status.items(): + print(f"🔹 {status['name']}") + print(f" ID: {task_id}") + print(f" 规则: {status['schedule_rule']}") + print(f" 状态: {status['status']}") + print(f" 启用: {'是' if status['enabled'] else '否'}") + print() + + # 使用指南 + print("📖 使用指南:") + print("-" * 60) + print("1. 根据实际需求选择合适的任务配置") + print("2. 调整股票池参数和分析数量") + print("3. 设置合适的执行时间") + print("4. 启用需要的任务: scheduler.enable_task('task_id')") + print("5. 启动调度器: scheduler.start()") + print("6. 立即测试: scheduler.execute_task_now('task_id')") + print() + + print("⚙️ 高级配置技巧:") + print("-" * 60) + print("• 开盘前(09:00): 扫描热门股,发现隔夜机会") + print("• 午间时段(12:30): 扫描龙头股,捕捉强势股") + print("• 收盘后(15:30): 综合扫描,总结全天机会") + print("• 高频监控(15分钟): 实时跟踪,但关闭通知") + print("• 晚间复盘(21:00): 扫描自选股,制定明日策略") + print() + + return scheduler, executor + + +def demo_task_execution(scheduler, executor): + """演示任务执行""" + + print("🎬 任务执行演示:") + print("-" * 60) + + # 立即执行一个任务进行测试 + print("立即执行任务: pre_market_hot_scan") + success = scheduler.execute_task_now("pre_market_hot_scan") + print(f"执行结果: {'成功' if success else '失败'}") + print() + + # 显示执行统计 + task_status = scheduler.get_task_status() + for task_id, status in task_status.items(): + if status['total_executions'] > 0: + print(f"📊 任务: {status['name']}") + print(f" 执行次数: {status['total_executions']}") + print(f" 成功率: {status['success_rate']:.1f}%") + if status['last_execution_time']: + print(f" 最后执行: {status['last_execution_time']}") + print() + + print("💡 提示: 在生产环境中启用任务并启动调度器") + print(" scheduler.enable_task('task_id')") + print(" scheduler.start()") + + +if __name__ == "__main__": + # 设置日志 + logger.remove() + logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") + + try: + scheduler, executor = setup_common_tasks() + demo_task_execution(scheduler, executor) + + print("=" * 80) + print("🎉 任务配置示例演示完成!") + print("💼 可根据实际需求调整参数和时间规则") + print("🚀 启用任务并启动调度器即可自动运行") + print("=" * 80) + + except Exception as e: + logger.error(f"演示过程中发生错误: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/main.py b/main.py index b49450d..cfa8bf8 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 """ -A股量化交易主程序 +A股量化交易 - 策略测试入口 +简化版本,专注于策略测试和执行 """ import sys @@ -13,533 +14,313 @@ sys.path.insert(0, str(src_dir)) from loguru import logger from src.utils.config_loader import config_loader -from src.data.data_fetcher import ADataFetcher -from src.data.sentiment_fetcher import SentimentFetcher -from src.utils.notification import NotificationManager +from src.data.tushare_fetcher import TushareFetcher +from src.data.stock_pool_manager import StockPoolManager from src.strategy.kline_pattern_strategy import KLinePatternStrategy -from src.database.mysql_database_manager import MySQLDatabaseManager +from src.execution.strategy_executor import StrategyExecutor +from src.execution.task_scheduler import TaskScheduler +from src.utils.notification import NotificationManager def setup_logging(): - """设置日志配置""" - log_config = config_loader.get_logging_config() - - # 移除默认的控制台日志 + """设置简化的日志配置""" logger.remove() - - # 添加控制台输出 logger.add( sys.stdout, - level=log_config.get('level', 'INFO'), - format=log_config.get('format', '{time} | {level} | {message}') + level="INFO", + format="{time:HH:mm:ss} | {level} | {message}" ) - # 添加文件输出 - log_file = Path(log_config.get('file_path', 'logs/trading.log')) - log_file.parent.mkdir(parents=True, exist_ok=True) - logger.add( - log_file, - level=log_config.get('level', 'INFO'), - format=log_config.get('format', '{time} | {level} | {message}'), - rotation=log_config.get('rotation', '1 day'), - retention=log_config.get('retention', '30 days') +def create_strategy_system(): + """创建策略系统的所有组件""" + # 数据层 + fetcher = TushareFetcher() + pool_manager = StockPoolManager(fetcher) + + # 通知层 + notification_config = config_loader.get('notification', {}) + notification_manager = NotificationManager(notification_config) + + # 策略层 + strategy_config = config_loader.get('strategy', {}).get('kline_pattern', { + 'min_entity_ratio': 0.55, + 'final_yang_min_ratio': 0.40, + 'max_turnover_ratio': 40.0, + 'timeframes': ['daily'], + 'pullback_tolerance': 0.02, + 'monitor_days': 30, + 'pullback_confirmation_days': 7 + }) + + # 数据库层 + from src.database.mysql_database_manager import MySQLDatabaseManager + db_manager = MySQLDatabaseManager() + + kline_strategy = KLinePatternStrategy( + data_fetcher=fetcher, + notification_manager=notification_manager, + config=strategy_config, + db_manager=db_manager ) - logger.info("日志系统初始化完成") + # 执行层 + executor = StrategyExecutor(pool_manager, notification_manager) + scheduler = TaskScheduler() + # 注册策略 + executor.register_strategy("kline_pattern", kline_strategy) -def main(): - """主函数""" - print("="*60) - print(" A股量化交易系统") - print("="*60) - - try: - # 初始化日志 - setup_logging() - - # 加载配置 - config = config_loader.load_config() - logger.info("配置文件加载成功") - - # 初始化数据获取器 - data_fetcher = ADataFetcher() - sentiment_fetcher = SentimentFetcher() - - # 初始化MySQL数据库管理器 - db_manager = MySQLDatabaseManager() - logger.info("MySQL数据库管理器初始化完成") - - # 初始化通知管理器 - notification_config = config.get('notification', {}) - notification_manager = NotificationManager(notification_config) - - # 初始化K线形态策略 - strategy_config = config.get('strategy', {}).get('kline_pattern', {}) - if strategy_config.get('enabled', False): - kline_strategy = KLinePatternStrategy(data_fetcher, notification_manager, strategy_config, db_manager) - logger.info("K线形态策略已启用") - else: - kline_strategy = None - logger.info("K线形态策略未启用") - - # 显示系统信息 - logger.info("系统启动成功") - print("\n系统功能:") - print("1. 数据获取 - 实时行情、历史数据、财务数据") - print("2. 舆情分析 - 北向资金、融资融券、热点股票、龙虎榜") - print("3. K线形态策略 - 两阳线+阴线+阳线突破形态识别") - print("4. 股票筛选 - 基于技术指标和基本面的选股") - print("5. 实时监控 - 价格变动、成交量异常监控") - print("6. 策略回测 - 历史数据验证交易策略") - - # 获取市场概况 - print("\n正在获取市场概况...") - market_overview = data_fetcher.get_market_overview() - - if market_overview: - print(f"\n市场概况 (更新时间: {market_overview.get('update_time', 'N/A')}):") - for market, data in market_overview.items(): - if market != 'update_time' and isinstance(data, dict): - price = data.get('close', data.get('current', 'N/A')) - change = data.get('change', 'N/A') - change_pct = data.get('change_pct', 'N/A') - print(f" {market.upper()}: 价格={price}, 涨跌={change}, 涨跌幅={change_pct}%") - - print("\n系统就绪,等待指令...") - print("输入 'help' 查看帮助,输入 'quit' 退出程序") - - # 简单的交互式命令行 - while True: - try: - command = input("\n> ").strip().lower() - - if command == 'quit' or command == 'exit': - print("感谢使用A股量化交易系统!") - break - elif command == 'help': - print_help() - elif command == 'status': - print_system_status() - elif command.startswith('search '): - keyword = command[7:] # 移除'search ' - search_stocks(data_fetcher, keyword) - elif command == 'market': - show_market_overview(data_fetcher) - elif command == 'sentiment': - show_market_sentiment(sentiment_fetcher) - elif command == 'hotstock': - show_hot_stocks(sentiment_fetcher) - elif command == 'northflow': - show_north_flow(sentiment_fetcher) - elif command == 'dragon': - show_dragon_tiger_list(sentiment_fetcher) - elif command.startswith('analyze '): - stock_code = command[8:] # 移除'analyze ' - analyze_stock_sentiment(sentiment_fetcher, stock_code) - elif command == 'strategy': - show_strategy_info(kline_strategy) - elif command.startswith('scan '): - stock_code = command[5:] # 移除'scan ' - scan_single_stock(kline_strategy, stock_code) - elif command == 'scanmarket': - scan_market_patterns(kline_strategy) - elif command == 'testnotify': - test_notification(notification_manager) - else: - print("未知命令,输入 'help' 查看帮助") - - except KeyboardInterrupt: - print("\n\n程序被用户中断") - break - except Exception as e: - logger.error(f"命令执行错误: {e}") - print(f"执行错误: {e}") - - except Exception as e: - logger.error(f"程序启动失败: {e}") - print(f"启动失败: {e}") - sys.exit(1) + return { + 'fetcher': fetcher, + 'pool_manager': pool_manager, + 'notification_manager': notification_manager, + 'kline_strategy': kline_strategy, + 'executor': executor, + 'scheduler': scheduler + } def print_help(): """打印帮助信息""" - print("\n可用命令:") - print(" help - 显示此帮助信息") - print(" status - 显示系统状态") - print(" market - 显示市场概况") - print(" search <关键词> - 搜索股票") - print(" sentiment - 显示市场舆情综合概览") - print(" hotstock - 显示热门股票排行") - print(" northflow - 显示北向资金流向") - print(" dragon - 显示龙虎榜数据") - print(" analyze <股票代码> - 分析单只股票舆情") - print(" strategy - 显示K线形态策略信息") - print(" scan <股票代码> - 扫描单只股票K线形态") - print(" scanmarket - 扫描市场K线形态") - print(" testnotify - 测试通知功能") - print(" quit/exit - 退出程序") + print("\n🚀 策略测试命令:") + print("-" * 50) + print(" scan <股票代码> - 分析单只股票") + print(" market <数量> - 扫描热门股票(默认20只)") + print(" pools - 查看可用股票池规则") + print(" task <规则> <数量> - 执行策略任务") + print(" schedule - 显示定时任务示例") + print(" help - 显示帮助") + print(" quit - 退出程序") + print("-" * 50) -def print_system_status(): - """显示系统状态""" - config = config_loader.config - print("\n系统状态:") - print(f" 配置文件: 已加载") - print(f" 数据源: {config.get('data', {}).get('sources', {}).get('primary', 'N/A')}") - print(f" 日志级别: {config.get('logging', {}).get('level', 'N/A')}") - print(f" 实时监控: {'启用' if config.get('monitor', {}).get('realtime', {}).get('enabled', False) else '禁用'}") +def main(): + """主函数""" + setup_logging() - -def search_stocks(data_fetcher: ADataFetcher, keyword: str): - """搜索股票""" - if not keyword: - print("请提供搜索关键词") - return - - print(f"\n搜索股票: {keyword}") - results = data_fetcher.search_stocks(keyword) - - if not results.empty: - print(f"找到 {len(results)} 个结果:") - for idx, row in results.head(10).iterrows(): # 只显示前10个结果 - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - print(f" {code} - {name}") - - if len(results) > 10: - print(f" ... 还有 {len(results) - 10} 个结果") - else: - print("未找到匹配的股票") - - -def show_market_overview(data_fetcher: ADataFetcher): - """显示市场概况""" - print("\n正在获取最新市场数据...") - overview = data_fetcher.get_market_overview() - - if overview: - print(f"\n市场概况 (更新时间: {overview.get('update_time', 'N/A')}):") - for market, data in overview.items(): - if market != 'update_time' and isinstance(data, dict): - price = data.get('close', data.get('current', 'N/A')) - change = data.get('change', 'N/A') - change_pct = data.get('change_pct', 'N/A') - volume = data.get('volume', 'N/A') - print(f" {market.upper()}: 价格={price}, 涨跌={change}, 涨跌幅={change_pct}%, 成交量={volume}") - else: - print("无法获取市场数据") - - -def show_market_sentiment(sentiment_fetcher: SentimentFetcher): - """显示市场舆情综合概览""" - print("\n正在获取市场舆情数据...") - overview = sentiment_fetcher.get_market_sentiment_overview() - - if overview: - print(f"\n市场舆情综合概览 (更新时间: {overview.get('update_time', 'N/A')}):") - - # 北向资金 - if 'north_flow' in overview: - north_data = overview['north_flow'] - print(f"\n📊 北向资金:") - print(f" 总净流入: {north_data.get('net_total', 'N/A')} 万元") - print(f" 沪股通: {north_data.get('net_hgt', 'N/A')} 万元") - print(f" 深股通: {north_data.get('net_sgt', 'N/A')} 万元") - print(f" 更新时间: {north_data.get('update_time', 'N/A')}") - - # 融资融券 - if 'latest_margin' in overview: - margin_data = overview['latest_margin'] - print(f"\n📈 融资融券:") - print(f" 融资余额: {margin_data.get('rzye', 'N/A')} 亿元") - print(f" 融券余额: {margin_data.get('rqye', 'N/A')} 亿元") - print(f" 两融余额: {margin_data.get('rzrqye', 'N/A')} 亿元") - - # 热门股票(前5名) - if 'hot_stocks_east' in overview and not overview['hot_stocks_east'].empty: - print(f"\n🔥 东财热门股票TOP5:") - for idx, row in overview['hot_stocks_east'].head(5).iterrows(): - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - rank = row.get('rank', idx + 1) - print(f" {rank}. {code} - {name}") - - # 热门概念(前5名) - if 'hot_concepts' in overview and not overview['hot_concepts'].empty: - print(f"\n💡 热门概念TOP5:") - for idx, row in overview['hot_concepts'].head(5).iterrows(): - name = row.get('concept_name', 'N/A') - change_pct = row.get('change_pct', 'N/A') - rank = row.get('rank', idx + 1) - print(f" {rank}. {name} (涨跌幅: {change_pct}%)") - - # 龙虎榜(前3名) - if 'dragon_tiger' in overview and not overview['dragon_tiger'].empty: - print(f"\n🐉 今日龙虎榜TOP3:") - for idx, row in overview['dragon_tiger'].head(3).iterrows(): - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - reason = row.get('reason', 'N/A') - print(f" {idx + 1}. {code} - {name} ({reason})") - else: - print("无法获取市场舆情数据") - - -def show_hot_stocks(sentiment_fetcher: SentimentFetcher): - """显示热门股票排行""" - print("\n正在获取热门股票数据...") - - # 东财人气股票 - east_stocks = sentiment_fetcher.get_popular_stocks_east_100() - if not east_stocks.empty: - print(f"\n🔥 东财人气股票TOP10:") - for idx, row in east_stocks.head(10).iterrows(): - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - rank = row.get('rank', idx + 1) - change_pct = row.get('change_pct', 'N/A') - print(f" {rank}. {code} - {name} (涨跌幅: {change_pct}%)") - - # 同花顺热门股票 - ths_stocks = sentiment_fetcher.get_hot_stocks_ths_100() - if not ths_stocks.empty: - print(f"\n🌟 同花顺热门股票TOP10:") - for idx, row in ths_stocks.head(10).iterrows(): - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - rank = row.get('rank', idx + 1) - change_pct = row.get('change_pct', 'N/A') - print(f" {rank}. {code} - {name} (涨跌幅: {change_pct}%)") - - -def show_north_flow(sentiment_fetcher: SentimentFetcher): - """显示北向资金流向""" - print("\n正在获取北向资金数据...") - - # 当前流向 - current_flow = sentiment_fetcher.get_north_flow_current() - if not current_flow.empty: - print(f"\n💰 当前北向资金流向:") - for idx, row in current_flow.iterrows(): - net_total = row.get('net_tgt', 'N/A') - net_hgt = row.get('net_hgt', 'N/A') - net_sgt = row.get('net_sgt', 'N/A') - trade_time = row.get('trade_time', 'N/A') - print(f" 总净流入: {net_total} 万元") - print(f" 沪股通: {net_hgt} 万元") - print(f" 深股通: {net_sgt} 万元") - print(f" 更新时间: {trade_time}") - break # 只显示第一行数据 - - # 历史流向(最近5天) - hist_flow = sentiment_fetcher.get_north_flow_history() - if not hist_flow.empty: - print(f"\n📊 最近5天北向资金流向:") - for idx, row in hist_flow.tail(5).iterrows(): - date = row.get('trade_date', 'N/A') - net_total = row.get('net_tgt', 'N/A') - print(f" {date}: {net_total} 万元") - - -def show_dragon_tiger_list(sentiment_fetcher: SentimentFetcher): - """显示龙虎榜数据""" - print("\n正在获取龙虎榜数据...") - - dragon_tiger = sentiment_fetcher.get_dragon_tiger_list_daily() - if not dragon_tiger.empty: - print(f"\n🐉 今日龙虎榜 (共{len(dragon_tiger)}只股票):") - for idx, row in dragon_tiger.head(15).iterrows(): # 显示前15个 - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - reason = row.get('reason', 'N/A') - change_pct = row.get('change_pct', 'N/A') - amount = row.get('amount', 'N/A') - print(f" {idx + 1}. {code} - {name}") - print(f" 上榜原因: {reason}") - print(f" 涨跌幅: {change_pct}%, 成交金额: {amount} 万元") - - if len(dragon_tiger) > 15: - print(f" ... 还有 {len(dragon_tiger) - 15} 只股票") - else: - print("今日暂无龙虎榜数据") - - -def analyze_stock_sentiment(sentiment_fetcher: SentimentFetcher, stock_code: str): - """分析单只股票舆情""" - if not stock_code: - print("请提供股票代码") - return - - print(f"\n正在分析股票 {stock_code} 的舆情情况...") - analysis = sentiment_fetcher.analyze_stock_sentiment(stock_code) - - if 'error' in analysis: - print(f"分析失败: {analysis['error']}") - return - - print(f"\n📊 {stock_code} 舆情分析报告:") - print(f"更新时间: {analysis.get('update_time', 'N/A')}") - - # 热度情况 - print(f"\n🔥 热度情况:") - print(f" 东财人气榜: {'在榜' if analysis.get('in_popular_east', False) else '不在榜'}") - print(f" 同花顺热门榜: {'在榜' if analysis.get('in_hot_ths', False) else '不在榜'}") - - # 龙虎榜情况 - if 'dragon_tiger' in analysis and not analysis['dragon_tiger'].empty: - print(f"\n🐉 龙虎榜情况:") - dragon_data = analysis['dragon_tiger'].iloc[0] - reason = dragon_data.get('reason', 'N/A') - amount = dragon_data.get('amount', 'N/A') - print(f" 上榜原因: {reason}") - print(f" 成交金额: {amount} 万元") - else: - print(f"\n🐉 龙虎榜情况: 今日未上榜") - - # 风险扫描 - if 'risk_scan' in analysis and not analysis['risk_scan'].empty: - print(f"\n⚠️ 风险扫描:") - risk_data = analysis['risk_scan'].iloc[0] - risk_level = risk_data.get('risk_level', 'N/A') - risk_desc = risk_data.get('risk_desc', 'N/A') - print(f" 风险等级: {risk_level}") - print(f" 风险描述: {risk_desc}") - else: - print(f"\n⚠️ 风险扫描: 暂无数据") - - -def show_strategy_info(kline_strategy: KLinePatternStrategy): - """显示K线形态策略信息""" - if kline_strategy is None: - print("K线形态策略未启用") - return - - print("\n" + "="*60) - print(" K线形态策略信息") - print("="*60) - print(kline_strategy.get_strategy_summary()) - - -def scan_single_stock(kline_strategy: KLinePatternStrategy, stock_code: str): - """扫描单只股票K线形态""" - if kline_strategy is None: - print("K线形态策略未启用") - return - - if not stock_code: - print("请提供股票代码") - return - - print(f"\n正在扫描股票 {stock_code} 的K线形态...") + print("=" * 60) + print("🎯 A股量化交易 - 策略测试系统") + print("=" * 60) try: - results = kline_strategy.analyze_stock(stock_code) + # 初始化系统 + logger.info("正在初始化策略系统...") + system = create_strategy_system() + logger.info("✅ 策略系统初始化完成") - print(f"\n📊 {stock_code} K线形态分析结果:") - total_signals = 0 + print(f"\n📊 系统组件:") + print(f" ✅ 数据获取器: TushareFetcher") + print(f" ✅ 股票池管理: StockPoolManager") + print(f" ✅ K线策略: KLinePatternStrategy") + print(f" ✅ 执行器: StrategyExecutor") + print(f" ✅ 调度器: TaskScheduler") - for timeframe, signals in results.items(): - print(f"\n{timeframe.upper()} 时间周期:") - if signals: - for i, signal in enumerate(signals, 1): - print(f" 信号 {i}:") - print(f" 日期: {signal['date']}") - print(f" 形态: {signal['pattern_type']}") - print(f" 突破价格: {signal['breakout_price']:.2f} 元") - print(f" 突破幅度: {signal['breakout_pct']:.2f}%") - print(f" 阳线1实体比例: {signal['yang1_entity_ratio']:.1%}") - print(f" 阳线2实体比例: {signal['yang2_entity_ratio']:.1%}") - print(f" EMA20价格: {signal['ema20_price']:.2f} 元") - print(f" EMA20状态: {'✅ 上方' if signal['above_ema20'] else '❌ 下方'}") - print(f" 换手率: {signal.get('turnover_ratio', 0):.2f}%") - total_signals += len(signals) - print(f" 共发现 {len(signals)} 个信号") - else: - print(" 未发现形态信号") + print_help() - print(f"\n总计发现 {total_signals} 个信号") + # 命令行交互 + while True: + try: + command = input("\n> ").strip() + + if not command: + continue + + if command.lower() in ['quit', 'exit']: + print("👋 感谢使用策略测试系统!") + break + + elif command.lower() == 'help': + print_help() + + elif command.startswith('scan '): + stock_code = command[5:].strip() + if stock_code: + scan_single_stock(system['kline_strategy'], stock_code) + else: + print("请提供股票代码,如: scan 000001.SZ") + + elif command.startswith('market'): + parts = command.split() + max_stocks = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 20 + scan_market(system['executor'], max_stocks) + + elif command.lower() == 'pools': + show_stock_pools(system['pool_manager']) + + elif command.startswith('task '): + parts = command.split() + if len(parts) >= 2: + rule = parts[1] + max_stocks = int(parts[2]) if len(parts) > 2 and parts[2].isdigit() else 10 + execute_task(system['executor'], rule, max_stocks) + else: + print("请提供股票池规则,如: task tushare_hot 15") + + elif command.lower() == 'schedule': + show_schedule_examples(system['scheduler'], system['executor']) + + else: + print("❌ 未知命令,输入 'help' 查看帮助") + + except KeyboardInterrupt: + print("\n\n👋 程序被用户中断") + break + except Exception as e: + logger.error(f"命令执行错误: {e}") + print(f"❌ 执行错误: {e}") except Exception as e: - logger.error(f"扫描股票失败: {e}") - print(f"扫描失败: {e}") + logger.error(f"系统启动失败: {e}") + print(f"❌ 启动失败: {e}") + return 1 + + return 0 -def scan_market_patterns(kline_strategy: KLinePatternStrategy): - """扫描市场K线形态""" - if kline_strategy is None: - print("K线形态策略未启用") - return - - print("\n开始扫描市场K线形态...") - print("⚠️ 注意: 这可能需要较长时间,请耐心等待") +def scan_single_stock(strategy, stock_code): + """扫描单只股票""" + print(f"\n🔍 分析股票: {stock_code}") + print("-" * 40) try: - # 获取扫描股票数量配置 - scan_count = kline_strategy.config.get('scan_stocks_count', 20) - print(f"扫描股票数量: {scan_count}") + results = strategy.analyze_stock(stock_code) - results = kline_strategy.scan_market(max_stocks=scan_count) + total_signals = 0 + for timeframe, result in results.items(): + signal_count = result.get_signal_count() + total_signals += signal_count - if results: - print(f"\n📈 市场扫描结果 (发现 {len(results)} 只股票有信号):") + if signal_count > 0: + print(f"📊 {timeframe.upper()}: 发现 {signal_count} 个信号") + for i, signal in enumerate(result.signals, 1): + print(f" {i}. {signal['date']} | {signal['signal_type']} | 价格: {signal['price']:.2f}元") + else: + print(f"📭 {timeframe.upper()}: 无信号") - for stock_code, stock_results in results.items(): - total_signals = sum(len(signals) for signals in stock_results.values()) - print(f"\n股票: {stock_code} (共{total_signals}个信号)") + print(f"\n📈 总计: {total_signals} 个信号") - for timeframe, signals in stock_results.items(): - if signals: - print(f" {timeframe}: {len(signals)}个信号") - # 只显示最新的信号 - latest_signal = signals[-1] - print(f" 最新: {latest_signal['date']} 突破价格 {latest_signal['breakout_price']:.2f}元") + except Exception as e: + logger.error(f"分析失败: {e}") + print(f"❌ 分析失败: {e}") - else: - print("未发现任何K线形态信号") + +def scan_market(executor, max_stocks): + """扫描市场热门股票""" + print(f"\n🌍 扫描市场热门股票 (前{max_stocks}只)") + print("-" * 50) + + try: + result = executor.execute_task( + task_id=f"market_scan_{max_stocks}", + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + stock_pool_params={"limit": max_stocks * 2}, + max_stocks=max_stocks, + send_notification=False + ) + + summary = result.get_summary() + print(f"📊 扫描结果:") + print(f" 股票池: {summary['stock_pool_rule_display']}") + print(f" 总扫描: {summary['total_stocks_analyzed']} 只") + print(f" 有信号: {summary['stocks_with_signals']} 只") + print(f" 信号数: {summary['total_signals_found']} 个") + print(f" 耗时: {summary['execution_time']:.2f} 秒") + + if result.strategy_results: + print(f"\n🎯 信号详情:") + for stock_code, timeframe_results in result.strategy_results.items(): + for timeframe, strategy_result in timeframe_results.items(): + if strategy_result.get_signal_count() > 0: + stock_name = executor.stock_pool_manager.fetcher.get_stock_name(stock_code) + print(f" 📈 {stock_code}({stock_name}): {strategy_result.get_signal_count()} 个信号") except Exception as e: logger.error(f"市场扫描失败: {e}") - print(f"扫描失败: {e}") + print(f"❌ 扫描失败: {e}") -def test_notification(notification_manager: NotificationManager): - """测试通知功能""" - print("\n正在测试通知功能...") +def show_stock_pools(pool_manager): + """显示可用股票池规则""" + print(f"\n📋 可用股票池规则:") + print("-" * 40) + + rules = pool_manager.get_available_rules() + for rule_id, rule_name in rules.items(): + print(f" 🎯 {rule_id}: {rule_name}") + + print(f"\n💡 使用方法: task <规则名> <股票数量>") + print(f" 示例: task tushare_hot 20") + + +def execute_task(executor, rule, max_stocks): + """执行策略任务""" + print(f"\n⚡ 执行策略任务") + print(f"股票池规则: {rule}") + print(f"最大股票数: {max_stocks}") + print("-" * 40) try: - # 发送测试消息 - success = notification_manager.send_test_message() - - if success: - print("✅ 通知测试成功") - else: - print("❌ 通知测试失败,请检查配置") - - # 发送策略信号测试 - test_success = notification_manager.send_strategy_signal( - stock_code="000001.SZ", - stock_name="平安银行", - timeframe="daily", - signal_type="测试信号", - price=10.50, - signal_date="2024-01-15", - additional_info={ - "测试项目": "通知功能", - "发送时间": "现在" - } + result = executor.execute_task( + task_id=f"manual_{rule}_{max_stocks}", + strategy_id="kline_pattern", + stock_pool_rule=rule, + max_stocks=max_stocks, + send_notification=False ) - if test_success: - print("✅ 策略信号通知测试成功") - else: - print("❌ 策略信号通知测试失败") + summary = result.get_summary() + print(f"✅ 任务完成:") + print(f" 任务ID: {summary['task_id']}") + print(f" 成功: {'是' if summary['success'] else '否'}") + print(f" 耗时: {summary['execution_time']:.2f} 秒") + print(f" 信号数: {summary['total_signals_found']} 个") + + if summary['error']: + print(f" 错误: {summary['error']}") except Exception as e: - logger.error(f"通知测试失败: {e}") - print(f"测试失败: {e}") + logger.error(f"任务执行失败: {e}") + print(f"❌ 任务失败: {e}") + + +def show_schedule_examples(scheduler, executor): + """显示定时任务配置示例""" + print(f"\n⏰ 定时任务配置示例:") + print("-" * 50) + + examples = [ + { + "name": "开盘前热门股扫描", + "rule": "weekdays at 09:00", + "desc": "每个工作日9点扫描同花顺热榜" + }, + { + "name": "午间龙头股扫描", + "rule": "weekdays at 12:30", + "desc": "每个工作日12:30扫描龙头股" + }, + { + "name": "收盘后综合扫描", + "rule": "weekdays at 15:30", + "desc": "每个工作日15:30综合扫描" + }, + { + "name": "高频监控", + "rule": "every 15 minutes", + "desc": "每15分钟监控一次" + } + ] + + for i, example in enumerate(examples, 1): + print(f" {i}. {example['name']}") + print(f" 规则: {example['rule']}") + print(f" 说明: {example['desc']}") + print() + + print("💡 要启用定时任务,请参考 examples/ 目录下的配置示例") if __name__ == "__main__": - main() \ No newline at end of file + sys.exit(main()) \ No newline at end of file diff --git a/market_scanner.py b/market_scanner.py new file mode 100644 index 0000000..519ca86 --- /dev/null +++ b/market_scanner.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +""" +市场扫描定时任务脚本 +专门用于Docker容器中定时执行市场扫描 +""" + +import sys +import os +from pathlib import Path + +# 将src目录添加到Python路径 +current_dir = Path(__file__).parent +src_dir = current_dir / "src" +sys.path.insert(0, str(src_dir)) + +from loguru import logger +from src.utils.config_loader import config_loader +from src.data.tushare_fetcher import TushareFetcher +from src.data.stock_pool_manager import StockPoolManager +from src.strategy.kline_pattern_strategy import KLinePatternStrategy +from src.execution.strategy_executor import StrategyExecutor +from src.utils.notification import NotificationManager +from datetime import datetime + + +def setup_logging(): + """设置日志配置""" + log_level = os.environ.get('LOG_LEVEL', 'INFO') + logger.remove() + logger.add( + sys.stdout, + level=log_level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" + ) + + # 确定日志目录 - Docker环境使用/app/logs,本地环境使用./logs + if os.path.exists('/app'): + log_dir = "/app/logs" + else: + log_dir = "./logs" + + # 创建日志目录 + os.makedirs(log_dir, exist_ok=True) + + # 添加文件日志 + log_file = os.path.join(log_dir, "market_scanner.log") + logger.add( + log_file, + level=log_level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", + rotation="1 day", + retention="30 days" + ) + + +def create_strategy_system(): + """创建策略系统的所有组件""" + # 数据层 + fetcher = TushareFetcher() + pool_manager = StockPoolManager(fetcher) + + # 通知层 + notification_config = config_loader.get('notification', {}) + notification_manager = NotificationManager(notification_config) + + # 策略层 + strategy_config = config_loader.get('strategy', {}).get('kline_pattern', { + 'min_entity_ratio': 0.55, + 'final_yang_min_ratio': 0.40, + 'max_turnover_ratio': 40.0, + 'timeframes': ['daily'], + 'pullback_tolerance': 0.02, + 'monitor_days': 30, + 'pullback_confirmation_days': 7 + }) + + # 数据库层 + from src.database.mysql_database_manager import MySQLDatabaseManager + db_manager = MySQLDatabaseManager() + + kline_strategy = KLinePatternStrategy( + data_fetcher=fetcher, + notification_manager=notification_manager, + config=strategy_config, + db_manager=db_manager + ) + + # 执行层 + executor = StrategyExecutor(pool_manager, notification_manager) + + # 注册策略 + executor.register_strategy("kline_pattern", kline_strategy) + + return executor + + +def scan_market(max_stocks=200): + """执行市场扫描""" + logger.info(f"🚀 开始市场扫描任务 - 扫描前{max_stocks}只热门股票") + + try: + # 初始化系统 + executor = create_strategy_system() + + # 执行扫描任务 + task_id = f"market_scan_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + result = executor.execute_task( + task_id=task_id, + strategy_id="kline_pattern", + stock_pool_rule="tushare_hot", + stock_pool_params={"limit": max_stocks}, + max_stocks=max_stocks, + send_notification=True # 启用通知 + ) + + summary = result.get_summary() + logger.info(f"✅ 市场扫描完成:") + logger.info(f" 任务ID: {summary['task_id']}") + logger.info(f" 股票池: {summary['stock_pool_rule_display']}") + logger.info(f" 总扫描: {summary['total_stocks_analyzed']} 只") + logger.info(f" 有信号: {summary['stocks_with_signals']} 只") + logger.info(f" 信号数: {summary['total_signals_found']} 个") + logger.info(f" 耗时: {summary['execution_time']:.2f} 秒") + + if summary['error']: + logger.error(f" 错误: {summary['error']}") + return 1 + + return 0 + + except Exception as e: + logger.error(f"❌ 市场扫描失败: {e}") + return 1 + + +def main(): + """主函数""" + setup_logging() + + # 从环境变量或命令行参数获取扫描数量 + max_stocks = 200 + if len(sys.argv) > 1: + try: + max_stocks = int(sys.argv[1]) + except ValueError: + logger.warning(f"无效的股票数量参数: {sys.argv[1]},使用默认值: 200") + + # 从环境变量获取 + max_stocks = int(os.environ.get('MARKET_SCAN_STOCKS', max_stocks)) + + logger.info(f"📊 市场扫描参数: 最大股票数={max_stocks}") + + return scan_market(max_stocks) + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/quick_test.py b/quick_test.py deleted file mode 100644 index aff4d2a..0000000 --- a/quick_test.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python3 -""" -K线形态策略快速测试 -""" - -import sys -from pathlib import Path - -# 将src目录添加到Python路径 -current_dir = Path(__file__).parent -src_dir = current_dir / "src" -sys.path.insert(0, str(src_dir)) - -from src.data.data_fetcher import ADataFetcher -from src.utils.notification import NotificationManager -from src.strategy.kline_pattern_strategy import KLinePatternStrategy - - -def quick_test(): - """快速测试策略功能""" - print("🚀 K线形态策略快速测试") - print("=" * 50) - - # 配置 - strategy_config = { - 'min_entity_ratio': 0.55, - 'timeframes': ['daily'], - 'scan_stocks_count': 3, # 只测试3只股票 - 'analysis_days': 30 - } - - notification_config = { - 'dingtalk': { - 'enabled': False, # 测试时关闭钉钉通知 - 'webhook_url': '' - } - } - - try: - # 初始化组件 - print("📊 初始化组件...") - data_fetcher = ADataFetcher() - notification_manager = NotificationManager(notification_config) - strategy = KLinePatternStrategy(data_fetcher, notification_manager, strategy_config) - - # 测试1: 单股分析 - print("\n🔍 测试1: 单股K线形态分析") - test_stock = "000001.SZ" - results = strategy.analyze_stock(test_stock) - - total_signals = sum(len(signals) for signals in results.values()) - print(f"✅ {test_stock} 分析完成: {total_signals} 个信号") - - # 测试2: 市场扫描 - print("\n🌍 测试2: 市场形态扫描") - market_results = strategy.scan_market(max_stocks=3) - - total_stocks_with_signals = len(market_results) - total_market_signals = sum( - sum(len(signals) for signals in stock_results.values()) - for stock_results in market_results.values() - ) - print(f"✅ 市场扫描完成: {total_stocks_with_signals} 只股票有信号,共 {total_market_signals} 个信号") - - # 测试3: 通知功能 - print("\n📱 测试3: 通知系统") - notification_success = notification_manager.send_strategy_signal( - stock_code="TEST001", - stock_name="测试股票", - timeframe="daily", - signal_type="快速测试信号", - price=12.34, - additional_info={ - "测试类型": "快速验证", - "状态": "正常" - } - ) - print(f"✅ 通知系统测试完成: {'成功' if notification_success else '失败(正常,未配置钉钉)'}") - - print("\n🎉 所有测试通过!") - print("\n📝 使用方法:") - print(" python main.py # 启动完整系统") - print(" python test_strategy.py # 详细功能测试") - - print("\n⚙️ 配置钉钉通知:") - print(" 1. 在钉钉群中添加自定义机器人") - print(" 2. 复制webhook地址到 config/config.yaml") - print(" 3. 设置 notification.dingtalk.enabled: true") - - except Exception as e: - print(f"❌ 测试失败: {e}") - return False - - return True - - -if __name__ == "__main__": - success = quick_test() - sys.exit(0 if success else 1) \ No newline at end of file diff --git a/scripts/init_container.sh b/scripts/init_container.sh index 1b725a5..2b58008 100644 --- a/scripts/init_container.sh +++ b/scripts/init_container.sh @@ -40,7 +40,7 @@ trading: data: # 数据源配置 sources: - primary: "adata" + primary: "tushare" # 数据更新频率 update_frequency: diff --git a/src/data/data_fetcher.py b/src/data/data_fetcher.py deleted file mode 100644 index aafa3ff..0000000 --- a/src/data/data_fetcher.py +++ /dev/null @@ -1,863 +0,0 @@ -""" -A股数据获取模块 -使用adata库获取A股市场数据 -""" - -import adata -import pandas as pd -from typing import List, Optional, Union -from datetime import datetime, date -import time -from loguru import logger - - -class ADataFetcher: - """A股数据获取器""" - - def __init__(self): - """初始化数据获取器""" - self.client = adata - - # 股票名称缓存机制 - self._stock_name_cache = {} - self._stock_list_cache = None - self._hot_stocks_cache = None - self._east_stocks_cache = None - self._cache_timestamp = None - self._cache_duration = 3600 # 缓存1小时 - - logger.info("AData客户端初始化完成") - - def get_stock_list(self, market: str = "A") -> pd.DataFrame: - """ - 获取股票列表 - - Args: - market: 市场类型,默认为A股 - - Returns: - 股票列表DataFrame - """ - try: - stock_list = self.client.stock.info.all_code() - logger.info(f"获取股票列表成功,共{len(stock_list)}只股票") - return stock_list - except Exception as e: - logger.error(f"获取股票列表失败: {e}") - return pd.DataFrame() - - def get_filtered_a_share_list(self, exclude_st: bool = True, exclude_bj: bool = True, min_market_cap: float = 2000000000) -> pd.DataFrame: - """ - 获取过滤后的A股股票列表 - - Args: - exclude_st: 是否排除ST股票 - exclude_bj: 是否排除北交所股票 - min_market_cap: 最小市值要求(元),默认20亿 - - Returns: - 过滤后的股票列表DataFrame - """ - try: - # 获取完整股票列表 - all_stocks = self.get_stock_list() - - if all_stocks.empty: - return pd.DataFrame() - - filtered_stocks = all_stocks.copy() - original_count = len(filtered_stocks) - - # 排除北交所股票 - if exclude_bj: - before_count = len(filtered_stocks) - filtered_stocks = filtered_stocks[filtered_stocks['exchange'] != 'BJ'] - bj_excluded = before_count - len(filtered_stocks) - logger.info(f"排除北交所股票: {bj_excluded}只") - - # 排除ST股票(包含ST、*ST、PT、退市等) - if exclude_st: - before_count = len(filtered_stocks) - # 排除包含ST、*ST、PT、退等字符的股票 - st_pattern = r'(\*?ST|PT|退|暂停)' - filtered_stocks = filtered_stocks[~filtered_stocks['short_name'].str.contains(st_pattern, na=False, case=False)] - st_excluded = before_count - len(filtered_stocks) - logger.info(f"排除ST等风险股票: {st_excluded}只") - - # 基于实际市值的筛选 - if min_market_cap > 0: - before_count = len(filtered_stocks) - filtered_stocks = self._filter_by_real_market_cap(filtered_stocks, min_market_cap) - cap_excluded = before_count - len(filtered_stocks) - logger.info(f"排除小市值股票(基于实际市值): {cap_excluded}只") - - # 统计最终结果 - final_count = len(filtered_stocks) - excluded_count = original_count - final_count - - # 添加完整股票代码(带交易所后缀) - if not filtered_stocks.empty and 'exchange' in filtered_stocks.columns: - filtered_stocks['full_stock_code'] = filtered_stocks.apply( - lambda row: f"{row['stock_code']}.{row['exchange']}", axis=1 - ) - - exchange_counts = filtered_stocks['exchange'].value_counts().to_dict() - exchange_detail = " | ".join([f"{k}: {v}只" for k, v in exchange_counts.items()]) - logger.info(f"✅ 获取过滤后A股列表成功") - logger.info(f"📊 原始股票: {original_count}只 | 过滤后: {final_count}只 | 排除: {excluded_count}只") - logger.info(f"📈 交易所分布: {exchange_detail}") - - return filtered_stocks - - except Exception as e: - logger.error(f"获取过滤A股列表失败: {e}") - return pd.DataFrame() - - def _filter_by_real_market_cap(self, stock_df: pd.DataFrame, min_market_cap: float) -> pd.DataFrame: - """ - 基于实际市值筛选股票 - 由于API限制,先使用启发式规则预筛选,再对部分股票进行实际市值验证 - - Args: - stock_df: 股票列表DataFrame - min_market_cap: 最小市值要求(元) - - Returns: - 过滤后的股票DataFrame - """ - if stock_df.empty: - return stock_df - - logger.info(f"开始基于市值筛选股票,阈值: {min_market_cap/100000000:.0f}亿元") - - # 步骤1: 使用启发式规则进行预筛选,减少API调用量 - logger.info("步骤1: 使用启发式规则预筛选...") - pre_filtered = self._filter_by_market_cap_proxy(stock_df, min_market_cap) - - if pre_filtered.empty: - logger.warning("启发式预筛选后无股票,返回空结果") - return pd.DataFrame() - - logger.info(f"启发式预筛选完成: {len(pre_filtered)}/{len(stock_df)} 只股票") - - # 步骤2: 对预筛选的股票进行小批量实际市值验证 - logger.info("步骤2: 对预筛选股票进行实际市值验证...") - - # 限制验证数量以避免API超时 - max_verify_count = min(500, len(pre_filtered)) # 最多验证500只 - stocks_to_verify = pre_filtered.head(max_verify_count) - - logger.info(f"将验证 {len(stocks_to_verify)} 只股票的实际市值") - - valid_stocks = [] - total_to_verify = len(stocks_to_verify) - - for idx, (_, stock) in enumerate(stocks_to_verify.iterrows()): - stock_code = stock['stock_code'] - exchange = stock['exchange'] - full_stock_code = f"{stock_code}.{exchange}" - - try: - # 获取股本信息 - shares_info = self.client.stock.info.get_stock_shares(stock_code=stock_code) - - if not shares_info.empty and 'total_share' in shares_info.columns: - total_shares = shares_info.iloc[0]['total_share'] - - # 获取当前股价 - current_price = None - try: - market_data = self.client.stock.market.get_market(full_stock_code) - if not market_data.empty and 'close' in market_data.columns: - current_price = market_data.iloc[0]['close'] - except: - pass - - # 计算市值 - if current_price is not None and total_shares > 0: - market_cap = total_shares * 10000 * current_price # 万股转换为股 - - if market_cap >= min_market_cap: - stock_with_cap = stock.copy() - stock_with_cap['market_cap'] = market_cap - stock_with_cap['total_shares'] = total_shares - stock_with_cap['current_price'] = current_price - valid_stocks.append(stock_with_cap) - - logger.debug(f"{full_stock_code}: 市值{market_cap/100000000:.1f}亿元 {'✓' if market_cap >= min_market_cap else '✗'}") - else: - # 如果无法获取实际市值,且预筛选通过,则保留 - valid_stocks.append(stock) - logger.debug(f"{full_stock_code}: 无市值数据,保留预筛选结果") - - except Exception as e: - # 如果API调用失败,且预筛选通过,则保留 - valid_stocks.append(stock) - logger.debug(f"{full_stock_code}: API失败,保留预筛选结果: {e}") - - # 显示进度 - if (idx + 1) % 50 == 0 or idx + 1 == total_to_verify: - logger.info(f"市值验证进度: {idx + 1}/{total_to_verify} ({(idx + 1)/total_to_verify*100:.1f}%)") - - # 添加延时以避免API限制 - if idx % 10 == 9: # 每10个请求休息0.1秒 - time.sleep(0.1) - - # 步骤3: 对于剩余未验证的股票,直接使用预筛选结果 - if len(pre_filtered) > max_verify_count: - remaining_stocks = pre_filtered.iloc[max_verify_count:].copy() - for _, stock in remaining_stocks.iterrows(): - valid_stocks.append(stock) - - logger.info(f"保留 {len(remaining_stocks)} 只未验证股票(基于预筛选结果)") - - # 转换为DataFrame - if valid_stocks: - result_df = pd.DataFrame(valid_stocks) - - # 确保没有重复 - if 'stock_code' in result_df.columns: - result_df = result_df.drop_duplicates(subset=['stock_code'], keep='first') - - logger.info(f"✅ 市值筛选完成: {len(result_df)}/{len(stock_df)} 只股票符合要求") - - # 统计实际验证vs预筛选的结果 - verified_count = min(max_verify_count, len(stocks_to_verify)) - unverified_count = len(result_df) - verified_count - logger.info(f"📊 验证详情: 实际验证{verified_count}只, 预筛选保留{unverified_count}只") - - return result_df - else: - logger.warning("⚠️ 没有股票通过市值筛选") - return pd.DataFrame() - - def _filter_by_market_cap_proxy(self, stock_df: pd.DataFrame, min_market_cap: float) -> pd.DataFrame: - """ - 基于股票代码的启发式规则筛选大市值股票 - - Args: - stock_df: 股票列表DataFrame - min_market_cap: 最小市值要求(元) - - Returns: - 过滤后的股票DataFrame - """ - if stock_df.empty: - return stock_df - - # 由于无法直接获取市值数据,使用启发式规则进行筛选 - # 注意:这只是一个近似筛选,真实的市值筛选需要实际的市值数据 - - def is_likely_large_cap(stock_code: str, exchange: str) -> bool: - """判断股票是否可能是大市值股票""" - code_num = stock_code - - if exchange == 'SH': # 上交所 - # 主板: 600xxx, 601xxx, 603xxx, 605xxx (通常市值较大) - if code_num.startswith(('600', '601', '603', '605')): - return True - # 科创板: 688xxx (新兴科技公司,市值相对较大) - elif code_num.startswith('688'): - return True - # 其他上交所股票 - return False - - elif exchange == 'SZ': # 深交所 - # 主板: 000xxx, 001xxx (老牌蓝筹,通常市值较大) - if code_num.startswith(('000', '001')): - return True - # 中小板: 002xxx (部分有大市值公司) - elif code_num.startswith('002'): - # 002开头的前1000只股票(002000-002999),上市较早,可能市值较大 - try: - code_suffix = int(code_num[3:]) - return code_suffix <= 999 # 002000-002999 - except: - return False - # 创业板: 300xxx, 301xxx (部分成长为大市值) - elif code_num.startswith(('300', '301')): - # 300开头的前500只股票,上市较早,部分已成长为大市值 - try: - if code_num.startswith('300'): - code_suffix = int(code_num[3:]) - return code_suffix <= 499 # 300000-300499 - else: # 301xxx较新,市值相对较小 - return False - except: - return False - return False - - return False # 其他情况默认排除 - - # 应用筛选规则 - if min_market_cap >= 2000000000: # 20亿以上 - logger.info(f"应用大市值筛选规则(≥{min_market_cap/100000000:.0f}亿元)") - mask = stock_df.apply(lambda row: is_likely_large_cap(row['stock_code'], row['exchange']), axis=1) - return stock_df[mask] - else: - # 小于20亿的筛选条件暂时不实施严格筛选 - logger.info(f"市值筛选阈值较低({min_market_cap/100000000:.1f}亿元),保留所有股票") - return stock_df - - def get_realtime_data(self, stock_codes: Union[str, List[str]]) -> pd.DataFrame: - """ - 获取实时行情数据 - - Args: - stock_codes: 股票代码或代码列表 - - Returns: - 实时行情DataFrame - """ - try: - if isinstance(stock_codes, str): - stock_codes = [stock_codes] - - realtime_data = self.client.stock.market.get_market(stock_codes) - logger.info(f"获取实时数据成功,股票数量: {len(stock_codes)}") - return realtime_data - except Exception as e: - logger.error(f"获取实时数据失败: {e}") - return pd.DataFrame() - - def get_historical_data( - self, - stock_code: str, - start_date: Union[str, date], - end_date: Union[str, date], - period: str = "daily" - ) -> pd.DataFrame: - """ - 获取历史行情数据 - - Args: - stock_code: 股票代码 - start_date: 开始日期 - end_date: 结束日期 - period: 数据周期 ('daily', 'weekly', 'monthly') - - Returns: - 历史行情DataFrame - """ - try: - # 转换日期格式 - if isinstance(start_date, date): - start_date = start_date.strftime("%Y-%m-%d") - if isinstance(end_date, date): - end_date = end_date.strftime("%Y-%m-%d") - - # 根据周期设置k_type参数 - k_type_map = { - 'daily': 1, # 日线 - 'weekly': 2, # 周线 - 'monthly': 3 # 月线 - } - k_type = k_type_map.get(period, 1) - - # 尝试获取数据 - hist_data = pd.DataFrame() - - # 方法1: 使用get_market获取指定周期数据 - try: - hist_data = self.client.stock.market.get_market( - stock_code, - k_type=k_type, - start_date=start_date, - end_date=end_date - ) - except Exception as e: - logger.debug(f"get_market失败: {e}") - - # 方法2: 如果方法1失败,尝试get_market_bar - if hist_data.empty: - try: - hist_data = self.client.stock.market.get_market_bar( - stock_code=stock_code, - start_date=start_date, - end_date=end_date - ) - except Exception as e: - logger.debug(f"get_market_bar失败: {e}") - - # 方法3: 如果以上都失败,生成模拟数据用于测试 - if hist_data.empty: - logger.warning(f"无法获取{stock_code}真实数据,生成模拟数据用于测试") - hist_data = self._generate_mock_data(stock_code, start_date, end_date) - - if not hist_data.empty: - logger.info(f"获取{stock_code}历史数据成功,数据量: {len(hist_data)}") - else: - logger.warning(f"获取{stock_code}历史数据为空") - - return hist_data - - except Exception as e: - logger.error(f"获取{stock_code}历史数据失败: {e}") - # 返回模拟数据作为后备 - return self._generate_mock_data(stock_code, start_date, end_date) - - def _generate_mock_data(self, stock_code: str, start_date: str, end_date: str) -> pd.DataFrame: - """ - 生成模拟K线数据用于测试 - - Args: - stock_code: 股票代码 - start_date: 开始日期 - end_date: 结束日期 - - Returns: - 模拟K线数据 - """ - try: - import numpy as np - from datetime import datetime, timedelta - - start = datetime.strptime(start_date, "%Y-%m-%d") - end = datetime.strptime(end_date, "%Y-%m-%d") - - # 生成交易日期(排除周末) - dates = [] - current = start - while current <= end: - if current.weekday() < 5: # 周一到周五 - dates.append(current) - current += timedelta(days=1) - - if not dates: - return pd.DataFrame() - - n = len(dates) - - # 生成模拟价格数据 - 创建一个包含我们需要形态的序列 - base_price = 10.0 - prices = [] - - # 设置随机种子以获得可重现的结果 - np.random.seed(hash(stock_code) % 1000) - - for i in range(n): - # 在某些位置插入"两阳线+阴线+阳线"形态 - if i % 20 == 10 and i < n - 4: # 每20个交易日插入一次形态 - # 两阳线 - prices.extend([ - base_price + 0.5, # 阳线1 - base_price + 1.0, # 阳线2 - base_price + 0.3, # 阴线 - base_price + 1.5 # 突破阳线 - ]) - i += 3 # 跳过已生成的数据点 - else: - # 正常随机价格 - change = np.random.uniform(-0.5, 0.5) - base_price = max(5.0, base_price + change) # 确保价格不会太低 - prices.append(base_price) - - # 确保价格数组长度匹配日期数量 - while len(prices) < n: - prices.append(base_price + np.random.uniform(-0.2, 0.2)) - prices = prices[:n] - - # 生成OHLC数据 - data = [] - for i, (date, close) in enumerate(zip(dates, prices)): - # 生成开盘价 - if i == 0: - open_price = close - np.random.uniform(-0.3, 0.3) - else: - open_price = prices[i-1] + np.random.uniform(-0.2, 0.2) - - # 确保高低价格的合理性 - high = max(open_price, close) + np.random.uniform(0, 0.5) - low = min(open_price, close) - np.random.uniform(0, 0.3) - - # 确保价格顺序正确 - low = max(0.1, low) # 确保最低价格为正数 - high = max(low + 0.1, high) # 确保最高价高于最低价 - - data.append({ - 'trade_date': date.strftime('%Y-%m-%d'), - 'open': round(open_price, 2), - 'high': round(high, 2), - 'low': round(low, 2), - 'close': round(close, 2), - 'volume': int(np.random.uniform(1000, 10000)) - }) - - mock_df = pd.DataFrame(data) - logger.info(f"生成{stock_code}模拟数据,数据量: {len(mock_df)}") - return mock_df - - except Exception as e: - logger.error(f"生成模拟数据失败: {e}") - return pd.DataFrame() - - def get_index_data(self, index_code: str = "000001.SH") -> pd.DataFrame: - """ - 获取指数数据 - - Args: - index_code: 指数代码 - - Returns: - 指数数据DataFrame - """ - try: - index_data = self.client.stock.market.get_market(index_code) - logger.info(f"获取指数{index_code}数据成功") - return index_data - except Exception as e: - logger.error(f"获取指数数据失败: {e}") - return pd.DataFrame() - - def get_financial_data(self, stock_code: str) -> pd.DataFrame: - """ - 获取财务数据 - - Args: - stock_code: 股票代码 - - Returns: - 财务数据DataFrame - """ - try: - financial_data = self.client.stock.info.financial(stock_code) - logger.info(f"获取{stock_code}财务数据成功") - return financial_data - except Exception as e: - logger.error(f"获取财务数据失败: {e}") - return pd.DataFrame() - - def search_stocks(self, keyword: str) -> pd.DataFrame: - """ - 搜索股票(基于本地股票列表) - - Args: - keyword: 搜索关键词 - - Returns: - 搜索结果DataFrame - """ - try: - # 获取完整股票列表 - all_stocks = self.get_stock_list() - if all_stocks.empty: - return pd.DataFrame() - - # 在股票代码和名称中搜索关键词 - keyword = str(keyword).strip() - if not keyword: - return pd.DataFrame() - - # 支持按代码或名称模糊搜索 - mask = ( - all_stocks['stock_code'].str.contains(keyword, case=False, na=False) | - all_stocks['short_name'].str.contains(keyword, case=False, na=False) - ) - - results = all_stocks[mask].copy() - logger.info(f"搜索股票'{keyword}'成功,找到{len(results)}个结果") - return results - except Exception as e: - logger.error(f"搜索股票失败: {e}") - return pd.DataFrame() - - def get_hot_stocks_ths(self, limit: int = 100) -> pd.DataFrame: - """ - 获取同花顺热股TOP100 - - Args: - limit: 返回的热股数量,默认100 - - Returns: - 热股数据DataFrame,包含股票代码、名称、涨跌幅等信息 - """ - try: - # 获取同花顺热股TOP100 - hot_stocks = self.client.sentiment.hot.hot_rank_100_ths() - - if not hot_stocks.empty: - # 限制返回数量 - hot_stocks = hot_stocks.head(limit) - logger.info(f"获取同花顺热股成功,共{len(hot_stocks)}只股票") - return hot_stocks - else: - logger.warning("获取同花顺热股数据为空") - return pd.DataFrame() - - except Exception as e: - logger.error(f"获取同花顺热股失败: {e}") - # 返回空DataFrame作为后备 - return pd.DataFrame() - - def get_popular_stocks_east(self, limit: int = 100) -> pd.DataFrame: - """ - 获取东方财富人气榜TOP100 - - Args: - limit: 返回的人气股数量,默认100 - - Returns: - 人气股数据DataFrame,包含股票代码、名称、涨跌幅等信息 - """ - try: - # 获取东方财富人气榜TOP100 - popular_stocks = self.client.sentiment.hot.pop_rank_100_east() - - if not popular_stocks.empty: - # 限制返回数量 - popular_stocks = popular_stocks.head(limit) - logger.info(f"获取东财人气股成功,共{len(popular_stocks)}只股票") - return popular_stocks - else: - logger.warning("获取东财人气股数据为空") - return pd.DataFrame() - - except Exception as e: - logger.error(f"获取东财人气股失败: {e}") - # 返回空DataFrame作为后备 - return pd.DataFrame() - - def _is_cache_valid(self) -> bool: - """检查缓存是否有效""" - if self._cache_timestamp is None: - return False - - import time - return (time.time() - self._cache_timestamp) < self._cache_duration - - def _update_stock_name_cache(self): - """更新股票名称缓存""" - try: - import time - - # 检查缓存是否有效 - if self._is_cache_valid(): - return - - logger.info("🔄 更新股票名称缓存...") - - # 获取热门股票数据并缓存 - self._hot_stocks_cache = self.get_hot_stocks_ths(limit=100) - self._east_stocks_cache = self.get_popular_stocks_east(limit=100) - - # 清空名称缓存并重新构建 - self._stock_name_cache.clear() - - # 从热门股票数据中构建缓存 - for df, source in [(self._hot_stocks_cache, '同花顺'), (self._east_stocks_cache, '东财')]: - if not df.empty and 'stock_code' in df.columns and 'short_name' in df.columns: - for _, row in df.iterrows(): - stock_code = row['stock_code'] - stock_name = row['short_name'] - if stock_code not in self._stock_name_cache: - self._stock_name_cache[stock_code] = stock_name - - # 更新缓存时间戳 - self._cache_timestamp = time.time() - - logger.info(f"✅ 股票名称缓存更新完成,共缓存 {len(self._stock_name_cache)} 只股票") - - except Exception as e: - logger.warning(f"更新股票名称缓存失败: {e}") - - def get_stock_name(self, stock_code: str) -> str: - """ - 获取股票中文名称(带缓存机制) - - Args: - stock_code: 股票代码 - - Returns: - 股票中文名称,如果获取失败返回股票代码 - """ - try: - # 更新缓存(如果需要) - self._update_stock_name_cache() - - # 从缓存中查找 - if stock_code in self._stock_name_cache: - return self._stock_name_cache[stock_code] - - # 缓存中没有,尝试搜索功能 - search_results = self.search_stocks(stock_code) - if not search_results.empty and 'short_name' in search_results.columns: - stock_name = search_results.iloc[0]['short_name'] - # 添加到缓存 - self._stock_name_cache[stock_code] = stock_name - return stock_name - - # 如果都失败,返回股票代码 - logger.debug(f"未能获取{stock_code}的中文名称") - return stock_code - - except Exception as e: - logger.debug(f"获取股票{stock_code}名称失败: {e}") - return stock_code - - def get_combined_hot_stocks(self, limit_per_source: int = 100, final_limit: int = 150) -> pd.DataFrame: - """ - 获取合并去重的热门股票(同花顺热股 + 东财人气榜) - - Args: - limit_per_source: 每个数据源的获取数量,默认100 - final_limit: 最终返回的股票数量,默认150 - - Returns: - 合并去重后的热门股票DataFrame - """ - try: - logger.info("开始获取合并热门股票数据...") - - # 获取同花顺热股 - ths_stocks = self.get_hot_stocks_ths(limit=limit_per_source) - - # 获取东财人气股 - east_stocks = self.get_popular_stocks_east(limit=limit_per_source) - - combined_stocks = pd.DataFrame() - - # 合并数据 - if not ths_stocks.empty and not east_stocks.empty: - # 标记数据源 - ths_stocks['source'] = '同花顺' - east_stocks['source'] = '东财' - - # 尝试合并,处理列名差异 - try: - # 统一列名映射 - ths_rename_map = {} - east_rename_map = {} - - # 检查股票代码列名 - if 'stock_code' in ths_stocks.columns: - ths_rename_map['stock_code'] = 'stock_code' - elif 'code' in ths_stocks.columns: - ths_rename_map['code'] = 'stock_code' - - if 'stock_code' in east_stocks.columns: - east_rename_map['stock_code'] = 'stock_code' - elif 'code' in east_stocks.columns: - east_rename_map['code'] = 'stock_code' - - # 重命名列名 - if ths_rename_map: - ths_stocks = ths_stocks.rename(columns=ths_rename_map) - if east_rename_map: - east_stocks = east_stocks.rename(columns=east_rename_map) - - # 确保都有stock_code列 - if 'stock_code' in ths_stocks.columns and 'stock_code' in east_stocks.columns: - # 合并数据框 - combined_stocks = pd.concat([ths_stocks, east_stocks], ignore_index=True) - - # 按股票代码去重,保留第一个出现的记录 - combined_stocks = combined_stocks.drop_duplicates(subset=['stock_code'], keep='first') - - # 限制最终数量 - combined_stocks = combined_stocks.head(final_limit) - - logger.info(f"合并热门股票成功:同花顺{len(ths_stocks)}只 + 东财{len(east_stocks)}只 → 去重后{len(combined_stocks)}只") - else: - logger.warning("股票代码列名不匹配,使用同花顺数据") - combined_stocks = ths_stocks.head(final_limit) - - except Exception as merge_error: - logger.error(f"合并数据时出错: {merge_error},使用同花顺数据") - combined_stocks = ths_stocks.head(final_limit) - - elif not ths_stocks.empty: - logger.info("仅获取到同花顺数据") - combined_stocks = ths_stocks.head(final_limit) - combined_stocks['source'] = '同花顺' - - elif not east_stocks.empty: - logger.info("仅获取到东财数据") - combined_stocks = east_stocks.head(final_limit) - combined_stocks['source'] = '东财' - - else: - logger.warning("两个数据源都未获取到数据") - return pd.DataFrame() - - return combined_stocks - - except Exception as e: - logger.error(f"获取合并热门股票失败: {e}") - return pd.DataFrame() - - def get_market_overview(self) -> dict: - """ - 获取市场概况 - - Returns: - 市场概况字典 - """ - try: - # 获取主要指数数据 - sh_index = self.get_index_data("000001.SH") # 上证指数 - sz_index = self.get_index_data("399001.SZ") # 深证成指 - cyb_index = self.get_index_data("399006.SZ") # 创业板指 - - overview = { - "update_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "shanghai": sh_index.iloc[0].to_dict() if not sh_index.empty else {}, - "shenzhen": sz_index.iloc[0].to_dict() if not sz_index.empty else {}, - "chinext": cyb_index.iloc[0].to_dict() if not cyb_index.empty else {} - } - - logger.info("获取市场概况成功") - return overview - except Exception as e: - logger.error(f"获取市场概况失败: {e}") - return {} - - -if __name__ == "__main__": - # 测试代码 - fetcher = ADataFetcher() - - # 测试获取股票列表 - print("测试获取股票列表...") - stock_list = fetcher.get_stock_list() - print(f"股票数量: {len(stock_list)}") - print(stock_list.head()) - - # 测试同花顺热股 - print("\n测试获取同花顺热股TOP10...") - hot_stocks = fetcher.get_hot_stocks_ths(limit=10) - if not hot_stocks.empty: - print(f"同花顺热股数量: {len(hot_stocks)}") - print(hot_stocks.head()) - else: - print("未能获取同花顺热股数据") - - # 测试东财人气股 - print("\n测试获取东财人气股TOP10...") - east_stocks = fetcher.get_popular_stocks_east(limit=10) - if not east_stocks.empty: - print(f"东财人气股数量: {len(east_stocks)}") - print(east_stocks.head()) - else: - print("未能获取东财人气股数据") - - # 测试合并热门股票 - print("\n测试获取合并热门股票TOP15...") - combined_stocks = fetcher.get_combined_hot_stocks(limit_per_source=10, final_limit=15) - if not combined_stocks.empty: - print(f"合并后股票数量: {len(combined_stocks)}") - if 'source' in combined_stocks.columns: - source_counts = combined_stocks['source'].value_counts().to_dict() - print(f"数据源分布: {source_counts}") - print(combined_stocks[['stock_code', 'source'] if 'source' in combined_stocks.columns else ['stock_code']].head()) - else: - print("未能获取合并热门股票数据") - - # 测试搜索功能 - print("\n测试搜索功能...") - search_results = fetcher.search_stocks("平安") - print(search_results.head()) - - # 测试获取市场概况 - print("\n测试获取市场概况...") - overview = fetcher.get_market_overview() - print(overview) \ No newline at end of file diff --git a/src/data/stock_pool_manager.py b/src/data/stock_pool_manager.py new file mode 100644 index 0000000..321fb7a --- /dev/null +++ b/src/data/stock_pool_manager.py @@ -0,0 +1,263 @@ +""" +股票池管理器 +负责根据不同规则获取和管理股票池 +""" + +from typing import List, Dict, Any, Optional +import pandas as pd +from loguru import logger +from abc import ABC, abstractmethod +from src.data.tushare_fetcher import TushareFetcher + + +class StockPoolRule(ABC): + """股票池规则抽象基类""" + + @abstractmethod + def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]: + """ + 获取股票列表 + + Args: + fetcher: 数据获取器 + **kwargs: 规则参数 + + Returns: + 股票代码列表 + """ + pass + + @abstractmethod + def get_rule_name(self) -> str: + """获取规则名称""" + pass + + +class TushareHotStocksRule(StockPoolRule): + """同花顺热榜股票池规则""" + + def get_stocks(self, fetcher: TushareFetcher, limit: int = 50, **kwargs) -> List[str]: + """获取同花顺热榜股票""" + try: + hot_stocks = fetcher.get_hot_stocks_ths(limit=limit) + if not hot_stocks.empty and 'stock_code' in hot_stocks.columns: + stocks = hot_stocks['stock_code'].tolist() + logger.info(f"✅ 同花顺热榜获取成功: {len(stocks)}只股票") + return stocks + else: + logger.warning("同花顺热榜数据为空") + return [] + except Exception as e: + logger.error(f"获取同花顺热榜失败: {e}") + return [] + + def get_rule_name(self) -> str: + return "同花顺热榜" + + +class CombinedHotStocksRule(StockPoolRule): + """合并热门股票池规则(同花顺+东财)""" + + def get_stocks(self, fetcher: TushareFetcher, limit_per_source: int = 30, final_limit: int = 50, **kwargs) -> List[str]: + """获取合并热门股票""" + try: + combined_stocks = fetcher.get_combined_hot_stocks( + limit_per_source=limit_per_source, + final_limit=final_limit + ) + if not combined_stocks.empty and 'stock_code' in combined_stocks.columns: + stocks = combined_stocks['stock_code'].tolist() + logger.info(f"✅ 合并热门股票获取成功: {len(stocks)}只股票") + return stocks + else: + logger.warning("合并热门股票数据为空") + return [] + except Exception as e: + logger.error(f"获取合并热门股票失败: {e}") + return [] + + def get_rule_name(self) -> str: + return "合并热门股票" + + +class LeadingStocksRule(StockPoolRule): + """龙头牛股股票池规则""" + + def get_stocks(self, fetcher: TushareFetcher, top_boards: int = 8, stocks_per_board: int = 3, min_score: float = 60.0, **kwargs) -> List[str]: + """获取龙头牛股""" + try: + result = fetcher.get_leading_stocks_from_hot_boards( + top_boards=top_boards, + stocks_per_board=stocks_per_board, + min_score=min_score + ) + + if 'error' not in result and not result['top_leading_stocks'].empty: + stocks = result['top_leading_stocks']['stock_code'].tolist() + logger.info(f"✅ 龙头牛股获取成功: {len(stocks)}只股票") + return stocks + else: + logger.warning("龙头牛股数据为空") + return [] + except Exception as e: + logger.error(f"获取龙头牛股失败: {e}") + return [] + + def get_rule_name(self) -> str: + return "龙头牛股" + + +class CustomStockListRule(StockPoolRule): + """自定义股票列表规则""" + + def __init__(self, stock_list: List[str]): + self.stock_list = stock_list + + def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]: + """返回自定义股票列表""" + logger.info(f"✅ 使用自定义股票列表: {len(self.stock_list)}只股票") + return self.stock_list.copy() + + def get_rule_name(self) -> str: + return "自定义股票列表" + + +class StockPoolManager: + """股票池管理器""" + + def __init__(self, fetcher: TushareFetcher): + """ + 初始化股票池管理器 + + Args: + fetcher: TuShare数据获取器 + """ + self.fetcher = fetcher + self.rules: Dict[str, StockPoolRule] = {} + self._register_default_rules() + + def _register_default_rules(self): + """注册默认规则""" + self.register_rule("tushare_hot", TushareHotStocksRule()) + self.register_rule("combined_hot", CombinedHotStocksRule()) + self.register_rule("leading_stocks", LeadingStocksRule()) + + def register_rule(self, rule_name: str, rule: StockPoolRule): + """ + 注册股票池规则 + + Args: + rule_name: 规则名称 + rule: 规则实例 + """ + self.rules[rule_name] = rule + logger.info(f"注册股票池规则: {rule_name} - {rule.get_rule_name()}") + + def get_stock_pool(self, rule_name: str, **kwargs) -> Dict[str, Any]: + """ + 根据规则获取股票池 + + Args: + rule_name: 规则名称 + **kwargs: 规则参数 + + Returns: + 包含股票列表和元信息的字典 + """ + if rule_name not in self.rules: + logger.error(f"未找到股票池规则: {rule_name}") + return { + 'stocks': [], + 'rule_name': rule_name, + 'rule_display_name': '未知规则', + 'total_count': 0, + 'success': False, + 'error': f'未找到规则: {rule_name}' + } + + rule = self.rules[rule_name] + + try: + logger.info(f"🔍 执行股票池规则: {rule.get_rule_name()}") + stocks = rule.get_stocks(self.fetcher, **kwargs) + + return { + 'stocks': stocks, + 'rule_name': rule_name, + 'rule_display_name': rule.get_rule_name(), + 'total_count': len(stocks), + 'success': True, + 'parameters': kwargs + } + + except Exception as e: + logger.error(f"执行股票池规则失败 {rule.get_rule_name()}: {e}") + return { + 'stocks': [], + 'rule_name': rule_name, + 'rule_display_name': rule.get_rule_name(), + 'total_count': 0, + 'success': False, + 'error': str(e) + } + + def get_available_rules(self) -> Dict[str, str]: + """ + 获取可用的规则列表 + + Returns: + 规则名称到显示名称的映射 + """ + return {name: rule.get_rule_name() for name, rule in self.rules.items()} + + def create_custom_rule(self, rule_name: str, stock_list: List[str]): + """ + 创建自定义股票列表规则 + + Args: + rule_name: 规则名称 + stock_list: 股票代码列表 + """ + custom_rule = CustomStockListRule(stock_list) + self.register_rule(rule_name, custom_rule) + + +if __name__ == "__main__": + # 测试股票池管理器 + from loguru import logger + import sys + + logger.remove() + logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") + + # 初始化 + fetcher = TushareFetcher() + pool_manager = StockPoolManager(fetcher) + + print("=" * 60) + print("📊 股票池管理器测试") + print("=" * 60) + + # 显示可用规则 + print("可用规则:") + for rule_id, rule_name in pool_manager.get_available_rules().items(): + print(f" {rule_id}: {rule_name}") + + # 测试同花顺热榜 + print(f"\n测试同花顺热榜:") + result = pool_manager.get_stock_pool("tushare_hot", limit=10) + if result['success']: + print(f"✅ 获取成功: {result['total_count']}只股票") + print(f"前5只: {result['stocks'][:5]}") + else: + print(f"❌ 获取失败: {result['error']}") + + # 测试自定义规则 + print(f"\n测试自定义规则:") + custom_stocks = ["000001.SZ", "000002.SZ", "600000.SH"] + pool_manager.create_custom_rule("my_custom", custom_stocks) + + result = pool_manager.get_stock_pool("my_custom") + if result['success']: + print(f"✅ 自定义规则: {result['total_count']}只股票") + print(f"股票: {result['stocks']}") \ No newline at end of file diff --git a/src/data/tushare_fetcher.py b/src/data/tushare_fetcher.py new file mode 100644 index 0000000..875691e --- /dev/null +++ b/src/data/tushare_fetcher.py @@ -0,0 +1,1560 @@ +""" +A股数据获取模块 +使用Tushare Pro API获取A股市场数据 +""" + +import tushare as ts +import pandas as pd +from typing import List, Optional, Union +from datetime import datetime, date, timedelta +import time +from loguru import logger +from functools import wraps +from src.utils.config_loader import config_loader + + +def retry_on_failure(retries: int = 3, delay: float = 1.0): + """ + 重试装饰器,用于网络请求失败时自动重试 + + Args: + retries: 重试次数 + delay: 重试间隔(秒) + """ + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + last_exception = None + + for attempt in range(retries + 1): + try: + return func(*args, **kwargs) + except Exception as e: + last_exception = e + if attempt < retries: + logger.warning(f"{func.__name__} 第{attempt + 1}次调用失败: {e}, {delay}秒后重试...") + time.sleep(delay) + else: + logger.error(f"{func.__name__} 已重试{retries}次仍然失败: {e}") + + raise last_exception + return wrapper + return decorator + + +class TushareFetcher: + """Tushare数据获取器""" + + def __init__(self, token: str = None): + """ + 初始化数据获取器 + + Args: + token: Tushare Pro token,如果为None则从配置文件读取 + """ + # 如果没有传入token,从配置文件读取 + if token is None: + token = config_loader.get_tushare_token() + if token: + logger.info("✅ 从配置文件读取TuShare token成功") + else: + logger.warning("⚠️ 配置文件中未找到TuShare token") + + self.token = token + if token: + try: + ts.set_token(token) + self.pro = ts.pro_api() + + # 验证token是否有效 + test_data = self.pro.trade_cal(exchange='SSE', cal_date='20240101', limit=1) + if not test_data.empty: + logger.info("✅ Tushare Pro客户端初始化完成,token验证成功") + else: + logger.warning("⚠️ Tushare Pro token可能无效或权限不足") + except Exception as e: + logger.error(f"❌ Tushare Pro初始化失败: {e}") + logger.warning("将回退到无Pro权限模式") + self.pro = None + else: + logger.warning("未提供Tushare token,将使用免费接口") + self.pro = None + + # 股票名称缓存机制 + self._stock_name_cache = {} + self._stock_list_cache = None + self._cache_timestamp = None + self._cache_duration = 3600 # 缓存1小时 + + def clear_caches(self): + """清除所有缓存""" + self._stock_name_cache.clear() + self._stock_list_cache = None + self._cache_timestamp = None + logger.info("🔄 已清除所有股票数据缓存") + + def get_stock_list(self, use_hot_stocks: bool = True, hot_limit: int = 100) -> pd.DataFrame: + """ + 获取股票列表 + + Args: + use_hot_stocks: 是否优先使用热门股票,默认True + hot_limit: 热门股票数量限制,默认100 + + Returns: + 股票列表DataFrame + """ + import time + current_time = time.time() + + # 检查缓存是否有效 + if (self._stock_list_cache is not None and + self._cache_timestamp is not None and + current_time - self._cache_timestamp < self._cache_duration): + logger.debug(f"🔄 使用缓存的股票列表数据 ({len(self._stock_list_cache)} 只股票)") + return self._stock_list_cache.copy() + + logger.info("📊 重新获取股票列表数据...") + + # 优先使用热门股票 + if use_hot_stocks: + try: + logger.info(f"🔥 优先获取同花顺热榜股票 (前{hot_limit}只)...") + hot_stocks = self.get_combined_hot_stocks( + limit_per_source=hot_limit, + final_limit=hot_limit + ) + + if not hot_stocks.empty and 'stock_code' in hot_stocks.columns: + # 转换为标准格式 + stock_list = hot_stocks.copy() + + # 确保有必要的列 + if 'full_stock_code' not in stock_list.columns: + stock_list['full_stock_code'] = stock_list['stock_code'] + + # 添加缺失的列 + for col in ['area', 'industry', 'exchange', 'list_date']: + if col not in stock_list.columns: + stock_list[col] = '' + + logger.info(f"✅ 获取热门股票成功,共{len(stock_list)}只股票") + if 'source' in stock_list.columns: + source_counts = stock_list['source'].value_counts().to_dict() + source_detail = " | ".join([f"{k}: {v}只" for k, v in source_counts.items()]) + logger.info(f"📊 数据源分布: {source_detail}") + + # 更新缓存 + self._stock_list_cache = stock_list.copy() + self._cache_timestamp = current_time + return stock_list + else: + logger.warning("热门股票数据为空,回退到全量股票列表") + except Exception as e: + logger.warning(f"获取热门股票失败: {e},回退到全量股票列表") + + try: + # 回退方案:获取全量股票列表 + logger.info("📊 获取全量A股股票列表...") + if self.pro: + # 使用Pro接口获取股票基本信息 + stock_list = self.pro.stock_basic( + exchange='', + list_status='L', + fields='ts_code,symbol,name,area,industry,market,list_date' + ) + # 统一列名 + stock_list.rename(columns={ + 'ts_code': 'full_stock_code', + 'symbol': 'stock_code', + 'name': 'short_name', + 'market': 'exchange' + }, inplace=True) + else: + # 需要Pro权限才能获取股票列表 + logger.error("获取股票列表需要TuShare Pro权限,请提供有效token") + return pd.DataFrame() + + logger.info(f"获取全量股票列表成功,共{len(stock_list)}只股票") + # 更新缓存 + self._stock_list_cache = stock_list.copy() + self._cache_timestamp = current_time + return stock_list + except Exception as e: + logger.error(f"获取股票列表失败: {e}") + return pd.DataFrame() + + def get_filtered_a_share_list(self, exclude_st: bool = True, exclude_bj: bool = True, min_market_cap: float = 2000000000) -> pd.DataFrame: + """ + 获取过滤后的A股股票列表 + + Args: + exclude_st: 是否排除ST股票 + exclude_bj: 是否排除北交所股票 + min_market_cap: 最小市值要求(元),默认20亿 + + Returns: + 过滤后的股票列表DataFrame + """ + try: + # 获取完整股票列表 + all_stocks = self.get_stock_list() + + if all_stocks.empty: + return pd.DataFrame() + + filtered_stocks = all_stocks.copy() + original_count = len(filtered_stocks) + + # 排除北交所股票 + if exclude_bj: + before_count = len(filtered_stocks) + if 'exchange' in filtered_stocks.columns: + filtered_stocks = filtered_stocks[filtered_stocks['exchange'] != 'BJ'] + else: + # 根据股票代码判断 + filtered_stocks = filtered_stocks[~filtered_stocks['stock_code'].str.startswith(('8', '43', '83'))] + bj_excluded = before_count - len(filtered_stocks) + logger.info(f"排除北交所股票: {bj_excluded}只") + + # 排除ST股票 + if exclude_st: + before_count = len(filtered_stocks) + st_pattern = r'(\*?ST|PT|退|暂停)' + filtered_stocks = filtered_stocks[~filtered_stocks['short_name'].str.contains(st_pattern, na=False, case=False)] + st_excluded = before_count - len(filtered_stocks) + logger.info(f"排除ST等风险股票: {st_excluded}只") + + # 基于市值筛选 + if min_market_cap > 0 and self.pro: + before_count = len(filtered_stocks) + filtered_stocks = self._filter_by_market_cap(filtered_stocks, min_market_cap) + cap_excluded = before_count - len(filtered_stocks) + logger.info(f"排除小市值股票: {cap_excluded}只") + + final_count = len(filtered_stocks) + excluded_count = original_count - final_count + + if not filtered_stocks.empty: + logger.info(f"✅ 获取过滤后A股列表成功") + logger.info(f"📊 原始股票: {original_count}只 | 过滤后: {final_count}只 | 排除: {excluded_count}只") + + return filtered_stocks + + except Exception as e: + logger.error(f"获取过滤A股列表失败: {e}") + return pd.DataFrame() + + def _filter_by_market_cap(self, stock_df: pd.DataFrame, min_market_cap: float) -> pd.DataFrame: + """ + 基于市值筛选股票 + + Args: + stock_df: 股票列表DataFrame + min_market_cap: 最小市值要求(元) + + Returns: + 过滤后的股票DataFrame + """ + if stock_df.empty or not self.pro: + return stock_df + + try: + logger.info(f"开始基于市值筛选股票,阈值: {min_market_cap/100000000:.0f}亿元") + + # 获取股票基本信息包含市值 + ts_codes = stock_df['full_stock_code'].tolist() if 'full_stock_code' in stock_df.columns else [] + + if not ts_codes: + return stock_df + + # 分批获取市值数据(避免API限制) + batch_size = 50 + valid_stocks = [] + + for i in range(0, len(ts_codes), batch_size): + batch_codes = ts_codes[i:i+batch_size] + + try: + # 获取每日基本面数据(包含市值) + trade_date = datetime.now().strftime('%Y%m%d') + daily_basic = self.pro.daily_basic( + ts_code=','.join(batch_codes), + trade_date=trade_date, + fields='ts_code,total_mv' + ) + + if not daily_basic.empty: + # 市值单位是万元,转换为元 + daily_basic['market_cap'] = daily_basic['total_mv'] * 10000 + + # 筛选符合市值要求的股票 + valid_codes = daily_basic[daily_basic['market_cap'] >= min_market_cap]['ts_code'].tolist() + + # 添加到结果中 + batch_stocks = stock_df[stock_df['full_stock_code'].isin(valid_codes)] + valid_stocks.append(batch_stocks) + + time.sleep(0.2) # API限制 + + except Exception as e: + logger.debug(f"获取批次市值数据失败: {e}") + # 如果获取失败,保留原数据 + batch_stocks = stock_df[stock_df['full_stock_code'].isin(batch_codes)] + valid_stocks.append(batch_stocks) + + if valid_stocks: + result_df = pd.concat(valid_stocks, ignore_index=True) + logger.info(f"✅ 市值筛选完成: {len(result_df)}/{len(stock_df)} 只股票符合要求") + return result_df + else: + return stock_df + + except Exception as e: + logger.error(f"市值筛选失败: {e}") + return stock_df + + def get_realtime_data(self, stock_codes: Union[str, List[str]]) -> pd.DataFrame: + """ + 获取实时行情数据 + + Args: + stock_codes: 股票代码或代码列表 + + Returns: + 实时行情DataFrame + """ + try: + if isinstance(stock_codes, str): + stock_codes = [stock_codes] + + if self.pro: + # 转换为tushare格式的代码 + ts_codes = [] + for code in stock_codes: + if '.' in code: + ts_codes.append(code) + else: + # 根据代码判断交易所 + if code.startswith(('60', '68', '90')): + ts_codes.append(f"{code}.SH") + else: + ts_codes.append(f"{code}.SZ") + + # 使用Pro接口获取最新行情数据 + today = pd.Timestamp.now().strftime('%Y%m%d') + realtime_data = self.pro.daily( + ts_code=','.join(ts_codes), + trade_date=today + ) + else: + logger.error("获取实时行情需要TuShare Pro权限,请提供有效token") + return pd.DataFrame() + + logger.info(f"获取实时数据成功,股票数量: {len(stock_codes)}") + return realtime_data + except Exception as e: + logger.error(f"获取实时数据失败: {e}") + return pd.DataFrame() + + def get_historical_data( + self, + stock_code: str, + start_date: Union[str, date], + end_date: Union[str, date], + period: str = "daily" + ) -> pd.DataFrame: + """ + 获取历史行情数据 + + Args: + stock_code: 股票代码 + start_date: 开始日期 + end_date: 结束日期 + period: 数据周期 ('daily', 'weekly', 'monthly') + + Returns: + 历史行情DataFrame + """ + try: + # 转换日期格式 + if isinstance(start_date, date): + start_date = start_date.strftime("%Y%m%d") + else: + start_date = start_date.replace('-', '') + + if isinstance(end_date, date): + end_date = end_date.strftime("%Y%m%d") + else: + end_date = end_date.replace('-', '') + + # 转换为tushare格式的代码 + if '.' not in stock_code: + if stock_code.startswith(('60', '68', '90')): + ts_code = f"{stock_code}.SH" + else: + ts_code = f"{stock_code}.SZ" + else: + ts_code = stock_code + + if self.pro: + # 使用Pro接口 + if period == 'daily': + hist_data = self.pro.daily( + ts_code=ts_code, + start_date=start_date, + end_date=end_date + ) + elif period == 'weekly': + hist_data = self.pro.weekly( + ts_code=ts_code, + start_date=start_date, + end_date=end_date + ) + elif period == 'monthly': + hist_data = self.pro.monthly( + ts_code=ts_code, + start_date=start_date, + end_date=end_date + ) + else: + hist_data = self.pro.daily( + ts_code=ts_code, + start_date=start_date, + end_date=end_date + ) + + if not hist_data.empty: + # 统一列名映射 + field_mapping = { + 'vol': 'volume', # TuShare返回vol,策略期望volume + 'ts_code': 'stock_code' # 统一股票代码字段名 + } + + for old_name, new_name in field_mapping.items(): + if old_name in hist_data.columns: + hist_data.rename(columns={old_name: new_name}, inplace=True) + + # 转换日期格式并保持双格式兼容 + if 'trade_date' in hist_data.columns: + hist_data['trade_date'] = pd.to_datetime(hist_data['trade_date']).dt.strftime('%Y-%m-%d') + hist_data['date'] = hist_data['trade_date'] # 同时提供date字段 + elif 'date' in hist_data.columns: + hist_data['date'] = pd.to_datetime(hist_data['date']).dt.strftime('%Y-%m-%d') + hist_data['trade_date'] = hist_data['date'] # 同时提供trade_date字段 + + # 按日期升序排列 + hist_data = hist_data.sort_values('date') + else: + # 需要Pro权限才能获取历史数据 + logger.error("获取历史数据需要TuShare Pro权限,请提供有效token") + return pd.DataFrame() + + if not hist_data.empty: + logger.info(f"获取{stock_code}历史数据成功,数据量: {len(hist_data)}") + else: + logger.warning(f"获取{stock_code}历史数据为空") + + return hist_data + + except Exception as e: + logger.error(f"获取{stock_code}历史数据失败: {e}") + return pd.DataFrame() + + def get_index_data(self, index_code: str = "000001.SH") -> pd.DataFrame: + """ + 获取指数数据 + + Args: + index_code: 指数代码 + + Returns: + 指数数据DataFrame + """ + try: + if self.pro: + # 转换指数代码格式 + if index_code == "000001.SH": + ts_code = "000001.SH" # 上证指数 + elif index_code == "399001.SZ": + ts_code = "399001.SZ" # 深证成指 + else: + ts_code = index_code + + index_data = self.pro.index_daily( + ts_code=ts_code, + trade_date=datetime.now().strftime('%Y%m%d') + ) + else: + # 需要Pro权限才能获取指数数据 + logger.error("获取指数数据需要TuShare Pro权限,请提供有效token") + return pd.DataFrame() + + logger.info(f"获取指数{index_code}数据成功") + return index_data + except Exception as e: + logger.error(f"获取指数数据失败: {e}") + return pd.DataFrame() + + def get_financial_data(self, stock_code: str) -> pd.DataFrame: + """ + 获取财务数据 + + Args: + stock_code: 股票代码 + + Returns: + 财务数据DataFrame + """ + try: + # 转换为tushare格式的代码 + if '.' not in stock_code: + if stock_code.startswith(('60', '68', '90')): + ts_code = f"{stock_code}.SH" + else: + ts_code = f"{stock_code}.SZ" + else: + ts_code = stock_code + + if self.pro: + # 获取财务数据 + financial_data = self.pro.income(ts_code=ts_code, period='20231231') + else: + # 免费接口的财务数据 + financial_data = pd.DataFrame() + + logger.info(f"获取{stock_code}财务数据成功") + return financial_data + except Exception as e: + logger.error(f"获取财务数据失败: {e}") + return pd.DataFrame() + + def search_stocks(self, keyword: str) -> pd.DataFrame: + """ + 搜索股票 + + Args: + keyword: 搜索关键词 + + Returns: + 搜索结果DataFrame + """ + try: + # 获取完整股票列表 + all_stocks = self.get_stock_list() + if all_stocks.empty: + return pd.DataFrame() + + # 在股票代码和名称中搜索关键词 + keyword = str(keyword).strip() + if not keyword: + return pd.DataFrame() + + # 支持按代码或名称模糊搜索 + mask = ( + all_stocks['stock_code'].str.contains(keyword, case=False, na=False) | + all_stocks['short_name'].str.contains(keyword, case=False, na=False) + ) + + results = all_stocks[mask].copy() + logger.info(f"搜索股票'{keyword}'成功,找到{len(results)}个结果") + return results + except Exception as e: + logger.error(f"搜索股票失败: {e}") + return pd.DataFrame() + + @retry_on_failure(retries=2, delay=1.0) + def get_hot_stocks_ths(self, limit: int = 100, trade_date: str = None) -> pd.DataFrame: + """ + 获取同花顺热门股票 + + Args: + limit: 返回数量限制 + trade_date: 交易日期,格式YYYYMMDD,默认为当日 + + Returns: + 热门股票DataFrame + """ + try: + if not self.pro: + logger.error("需要Tushare Pro权限才能获取同花顺热股数据") + return self._get_fallback_hot_stocks(limit) + + # 使用TuShare Pro的同花顺热榜接口 + params = { + 'market': '热股', + 'is_new': 'Y' + } + if trade_date: + params['trade_date'] = trade_date + + df = self.pro.ths_hot(**params) + + if df.empty: + logger.warning("同花顺热股数据为空,使用备用方法") + return self._get_fallback_hot_stocks(limit) + + # 数据清洗和去重 + logger.info(f"原始数据: {len(df)} 条") + + # 1. 去除重复股票代码,保留第一个(排名最好的) + if 'ts_code' in df.columns: + df = df.drop_duplicates(subset=['ts_code'], keep='first') + logger.info(f"去重后数据: {len(df)} 条") + + # 2. 按rank排序,处理排名异常 + if 'rank' in df.columns: + df = df.sort_values('rank') + # 重新分配连续排名 + df['original_rank'] = df['rank'] # 保留原始排名 + df['rank'] = range(1, len(df) + 1) # 重新分配连续排名 + + # 3. 只取前limit条 + df = df.head(limit) + + # 4. 统一列名格式 + if 'ts_code' in df.columns: + df.rename(columns={'ts_code': 'stock_code', 'ts_name': 'short_name'}, inplace=True) + + # 5. 添加数据源标识和有用字段 + df['source'] = '同花顺热股' + + # 保留有用的原始字段 + useful_cols = ['stock_code', 'short_name', 'rank', 'source'] + for col in ['pct_change', 'current_price', 'hot', 'concept', 'original_rank']: + if col in df.columns: + useful_cols.append(col) + + df = df[useful_cols] + + logger.info(f"获取同花顽热门股票成功,共{len(df)}只股票") + return df + + except Exception as e: + logger.error(f"获取同花顽热门股票失败: {e}") + return self._get_fallback_hot_stocks(limit) + + @retry_on_failure(retries=2, delay=1.0) + def get_popular_stocks_east(self, limit: int = 100, trade_date: str = None) -> pd.DataFrame: + """ + 获取东方财富人气股票 + + Args: + limit: 返回数量限制 + trade_date: 交易日期,格式YYYYMMDD,默认为当日 + + Returns: + 人气股票DataFrame + """ + try: + if not self.pro: + logger.error("需要Tushare Pro权限才能获取东财人气股数据") + return self._get_fallback_hot_stocks(limit) + + # 使用TuShare Pro的东方财富热榜接口 + params = { + 'market': 'A股市场', + 'hot_type': '人气榜', + 'is_new': 'Y' + } + if trade_date: + params['trade_date'] = trade_date + + df = self.pro.dc_hot(**params) + + if df.empty: + logger.warning("东财人气股数据为空,使用备用方法") + return self._get_fallback_hot_stocks(limit) + + # 数据处理和标准化 + df = df.head(limit) + + # 统一列名格式 + if 'ts_code' in df.columns: + df.rename(columns={'ts_code': 'stock_code', 'ts_name': 'short_name'}, inplace=True) + + # 添加数据源标识 + df['source'] = '东财人气榜' + if 'rank' not in df.columns: + df['rank'] = range(1, len(df) + 1) + + logger.info(f"获取东财人气股票成功,共{len(df)}只股票") + return df + + except Exception as e: + logger.error(f"获取东财人气股票失败: {e}") + return self._get_fallback_hot_stocks(limit) + + def get_combined_hot_stocks(self, limit_per_source: int = 100, final_limit: int = 150) -> pd.DataFrame: + """ + 获取合并的热门股票(同花顽+东财) + + Args: + limit_per_source: 每个数据源的股票数量 + final_limit: 最终返回的股票数量 + + Returns: + 合并去重后的热门股票DataFrame + """ + try: + all_stocks = [] + + # 1. 获取同花顽热股 + ths_stocks = self.get_hot_stocks_ths(limit_per_source) + if not ths_stocks.empty: + all_stocks.append(ths_stocks) + logger.info(f"同花顽热股: {len(ths_stocks)}只") + + # 2. 获取东财人气股 + east_stocks = self.get_popular_stocks_east(limit_per_source) + if not east_stocks.empty: + all_stocks.append(east_stocks) + logger.info(f"东财人气股: {len(east_stocks)}只") + + if not all_stocks: + logger.warning("所有热门股票数据源都失败,使用备用股票池") + return self._get_fallback_hot_stocks(final_limit) + + # 3. 合并数据(过滤空数据框) + non_empty_stocks = [df for df in all_stocks if not df.empty] + if not non_empty_stocks: + logger.warning("没有有效的热门股票数据,使用备用股票池") + return self._get_fallback_hot_stocks(final_limit) + + combined_df = pd.concat(non_empty_stocks, ignore_index=True) + + # 4. 去重(优先保留排名靠前的) + combined_df = combined_df.sort_values(['rank', 'source']) + unique_stocks = combined_df.drop_duplicates(subset=['stock_code'], keep='first') + + # 5. 限制最终数量 + result = unique_stocks.head(final_limit) + + # 6. 重新排序并添加最终排名 + result = result.reset_index(drop=True) + result['final_rank'] = range(1, len(result) + 1) + + logger.info(f"合并热门股票成功: 原始{len(combined_df)}只,去重后{len(unique_stocks)}只,最终{len(result)}只") + return result + + except Exception as e: + logger.error(f"获取合并热门股票失败: {e}") + return self._get_fallback_hot_stocks(final_limit) + + def get_stock_name(self, stock_code: str) -> str: + """ + 获取股票中文名称 + + Args: + stock_code: 股票代码 + + Returns: + 股票中文名称 + """ + try: + # 从缓存中查找 + if stock_code in self._stock_name_cache: + return self._stock_name_cache[stock_code] + + # 直接通过TuShare Pro查询单个股票信息 + if self.pro: + # 将股票代码转换为TuShare格式 + ts_code = stock_code + if '.' not in stock_code: + # 如果没有后缀,根据代码前缀添加 + if stock_code.startswith('6'): + ts_code = f"{stock_code}.SH" + elif stock_code.startswith(('0', '3')): + ts_code = f"{stock_code}.SZ" + + # 查询股票基本信息 + stock_info = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,name') + if not stock_info.empty and 'name' in stock_info.columns: + stock_name = stock_info.iloc[0]['name'] + # 添加到缓存 + self._stock_name_cache[stock_code] = stock_name + return stock_name + + # 如果Pro查询失败,返回股票代码本身 + self._stock_name_cache[stock_code] = stock_code + return stock_code + + except Exception as e: + logger.debug(f"获取股票{stock_code}名称失败: {e}") + return stock_code + + def _get_fallback_hot_stocks(self, limit: int = 100) -> pd.DataFrame: + """ + 备用热门股票获取方法(当主要接口失败时使用) + + Args: + limit: 返回股票数量 + + Returns: + 热门股票DataFrame + """ + try: + logger.info("使用备用方法获取热门股票(基于成交量排序)") + + if self.pro: + # 获取当日成交量排行 + trade_date = datetime.now().strftime('%Y%m%d') + daily_data = self.pro.daily(trade_date=trade_date) + + if not daily_data.empty: + # 按成交量排序 + hot_stocks = daily_data.sort_values('vol', ascending=False).head(limit) + + # 统一列名 + hot_stocks.rename(columns={'ts_code': 'stock_code'}, inplace=True) + + # 添加股票名称 + stock_list = self.get_stock_list() + if not stock_list.empty: + hot_stocks = hot_stocks.merge( + stock_list[['stock_code', 'short_name']], + left_on='stock_code', + right_on='stock_code', + how='left' + ) + + # 添加必要字段 + hot_stocks['source'] = '成交量排序' + hot_stocks['rank'] = range(1, len(hot_stocks) + 1) + + logger.info(f"备用方法获取热门股票成功,共{len(hot_stocks)}只股票") + return hot_stocks + + # 如果Pro接口也失败,返回预设的股票池 + return self._get_default_stock_pool(limit) + + except Exception as e: + logger.error(f"备用热门股票获取失败: {e}") + return self._get_default_stock_pool(limit) + + def _get_default_stock_pool(self, limit: int = 100) -> pd.DataFrame: + """ + 默认股票池(当所有数据获取方法都失败时使用) + + Args: + limit: 返回股票数量 + + Returns: + 默认股票池DataFrame + """ + # 预设一些主要的大盘股和活跃股票 + default_stocks = [ + {'stock_code': '000001.SZ', 'short_name': '平安银行'}, + {'stock_code': '000002.SZ', 'short_name': '万科A'}, + {'stock_code': '000858.SZ', 'short_name': '五粮液'}, + {'stock_code': '600000.SH', 'short_name': '浦发银行'}, + {'stock_code': '600036.SH', 'short_name': '招商银行'}, + {'stock_code': '600519.SH', 'short_name': '贵州茅台'}, + {'stock_code': '600887.SH', 'short_name': '伊利股份'}, + {'stock_code': '002415.SZ', 'short_name': '海康威视'}, + {'stock_code': '300014.SZ', 'short_name': '亿纬锂能'}, + {'stock_code': '300059.SZ', 'short_name': '东方财富'}, + {'stock_code': '300750.SZ', 'short_name': '宁德时代'}, + {'stock_code': '000876.SZ', 'short_name': '新希望'}, + {'stock_code': '002594.SZ', 'short_name': 'BYD'}, + {'stock_code': '000895.SZ', 'short_name': '双汇发展'}, + {'stock_code': '600031.SH', 'short_name': '三一重工'}, + {'stock_code': '601318.SH', 'short_name': '中国平安'}, + {'stock_code': '601166.SH', 'short_name': '兴业银行'}, + {'stock_code': '600009.SH', 'short_name': '上海机场'}, + {'stock_code': '600276.SH', 'short_name': '恒瑞医药'}, + {'stock_code': '000063.SZ', 'short_name': '中兴通讯'}, + ] + + df = pd.DataFrame(default_stocks[:limit]) + df['source'] = '默认股票池' + df['rank'] = range(1, len(df) + 1) + + logger.warning(f"使用默认股票池,包含{len(df)}只股票") + return df + + def get_market_overview(self) -> dict: + """ + 获取市场概况 + + Returns: + 市场概况字典 + """ + try: + # 获取主要指数数据 + sh_index = self.get_index_data("000001.SH") # 上证指数 + sz_index = self.get_index_data("399001.SZ") # 深证成指 + cyb_index = self.get_index_data("399006.SZ") # 创业板指 + + overview = { + "update_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "shanghai": sh_index.iloc[0].to_dict() if not sh_index.empty else {}, + "shenzhen": sz_index.iloc[0].to_dict() if not sz_index.empty else {}, + "chinext": cyb_index.iloc[0].to_dict() if not cyb_index.empty else {} + } + + logger.info("获取市场概况成功") + return overview + except Exception as e: + logger.error(f"获取市场概况失败: {e}") + return {} + + def get_sector_money_flow(self, trade_date: str = None) -> pd.DataFrame: + """ + 获取板块资金流向数据 + 使用同花顺行业资金流向接口 + + Args: + trade_date: 交易日期,格式YYYYMMDD,默认为当日 + + Returns: + 板块资金流向DataFrame,包含净流入金额等信息 + """ + try: + if not self.pro: + logger.error("需要Tushare Pro权限才能获取板块资金流向数据") + return pd.DataFrame() + + # 如果未指定日期,使用最近交易日 + if trade_date is None: + trade_date = datetime.now().strftime('%Y%m%d') + + # 获取同花顺行业资金流向数据 + df = self.pro.moneyflow_ind_ths(trade_date=trade_date) + + if df.empty: + # 如果当日无数据,尝试获取前一交易日 + prev_date = (datetime.strptime(trade_date, '%Y%m%d') - timedelta(days=1)).strftime('%Y%m%d') + df = self.pro.moneyflow_ind_ths(trade_date=prev_date) + + if not df.empty: + # 重命名列以保持兼容性 + if 'industry' in df.columns: + df['name'] = df['industry'] + + # 使用正确的净流入字段 + if 'net_amount' in df.columns: + df['net_amount'] = df['net_amount'] + else: + logger.warning("数据中未找到net_amount字段") + return pd.DataFrame() + + # 按净流入金额排序(从大到小) + df = df.sort_values('net_amount', ascending=False) + # 添加排名 + df['rank'] = range(1, len(df) + 1) + + # 确保有涨跌幅字段 + if 'pct_change' not in df.columns: + df['pct_change'] = 0 + + logger.info(f"获取板块资金流向成功,共{len(df)}个板块") + else: + logger.warning(f"未获取到{trade_date}的板块资金流向数据") + + return df + + except Exception as e: + logger.error(f"获取板块资金流向失败: {e}") + return pd.DataFrame() + + def get_concept_money_flow(self, trade_date: str = None) -> pd.DataFrame: + """ + 获取概念板块资金流向数据 + 使用同花顺概念板块资金流向接口 + + Args: + trade_date: 交易日期,格式YYYYMMDD,默认为当日 + + Returns: + 概念板块资金流向DataFrame + """ + try: + if not self.pro: + logger.error("需要Tushare Pro权限才能获取概念资金流向数据") + return pd.DataFrame() + + if trade_date is None: + trade_date = datetime.now().strftime('%Y%m%d') + + # 获取同花顺概念板块资金流向数据 + df = self.pro.moneyflow_cnt_ths(trade_date=trade_date) + + if df.empty: + # 如果当日无数据,尝试获取前一交易日 + prev_date = (datetime.strptime(trade_date, '%Y%m%d') - timedelta(days=1)).strftime('%Y%m%d') + df = self.pro.moneyflow_cnt_ths(trade_date=prev_date) + + if not df.empty: + # 按净流入金额排序(从大到小) + df = df.sort_values('net_amount', ascending=False) + # 添加排名 + df['rank'] = range(1, len(df) + 1) + logger.info(f"获取概念资金流向成功,共{len(df)}个概念") + else: + logger.warning(f"未获取到{trade_date}的概念资金流向数据") + + return df + + except Exception as e: + logger.error(f"获取概念资金流向失败: {e}") + return pd.DataFrame() + + def get_strongest_concept_boards(self, trade_date: str = None, ts_code: str = None) -> pd.DataFrame: + """ + 获取最强板块统计(涨停股票最多的概念板块) + + Args: + trade_date: 交易日期,格式YYYYMMDD,默认为当日 + ts_code: 板块代码,可选 + + Returns: + 最强板块统计DataFrame,包含以下字段: + - ts_code: 板块代码 + - name: 板块名称 + - trade_date: 交易日期 + - days: 上榜天数 + - up_stat: 连板高度 + - cons_nums: 连板家数 + - up_nums: 涨停家数 + - pct_chg: 涨跌幅% + - rank: 板块热点排名 + """ + try: + if not self.pro: + logger.error("需要Tushare Pro权限才能获取最强板块数据") + return pd.DataFrame() + + # 设置查询参数 + params = {} + if trade_date: + params['trade_date'] = trade_date + if ts_code: + params['ts_code'] = ts_code + + # 调用Tushare接口 + df = self.pro.limit_cpt_list(**params) + + if df.empty: + logger.warning(f"未获取到最强板块数据: {trade_date or '当日'}") + return pd.DataFrame() + + # 按照涨停家数和涨跌幅排序 + df = df.sort_values(['up_nums', 'pct_chg'], ascending=[False, False]) + + logger.info(f"获取最强板块数据成功: {len(df)}个板块,日期: {trade_date or '当日'}") + return df + + except Exception as e: + logger.error(f"获取最强板块统计失败: {e}") + return pd.DataFrame() + + def get_concept_constituent_stocks(self, ts_code: str) -> pd.DataFrame: + """ + 获取概念板块的成分股票 + + Args: + ts_code: 概念板块代码 + + Returns: + 成分股票DataFrame,包含股票代码、名称等信息 + """ + try: + if not self.pro: + logger.error("需要Tushare Pro权限才能获取概念成分股") + return pd.DataFrame() + + # 获取概念成分股 + df = self.pro.concept_detail(id=ts_code) + + if df.empty: + logger.warning(f"未获取到概念板块 {ts_code} 的成分股") + return pd.DataFrame() + + logger.info(f"获取概念板块 {ts_code} 成分股成功: {len(df)}只股票") + return df + + except Exception as e: + logger.error(f"获取概念成分股失败: {e}") + return pd.DataFrame() + + def get_strongest_concept_stocks(self, trade_date: str = None, top_boards: int = 5) -> dict: + """ + 获取最强板块中的股票(综合方法) + + Args: + trade_date: 交易日期,格式YYYYMMDD,默认为当日 + top_boards: 选择前N个最强板块,默认5个 + + Returns: + 包含最强板块及其成分股的字典 + """ + try: + result = { + 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), + 'strongest_boards': pd.DataFrame(), + 'stocks_by_board': {}, + 'all_stocks': pd.DataFrame(), + 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + + # 1. 获取最强板块列表 + strongest_boards = self.get_strongest_concept_boards(trade_date) + if strongest_boards.empty: + logger.warning("未获取到最强板块数据") + return result + + # 取前N个最强板块 + top_boards_data = strongest_boards.head(top_boards) + result['strongest_boards'] = top_boards_data + + # 2. 获取每个强势板块的成分股 + all_stocks_list = [] + for _, board in top_boards_data.iterrows(): + board_code = board['ts_code'] + board_name = board['name'] + + # 获取该板块的成分股 + stocks = self.get_concept_constituent_stocks(board_code) + if not stocks.empty: + # 添加板块信息到股票数据 + stocks['board_code'] = board_code + stocks['board_name'] = board_name + stocks['board_up_nums'] = board['up_nums'] + stocks['board_pct_chg'] = board['pct_chg'] + + result['stocks_by_board'][board_name] = stocks + all_stocks_list.append(stocks) + + logger.info(f"板块 {board_name} 包含 {len(stocks)} 只股票") + + # 避免频繁调用API + time.sleep(0.1) + + # 3. 合并所有股票数据 + if all_stocks_list: + result['all_stocks'] = pd.concat(all_stocks_list, ignore_index=True) + + # 去重(一只股票可能属于多个概念) + unique_stocks = result['all_stocks'].drop_duplicates(subset=['ts_code']) + + logger.info(f"最强板块股票获取完成:") + logger.info(f" - 强势板块数量: {len(top_boards_data)}") + logger.info(f" - 包含股票总数: {len(result['all_stocks'])}") + logger.info(f" - 去重后股票数: {len(unique_stocks)}") + + return result + + except Exception as e: + logger.error(f"获取最强板块股票失败: {e}") + return { + 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), + 'strongest_boards': pd.DataFrame(), + 'stocks_by_board': {}, + 'all_stocks': pd.DataFrame(), + 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'error': str(e) + } + + def get_leading_stocks_from_board(self, board_code: str, board_name: str = None, + top_n: int = 3, min_price: float = 5.0, + max_price: float = 300.0) -> pd.DataFrame: + """ + 从单个板块中筛选龙头股 + + Args: + board_code: 板块代码 + board_name: 板块名称 + top_n: 返回龙头股数量 + min_price: 最低股价过滤 + max_price: 最高股价过滤 + + Returns: + 龙头股DataFrame,包含评分和排名 + """ + try: + if not self.pro: + logger.error("需要Tushare Pro权限才能筛选龙头股") + return pd.DataFrame() + + # 1. 获取板块成分股 + constituent_stocks = self.get_concept_constituent_stocks(board_code) + if constituent_stocks.empty: + logger.warning(f"板块 {board_code} 没有成分股数据") + return pd.DataFrame() + + # 2. 获取股票基本信息和当日行情 + trade_date = datetime.now().strftime('%Y%m%d') + stock_codes = constituent_stocks['ts_code'].tolist() + + leading_candidates = [] + + for stock_code in stock_codes[:50]: # 限制查询数量避免超限 + try: + # 获取基本信息 + basic_info = self.pro.stock_basic(ts_code=stock_code) + if basic_info.empty: + continue + + # 获取当日行情 + daily_data = self.pro.daily(ts_code=stock_code, trade_date=trade_date) + if daily_data.empty: + # 尝试获取最近一个交易日数据 + recent_data = self.pro.daily(ts_code=stock_code, limit=1) + if not recent_data.empty: + daily_data = recent_data + else: + continue + + stock_info = daily_data.iloc[0] + basic = basic_info.iloc[0] + + # 价格过滤 + current_price = stock_info['close'] + if current_price < min_price or current_price > max_price: + continue + + # 计算评分指标 + candidate = { + 'ts_code': stock_code, + 'stock_code': stock_code, + 'name': basic['name'], + 'board_code': board_code, + 'board_name': board_name or board_code, + 'close': current_price, + 'pct_chg': stock_info.get('pct_chg', 0), + 'vol': stock_info.get('vol', 0), + 'amount': stock_info.get('amount', 0), + 'turnover_rate': stock_info.get('turnover_rate', 0), + 'total_mv': stock_info.get('total_mv', 0), # 总市值 + 'circ_mv': stock_info.get('circ_mv', 0), # 流通市值 + } + + # 获取近5日数据计算连涨天数 + recent_5d = self.pro.daily(ts_code=stock_code, limit=5) + if not recent_5d.empty: + # 计算连续上涨天数 + consecutive_up = 0 + for _, row in recent_5d.iterrows(): + if row['pct_chg'] > 0: + consecutive_up += 1 + else: + break + candidate['consecutive_up_days'] = consecutive_up + + # 计算5日平均成交额 + candidate['avg_amount_5d'] = recent_5d['amount'].mean() + else: + candidate['consecutive_up_days'] = 0 + candidate['avg_amount_5d'] = candidate['amount'] + + leading_candidates.append(candidate) + time.sleep(0.1) # 避免调用过于频繁 + + except Exception as e: + logger.debug(f"处理股票 {stock_code} 时出错: {e}") + continue + + if not leading_candidates: + logger.warning(f"板块 {board_code} 没有找到符合条件的龙头股") + return pd.DataFrame() + + # 3. 计算龙头股评分 + df = pd.DataFrame(leading_candidates) + df = self._calculate_leading_score(df) + + # 4. 排序并返回前N个 + df = df.sort_values('leading_score', ascending=False).head(top_n) + df['rank'] = range(1, len(df) + 1) + + logger.info(f"板块 {board_name or board_code} 筛选出 {len(df)} 只龙头股") + return df + + except Exception as e: + logger.error(f"从板块 {board_code} 筛选龙头股失败: {e}") + return pd.DataFrame() + + def _calculate_leading_score(self, df: pd.DataFrame) -> pd.DataFrame: + """ + 计算龙头股评分 + + Args: + df: 候选股票DataFrame + + Returns: + 包含评分的DataFrame + """ + try: + if df.empty: + return df + + df = df.copy() + + # 标准化各项指标到0-100分 + # 1. 涨幅得分 (30%) + df['pct_chg_score'] = self._normalize_score(df['pct_chg'], weight=30) + + # 2. 成交额得分 (25%) + df['amount_score'] = self._normalize_score(df['avg_amount_5d'], weight=25) + + # 3. 连续上涨天数得分 (20%) + df['consecutive_score'] = df['consecutive_up_days'] * 4 # 每天4分,最高20分 + df['consecutive_score'] = df['consecutive_score'].clip(upper=20) + + # 4. 换手率得分 (15%) - 适中的换手率更好 + optimal_turnover = 8 # 最优换手率8% + df['turnover_score'] = 15 - abs(df['turnover_rate'] - optimal_turnover) + df['turnover_score'] = df['turnover_score'].clip(lower=0, upper=15) + + # 5. 市值得分 (10%) - 流通市值适中更好 + # 50-500亿为最佳区间 + df['mv_score'] = df['circ_mv'].apply(lambda x: self._get_mv_score(x)) + + # 综合评分 + df['leading_score'] = ( + df['pct_chg_score'] + + df['amount_score'] + + df['consecutive_score'] + + df['turnover_score'] + + df['mv_score'] + ) + + # 添加评级 + df['leading_grade'] = df['leading_score'].apply(self._get_leading_grade) + + return df + + except Exception as e: + logger.error(f"计算龙头股评分失败: {e}") + return df + + def _normalize_score(self, series: pd.Series, weight: int = 100) -> pd.Series: + """ + 将数据标准化为指定权重的得分 + + Args: + series: 原始数据序列 + weight: 权重分数 + + Returns: + 标准化后的得分序列 + """ + if series.empty or series.max() == series.min(): + return pd.Series([0] * len(series), index=series.index) + + # Min-Max标准化到0-1,然后乘以权重 + normalized = (series - series.min()) / (series.max() - series.min()) + return normalized * weight + + def _get_mv_score(self, market_value: float) -> float: + """ + 根据流通市值计算得分 + + Args: + market_value: 流通市值(万元) + + Returns: + 市值得分 + """ + if market_value <= 0: + return 0 + + # 转换为亿元 + mv_billion = market_value / 10000 + + if 50 <= mv_billion <= 500: # 最佳区间:50-500亿 + return 10 + elif 20 <= mv_billion < 50: # 较小但可接受:20-50亿 + return 8 + elif 500 < mv_billion <= 1000: # 较大但可接受:500-1000亿 + return 6 + elif 10 <= mv_billion < 20: # 偏小:10-20亿 + return 4 + elif mv_billion > 1000: # 太大:>1000亿 + return 2 + else: # 太小:<10亿 + return 1 + + def _get_leading_grade(self, score: float) -> str: + """ + 根据评分获取评级 + + Args: + score: 综合评分 + + Returns: + 评级字符串 + """ + if score >= 80: + return "A+ 超级龙头" + elif score >= 70: + return "A 优质龙头" + elif score >= 60: + return "B+ 潜力龙头" + elif score >= 50: + return "B 一般龙头" + else: + return "C 弱势股票" + + def get_leading_stocks_from_hot_boards(self, top_boards: int = 10, + stocks_per_board: int = 2, + min_score: float = 50.0) -> dict: + """ + 从热门板块中筛选龙头牛股(主要接口) + + Args: + top_boards: 分析前N个热门板块 + stocks_per_board: 每个板块选择的龙头股数量 + min_score: 最低评分要求 + + Returns: + 包含所有龙头股信息的字典 + """ + try: + result = { + 'scan_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'boards_analyzed': 0, + 'total_leading_stocks': 0, + 'leading_stocks_by_board': {}, + 'all_leading_stocks': pd.DataFrame(), + 'top_leading_stocks': pd.DataFrame() + } + + # 1. 获取最强板块 + logger.info(f"正在获取前 {top_boards} 个最强板块...") + strongest_boards = self.get_strongest_concept_boards() + + if strongest_boards.empty: + logger.warning("未能获取到强势板块数据") + return result + + # 取前N个板块 + target_boards = strongest_boards.head(top_boards) + result['boards_analyzed'] = len(target_boards) + + # 2. 逐个分析每个板块的龙头股 + all_leading_stocks = [] + + for idx, (_, board) in enumerate(target_boards.iterrows(), 1): + board_code = board['ts_code'] + board_name = board['name'] + board_up_nums = board.get('up_nums', 0) + board_pct_chg = board.get('pct_chg', 0) + + logger.info(f"[{idx}/{len(target_boards)}] 分析板块: {board_name} (涨停{board_up_nums}只, 涨幅{board_pct_chg:.2f}%)") + + # 从该板块筛选龙头股 + board_leaders = self.get_leading_stocks_from_board( + board_code, board_name, + top_n=stocks_per_board + ) + + if not board_leaders.empty: + # 过滤低分股票 + qualified_leaders = board_leaders[board_leaders['leading_score'] >= min_score] + + if not qualified_leaders.empty: + # 添加板块信息 + qualified_leaders['board_up_nums'] = board_up_nums + qualified_leaders['board_pct_chg'] = board_pct_chg + qualified_leaders['board_rank'] = idx + + result['leading_stocks_by_board'][board_name] = qualified_leaders + all_leading_stocks.append(qualified_leaders) + + logger.info(f" ✅ 找到 {len(qualified_leaders)} 只龙头股") + else: + logger.info(f" ❌ 无符合评分要求的龙头股") + else: + logger.info(f" ❌ 板块数据获取失败") + + # 避免API限制 + time.sleep(0.5) + + # 3. 汇总所有龙头股 + if all_leading_stocks: + all_df = pd.concat(all_leading_stocks, ignore_index=True) + result['all_leading_stocks'] = all_df + result['total_leading_stocks'] = len(all_df) + + # 4. 获取综合排名前N的超级龙头 + top_leaders = all_df.nlargest(20, 'leading_score') + top_leaders['overall_rank'] = range(1, len(top_leaders) + 1) + result['top_leading_stocks'] = top_leaders + + logger.info(f"🎯 筛选完成! 共分析 {result['boards_analyzed']} 个板块,发现 {result['total_leading_stocks']} 只龙头股") + logger.info(f"📈 TOP10 超级龙头:") + for _, stock in top_leaders.head(10).iterrows(): + logger.info(f" {stock['overall_rank']}. {stock['stock_code']} {stock['name']} | {stock['board_name']} | 评分:{stock['leading_score']:.1f} | {stock['leading_grade']}") + + return result + + except Exception as e: + logger.error(f"筛选热门板块龙头股失败: {e}") + result['error'] = str(e) + return result + + def get_top_money_flow_sectors(self, trade_date: str = None, top_n: int = 10) -> dict: + """ + 获取当日资金净流入最多的板块 + + Args: + trade_date: 交易日期,格式YYYYMMDD,默认为当日 + top_n: 返回前N个板块,默认10个 + + Returns: + 包含行业板块和概念板块的字典 + """ + try: + result = { + 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), + 'sectors': pd.DataFrame(), + 'concepts': pd.DataFrame(), + 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + + # 获取行业板块资金流向 + sectors_df = self.get_sector_money_flow(trade_date) + if not sectors_df.empty: + result['sectors'] = sectors_df.head(top_n) + logger.info(f"获取行业板块TOP{top_n}成功") + + # 获取概念板块资金流向 + concepts_df = self.get_concept_money_flow(trade_date) + if not concepts_df.empty: + result['concepts'] = concepts_df.head(top_n) + logger.info(f"获取概念板块TOP{top_n}成功") + + return result + + except Exception as e: + logger.error(f"获取TOP资金流向板块失败: {e}") + return { + 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), + 'sectors': pd.DataFrame(), + 'concepts': pd.DataFrame(), + 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'error': str(e) + } + + +# ADataFetcher别名已移除,请直接使用TushareFetcher + + +if __name__ == "__main__": + # 测试代码 + token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" + fetcher = TushareFetcher(token) + + # 测试获取股票列表 + print("测试获取股票列表...") + stock_list = fetcher.get_stock_list() + print(f"股票数量: {len(stock_list)}") + if not stock_list.empty: + print(stock_list.head()) + + # 测试搜索功能 + print("\n测试搜索功能...") + search_results = fetcher.search_stocks("平安") + if not search_results.empty: + print(search_results.head()) \ No newline at end of file diff --git a/src/database/mysql_database_manager.py b/src/database/mysql_database_manager.py index 0df952f..4ab4efd 100644 --- a/src/database/mysql_database_manager.py +++ b/src/database/mysql_database_manager.py @@ -12,6 +12,8 @@ import json from pathlib import Path import sys from loguru import logger +from sqlalchemy import create_engine +import warnings # 添加项目根目录到路径 current_dir = Path(__file__).parent.parent.parent @@ -33,6 +35,9 @@ class MySQLDatabaseManager: self.config = config or MYSQL_CONFIG self.connection_params = self.config.to_dict() + # 创建SQLAlchemy引擎用于pandas操作 + self._create_sqlalchemy_engine() + # 测试连接并初始化数据库 try: self._test_connection() @@ -42,6 +47,62 @@ class MySQLDatabaseManager: logger.error(f"MySQL数据库初始化失败: {e}") raise + def _create_sqlalchemy_engine(self): + """创建SQLAlchemy引擎""" + try: + # 构建连接字符串 + connection_string = ( + f"mysql+pymysql://{self.config.user}:{self.config.password}" + f"@{self.config.host}:{self.config.port}/{self.config.database}" + f"?charset={self.config.charset}" + ) + + # 创建引擎 + self.engine = create_engine( + connection_string, + pool_pre_ping=True, + pool_recycle=3600, + echo=False + ) + + # 抑制pandas SQLAlchemy警告 + warnings.filterwarnings('ignore', + message='pandas only supports SQLAlchemy connectable.*', + category=UserWarning) + + except Exception as e: + logger.error(f"创建SQLAlchemy引擎失败: {e}") + self.engine = None + + def _execute_query_with_engine(self, sql: str, params: list = None) -> pd.DataFrame: + """使用适当的引擎执行查询""" + try: + if self.engine and params: + # 对于SQLAlchemy,将%s替换为实际值(仅适用于简单参数) + formatted_sql = sql + for param in params: + if isinstance(param, str): + # 字符串参数需要加引号 + formatted_sql = formatted_sql.replace('%s', f"'{param}'", 1) + elif isinstance(param, (date, datetime)): + # 日期参数需要转换为正确格式并加引号 + formatted_sql = formatted_sql.replace('%s', f"'{param.strftime('%Y-%m-%d')}'", 1) + else: + # 数值参数直接替换 + formatted_sql = formatted_sql.replace('%s', str(param), 1) + return pd.read_sql_query(formatted_sql, self.engine) + elif self.engine: + return pd.read_sql_query(sql, self.engine) + else: + # 回退到pymysql + with pymysql.connect(**self.connection_params) as conn: + return pd.read_sql_query(sql, conn, params=params) + except Exception as e: + # 如果SQLAlchemy失败,回退到pymysql + logger.warning(f"SQLAlchemy查询失败,回退到pymysql: {e}") + with pymysql.connect(**self.connection_params) as conn: + return pd.read_sql_query(sql, conn, params=params) + def _test_connection(self): """测试数据库连接""" try: @@ -72,7 +133,13 @@ class MySQLDatabaseManager: try: cursor.execute(statement) except Exception as e: - if "already exists" not in str(e): + # 忽略常见的重复创建警告 + error_str = str(e) + if "Duplicate key name" in error_str: + # 特别处理重复索引键名,只记录debug级别 + logger.debug(f"索引已存在,跳过: {e}") + elif ("already exists" not in error_str and + "Table" not in error_str): logger.warning(f"执行SQL语句时警告: {e}") conn.commit() @@ -172,6 +239,22 @@ class MySQLDatabaseManager: logger.error(f"创建扫描会话失败: {e}") raise + def update_scan_session_stats(self, session_id: int, total_scanned: int, total_signals: int): + """更新扫描会话统计""" + try: + with pymysql.connect(**self.connection_params) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE scan_sessions + SET total_scanned = %s, total_signals = %s + WHERE id = %s + """, (total_scanned, total_signals, session_id)) + conn.commit() + logger.debug(f"更新扫描会话统计: session_id={session_id}, 扫描={total_scanned}, 信号={total_signals}") + except Exception as e: + logger.error(f"更新扫描会话统计失败: {e}") + raise + def save_stock_signal(self, session_id: int, strategy_id: int, signal: Dict[str, Any]) -> int: """保存股票信号""" try: @@ -253,19 +336,17 @@ class MySQLDatabaseManager: def get_latest_signals(self, strategy_name: str = None, limit: int = 100) -> pd.DataFrame: """获取最新信号""" try: - with pymysql.connect(**self.connection_params) as conn: - sql = "SELECT * FROM latest_signals_view" - params = [] + sql = "SELECT * FROM latest_signals_view" + params = [] - if strategy_name: - sql += " WHERE strategy_name = %s" - params.append(strategy_name) + if strategy_name: + sql += " WHERE strategy_name = %s" + params.append(strategy_name) - sql += " LIMIT %s" - params.append(limit) + sql += " LIMIT %s" + params.append(limit) - df = pd.read_sql_query(sql, conn, params=params) - return df + return self._execute_query_with_engine(sql, params) except Exception as e: logger.error(f"获取最新信号失败: {e}") @@ -275,28 +356,26 @@ class MySQLDatabaseManager: strategy_name: str = None, timeframe: str = None) -> pd.DataFrame: """按日期范围获取信号""" try: - with pymysql.connect(**self.connection_params) as conn: - if end_date is None: - end_date = date.today() + if end_date is None: + end_date = date.today() - sql = """ - SELECT * FROM latest_signals_view - WHERE signal_date >= %s AND signal_date <= %s - """ - params = [start_date, end_date] + sql = """ + SELECT * FROM latest_signals_view + WHERE signal_date >= %s AND signal_date <= %s + """ + params = [start_date, end_date] - if strategy_name: - sql += " AND strategy_name = %s" - params.append(strategy_name) + if strategy_name: + sql += " AND strategy_name = %s" + params.append(strategy_name) - if timeframe: - sql += " AND timeframe = %s" - params.append(timeframe) + if timeframe: + sql += " AND timeframe = %s" + params.append(timeframe) - sql += " ORDER BY signal_date DESC, scan_time DESC" + sql += " ORDER BY signal_date DESC, scan_time DESC" - df = pd.read_sql_query(sql, conn, params=params) - return df + return self._execute_query_with_engine(sql, params) except Exception as e: logger.error(f"按日期范围获取信号失败: {e}") @@ -305,9 +384,8 @@ class MySQLDatabaseManager: def get_strategy_stats(self) -> pd.DataFrame: """获取策略统计""" try: - with pymysql.connect(**self.connection_params) as conn: - df = pd.read_sql_query("SELECT * FROM strategy_stats_view", conn) - return df + sql = "SELECT * FROM strategy_stats_view" + return self._execute_query_with_engine(sql) except Exception as e: logger.error(f"获取策略统计失败: {e}") @@ -316,17 +394,15 @@ class MySQLDatabaseManager: def get_pullback_alerts(self, days: int = 30) -> pd.DataFrame: """获取回踩提醒""" try: - with pymysql.connect(**self.connection_params) as conn: - cutoff_date = date.today() - timedelta(days=days) + cutoff_date = date.today() - timedelta(days=days) - sql = """ - SELECT * FROM pullback_alerts - WHERE pullback_date >= %s - ORDER BY pullback_date DESC - """ + sql = """ + SELECT * FROM pullback_alerts + WHERE pullback_date >= %s + ORDER BY pullback_date DESC + """ - df = pd.read_sql_query(sql, conn, params=[cutoff_date]) - return df + return self._execute_query_with_engine(sql, [cutoff_date]) except Exception as e: logger.error(f"获取回踩提醒失败: {e}") diff --git a/src/database/mysql_schema.sql b/src/database/mysql_schema.sql index ec3b47b..6ebc91e 100644 --- a/src/database/mysql_schema.sql +++ b/src/database/mysql_schema.sql @@ -108,8 +108,8 @@ CREATE TABLE IF NOT EXISTS pullback_alerts ( INDEX idx_pullback_date (pullback_date) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; --- 创建索引以提高查询性能 -CREATE INDEX idx_scan_sessions_scan_date ON scan_sessions (scan_date); +-- 创建索引以提高查询性能(忽略重复索引错误) +-- CREATE INDEX idx_scan_sessions_scan_date ON scan_sessions (scan_date); -- 创建视图:最新信号概览 CREATE OR REPLACE VIEW latest_signals_view AS diff --git a/src/execution/strategy_executor.py b/src/execution/strategy_executor.py new file mode 100644 index 0000000..117aa78 --- /dev/null +++ b/src/execution/strategy_executor.py @@ -0,0 +1,382 @@ +""" +策略执行器 +负责协调股票池获取和策略分析的执行流程 +""" + +from typing import Dict, List, Any, Optional +import time +from datetime import datetime +from loguru import logger + +from src.data.stock_pool_manager import StockPoolManager +from src.strategy.base_strategy import BaseStrategy, StrategyResult +from src.utils.notification import NotificationManager + + +class ExecutionResult: + """执行结果""" + + def __init__(self, + task_id: str, + strategy_name: str, + stock_pool_rule: str, + start_time: datetime): + """ + 初始化执行结果 + + Args: + task_id: 任务ID + strategy_name: 策略名称 + stock_pool_rule: 股票池规则 + start_time: 开始时间 + """ + self.task_id = task_id + self.strategy_name = strategy_name + self.stock_pool_rule = stock_pool_rule + self.start_time = start_time + self.end_time: Optional[datetime] = None + + # 股票池信息 + self.stock_pool_info: Dict[str, Any] = {} + self.stock_list: List[str] = [] + + # 策略分析结果 + self.strategy_results: Dict[str, Dict[str, StrategyResult]] = {} + + # 统计信息 + self.total_stocks_analyzed = 0 + self.total_signals_found = 0 + self.stocks_with_signals = 0 + self.execution_time = 0.0 + + # 执行状态 + self.success = False + self.error: Optional[str] = None + + def complete(self, success: bool = True, error: str = None): + """完成执行""" + self.end_time = datetime.now() + self.execution_time = (self.end_time - self.start_time).total_seconds() + self.success = success + self.error = error + + if success: + logger.info(f"✅ 执行完成: {self.task_id} (耗时: {self.execution_time:.2f}秒)") + else: + logger.error(f"❌ 执行失败: {self.task_id} - {error}") + + def add_strategy_results(self, results: Dict[str, Dict[str, StrategyResult]], actual_analyzed_count: int = None): + """添加策略分析结果""" + self.strategy_results = results + # 使用实际分析的股票数量,如果没有提供则使用结果数量 + if actual_analyzed_count is not None: + self.total_stocks_analyzed = actual_analyzed_count + else: + self.total_stocks_analyzed = len(self.stock_list) + self.stocks_with_signals = len(results) + + # 统计信号总数 + for stock_results in results.values(): + for timeframe_result in stock_results.values(): + self.total_signals_found += timeframe_result.get_signal_count() + + def get_summary(self) -> Dict[str, Any]: + """获取执行摘要""" + return { + 'task_id': self.task_id, + 'strategy_name': self.strategy_name, + 'stock_pool_rule': self.stock_pool_rule, + 'stock_pool_rule_display': self.stock_pool_info.get('rule_display_name', ''), + 'start_time': self.start_time.isoformat(), + 'end_time': self.end_time.isoformat() if self.end_time else None, + 'execution_time': self.execution_time, + 'success': self.success, + 'error': self.error, + 'total_stocks_in_pool': len(self.stock_list), + 'total_stocks_analyzed': self.total_stocks_analyzed, + 'stocks_with_signals': self.stocks_with_signals, + 'total_signals_found': self.total_signals_found + } + + +class StrategyExecutor: + """策略执行器""" + + def __init__(self, + stock_pool_manager: StockPoolManager, + notification_manager: NotificationManager): + """ + 初始化策略执行器 + + Args: + stock_pool_manager: 股票池管理器 + notification_manager: 通知管理器 + """ + self.stock_pool_manager = stock_pool_manager + self.notification_manager = notification_manager + self.registered_strategies: Dict[str, BaseStrategy] = {} + + def register_strategy(self, strategy_id: str, strategy: BaseStrategy): + """ + 注册策略 + + Args: + strategy_id: 策略唯一标识 + strategy: 策略实例 + """ + self.registered_strategies[strategy_id] = strategy + logger.info(f"注册策略: {strategy_id} - {strategy.get_strategy_name()}") + + def execute_task(self, + task_id: str, + strategy_id: str, + stock_pool_rule: str, + stock_pool_params: Dict[str, Any] = None, + max_stocks: int = None, + send_notification: bool = True) -> ExecutionResult: + """ + 执行策略分析任务 + + Args: + task_id: 任务唯一标识 + strategy_id: 策略ID + stock_pool_rule: 股票池规则名称 + stock_pool_params: 股票池参数 + max_stocks: 最大分析股票数量 + send_notification: 是否发送通知 + + Returns: + 执行结果 + """ + start_time = datetime.now() + + # 检查策略是否已注册 + if strategy_id not in self.registered_strategies: + error = f"策略未注册: {strategy_id}" + logger.error(error) + result = ExecutionResult(task_id, strategy_id, stock_pool_rule, start_time) + result.complete(False, error) + return result + + strategy = self.registered_strategies[strategy_id] + result = ExecutionResult(task_id, strategy.get_strategy_name(), stock_pool_rule, start_time) + + logger.info(f"🚀 开始执行策略任务: {task_id}") + logger.info(f" 📊 策略: {strategy.get_strategy_name()}") + logger.info(f" 🎯 股票池规则: {stock_pool_rule}") + + try: + # 第1步: 获取股票池 + logger.info("📊 第1步: 获取股票池...") + stock_pool_info = self.stock_pool_manager.get_stock_pool( + stock_pool_rule, + **(stock_pool_params or {}) + ) + + if not stock_pool_info['success']: + error = f"股票池获取失败: {stock_pool_info.get('error', '未知错误')}" + result.complete(False, error) + return result + + stock_list = stock_pool_info['stocks'] + if not stock_list: + error = "股票池为空" + result.complete(False, error) + return result + + result.stock_pool_info = stock_pool_info + result.stock_list = stock_list + + logger.info(f"✅ 股票池获取成功: {stock_pool_info['rule_display_name']} - {len(stock_list)}只股票") + + # 第2步: 执行策略分析 + logger.info("🔍 第2步: 执行策略分析...") + strategy_results = strategy.analyze_stock_pool(stock_list, max_stocks) + + # 计算实际分析的股票数量(限制后的数量) + actual_analyzed = min(len(stock_list), max_stocks) if max_stocks else len(stock_list) + result.add_strategy_results(strategy_results, actual_analyzed) + + # 第3步: 发送通知 + if send_notification and result.total_signals_found > 0: + logger.info("📱 第3步: 发送分析结果通知...") + try: + self._send_notification(result, strategy_results) + except Exception as e: + logger.warning(f"发送通知失败: {e}") + + # 完成执行 + result.complete(True) + + except Exception as e: + result.complete(False, str(e)) + + return result + + def _send_notification(self, result: ExecutionResult, strategy_results: Dict[str, Dict[str, StrategyResult]]): + """发送分析结果通知""" + # 准备通知数据 + summary = result.get_summary() + + # 构建策略结果摘要 + strategy_summary = { + 'strategy_name': result.strategy_name, + 'stock_pool_source': result.stock_pool_info.get('rule_display_name', ''), + 'total_stocks': result.total_stocks_analyzed, + 'stocks_with_signals': result.stocks_with_signals, + 'total_signals': result.total_signals_found, + 'execution_time': result.execution_time + } + + # 发送策略摘要通知 + try: + # 转换StrategyResult格式为通知所需格式 + all_signals = {} + for stock_code, timeframe_results in strategy_results.items(): + for timeframe, strategy_result in timeframe_results.items(): + if strategy_result.get_signal_count() > 0: + if stock_code not in all_signals: + all_signals[stock_code] = {} + all_signals[stock_code][timeframe] = strategy_result.signals + + success = self.notification_manager.send_strategy_summary( + all_signals, + strategy_summary + ) + if success: + logger.info("📱 策略结果通知发送成功") + else: + logger.warning("📱 策略结果通知发送失败") + except Exception as e: + logger.error(f"发送策略摘要通知失败: {e}") + + def get_registered_strategies(self) -> Dict[str, str]: + """ + 获取已注册的策略列表 + + Returns: + 策略ID到策略名称的映射 + """ + return { + strategy_id: strategy.get_strategy_name() + for strategy_id, strategy in self.registered_strategies.items() + } + + def create_task_function(self, + strategy_id: str, + stock_pool_rule: str, + stock_pool_params: Dict[str, Any] = None, + max_stocks: int = None, + send_notification: bool = True): + """ + 创建用于任务调度的函数 + + Args: + strategy_id: 策略ID + stock_pool_rule: 股票池规则 + stock_pool_params: 股票池参数 + max_stocks: 最大股票数量 + send_notification: 是否发送通知 + + Returns: + 可调度的任务函数 + """ + def task_function(): + task_id = f"{strategy_id}_{stock_pool_rule}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + return self.execute_task( + task_id=task_id, + strategy_id=strategy_id, + stock_pool_rule=stock_pool_rule, + stock_pool_params=stock_pool_params, + max_stocks=max_stocks, + send_notification=send_notification + ) + + return task_function + + +if __name__ == "__main__": + # 测试策略执行器 + from loguru import logger + import sys + + logger.remove() + logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") + + # 创建模拟组件 + from src.data.tushare_fetcher import TushareFetcher + from src.data.stock_pool_manager import StockPoolManager + from src.strategy.base_strategy import BaseStrategy, StrategyResult + from src.utils.config_loader import config_loader + + # 创建测试策略 + class TestStrategy(BaseStrategy): + def __init__(self): + super().__init__("测试K线策略", {"timeframes": ["daily"]}) + + def analyze_stock(self, stock_code: str, timeframes: List[str] = None) -> Dict[str, StrategyResult]: + results = {} + # 模拟分析,只有某些股票有信号 + signals = [] + if stock_code.endswith(".SH"): # 模拟上海股票有信号 + signals = [{ + 'date': '2024-01-10', + 'signal_type': '突破信号', + 'price': 15.50 + }] + + results['daily'] = StrategyResult( + strategy_name=self.strategy_name, + stock_code=stock_code, + timeframe='daily', + signals=signals, + execution_time=0.1 + ) + return results + + def get_strategy_description(self) -> str: + return "测试策略,用于验证执行器功能" + + print("=" * 80) + print("🚀 策略执行器测试") + print("=" * 80) + + # 初始化组件 + fetcher = TushareFetcher() + pool_manager = StockPoolManager(fetcher) + + notification_config = config_loader.get('notification', {}) + notification_manager = NotificationManager(notification_config) + + executor = StrategyExecutor(pool_manager, notification_manager) + + # 注册测试策略 + test_strategy = TestStrategy() + executor.register_strategy("test_kline", test_strategy) + + print("已注册策略:") + for strategy_id, strategy_name in executor.get_registered_strategies().items(): + print(f" {strategy_id}: {strategy_name}") + + # 执行测试任务 + print(f"\n执行测试任务...") + result = executor.execute_task( + task_id="test_task_001", + strategy_id="test_kline", + stock_pool_rule="tushare_hot", + stock_pool_params={"limit": 10}, + max_stocks=5, + send_notification=False + ) + + # 显示结果 + print(f"\n执行结果摘要:") + summary = result.get_summary() + for key, value in summary.items(): + print(f" {key}: {value}") + + print(f"\n策略结果详情:") + for stock_code, timeframe_results in result.strategy_results.items(): + for timeframe, strategy_result in timeframe_results.items(): + signals = strategy_result.get_signal_count() + print(f" {stock_code} ({timeframe}): {signals} 个信号") \ No newline at end of file diff --git a/src/execution/task_scheduler.py b/src/execution/task_scheduler.py new file mode 100644 index 0000000..3194b89 --- /dev/null +++ b/src/execution/task_scheduler.py @@ -0,0 +1,396 @@ +""" +任务调度器 +负责定时执行策略分析任务 +""" + +import schedule +import time +import threading +from typing import Dict, List, Callable, Any, Optional +from datetime import datetime, timedelta +from loguru import logger +from dataclasses import dataclass +from enum import Enum + + +class TaskStatus(Enum): + """任务状态""" + PENDING = "pending" # 等待中 + RUNNING = "running" # 运行中 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 + DISABLED = "disabled" # 已禁用 + + +@dataclass +class TaskExecution: + """任务执行记录""" + task_id: str + start_time: datetime + end_time: Optional[datetime] = None + status: TaskStatus = TaskStatus.RUNNING + result: Any = None + error: Optional[str] = None + duration: Optional[float] = None + + def complete(self, result: Any = None): + """标记任务完成""" + self.end_time = datetime.now() + self.status = TaskStatus.COMPLETED + self.result = result + if self.start_time: + self.duration = (self.end_time - self.start_time).total_seconds() + + def fail(self, error: str): + """标记任务失败""" + self.end_time = datetime.now() + self.status = TaskStatus.FAILED + self.error = error + if self.start_time: + self.duration = (self.end_time - self.start_time).total_seconds() + + +class ScheduledTask: + """调度任务""" + + def __init__(self, + task_id: str, + name: str, + func: Callable, + schedule_rule: str, + enabled: bool = True, + max_history: int = 50): + """ + 初始化调度任务 + + Args: + task_id: 任务唯一标识 + name: 任务名称 + func: 执行函数 + schedule_rule: 调度规则 (如: "every 30 minutes", "daily at 09:30") + enabled: 是否启用 + max_history: 最大历史记录数量 + """ + self.task_id = task_id + self.name = name + self.func = func + self.schedule_rule = schedule_rule + self.enabled = enabled + self.max_history = max_history + + # 执行历史 + self.execution_history: List[TaskExecution] = [] + self.current_execution: Optional[TaskExecution] = None + + # 统计信息 + self.total_executions = 0 + self.successful_executions = 0 + self.failed_executions = 0 + self.last_execution_time: Optional[datetime] = None + self.next_execution_time: Optional[datetime] = None + + def execute(self) -> TaskExecution: + """执行任务""" + if not self.enabled: + logger.warning(f"任务已禁用: {self.name}") + return None + + # 创建执行记录 + execution = TaskExecution( + task_id=self.task_id, + start_time=datetime.now() + ) + self.current_execution = execution + + logger.info(f"🚀 开始执行任务: {self.name}") + + try: + # 执行任务函数 + result = self.func() + execution.complete(result) + + logger.info(f"✅ 任务执行成功: {self.name} (耗时: {execution.duration:.2f}秒)") + self.successful_executions += 1 + + except Exception as e: + error_msg = str(e) + execution.fail(error_msg) + + logger.error(f"❌ 任务执行失败: {self.name} - {error_msg}") + self.failed_executions += 1 + + finally: + # 更新统计信息 + self.total_executions += 1 + self.last_execution_time = execution.start_time + self.current_execution = None + + # 添加到历史记录 + self.execution_history.append(execution) + + # 限制历史记录数量 + if len(self.execution_history) > self.max_history: + self.execution_history = self.execution_history[-self.max_history:] + + return execution + + def get_success_rate(self) -> float: + """获取成功率""" + if self.total_executions == 0: + return 0.0 + return self.successful_executions / self.total_executions * 100 + + def get_status(self) -> TaskStatus: + """获取任务状态""" + if not self.enabled: + return TaskStatus.DISABLED + if self.current_execution: + return TaskStatus.RUNNING + return TaskStatus.PENDING + + def get_last_execution(self) -> Optional[TaskExecution]: + """获取最后一次执行记录""" + return self.execution_history[-1] if self.execution_history else None + + +class TaskScheduler: + """任务调度器""" + + def __init__(self): + """初始化调度器""" + self.tasks: Dict[str, ScheduledTask] = {} + self.running = False + self.scheduler_thread: Optional[threading.Thread] = None + + def add_task(self, + task_id: str, + name: str, + func: Callable, + schedule_rule: str, + enabled: bool = True) -> bool: + """ + 添加调度任务 + + Args: + task_id: 任务唯一标识 + name: 任务名称 + func: 执行函数 + schedule_rule: 调度规则 + enabled: 是否启用 + + Returns: + 是否添加成功 + """ + try: + # 创建任务 + task = ScheduledTask(task_id, name, func, schedule_rule, enabled) + + # 解析并设置调度规则 + self._parse_schedule_rule(task, schedule_rule) + + # 添加到任务列表 + self.tasks[task_id] = task + + logger.info(f"✅ 添加调度任务: {name} - {schedule_rule}") + return True + + except Exception as e: + logger.error(f"❌ 添加调度任务失败: {name} - {e}") + return False + + def _parse_schedule_rule(self, task: ScheduledTask, rule: str): + """ + 解析调度规则并设置到schedule库 + + Args: + task: 任务实例 + rule: 调度规则字符串 + """ + rule = rule.lower().strip() + + # 解析不同的调度规则格式 + if rule.startswith("every"): + # 格式: "every 30 minutes", "every 1 hour", "every 2 days" + parts = rule.split() + if len(parts) >= 3: + interval = int(parts[1]) + unit = parts[2].rstrip('s') # 移除复数s + + if unit == "second": + schedule.every(interval).seconds.do(task.execute) + elif unit == "minute": + schedule.every(interval).minutes.do(task.execute) + elif unit == "hour": + schedule.every(interval).hours.do(task.execute) + elif unit == "day": + schedule.every(interval).days.do(task.execute) + elif unit == "week": + schedule.every(interval).weeks.do(task.execute) + + elif "daily at" in rule: + # 格式: "daily at 09:30", "daily at 14:00" + time_str = rule.split("at")[1].strip() + schedule.every().day.at(time_str).do(task.execute) + + elif "weekdays at" in rule: + # 格式: "weekdays at 09:30" + time_str = rule.split("at")[1].strip() + schedule.every().monday.at(time_str).do(task.execute) + schedule.every().tuesday.at(time_str).do(task.execute) + schedule.every().wednesday.at(time_str).do(task.execute) + schedule.every().thursday.at(time_str).do(task.execute) + schedule.every().friday.at(time_str).do(task.execute) + + elif "monday at" in rule: + time_str = rule.split("at")[1].strip() + schedule.every().monday.at(time_str).do(task.execute) + + elif "tuesday at" in rule: + time_str = rule.split("at")[1].strip() + schedule.every().tuesday.at(time_str).do(task.execute) + + # 可以继续添加更多规则... + + else: + raise ValueError(f"不支持的调度规则: {rule}") + + def remove_task(self, task_id: str) -> bool: + """ + 移除任务 + + Args: + task_id: 任务ID + + Returns: + 是否移除成功 + """ + if task_id in self.tasks: + # 清除schedule中的任务 + schedule.clear(tag=task_id) + del self.tasks[task_id] + logger.info(f"✅ 移除任务: {task_id}") + return True + else: + logger.warning(f"⚠️ 任务不存在: {task_id}") + return False + + def enable_task(self, task_id: str) -> bool: + """启用任务""" + if task_id in self.tasks: + self.tasks[task_id].enabled = True + logger.info(f"✅ 启用任务: {task_id}") + return True + return False + + def disable_task(self, task_id: str) -> bool: + """禁用任务""" + if task_id in self.tasks: + self.tasks[task_id].enabled = False + logger.info(f"⏸️ 禁用任务: {task_id}") + return True + return False + + def start(self): + """启动调度器""" + if self.running: + logger.warning("调度器已在运行中") + return + + self.running = True + self.scheduler_thread = threading.Thread(target=self._run_scheduler, daemon=True) + self.scheduler_thread.start() + + logger.info(f"🚀 任务调度器已启动,共 {len(self.tasks)} 个任务") + + def stop(self): + """停止调度器""" + self.running = False + if self.scheduler_thread: + self.scheduler_thread.join(timeout=5) + + logger.info("⏹️ 任务调度器已停止") + + def _run_scheduler(self): + """运行调度器主循环""" + logger.info("📅 调度器主循环已启动") + + while self.running: + try: + schedule.run_pending() + time.sleep(1) + except Exception as e: + logger.error(f"调度器运行异常: {e}") + time.sleep(5) + + def get_task_status(self) -> Dict[str, Dict[str, Any]]: + """获取所有任务状态""" + status = {} + for task_id, task in self.tasks.items(): + last_execution = task.get_last_execution() + status[task_id] = { + 'name': task.name, + 'schedule_rule': task.schedule_rule, + 'enabled': task.enabled, + 'status': task.get_status().value, + 'total_executions': task.total_executions, + 'success_rate': task.get_success_rate(), + 'last_execution_time': task.last_execution_time.isoformat() if task.last_execution_time else None, + 'last_execution_result': last_execution.status.value if last_execution else None + } + return status + + def execute_task_now(self, task_id: str) -> bool: + """立即执行指定任务""" + if task_id not in self.tasks: + logger.error(f"任务不存在: {task_id}") + return False + + task = self.tasks[task_id] + execution = task.execute() + return execution.status == TaskStatus.COMPLETED if execution else False + + +if __name__ == "__main__": + # 测试任务调度器 + from loguru import logger + import sys + + logger.remove() + logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") + + def test_task_1(): + """测试任务1""" + logger.info("执行测试任务1") + time.sleep(0.5) + return {"result": "task1_completed"} + + def test_task_2(): + """测试任务2""" + logger.info("执行测试任务2") + return {"result": "task2_completed"} + + # 创建调度器 + scheduler = TaskScheduler() + + # 添加任务 + scheduler.add_task("task1", "测试任务1", test_task_1, "every 10 seconds") + scheduler.add_task("task2", "测试任务2", test_task_2, "every 30 seconds") + + print("=" * 60) + print("📅 任务调度器测试") + print("=" * 60) + + # 显示任务状态 + print("任务状态:") + for task_id, status in scheduler.get_task_status().items(): + print(f" {task_id}: {status['name']} - {status['schedule_rule']}") + + # 立即执行任务测试 + print("\n立即执行任务测试:") + scheduler.execute_task_now("task1") + + # 可以取消下面的注释来测试定时执行 + # print("\n启动调度器...") + # scheduler.start() + # time.sleep(60) # 运行1分钟 + # scheduler.stop() \ No newline at end of file diff --git a/src/strategy/base_strategy.py b/src/strategy/base_strategy.py new file mode 100644 index 0000000..c9f76d6 --- /dev/null +++ b/src/strategy/base_strategy.py @@ -0,0 +1,257 @@ +""" +策略基类 +定义所有技术分析策略的通用接口 +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +import pandas as pd +from loguru import logger +from datetime import datetime + + +class StrategyResult: + """策略执行结果""" + + def __init__(self, + strategy_name: str, + stock_code: str, + timeframe: str, + signals: List[Dict[str, Any]] = None, + success: bool = True, + error: str = None, + execution_time: float = 0.0): + """ + 初始化策略结果 + + Args: + strategy_name: 策略名称 + stock_code: 股票代码 + timeframe: 时间周期 + signals: 信号列表 + success: 是否成功 + error: 错误信息 + execution_time: 执行时间 + """ + self.strategy_name = strategy_name + self.stock_code = stock_code + self.timeframe = timeframe + self.signals = signals or [] + self.success = success + self.error = error + self.execution_time = execution_time + self.timestamp = datetime.now() + + def has_signals(self) -> bool: + """是否有信号""" + return len(self.signals) > 0 + + def get_signal_count(self) -> int: + """获取信号数量""" + return len(self.signals) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + 'strategy_name': self.strategy_name, + 'stock_code': self.stock_code, + 'timeframe': self.timeframe, + 'signals': self.signals, + 'signal_count': self.get_signal_count(), + 'success': self.success, + 'error': self.error, + 'execution_time': self.execution_time, + 'timestamp': self.timestamp.isoformat() + } + + +class BaseStrategy(ABC): + """策略基类""" + + def __init__(self, strategy_name: str, config: Dict[str, Any]): + """ + 初始化策略 + + Args: + strategy_name: 策略名称 + config: 策略配置 + """ + self.strategy_name = strategy_name + self.config = config + self.timeframes = config.get('timeframes', ['daily']) + + @abstractmethod + def analyze_stock(self, stock_code: str, timeframes: List[str] = None) -> Dict[str, StrategyResult]: + """ + 分析单只股票 + + Args: + stock_code: 股票代码 + timeframes: 时间周期列表,如果为None则使用策略默认周期 + + Returns: + 时间周期到策略结果的映射 + """ + pass + + def analyze_stock_pool(self, stock_list: List[str], max_stocks: int = None) -> Dict[str, Dict[str, StrategyResult]]: + """ + 分析股票池 + + Args: + stock_list: 股票代码列表 + max_stocks: 最大分析股票数量 + + Returns: + 股票代码到时间周期结果映射的字典 + """ + if max_stocks and len(stock_list) > max_stocks: + stock_list = stock_list[:max_stocks] + logger.info(f"限制分析股票数量为: {max_stocks}") + + results = {} + total_signals = 0 + + logger.info(f"🔍 开始分析股票池,共 {len(stock_list)} 只股票") + + for i, stock_code in enumerate(stock_list, 1): + try: + logger.info(f"⏳ 分析进度: [{i:3d}/{len(stock_list):3d}] 🔍 {stock_code}") + + # 分析单只股票 + stock_results = self.analyze_stock(stock_code) + + # 统计信号数量 + stock_signal_count = sum( + result.get_signal_count() for result in stock_results.values() + ) + + if stock_signal_count > 0: + results[stock_code] = stock_results + total_signals += stock_signal_count + logger.info(f"✅ {stock_code} 发现 {stock_signal_count} 个信号") + + except Exception as e: + logger.error(f"❌ 分析股票 {stock_code} 失败: {e}") + continue + + logger.info(f"🎉 股票池分析完成: 扫描 {len(stock_list)} 只,发现 {total_signals} 个信号,涉及 {len(results)} 只股票") + + return results + + def get_strategy_name(self) -> str: + """获取策略名称""" + return self.strategy_name + + def get_timeframes(self) -> List[str]: + """获取支持的时间周期""" + return self.timeframes + + def get_config(self) -> Dict[str, Any]: + """获取策略配置""" + return self.config + + @abstractmethod + def get_strategy_description(self) -> str: + """获取策略描述""" + pass + + def format_results_summary(self, results: Dict[str, Dict[str, StrategyResult]]) -> str: + """ + 格式化结果摘要 + + Args: + results: 分析结果 + + Returns: + 格式化的摘要字符串 + """ + if not results: + return f"📊 {self.strategy_name}: 未发现任何信号" + + total_stocks = len(results) + total_signals = sum( + sum(result.get_signal_count() for result in stock_results.values()) + for stock_results in results.values() + ) + + summary = f"📊 {self.strategy_name} 分析结果:\n" + summary += f" 🎯 发现信号: {total_signals} 个\n" + summary += f" 📈 涉及股票: {total_stocks} 只\n" + + # 按信号数量排序显示前5只股票 + stock_signal_counts = [] + for stock_code, stock_results in results.items(): + signal_count = sum(result.get_signal_count() for result in stock_results.values()) + stock_signal_counts.append((stock_code, signal_count)) + + stock_signal_counts.sort(key=lambda x: x[1], reverse=True) + + summary += f"\n 🔥 信号最多的股票:\n" + for i, (stock_code, signal_count) in enumerate(stock_signal_counts[:5], 1): + summary += f" {i}. {stock_code}: {signal_count} 个信号\n" + + return summary + + +if __name__ == "__main__": + # 测试策略基类功能 + from loguru import logger + import sys + + logger.remove() + logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") + + # 创建一个测试策略 + class TestStrategy(BaseStrategy): + def __init__(self): + super().__init__("测试策略", {"timeframes": ["daily", "weekly"]}) + + def analyze_stock(self, stock_code: str, timeframes: List[str] = None) -> Dict[str, StrategyResult]: + # 模拟分析结果 + results = {} + for timeframe in self.timeframes: + # 模拟生成一些信号 + signals = [] + if stock_code.startswith("000001"): # 模拟只有某些股票有信号 + signals = [ + { + 'date': '2024-01-10', + 'signal_type': '买入信号', + 'price': 10.50, + 'confidence': 0.85 + } + ] + + results[timeframe] = StrategyResult( + strategy_name=self.strategy_name, + stock_code=stock_code, + timeframe=timeframe, + signals=signals, + execution_time=0.1 + ) + + return results + + def get_strategy_description(self) -> str: + return "这是一个测试策略,用于验证基类功能" + + # 测试 + print("=" * 60) + print("📊 策略基类测试") + print("=" * 60) + + strategy = TestStrategy() + test_stocks = ["000001.SZ", "000002.SZ", "600000.SH"] + + print(f"策略名称: {strategy.get_strategy_name()}") + print(f"支持周期: {strategy.get_timeframes()}") + print(f"策略描述: {strategy.get_strategy_description()}") + + # 测试股票池分析 + results = strategy.analyze_stock_pool(test_stocks) + print(f"\n{strategy.format_results_summary(results)}") + + # 测试单股分析 + single_result = strategy.analyze_stock("000001.SZ") + print(f"\n单股分析结果: {single_result['daily'].to_dict()}") \ No newline at end of file diff --git a/src/strategy/kline_pattern_strategy.py b/src/strategy/kline_pattern_strategy.py index a210102..63ec68c 100644 --- a/src/strategy/kline_pattern_strategy.py +++ b/src/strategy/kline_pattern_strategy.py @@ -9,15 +9,16 @@ from typing import Dict, List, Tuple, Optional, Any from datetime import datetime, timedelta from loguru import logger -from ..data.tushare_fetcher import TushareFetcher as ADataFetcher +from ..data.tushare_fetcher import TushareFetcher from ..utils.notification import NotificationManager from ..database.mysql_database_manager import MySQLDatabaseManager +from .base_strategy import BaseStrategy, StrategyResult -class KLinePatternStrategy: +class KLinePatternStrategy(BaseStrategy): """K线形态策略类""" - def __init__(self, data_fetcher: ADataFetcher, notification_manager: NotificationManager, + def __init__(self, data_fetcher: TushareFetcher, notification_manager: NotificationManager, config: Dict[str, Any], db_manager: MySQLDatabaseManager = None): """ 初始化K线形态策略 @@ -28,17 +29,22 @@ class KLinePatternStrategy: config: 策略配置 db_manager: 数据库管理器 """ + # 初始化基类 + super().__init__("K线形态策略", config) + self.data_fetcher = data_fetcher self.notification_manager = notification_manager - self.config = config self.db_manager = db_manager or MySQLDatabaseManager() # 策略参数 - self.strategy_name = "K线形态策略" self.min_entity_ratio = config.get('min_entity_ratio', 0.55) # 前两根阳线实体部分最小比例 self.final_yang_min_ratio = config.get('final_yang_min_ratio', 0.40) # 最后阳线实体部分最小比例 self.max_turnover_ratio = config.get('max_turnover_ratio', 40.0) # 最后阳线最大换手率(%) - self.timeframes = config.get('timeframes', ['daily', 'weekly']) # 支持的时间周期 + + # 热榜缓存机制 + self._hot_stocks_cache = None + self._cache_timestamp = None + self._cache_duration = 300 # 缓存5分钟 # 回踩监控参数 self.pullback_tolerance = config.get('pullback_tolerance', 0.02) # 回踩容忍度(2%) @@ -51,10 +57,6 @@ class KLinePatternStrategy: # 格式: {stock_code: {'signals': [signal_dict], 'last_check_date': date}} self.triggered_signals = {} - # 新增:存储待确认的模式(等待回踩确认) - # 格式: {stock_code: {'pending_patterns': [pattern_dict], 'last_check_date': date}} - self.pending_patterns = {} - # 确保策略在数据库中存在 self.strategy_id = self.db_manager.create_or_update_strategy( strategy_name=self.strategy_name, @@ -65,6 +67,47 @@ class KLinePatternStrategy: logger.info(f"K线形态策略初始化完成 (策略ID: {self.strategy_id})") + def _get_cached_hot_stocks(self, max_stocks: int) -> List[str]: + """ + 获取缓存的热榜股票列表 + + Args: + max_stocks: 最大股票数量 + + Returns: + 股票代码列表 + """ + import time + current_time = time.time() + + # 检查缓存是否有效 + if (self._hot_stocks_cache is not None and + self._cache_timestamp is not None and + current_time - self._cache_timestamp < self._cache_duration): + + logger.info(f"🔄 使用缓存的同花顺热榜数据 ({len(self._hot_stocks_cache)} 只股票)") + return self._hot_stocks_cache[:max_stocks] + + # 缓存失效或不存在,重新获取 + logger.info(f"🔥 获取同花顺热榜股票 (前{max_stocks}只)...") + hot_stocks = self.data_fetcher.get_hot_stocks_ths(limit=max_stocks * 2) # 多获取一些以备不同max_stocks需求 + + if not hot_stocks.empty and 'stock_code' in hot_stocks.columns: + self._hot_stocks_cache = hot_stocks['stock_code'].tolist() + self._cache_timestamp = current_time + + logger.info(f"✅ 同花顺热榜获取成功,已缓存 {len(self._hot_stocks_cache)} 只股票") + return self._hot_stocks_cache[:max_stocks] + else: + logger.error("❌ 同花顺热榜数据为空") + return [] + + def clear_hot_stocks_cache(self): + """清除热榜缓存,强制下次重新获取""" + self._hot_stocks_cache = None + self._cache_timestamp = None + logger.info("🔄 热榜缓存已清除") + def calculate_kline_features(self, df: pd.DataFrame) -> pd.DataFrame: """ 计算K线特征指标 @@ -117,47 +160,6 @@ class KLinePatternStrategy: return df - def check_follow_up_strength(self, df: pd.DataFrame, pattern_index: int, yin_high: float, ema20_price: float) -> bool: - """ - 检查形态发生后的后续强势条件 - - Args: - df: K线数据DataFrame - pattern_index: 形态完成的索引位置(第4根阳线的位置) - yin_high: 阴线的最高价 - ema20_price: 形态完成时的EMA20价格 - - Returns: - bool: 是否满足后续强势条件 - """ - # 获取配置的后续验证天数,默认为3天 - follow_up_days = self.config.get('follow_up_days', 3) - - # 检查是否有足够的后续数据 - if pattern_index + follow_up_days >= len(df): - # 如果没有足够的后续数据,说明是最新的形态,暂时通过验证 - return True - - # 检查后续N天的K线 - for j in range(1, follow_up_days + 1): - if pattern_index + j >= len(df): - break - - next_kline = df.iloc[pattern_index + j] - - # 检查1: 最低价不能回踩阴线最高价 - if next_kline['low'] <= yin_high: - logger.debug(f"后续第{j}天最低价{next_kline['low']:.2f}回踩阴线最高价{yin_high:.2f},不符合强势条件") - return False - - # 检查2: 收盘价不能跌破EMA20 - next_ema20 = next_kline.get('ema20', ema20_price) - if next_kline['close'] <= next_ema20: - logger.debug(f"后续第{j}天收盘价{next_kline['close']:.2f}跌破EMA20价格{next_ema20:.2f},不符合强势条件") - return False - - logger.debug(f"后续{follow_up_days}天强势验证通过") - return True def detect_potential_pattern(self, df: pd.DataFrame) -> List[Dict[str, Any]]: """ @@ -382,6 +384,52 @@ class KLinePatternStrategy: return None + def _filter_recent_signals(self, signals: List[Dict[str, Any]], days: int = 7) -> List[Dict[str, Any]]: + """ + 过滤最近N天内产生的信号 + + Args: + signals: 信号列表 + days: 最近天数,默认7天 + + Returns: + 过滤后的信号列表 + """ + if not signals: + return signals + + current_date = datetime.now().date() + recent_signals = [] + + for signal in signals: + signal_date = signal.get('confirmation_date') or signal.get('date') + + # 处理不同的日期格式 + if isinstance(signal_date, str): + try: + signal_date = pd.to_datetime(signal_date).date() + except: + continue + elif hasattr(signal_date, 'date'): + signal_date = signal_date.date() + elif not isinstance(signal_date, datetime.date): + continue + + # 计算信号距今天数 + days_ago = (current_date - signal_date).days + + # 只保留最近N天内的信号 + if days_ago <= days: + recent_signals.append(signal) + logger.debug(f"✅ 保留近期信号: {signal_date} (距今{days_ago}天)") + else: + logger.debug(f"🗓️ 过滤历史信号: {signal_date} (距今{days_ago}天)") + + if len(recent_signals) != len(signals): + logger.info(f"📅 信号过滤: 总共{len(signals)}个 → 近{days}天内{len(recent_signals)}个") + + return recent_signals + def detect_pattern(self, df: pd.DataFrame) -> List[Dict[str, Any]]: """ 检测"两阳线+阴线+阳线"形态(创新高回踩确认逻辑) @@ -410,40 +458,45 @@ class KLinePatternStrategy: return confirmed_signals - def analyze_stock(self, stock_code: str, stock_name: str = None, days: int = 60, - session_id: Optional[int] = None) -> Dict[str, List[Dict[str, Any]]]: + def analyze_stock(self, stock_code: str, timeframes: List[str] = None, session_id: int = None) -> Dict[str, StrategyResult]: """ - 分析单只股票的K线形态 + 分析单只股票的K线形态 - 继承自BaseStrategy Args: stock_code: 股票代码 - stock_name: 股票名称 - days: 分析的天数 + timeframes: 时间周期列表,如果为None则使用策略默认周期 Returns: - 各时间周期的信号字典 + 时间周期到策略结果的映射 """ + if timeframes is None: + timeframes = self.timeframes + results = {} + stock_name = self.data_fetcher.get_stock_name(stock_code) - if stock_name is None: - # 尝试获取股票中文名称 - stock_name = self.data_fetcher.get_stock_name(stock_code) - - try: - # 计算开始日期,针对不同周期调整时间范围 - end_date = datetime.now().strftime('%Y-%m-%d') - - for timeframe in self.timeframes: - analysis_days = days - start_date = (datetime.now() - timedelta(days=analysis_days)).strftime('%Y-%m-%d') + for timeframe in timeframes: + start_time = datetime.now() + try: + # 计算开始日期 + end_date = datetime.now().strftime('%Y-%m-%d') + start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') logger.info(f"🔍 分析股票: {stock_code}({stock_name}) | 周期: {timeframe}") - # 获取历史数据 - 直接使用adata的原生周期支持 + # 获取历史数据 df = self.data_fetcher.get_historical_data(stock_code, start_date, end_date, timeframe) if df.empty: logger.warning(f"{stock_code} {timeframe} 数据为空") - results[timeframe] = [] + results[timeframe] = StrategyResult( + strategy_name=self.strategy_name, + stock_code=stock_code, + timeframe=timeframe, + signals=[], + success=False, + error="数据为空", + execution_time=(datetime.now() - start_time).total_seconds() + ) continue # 计算K线特征 @@ -452,16 +505,46 @@ class KLinePatternStrategy: # 检测形态 signals = self.detect_pattern(df_with_features) - # 处理信号 - for signal in signals: + # 过滤一周内的信号 + recent_signals = self._filter_recent_signals(signals, days=7) + + # 处理信号格式 + formatted_signals = [] + for signal in recent_signals: + formatted_signal = { + 'date': signal['date'], + 'signal_type': signal['pattern_type'], + 'price': signal['breakout_price'], + 'confidence': signal['final_yang_entity_ratio'], + 'stock_name': stock_name, # 添加股票名称 + 'details': { + 'yin_high': signal['yin_high'], + 'breakout_amount': signal['breakout_amount'], + 'breakout_pct': signal['breakout_pct'], + 'ema20_price': signal['ema20_price'], + 'turnover_ratio': signal['turnover_ratio'] + } + } + + # 如果是确认信号,添加确认信息 + if not signal.get('confirmation_pending', True): + formatted_signal['details'].update({ + 'new_high_price': signal.get('new_high_price'), + 'new_high_date': signal.get('new_high_date'), + 'confirmation_date': signal.get('confirmation_date'), + 'confirmation_days': signal.get('confirmation_days'), + 'pullback_distance': signal.get('pullback_distance') + }) + + formatted_signals.append(formatted_signal) + + # 将信号添加到监控列表 signal['stock_code'] = stock_code signal['stock_name'] = stock_name signal['timeframe'] = timeframe - - # 将信号添加到监控列表 self.add_triggered_signal(signal) - # 保存信号到数据库(如果提供了session_id) + # 保存信号到数据库 if session_id is not None: try: signal_id = self.db_manager.save_stock_signal( @@ -469,28 +552,59 @@ class KLinePatternStrategy: strategy_id=self.strategy_id, signal=signal ) - signal['signal_id'] = signal_id - logger.debug(f"信号已保存到数据库: {stock_code} (ID: {signal_id})") + logger.debug(f"信号已保存到数据库: signal_id={signal_id}") except Exception as e: logger.error(f"保存信号到数据库失败: {e}") - results[timeframe] = signals + execution_time = (datetime.now() - start_time).total_seconds() + results[timeframe] = StrategyResult( + strategy_name=self.strategy_name, + stock_code=stock_code, + timeframe=timeframe, + signals=formatted_signals, + success=True, + execution_time=execution_time + ) # 美化信号统计日志 - if signals: - logger.info(f"✅ {stock_code}({stock_name}) {timeframe}周期: 发现 {len(signals)} 个信号") - for i, signal in enumerate(signals, 1): - logger.info(f" 📊 信号{i}: {signal['date']} | 价格: {signal['breakout_price']:.2f}元 | 实体: {signal['final_yang_entity_ratio']:.1%}") + if formatted_signals: + logger.info(f"✅ {stock_code}({stock_name}) {timeframe}周期: 发现 {len(formatted_signals)} 个信号") + for i, signal in enumerate(formatted_signals, 1): + logger.info(f" 📊 信号{i}: {signal['date']} | 价格: {signal['price']:.2f}元 | 置信度: {signal['confidence']:.1%}") else: logger.debug(f"📭 {stock_code}({stock_name}) {timeframe}周期: 无信号") - except Exception as e: - logger.error(f"分析股票 {stock_code} 失败: {e}") - for timeframe in self.timeframes: - results[timeframe] = [] + except Exception as e: + logger.error(f"分析股票 {stock_code} {timeframe}周期失败: {e}") + execution_time = (datetime.now() - start_time).total_seconds() + results[timeframe] = StrategyResult( + strategy_name=self.strategy_name, + stock_code=stock_code, + timeframe=timeframe, + signals=[], + success=False, + error=str(e), + execution_time=execution_time + ) return results + def get_strategy_description(self) -> str: + """获取策略描述""" + return f"""K线形态策略 - 两阳线+阴线+阳线突破(创新高回踩确认) + +该策略通过识别特定的K线形态来发现股票突破机会: +1. 识别连续4根K线:阳线 + 阳线 + 阴线 + 阳线 +2. 前两根阳线实体部分须占振幅的 {self.min_entity_ratio:.0%} 以上 +3. 最后阳线实体部分须占振幅的 {self.final_yang_min_ratio:.0%} 以上 +4. 最后阳线收盘价须高于阴线最高价(突破确认) +5. 最后阳线收盘价须在EMA20上方(趋势确认) +6. 最后阳线换手率不高于 {self.max_turnover_ratio:.1f}%(流动性约束) +7. 价格必须创新高后回踩到阴线最高点附近才产生正式信号 + +支持时间周期:{', '.join(self.timeframes)} +""" + def check_pullback_signals(self, stock_code: str, current_data: pd.DataFrame) -> List[Dict[str, Any]]: """ 检查已触发信号的价格回踩情况 @@ -618,109 +732,7 @@ class KLinePatternStrategy: self.triggered_signals[stock_code]['signals'] = \ self.triggered_signals[stock_code]['signals'][:max_signals_per_stock] - def add_pending_pattern(self, pattern: Dict[str, Any]): - """ - 添加待确认的模式到监控列表 - Args: - pattern: 潜在模式字典 - """ - stock_code = pattern.get('stock_code') - if not stock_code: - return - - if stock_code not in self.pending_patterns: - self.pending_patterns[stock_code] = { - 'pending_patterns': [], - 'last_check_date': datetime.now().date() - } - - # 添加模式到监控列表 - self.pending_patterns[stock_code]['pending_patterns'].append(pattern) - - # 只保留最近的模式(避免内存占用过多) - max_patterns_per_stock = 5 - if len(self.pending_patterns[stock_code]['pending_patterns']) > max_patterns_per_stock: - # 按日期排序,保留最新的模式 - self.pending_patterns[stock_code]['pending_patterns'].sort( - key=lambda x: pd.to_datetime(x['date']) if isinstance(x['date'], str) else x['date'], - reverse=True - ) - self.pending_patterns[stock_code]['pending_patterns'] = \ - self.pending_patterns[stock_code]['pending_patterns'][:max_patterns_per_stock] - - def monitor_pending_pattern_confirmations(self) -> List[Dict[str, Any]]: - """ - 监控待确认模式的回踩确认情况 - - Returns: - 新确认的信号列表 - """ - newly_confirmed_signals = [] - current_date = datetime.now().date() - - # 清理过期的模式 - stocks_to_remove = [] - for stock_code, pattern_info in self.pending_patterns.items(): - # 过滤掉过期的模式 - valid_patterns = [] - for pattern in pattern_info['pending_patterns']: - pattern_date = pattern['date'] - if isinstance(pattern_date, str): - pattern_date = pd.to_datetime(pattern_date).date() - elif hasattr(pattern_date, 'date'): - pattern_date = pattern_date.date() - - days_since_pattern = (current_date - pattern_date).days - if days_since_pattern <= self.pullback_confirmation_days: - valid_patterns.append(pattern) - - if valid_patterns: - self.pending_patterns[stock_code]['pending_patterns'] = valid_patterns - else: - stocks_to_remove.append(stock_code) - - # 移除没有有效模式的股票 - for stock_code in stocks_to_remove: - del self.pending_patterns[stock_code] - - logger.info(f"🔍 当前监控待确认模式的股票数量: {len(self.pending_patterns)}") - - # 检查每只股票的回踩确认情况 - for stock_code in self.pending_patterns.keys(): - try: - # 获取最近的数据(包含确认窗口) - end_date = current_date.strftime('%Y-%m-%d') - start_date = (current_date - timedelta(days=self.pullback_confirmation_days + 5)).strftime('%Y-%m-%d') - - current_data = self.data_fetcher.get_historical_data( - stock_code, start_date, end_date, 'daily' - ) - - if not current_data.empty: - # 计算K线特征 - df_with_features = self.calculate_kline_features(current_data) - - # 检查每个待确认模式 - for pattern in self.pending_patterns[stock_code]['pending_patterns']: - has_confirmation = self.check_pullback_confirmation(df_with_features, pattern) - - if has_confirmation: - # 创建确认信号 - confirmed_signal = pattern.copy() - confirmed_signal['pattern_type'] = '两阳+阴+阳突破(已确认)' - confirmed_signal['confirmation_pending'] = False - confirmed_signal['pullback_confirmed'] = True - confirmed_signal['stock_code'] = stock_code - - newly_confirmed_signals.append(confirmed_signal) - - logger.info(f"✅ 股票 {stock_code} 的待确认模式已通过回踩确认") - - except Exception as e: - logger.error(f"监控股票 {stock_code} 待确认模式失败: {e}") - - return newly_confirmed_signals def monitor_pullback_for_triggered_signals(self) -> List[Dict[str, Any]]: """ @@ -790,108 +802,27 @@ class KLinePatternStrategy: return all_pullback_alerts - def _convert_to_weekly(self, daily_df: pd.DataFrame) -> pd.DataFrame: + + + def scan_market(self, stock_list: List[str] = None, max_stocks: int = 100) -> Dict[str, Dict[str, List[Dict[str, Any]]]]: """ - 将日线数据转换为周线数据 + 扫描市场中的股票形态 - 只使用同花顺热榜股票 Args: - daily_df: 日线数据 - - Returns: - 周线数据 - """ - if daily_df.empty: - return daily_df - - try: - df = daily_df.copy() - - # 确保有trade_date列并设置为索引 - if 'trade_date' in df.columns: - df['trade_date'] = pd.to_datetime(df['trade_date']) - df.set_index('trade_date', inplace=True) - - # 按周聚合 - weekly_df = df.resample('W').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'volume': 'sum' if 'volume' in df.columns else 'last' - }).dropna() - - # 重置索引,保持trade_date列 - weekly_df.reset_index(inplace=True) - - return weekly_df - - except Exception as e: - logger.error(f"转换周线数据失败: {e}") - return pd.DataFrame() - - def _convert_to_monthly(self, daily_df: pd.DataFrame) -> pd.DataFrame: - """ - 将日线数据转换为月线数据 - - Args: - daily_df: 日线数据 - - Returns: - 月线数据 - """ - if daily_df.empty: - return daily_df - - try: - df = daily_df.copy() - - # 确保有trade_date列并设置为索引 - if 'trade_date' in df.columns: - df['trade_date'] = pd.to_datetime(df['trade_date']) - df.set_index('trade_date', inplace=True) - - # 按月聚合 - monthly_df = df.resample('ME').agg({ - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'volume': 'sum' if 'volume' in df.columns else 'last' - }).dropna() - - # 重置索引,保持trade_date列 - monthly_df.reset_index(inplace=True) - - return monthly_df - - except Exception as e: - logger.error(f"转换月线数据失败: {e}") - return pd.DataFrame() - - def scan_market(self, stock_list: List[str] = None, max_stocks: int = 100, use_hot_stocks: bool = True, use_combined_sources: bool = True, use_all_a_shares: bool = False) -> Dict[str, Dict[str, List[Dict[str, Any]]]]: - """ - 扫描市场中的股票形态 - - Args: - stock_list: 股票代码列表,如果为None则自动选择股票池 + stock_list: 股票代码列表,如果为None则使用同花顺热榜 max_stocks: 最大扫描股票数量 - use_hot_stocks: 是否使用热门股票数据,默认True - use_combined_sources: 是否使用合并的双数据源(同花顺+东财),默认True - use_all_a_shares: 是否使用所有A股股票(排除北交所和ST),优先级最高 Returns: 所有股票的分析结果 """ logger.info("🚀" + "="*70) - logger.info("🌍 开始市场K线形态扫描") + logger.info("🌍 开始市场K线形态扫描 - 只使用同花顺热榜") logger.info("🚀" + "="*70) # 创建扫描会话 scan_config = { 'max_stocks': max_stocks, - 'use_hot_stocks': use_hot_stocks, - 'use_combined_sources': use_combined_sources, - 'use_all_a_shares': use_all_a_shares, + 'data_source': '同花顺热榜', 'timeframes': self.timeframes } session_id = self.db_manager.create_scan_session( @@ -900,81 +831,14 @@ class KLinePatternStrategy: ) if stock_list is None: - # 优先级1: 使用所有A股股票 - if use_all_a_shares: - try: - logger.info("📊 获取所有A股股票数据(排除北交所和ST股票)...") - filtered_stocks = self.data_fetcher.get_filtered_a_share_list() + # 使用缓存的同花顺热榜数据 + stock_list = self._get_cached_hot_stocks(max_stocks) - if not filtered_stocks.empty: - # 如果max_stocks小于总股票数,随机采样 - if max_stocks > 0 and max_stocks < len(filtered_stocks): - # 按市值排序或随机选择,这里先随机选择 - selected_stocks = filtered_stocks.sample(max_stocks) - stock_list = selected_stocks['full_stock_code'].tolist() - logger.info(f"📈 从{len(filtered_stocks)}只A股中随机选择{len(stock_list)}只进行分析") - else: - stock_list = filtered_stocks['full_stock_code'].tolist() - logger.info(f"📈 分析全部{len(stock_list)}只A股股票") - - source_info = "全A股(排除北交所和ST)" - else: - logger.warning("获取A股列表失败,回退到热门股票") - use_all_a_shares = False - - except Exception as e: - logger.error(f"获取A股列表失败: {e},回退到热门股票") - use_all_a_shares = False - - # 优先级2: 使用热门股票数据 - if not use_all_a_shares and use_hot_stocks: - try: - if use_combined_sources: - # 使用合并的双数据源 - logger.info("获取合并热门股票数据(同花顺+东财)...") - hot_stocks = self.data_fetcher.get_combined_hot_stocks( - limit_per_source=max_stocks, - final_limit=max_stocks - ) - source_info = "双数据源合并" - else: - # 仅使用同花顺数据 - logger.info("获取同花顺热股TOP100数据...") - hot_stocks = self.data_fetcher.get_hot_stocks_ths(limit=max_stocks) - source_info = "同花顺热股" - - if not hot_stocks.empty and 'stock_code' in hot_stocks.columns: - stock_list = hot_stocks['stock_code'].tolist() - - # 统计数据源分布 - if 'source' in hot_stocks.columns: - source_counts = hot_stocks['source'].value_counts().to_dict() - source_detail = " | ".join([f"{k}: {v}只" for k, v in source_counts.items()]) - logger.info(f"📊 数据源: {source_info} | 总计: {len(stock_list)}只股票") - logger.info(f"📈 分布详情: {source_detail}") - else: - logger.info(f"📊 数据源: {source_info} | 总计: {len(stock_list)}只股票") - else: - logger.warning("热门股票数据为空,回退到全市场股票") - use_hot_stocks = False - except Exception as e: - logger.error(f"获取热门股票失败: {e},回退到全市场股票") - use_hot_stocks = False - - # 优先级3: 如果热股获取失败,使用全市场股票列表 - if not use_all_a_shares and not use_hot_stocks: - try: - all_stocks = self.data_fetcher.get_stock_list() - if not all_stocks.empty: - # 随机选择一些股票进行扫描 - stock_list = all_stocks['stock_code'].head(max_stocks).tolist() - logger.info(f"使用全市场股票数据,共{len(stock_list)}只股票") - else: - logger.warning("未能获取股票列表") - return {} - except Exception as e: - logger.error(f"获取股票列表失败: {e}") - return {} + if stock_list: + logger.info(f"📊 数据源: 同花顺热榜 | 扫描股票: {len(stock_list)} 只") + else: + logger.error("❌ 同花顺热榜数据为空,无法进行扫描") + return {} results = {} total_signals = 0 @@ -1038,15 +902,9 @@ class KLinePatternStrategy: # 发送汇总通知 if results: - # 判断数据源类型 - data_source = '全市场股票' - if stock_list and len(stock_list) <= max_stocks: - if use_hot_stocks: - data_source = '合并热门股票' if use_combined_sources else '热门股票' - scan_stats = { 'total_scanned': len(stock_list), - 'data_source': data_source + 'data_source': '同花顺热榜' } try: @@ -1119,10 +977,77 @@ K线形态策略 - 两阳线+阴线+阳线突破(优化版:创新高回踩 - 系统日志详细记录 """ + def analyze_stock_pool(self, stock_list: List[str], max_stocks: int = None) -> Dict[str, Dict[str, StrategyResult]]: + """ + 分析股票池(带数据库会话管理) + + Args: + stock_list: 股票代码列表 + max_stocks: 最大分析股票数量 + + Returns: + 股票代码到时间周期结果映射的字典 + """ + if max_stocks and len(stock_list) > max_stocks: + stock_list = stock_list[:max_stocks] + logger.info(f"限制分析股票数量为: {max_stocks}") + + # 创建扫描会话 + scan_config = { + 'max_stocks': max_stocks, + 'timeframes': self.timeframes + } + session_id = self.db_manager.create_scan_session( + strategy_id=self.strategy_id, + scan_config=scan_config + ) + logger.debug(f"创建扫描会话: session_id={session_id}") + + results = {} + total_signals = 0 + + logger.info(f"🔍 开始分析股票池,共 {len(stock_list)} 只股票") + + for i, stock_code in enumerate(stock_list, 1): + try: + logger.info(f"⏳ 分析进度: [{i:3d}/{len(stock_list):3d}] 🔍 {stock_code}") + + # 分析单只股票,传递session_id + stock_results = self.analyze_stock(stock_code, session_id=session_id) + + # 统计信号数量 + stock_signal_count = sum( + result.get_signal_count() for result in stock_results.values() + ) + + if stock_signal_count > 0: + results[stock_code] = stock_results + total_signals += stock_signal_count + logger.info(f"✅ {stock_code} 发现 {stock_signal_count} 个信号") + + except Exception as e: + logger.error(f"❌ 分析股票 {stock_code} 失败: {e}") + continue + + # 更新扫描会话统计 + try: + self.db_manager.update_scan_session_stats( + session_id=session_id, + total_scanned=len(stock_list), + total_signals=total_signals + ) + logger.debug(f"更新扫描会话统计: session_id={session_id}, 扫描={len(stock_list)}, 信号={total_signals}") + except Exception as e: + logger.error(f"更新扫描会话统计失败: {e}") + + logger.info(f"🎉 股票池分析完成: 扫描 {len(stock_list)} 只,发现 {total_signals} 个信号,涉及 {len(results)} 只股票") + + return results + if __name__ == "__main__": # 测试代码 - from ..data.tushare_fetcher import TushareFetcher as ADataFetcher + from ..data.tushare_fetcher import TushareFetcher from ..utils.notification import NotificationManager # 模拟配置 @@ -1139,7 +1064,7 @@ if __name__ == "__main__": } # 初始化组件 - data_fetcher = ADataFetcher() + data_fetcher = TushareFetcher() notification_manager = NotificationManager(notification_config) strategy = KLinePatternStrategy(data_fetcher, notification_manager, strategy_config) diff --git a/src/utils/config_loader.py b/src/utils/config_loader.py index 11476a6..633ce4c 100644 --- a/src/utils/config_loader.py +++ b/src/utils/config_loader.py @@ -96,6 +96,10 @@ class ConfigLoader: """获取日志配置""" return self.get('logging', {}) + def get_tushare_token(self) -> str: + """获取TuShare token""" + return self.get('data_source.tushare_token', '') + # 全局配置实例 config_loader = ConfigLoader() diff --git a/src/utils/notification.py b/src/utils/notification.py index f8e58ed..3850b80 100644 --- a/src/utils/notification.py +++ b/src/utils/notification.py @@ -357,8 +357,9 @@ class NotificationManager: return False try: - from datetime import datetime + from datetime import datetime, date import math + import pandas as pd current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 收集所有信号详情 @@ -366,15 +367,39 @@ class NotificationManager: total_signals = 0 total_stocks = len(all_signals) + # 计算一周前的日期 + current_date = date.today() + one_week_ago = current_date - pd.Timedelta(days=7) + for stock_code, stock_results in all_signals.items(): for timeframe, signals in stock_results.items(): for signal in signals: + # 添加一周内信号过滤 + signal_date = signal.get('details', {}).get('confirmation_date') or signal['date'] + + # 处理不同的日期格式 + if isinstance(signal_date, str): + try: + signal_date = pd.to_datetime(signal_date).date() + except: + continue + elif hasattr(signal_date, 'date'): + signal_date = signal_date.date() + elif not isinstance(signal_date, date): + continue + + # 只处理一周内的信号 + if signal_date < one_week_ago: + logger.debug(f"🗓️ 通知过滤历史信号: {stock_code} {signal_date} (距今{(current_date - signal_date).days}天)") + continue + total_signals += 1 # 根据新的信号格式提取信息 - confirmation_date = signal.get('confirmation_date', signal['date']) - new_high_price = signal.get('new_high_price', signal['breakout_price']) - confirmation_days = signal.get('confirmation_days', 0) + details = signal.get('details', {}) + confirmation_date = details.get('confirmation_date', signal['date']) + new_high_price = details.get('new_high_price', signal['price']) + confirmation_days = details.get('confirmation_days', 0) all_signal_details.append({ 'stock_code': stock_code, @@ -383,20 +408,26 @@ class NotificationManager: 'pattern_date': signal['date'], # 模式形成日期 'confirmation_date': confirmation_date, # 回踩确认日期 'price': new_high_price, # 创新高价格 - 'original_breakout_price': signal['breakout_price'], # 原突破价 - 'yin_high': signal.get('yin_high', 0), # 阴线最高价 - 'turnover': signal.get('turnover_ratio', 0), - 'breakout_pct': signal.get('breakout_pct', 0), - 'ema20_status': '✅上方' if signal.get('above_ema20', False) else '❌下方', + 'original_breakout_price': signal['price'], # 原突破价 + 'yin_high': details.get('yin_high', 0), # 阴线最高价 + 'turnover': details.get('turnover_ratio', 0), + 'breakout_pct': details.get('breakout_pct', 0), + 'ema20_status': '✅上方' if details.get('above_ema20', False) else '❌下方', 'confirmation_days': confirmation_days, - 'pullback_distance': signal.get('pullback_distance', 0), - 'is_new_format': signal.get('new_high_confirmed', False) # 是否为新格式信号 + 'pullback_distance': details.get('pullback_distance', 0), + 'is_new_format': details.get('new_high_confirmed', False) # 是否为新格式信号 }) # 如果没有信号,直接返回 if total_signals == 0: + logger.info("📱 通知过滤: 没有一周内的新信号,不发送通知") return True + # 记录过滤后的信号统计 + original_signals = sum(len(signals) for stock_results in all_signals.values() for signals in stock_results.values()) + if original_signals > total_signals: + logger.info(f"📅 通知信号过滤: 原始{original_signals}个 → 一周内{total_signals}个") + # 按10个信号为一组分批发送 signals_per_group = 10 total_groups = math.ceil(total_signals / signals_per_group) diff --git a/start_market_scanner.sh b/start_market_scanner.sh new file mode 100755 index 0000000..559fda7 --- /dev/null +++ b/start_market_scanner.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# 市场扫描服务启动脚本 + +echo "🚀 启动市场扫描定时任务服务..." + +# 创建日志目录 +mkdir -p /app/logs + +# 安装cron +echo "📦 安装cron服务..." +apt-get update && apt-get install -y cron + +# 复制crontab配置 +echo "⏰ 配置定时任务..." +cp /app/crontab/market-scanner /etc/cron.d/market-scanner + +# 设置权限 +chmod 0644 /etc/cron.d/market-scanner + +# 启动cron服务 +echo "🔄 启动cron守护进程..." +service cron start + +# 显示已配置的任务 +echo "📋 已配置的定时任务:" +crontab -l 2>/dev/null || echo "使用系统cron配置: /etc/cron.d/market-scanner" +cat /etc/cron.d/market-scanner + +# 记录启动信息 +echo "$(date): 市场扫描服务启动完成" >> /app/logs/scanner_startup.log + +echo "✅ 市场扫描定时任务服务启动完成" +echo "📊 扫描参数: MARKET_SCAN_STOCKS=${MARKET_SCAN_STOCKS:-200}" +echo "📝 日志文件: /app/logs/market_scanner.log" +echo "⏰ Cron日志: /app/logs/cron.log" + +# 执行一次初始扫描 +echo "🔍 执行初始市场扫描..." +python /app/market_scanner.py ${MARKET_SCAN_STOCKS:-200} + +# 保持容器运行并显示日志 +echo "👁️ 监控日志输出..." +tail -f /app/logs/market_scanner.log /app/logs/cron.log \ No newline at end of file diff --git a/test_cache_optimization.py b/test_cache_optimization.py deleted file mode 100644 index 160e0a1..0000000 --- a/test_cache_optimization.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -""" -测试股票名称获取的缓存优化 -""" - -import sys -from pathlib import Path -import time - -# 添加src目录到路径 -current_dir = Path(__file__).parent -src_dir = current_dir / "src" -sys.path.insert(0, str(src_dir)) - -from loguru import logger -from src.data.data_fetcher import ADataFetcher - -def test_stock_name_cache(): - """测试股票名称缓存机制""" - logger.info("🧪 开始测试股票名称缓存优化") - - # 初始化数据获取器 - data_fetcher = ADataFetcher() - - # 测试股票列表 - test_stocks = ['000001.SZ', '000002.SZ', '600000.SH', '600036.SH', '000858.SZ'] - - # 第一次获取股票名称(会触发缓存构建) - logger.info("📊 第一次获取股票名称(构建缓存)...") - start_time = time.time() - - names_first = {} - for stock_code in test_stocks: - name = data_fetcher.get_stock_name(stock_code) - names_first[stock_code] = name - logger.info(f" {stock_code}: {name}") - - first_duration = time.time() - start_time - logger.info(f"⏱️ 第一次获取耗时: {first_duration:.2f}秒") - - # 等待一秒 - time.sleep(1) - - # 第二次获取股票名称(应该从缓存读取) - logger.info("📊 第二次获取股票名称(从缓存读取)...") - start_time = time.time() - - names_second = {} - for stock_code in test_stocks: - name = data_fetcher.get_stock_name(stock_code) - names_second[stock_code] = name - logger.info(f" {stock_code}: {name}") - - second_duration = time.time() - start_time - logger.info(f"⏱️ 第二次获取耗时: {second_duration:.2f}秒") - - # 比较结果 - logger.info("📈 性能对比:") - if second_duration < first_duration: - speedup = first_duration / second_duration - logger.info(f"✅ 缓存优化成功! 第二次比第一次快 {speedup:.1f}x") - else: - logger.warning("❌ 缓存优化效果不明显") - - # 验证数据一致性 - consistent = names_first == names_second - logger.info(f"🔍 数据一致性: {'✅ 一致' if consistent else '❌ 不一致'}") - - # 显示缓存状态 - logger.info(f"📦 当前缓存中的股票数量: {len(data_fetcher._stock_name_cache)}") - - return first_duration, second_duration, consistent - -if __name__ == "__main__": - # 设置日志 - logger.remove() - logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") - - print("=" * 60) - print("🧪 股票名称缓存优化测试") - print("=" * 60) - - try: - first_time, second_time, is_consistent = test_stock_name_cache() - - print("\n" + "=" * 60) - print("📊 测试结果总结:") - print(f" 第一次获取耗时: {first_time:.2f}秒") - print(f" 第二次获取耗时: {second_time:.2f}秒") - print(f" 性能提升倍数: {first_time/second_time:.1f}x") - print(f" 数据一致性: {'✅ 通过' if is_consistent else '❌ 失败'}") - print("=" * 60) - - except Exception as e: - logger.error(f"测试过程中发生错误: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/test_dingtalk.py b/test_dingtalk.py deleted file mode 100644 index b8f2af3..0000000 --- a/test_dingtalk.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -""" -测试钉钉通知功能 -""" - -import sys -import os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) - -from utils.notification import DingTalkNotifier, NotificationManager -import yaml - -def test_dingtalk_with_secret(): - """测试带加签的钉钉通知""" - print("🔧 测试钉钉加签功能...") - - # 测试加签生成 - webhook_url = "https://oapi.dingtalk.com/robot/send?access_token=YOUR_TOKEN" - secret = "SEC6e9dbd71d4addd2c4e673fb72d686293b342da5ae48da2f8ec788a68de99f981" - - notifier = DingTalkNotifier(webhook_url, secret) - - # 生成签名URL - signed_url = notifier._get_signed_url() - print(f"✅ 签名URL生成成功") - print(f"📄 原始URL: {webhook_url}") - print(f"🔐 签名URL: {signed_url}") - - # 检查URL格式 - if "timestamp=" in signed_url and "sign=" in signed_url: - print("✅ 加签参数正确添加") - else: - print("❌ 加签参数缺失") - return False - - return True - -def test_notification_manager(): - """测试通知管理器配置""" - print("\n🔧 测试通知管理器配置...") - - # 从配置文件读取配置 - try: - with open('config/config.yaml', 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - notification_config = config.get('notification', {}) - print(f"✅ 配置文件加载成功") - print(f"📄 钉钉配置: {notification_config.get('dingtalk', {})}") - - # 创建通知管理器 - notifier_manager = NotificationManager(notification_config) - - if notifier_manager.dingtalk_notifier: - print("✅ 钉钉通知器初始化成功") - if notifier_manager.dingtalk_notifier.secret: - print("✅ 加签密钥配置正确") - else: - print("❌ 加签密钥未配置") - return False - else: - print("❌ 钉钉通知器未启用") - return False - - return True - - except Exception as e: - print(f"❌ 配置测试失败: {e}") - return False - -def main(): - print("=" * 60) - print(" 钉钉通知功能测试") - print("=" * 60) - - # 测试加签功能 - test1_passed = test_dingtalk_with_secret() - - # 测试配置管理 - test2_passed = test_notification_manager() - - print("\n" + "=" * 60) - print("测试结果:") - print(f"🔐 加签功能测试: {'✅ 通过' if test1_passed else '❌ 失败'}") - print(f"⚙️ 配置管理测试: {'✅ 通过' if test2_passed else '❌ 失败'}") - - if test1_passed and test2_passed: - print("\n🎉 所有测试通过!钉钉通知功能配置正确") - print("💡 注意: 需要提供完整的webhook URL才能发送实际消息") - else: - print("\n❌ 部分测试失败,请检查配置") - - print("=" * 60) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_eastmoney_sectors.py b/test_eastmoney_sectors.py deleted file mode 100644 index 71383a9..0000000 --- a/test_eastmoney_sectors.py +++ /dev/null @@ -1,268 +0,0 @@ -#!/usr/bin/env python3 -""" -使用东财概念板块数据分析本周强势板块 -需要5000积分 -""" - -import sys -from pathlib import Path -import pandas as pd -from datetime import datetime, timedelta - -# 添加项目根目录到路径 -current_dir = Path(__file__).parent -sys.path.insert(0, str(current_dir)) - -from src.data.tushare_fetcher import TushareFetcher -from loguru import logger - - -def get_recent_trading_dates(days_back=5): - """获取最近的交易日期""" - dates = [] - current = datetime.now() - - while len(dates) < days_back: - # 排除周末 - if current.weekday() < 5: # 0-4是周一到周五 - dates.append(current.strftime('%Y%m%d')) - current -= timedelta(days=1) - - return sorted(dates) # 升序返回 - - -def analyze_eastmoney_concepts(fetcher: TushareFetcher): - """使用东财概念板块数据分析""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info("🚀 使用东财概念板块数据分析...") - - # 获取最近5个交易日 - trading_dates = get_recent_trading_dates(5) - logger.info(f"分析时间范围: {trading_dates[0]} 到 {trading_dates[-1]}") - - # 获取最新交易日的概念板块数据 - latest_date = trading_dates[-1] - - try: - # 使用东财概念板块接口 - dc_concepts = fetcher.pro.dc_index(trade_date=latest_date) - logger.info(f"获取到 {len(dc_concepts)} 个东财概念板块") - - if dc_concepts.empty: - logger.warning("未获取到东财概念板块数据") - return - - # 打印数据结构以便调试 - logger.info(f"数据列名: {list(dc_concepts.columns)}") - if not dc_concepts.empty: - logger.info(f"样本数据:\n{dc_concepts.head(2)}") - - # 检查涨跌幅字段名 - change_col = None - for col in ['pct_chg', 'pct_change', 'change_pct', 'chg_pct']: - if col in dc_concepts.columns: - change_col = col - break - - if change_col: - # 按涨跌幅排序 - dc_concepts = dc_concepts.sort_values(change_col, ascending=False) - else: - logger.warning("未找到涨跌幅字段,使用原始顺序") - change_col = 'code' # 使用code作为默认排序 - - print("\n" + "="*80) - print("📈 东财概念板块实时排行榜") - print("="*80) - # 显示表头 - if change_col != 'code': - print(f"{'排名':<4} {'概念名称':<25} {'涨跌幅':<10} {'概念代码':<15}") - else: - print(f"{'排名':<4} {'概念名称':<25} {'概念代码':<15}") - print("-" * 80) - - for i, (_, concept) in enumerate(dc_concepts.head(20).iterrows()): - rank = i + 1 - name = concept.get('name', 'N/A')[:23] + '..' if len(str(concept.get('name', 'N/A'))) > 23 else concept.get('name', 'N/A') - code = concept.get('ts_code', 'N/A') - - if change_col != 'code': - change_pct = f"{concept[change_col]:+.2f}%" if not pd.isna(concept.get(change_col, 0)) else "N/A" - print(f"{rank:<4} {name:<25} {change_pct:<10} {code:<15}") - else: - print(f"{rank:<4} {name:<25} {code:<15}") - - # 强势概念TOP10 - if change_col != 'code': - print(f"\n🚀 强势概念板块TOP10:") - for i, (_, concept) in enumerate(dc_concepts.head(10).iterrows()): - change_val = concept.get(change_col, 0) - if not pd.isna(change_val): - print(f" {i+1:2d}. {concept.get('name', 'N/A')}: {change_val:+.2f}%") - - # 弱势概念TOP10 - print(f"\n📉 弱势概念板块TOP10:") - weak_concepts = dc_concepts.tail(10).iloc[::-1] # 反转顺序 - for i, (_, concept) in enumerate(weak_concepts.iterrows()): - change_val = concept.get(change_col, 0) - if not pd.isna(change_val): - print(f" {i+1:2d}. {concept.get('name', 'N/A')}: {change_val:+.2f}%") - else: - print(f"\n📋 概念板块列表(前10个):") - for i, (_, concept) in enumerate(dc_concepts.head(10).iterrows()): - print(f" {i+1:2d}. {concept.get('name', 'N/A')} ({concept.get('ts_code', 'N/A')})") - - return dc_concepts - - except Exception as e: - logger.error(f"获取东财概念板块数据失败: {e}") - return None - - except Exception as e: - logger.error(f"分析东财概念板块失败: {e}") - return None - - -def analyze_concept_trend(fetcher: TushareFetcher, concept_codes=None): - """分析概念板块的趋势(多日对比)""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info("📊 分析概念板块趋势...") - - # 获取最近5个交易日 - trading_dates = get_recent_trading_dates(5) - - # 如果没有指定概念代码,获取当日表现最好的前10个 - if concept_codes is None: - latest_concepts = analyze_eastmoney_concepts(fetcher) - if latest_concepts is not None and not latest_concepts.empty: - concept_codes = latest_concepts.head(5)['code'].tolist() - else: - logger.warning("无法获取概念代码") - return - - print(f"\n" + "="*80) - print("📈 热门概念板块多日趋势分析") - print("="*80) - - for concept_code in concept_codes: - concept_trend = [] - - for date in trading_dates: - try: - # 获取特定日期的概念数据 - daily_data = fetcher.pro.dc_index( - trade_date=date, - ts_code=concept_code - ) - - if not daily_data.empty: - # 检查数据结构 - logger.debug(f"概念 {concept_code} 在 {date} 的数据字段: {list(daily_data.columns)}") - - # 东财概念数据可能没有close字段,使用其他字段替代 - close_value = daily_data.iloc[0].get('total_mv', 1) # 使用总市值代替 - if close_value == 0: - close_value = 1 # 避免除零 - - concept_trend.append({ - 'date': date, - 'name': daily_data.iloc[0]['name'], - 'close': close_value, - 'pct_chg': daily_data.iloc[0]['pct_change'] - }) - - except Exception as e: - logger.debug(f"获取概念 {concept_code} 在 {date} 的数据失败: {e}") - continue - - # 输出趋势 - if concept_trend: - concept_name = concept_trend[0]['name'] - print(f"\n📊 {concept_name} ({concept_code}) 近5日走势:") - - # 计算总涨跌幅 - if len(concept_trend) >= 2: - start_close = concept_trend[0]['close'] - end_close = concept_trend[-1]['close'] - - if start_close != 0 and start_close is not None: - total_change = (end_close - start_close) / start_close * 100 - print(f" 总涨跌幅: {total_change:+.2f}%") - else: - print(f" 总涨跌幅: 无法计算(起始值为0)") - - # 显示每日数据 - for data in concept_trend: - print(f" {data['date']}: {data['pct_chg']:+6.2f}% (指数: {data['close']:8.2f})") - - print("\n" + "="*80) - - except Exception as e: - logger.error(f"分析概念趋势失败: {e}") - - -def get_concept_constituents(fetcher: TushareFetcher, concept_code: str): - """获取概念板块成分股""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info(f"获取概念 {concept_code} 的成分股...") - - # 尝试通过概念板块获取成分股 - try: - # 使用concept_detail接口(如果可用) - constituents = fetcher.pro.concept_detail(id=concept_code) - - if not constituents.empty: - print(f"\n📋 概念成分股 ({len(constituents)}只):") - for _, stock in constituents.head(10).iterrows(): - print(f" {stock['ts_code']}: {stock.get('name', 'N/A')}") - else: - logger.warning(f"概念 {concept_code} 无成分股数据") - - except Exception as e: - logger.error(f"获取概念成分股失败: {e}") - - except Exception as e: - logger.error(f"获取概念成分股失败: {e}") - - -def main(): - """主函数""" - logger.info("🚀 开始使用东财概念板块数据分析...") - - # 初始化Tushare数据获取器 - token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" - fetcher = TushareFetcher(token=token) - - # 1. 分析当日概念板块表现 - concepts_data = analyze_eastmoney_concepts(fetcher) - - # 2. 分析热门概念的多日趋势 - if concepts_data is not None and not concepts_data.empty: - print("\n" + "="*80 + "\n") - - # 获取表现最好的前3个概念进行趋势分析 - top_concepts = concepts_data.head(3)['ts_code'].tolist() - analyze_concept_trend(fetcher, top_concepts) - - # 3. 获取第一个概念的成分股示例 - # top_concept_code = top_concepts[0] if top_concepts else None - # if top_concept_code: - # get_concept_constituents(fetcher, top_concept_code) - - logger.info("✅ 分析完成!") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_optimized_notification.py b/test_optimized_notification.py deleted file mode 100644 index bdc4966..0000000 --- a/test_optimized_notification.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python3 -""" -测试优化后的钉钉通知格式(创新高回踩确认版) -""" - -import sys -from pathlib import Path -from datetime import datetime - -# 添加项目根目录到路径 -current_dir = Path(__file__).parent -sys.path.insert(0, str(current_dir)) - -from src.utils.notification import NotificationManager -from loguru import logger - - -def test_optimized_notification(): - """ - 测试优化后的钉钉通知格式 - """ - logger.info("🚀 开始测试优化后的钉钉通知格式...") - - # 配置通知管理器(测试模式,不实际发送) - notification_config = { - 'dingtalk': { - 'enabled': False, # 设置为True并提供真实webhook进行实际测试 - 'webhook_url': 'https://oapi.dingtalk.com/robot/send?access_token=TEST_TOKEN' - } - } - - notification_manager = NotificationManager(notification_config) - - # 测试1: 单个策略信号通知 - logger.info("\n📊 测试1: 单个策略信号通知") - test_signal_data = { - 'stock_code': '000001.SZ', - 'stock_name': '平安银行', - 'timeframe': 'daily', - 'signal_type': '两阳+阴+阳突破(创新高回踩确认)', - 'price': 14.50, - 'signal_date': '2024-01-11', - 'additional_info': { - 'pattern_date': '2024-01-04', - 'breakout_price': 11.60, - 'new_high_price': 14.50, - 'new_high_date': '2024-01-10', - 'yin_high': 11.20, - 'confirmation_date': '2024-01-11', - 'confirmation_days': 7, - 'pullback_distance': -0.89, - 'yang1_entity_ratio': 0.60, - 'yang2_entity_ratio': 0.67, - 'final_yang_entity_ratio': 0.89, - 'breakout_pct': 3.57, - 'turnover_ratio': 2.50, - 'above_ema20': True - } - } - - try: - # 模拟发送单个信号通知 - if notification_config['dingtalk']['enabled']: - success = notification_manager.send_strategy_signal(**test_signal_data) - logger.info(f"单个信号通知发送: {'✅成功' if success else '❌失败'}") - else: - logger.info("单个信号通知格式测试完成(未实际发送)") - except Exception as e: - logger.error(f"单个信号通知测试失败: {e}") - - # 测试2: 策略汇总通知 - logger.info("\n📊 测试2: 策略汇总通知") - test_summary_data = { - '000001.SZ': { - 'daily': [ - { - 'stock_name': '平安银行', - 'date': '2024-01-04', - 'breakout_price': 11.60, - 'new_high_price': 14.50, - 'new_high_date': '2024-01-10', - 'confirmation_date': '2024-01-11', - 'confirmation_days': 7, - 'pullback_distance': -0.89, - 'yin_high': 11.20, - 'turnover_ratio': 2.5, - 'breakout_pct': 3.57, - 'above_ema20': True, - 'new_high_confirmed': True # 标记为新格式 - } - ] - }, - '000002.SZ': { - 'daily': [ - { - 'stock_name': '万科A', - 'date': '2024-01-05', - 'breakout_price': 9.80, - 'new_high_price': 12.30, - 'new_high_date': '2024-01-09', - 'confirmation_date': '2024-01-12', - 'confirmation_days': 7, - 'pullback_distance': -1.2, - 'yin_high': 9.60, - 'turnover_ratio': 3.2, - 'breakout_pct': 2.08, - 'above_ema20': True, - 'new_high_confirmed': True # 标记为新格式 - } - ] - } - } - - scan_stats = { - 'total_scanned': 100, - 'data_source': '双数据源合并' - } - - try: - # 模拟发送汇总通知 - if notification_config['dingtalk']['enabled']: - success = notification_manager.send_strategy_summary(test_summary_data, scan_stats) - logger.info(f"汇总通知发送: {'✅成功' if success else '❌失败'}") - else: - logger.info("汇总通知格式测试完成(未实际发送)") - except Exception as e: - logger.error(f"汇总通知测试失败: {e}") - - # 测试3: 回踩提醒通知 - logger.info("\n📊 测试3: 回踩提醒通知") - test_pullback_alerts = [ - { - 'stock_code': '000001.SZ', - 'stock_name': '平安银行', - 'signal_date': '2024-01-11', - 'current_date': '2024-01-18', - 'timeframe': 'daily', - 'yin_high': 11.20, - 'breakout_price': 11.60, - 'current_price': 11.15, - 'current_low': 11.10, - 'pullback_pct': -4.5, - 'distance_to_yin_high': -0.45, - 'days_since_signal': 7, - 'alert_type': 'pullback_to_yin_high' - } - ] - - try: - # 模拟发送回踩提醒 - if notification_config['dingtalk']['enabled']: - success = notification_manager.send_pullback_alerts(test_pullback_alerts) - logger.info(f"回踩提醒发送: {'✅成功' if success else '❌失败'}") - else: - logger.info("回踩提醒格式测试完成(未实际发送)") - except Exception as e: - logger.error(f"回踩提醒测试失败: {e}") - - # 显示消息格式预览 - print("\n" + "="*80) - print("📱 优化后的钉钉消息格式预览") - print("="*80) - - print("\n🎯 单个信号通知示例:") - print("标题: 🎯 两阳+阴+阳突破(创新高回踩确认)信号确认") - print("内容包含: 股票信息、创新高回踩确认详情、技术指标、操作建议等") - - print("\n📊 汇总通知示例:") - print("标题: 🎯 K线形态策略信号汇总") - print("内容包含: 扫描统计、确认信号详情(模式日期+确认日期+创新高价等)") - - print("\n⚠️ 回踩提醒示例:") - print("标题: ⚠️ 已确认信号二次回踩提醒") - print("内容包含: 已确认信号的二次回踩情况、支撑分析建议等") - - print("\n✅ 钉钉消息优化完成!") - print("主要改进:") - print("- 突出创新高回踩确认逻辑") - print("- 详细展示时间线(模式日期→创新高日期→确认日期)") - print("- 增加操作建议和风险提示") - print("- 区分新旧格式信号,向下兼容") - - -def main(): - """主函数""" - logger.info("🚀 开始钉钉通知优化测试...") - test_optimized_notification() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_pullback_feature.py b/test_pullback_feature.py deleted file mode 100644 index aa5dbaf..0000000 --- a/test_pullback_feature.py +++ /dev/null @@ -1,179 +0,0 @@ -#!/usr/bin/env python3 -""" -测试价格回踩阴线最高点提醒功能 -""" - -import sys -from pathlib import Path -import pandas as pd -from datetime import datetime, timedelta - -# 添加src目录到路径 -current_dir = Path(__file__).parent -src_dir = current_dir / "src" -sys.path.insert(0, str(src_dir)) - -from loguru import logger -from src.utils.config_loader import config_loader -from src.data.data_fetcher import ADataFetcher -from src.utils.notification import NotificationManager -from src.strategy.kline_pattern_strategy import KLinePatternStrategy - - -def create_test_signal(): - """创建一个测试用的K线形态信号""" - test_signal = { - 'stock_code': '000001.SZ', - 'stock_name': '平安银行', - 'date': datetime.now() - timedelta(days=5), # 5天前的信号 - 'timeframe': 'daily', - 'breakout_price': 15.50, # 突破价格 - 'yin_high': 15.20, # 阴线最高点 - 'pattern_type': '两阳+阴+阳突破' - } - return test_signal - - -def create_test_pullback_data(): - """创建模拟回踩数据 - 模拟价格回踩到阴线最高点附近""" - dates = pd.date_range(start=datetime.now() - timedelta(days=3), end=datetime.now(), freq='D') - - # 模拟价格回踩阴线最高点的情况 - test_data = [] - yin_high = 15.20 # 阴线最高点价格 - initial_price = 15.50 # 突破价格 - - # 设计价格走势:从突破价格逐步回调到阴线最高点附近 - price_path = [15.45, 15.35, 15.25, 15.18] # 最后一个价格接近阴线最高点 - - for i, date in enumerate(dates): - if i < len(price_path): - close_price = price_path[i] - else: - close_price = yin_high - 0.02 # 稍低于阴线最高点 - - low_price = close_price - 0.03 # 最低价更接近阴线最高点 - - test_data.append({ - 'trade_date': date, - 'open': close_price + 0.02, - 'high': close_price + 0.05, - 'low': low_price, - 'close': close_price, - 'volume': 1000000 - }) - - return pd.DataFrame(test_data) - - -def test_pullback_detection(): - """测试回踩检测功能""" - logger.info("🧪 开始测试价格回踩阴线最高点提醒功能") - - # 初始化配置 - config = config_loader.load_config() - - # 初始化组件 - data_fetcher = ADataFetcher() - notification_config = config.get('notification', {}) - notification_manager = NotificationManager(notification_config) - - # 获取K线形态策略配置 - kline_config = config.get('strategy', {}).get('kline_pattern', {}) - strategy = KLinePatternStrategy(data_fetcher, notification_manager, kline_config) - - # 创建测试信号 - test_signal = create_test_signal() - logger.info(f"📊 创建测试信号: {test_signal['stock_code']}({test_signal['stock_name']})") - logger.info(f" - 信号日期: {test_signal['date']}") - logger.info(f" - 突破价格: {test_signal['breakout_price']}") - logger.info(f" - 阴线最高点: {test_signal['yin_high']}") - - # 添加到策略的监控列表 - strategy.add_triggered_signal(test_signal) - logger.info("✅ 测试信号已添加到监控列表") - - # 创建模拟回踩数据 - test_pullback_data = create_test_pullback_data() - logger.info(f"📈 创建模拟K线数据,共{len(test_pullback_data)}条记录") - - print("\n模拟K线数据:") - for _, row in test_pullback_data.iterrows(): - print(f" {row['trade_date'].strftime('%Y-%m-%d')}: " - f"开{row['open']:.2f} 高{row['high']:.2f} 低{row['low']:.2f} 收{row['close']:.2f}") - - # 检测回踩情况 - logger.info("🔍 开始检测回踩情况...") - pullback_alerts = strategy.check_pullback_signals(test_signal['stock_code'], test_pullback_data) - - if pullback_alerts: - logger.info(f"⚠️ 检测到 {len(pullback_alerts)} 个回踩提醒") - for i, alert in enumerate(pullback_alerts, 1): - logger.info(f" 提醒{i}: {alert['stock_code']} - 当前价格{alert['current_price']:.2f}," - f"回调{alert['pullback_pct']:.2f}%,距阴线高点{alert['distance_to_yin_high']:.2f}%") - - # 测试通知发送(如果启用了钉钉通知) - if notification_config.get('dingtalk', {}).get('enabled', False): - logger.info("📱 测试发送回踩提醒通知...") - success = notification_manager.send_pullback_alerts(pullback_alerts) - if success: - logger.info("✅ 回踩提醒通知发送成功") - else: - logger.warning("❌ 回踩提醒通知发送失败") - else: - logger.info("ℹ️ 钉钉通知未启用,跳过通知发送测试") - else: - logger.info("ℹ️ 未检测到回踩情况") - - # 测试完整的监控流程 - logger.info("\n🔍 测试完整的回踩监控流程...") - all_pullback_alerts = strategy.monitor_pullback_for_triggered_signals() - - logger.info("🎯 测试完成!") - if all_pullback_alerts: - logger.info(f"✅ 成功检测到 {len(all_pullback_alerts)} 个回踩提醒") - else: - logger.info("ℹ️ 当前监控中无回踩情况") - - -def test_strategy_config(): - """测试策略配置是否正确加载""" - logger.info("🔧 测试策略配置加载...") - - config = config_loader.load_config() - kline_config = config.get('strategy', {}).get('kline_pattern', {}) - - logger.info("📋 当前K线形态策略配置:") - logger.info(f" - 启用状态: {kline_config.get('enabled', False)}") - logger.info(f" - 前两阳线实体比例: {kline_config.get('min_entity_ratio', 0.55)}") - logger.info(f" - 最后阳线实体比例: {kline_config.get('final_yang_min_ratio', 0.40)}") - logger.info(f" - 回踩容忍度: {kline_config.get('pullback_tolerance', 0.02)}") - logger.info(f" - 监控天数: {kline_config.get('monitor_days', 30)}") - logger.info(f" - 支持时间周期: {kline_config.get('timeframes', ['daily'])}") - - -if __name__ == "__main__": - # 设置日志 - logger.remove() - logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") - - print("=" * 60) - print("🧪 价格回踩阴线最高点提醒功能测试") - print("=" * 60) - - try: - # 测试配置加载 - test_strategy_config() - print() - - # 测试回踩检测功能 - test_pullback_detection() - - except Exception as e: - logger.error(f"测试过程中发生错误: {e}") - import traceback - traceback.print_exc() - - print("\n" + "=" * 60) - print("🏁 测试结束") - print("=" * 60) \ No newline at end of file diff --git a/test_sentiment.py b/test_sentiment.py deleted file mode 100644 index de20743..0000000 --- a/test_sentiment.py +++ /dev/null @@ -1,135 +0,0 @@ -#!/usr/bin/env python3 -""" -舆情数据功能测试脚本 -""" - -import sys -from pathlib import Path - -# 将src目录添加到Python路径 -current_dir = Path(__file__).parent -src_dir = current_dir / "src" -sys.path.insert(0, str(src_dir)) - -from src.data.sentiment_fetcher import SentimentFetcher - - -def test_sentiment_features(): - """测试舆情功能""" - print("="*60) - print(" A股舆情数据功能测试") - print("="*60) - - fetcher = SentimentFetcher() - - # 1. 测试北向资金 - print("\n🌊 1. 北向资金数据测试") - print("-" * 30) - current_flow = fetcher.get_north_flow_current() - if not current_flow.empty: - row = current_flow.iloc[0] - print(f"总净流入: {row.get('net_tgt', 'N/A')} 万元") - print(f"沪股通: {row.get('net_hgt', 'N/A')} 万元") - print(f"深股通: {row.get('net_sgt', 'N/A')} 万元") - print(f"更新时间: {row.get('trade_time', 'N/A')}") - else: - print("未获取到当前北向资金数据") - - # 2. 测试热门股票 - print("\n🔥 2. 热门股票数据测试") - print("-" * 30) - hot_stocks = fetcher.get_popular_stocks_east_100() - if not hot_stocks.empty: - print(f"东财人气股票TOP5:") - for idx, row in hot_stocks.head(5).iterrows(): - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - print(f" {idx + 1}. {code} - {name}") - else: - print("未获取到热门股票数据") - - # 3. 测试龙虎榜 - print("\n🐉 3. 龙虎榜数据测试") - print("-" * 30) - dragon_tiger = fetcher.get_dragon_tiger_list_daily() - if not dragon_tiger.empty: - print(f"今日龙虎榜 (共{len(dragon_tiger)}只股票):") - for idx, row in dragon_tiger.head(5).iterrows(): - code = row.get('stock_code', 'N/A') - name = row.get('short_name', 'N/A') - reason = row.get('reason', 'N/A') - print(f" {idx + 1}. {code} - {name}") - print(f" 上榜原因: {reason}") - else: - print("今日无龙虎榜数据") - - # 4. 测试热门概念 - print("\n💡 4. 热门概念数据测试") - print("-" * 30) - try: - hot_concepts = fetcher.get_hot_concept_ths_20() - if not hot_concepts.empty: - print(f"同花顺热门概念TOP5:") - for idx, row in hot_concepts.head(5).iterrows(): - name = row.get('concept_name', 'N/A') - change_pct = row.get('change_pct', 'N/A') - print(f" {idx + 1}. {name} (涨跌幅: {change_pct}%)") - else: - print("未获取到热门概念数据") - except Exception as e: - print(f"热门概念获取失败: {e}") - - # 5. 测试市场舆情综合概览 - print("\n📊 5. 市场舆情综合概览测试") - print("-" * 30) - try: - overview = fetcher.get_market_sentiment_overview() - if overview: - print("✅ 市场舆情概览获取成功") - - # 北向资金 - if 'north_flow' in overview: - north_data = overview['north_flow'] - print(f"北向资金: 总净流入 {north_data.get('net_total', 'N/A')} 万元") - - # 热门股票 - if 'hot_stocks_east' in overview and not overview['hot_stocks_east'].empty: - count = len(overview['hot_stocks_east']) - print(f"热门股票: 获取到 {count} 只") - - # 龙虎榜 - if 'dragon_tiger' in overview and not overview['dragon_tiger'].empty: - count = len(overview['dragon_tiger']) - print(f"龙虎榜: 获取到 {count} 只") - else: - print("市场舆情概览获取失败") - except Exception as e: - print(f"市场舆情概览测试失败: {e}") - - # 6. 测试个股舆情分析 - print("\n🔍 6. 个股舆情分析测试") - print("-" * 30) - test_stock = "000001.SZ" # 平安银行 - try: - analysis = fetcher.analyze_stock_sentiment(test_stock) - if 'error' not in analysis: - print(f"✅ {test_stock} 舆情分析成功") - print(f"东财人气榜: {'在榜' if analysis.get('in_popular_east', False) else '不在榜'}") - print(f"同花顺热门榜: {'在榜' if analysis.get('in_hot_ths', False) else '不在榜'}") - - if 'dragon_tiger' in analysis and not analysis['dragon_tiger'].empty: - print("✅ 今日上榜龙虎榜") - else: - print("❌ 今日未上榜龙虎榜") - else: - print(f"个股舆情分析失败: {analysis.get('error', '未知错误')}") - except Exception as e: - print(f"个股舆情分析测试失败: {e}") - - print("\n" + "="*60) - print(" 舆情数据功能测试完成") - print("="*60) - - -if __name__ == "__main__": - test_sentiment_features() \ No newline at end of file diff --git a/test_strategy.py b/test_strategy.py deleted file mode 100644 index 07e7416..0000000 --- a/test_strategy.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env python3 -""" -K线形态策略测试脚本 -""" - -import sys -from pathlib import Path -import pandas as pd -import numpy as np - -# 将src目录添加到Python路径 -current_dir = Path(__file__).parent -src_dir = current_dir / "src" -sys.path.insert(0, str(src_dir)) - -from src.data.data_fetcher import ADataFetcher -from src.utils.notification import NotificationManager -from src.strategy.kline_pattern_strategy import KLinePatternStrategy - - -def create_test_kline_data(): - """创建测试K线数据 - 包含两阳线+阴线+阳线形态""" - dates = pd.date_range('2023-01-01', periods=10, freq='D') - - # 模拟K线数据 - test_data = { - 'trade_date': dates, - 'open': [10.0, 10.5, 11.0, 12.0, 11.5, 11.0, 10.5, 11.0, 11.8, 12.5], - 'high': [10.8, 11.2, 11.8, 12.5, 12.0, 11.5, 11.2, 11.5, 12.2, 13.0], - 'low': [9.8, 10.3, 10.8, 11.8, 10.8, 10.5, 10.2, 10.8, 11.6, 12.3], - 'close':[10.6, 11.0, 11.5, 12.2, 11.0, 10.8, 11.2, 11.3, 12.0, 12.8], - 'volume': [1000] * 10 - } - - df = pd.DataFrame(test_data) - print("测试K线数据:") - print(df) - print() - - return df - - -def test_pattern_detection(): - """测试形态检测功能""" - print("="*60) - print(" K线形态检测功能测试") - print("="*60) - - # 创建测试配置 - strategy_config = { - 'min_entity_ratio': 0.55, - 'timeframes': ['daily'], - 'scan_stocks_count': 10, - 'analysis_days': 60 - } - - notification_config = { - 'dingtalk': { - 'enabled': False, - 'webhook_url': '' - } - } - - # 初始化组件 - data_fetcher = ADataFetcher() - notification_manager = NotificationManager(notification_config) - strategy = KLinePatternStrategy(data_fetcher, notification_manager, strategy_config) - - print("1. 策略信息:") - print(strategy.get_strategy_summary()) - - print("\n2. 测试K线特征计算:") - test_df = create_test_kline_data() - df_with_features = strategy.calculate_kline_features(test_df) - - print("添加特征后的数据:") - relevant_cols = ['trade_date', 'open', 'high', 'low', 'close', 'is_yang', 'is_yin', 'entity_ratio'] - print(df_with_features[relevant_cols]) - - print("\n3. 测试形态检测:") - signals = strategy.detect_pattern(df_with_features) - - if signals: - print(f"发现 {len(signals)} 个形态信号:") - for i, signal in enumerate(signals, 1): - print(f"\n信号 {i}:") - print(f" 日期: {signal['date']}") - print(f" 形态: {signal['pattern_type']}") - print(f" 突破价格: {signal['breakout_price']:.2f}") - print(f" 突破幅度: {signal['breakout_pct']:.2f}%") - print(f" 阳线1实体比例: {signal['yang1_entity_ratio']:.1%}") - print(f" 阳线2实体比例: {signal['yang2_entity_ratio']:.1%}") - else: - print("未发现形态信号") - - print("\n4. 测试真实股票数据:") - test_stocks = ["000001.SZ", "000002.SZ"] # 平安银行、万科A - - for stock_code in test_stocks: - print(f"\n分析股票: {stock_code}") - try: - results = strategy.analyze_stock(stock_code, days=30) # 分析最近30天 - - total_signals = sum(len(signals) for signals in results.values()) - print(f"总信号数: {total_signals}") - - for timeframe, signals in results.items(): - if signals: - print(f"{timeframe}: {len(signals)}个信号") - # 显示最新信号 - latest = signals[-1] - print(f" 最新: {latest['date']} {latest['breakout_price']:.2f}元") - else: - print(f"{timeframe}: 无信号") - - except Exception as e: - print(f"分析失败: {e}") - - print("\n5. 测试通知功能:") - try: - # 测试日志通知 - notification_manager.send_strategy_signal( - stock_code="TEST001", - stock_name="测试股票", - timeframe="daily", - signal_type="测试信号", - price=15.50, - additional_info={ - "阳线1实体比例": "65%", - "阳线2实体比例": "70%", - "突破幅度": "2.5%" - } - ) - print("✅ 通知功能测试完成(日志记录)") - - except Exception as e: - print(f"❌ 通知功能测试失败: {e}") - - print("\n" + "="*60) - print(" 策略测试完成") - print("="*60) - - -def test_weekly_monthly_conversion(): - """测试周线月线转换功能""" - print("\n测试周线/月线数据转换:") - - # 创建更多天数的测试数据 - dates = pd.date_range('2023-01-01', periods=50, freq='D') - - test_data = { - 'trade_date': dates, - 'open': np.random.uniform(10, 15, 50), - 'high': np.random.uniform(15, 20, 50), - 'low': np.random.uniform(8, 12, 50), - 'close': np.random.uniform(10, 15, 50), - 'volume': np.random.randint(1000, 5000, 50) - } - - daily_df = pd.DataFrame(test_data) - - strategy_config = {'min_entity_ratio': 0.55, 'timeframes': ['daily']} - notification_config = {'dingtalk': {'enabled': False}} - - data_fetcher = ADataFetcher() - notification_manager = NotificationManager(notification_config) - strategy = KLinePatternStrategy(data_fetcher, notification_manager, strategy_config) - - # 测试周线转换 - weekly_df = strategy._convert_to_weekly(daily_df) - print(f"日线数据: {len(daily_df)} 条") - print(f"周线数据: {len(weekly_df)} 条") - - # 测试月线转换 - monthly_df = strategy._convert_to_monthly(daily_df) - print(f"月线数据: {len(monthly_df)} 条") - - if not weekly_df.empty: - print("\n周线数据样本:") - print(weekly_df[['trade_date', 'open', 'high', 'low', 'close']].head()) - - -if __name__ == "__main__": - test_pattern_detection() - test_weekly_monthly_conversion() \ No newline at end of file diff --git a/test_strong_sectors_advanced.py b/test_strong_sectors_advanced.py deleted file mode 100644 index 361b575..0000000 --- a/test_strong_sectors_advanced.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python3 -""" -高级强势板块筛选器 -筛选条件: -1. 本周收阳(周涨幅>0) -2. 周线级别创阶段新高(20周新高) -3. 成交额巨大(超过1000亿) -""" - -import sys -from pathlib import Path -import pandas as pd -import numpy as np -from datetime import datetime, timedelta - -# 添加项目根目录到路径 -current_dir = Path(__file__).parent -sys.path.insert(0, str(current_dir)) - -from src.data.tushare_fetcher import TushareFetcher -from loguru import logger - - -def get_trading_dates(days_back=100): - """获取过去N个交易日""" - dates = [] - current = datetime.now() - - while len(dates) < days_back: - if current.weekday() < 5: # 周一到周五 - dates.append(current.strftime('%Y%m%d')) - current -= timedelta(days=1) - - return sorted(dates) # 升序返回 - - -def filter_index_concepts(ths_concepts): - """过滤掉指数型板块""" - index_keywords = [ - '成份股', '样本股', '成分股', '50', '100', '300', '500', '1000', - '上证', '深证', '中证', '创业板', '科创板', '北证', - 'ETF', '指数', 'Index', '基准' - ] - - def is_concept_plate(name: str) -> bool: - name_lower = name.lower() - for keyword in index_keywords: - if keyword.lower() in name_lower: - return False - return True - - original_count = len(ths_concepts) - filtered_concepts = ths_concepts[ths_concepts['name'].apply(is_concept_plate)] - filtered_count = len(filtered_concepts) - - logger.info(f"过滤指数型板块: {original_count} -> {filtered_count} 个") - return filtered_concepts - - -def analyze_concept_strength(fetcher: TushareFetcher, concept_info, trading_dates): - """分析单个概念的强势特征""" - ts_code = concept_info['ts_code'] - name = concept_info['name'] - - try: - # 获取过去20周的数据(约100个交易日) - start_date = trading_dates[0] - end_date = trading_dates[-1] - - daily_data = fetcher.pro.ths_daily( - ts_code=ts_code, - start_date=start_date, - end_date=end_date - ) - - if daily_data.empty or len(daily_data) < 20: - return None - - # 按日期排序 - daily_data = daily_data.sort_values('trade_date') - daily_data.reset_index(drop=True, inplace=True) - - # 1. 计算本周涨幅(最近5个交易日) - recent_data = daily_data.tail(5) - if len(recent_data) < 2: - return None - - week_start_close = recent_data.iloc[0]['close'] - week_end_close = recent_data.iloc[-1]['close'] - week_change = (week_end_close - week_start_close) / week_start_close * 100 - - # 2. 检查是否本周收阳 - is_weekly_positive = week_change > 0 - - # 3. 计算20周新高(约100个交易日) - current_close = daily_data.iloc[-1]['close'] - past_20weeks_high = daily_data['high'].max() - is_20week_high = current_close >= past_20weeks_high * 0.99 # 允许1%的误差 - - # 4. 计算成交额(最近5日平均) - recent_turnover = recent_data['vol'].mean() * recent_data['close'].mean() # 简化计算 - turnover_100yi = recent_turnover / 100000000 # 转换为亿元 - - # 5. 计算技术指标 - # RSI相对强弱指数 - rsi = calculate_rsi(daily_data['close'].values) - - # 20日均线趋势 - ma20 = daily_data['close'].rolling(20).mean() - ma20_trend = (ma20.iloc[-1] - ma20.iloc[-10]) / ma20.iloc[-10] * 100 if len(ma20) >= 20 else 0 - - # 波动率 - volatility = daily_data['pct_change'].std() * np.sqrt(250) # 年化波动率 - - return { - 'ts_code': ts_code, - 'name': name, - 'week_change': week_change, - 'is_weekly_positive': is_weekly_positive, - 'is_20week_high': is_20week_high, - 'avg_turnover_yi': turnover_100yi, - 'current_close': current_close, - 'rsi': rsi, - 'ma20_trend': ma20_trend, - 'volatility': volatility, - 'data_length': len(daily_data) - } - - except Exception as e: - logger.debug(f"分析 {name} 失败: {e}") - return None - - -def calculate_rsi(prices, period=14): - """计算RSI指标""" - try: - delta = np.diff(prices) - gain = np.where(delta > 0, delta, 0) - loss = np.where(delta < 0, -delta, 0) - - avg_gain = np.mean(gain[-period:]) if len(gain) >= period else 0 - avg_loss = np.mean(loss[-period:]) if len(loss) >= period else 0 - - if avg_loss == 0: - return 100 - - rs = avg_gain / avg_loss - rsi = 100 - (100 / (1 + rs)) - return rsi - except: - return 50 # 默认值 - - -def find_strong_sectors(fetcher: TushareFetcher): - """寻找强势板块""" - try: - logger.info("🔍 开始寻找强势板块...") - - # 获取交易日期 - trading_dates = get_trading_dates(100) - logger.info(f"分析周期: {trading_dates[0]} 到 {trading_dates[-1]} (100个交易日)") - - # 获取同花顺概念列表 - ths_concepts = fetcher.pro.ths_index(exchange='A', type='N') - if ths_concepts.empty: - logger.error("未获取到概念数据") - return - - # 过滤指数型概念 - ths_concepts = filter_index_concepts(ths_concepts) - - # 分析概念强度 - strong_concepts = [] - total_concepts = min(80, len(ths_concepts)) # 分析前80个概念 - logger.info(f"分析 {total_concepts} 个概念...") - - for i, (_, concept) in enumerate(ths_concepts.head(total_concepts).iterrows()): - if i % 10 == 0: - logger.info(f"进度: {i+1}/{total_concepts}") - - result = analyze_concept_strength(fetcher, concept, trading_dates) - if result: - strong_concepts.append(result) - - if not strong_concepts: - logger.warning("未找到符合条件的强势板块") - return - - # 转换为DataFrame - df = pd.DataFrame(strong_concepts) - - # 强势板块筛选 - logger.info("🚀 应用强势板块筛选条件...") - - # 条件1:本周收阳 - weekly_positive = df[df['is_weekly_positive']] - logger.info(f"本周收阳概念: {len(weekly_positive)} 个") - - # 条件2:20周新高 - new_high_concepts = df[df['is_20week_high']] - logger.info(f"20周新高概念: {len(new_high_concepts)} 个") - - # 条件3:成交额超过1000亿(这里设置为10亿,因为单个概念1000亿太高) - high_turnover = df[df['avg_turnover_yi'] >= 10] - logger.info(f"成交额超过10亿概念: {len(high_turnover)} 个") - - # 综合强势板块(满足至少2个条件) - df['strength_score'] = ( - df['is_weekly_positive'].astype(int) + - df['is_20week_high'].astype(int) + - (df['avg_turnover_yi'] >= 10).astype(int) - ) - - # 按强势得分和周涨幅排序 - strong_sectors = df[df['strength_score'] >= 2].sort_values(['strength_score', 'week_change'], ascending=[False, False]) - - # 显示结果 - display_strong_sectors(df, strong_sectors, weekly_positive, new_high_concepts, high_turnover) - - return df - - except Exception as e: - logger.error(f"寻找强势板块失败: {e}") - - -def display_strong_sectors(df, strong_sectors, weekly_positive, new_high_concepts, high_turnover): - """显示强势板块分析结果""" - - print("\n" + "="*100) - print("🔍 强势板块综合分析报告") - print("="*100) - - # 1. 综合强势板块(满足多个条件) - if not strong_sectors.empty: - print(f"\n🚀 综合强势板块TOP10(满足2+条件):") - print(f"{'排名':<4} {'概念名称':<25} {'周涨幅':<10} {'强势分':<8} {'RSI':<8} {'成交额(亿)':<12} {'条件':<20}") - print("-" * 100) - - for i, (_, concept) in enumerate(strong_sectors.head(10).iterrows()): - rank = i + 1 - name = concept['name'][:23] + '..' if len(concept['name']) > 23 else concept['name'] - week_chg = f"{concept['week_change']:+.2f}%" - score = f"{concept['strength_score']}/3" - rsi = f"{concept['rsi']:.1f}" - turnover = f"{concept['avg_turnover_yi']:.1f}" - - conditions = [] - if concept['is_weekly_positive']: - conditions.append("周阳") - if concept['is_20week_high']: - conditions.append("新高") - if concept['avg_turnover_yi'] >= 10: - conditions.append("大额") - - condition_str = "+".join(conditions) - - print(f"{rank:<4} {name:<25} {week_chg:<10} {score:<8} {rsi:<8} {turnover:<12} {condition_str:<20}") - - # 2. 分类展示 - print(f"\n📊 分类统计:") - print(f" 本周收阳: {len(weekly_positive)} 个") - print(f" 20周新高: {len(new_high_concepts)} 个") - print(f" 大成交额: {len(high_turnover)} 个") - print(f" 综合强势: {len(strong_sectors)} 个") - - # 3. 本周收阳TOP10 - if not weekly_positive.empty: - top_weekly = weekly_positive.sort_values('week_change', ascending=False) - print(f"\n📈 本周收阳TOP10:") - for i, (_, concept) in enumerate(top_weekly.head(10).iterrows()): - print(f" {i+1:2d}. {concept['name']}: {concept['week_change']:+.2f}%") - - # 4. 20周新高概念 - if not new_high_concepts.empty: - print(f"\n🎯 20周新高概念TOP10:") - new_high_sorted = new_high_concepts.sort_values('week_change', ascending=False) - for i, (_, concept) in enumerate(new_high_sorted.head(10).iterrows()): - print(f" {i+1:2d}. {concept['name']}: {concept['week_change']:+.2f}% (RSI: {concept['rsi']:.1f})") - - # 5. 大成交额概念 - if not high_turnover.empty: - print(f"\n💰 大成交额概念TOP10:") - turnover_sorted = high_turnover.sort_values('avg_turnover_yi', ascending=False) - for i, (_, concept) in enumerate(turnover_sorted.head(10).iterrows()): - print(f" {i+1:2d}. {concept['name']}: {concept['avg_turnover_yi']:.1f}亿 ({concept['week_change']:+.2f}%)") - - # 6. 技术面强势 - strong_tech = df[(df['rsi'] > 60) & (df['ma20_trend'] > 0)].sort_values('week_change', ascending=False) - if not strong_tech.empty: - print(f"\n📊 技术面强势概念TOP10:") - for i, (_, concept) in enumerate(strong_tech.head(10).iterrows()): - print(f" {i+1:2d}. {concept['name']}: RSI {concept['rsi']:.1f}, MA20趋势 {concept['ma20_trend']:+.2f}%") - - -def main(): - """主函数""" - logger.info("🚀 开始高级强势板块分析...") - - # 初始化Tushare数据获取器 - token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" - fetcher = TushareFetcher(token=token) - - # 寻找强势板块 - result_df = find_strong_sectors(fetcher) - - if result_df is not None: - logger.info("✅ 强势板块分析完成!") - else: - logger.error("❌ 强势板块分析失败!") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_ths_concepts.py b/test_ths_concepts.py deleted file mode 100644 index cb20414..0000000 --- a/test_ths_concepts.py +++ /dev/null @@ -1,318 +0,0 @@ -#!/usr/bin/env python3 -""" -测试同花顺概念板块数据 -""" - -import sys -from pathlib import Path -import pandas as pd -from datetime import datetime, timedelta - -# 添加项目根目录到路径 -current_dir = Path(__file__).parent -sys.path.insert(0, str(current_dir)) - -from src.data.tushare_fetcher import TushareFetcher -from loguru import logger - - -def explore_ths_interfaces(fetcher: TushareFetcher): - """探索同花顺相关接口""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info("🔍 探索同花顺概念板块相关接口...") - - # 1. 获取同花顺概念指数列表 - try: - logger.info("1. 获取同花顺概念指数列表...") - ths_index = fetcher.pro.ths_index(exchange='A', type='N') - logger.info(f"获取到 {len(ths_index)} 个同花顺概念指数") - - if not ths_index.empty: - print("\n📋 同花顺概念指数(前20个):") - print(f"{'代码':<15} {'名称':<30} {'发布日期':<12}") - print("-" * 60) - for _, index in ths_index.head(20).iterrows(): - code = index['ts_code'] - name = index['name'][:28] + '..' if len(index['name']) > 28 else index['name'] - pub_date = index.get('list_date', 'N/A') - print(f"{code:<15} {name:<30} {pub_date:<12}") - - return ths_index - - except Exception as e: - logger.error(f"获取同花顺概念指数失败: {e}") - - # 2. 尝试获取同花顺概念成分股 - try: - logger.info("\n2. 测试获取同花顺概念成分股...") - # 尝试获取一个概念的成分股 - sample_concept = "885311.TI" # 智能电网 - ths_member = fetcher.pro.ths_member(ts_code=sample_concept) - logger.info(f"获取智能电网概念成分股: {len(ths_member)} 只") - - if not ths_member.empty: - print(f"\n📊 智能电网概念成分股(前10只):") - for _, stock in ths_member.head(10).iterrows(): - print(f" {stock['code']}: {stock.get('name', 'N/A')}") - - except Exception as e: - logger.error(f"获取同花顺概念成分股失败: {e}") - - # 3. 尝试获取同花顺概念日行情 - try: - logger.info("\n3. 测试获取同花顺概念日行情...") - today = datetime.now().strftime('%Y%m%d') - yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y%m%d') - - ths_daily = fetcher.pro.ths_daily( - ts_code="885311.TI", # 智能电网 - start_date=yesterday, - end_date=today - ) - logger.info(f"获取智能电网概念日行情: {len(ths_daily)} 条记录") - - if not ths_daily.empty: - print(f"\n📈 智能电网概念近期行情:") - print(ths_daily[['trade_date', 'close', 'pct_chg', 'vol']].head()) - - except Exception as e: - logger.error(f"获取同花顺概念日行情失败: {e}") - - # 4. 探索其他可能的同花顺接口 - try: - logger.info("\n4. 探索同花顺行业分类...") - ths_industry = fetcher.pro.ths_index(exchange='A', type='I') - logger.info(f"获取到 {len(ths_industry)} 个同花顺行业指数") - - if not ths_industry.empty: - print(f"\n📊 同花顺行业指数(前10个):") - for _, index in ths_industry.head(10).iterrows(): - print(f" {index['ts_code']}: {index['name']}") - - except Exception as e: - logger.error(f"获取同花顺行业分类失败: {e}") - - except Exception as e: - logger.error(f"探索同花顺接口失败: {e}") - - -def get_ths_concept_7day_ranking(fetcher: TushareFetcher): - """获取同花顺概念板块过去7个交易日排名""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info("📊 计算同花顺概念板块过去7个交易日涨幅...") - - # 获取同花顺概念指数列表 - ths_concepts = fetcher.pro.ths_index(exchange='A', type='N') - if ths_concepts.empty: - logger.error("未获取到同花顺概念指数") - return - - # 过滤掉指数型板块,只保留真正的概念板块 - logger.info("过滤指数型板块...") - index_keywords = [ - '成份股', '样本股', '成分股', '50', '100', '300', '500', '1000', - '上证', '深证', '中证', '创业板', '科创板', '北证', - 'ETF', '指数', 'Index', '基准' - ] - - def is_concept_plate(name: str) -> bool: - """判断是否为真正的概念板块""" - name_lower = name.lower() - for keyword in index_keywords: - if keyword.lower() in name_lower: - return False - return True - - # 过滤数据 - original_count = len(ths_concepts) - ths_concepts = ths_concepts[ths_concepts['name'].apply(is_concept_plate)] - filtered_count = len(ths_concepts) - - logger.info(f"过滤结果: {original_count} -> {filtered_count} 个概念板块(剔除{original_count - filtered_count}个指数型板块)") - - # 获取过去7个交易日 - trading_dates = [] - current = datetime.now() - - while len(trading_dates) < 7: - if current.weekday() < 5: # 周一到周五 - trading_dates.append(current.strftime('%Y%m%d')) - current -= timedelta(days=1) - - trading_dates.reverse() # 升序排列,最早的日期在前 - - if len(trading_dates) < 2: - logger.warning("交易日不足") - return - - start_date = trading_dates[0] # 7个交易日前 - end_date = trading_dates[-1] # 最新交易日 - - logger.info(f"分析周期: {start_date} 到 {end_date} (过去7个交易日)") - - # 计算各概念的7日涨幅 - concept_performance = [] - - # 限制分析数量,避免API调用过多 - sample_concepts = ths_concepts.head(50) # 分析前50个概念(过滤后数量减少) - logger.info(f"分析前 {len(sample_concepts)} 个同花顺概念...") - - for _, concept in sample_concepts.iterrows(): - ts_code = concept['ts_code'] - name = concept['name'] - - try: - # 获取过去7个交易日行情数据 - daily_data = fetcher.pro.ths_daily( - ts_code=ts_code, - start_date=start_date, - end_date=end_date - ) - - if not daily_data.empty: - # 检查数据结构 - logger.debug(f"{name} 数据字段: {list(daily_data.columns)}") - - if len(daily_data) >= 2: - # 按日期排序 - daily_data = daily_data.sort_values('trade_date') - - start_close = daily_data.iloc[0]['close'] - end_close = daily_data.iloc[-1]['close'] - - if start_close > 0: - period_change = (end_close - start_close) / start_close * 100 - - # 检查涨跌幅字段名 - pct_change_col = None - for col in ['pct_chg', 'pct_change', 'change']: - if col in daily_data.columns: - pct_change_col = col - break - - latest_daily_change = daily_data.iloc[-1][pct_change_col] if pct_change_col else 0 - - concept_performance.append({ - 'ts_code': ts_code, - 'name': name, - 'period_change': period_change, - 'start_close': start_close, - 'end_close': end_close, - 'latest_daily_change': latest_daily_change, - 'trading_days': len(daily_data) - }) - - logger.debug(f"{name}: 过去7日{period_change:+.2f}%") - - except Exception as e: - logger.debug(f"获取 {name} 数据失败: {e}") - continue - - # 显示结果 - if concept_performance: - df_ths = pd.DataFrame(concept_performance) - df_ths = df_ths.sort_values('period_change', ascending=False) - - print(f"\n" + "="*80) - print("📈 同花顺概念板块过去7个交易日涨幅排行榜") - print("="*80) - print(f"{'排名':<4} {'概念名称':<30} {'7日涨幅':<12} {'今日涨幅':<12} {'指数代码':<15}") - print("-" * 80) - - for i, (_, concept) in enumerate(df_ths.iterrows()): - rank = i + 1 - name = concept['name'][:28] + '..' if len(concept['name']) > 28 else concept['name'] - period_chg = f"{concept['period_change']:+.2f}%" - daily_chg = f"{concept['latest_daily_change']:+.2f}%" - ts_code = concept['ts_code'] - - print(f"{rank:<4} {name:<30} {period_chg:<12} {daily_chg:<12} {ts_code:<15}") - - # 强势概念TOP15 - print(f"\n🚀 同花顺强势概念TOP15:") - for i, (_, concept) in enumerate(df_ths.head(15).iterrows()): - print(f" {i+1:2d}. {concept['name']}: {concept['period_change']:+.2f}%") - - # 弱势概念TOP10 - print(f"\n📉 同花顺弱势概念TOP10:") - weak_concepts = df_ths.tail(10).iloc[::-1] # 反转顺序 - for i, (_, concept) in enumerate(weak_concepts.iterrows()): - print(f" {i+1:2d}. {concept['name']}: {concept['period_change']:+.2f}%") - - return df_ths - - else: - logger.warning("未能计算同花顺概念涨幅") - - except Exception as e: - logger.error(f"获取同花顺概念排名失败: {e}") - - -def compare_concept_sources(fetcher: TushareFetcher): - """对比东财和同花顺概念数据""" - try: - logger.info("📊 对比东财 vs 同花顺概念数据...") - - # 获取东财概念数量 - today = datetime.now().strftime('%Y%m%d') - dc_concepts = fetcher.pro.dc_index(trade_date=today) - dc_count = len(dc_concepts) if not dc_concepts.empty else 0 - - # 获取同花顺概念数量 - ths_concepts = fetcher.pro.ths_index(exchange='A', type='N') - ths_count = len(ths_concepts) if not ths_concepts.empty else 0 - - print(f"\n📊 概念板块数据源对比:") - print(f" 东财概念板块: {dc_count} 个") - print(f" 同花顺概念: {ths_count} 个") - - # 数据特点对比 - print(f"\n📈 数据特点对比:") - print(f" 东财概念:") - print(f" - 更新频率: 每日更新") - print(f" - 数据字段: 涨跌幅、市值、上涨下跌股数等") - print(f" - 适用场景: 实时概念轮动分析") - - print(f" 同花顺概念:") - print(f" - 更新频率: 每日更新") - print(f" - 数据字段: 指数价格、涨跌幅、成交量等") - print(f" - 适用场景: 概念指数走势分析") - - except Exception as e: - logger.error(f"对比概念数据源失败: {e}") - - -def main(): - """主函数""" - logger.info("🚀 开始探索同花顺概念板块数据...") - - # 初始化Tushare数据获取器 - token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" - fetcher = TushareFetcher(token=token) - - # 1. 探索同花顺接口 - ths_index = explore_ths_interfaces(fetcher) - - print("\n" + "="*80 + "\n") - - # 2. 计算同花顺概念过去7个交易日排名 - get_ths_concept_7day_ranking(fetcher) - - print("\n" + "="*80 + "\n") - - # 3. 对比数据源 - compare_concept_sources(fetcher) - - logger.info("✅ 探索完成!") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_tushare_sectors.py b/test_tushare_sectors.py deleted file mode 100644 index 010f880..0000000 --- a/test_tushare_sectors.py +++ /dev/null @@ -1,238 +0,0 @@ -#!/usr/bin/env python3 -""" -使用Tushare直接获取板块数据的测试 -""" - -import sys -from pathlib import Path -import pandas as pd -from datetime import datetime, timedelta - -# 添加项目根目录到路径 -current_dir = Path(__file__).parent -sys.path.insert(0, str(current_dir)) - -from src.data.tushare_fetcher import TushareFetcher -from loguru import logger - - -def get_concept_sectors(fetcher: TushareFetcher): - """获取概念板块数据""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info("尝试获取概念板块数据...") - - # 1. 尝试获取概念板块列表 - try: - concept_list = fetcher.pro.concept() - logger.info(f"获取到 {len(concept_list)} 个概念板块") - - if not concept_list.empty: - print("概念板块列表(前10个):") - for _, concept in concept_list.head(10).iterrows(): - print(f" {concept['code']}: {concept['name']}") - - except Exception as e: - logger.error(f"获取概念板块列表失败: {e}") - - # 2. 尝试获取同花顺概念指数 - try: - ths_concept = fetcher.pro.ths_index(exchange='A', type='N') - logger.info(f"获取同花顺概念指数: {len(ths_concept)} 个") - - if not ths_concept.empty: - print("\n同花顺概念指数(前10个):") - for _, index in ths_concept.head(10).iterrows(): - print(f" {index['ts_code']}: {index['name']}") - - except Exception as e: - logger.error(f"获取同花顺概念指数失败: {e}") - - # 3. 尝试获取行业指数 - try: - industry_index = fetcher.pro.index_basic(market='SW') - logger.info(f"获取申万行业指数: {len(industry_index)} 个") - - if not industry_index.empty: - print("\n申万行业指数(前10个):") - for _, index in industry_index.head(10).iterrows(): - print(f" {index['ts_code']}: {index['name']}") - - except Exception as e: - logger.error(f"获取申万行业指数失败: {e}") - - except Exception as e: - logger.error(f"获取板块数据失败: {e}") - - -def analyze_hot_concepts(fetcher: TushareFetcher): - """分析热门概念板块""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info("分析热门概念板块...") - - # 获取今日涨跌停统计 - today = datetime.now().strftime('%Y%m%d') - - try: - # 获取涨停股票 - limit_up = fetcher.pro.limit_list(trade_date=today, limit_type='U') - logger.info(f"今日涨停股票: {len(limit_up)} 只") - - if not limit_up.empty: - print(f"\n今日涨停股票(前10只):") - for _, stock in limit_up.head(10).iterrows(): - print(f" {stock['ts_code']}: {stock['name']} (+{stock['pct_chg']:.2f}%)") - - # 分析涨停股票的行业分布 - if 'industry' in limit_up.columns: - industry_counts = limit_up['industry'].value_counts() - print(f"\n涨停股票行业分布:") - for industry, count in industry_counts.head(5).items(): - print(f" {industry}: {count}只") - - except Exception as e: - logger.error(f"获取涨停数据失败: {e}") - - # 获取龙虎榜数据 - try: - top_list = fetcher.pro.top_list(trade_date=today) - logger.info(f"今日龙虎榜: {len(top_list)} 只股票") - - if not top_list.empty: - print(f"\n今日龙虎榜股票(前5只):") - for _, stock in top_list.head(5).iterrows(): - print(f" {stock['ts_code']}: {stock['name']} 净买入: {stock['amount']:.0f}万元") - - except Exception as e: - logger.error(f"获取龙虎榜数据失败: {e}") - - except Exception as e: - logger.error(f"分析热门概念失败: {e}") - - -def get_sector_performance_direct(fetcher: TushareFetcher): - """直接通过指数数据获取板块表现""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return - - logger.info("通过指数数据分析板块表现...") - - # 获取申万一级行业指数 - try: - sw_index = fetcher.pro.index_basic(market='SW', level='L1') - logger.info(f"获取申万一级行业指数: {len(sw_index)} 个") - - if sw_index.empty: - logger.warning("未获取到申万行业指数") - return - - # 获取最近两个交易日的指数行情 - end_date = datetime.now().strftime('%Y%m%d') - start_date = (datetime.now() - timedelta(days=7)).strftime('%Y%m%d') - - sector_performance = [] - - for _, index in sw_index.head(15).iterrows(): # 分析前15个行业 - ts_code = index['ts_code'] - name = index['name'] - - try: - # 获取指数行情 - index_data = fetcher.pro.index_daily( - ts_code=ts_code, - start_date=start_date, - end_date=end_date - ) - - if not index_data.empty and len(index_data) >= 2: - # 计算涨跌幅 - latest = index_data.iloc[0] - previous = index_data.iloc[1] - - change_pct = (latest['close'] - previous['close']) / previous['close'] * 100 - - sector_performance.append({ - 'name': name, - 'code': ts_code, - 'change_pct': change_pct, - 'latest_close': latest['close'], - 'volume': latest['vol'] - }) - - logger.debug(f"{name}: {change_pct:+.2f}%") - - except Exception as e: - logger.debug(f"获取 {name} 指数数据失败: {e}") - continue - - # 输出结果 - if sector_performance: - df = pd.DataFrame(sector_performance) - df = df.sort_values('change_pct', ascending=False) - - print("\n" + "="*60) - print("📈 申万行业指数表现排行") - print("="*60) - print(f"{'排名':<4} {'行业名称':<20} {'涨跌幅':<10} {'最新点位':<10}") - print("-" * 60) - - for i, (_, row) in enumerate(df.iterrows()): - rank = i + 1 - name = row['name'][:18] + '..' if len(row['name']) > 18 else row['name'] - change = f"{row['change_pct']:+.2f}%" - close = f"{row['latest_close']:.2f}" - - print(f"{rank:<4} {name:<20} {change:<10} {close:<10}") - - # 强势行业 - print(f"\n🚀 强势行业TOP5:") - for _, row in df.head(5).iterrows(): - print(f" {row['name']}: {row['change_pct']:+.2f}%") - - # 弱势行业 - print(f"\n📉 弱势行业TOP5:") - for _, row in df.tail(5).iterrows(): - print(f" {row['name']}: {row['change_pct']:+.2f}%") - - except Exception as e: - logger.error(f"获取申万指数失败: {e}") - - except Exception as e: - logger.error(f"分析板块表现失败: {e}") - - -def main(): - """主函数""" - logger.info("测试Tushare板块数据接口...") - - # 初始化Tushare数据获取器 - token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" - fetcher = TushareFetcher(token=token) - - # 1. 获取板块分类数据 - get_concept_sectors(fetcher) - - print("\n" + "="*80 + "\n") - - # 2. 分析热门概念 - analyze_hot_concepts(fetcher) - - print("\n" + "="*80 + "\n") - - # 3. 通过指数直接获取板块表现 - get_sector_performance_direct(fetcher) - - logger.info("测试完成!") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_weekly_concept_ranking.py b/test_weekly_concept_ranking.py deleted file mode 100644 index 50e3e42..0000000 --- a/test_weekly_concept_ranking.py +++ /dev/null @@ -1,244 +0,0 @@ -#!/usr/bin/env python3 -""" -按本周总涨幅排名东财概念板块 -""" - -import sys -from pathlib import Path -import pandas as pd -from datetime import datetime, timedelta - -# 添加项目根目录到路径 -current_dir = Path(__file__).parent -sys.path.insert(0, str(current_dir)) - -from src.data.tushare_fetcher import TushareFetcher -from loguru import logger - - -def get_this_week_dates(): - """获取本周的交易日期(周一到周五)""" - today = datetime.now() - # 获取本周一 - monday = today - timedelta(days=today.weekday()) - # 获取本周五(或今天如果还没到周五) - friday = monday + timedelta(days=4) - if friday > today: - friday = today - - # 生成本周所有交易日 - dates = [] - current = monday - while current <= friday: - if current.weekday() < 5: # 周一到周五 - dates.append(current.strftime('%Y%m%d')) - current += timedelta(days=1) - - return dates - - -def calculate_weekly_concept_performance(fetcher: TushareFetcher): - """计算概念板块本周总涨幅""" - try: - if not fetcher.pro: - logger.error("需要Tushare Pro权限") - return None - - logger.info("🚀 计算概念板块本周总涨幅排名...") - - # 获取本周交易日 - week_dates = get_this_week_dates() - logger.info(f"本周交易日: {week_dates}") - - if len(week_dates) < 2: - logger.warning("本周交易日不足,无法计算周涨幅") - return None - - start_date = week_dates[0] # 周一 - end_date = week_dates[-1] # 最新交易日 - - logger.info(f"分析周期: {start_date} 到 {end_date}") - - # 获取周一的概念板块数据(基准) - logger.info(f"获取 {start_date} 的概念数据作为基准...") - start_concepts = fetcher.pro.dc_index(trade_date=start_date) - - # 获取最新交易日的概念板块数据 - logger.info(f"获取 {end_date} 的概念数据...") - end_concepts = fetcher.pro.dc_index(trade_date=end_date) - - if start_concepts.empty or end_concepts.empty: - logger.error("无法获取概念板块数据") - return None - - logger.info(f"周一概念数据: {len(start_concepts)} 个") - logger.info(f"最新概念数据: {len(end_concepts)} 个") - - # 计算本周涨幅 - weekly_performance = [] - - # 以最新数据为准,匹配周一数据 - for _, end_concept in end_concepts.iterrows(): - ts_code = end_concept['ts_code'] - name = end_concept['name'] - end_mv = end_concept['total_mv'] - - # 查找对应的周一数据 - start_data = start_concepts[start_concepts['ts_code'] == ts_code] - - if not start_data.empty: - start_mv = start_data.iloc[0]['total_mv'] - - # 计算本周总涨幅 - if start_mv > 0: - weekly_change = (end_mv - start_mv) / start_mv * 100 - - weekly_performance.append({ - 'ts_code': ts_code, - 'name': name, - 'weekly_change': weekly_change, - 'start_mv': start_mv, - 'end_mv': end_mv, - 'latest_daily_change': end_concept['pct_change'], - 'up_num': end_concept.get('up_num', 0), - 'down_num': end_concept.get('down_num', 0) - }) - - if not weekly_performance: - logger.error("无法计算概念板块周涨幅") - return None - - # 转换为DataFrame并按周涨幅排序 - df_weekly = pd.DataFrame(weekly_performance) - df_weekly = df_weekly.sort_values('weekly_change', ascending=False) - - logger.info(f"成功计算 {len(df_weekly)} 个概念板块的本周涨幅") - - return df_weekly - - except Exception as e: - logger.error(f"计算概念板块周涨幅失败: {e}") - return None - - -def display_weekly_ranking(df_weekly: pd.DataFrame): - """显示本周涨幅排名""" - if df_weekly is None or df_weekly.empty: - logger.error("无数据可显示") - return - - print("\n" + "="*100) - print("📈 东财概念板块本周涨幅排行榜") - print("="*100) - print(f"{'排名':<4} {'概念名称':<25} {'本周涨幅':<12} {'今日涨幅':<12} {'上涨股数':<8} {'下跌股数':<8} {'概念代码':<15}") - print("-" * 100) - - for i, (_, concept) in enumerate(df_weekly.head(30).iterrows()): - rank = i + 1 - name = concept['name'][:23] + '..' if len(concept['name']) > 23 else concept['name'] - weekly_chg = f"{concept['weekly_change']:+.2f}%" - daily_chg = f"{concept['latest_daily_change']:+.2f}%" - up_num = f"{concept['up_num']:.0f}" - down_num = f"{concept['down_num']:.0f}" - ts_code = concept['ts_code'] - - print(f"{rank:<4} {name:<25} {weekly_chg:<12} {daily_chg:<12} {up_num:<8} {down_num:<8} {ts_code:<15}") - - # 强势概念TOP15 - print(f"\n🚀 本周强势概念板块TOP15:") - for i, (_, concept) in enumerate(df_weekly.head(15).iterrows()): - print(f" {i+1:2d}. {concept['name']}: {concept['weekly_change']:+.2f}% (今日{concept['latest_daily_change']:+.2f}%)") - - # 弱势概念TOP10 - print(f"\n📉 本周弱势概念板块TOP10:") - weak_concepts = df_weekly.tail(10).iloc[::-1] # 反转顺序 - for i, (_, concept) in enumerate(weak_concepts.iterrows()): - print(f" {i+1:2d}. {concept['name']}: {concept['weekly_change']:+.2f}% (今日{concept['latest_daily_change']:+.2f}%)") - - # 统计分析 - print(f"\n📊 本周概念板块统计:") - total_concepts = len(df_weekly) - positive_concepts = len(df_weekly[df_weekly['weekly_change'] > 0]) - negative_concepts = len(df_weekly[df_weekly['weekly_change'] < 0]) - - print(f" 总概念数量: {total_concepts}") - print(f" 上涨概念: {positive_concepts} ({positive_concepts/total_concepts*100:.1f}%)") - print(f" 下跌概念: {negative_concepts} ({negative_concepts/total_concepts*100:.1f}%)") - print(f" 平均涨幅: {df_weekly['weekly_change'].mean():+.2f}%") - print(f" 涨幅中位数: {df_weekly['weekly_change'].median():+.2f}%") - - return df_weekly - - -def analyze_top_concepts_detail(fetcher: TushareFetcher, df_weekly: pd.DataFrame, top_n=5): - """分析TOP概念的详细趋势""" - if df_weekly is None or df_weekly.empty: - return - - logger.info(f"详细分析TOP{top_n}强势概念...") - - print(f"\n" + "="*80) - print(f"📊 TOP{top_n}强势概念详细分析") - print("="*80) - - week_dates = get_this_week_dates() - - for i, (_, concept) in enumerate(df_weekly.head(top_n).iterrows()): - concept_code = concept['ts_code'] - concept_name = concept['name'] - - print(f"\n📈 {i+1}. {concept_name} ({concept_code})") - print(f" 本周总涨幅: {concept['weekly_change']:+.2f}%") - - # 获取每日详细数据 - daily_data = [] - for date in week_dates: - try: - daily_concept = fetcher.pro.dc_index( - trade_date=date, - ts_code=concept_code - ) - - if not daily_concept.empty: - daily_data.append({ - 'date': date, - 'pct_change': daily_concept.iloc[0]['pct_change'], - 'total_mv': daily_concept.iloc[0]['total_mv'], - 'up_num': daily_concept.iloc[0].get('up_num', 0), - 'down_num': daily_concept.iloc[0].get('down_num', 0) - }) - - except Exception as e: - logger.debug(f"获取 {concept_code} 在 {date} 的数据失败: {e}") - continue - - # 显示每日走势 - if daily_data: - print(f" 每日走势:") - for data in daily_data: - print(f" {data['date']}: {data['pct_change']:+6.2f}% (上涨{data['up_num']:.0f}只/下跌{data['down_num']:.0f}只)") - - -def main(): - """主函数""" - logger.info("🚀 开始计算东财概念板块本周涨幅排名...") - - # 初始化Tushare数据获取器 - token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" - fetcher = TushareFetcher(token=token) - - # 1. 计算本周涨幅 - df_weekly = calculate_weekly_concept_performance(fetcher) - - # 2. 显示排名 - if df_weekly is not None: - display_weekly_ranking(df_weekly) - - # 3. 详细分析TOP5概念 - analyze_top_concepts_detail(fetcher, df_weekly, top_n=5) - - logger.info("✅ 分析完成!") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_weekly_sectors.py b/test_weekly_sectors.py deleted file mode 100644 index ac12920..0000000 --- a/test_weekly_sectors.py +++ /dev/null @@ -1,258 +0,0 @@ -#!/usr/bin/env python3 -""" -测试使用Tushare找出本周强势板块 -""" - -import sys -from pathlib import Path -import pandas as pd -from datetime import datetime, timedelta - -# 添加项目根目录到路径 -current_dir = Path(__file__).parent -sys.path.insert(0, str(current_dir)) - -from src.data.tushare_fetcher import TushareFetcher -from loguru import logger - - -def get_this_week_dates(): - """获取本周的开始和结束日期""" - today = datetime.now() - # 获取本周一 - monday = today - timedelta(days=today.weekday()) - # 获取本周五(或今天如果还没到周五) - friday = monday + timedelta(days=4) - if friday > today: - friday = today - - return monday.strftime('%Y%m%d'), friday.strftime('%Y%m%d') - - -def analyze_sector_performance(fetcher: TushareFetcher): - """分析行业板块表现""" - try: - logger.info("开始分析本周强势板块...") - - # 获取本周日期范围 - start_date, end_date = get_this_week_dates() - logger.info(f"分析时间范围: {start_date} 到 {end_date}") - - if not fetcher.pro: - logger.error("需要Tushare Pro权限才能获取行业数据") - return - - # 直接使用概念和行业数据分析 - logger.info("通过个股数据分析板块表现...") - - # 获取股票基本信息(包含行业分类) - try: - stock_basic = fetcher.pro.stock_basic( - exchange='', - list_status='L', - fields='ts_code,symbol,name,area,industry,market' - ) - logger.info(f"获取到 {len(stock_basic)} 只股票基本信息") - - except Exception as e: - logger.error(f"获取股票基本信息失败: {e}") - return - - # 按行业分组分析 - industry_performance = {} - - # 取样分析(避免API请求过多) - sample_stocks = stock_basic.sample(min(200, len(stock_basic))) # 随机取200只股票 - logger.info(f"随机抽样 {len(sample_stocks)} 只股票进行分析...") - - for _, stock in sample_stocks.iterrows(): - ts_code = stock['ts_code'] - industry = stock['industry'] - stock_name = stock['name'] - - if pd.isna(industry) or industry == '': - continue - - try: - # 获取个股本周数据 - stock_data = fetcher.pro.daily( - ts_code=ts_code, - start_date=start_date, - end_date=end_date - ) - - if not stock_data.empty and len(stock_data) >= 1: - # 获取最新价格和前一个交易日价格 - if len(stock_data) >= 2: - latest_close = stock_data.iloc[0]['close'] - prev_close = stock_data.iloc[1]['close'] - else: - # 如果只有一天数据,获取开盘价作为对比 - latest_close = stock_data.iloc[0]['close'] - prev_close = stock_data.iloc[0]['open'] - - if prev_close > 0: - change_pct = (latest_close - prev_close) / prev_close * 100 - - # 按行业归类 - if industry not in industry_performance: - industry_performance[industry] = { - 'stock_changes': [], - 'stock_count': 0, - 'stock_names': [] - } - - industry_performance[industry]['stock_changes'].append(change_pct) - industry_performance[industry]['stock_count'] += 1 - industry_performance[industry]['stock_names'].append(f"{stock_name}({change_pct:+.2f}%)") - - logger.debug(f"{stock_name} ({industry}): {change_pct:+.2f}%") - - except Exception as e: - logger.debug(f"分析个股 {ts_code} 失败: {e}") - continue - - # 计算各行业平均表现 - industry_results = [] - for industry, data in industry_performance.items(): - if data['stock_count'] >= 3: # 至少要有3只股票才参与排名 - avg_change = sum(data['stock_changes']) / len(data['stock_changes']) - industry_results.append({ - 'industry_name': industry, - 'avg_change_pct': avg_change, - 'stock_count': data['stock_count'], - 'best_stocks': sorted(data['stock_names'], key=lambda x: float(x.split('(')[1].split('%')[0]), reverse=True)[:3] - }) - - # 3. 分析结果 - if industry_results: - df_performance = pd.DataFrame(industry_results) - - # 按平均涨跌幅排序 - df_performance = df_performance.sort_values('avg_change_pct', ascending=False) - - logger.info("\n" + "="*80) - logger.info("📈 本周强势板块排行榜(基于抽样股票分析)") - logger.info("="*80) - - print(f"{'排名':<4} {'行业名称':<20} {'平均涨跌幅':<12} {'样本数量':<8} {'代表个股':<30}") - print("-" * 80) - - for i, (_, row) in enumerate(df_performance.head(15).iterrows()): - rank = i + 1 - industry_name = row['industry_name'][:18] + '..' if len(row['industry_name']) > 18 else row['industry_name'] - change_pct = f"{row['avg_change_pct']:+.2f}%" - stock_count = f"{row['stock_count']}只" - best_stock = row['best_stocks'][0] if row['best_stocks'] else "无数据" - - print(f"{rank:<4} {industry_name:<20} {change_pct:<12} {stock_count:<8} {best_stock:<30}") - - # 输出强势板块(涨幅前5) - top_sectors = df_performance.head(5) - logger.info(f"\n🚀 本周TOP5强势板块:") - for i, (_, sector) in enumerate(top_sectors.iterrows()): - logger.info(f" {i+1}. {sector['industry_name']}: {sector['avg_change_pct']:+.2f}% (样本{sector['stock_count']}只)") - for j, stock in enumerate(sector['best_stocks'][:3]): - logger.info(f" └─ {stock}") - - # 输出弱势板块(跌幅前5) - weak_sectors = df_performance.tail(5) - logger.info(f"\n📉 本周TOP5弱势板块:") - for i, (_, sector) in enumerate(weak_sectors.iterrows()): - logger.info(f" {i+1}. {sector['industry_name']}: {sector['avg_change_pct']:+.2f}% (样本{sector['stock_count']}只)") - - else: - logger.warning("未获取到有效的行业表现数据") - - except Exception as e: - logger.error(f"分析行业表现失败: {e}") - - -def get_sector_top_stocks(fetcher: TushareFetcher, industry_code: str, industry_name: str, limit: int = 5): - """获取指定板块的强势个股""" - try: - logger.info(f"获取 {industry_name} 板块的强势个股...") - - if not fetcher.pro: - return - - # 获取该行业的成分股 - try: - constituents = fetcher.pro.index_member(index_code=industry_code) - if constituents.empty: - logger.warning(f"{industry_name} 行业无成分股数据") - return - - stock_codes = constituents['con_code'].tolist()[:20] # 取前20只股票测试 - logger.info(f"{industry_name} 行业共 {len(constituents)} 只成分股,分析前 {len(stock_codes)} 只") - - except Exception as e: - logger.error(f"获取 {industry_name} 成分股失败: {e}") - return - - # 获取本周日期 - start_date, end_date = get_this_week_dates() - - # 分析各股票本周表现 - stock_performance = [] - - for stock_code in stock_codes[:10]: # 限制分析数量 - try: - # 获取个股本周数据 - stock_data = fetcher.pro.daily( - ts_code=stock_code, - start_date=start_date, - end_date=end_date - ) - - if not stock_data.empty and len(stock_data) >= 2: - latest_close = stock_data.iloc[0]['close'] - week_start_close = stock_data.iloc[-1]['close'] - week_change = (latest_close - week_start_close) / week_start_close * 100 - - # 获取股票名称 - stock_name = fetcher.get_stock_name(stock_code.split('.')[0]) - - stock_performance.append({ - 'stock_code': stock_code, - 'stock_name': stock_name, - 'week_change_pct': week_change, - 'latest_close': latest_close - }) - - except Exception as e: - logger.debug(f"分析个股 {stock_code} 失败: {e}") - continue - - # 输出该板块强势个股 - if stock_performance: - df_stocks = pd.DataFrame(stock_performance) - df_stocks = df_stocks.sort_values('week_change_pct', ascending=False) - - logger.info(f"\n📊 {industry_name} 板块强势个股 TOP{limit}:") - for _, stock in df_stocks.head(limit).iterrows(): - logger.info(f" {stock['stock_name']} ({stock['stock_code']}): {stock['week_change_pct']:+.2f}%") - - except Exception as e: - logger.error(f"获取 {industry_name} 板块个股失败: {e}") - - -def main(): - """主函数""" - logger.info("开始测试Tushare获取本周强势板块...") - - # 初始化Tushare数据获取器 - token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" - fetcher = TushareFetcher(token=token) - - # 分析板块表现 - analyze_sector_performance(fetcher) - - # 分析某个强势板块的个股(示例) - # get_sector_top_stocks(fetcher, "801010.SI", "农林牧渔") - - logger.info("测试完成!") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/web/mysql_app.py b/web/mysql_app.py index bf3e549..363957a 100644 --- a/web/mysql_app.py +++ b/web/mysql_app.py @@ -29,28 +29,9 @@ config_loader = ConfigLoader() @app.route('/') def index(): - """首页 - 显示信号概览和统计""" - try: - # 获取策略统计 - strategy_stats = db_manager.get_strategy_stats() - - # 获取最新信号(前10条) - signals_df = db_manager.get_latest_signals(limit=10) - - # 获取回踩提醒(前5条) - pullback_alerts = db_manager.get_pullback_alerts(days=7) - - # 当前时间 - current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - - return render_template('index.html', - strategy_stats=strategy_stats.to_dict('records') if not strategy_stats.empty else [], - signals=signals_df.to_dict('records') if not signals_df.empty else [], - pullback_alerts=pullback_alerts.to_dict('records') if not pullback_alerts.empty else [], - current_time=current_time) - except Exception as e: - logger.error(f"首页数据加载失败: {e}") - return render_template('error.html', error=str(e)) + """首页 - 重定向到信号页面""" + from flask import redirect, url_for + return redirect(url_for('signals')) @app.route('/signals') @@ -60,7 +41,7 @@ def signals(): # 获取查询参数 strategy_name = request.args.get('strategy', '') timeframe = request.args.get('timeframe', '') - days = int(request.args.get('days', 30)) + days = int(request.args.get('days', 7)) # 默认显示7天内的信号 page = int(request.args.get('page', 1)) per_page = int(request.args.get('per_page', 20)) @@ -154,20 +135,6 @@ def signals(): return render_template('error.html', error=str(e)) -@app.route('/pullbacks') -def pullbacks(): - """回踩监控页面""" - try: - days = int(request.args.get('days', 30)) - pullback_alerts = db_manager.get_pullback_alerts(days=days) - - return render_template('pullbacks.html', - pullback_alerts=pullback_alerts.to_dict('records') if not pullback_alerts.empty else [], - days=days) - except Exception as e: - logger.error(f"回踩监控页面数据加载失败: {e}") - return render_template('error.html', error=str(e)) - @app.route('/api/signals') def api_signals(): @@ -206,25 +173,11 @@ def api_stats(): return jsonify({'success': False, 'error': str(e)}) -@app.route('/api/pullbacks') -def api_pullbacks(): - """API接口 - 获取回踩提醒""" - try: - days = int(request.args.get('days', 7)) - pullback_alerts = db_manager.get_pullback_alerts(days=days) - - return jsonify({ - 'success': True, - 'data': pullback_alerts.to_dict('records') if not pullback_alerts.empty else [] - }) - except Exception as e: - logger.error(f"API获取回踩提醒失败: {e}") - return jsonify({'success': False, 'error': str(e)}) @app.template_filter('datetime_format') def datetime_format(value, format='%Y-%m-%d %H:%M'): - """日期时间格式化过滤器 - 转换为东八区时间""" + """日期时间格式化过滤器 - 智能处理时区转换""" if value is None: return '' @@ -245,10 +198,17 @@ def datetime_format(value, format='%Y-%m-%d %H:%M'): if isinstance(value, datetime) and value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) - # 转换为东八区时间 (UTC+8) + # 智能时区转换:检查是否已经是东八区时间 if isinstance(value, datetime) and value.tzinfo is not None: china_tz = timezone(timedelta(hours=8)) - value = value.astimezone(china_tz) + + # 如果已经是东八区时间,直接使用;否则转换 + if value.utcoffset() == timedelta(hours=8): + # 已经是东八区时间,无需转换 + pass + else: + # 转换为东八区时间 + value = value.astimezone(china_tz) return value.strftime(format) @@ -284,4 +244,4 @@ if __name__ == '__main__': print(f"📋 数据库: {db_manager.config.database}") print("=" * 60) - app.run(host='0.0.0.0', port=8080, debug=True) \ No newline at end of file + app.run(host='0.0.0.0', port=8081, debug=True) \ No newline at end of file