大版本更新
This commit is contained in:
parent
665fa4782e
commit
8320cb0d69
344
README.md
344
README.md
@ -1,328 +1,56 @@
|
||||
# A股AI分析Agent系统
|
||||
# Crypto Agent
|
||||
|
||||
基于AI Agent的股票智能分析系统,提供自然语言对话界面,支持实时行情查询、技术分析、基本面分析等功能。
|
||||
聚焦加密货币合约交易的智能交易系统,包含:
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **自然语言对话**:通过对话方式查询股票信息
|
||||
- **实时行情查询**:获取股票实时价格、涨跌幅等数据
|
||||
- **技术指标分析**:计算MA、MACD、RSI、KDJ、BOLL等技术指标
|
||||
- **基本面信息**:查询公司概况、行业、上市日期等
|
||||
- **数据可视化**:生成专业的K线图和技术指标图表
|
||||
- **技能插件系统**:可扩展的技能架构,支持动态启用/禁用
|
||||
- **对话历史**:保存和查看历史分析记录
|
||||
|
||||
## 技术栈
|
||||
|
||||
### 后端
|
||||
- **框架**:FastAPI
|
||||
- **AI Agent**:LangChain + 智谱AI GLM-4
|
||||
- **数据源**:Tushare
|
||||
- **缓存**:内存缓存(无需Redis)
|
||||
- **数据库**:SQLite
|
||||
- **语言**:Python 3.11+ (推荐 3.11 或 3.12)
|
||||
|
||||
### 前端
|
||||
- **框架**:Vue 3 (CDN版本)
|
||||
- **UI**:Bootstrap 5
|
||||
- **图表**:Lightweight Charts
|
||||
- **通信**:Fetch API
|
||||
- 行情采集与特征分析
|
||||
- LLM 信号分析与分流
|
||||
- 模拟盘执行
|
||||
- Bitget U 本位合约执行
|
||||
- 风控、停机保护、执行总控
|
||||
- Web 总控台、交易页、信号页
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
Stock_Agent/
|
||||
├── backend/ # 后端代码
|
||||
│ ├── app/
|
||||
│ │ ├── agent/ # AI Agent核心
|
||||
│ │ ├── api/ # API路由
|
||||
│ │ ├── models/ # 数据模型
|
||||
│ │ ├── services/ # 数据服务
|
||||
│ │ ├── skills/ # 技能插件
|
||||
│ │ ├── utils/ # 工具函数
|
||||
│ │ ├── config.py # 配置管理
|
||||
│ │ └── main.py # 应用入口
|
||||
│ └── requirements.txt # Python依赖
|
||||
├── frontend/ # 前端代码
|
||||
│ ├── css/ # 样式文件
|
||||
│ ├── js/ # JavaScript文件
|
||||
│ └── index.html # 主页面
|
||||
├── .env.example # 环境变量示例
|
||||
├── .gitignore
|
||||
└── README.md
|
||||
```
|
||||
- `backend/`: FastAPI 后端与交易执行逻辑
|
||||
- `frontend/`: 控制台、交易页、信号页等静态页面
|
||||
|
||||
## 快速开始
|
||||
## 运行方式
|
||||
|
||||
### ⚠️ 重要提示:Python 版本
|
||||
|
||||
**推荐使用 Python 3.11 或 3.12**。如果您使用 Python 3.13,可能会遇到依赖安装问题。
|
||||
|
||||
详细的安装问题解决方案请查看:[安装指南](docs/INSTALL_GUIDE.md)
|
||||
|
||||
### 1. 环境准备
|
||||
|
||||
**系统要求**:
|
||||
- Python 3.11 或 3.12(推荐)
|
||||
- 无需 Redis(使用内存缓存)
|
||||
|
||||
**获取API密钥**:
|
||||
- [Tushare](https://tushare.pro/):注册并获取Token
|
||||
- [智谱AI](https://open.bigmodel.cn/):注册并获取API Key
|
||||
|
||||
### 2. 安装依赖
|
||||
项目根目录:
|
||||
|
||||
```bash
|
||||
./start.sh
|
||||
```
|
||||
|
||||
或后端目录:
|
||||
|
||||
```bash
|
||||
# 进入后端目录
|
||||
cd backend
|
||||
|
||||
# 创建虚拟环境(使用 Python 3.11)
|
||||
python3.11 -m venv venv
|
||||
|
||||
# 激活虚拟环境
|
||||
# Windows:
|
||||
venv\Scripts\activate
|
||||
# macOS/Linux:
|
||||
source venv/bin/activate
|
||||
|
||||
# 安装依赖
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
./start.sh
|
||||
```
|
||||
|
||||
**如果遇到安装错误**,请查看 [安装指南](docs/INSTALL_GUIDE.md) 获取详细解决方案。
|
||||
## 核心页面
|
||||
|
||||
### 3. 配置环境变量
|
||||
- `/console`: 总控台
|
||||
- `/trading`: 模拟盘交易页
|
||||
- `/bitget-trading`: Bitget 实盘页
|
||||
- `/signals`: 信号页
|
||||
- `/docs`: FastAPI 文档
|
||||
|
||||
复制 `.env.example` 为 `.env` 并填写配置:
|
||||
## 环境变量
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
编辑 `.env` 文件:
|
||||
至少建议配置:
|
||||
|
||||
```env
|
||||
# Tushare API
|
||||
TUSHARE_TOKEN=your_tushare_token_here
|
||||
|
||||
# 智谱AI GLM-4 API
|
||||
ZHIPUAI_API_KEY=your_zhipuai_key_here
|
||||
|
||||
# 其他配置保持默认即可
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
|
||||
# 其他配置保持默认即可
|
||||
DEEPSEEK_API_KEY=
|
||||
ZHIPUAI_API_KEY=
|
||||
BITGET_API_KEY=
|
||||
BITGET_API_SECRET=
|
||||
BITGET_PASSPHRASE=
|
||||
BITGET_TRADING_ENABLED=false
|
||||
```
|
||||
|
||||
### 4. 启动Redis(可选)
|
||||
## 说明
|
||||
|
||||
如果要使用缓存功能,请先启动Redis:
|
||||
|
||||
```bash
|
||||
# macOS (使用Homebrew)
|
||||
brew services start redis
|
||||
|
||||
# Linux
|
||||
sudo systemctl start redis
|
||||
|
||||
# Windows
|
||||
# 下载并运行Redis for Windows
|
||||
```
|
||||
|
||||
### 5. 启动后端服务
|
||||
|
||||
```bash
|
||||
# 在backend目录下
|
||||
cd backend
|
||||
|
||||
# 启动服务
|
||||
python -m app.main
|
||||
|
||||
# 或使用uvicorn
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
服务启动后,访问:
|
||||
- 前端界面:http://localhost:8000
|
||||
- API文档:http://localhost:8000/docs
|
||||
|
||||
## 使用指南
|
||||
|
||||
### 基本查询
|
||||
|
||||
1. **查询实时行情**
|
||||
```
|
||||
查询600519的实时行情
|
||||
贵州茅台的价格
|
||||
000001现在多少钱
|
||||
```
|
||||
|
||||
2. **查看K线图**
|
||||
```
|
||||
600519的K线图
|
||||
贵州茅台的走势
|
||||
```
|
||||
|
||||
3. **技术指标分析**
|
||||
```
|
||||
600519的技术指标
|
||||
分析贵州茅台的MACD
|
||||
```
|
||||
|
||||
4. **基本面信息**
|
||||
```
|
||||
600519的基本信息
|
||||
贵州茅台是什么行业
|
||||
```
|
||||
|
||||
### 技能管理
|
||||
|
||||
点击右上角"技能管理"按钮,可以:
|
||||
- 查看所有可用技能
|
||||
- 启用/禁用特定技能
|
||||
- 查看技能描述
|
||||
|
||||
## API文档
|
||||
|
||||
启动服务后,访问 http://localhost:8000/docs 查看完整的API文档。
|
||||
|
||||
### 主要接口
|
||||
|
||||
#### 1. 发送消息
|
||||
```http
|
||||
POST /api/chat/message
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"message": "查询600519的实时行情",
|
||||
"session_id": "optional_session_id"
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. 获取对话历史
|
||||
```http
|
||||
GET /api/chat/history/{session_id}
|
||||
```
|
||||
|
||||
#### 3. 获取股票行情
|
||||
```http
|
||||
GET /api/stock/quote/{stock_code}
|
||||
```
|
||||
|
||||
#### 4. 获取K线数据
|
||||
```http
|
||||
GET /api/stock/kline/{stock_code}?start_date=20240101&end_date=20240201
|
||||
```
|
||||
|
||||
#### 5. 获取技能列表
|
||||
```http
|
||||
GET /api/skills/
|
||||
```
|
||||
|
||||
## 开发指南
|
||||
|
||||
### 添加新技能
|
||||
|
||||
1. 在 `backend/app/skills/` 目录下创建新的技能文件
|
||||
2. 继承 `BaseSkill` 类并实现 `execute` 方法
|
||||
3. 在 `backend/app/agent/core.py` 中注册新技能
|
||||
|
||||
示例:
|
||||
|
||||
```python
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
|
||||
class MyNewSkill(BaseSkill):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "my_skill"
|
||||
self.description = "我的新技能"
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="param1",
|
||||
type="string",
|
||||
description="参数1",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
# 实现技能逻辑
|
||||
return {"result": "success"}
|
||||
```
|
||||
|
||||
### 扩展数据源
|
||||
|
||||
在 `backend/app/services/` 目录下添加新的数据服务类,参考 `tushare_service.py` 的实现。
|
||||
|
||||
## 常见问题
|
||||
|
||||
### 1. Redis连接失败
|
||||
|
||||
如果Redis未安装或未启动,系统会自动降级,不影响核心功能,但会失去缓存能力。
|
||||
|
||||
### 2. Tushare API限制
|
||||
|
||||
免费版Tushare有调用频率限制(120次/分钟)。如果遇到限制,可以:
|
||||
- 等待一段时间后重试
|
||||
- 考虑升级到付费版
|
||||
- 使用Redis缓存减少API调用
|
||||
|
||||
### 3. 股票代码格式
|
||||
|
||||
支持的股票代码格式:
|
||||
- 6位数字:600000、000001
|
||||
- 带后缀:600000.SH、000001.SZ
|
||||
- 股票名称:贵州茅台、中国平安
|
||||
|
||||
### 4. 端口被占用
|
||||
|
||||
如果8000端口被占用,可以修改 `.env` 文件中的 `API_PORT` 配置。
|
||||
|
||||
## 性能优化
|
||||
|
||||
1. **启用Redis缓存**:显著减少API调用和响应时间
|
||||
2. **调整缓存TTL**:在 `cache_service.py` 中根据需求调整缓存时间
|
||||
3. **限制历史消息数**:在 `context.py` 中调整 `max_history` 参数
|
||||
|
||||
## 安全建议
|
||||
|
||||
1. **生产环境**:
|
||||
- 修改 `.env` 中的 `SECRET_KEY`
|
||||
- 设置 `DEBUG=False`
|
||||
- 配置严格的CORS策略
|
||||
- 使用HTTPS
|
||||
|
||||
2. **API密钥**:
|
||||
- 不要将 `.env` 文件提交到版本控制
|
||||
- 定期更换API密钥
|
||||
- 使用环境变量或密钥管理服务
|
||||
|
||||
## 贡献指南
|
||||
|
||||
欢迎贡献代码!请遵循以下步骤:
|
||||
|
||||
1. Fork本项目
|
||||
2. 创建特性分支 (`git checkout -b feature/AmazingFeature`)
|
||||
3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
|
||||
4. 推送到分支 (`git push origin feature/AmazingFeature`)
|
||||
5. 开启Pull Request
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 MIT 许可证。
|
||||
|
||||
## 联系方式
|
||||
|
||||
如有问题或建议,请提交Issue。
|
||||
|
||||
## 致谢
|
||||
|
||||
- [Tushare](https://tushare.pro/) - 金融数据接口
|
||||
- [智谱AI](https://open.bigmodel.cn/) - AI模型服务
|
||||
- [FastAPI](https://fastapi.tiangolo.com/) - Web框架
|
||||
- [LangChain](https://www.langchain.com/) - AI Agent框架
|
||||
- [Lightweight Charts](https://tradingview.github.io/lightweight-charts/) - 图表库
|
||||
- 当前项目已清理股票、A 股、旧聊天智能体相关运行链。
|
||||
- 数据库 schema 暂未迁移,保留现有兼容性配置。
|
||||
|
||||
@ -1,197 +0,0 @@
|
||||
"""
|
||||
上下文管理器
|
||||
管理对话历史和上下文
|
||||
"""
|
||||
from typing import List, Dict, Optional
|
||||
from app.services.db_service import db_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""上下文管理器"""
|
||||
|
||||
def __init__(self, max_history: int = 10):
|
||||
"""
|
||||
初始化上下文管理器
|
||||
|
||||
Args:
|
||||
max_history: 最大历史消息数
|
||||
"""
|
||||
self.max_history = max_history
|
||||
|
||||
def get_context(self, session_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取对话上下文
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
messages = db_service.get_conversation_history(session_id, limit=self.max_history)
|
||||
|
||||
context = []
|
||||
for msg in messages:
|
||||
context.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"metadata": msg.metadata if hasattr(msg, 'metadata') else {}
|
||||
})
|
||||
|
||||
return context
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[dict] = None,
|
||||
user_id: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
添加消息到上下文
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
role: 角色(user/assistant)
|
||||
content: 消息内容
|
||||
metadata: 元数据
|
||||
user_id: 用户ID(创建新对话时需要)
|
||||
"""
|
||||
db_service.add_message(session_id, role, content, metadata, user_id)
|
||||
logger.info(f"添加消息到上下文: {session_id}, {role}")
|
||||
|
||||
def clear_context(self, session_id: str):
|
||||
"""
|
||||
清除上下文(暂不实现删除,保留历史)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
logger.info(f"清除上下文请求: {session_id}")
|
||||
# 实际不删除,只是标记
|
||||
pass
|
||||
|
||||
def format_context_for_llm(self, session_id: str) -> str:
|
||||
"""
|
||||
格式化上下文供LLM使用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
格式化的上下文字符串
|
||||
"""
|
||||
context = self.get_context(session_id)
|
||||
|
||||
if not context:
|
||||
return ""
|
||||
|
||||
formatted = []
|
||||
for msg in context:
|
||||
role = "用户" if msg["role"] == "user" else "助手"
|
||||
formatted.append(f"{role}: {msg['content']}")
|
||||
|
||||
return "\n".join(formatted)
|
||||
|
||||
def extract_context_info(self, session_id: str) -> Dict:
|
||||
"""
|
||||
提取上下文信息
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
ContextInfo: {
|
||||
'last_stock': str | None, # 上次讨论的股票
|
||||
'last_topic': str | None, # 上次的话题
|
||||
'user_preferences': dict # 用户偏好
|
||||
}
|
||||
"""
|
||||
history = self.get_context(session_id)
|
||||
|
||||
return {
|
||||
'last_stock': self._extract_last_stock(history),
|
||||
'last_topic': self._extract_last_topic(history),
|
||||
'user_preferences': self._analyze_user_preferences(history)
|
||||
}
|
||||
|
||||
def _extract_last_stock(self, history: List[Dict]) -> Optional[str]:
|
||||
"""
|
||||
从历史对话中提取最后讨论的股票
|
||||
|
||||
Args:
|
||||
history: 对话历史
|
||||
|
||||
Returns:
|
||||
股票代码或None
|
||||
"""
|
||||
# 从后往前查找
|
||||
for msg in reversed(history):
|
||||
if msg['role'] == 'assistant':
|
||||
metadata = msg.get('metadata', {})
|
||||
if isinstance(metadata, dict):
|
||||
# 尝试从不同位置提取股票代码
|
||||
if 'data' in metadata:
|
||||
data = metadata['data']
|
||||
if isinstance(data, dict):
|
||||
if 'stock_code' in data:
|
||||
return data['stock_code']
|
||||
if 'ts_code' in data:
|
||||
return data['ts_code']
|
||||
|
||||
# 尝试从intent中提取
|
||||
if 'intent' in metadata:
|
||||
intent = metadata['intent']
|
||||
if isinstance(intent, dict) and 'target' in intent:
|
||||
target = intent['target']
|
||||
if isinstance(target, dict) and 'stock_code' in target:
|
||||
return target['stock_code']
|
||||
|
||||
return None
|
||||
|
||||
def _extract_last_topic(self, history: List[Dict]) -> Optional[str]:
|
||||
"""
|
||||
从历史对话中提取最后的话题
|
||||
|
||||
Args:
|
||||
history: 对话历史
|
||||
|
||||
Returns:
|
||||
话题或None
|
||||
"""
|
||||
if not history:
|
||||
return None
|
||||
|
||||
# 获取最后一条用户消息
|
||||
for msg in reversed(history):
|
||||
if msg['role'] == 'user':
|
||||
content = msg['content']
|
||||
# 简单提取话题(前50个字符)
|
||||
return content[:50] if len(content) > 50 else content
|
||||
|
||||
return None
|
||||
|
||||
def _analyze_user_preferences(self, history: List[Dict]) -> Dict:
|
||||
"""
|
||||
分析用户偏好
|
||||
|
||||
Args:
|
||||
history: 对话历史
|
||||
|
||||
Returns:
|
||||
用户偏好字典
|
||||
"""
|
||||
preferences = {
|
||||
'preferred_style': 'casual',
|
||||
'typical_time_scope': 'short_term',
|
||||
'frequent_dimensions': []
|
||||
}
|
||||
|
||||
# 简单的偏好分析(可以后续扩展)
|
||||
if len(history) > 5:
|
||||
# 如果对话较多,可能是专业用户
|
||||
preferences['preferred_style'] = 'professional'
|
||||
|
||||
return preferences
|
||||
@ -1,340 +0,0 @@
|
||||
"""
|
||||
问题分析器 - 使用LLM深度理解用户意图
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.services.llm_service import llm_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class QuestionAnalyzer:
|
||||
"""智能问题分析器 - 使用LLM深度理解用户意图"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化问题分析器"""
|
||||
self.use_llm = llm_service.client is not None
|
||||
if not self.use_llm:
|
||||
logger.warning("LLM未配置,QuestionAnalyzer将使用降级模式")
|
||||
|
||||
async def analyze_question(
|
||||
self,
|
||||
question: str,
|
||||
context: List[Dict],
|
||||
session_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
深度分析用户问题
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
context: 对话历史上下文
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
QuestionIntent: {
|
||||
'type': 'stock_analysis' | 'market_overview' | 'knowledge' | 'chat',
|
||||
'target': {
|
||||
'stock_code': str,
|
||||
'stock_name': str,
|
||||
'market': 'A股' | '美股'
|
||||
},
|
||||
'dimensions': {
|
||||
'price_trend': bool, # 价格走势
|
||||
'technical': bool, # 技术指标
|
||||
'fundamental': bool, # 基本面
|
||||
'valuation': bool, # 估值
|
||||
'money_flow': bool, # 资金流向
|
||||
'risk': bool # 风险分析
|
||||
},
|
||||
'time_scope': {
|
||||
'short_term': bool, # 短期(1-2周)
|
||||
'medium_term': bool, # 中期(1-3月)
|
||||
'long_term': bool # 长期(半年+)
|
||||
},
|
||||
'analysis_depth': 'quick' | 'standard' | 'deep',
|
||||
'specific_concerns': List[str], # 特定关注点
|
||||
'context_references': {
|
||||
'refers_to_previous': bool,
|
||||
'comparison_target': str | None
|
||||
},
|
||||
'user_style': {
|
||||
'tone': 'professional' | 'casual',
|
||||
'detail_level': 'brief' | 'detailed'
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not self.use_llm:
|
||||
# 降级模式:返回基本的意图分析
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
# 构建上下文字符串
|
||||
context_str = self._format_context(context)
|
||||
|
||||
# 构建LLM prompt
|
||||
prompt = self._build_analysis_prompt(question, context_str)
|
||||
|
||||
try:
|
||||
# 异步调用LLM
|
||||
result = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
max_tokens=800
|
||||
)
|
||||
|
||||
if not result:
|
||||
logger.warning("LLM返回空结果,使用降级模式")
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
# 清理和解析JSON
|
||||
intent = self._parse_llm_response(result)
|
||||
|
||||
if intent:
|
||||
logger.info(f"问题分析成功: type={intent.get('type')}, dimensions={intent.get('dimensions')}")
|
||||
return intent
|
||||
else:
|
||||
logger.warning("JSON解析失败,使用降级模式")
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问题分析失败: {e}")
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
async def _call_llm_async(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 800
|
||||
) -> Optional[str]:
|
||||
"""异步调用LLM"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: llm_service.chat(messages, temperature, max_tokens)
|
||||
)
|
||||
|
||||
def _format_context(self, context: List[Dict]) -> str:
|
||||
"""格式化对话历史上下文"""
|
||||
if not context:
|
||||
return ""
|
||||
|
||||
context_str = "\n\n【对话历史】\n"
|
||||
# 只取最近4条消息
|
||||
for msg in context[-4:]:
|
||||
role = "用户" if msg["role"] == "user" else "助手"
|
||||
content = msg['content'][:100] # 限制长度
|
||||
context_str += f"{role}: {content}\n"
|
||||
|
||||
return context_str
|
||||
|
||||
def _build_analysis_prompt(self, question: str, context_str: str) -> str:
|
||||
"""构建问题分析的LLM prompt"""
|
||||
prompt = f"""你是一个专业的金融问题分析专家。请深度分析用户的问题,提取结构化信息。
|
||||
|
||||
{context_str}
|
||||
|
||||
【当前问题】
|
||||
用户: {question}
|
||||
|
||||
请分析以下维度:
|
||||
|
||||
1. **问题类型**
|
||||
- stock_analysis: 针对**特定单只股票**的分析(如"贵州茅台怎么样"、"分析比亚迪"、"AAPL走势"、"阿里巴巴美股")
|
||||
**注意**:如果用户问的是"板块"、"行业"、"概念股"等,这不是stock_analysis,而是market_overview
|
||||
- market_overview: 市场整体分析、行业板块分析、投资机会(如"最近有什么投资机会"、"商业航天板块怎么样"、"新能源行业走势"、"现在适合买股票吗")
|
||||
- knowledge: 金融知识问答(如"什么是MACD"、"如何看K线图")
|
||||
- chat: 一般对话(如"你好"、"在吗")
|
||||
|
||||
**重要**:判断是stock_analysis还是market_overview的关键:
|
||||
- 如果提到具体的公司名称或股票代码 → stock_analysis
|
||||
- 如果提到"板块"、"行业"、"概念"、"赛道"、"领域" → market_overview
|
||||
- 如果问"哪些股票"、"什么机会" → market_overview
|
||||
|
||||
2. **股票识别**(如果是stock_analysis,这是最重要的部分)
|
||||
请识别用户提到的股票,并返回准确的股票代码:
|
||||
|
||||
**重要**:如果用户提到多只股票(如"TSLA和NVDA"、"特斯拉、英伟达"),请返回所有股票代码的列表。
|
||||
|
||||
**A股代码格式**:6位数字
|
||||
- 上海主板:600xxx、601xxx、603xxx、605xxx
|
||||
- 深圳主板:000xxx、001xxx
|
||||
- 创业板:300xxx、301xxx
|
||||
- 科创板:688xxx
|
||||
- 常见示例:贵州茅台→600519,比亚迪→002594,宁德时代→300750
|
||||
|
||||
**美股代码格式**:1-5位大写字母
|
||||
- 常见示例:苹果→AAPL,特斯拉→TSLA,微软→MSFT,谷歌→GOOGL,英伟达→NVDA
|
||||
- 中概股美股:阿里巴巴美股→BABA,京东美股→JD,拼多多→PDD,百度美股→BIDU,网易美股→NTES,哔哩哔哩美股→BILI
|
||||
|
||||
**港股代码格式**:4-5位数字加.HK后缀
|
||||
- 常见示例:腾讯→0700.HK,阿里巴巴港股→9988.HK,美团→3690.HK,小米→1810.HK,京东港股→9618.HK,百度港股→9888.HK,网易港股→9999.HK,哔哩哔哩港股→9626.HK
|
||||
- 注意:港股代码需要包含.HK后缀
|
||||
|
||||
**市场判断**:
|
||||
- 如果用户明确说"美股"、"纳斯达克"、"纽交所" → 美股
|
||||
- 如果用户明确说"港股"、"香港"、"恒生" → 港股
|
||||
- 对于同时在多地上市的公司(如阿里巴巴、京东、百度等):
|
||||
- 用户说"美股"或没有明确指定 → 返回美股代码(如BABA)
|
||||
- 用户说"港股" → 返回港股代码(如9988.HK)
|
||||
- 纯港股公司(如腾讯、美团、小米)→ 港股
|
||||
- 默认情况下,中国公司优先考虑A股市场
|
||||
|
||||
3. **用户关注维度**(如果是stock_analysis)
|
||||
分析用户想了解哪些方面:
|
||||
- price_trend: 价格走势、涨跌情况、最新价格
|
||||
- technical: 技术指标(MACD、RSI、均线、KDJ等)
|
||||
- fundamental: 基本面(公司业务、行业地位、财务状况)
|
||||
- valuation: 估值水平(PE、PB、市值、估值是否合理)
|
||||
- money_flow: 资金流向、主力动向、大单流入流出
|
||||
- risk: 风险分析、风险提示、投资风险
|
||||
|
||||
4. **时间范围**
|
||||
- short_term: 短期(1-2周)- 如"短期走势"、"近期表现"
|
||||
- medium_term: 中期(1-3月)- 如"中期趋势"、"未来一个月"
|
||||
- long_term: 长期(半年以上)- 如"长期投资"、"适合长期持有吗"
|
||||
|
||||
5. **分析深度**
|
||||
- quick: 快速查看(只需要基本信息,如"价格多少")
|
||||
- standard: 标准分析(常规分析,如"怎么样"、"分析一下")
|
||||
- deep: 深度分析(全面详细,如"全面分析"、"深度研究")
|
||||
|
||||
6. **特定关注点**
|
||||
提取用户明确提到的关注点,如:
|
||||
- "支撑位在哪"
|
||||
- "盈利能力如何"
|
||||
- "适合长期持有吗"
|
||||
- "有没有金叉"
|
||||
|
||||
7. **上下文引用**
|
||||
- 是否引用了之前的对话("这只股票"、"它"、"那技术面呢")
|
||||
- 是否要求对比分析("和上次相比"、"对比一下")
|
||||
|
||||
8. **用户风格**
|
||||
- tone: professional(专业,使用专业术语)/ casual(随意,通俗易懂)
|
||||
- detail_level: brief(简洁,简短回答)/ detailed(详细,详细分析)
|
||||
|
||||
请以JSON格式返回分析结果:
|
||||
{{
|
||||
"type": "问题类型",
|
||||
"target": {{
|
||||
"stock_code": "单只股票时为字符串(如'AAPL'),多只股票时为列表(如['TSLA', 'NVDA'])",
|
||||
"stock_name": "单只股票时为字符串(如'苹果'),多只股票时为列表(如['特斯拉', '英伟达'])",
|
||||
"market": "A股/美股/港股"
|
||||
}},
|
||||
"dimensions": {{
|
||||
"price_trend": true/false,
|
||||
"technical": true/false,
|
||||
"fundamental": true/false,
|
||||
"valuation": true/false,
|
||||
"money_flow": true/false,
|
||||
"risk": true/false
|
||||
}},
|
||||
"time_scope": {{
|
||||
"short_term": true/false,
|
||||
"medium_term": true/false,
|
||||
"long_term": true/false
|
||||
}},
|
||||
"analysis_depth": "quick/standard/deep",
|
||||
"specific_concerns": ["关注点1", "关注点2"],
|
||||
"context_references": {{
|
||||
"refers_to_previous": true/false,
|
||||
"comparison_target": "对比目标(如有)"
|
||||
}},
|
||||
"user_style": {{
|
||||
"tone": "professional/casual",
|
||||
"detail_level": "brief/detailed"
|
||||
}}
|
||||
}}
|
||||
|
||||
只返回JSON,不要有任何其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析LLM返回的JSON响应"""
|
||||
try:
|
||||
# 清理结果,移除可能的markdown代码块标记
|
||||
result = response.strip()
|
||||
if result.startswith("```json"):
|
||||
result = result[7:]
|
||||
if result.startswith("```"):
|
||||
result = result[3:]
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
result = result.strip()
|
||||
|
||||
# 检查是否为空
|
||||
if not result:
|
||||
return None
|
||||
|
||||
# 解析JSON
|
||||
intent = json.loads(result)
|
||||
return intent
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e}, 原始响应: {response[:200]}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return None
|
||||
|
||||
def _fallback_analysis(self, question: str) -> Dict[str, Any]:
|
||||
"""降级模式:基于规则的简单分析"""
|
||||
question_lower = question.lower()
|
||||
|
||||
# 简单的关键词匹配
|
||||
is_stock_query = any(kw in question for kw in [
|
||||
"股票", "分析", "怎么样", "如何", "走势", "价格", "涨", "跌"
|
||||
])
|
||||
|
||||
if is_stock_query:
|
||||
# 尝试提取股票名称(简单规则)
|
||||
return {
|
||||
'type': 'stock_analysis',
|
||||
'target': {
|
||||
'stock_code': '',
|
||||
'stock_name': '',
|
||||
'market': 'A股'
|
||||
},
|
||||
'dimensions': {
|
||||
'price_trend': True,
|
||||
'technical': True,
|
||||
'fundamental': True,
|
||||
'valuation': False,
|
||||
'money_flow': False,
|
||||
'risk': False
|
||||
},
|
||||
'time_scope': {
|
||||
'short_term': True,
|
||||
'medium_term': True,
|
||||
'long_term': False
|
||||
},
|
||||
'analysis_depth': 'standard',
|
||||
'specific_concerns': [],
|
||||
'context_references': {
|
||||
'refers_to_previous': False,
|
||||
'comparison_target': None
|
||||
},
|
||||
'user_style': {
|
||||
'tone': 'casual',
|
||||
'detail_level': 'detailed'
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 默认为一般对话
|
||||
return {
|
||||
'type': 'chat',
|
||||
'target': {},
|
||||
'dimensions': {},
|
||||
'time_scope': {},
|
||||
'analysis_depth': 'quick',
|
||||
'specific_concerns': [],
|
||||
'context_references': {
|
||||
'refers_to_previous': False,
|
||||
'comparison_target': None
|
||||
},
|
||||
'user_style': {
|
||||
'tone': 'casual',
|
||||
'detail_level': 'brief'
|
||||
}
|
||||
}
|
||||
@ -1,339 +0,0 @@
|
||||
"""
|
||||
技能管理器
|
||||
管理所有技能的注册、发现和调用
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict, Optional, List, Type, Any
|
||||
from app.skills.base import BaseSkill
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""技能管理器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化技能管理器"""
|
||||
self._skills: Dict[str, BaseSkill] = {}
|
||||
logger.info("技能管理器初始化")
|
||||
|
||||
def register(self, skill: BaseSkill) -> bool:
|
||||
"""
|
||||
注册技能
|
||||
|
||||
Args:
|
||||
skill: 技能实例
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not skill.name:
|
||||
logger.error("技能名称不能为空")
|
||||
return False
|
||||
|
||||
if skill.name in self._skills:
|
||||
logger.warning(f"技能已存在,将被覆盖: {skill.name}")
|
||||
|
||||
self._skills[skill.name] = skill
|
||||
logger.info(f"技能注册成功: {skill.name}")
|
||||
return True
|
||||
|
||||
def unregister(self, skill_name: str) -> bool:
|
||||
"""
|
||||
注销技能
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if skill_name in self._skills:
|
||||
del self._skills[skill_name]
|
||||
logger.info(f"技能注销成功: {skill_name}")
|
||||
return True
|
||||
|
||||
logger.warning(f"技能不存在: {skill_name}")
|
||||
return False
|
||||
|
||||
def get_skill(self, skill_name: str) -> Optional[BaseSkill]:
|
||||
"""
|
||||
获取技能
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
|
||||
Returns:
|
||||
技能实例或None
|
||||
"""
|
||||
return self._skills.get(skill_name)
|
||||
|
||||
def get_all_skills(self) -> List[BaseSkill]:
|
||||
"""
|
||||
获取所有技能
|
||||
|
||||
Returns:
|
||||
技能列表
|
||||
"""
|
||||
return list(self._skills.values())
|
||||
|
||||
def get_enabled_skills(self) -> List[BaseSkill]:
|
||||
"""
|
||||
获取所有启用的技能
|
||||
|
||||
Returns:
|
||||
启用的技能列表
|
||||
"""
|
||||
return [skill for skill in self._skills.values() if skill.enabled]
|
||||
|
||||
async def execute_skill(self, skill_name: str, **kwargs) -> Dict:
|
||||
"""
|
||||
执行技能
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
**kwargs: 技能参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
skill = self.get_skill(skill_name)
|
||||
|
||||
if not skill:
|
||||
logger.error(f"❌ 技能不存在: {skill_name}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"技能不存在: {skill_name}"
|
||||
}
|
||||
|
||||
if not skill.enabled:
|
||||
logger.warning(f"⚠️ 技能已禁用: {skill_name}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"技能已禁用: {skill_name}"
|
||||
}
|
||||
|
||||
# 验证参数
|
||||
valid, error = skill.validate_params(**kwargs)
|
||||
if not valid:
|
||||
logger.error(f"❌ 技能参数验证失败 {skill_name}: {error}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": error
|
||||
}
|
||||
|
||||
# 执行技能
|
||||
try:
|
||||
logger.info(f"🚀 开始执行技能: {skill_name}, 参数: {kwargs}")
|
||||
result = await skill.execute(**kwargs)
|
||||
logger.info(f"✅ 技能执行成功: {skill_name}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 技能执行失败 {skill_name}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def enable_skill(self, skill_name: str) -> bool:
|
||||
"""
|
||||
启用技能
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
skill = self.get_skill(skill_name)
|
||||
if skill:
|
||||
skill.enable()
|
||||
logger.info(f"技能已启用: {skill_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_skill(self, skill_name: str) -> bool:
|
||||
"""
|
||||
禁用技能
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
skill = self.get_skill(skill_name)
|
||||
if skill:
|
||||
skill.disable()
|
||||
logger.info(f"技能已禁用: {skill_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_skills_info(self) -> List[Dict]:
|
||||
"""
|
||||
获取所有技能信息
|
||||
|
||||
Returns:
|
||||
技能信息列表
|
||||
"""
|
||||
return [skill.get_info() for skill in self._skills.values()]
|
||||
|
||||
async def execute_plan(
|
||||
self,
|
||||
plan: Dict[str, Any],
|
||||
stock_code: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行技能规划
|
||||
|
||||
Args:
|
||||
plan: 技能执行计划(来自SkillPlanner)
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
{
|
||||
'results': {
|
||||
'market_data': {...},
|
||||
'technical_analysis': {...},
|
||||
...
|
||||
},
|
||||
'execution_time': float,
|
||||
'errors': List[str]
|
||||
}
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
skills = plan.get('skills', [])
|
||||
strategy = plan.get('execution_strategy', 'parallel')
|
||||
|
||||
logger.info(f"开始执行技能规划: {len(skills)}个技能, 策略: {strategy}")
|
||||
|
||||
if strategy == 'parallel':
|
||||
results = await self._execute_parallel(skills, stock_code)
|
||||
else:
|
||||
results = await self._execute_sequential(skills, stock_code)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"技能规划执行完成,耗时: {execution_time:.2f}秒")
|
||||
|
||||
return {
|
||||
'results': results['results'],
|
||||
'execution_time': execution_time,
|
||||
'errors': results['errors']
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
skills: List[Dict[str, Any]],
|
||||
stock_code: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
并行执行技能(按优先级分组)
|
||||
|
||||
Args:
|
||||
skills: 技能列表
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
# 按优先级分组
|
||||
priority_groups = {}
|
||||
for skill_info in skills:
|
||||
priority = skill_info['priority']
|
||||
if priority not in priority_groups:
|
||||
priority_groups[priority] = []
|
||||
priority_groups[priority].append(skill_info)
|
||||
|
||||
all_results = {}
|
||||
all_errors = []
|
||||
|
||||
# 按优先级顺序执行
|
||||
for priority in sorted(priority_groups.keys()):
|
||||
skill_group = priority_groups[priority]
|
||||
logger.info(f"📋 执行优先级 {priority} 的技能: {[s['name'] for s in skill_group]}")
|
||||
|
||||
# 同一优先级的技能并行执行
|
||||
tasks = []
|
||||
for skill_info in skill_group:
|
||||
params = skill_info['params'].copy()
|
||||
params['stock_code'] = stock_code
|
||||
logger.info(f" ➡️ 准备执行技能: {skill_info['name']}, 原因: {skill_info.get('reason', '未知')}")
|
||||
task = self.execute_skill(skill_info['name'], **params)
|
||||
tasks.append((skill_info['name'], task))
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
for (skill_name, _), result in zip(tasks, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"❌ 技能执行异常: {skill_name}, {result}")
|
||||
all_errors.append(f"{skill_name}: {str(result)}")
|
||||
all_results[skill_name] = {'error': str(result)}
|
||||
elif result.get('success'):
|
||||
# 不在这里记录成功日志,因为 execute_skill 已经记录了
|
||||
all_results[skill_name] = result.get('data', {})
|
||||
else:
|
||||
error_msg = result.get('error', '未知错误')
|
||||
logger.error(f"❌ 技能执行失败: {skill_name}, {error_msg}")
|
||||
all_errors.append(f"{skill_name}: {error_msg}")
|
||||
all_results[skill_name] = {'error': error_msg}
|
||||
|
||||
return {
|
||||
'results': all_results,
|
||||
'errors': all_errors
|
||||
}
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
skills: List[Dict[str, Any]],
|
||||
stock_code: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
串行执行技能
|
||||
|
||||
Args:
|
||||
skills: 技能列表
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
all_results = {}
|
||||
all_errors = []
|
||||
|
||||
for skill_info in skills:
|
||||
skill_name = skill_info['name']
|
||||
params = skill_info['params'].copy()
|
||||
params['stock_code'] = stock_code
|
||||
|
||||
logger.info(f"执行技能: {skill_name}")
|
||||
|
||||
try:
|
||||
result = await self.execute_skill(skill_name, **params)
|
||||
|
||||
if result.get('success'):
|
||||
all_results[skill_name] = result.get('data', {})
|
||||
else:
|
||||
error_msg = result.get('error', '未知错误')
|
||||
logger.error(f"技能执行失败: {skill_name}, {error_msg}")
|
||||
all_errors.append(f"{skill_name}: {error_msg}")
|
||||
all_results[skill_name] = {'error': error_msg}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"技能执行异常: {skill_name}, {e}")
|
||||
all_errors.append(f"{skill_name}: {str(e)}")
|
||||
all_results[skill_name] = {'error': str(e)}
|
||||
|
||||
return {
|
||||
'results': all_results,
|
||||
'errors': all_errors
|
||||
}
|
||||
|
||||
|
||||
# 创建全局技能管理器实例
|
||||
skill_manager = SkillManager()
|
||||
@ -1,396 +0,0 @@
|
||||
"""
|
||||
技能规划器 - 根据用户意图智能选择技能组合
|
||||
"""
|
||||
from typing import Dict, Any, List, Set
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class SkillPlanner:
|
||||
"""智能技能规划器 - 根据问题意图动态选择技能"""
|
||||
|
||||
# A股维度到技能的映射
|
||||
A_STOCK_DIMENSION_SKILL_MAP = {
|
||||
'price_trend': {
|
||||
'required': ['market_data', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'technical': {
|
||||
'required': ['market_data', 'technical_analysis', 'brave_search'],
|
||||
'optional': ['visualization']
|
||||
},
|
||||
'fundamental': {
|
||||
'required': ['fundamental', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'valuation': {
|
||||
'required': ['advanced_data', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'money_flow': {
|
||||
'required': ['advanced_data', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'risk': {
|
||||
'required': ['technical_analysis', 'advanced_data', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'news': {
|
||||
'required': ['brave_search'],
|
||||
'optional': []
|
||||
}
|
||||
}
|
||||
|
||||
# 美股/港股维度到技能的映射(使用 yfinance)
|
||||
INTL_STOCK_DIMENSION_SKILL_MAP = {
|
||||
'price_trend': {
|
||||
'required': ['us_stock_analysis', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'technical': {
|
||||
'required': ['us_stock_analysis', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'fundamental': {
|
||||
'required': ['us_stock_analysis', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'valuation': {
|
||||
'required': ['us_stock_analysis', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'money_flow': {
|
||||
'required': ['us_stock_analysis', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'risk': {
|
||||
'required': ['us_stock_analysis', 'brave_search'],
|
||||
'optional': []
|
||||
},
|
||||
'news': {
|
||||
'required': ['brave_search'],
|
||||
'optional': []
|
||||
}
|
||||
}
|
||||
|
||||
# 技能依赖关系(仅 A 股)
|
||||
SKILL_DEPENDENCIES = {
|
||||
'technical_analysis': ['market_data'],
|
||||
'visualization': ['market_data'],
|
||||
}
|
||||
|
||||
# 技能优先级(数字越小优先级越高)
|
||||
SKILL_PRIORITY = {
|
||||
'market_data': 1,
|
||||
'fundamental': 1,
|
||||
'brave_search': 1,
|
||||
'us_stock_analysis': 1,
|
||||
'technical_analysis': 2,
|
||||
'advanced_data': 2,
|
||||
'visualization': 3,
|
||||
}
|
||||
|
||||
# 分析深度策略
|
||||
DEPTH_STRATEGY = {
|
||||
'quick': {
|
||||
'max_skills': 2,
|
||||
'include_optional': False,
|
||||
'use_cache': True
|
||||
},
|
||||
'standard': {
|
||||
'max_skills': 4,
|
||||
'include_optional': True,
|
||||
'use_cache': True
|
||||
},
|
||||
'deep': {
|
||||
'max_skills': None,
|
||||
'include_optional': True,
|
||||
'use_cache': False
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化技能规划器"""
|
||||
logger.info("技能规划器初始化")
|
||||
|
||||
def plan_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据意图规划技能执行
|
||||
|
||||
Args:
|
||||
intent: 问题意图(来自QuestionAnalyzer)
|
||||
|
||||
Returns:
|
||||
SkillExecutionPlan
|
||||
"""
|
||||
# 获取市场类型
|
||||
target = intent.get('target', {})
|
||||
market = target.get('market', 'A股')
|
||||
stock_code = target.get('stock_code', '')
|
||||
stock_name = target.get('stock_name', '')
|
||||
|
||||
# 根据市场类型选择不同的技能映射
|
||||
if market in ('美股', '港股'):
|
||||
return self._plan_intl_stock_skills(intent, market, stock_code, stock_name)
|
||||
else:
|
||||
return self._plan_a_stock_skills(intent)
|
||||
|
||||
def _plan_a_stock_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""规划 A 股技能"""
|
||||
# 1. 根据维度映射技能
|
||||
skills = self._map_dimensions_to_skills(
|
||||
intent.get('dimensions', {}),
|
||||
self.A_STOCK_DIMENSION_SKILL_MAP
|
||||
)
|
||||
|
||||
# 2. 根据分析深度调整
|
||||
depth = intent.get('analysis_depth', 'standard')
|
||||
skills = self._apply_depth_strategy(skills, depth)
|
||||
|
||||
# 3. 解析依赖关系
|
||||
skills = self._resolve_dependencies(skills)
|
||||
|
||||
# 4. 去重并排序
|
||||
skills = list(set(skills))
|
||||
sorted_skills = self._sort_by_priority(skills)
|
||||
|
||||
# 5. 构建执行计划
|
||||
plan = {
|
||||
'skills': [
|
||||
{
|
||||
'name': skill,
|
||||
'params': self._get_skill_params(skill, intent),
|
||||
'priority': self.SKILL_PRIORITY.get(skill, 5),
|
||||
'required': True,
|
||||
'reason': self._get_skill_reason(skill, intent)
|
||||
}
|
||||
for skill in sorted_skills
|
||||
],
|
||||
'execution_strategy': self._determine_strategy(sorted_skills),
|
||||
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
|
||||
}
|
||||
|
||||
logger.info(f"[A股] 技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
|
||||
return plan
|
||||
|
||||
def _plan_intl_stock_skills(self, intent: Dict[str, Any], market: str, stock_code: str, stock_name: str) -> Dict[str, Any]:
|
||||
"""规划美股/港股技能"""
|
||||
# 1. 根据维度映射技能
|
||||
skills = self._map_dimensions_to_skills(
|
||||
intent.get('dimensions', {}),
|
||||
self.INTL_STOCK_DIMENSION_SKILL_MAP
|
||||
)
|
||||
|
||||
# 2. 确保至少有 us_stock_analysis
|
||||
if 'us_stock_analysis' not in skills:
|
||||
skills.append('us_stock_analysis')
|
||||
|
||||
# 3. 去重并排序
|
||||
skills = list(set(skills))
|
||||
sorted_skills = self._sort_by_priority(skills)
|
||||
|
||||
# 4. 构建执行计划
|
||||
depth = intent.get('analysis_depth', 'standard')
|
||||
plan = {
|
||||
'skills': [
|
||||
{
|
||||
'name': skill,
|
||||
'params': self._get_intl_skill_params(skill, stock_code, stock_name),
|
||||
'priority': self.SKILL_PRIORITY.get(skill, 5),
|
||||
'required': skill == 'us_stock_analysis',
|
||||
'reason': self._get_intl_skill_reason(skill, market)
|
||||
}
|
||||
for skill in sorted_skills
|
||||
],
|
||||
'execution_strategy': 'parallel',
|
||||
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
|
||||
}
|
||||
|
||||
logger.info(f"[{market}] 技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
|
||||
return plan
|
||||
|
||||
def _get_intl_skill_params(self, skill_name: str, stock_code: str, stock_name: str) -> Dict[str, Any]:
|
||||
"""获取美股/港股技能参数"""
|
||||
if skill_name == 'us_stock_analysis':
|
||||
return {
|
||||
'symbol': stock_code,
|
||||
'analysis_type': 'comprehensive'
|
||||
}
|
||||
elif skill_name == 'brave_search':
|
||||
return {
|
||||
'query': f'{stock_name} 最新动态 财报',
|
||||
'search_type': 'news',
|
||||
'count': 5,
|
||||
'freshness': 'pw'
|
||||
}
|
||||
return {}
|
||||
|
||||
def _get_intl_skill_reason(self, skill_name: str, market: str) -> str:
|
||||
"""获取美股/港股技能调用原因"""
|
||||
if skill_name == 'us_stock_analysis':
|
||||
return f'获取{market}基础数据和技术指标'
|
||||
elif skill_name == 'brave_search':
|
||||
return '获取最新市场资讯和舆情'
|
||||
return '提供分析数据'
|
||||
|
||||
def _map_dimensions_to_skills(self, dimensions: Dict[str, bool], skill_map: Dict) -> List[str]:
|
||||
"""将用户关注维度映射到技能"""
|
||||
skills = []
|
||||
|
||||
for dimension, enabled in dimensions.items():
|
||||
if enabled and dimension in skill_map:
|
||||
mapping = skill_map[dimension]
|
||||
skills.extend(mapping['required'])
|
||||
skills.extend(mapping['optional'])
|
||||
|
||||
return skills
|
||||
|
||||
def _apply_depth_strategy(self, skills: List[str], depth: str) -> List[str]:
|
||||
"""根据分析深度调整技能列表"""
|
||||
strategy = self.DEPTH_STRATEGY.get(depth, self.DEPTH_STRATEGY['standard'])
|
||||
|
||||
# 如果有最大技能数限制
|
||||
if strategy['max_skills'] is not None and len(skills) > strategy['max_skills']:
|
||||
# 按优先级保留前N个
|
||||
sorted_skills = self._sort_by_priority(skills)
|
||||
skills = sorted_skills[:strategy['max_skills']]
|
||||
|
||||
return skills
|
||||
|
||||
def _resolve_dependencies(self, skills: List[str]) -> List[str]:
|
||||
"""解析技能依赖关系,自动添加依赖的技能"""
|
||||
resolved_skills = set(skills)
|
||||
|
||||
for skill in skills:
|
||||
if skill in self.SKILL_DEPENDENCIES:
|
||||
dependencies = self.SKILL_DEPENDENCIES[skill]
|
||||
resolved_skills.update(dependencies)
|
||||
|
||||
return list(resolved_skills)
|
||||
|
||||
def _sort_by_priority(self, skills: List[str]) -> List[str]:
|
||||
"""按优先级排序技能"""
|
||||
return sorted(skills, key=lambda s: self.SKILL_PRIORITY.get(s, 999))
|
||||
|
||||
def _determine_strategy(self, skills: List[str]) -> str:
|
||||
"""确定执行策略(并行/串行)"""
|
||||
# 如果技能数量少于等于3,使用并行
|
||||
if len(skills) <= 3:
|
||||
return 'parallel'
|
||||
else:
|
||||
# 技能较多时,按优先级分组并行
|
||||
return 'parallel'
|
||||
|
||||
def _get_skill_params(self, skill_name: str, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""获取技能执行参数"""
|
||||
params = {}
|
||||
|
||||
if skill_name == 'market_data':
|
||||
params['data_type'] = 'quote'
|
||||
|
||||
elif skill_name == 'technical_analysis':
|
||||
# 根据用户关注点决定指标
|
||||
indicators = ['ma', 'macd']
|
||||
specific_concerns = intent.get('specific_concerns', [])
|
||||
|
||||
if any('rsi' in concern.lower() for concern in specific_concerns):
|
||||
indicators.append('rsi')
|
||||
if any('kdj' in concern.lower() for concern in specific_concerns):
|
||||
indicators.append('kdj')
|
||||
if any('布林' in concern or 'boll' in concern.lower() for concern in specific_concerns):
|
||||
indicators.append('boll')
|
||||
|
||||
params['indicators'] = indicators
|
||||
|
||||
elif skill_name == 'advanced_data':
|
||||
# 根据维度决定数据类型
|
||||
data_types = []
|
||||
dimensions = intent.get('dimensions', {})
|
||||
|
||||
if dimensions.get('valuation'):
|
||||
data_types.append('valuation')
|
||||
if dimensions.get('money_flow'):
|
||||
data_types.append('money_flow')
|
||||
|
||||
if not data_types:
|
||||
data_types = ['valuation', 'money_flow']
|
||||
|
||||
params['data_types'] = data_types
|
||||
|
||||
elif skill_name == 'brave_search':
|
||||
# 构建搜索查询
|
||||
target = intent.get('target', {})
|
||||
stock_name = target.get('stock_name', '')
|
||||
stock_code = target.get('stock_code', '')
|
||||
dimensions = intent.get('dimensions', {})
|
||||
|
||||
# 根据维度构建搜索关键词
|
||||
search_keywords = []
|
||||
if stock_name:
|
||||
search_keywords.append(stock_name)
|
||||
elif stock_code:
|
||||
search_keywords.append(stock_code)
|
||||
|
||||
# 添加维度相关关键词
|
||||
if dimensions.get('fundamental'):
|
||||
search_keywords.append('财报 业绩')
|
||||
if dimensions.get('news'):
|
||||
search_keywords.append('最新消息')
|
||||
if dimensions.get('risk'):
|
||||
search_keywords.append('风险 预警')
|
||||
|
||||
# 如果没有特定维度,搜索一般新闻
|
||||
if not any(dimensions.values()):
|
||||
search_keywords.append('最新动态')
|
||||
|
||||
params['query'] = ' '.join(search_keywords)
|
||||
params['search_type'] = 'news' # 默认搜索新闻
|
||||
params['count'] = 5
|
||||
params['freshness'] = 'pw' # 过去一周
|
||||
|
||||
return params
|
||||
|
||||
def _get_skill_reason(self, skill_name: str, intent: Dict[str, Any]) -> str:
|
||||
"""获取调用该技能的原因"""
|
||||
dimensions = intent.get('dimensions', {})
|
||||
reasons = []
|
||||
|
||||
if skill_name == 'market_data':
|
||||
if dimensions.get('price_trend'):
|
||||
reasons.append('用户关注价格走势')
|
||||
else:
|
||||
reasons.append('获取基础行情数据')
|
||||
|
||||
elif skill_name == 'technical_analysis':
|
||||
if dimensions.get('technical'):
|
||||
reasons.append('用户关注技术指标')
|
||||
else:
|
||||
reasons.append('提供技术面分析')
|
||||
|
||||
elif skill_name == 'fundamental':
|
||||
if dimensions.get('fundamental'):
|
||||
reasons.append('用户关注基本面')
|
||||
else:
|
||||
reasons.append('提供公司基本信息')
|
||||
|
||||
elif skill_name == 'advanced_data':
|
||||
if dimensions.get('valuation'):
|
||||
reasons.append('用户关注估值')
|
||||
if dimensions.get('money_flow'):
|
||||
reasons.append('用户关注资金流向')
|
||||
if not reasons:
|
||||
reasons.append('提供高级财务数据')
|
||||
|
||||
elif skill_name == 'visualization':
|
||||
reasons.append('生成K线图表')
|
||||
|
||||
elif skill_name == 'brave_search':
|
||||
if dimensions.get('news'):
|
||||
reasons.append('用户关注最新新闻')
|
||||
elif dimensions.get('fundamental'):
|
||||
reasons.append('搜索公司最新动态和财报信息')
|
||||
elif dimensions.get('risk'):
|
||||
reasons.append('搜索风险预警信息')
|
||||
else:
|
||||
reasons.append('获取最新市场资讯和舆情')
|
||||
|
||||
return ', '.join(reasons) if reasons else '提供分析数据'
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,212 +0,0 @@
|
||||
"""
|
||||
A股相关 API 路由
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from typing import Dict, Any
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局变量,用于访问智能体实例
|
||||
_astock_agent_instance = None
|
||||
|
||||
|
||||
def set_astock_agent(agent):
|
||||
"""设置智能体实例(由 main.py 调用)"""
|
||||
global _astock_agent_instance
|
||||
_astock_agent_instance = agent
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_astock_status() -> Dict[str, Any]:
|
||||
"""
|
||||
获取A股智能体状态
|
||||
|
||||
Returns:
|
||||
智能体状态信息
|
||||
"""
|
||||
try:
|
||||
if _astock_agent_instance is None:
|
||||
return {
|
||||
"enabled": False,
|
||||
"message": "A股智能体未启用"
|
||||
}
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"running": _astock_agent_instance.running,
|
||||
"selector_type": "short_term_thematic",
|
||||
"description": "短期题材选股器(题材轮动 + 技术面确认 + 风险控制)",
|
||||
"config": {
|
||||
"min_market_cap": settings.astock_min_market_cap if hasattr(settings, 'astock_min_market_cap') else 50,
|
||||
"max_market_cap": settings.astock_max_market_cap if hasattr(settings, 'astock_max_market_cap') else 500,
|
||||
"change_threshold": settings.astock_change_threshold,
|
||||
"top_n": settings.astock_top_n
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取A股状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/select")
|
||||
async def trigger_selection(background_tasks: BackgroundTasks) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发选股
|
||||
|
||||
Returns:
|
||||
选股任务状态
|
||||
"""
|
||||
try:
|
||||
if _astock_agent_instance is None:
|
||||
raise HTTPException(status_code=400, detail="A股智能体未启用")
|
||||
|
||||
# 在后台执行选股任务
|
||||
background_tasks.add_task(_astock_agent_instance.run_once)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "选股任务已提交,正在后台执行",
|
||||
"note": "请查看通知或日志获取结果"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"触发选股失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/select/sync")
|
||||
async def trigger_selection_sync() -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发选股(同步执行)
|
||||
|
||||
Returns:
|
||||
选股结果
|
||||
"""
|
||||
try:
|
||||
if _astock_agent_instance is None:
|
||||
raise HTTPException(status_code=400, detail="A股智能体未启用")
|
||||
|
||||
# 同步执行选股
|
||||
result = await _astock_agent_instance.run_once()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"触发选股失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_astock_config() -> Dict[str, Any]:
|
||||
"""
|
||||
获取A股选股配置
|
||||
|
||||
Returns:
|
||||
配置信息
|
||||
"""
|
||||
try:
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"selector": {
|
||||
"type": "short_term_thematic",
|
||||
"description": "短期题材选股器",
|
||||
"strategy": "题材轮动 + 技术面确认 + 风险控制"
|
||||
},
|
||||
"screening": {
|
||||
"min_market_cap": 50, # 最小市值(亿)
|
||||
"max_market_cap": 500, # 最大市值(亿)
|
||||
"min_turnover": 3.0, # 最小换手率(%)
|
||||
"max_turnover": 15.0, # 最大换手率(%)
|
||||
"sector_change_threshold": 2.0, # 板块涨幅阈值(%)
|
||||
"volume_ratio_threshold": 1.2 # 量比阈值
|
||||
},
|
||||
"risk_control": {
|
||||
"max_drawdown": 10.0, # 最大回撤(%)
|
||||
"hard_stop_loss": -7.0, # 硬止损(%)
|
||||
"max_single_position": 20, # 单票最大仓位(%)
|
||||
"max_sector_position": 40, # 单行业最大仓位(%)
|
||||
"max_total_position": 80 # 总仓位最大值(%)
|
||||
},
|
||||
"schedule": {
|
||||
"enabled": settings.astock_monitor_enabled,
|
||||
"time": "15:30", # 盘后运行
|
||||
"timezone": "Asia/Shanghai"
|
||||
},
|
||||
"notifications": {
|
||||
"dingtalk": settings.dingtalk_enabled,
|
||||
"telegram": settings.telegram_enabled
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/sectors")
|
||||
async def get_hot_sectors(limit: int = 10) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前异动板块
|
||||
|
||||
Args:
|
||||
limit: 返回板块数量
|
||||
|
||||
Returns:
|
||||
异动板块列表
|
||||
"""
|
||||
try:
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
|
||||
if not ts_client:
|
||||
raise HTTPException(status_code=400, detail="Tushare客户端未初始化")
|
||||
|
||||
# 获取异动板块
|
||||
sectors_df = ts_client.get_hot_sectors(threshold=2.0)
|
||||
|
||||
if sectors_df.empty:
|
||||
return {
|
||||
"success": True,
|
||||
"count": 0,
|
||||
"sectors": []
|
||||
}
|
||||
|
||||
# 转换为列表格式
|
||||
sectors = []
|
||||
for _, row in sectors_df.head(limit).iterrows():
|
||||
sectors.append({
|
||||
"code": row['ts_code'],
|
||||
"name": row['name'],
|
||||
"change_pct": float(row['change_pct']),
|
||||
"amount": float(row['amount']),
|
||||
"amount_yi": float(row['amount']) / 100000000, # 转换为亿
|
||||
"close": float(row['close'])
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"count": len(sectors),
|
||||
"sectors": sectors
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取异动板块失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -5,15 +5,15 @@ Bitget 实盘交易 API
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Optional
|
||||
|
||||
from app.services.bitget_live_trading_service import get_bitget_live_service
|
||||
from app.services.bitget_live_trading_service import get_all_bitget_live_services, get_bitget_live_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/bitget", tags=["Bitget"])
|
||||
|
||||
|
||||
def _get_service():
|
||||
service = get_bitget_live_service()
|
||||
def _get_service(account_id: str = "default"):
|
||||
service = get_bitget_live_service(account_id)
|
||||
if service is None:
|
||||
return None
|
||||
return service
|
||||
@ -65,11 +65,12 @@ async def get_account():
|
||||
|
||||
@router.get("/positions")
|
||||
async def get_positions(
|
||||
symbol: Optional[str] = Query(None, description="币种筛选,如 BTC")
|
||||
symbol: Optional[str] = Query(None, description="币种筛选,如 BTC"),
|
||||
account_id: str = Query("default", description="Bitget 账号 ID")
|
||||
):
|
||||
"""获取 Bitget 持仓信息"""
|
||||
try:
|
||||
service = _get_service()
|
||||
service = _get_service(account_id)
|
||||
if service is None:
|
||||
return {"success": True, "enabled": False, "positions": []}
|
||||
|
||||
@ -103,11 +104,12 @@ async def get_positions(
|
||||
|
||||
@router.get("/orders")
|
||||
async def get_orders(
|
||||
symbol: Optional[str] = Query(None, description="币种筛选,如 BTC")
|
||||
symbol: Optional[str] = Query(None, description="币种筛选,如 BTC"),
|
||||
account_id: str = Query("default", description="Bitget 账号 ID")
|
||||
):
|
||||
"""获取 Bitget 挂单信息"""
|
||||
try:
|
||||
service = _get_service()
|
||||
service = _get_service(account_id)
|
||||
if service is None:
|
||||
return {"success": True, "enabled": False, "orders": []}
|
||||
|
||||
@ -137,24 +139,25 @@ async def get_orders(
|
||||
async def get_summary():
|
||||
"""获取 Bitget 交易摘要"""
|
||||
try:
|
||||
services = get_all_bitget_live_services()
|
||||
if not services:
|
||||
service = _get_service()
|
||||
if service is None:
|
||||
return {"success": True, "enabled": False, "message": "Bitget 服务未启用"}
|
||||
services = {"default": service} if service else {}
|
||||
|
||||
if not services:
|
||||
return {"success": True, "enabled": False, "message": "Bitget 服务未启用"}
|
||||
accounts = []
|
||||
for account_id, service in services.items():
|
||||
state = service.get_account_state()
|
||||
positions = service.get_open_positions()
|
||||
orders = service.get_open_orders()
|
||||
total_position_value = sum(abs(p["size"]) * p["entry_price"] for p in positions)
|
||||
|
||||
current_leverage = total_position_value / state["account_value"] if state["account_value"] > 0 else 0
|
||||
drawdown = 0
|
||||
if service.initial_balance and service.initial_balance > 0:
|
||||
drawdown = (service.initial_balance - state["account_value"]) / service.initial_balance
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": True,
|
||||
"data": {
|
||||
accounts.append({
|
||||
"account_id": account_id,
|
||||
"account": {
|
||||
"account_value": state["account_value"],
|
||||
"available_balance": state["available_balance"],
|
||||
@ -173,6 +176,37 @@ async def get_summary():
|
||||
"drawdown": drawdown * 100,
|
||||
"circuit_breaker_threshold": service.circuit_breaker_drawdown * 100,
|
||||
},
|
||||
})
|
||||
|
||||
total_account_value = sum(item["account"]["account_value"] for item in accounts)
|
||||
total_available = sum(item["account"]["available_balance"] for item in accounts)
|
||||
total_margin = sum(item["account"]["total_margin_used"] for item in accounts)
|
||||
total_positions = sum(item["positions"]["count"] for item in accounts)
|
||||
total_position_value = sum(item["positions"]["total_value"] for item in accounts)
|
||||
total_orders = sum(item["orders"]["count"] for item in accounts)
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": True,
|
||||
"data": {
|
||||
"account": {
|
||||
"account_value": total_account_value,
|
||||
"available_balance": total_available,
|
||||
"total_margin_used": total_margin,
|
||||
},
|
||||
"positions": {"count": total_positions, "total_value": total_position_value},
|
||||
"orders": {
|
||||
"count": total_orders,
|
||||
"entry_orders": sum(item["orders"]["entry_orders"] for item in accounts),
|
||||
"tp_sl_orders": sum(item["orders"]["tp_sl_orders"] for item in accounts),
|
||||
},
|
||||
"risk": {
|
||||
"current_leverage": total_position_value / total_account_value if total_account_value > 0 else 0,
|
||||
"max_leverage": max((item["risk"]["max_leverage"] for item in accounts), default=0),
|
||||
"leverage_utilization": 0,
|
||||
"drawdown": max((item["risk"]["drawdown"] for item in accounts), default=0),
|
||||
"circuit_breaker_threshold": max((item["risk"]["circuit_breaker_threshold"] for item in accounts), default=0),
|
||||
},
|
||||
"accounts": accounts,
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
|
||||
@ -1,133 +0,0 @@
|
||||
"""
|
||||
对话API路由
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Optional
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from app.models.chat import ChatRequest, ChatResponse
|
||||
from app.models.database import User
|
||||
from app.agent.smart_agent import smart_agent # 使用智能Agent
|
||||
from app.middleware.auth_middleware import get_current_user
|
||||
from app.utils.logger import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/message", response_model=ChatResponse)
|
||||
async def send_message(request: ChatRequest):
|
||||
"""
|
||||
发送消息给Agent
|
||||
|
||||
Args:
|
||||
request: 聊天请求
|
||||
|
||||
Returns:
|
||||
Agent响应
|
||||
"""
|
||||
try:
|
||||
# 生成或使用现有session_id
|
||||
session_id = request.session_id or str(uuid.uuid4())
|
||||
|
||||
# 处理消息(使用智能Agent)
|
||||
response = await smart_agent.process_message(
|
||||
message=request.message,
|
||||
session_id=session_id,
|
||||
user_id=request.user_id
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
message=response["message"],
|
||||
session_id=session_id,
|
||||
metadata=response.get("metadata")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history/{session_id}")
|
||||
async def get_history(session_id: str, limit: int = 50):
|
||||
"""
|
||||
获取对话历史
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
limit: 最大消息数
|
||||
|
||||
Returns:
|
||||
对话历史
|
||||
"""
|
||||
try:
|
||||
context = smart_agent.context_manager.get_context(session_id)
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"messages": context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取历史失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/message/stream")
|
||||
async def send_message_stream(
|
||||
request: ChatRequest,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
流式发送消息给Agent
|
||||
|
||||
Args:
|
||||
request: 聊天请求
|
||||
current_user: 当前登录用户
|
||||
|
||||
Returns:
|
||||
Server-Sent Events 流式响应
|
||||
"""
|
||||
try:
|
||||
# 生成或使用现有session_id
|
||||
session_id = request.session_id or str(uuid.uuid4())
|
||||
|
||||
async def event_generator():
|
||||
"""生成SSE事件流"""
|
||||
try:
|
||||
# 发送session_id
|
||||
yield f"data: {json.dumps({'type': 'session_id', 'session_id': session_id})}\n\n"
|
||||
|
||||
# 添加小延迟确保数据被发送
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# 处理消息并流式返回(使用真实用户ID)
|
||||
async for chunk in smart_agent.process_message_stream(
|
||||
message=request.message,
|
||||
session_id=session_id,
|
||||
user_id=str(current_user.id)
|
||||
):
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk})}\n\n"
|
||||
# 添加小延迟,让浏览器有机会接收数据
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
# 发送完成信号
|
||||
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式处理消息失败: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Transfer-Encoding": "chunked"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建流式响应失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -1,179 +0,0 @@
|
||||
"""
|
||||
新闻 API - 提供新闻查询接口
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from app.news_agent.news_agent import get_news_agent
|
||||
from app.news_agent.news_db_service import get_news_db_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/news", tags=["新闻管理"])
|
||||
|
||||
|
||||
@router.get("/articles")
|
||||
async def get_articles(
|
||||
category: Optional[str] = Query(None, description="分类过滤 (crypto/stock)"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
hours: int = Query(24, ge=1, le=168, description="查询最近多少小时")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取新闻文章列表
|
||||
|
||||
Args:
|
||||
category: 分类过滤
|
||||
limit: 返回数量限制
|
||||
hours: 查询最近多少小时
|
||||
|
||||
Returns:
|
||||
文章列表
|
||||
"""
|
||||
try:
|
||||
db_service = get_news_db_service()
|
||||
articles = db_service.get_latest_articles(
|
||||
category=category,
|
||||
limit=limit,
|
||||
hours=hours
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'articles': articles,
|
||||
'count': len(articles)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取新闻文章失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_news_stats(
|
||||
hours: int = Query(24, ge=1, le=168, description="统计最近多少小时")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取新闻统计信息
|
||||
|
||||
Args:
|
||||
hours: 统计最近多少小时
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
news_agent = get_news_agent()
|
||||
agent_stats = news_agent.get_stats()
|
||||
|
||||
db_service = get_news_db_service()
|
||||
db_stats = db_service.get_stats(hours=hours)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'agent': agent_stats,
|
||||
'database': db_stats
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取新闻统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/fetch")
|
||||
async def manual_fetch(
|
||||
category: Optional[str] = Query(None, description="分类过滤 (crypto/stock)")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发新闻抓取
|
||||
|
||||
Args:
|
||||
category: 分类过滤
|
||||
|
||||
Returns:
|
||||
抓取结果
|
||||
"""
|
||||
try:
|
||||
news_agent = get_news_agent()
|
||||
result = await news_agent.manual_fetch(category)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
**result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"手动抓取新闻失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/high-priority")
|
||||
async def get_high_priority_articles(
|
||||
limit: int = Query(20, ge=1, le=100, description="返回数量限制"),
|
||||
min_priority: float = Query(40.0, description="最低优先级分数"),
|
||||
hours: int = Query(24, ge=1, le=168, description="查询最近多少小时")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取高优先级文章
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
min_priority: 最低优先级分数
|
||||
hours: 查询最近多少小时
|
||||
|
||||
Returns:
|
||||
高优先级文章列表
|
||||
"""
|
||||
try:
|
||||
db_service = get_news_db_service()
|
||||
articles = db_service.get_high_priority_articles(
|
||||
limit=limit,
|
||||
min_priority=min_priority,
|
||||
hours=hours
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'articles': [article.to_dict() for article in articles],
|
||||
'count': len(articles)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取高优先级文章失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/sources")
|
||||
async def get_news_sources() -> Dict[str, Any]:
|
||||
"""
|
||||
获取新闻源配置
|
||||
|
||||
Returns:
|
||||
新闻源列表
|
||||
"""
|
||||
try:
|
||||
from app.news_agent.sources import CRYPTO_NEWS_SOURCES, STOCK_NEWS_SOURCES
|
||||
|
||||
# 只返回基本信息,隐藏敏感配置
|
||||
crypto_sources = [
|
||||
{
|
||||
'name': s['name'],
|
||||
'category': s['category'],
|
||||
'enabled': s['enabled']
|
||||
}
|
||||
for s in CRYPTO_NEWS_SOURCES
|
||||
]
|
||||
|
||||
stock_sources = [
|
||||
{
|
||||
'name': s['name'],
|
||||
'category': s['category'],
|
||||
'enabled': s['enabled']
|
||||
}
|
||||
for s in STOCK_NEWS_SOURCES
|
||||
]
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'crypto': crypto_sources,
|
||||
'stock': stock_sources,
|
||||
'total': len(crypto_sources) + len(stock_sources)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取新闻源失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -31,6 +31,14 @@ class DeleteOrdersRequest(BaseModel):
|
||||
class ResumePlatformRequest(BaseModel):
|
||||
"""恢复平台执行请求"""
|
||||
platform: str
|
||||
target_key: Optional[str] = None
|
||||
|
||||
|
||||
class ExecutionControlRequest(BaseModel):
|
||||
"""执行目标自动交易开关请求"""
|
||||
target_key: str
|
||||
enabled: bool
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class OrderResponse(BaseModel):
|
||||
@ -239,16 +247,55 @@ async def get_platform_halts():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/execution-controls")
|
||||
async def get_execution_controls():
|
||||
"""获取目标级自动交易开关状态"""
|
||||
try:
|
||||
agent = get_crypto_agent()
|
||||
return {
|
||||
"success": True,
|
||||
"execution_controls": agent.get_target_execution_status(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取自动交易控制状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/execution-controls")
|
||||
async def set_execution_controls(request: ExecutionControlRequest):
|
||||
"""设置目标级自动交易开关"""
|
||||
try:
|
||||
agent = get_crypto_agent()
|
||||
result = agent.set_target_execution_enabled(
|
||||
target_key=request.target_key,
|
||||
enabled=request.enabled,
|
||||
reason=request.reason or "",
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"{request.target_key} 自动交易已{'开启' if request.enabled else '关闭'}",
|
||||
"target_key": request.target_key,
|
||||
"status": result,
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"设置自动交易控制状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/platform-halts/resume")
|
||||
async def resume_platform(request: ResumePlatformRequest):
|
||||
"""手动恢复指定平台执行"""
|
||||
try:
|
||||
agent = get_crypto_agent()
|
||||
result = agent.resume_platform(request.platform)
|
||||
target = request.target_key or request.platform
|
||||
result = agent.resume_platform(target)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"{request.platform} 已恢复执行",
|
||||
"message": f"{target} 已恢复执行",
|
||||
"platform": request.platform,
|
||||
"target_key": target,
|
||||
"status": result,
|
||||
}
|
||||
except ValueError as e:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
信号 API - 提供加密货币和美股信号查询接口(数据库版本)
|
||||
信号 API - 提供加密货币信号查询接口
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Dict, List, Optional, Any
|
||||
@ -53,91 +53,13 @@ async def get_crypto_signals(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stock")
|
||||
async def get_stock_signals(
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
symbol: Optional[str] = Query(None, description="过滤指定股票"),
|
||||
days: int = Query(7, ge=1, le=30, description="查询最近多少天的信号")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取美股信号列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制(默认50)
|
||||
symbol: 过滤指定股票
|
||||
days: 查询最近多少天的信号(默认7天)
|
||||
|
||||
Returns:
|
||||
信号列表
|
||||
"""
|
||||
try:
|
||||
service = get_signal_db_service()
|
||||
|
||||
if symbol:
|
||||
# 获取指定股票的最新信号
|
||||
signal = service.get_latest_signal('stock', symbol)
|
||||
return {
|
||||
'success': True,
|
||||
'symbol': symbol,
|
||||
'signal': signal,
|
||||
'count': 1 if signal else 0
|
||||
}
|
||||
else:
|
||||
# 获取所有信号
|
||||
signals = service.get_stock_signals(limit=limit, days=days)
|
||||
return {
|
||||
'success': True,
|
||||
'signals': signals,
|
||||
'count': len(signals)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股信号失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
async def get_all_signals(
|
||||
limit: int = Query(50, ge=1, le=200, description="每种类型返回数量限制"),
|
||||
days: int = Query(7, ge=1, le=30, description="查询最近多少天的信号")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取所有信号(加密货币 + 美股)
|
||||
|
||||
Args:
|
||||
limit: 每种类型返回数量限制(默认50)
|
||||
days: 查询最近多少天的信号(默认7天)
|
||||
|
||||
Returns:
|
||||
所有信号
|
||||
"""
|
||||
try:
|
||||
service = get_signal_db_service()
|
||||
signals = service.get_all_signals(limit=limit, days=days)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'crypto': {
|
||||
'signals': signals['crypto'],
|
||||
'count': len(signals['crypto'])
|
||||
},
|
||||
'stock': {
|
||||
'signals': signals['stock'],
|
||||
'count': len(signals['stock'])
|
||||
},
|
||||
'total_count': len(signals['crypto']) + len(signals['stock'])
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有信号失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/latest")
|
||||
async def get_latest_signals(
|
||||
limit: int = Query(20, ge=1, le=100, description="返回数量限制"),
|
||||
days: int = Query(7, ge=1, le=30, description="查询最近多少天的信号")
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取最新的所有信号(按时间排序)
|
||||
获取最新的加密货币信号(按时间排序)
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制(默认20)
|
||||
@ -148,7 +70,7 @@ async def get_latest_signals(
|
||||
"""
|
||||
try:
|
||||
service = get_signal_db_service()
|
||||
signals = service.get_latest_signals(limit=limit, days=days)
|
||||
signals = service.get_crypto_signals(limit=limit, days=days)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
@ -179,7 +101,11 @@ async def get_signal_stats(
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
**stats
|
||||
**{
|
||||
'crypto': stats.get('crypto', {'total': 0, 'buy': 0, 'sell': 0, 'recent_24h': 0}),
|
||||
'grades': stats.get('grades', {}),
|
||||
'total': stats.get('crypto', {}).get('total', 0),
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取信号统计失败: {e}")
|
||||
|
||||
@ -1,99 +0,0 @@
|
||||
"""
|
||||
技能管理API路由
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from app.agent.skill_manager import skill_manager
|
||||
from app.utils.logger import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ToggleRequest(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_skills():
|
||||
"""
|
||||
获取所有技能列表
|
||||
|
||||
Returns:
|
||||
技能信息列表
|
||||
"""
|
||||
try:
|
||||
skills_info = skill_manager.get_skills_info()
|
||||
return skills_info # 直接返回数组
|
||||
except Exception as e:
|
||||
logger.error(f"获取技能列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{skill_name}/toggle")
|
||||
async def toggle_skill(skill_name: str, request: ToggleRequest):
|
||||
"""
|
||||
切换技能状态
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
request: 包含enabled字段的请求体
|
||||
|
||||
Returns:
|
||||
操作结果
|
||||
"""
|
||||
try:
|
||||
if request.enabled:
|
||||
success = skill_manager.enable_skill(skill_name)
|
||||
message = f"技能 {skill_name} 已启用"
|
||||
else:
|
||||
success = skill_manager.disable_skill(skill_name)
|
||||
message = f"技能 {skill_name} 已禁用"
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="技能不存在")
|
||||
return {"message": message, "success": True}
|
||||
except Exception as e:
|
||||
logger.error(f"切换技能失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{skill_name}/enable")
|
||||
async def enable_skill(skill_name: str):
|
||||
"""
|
||||
启用技能
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
|
||||
Returns:
|
||||
操作结果
|
||||
"""
|
||||
try:
|
||||
success = skill_manager.enable_skill(skill_name)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="技能不存在")
|
||||
return {"message": f"技能 {skill_name} 已启用"}
|
||||
except Exception as e:
|
||||
logger.error(f"启用技能失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{skill_name}/disable")
|
||||
async def disable_skill(skill_name: str):
|
||||
"""
|
||||
禁用技能
|
||||
|
||||
Args:
|
||||
skill_name: 技能名称
|
||||
|
||||
Returns:
|
||||
操作结果
|
||||
"""
|
||||
try:
|
||||
success = skill_manager.disable_skill(skill_name)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="技能不存在")
|
||||
return {"message": f"技能 {skill_name} 已禁用"}
|
||||
except Exception as e:
|
||||
logger.error(f"禁用技能失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -1,132 +0,0 @@
|
||||
"""
|
||||
股票数据API路由
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from typing import Optional
|
||||
from app.services.tushare_service import tushare_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sector/check")
|
||||
async def trigger_sector_check():
|
||||
"""
|
||||
手动触发板块异动检查
|
||||
|
||||
Returns:
|
||||
检查结果
|
||||
"""
|
||||
try:
|
||||
from app.main import _astock_monitor_instance
|
||||
from app.config import get_settings
|
||||
|
||||
if not _astock_monitor_instance:
|
||||
# 创建临时监控实例
|
||||
from app.astock_agent import SectorMonitor
|
||||
settings = get_settings()
|
||||
monitor = SectorMonitor(
|
||||
change_threshold=settings.astock_change_threshold,
|
||||
top_n=settings.astock_top_n,
|
||||
enable_notifier=False # 手动触发不发送通知
|
||||
)
|
||||
result = await monitor.check_once()
|
||||
return result
|
||||
else:
|
||||
result = await _astock_monitor_instance.check_once()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"手动触发板块检查失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/sector/stats")
|
||||
async def get_sector_stats():
|
||||
"""
|
||||
获取板块监控统计信息
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
try:
|
||||
from app.main import _astock_monitor_instance
|
||||
|
||||
if not _astock_monitor_instance:
|
||||
return {"error": "板块监控未运行"}
|
||||
|
||||
stats = _astock_monitor_instance.get_stats()
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取板块统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/quote/{stock_code}")
|
||||
async def get_quote(stock_code: str):
|
||||
"""
|
||||
获取股票实时行情
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
行情数据
|
||||
"""
|
||||
try:
|
||||
quote = tushare_service.get_realtime_quote(stock_code)
|
||||
if not quote:
|
||||
raise HTTPException(status_code=404, detail="未找到股票数据")
|
||||
return quote
|
||||
except Exception as e:
|
||||
logger.error(f"获取行情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/kline/{stock_code}")
|
||||
async def get_kline(
|
||||
stock_code: str,
|
||||
start_date: Optional[str] = Query(None, description="开始日期YYYYMMDD"),
|
||||
end_date: Optional[str] = Query(None, description="结束日期YYYYMMDD"),
|
||||
period: str = Query("D", description="周期D/W/M")
|
||||
):
|
||||
"""
|
||||
获取K线数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
period: 周期
|
||||
|
||||
Returns:
|
||||
K线数据
|
||||
"""
|
||||
try:
|
||||
kline = tushare_service.get_kline_data(stock_code, start_date, end_date, period)
|
||||
if not kline:
|
||||
raise HTTPException(status_code=404, detail="未找到K线数据")
|
||||
return {"kline_data": kline}
|
||||
except Exception as e:
|
||||
logger.error(f"获取K线失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/basic/{stock_code}")
|
||||
async def get_basic(stock_code: str):
|
||||
"""
|
||||
获取股票基本信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
基本信息
|
||||
"""
|
||||
try:
|
||||
basic = tushare_service.get_stock_basic(stock_code)
|
||||
if not basic:
|
||||
raise HTTPException(status_code=404, detail="未找到股票信息")
|
||||
return basic
|
||||
except Exception as e:
|
||||
logger.error(f"获取基本信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -1,131 +0,0 @@
|
||||
"""
|
||||
美股相关 API 路由
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Dict, Any, List
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 全局变量,用于访问智能体实例
|
||||
_stock_agent_instance = None
|
||||
|
||||
|
||||
def set_stock_agent(agent):
|
||||
"""设置智能体实例(由 main.py 调用)"""
|
||||
global _stock_agent_instance
|
||||
_stock_agent_instance = agent
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_stock_status() -> Dict[str, Any]:
|
||||
"""
|
||||
获取美股智能体状态
|
||||
|
||||
Returns:
|
||||
智能体状态信息
|
||||
"""
|
||||
try:
|
||||
if _stock_agent_instance is None:
|
||||
return {
|
||||
"enabled": False,
|
||||
"message": "美股智能体未启用"
|
||||
}
|
||||
|
||||
status = _stock_agent_instance.get_status()
|
||||
status["enabled"] = True
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/symbols")
|
||||
async def get_stock_symbols() -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前监控的股票列表
|
||||
|
||||
Returns:
|
||||
股票列表
|
||||
"""
|
||||
try:
|
||||
settings = get_settings()
|
||||
symbols = settings.stock_symbols.split(',') if settings.stock_symbols else []
|
||||
|
||||
return {
|
||||
"symbols": symbols,
|
||||
"count": len(symbols)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/analyze/{symbol}")
|
||||
async def analyze_stock(symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发分析指定股票
|
||||
|
||||
Args:
|
||||
symbol: 股票代码,如 'AAPL'
|
||||
|
||||
Returns:
|
||||
分析结果
|
||||
"""
|
||||
try:
|
||||
if _stock_agent_instance is None:
|
||||
raise HTTPException(status_code=400, detail="美股智能体未启用")
|
||||
|
||||
# 执行单次分析
|
||||
result = await _stock_agent_instance.analyze_once(symbol)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=400, detail=result["error"])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"symbol": symbol,
|
||||
"result": result
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"分析 {symbol} 失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/quote/{symbol}")
|
||||
async def get_stock_quote(symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取股票实时行情
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
|
||||
Returns:
|
||||
行情数据
|
||||
"""
|
||||
try:
|
||||
from app.services.yfinance_service import get_yfinance_service
|
||||
|
||||
yf_service = get_yfinance_service()
|
||||
quote = yf_service.get_ticker(symbol.upper())
|
||||
|
||||
if quote is None:
|
||||
raise HTTPException(status_code=404, detail=f"无法获取 {symbol} 的行情")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"symbol": symbol.upper(),
|
||||
"quote": quote
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {symbol} 行情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -10,7 +10,7 @@ from app.utils.system_status import get_system_monitor
|
||||
from app.crypto_agent.crypto_agent import get_crypto_agent
|
||||
from app.services.signal_database_service import get_signal_db_service
|
||||
from app.services.paper_trading_service import get_paper_trading_service
|
||||
from app.services.bitget_live_trading_service import get_bitget_live_service
|
||||
from app.services.bitget_live_trading_service import get_all_bitget_live_services, get_bitget_live_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -114,6 +114,7 @@ def _normalize_platform_order(platform: str, order: Dict[str, Any]) -> Dict[str,
|
||||
|
||||
return {
|
||||
"platform": platform,
|
||||
"account_id": order.get("account_id"),
|
||||
"symbol": symbol,
|
||||
"side": side,
|
||||
"category": category,
|
||||
@ -132,6 +133,44 @@ def _normalize_platform_order(platform: str, order: Dict[str, Any]) -> Dict[str,
|
||||
}
|
||||
|
||||
|
||||
def _build_bitget_account_summary(account_id: str, service: Any) -> Dict[str, Any]:
|
||||
bg_account = service.get_account_state()
|
||||
bg_positions = service.get_open_positions()
|
||||
bg_orders = service.get_open_orders()
|
||||
bg_total_position_value = sum(abs(p["size"]) * p["entry_price"] for p in bg_positions)
|
||||
bg_drawdown = 0.0
|
||||
if service.initial_balance and service.initial_balance > 0:
|
||||
bg_drawdown = (service.initial_balance - bg_account["account_value"]) / service.initial_balance * 100
|
||||
|
||||
return {
|
||||
"account_id": account_id,
|
||||
"enabled": True,
|
||||
"account": {
|
||||
"account_value": bg_account.get("account_value", 0),
|
||||
"available_balance": bg_account.get("available_balance", 0),
|
||||
"total_margin_used": bg_account.get("total_margin_used", 0),
|
||||
"initial_balance": service.initial_balance,
|
||||
},
|
||||
"positions": {
|
||||
"count": len(bg_positions),
|
||||
"total_value": bg_total_position_value,
|
||||
"items": bg_positions[:8],
|
||||
},
|
||||
"orders": {
|
||||
"count": len(bg_orders),
|
||||
"entry_orders": len([o for o in bg_orders if not o.get("is_reduce_only")]),
|
||||
"tp_sl_orders": len([o for o in bg_orders if o.get("is_reduce_only")]),
|
||||
"items": bg_orders[:8],
|
||||
},
|
||||
"risk": {
|
||||
"current_leverage": bg_total_position_value / bg_account["account_value"] if bg_account.get("account_value", 0) > 0 else 0,
|
||||
"max_leverage": service.max_total_leverage,
|
||||
"drawdown_percent": bg_drawdown,
|
||||
"circuit_breaker_threshold": service.circuit_breaker_drawdown * 100,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_attention_items(
|
||||
platform_halts: Dict[str, Any],
|
||||
platforms: Dict[str, Any],
|
||||
@ -269,7 +308,7 @@ async def get_agent_status(agent_id: str):
|
||||
获取指定 Agent 的详细状态
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (如: crypto_agent, stock_agent)
|
||||
agent_id: Agent ID (如: crypto_agent)
|
||||
"""
|
||||
try:
|
||||
monitor = get_system_monitor()
|
||||
@ -304,7 +343,7 @@ async def get_console_snapshot():
|
||||
|
||||
signal_db = get_signal_db_service()
|
||||
signal_stats = signal_db.get_signal_stats(days=7)
|
||||
latest_signals = signal_db.get_latest_signals(limit=12, days=3)
|
||||
latest_signals = signal_db.get_crypto_signals(limit=12, days=3)
|
||||
|
||||
crypto_agent = get_crypto_agent()
|
||||
crypto_status = crypto_agent.get_status()
|
||||
@ -316,43 +355,66 @@ async def get_console_snapshot():
|
||||
paper_pending = [o for o in paper_orders if o.get('status') == 'pending']
|
||||
paper_stats = paper_service.calculate_statistics()
|
||||
|
||||
bitget_service = get_bitget_live_service()
|
||||
bitget_summary = {"enabled": False}
|
||||
if bitget_service is not None:
|
||||
bg_account = bitget_service.get_account_state()
|
||||
bg_positions = bitget_service.get_open_positions()
|
||||
bg_orders = bitget_service.get_open_orders()
|
||||
bg_total_position_value = sum(abs(p["size"]) * p["entry_price"] for p in bg_positions)
|
||||
bg_drawdown = 0.0
|
||||
if bitget_service.initial_balance and bitget_service.initial_balance > 0:
|
||||
bg_drawdown = (bitget_service.initial_balance - bg_account["account_value"]) / bitget_service.initial_balance * 100
|
||||
bitget_services = get_all_bitget_live_services()
|
||||
if not bitget_services:
|
||||
default_bitget = get_bitget_live_service()
|
||||
if default_bitget:
|
||||
bitget_services = {"default": default_bitget}
|
||||
|
||||
bitget_accounts = []
|
||||
for account_id, service in bitget_services.items():
|
||||
try:
|
||||
bitget_accounts.append(_build_bitget_account_summary(account_id, service))
|
||||
except Exception as exc:
|
||||
logger.error(f"获取 Bitget 账号摘要失败: account={account_id} error={exc}")
|
||||
|
||||
if bitget_accounts:
|
||||
total_account_value = sum(item["account"]["account_value"] for item in bitget_accounts)
|
||||
total_available_balance = sum(item["account"]["available_balance"] for item in bitget_accounts)
|
||||
total_margin_used = sum(item["account"]["total_margin_used"] for item in bitget_accounts)
|
||||
total_positions_count = sum(item["positions"]["count"] for item in bitget_accounts)
|
||||
total_position_value = sum(item["positions"]["total_value"] for item in bitget_accounts)
|
||||
total_orders_count = sum(item["orders"]["count"] for item in bitget_accounts)
|
||||
total_entry_orders = sum(item["orders"]["entry_orders"] for item in bitget_accounts)
|
||||
total_tp_sl_orders = sum(item["orders"]["tp_sl_orders"] for item in bitget_accounts)
|
||||
leverage_weight = total_account_value if total_account_value > 0 else len(bitget_accounts)
|
||||
weighted_drawdown = sum(
|
||||
item["risk"]["drawdown_percent"] * (
|
||||
item["account"]["account_value"] if total_account_value > 0 else 1
|
||||
)
|
||||
for item in bitget_accounts
|
||||
) / leverage_weight if leverage_weight > 0 else 0
|
||||
max_leverage = max((item["risk"]["max_leverage"] for item in bitget_accounts), default=0)
|
||||
breaker_threshold = max((item["risk"]["circuit_breaker_threshold"] for item in bitget_accounts), default=0)
|
||||
|
||||
bitget_summary = {
|
||||
"enabled": True,
|
||||
"accounts": bitget_accounts,
|
||||
"account": {
|
||||
"account_value": bg_account.get("account_value", 0),
|
||||
"available_balance": bg_account.get("available_balance", 0),
|
||||
"total_margin_used": bg_account.get("total_margin_used", 0),
|
||||
"initial_balance": bitget_service.initial_balance,
|
||||
"account_value": total_account_value,
|
||||
"available_balance": total_available_balance,
|
||||
"total_margin_used": total_margin_used,
|
||||
},
|
||||
"positions": {
|
||||
"count": len(bg_positions),
|
||||
"total_value": bg_total_position_value,
|
||||
"items": bg_positions[:8],
|
||||
"count": total_positions_count,
|
||||
"total_value": total_position_value,
|
||||
"items": [item for account in bitget_accounts for item in account["positions"]["items"]][:12],
|
||||
},
|
||||
"orders": {
|
||||
"count": len(bg_orders),
|
||||
"entry_orders": len([o for o in bg_orders if not o.get("is_reduce_only")]),
|
||||
"tp_sl_orders": len([o for o in bg_orders if o.get("is_reduce_only")]),
|
||||
"items": bg_orders[:8],
|
||||
"count": total_orders_count,
|
||||
"entry_orders": total_entry_orders,
|
||||
"tp_sl_orders": total_tp_sl_orders,
|
||||
"items": [item for account in bitget_accounts for item in account["orders"]["items"]][:12],
|
||||
},
|
||||
"risk": {
|
||||
"current_leverage": bg_total_position_value / bg_account["account_value"] if bg_account.get("account_value", 0) > 0 else 0,
|
||||
"max_leverage": bitget_service.max_total_leverage,
|
||||
"drawdown_percent": bg_drawdown,
|
||||
"circuit_breaker_threshold": bitget_service.circuit_breaker_drawdown * 100,
|
||||
"current_leverage": total_position_value / total_account_value if total_account_value > 0 else 0,
|
||||
"max_leverage": max_leverage,
|
||||
"drawdown_percent": weighted_drawdown,
|
||||
"circuit_breaker_threshold": breaker_threshold,
|
||||
},
|
||||
}
|
||||
else:
|
||||
bitget_summary = {"enabled": False, "accounts": []}
|
||||
|
||||
recent_cutoff = now - timedelta(minutes=30)
|
||||
recent_signal_count = sum(
|
||||
@ -370,14 +432,17 @@ async def get_console_snapshot():
|
||||
for order in paper_pending[:12]
|
||||
]
|
||||
|
||||
bitget_position_items = [
|
||||
_normalize_platform_position("bitget", pos)
|
||||
for pos in (bg_positions[:12] if bitget_service is not None else [])
|
||||
]
|
||||
bitget_order_items = [
|
||||
_normalize_platform_order("bitget", order)
|
||||
for order in (bg_orders[:12] if bitget_service is not None else [])
|
||||
]
|
||||
bitget_position_items = []
|
||||
bitget_order_items = []
|
||||
for account in bitget_accounts:
|
||||
for pos in account["positions"]["items"][:12]:
|
||||
normalized = _normalize_platform_position("bitget", pos)
|
||||
normalized["account_id"] = account["account_id"]
|
||||
bitget_position_items.append(normalized)
|
||||
for order in account["orders"]["items"][:12]:
|
||||
enriched_order = dict(order)
|
||||
enriched_order["account_id"] = account["account_id"]
|
||||
bitget_order_items.append(_normalize_platform_order("bitget", enriched_order))
|
||||
|
||||
unified_positions = sorted(
|
||||
paper_position_items + bitget_position_items,
|
||||
@ -437,7 +502,11 @@ async def get_console_snapshot():
|
||||
"crypto_agent": crypto_status,
|
||||
"execution_events": execution_events,
|
||||
"signals": {
|
||||
"stats_7d": signal_stats,
|
||||
"stats_7d": {
|
||||
"crypto": signal_stats.get("crypto", {"total": 0, "buy": 0, "sell": 0, "recent_24h": 0}),
|
||||
"grades": signal_stats.get("grades", {}),
|
||||
"total": signal_stats.get("crypto", {}).get("total", 0),
|
||||
},
|
||||
"latest": latest_signals,
|
||||
"recent_30m_count": recent_signal_count,
|
||||
},
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
"""
|
||||
A 股板块异动监控 Agent
|
||||
提供 Tushare 数据源版本
|
||||
"""
|
||||
from .sector_monitor import SectorMonitor
|
||||
from .tushare_client import TushareClient, get_tushare_client
|
||||
from .tushare_sector_analyzer import TushareSectorAnalyzer
|
||||
from .tushare_stock_selector import TushareStockSelector
|
||||
from .short_term_thematic_selector import ShortTermThematicSelector, get_thematic_selector
|
||||
from .astock_agent import AStockAgent, get_astock_agent
|
||||
|
||||
__all__ = [
|
||||
'SectorMonitor',
|
||||
'TushareClient',
|
||||
'get_tushare_client',
|
||||
'TushareSectorAnalyzer',
|
||||
'TushareStockSelector',
|
||||
'ShortTermThematicSelector',
|
||||
'get_thematic_selector',
|
||||
'AStockAgent',
|
||||
'get_astock_agent',
|
||||
]
|
||||
@ -1,234 +0,0 @@
|
||||
"""
|
||||
Akshare 数据封装
|
||||
提供 A 股板块、个股行情数据获取接口
|
||||
支持概念板块和行业板块
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import akshare as ak
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
# 禁用全局代理设置
|
||||
os.environ.pop('HTTP_PROXY', None)
|
||||
os.environ.pop('HTTPS_PROXY', None)
|
||||
os.environ.pop('http_proxy', None)
|
||||
os.environ.pop('https_proxy', None)
|
||||
|
||||
# Monkey patch requests 以禁用代理
|
||||
import requests
|
||||
_original_session_init = requests.Session.__init__
|
||||
|
||||
|
||||
def _patched_session_init(self, *args, **kwargs):
|
||||
_original_session_init(self, *args, **kwargs)
|
||||
self.trust_env = False
|
||||
self.proxies = {'http': None, 'https': None}
|
||||
|
||||
|
||||
requests.Session.__init__ = _patched_session_init
|
||||
|
||||
|
||||
class AkshareClient:
|
||||
"""Akshare 数据客户端"""
|
||||
|
||||
# 缓存数据,避免频繁请求
|
||||
_cache = {}
|
||||
_cache_time = {}
|
||||
_last_request_time = 0
|
||||
|
||||
def __init__(self):
|
||||
"""初始化客户端"""
|
||||
self.cache_ttl = 60 # 缓存60秒
|
||||
self.request_delay = 1.0 # 请求间隔(秒)
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
def _get_cached(self, key: str, fetch_func) -> pd.DataFrame:
|
||||
"""获取缓存数据,支持重试"""
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if key in self._cache:
|
||||
cache_time = self._cache_time.get(key)
|
||||
if cache_time and (now - cache_time).seconds < self.cache_ttl:
|
||||
logger.debug(f"使用缓存数据: {key}")
|
||||
return self._cache[key]
|
||||
|
||||
# 请求限流
|
||||
elapsed = now.timestamp() - self._last_request_time
|
||||
if elapsed < self.request_delay:
|
||||
time.sleep(self.request_delay - elapsed)
|
||||
|
||||
# 重试逻辑
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
self._last_request_time = time.time()
|
||||
df = fetch_func()
|
||||
|
||||
if df is not None and not df.empty:
|
||||
self._cache[key] = df
|
||||
self._cache_time[key] = now
|
||||
logger.debug(f"获取数据成功: {key}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_msg = str(e)
|
||||
|
||||
# 判断错误类型
|
||||
if 'Connection' in error_msg or 'RemoteDisconnected' in error_msg:
|
||||
# 连接错误,指数退避重试
|
||||
if attempt < self.max_retries - 1:
|
||||
wait_time = (2 ** attempt) * 2 # 2, 4, 8秒
|
||||
logger.warning(
|
||||
f"获取数据失败 {key} (尝试 {attempt + 1}/{self.max_retries}): {e},"
|
||||
f"等待 {wait_time}秒后重试..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# 其他错误或重试次数用尽
|
||||
logger.error(f"获取数据失败 {key}: {e}")
|
||||
break
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_concept_spot(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取概念板块行情(实时)
|
||||
|
||||
Returns:
|
||||
概念板块行情数据
|
||||
"""
|
||||
def fetch():
|
||||
# stock_board_concept_name_em - 东方财富概念板块行情
|
||||
return ak.stock_board_concept_name_em()
|
||||
|
||||
return self._get_cached('concept_spot', fetch)
|
||||
|
||||
def get_industry_spot(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取行业板块行情(实时)
|
||||
|
||||
Returns:
|
||||
行业板块行情数据
|
||||
"""
|
||||
def fetch():
|
||||
# stock_board_industry_name_em - 东方财富行业板块行情
|
||||
return ak.stock_board_industry_name_em()
|
||||
|
||||
return self._get_cached('industry_spot', fetch)
|
||||
|
||||
def get_concept_stocks(self, sector_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
获取概念板块成分股
|
||||
|
||||
Args:
|
||||
sector_name: 板块名称
|
||||
|
||||
Returns:
|
||||
成分股数据
|
||||
"""
|
||||
def fetch():
|
||||
# stock_board_concept_cons_em - 概念板块成分股
|
||||
df = ak.stock_board_concept_cons_em(symbol=sector_name)
|
||||
return df if df is not None else pd.DataFrame()
|
||||
|
||||
return self._get_cached(f'concept_stocks_{sector_name}', fetch)
|
||||
|
||||
def get_industry_stocks(self, sector_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
获取行业板块成分股
|
||||
|
||||
Args:
|
||||
sector_name: 板块名称
|
||||
|
||||
Returns:
|
||||
成分股数据
|
||||
"""
|
||||
def fetch():
|
||||
# stock_board_industry_cons_em - 行业板块成分股
|
||||
df = ak.stock_board_industry_cons_em(symbol=sector_name)
|
||||
return df if df is not None else pd.DataFrame()
|
||||
|
||||
return self._get_cached(f'industry_stocks_{sector_name}', fetch)
|
||||
|
||||
def get_stock_spot(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取 A 股实时行情
|
||||
|
||||
Returns:
|
||||
A 股实时行情数据
|
||||
"""
|
||||
def fetch():
|
||||
return ak.stock_zh_a_spot_em()
|
||||
|
||||
return self._get_cached('stock_spot', fetch)
|
||||
|
||||
def get_stock_fund_flow(self, symbol: str) -> pd.DataFrame:
|
||||
"""
|
||||
获取个股资金流向
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
|
||||
Returns:
|
||||
资金流向数据
|
||||
"""
|
||||
def fetch():
|
||||
return ak.stock_individual_fund_flow(
|
||||
stock=symbol,
|
||||
market="sh" if symbol.startswith('6') else "sz"
|
||||
)
|
||||
|
||||
return self._get_cached(f'fund_flow_{symbol}', fetch)
|
||||
|
||||
def get_stock_info(self, symbol: str) -> Dict:
|
||||
"""
|
||||
获取个股基本信息
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
|
||||
Returns:
|
||||
股票信息字典
|
||||
"""
|
||||
try:
|
||||
info = ak.stock_individual_info_em(symbol=symbol)
|
||||
return {
|
||||
'name': info.get('股票简称', ''),
|
||||
'industry': info.get('行业', ''),
|
||||
'market_cap': info.get('总市值', ''),
|
||||
'float_cap': info.get('流通市值', ''),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票信息失败 {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_limit_list_stocks(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取涨停板股票
|
||||
|
||||
Returns:
|
||||
涨停板股票列表
|
||||
"""
|
||||
def fetch():
|
||||
return ak.stock_zt_pool_em(date=datetime.now().strftime('%Y%m%d'))
|
||||
|
||||
return self._get_cached('limit_list', fetch)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_akshare_client: Optional[AkshareClient] = None
|
||||
|
||||
|
||||
def get_akshare_client() -> AkshareClient:
|
||||
"""获取 Akshare 客户端单例"""
|
||||
global _akshare_client
|
||||
if _akshare_client is None:
|
||||
_akshare_client = AkshareClient()
|
||||
return _akshare_client
|
||||
@ -1,206 +0,0 @@
|
||||
"""
|
||||
A股智能体 - 主控制器
|
||||
负责执行每日选股并发送通知
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, time
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from app.services.dingtalk_service import get_dingtalk_service
|
||||
from app.services.telegram_service import get_telegram_service
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
from app.astock_agent.short_term_thematic_selector import get_thematic_selector
|
||||
|
||||
|
||||
class AStockAgent:
|
||||
"""A股智能体"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""初始化智能体"""
|
||||
if AStockAgent._initialized:
|
||||
return
|
||||
|
||||
AStockAgent._initialized = True
|
||||
self.settings = get_settings()
|
||||
|
||||
# 初始化Tushare客户端
|
||||
self.ts_client = get_tushare_client(self.settings.tushare_token)
|
||||
if not self.ts_client:
|
||||
logger.error("Tushare客户端初始化失败,请检查配置")
|
||||
raise Exception("Tushare客户端初始化失败")
|
||||
|
||||
# 初始化选股器
|
||||
self.selector = get_thematic_selector(self.ts_client)
|
||||
|
||||
# 初始化通知服务
|
||||
self.dingtalk = get_dingtalk_service()
|
||||
self.telegram = get_telegram_service()
|
||||
|
||||
# 运行状态
|
||||
self.running = False
|
||||
self._task = None
|
||||
|
||||
logger.info("A股智能体初始化完成")
|
||||
|
||||
async def run_once(self) -> Dict[str, Any]:
|
||||
"""
|
||||
执行一次选股
|
||||
|
||||
Returns:
|
||||
选股结果
|
||||
"""
|
||||
try:
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("📊 开始执行短期题材选股")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 执行选股
|
||||
result = self.selector.select_stocks(max_stocks=10)
|
||||
|
||||
# 输出日志
|
||||
self._log_result(result)
|
||||
|
||||
# 发送通知
|
||||
await self._send_notifications(result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选股执行失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return {}
|
||||
|
||||
def _log_result(self, result: Dict[str, Any]):
|
||||
"""输出选股结果到日志"""
|
||||
if not result or result.get('total_stocks', 0) == 0:
|
||||
logger.info("\n📊 今日未选出符合条件的股票")
|
||||
return
|
||||
|
||||
logger.info(f"\n📊 选股完成,共选出 {result['total_stocks']} 只股票")
|
||||
|
||||
if result.get('summary'):
|
||||
summary = result['summary']
|
||||
logger.info(f" - 总仓位: {summary.get('position_percent', 0):.1f}%")
|
||||
logger.info(f" - 涉及板块: {summary.get('sector_count', 0)} 个")
|
||||
|
||||
for stock in result.get('stocks', []):
|
||||
logger.info(f" - {stock['name']}({stock['ts_code']}): {stock['close']:.2f}元, "
|
||||
f"仓位:{stock['position']*100:.1f}%, 评分:{stock['score']:.1f}分")
|
||||
|
||||
async def _send_notifications(self, result: Dict[str, Any]):
|
||||
"""发送选股通知"""
|
||||
try:
|
||||
# 格式化输出文本
|
||||
text = self.selector.format_output_text(result)
|
||||
|
||||
# 发送到钉钉
|
||||
if self.settings.dingtalk_enabled:
|
||||
await self.dingtalk.send_markdown(
|
||||
"📊 短期题材选股结果",
|
||||
text
|
||||
)
|
||||
logger.info("✅ 钉钉通知已发送")
|
||||
|
||||
# 发送到Telegram
|
||||
if self.settings.telegram_enabled:
|
||||
await self.telegram.send_message(text)
|
||||
logger.info("✅ Telegram通知已发送")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送通知失败: {e}")
|
||||
|
||||
async def run_daily(self, run_time: str = "15:30"):
|
||||
"""
|
||||
每日定时运行
|
||||
|
||||
Args:
|
||||
run_time: 运行时间(HH:MM格式,24小时制)
|
||||
"""
|
||||
self.running = True
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("🚀 A股智能体已启动")
|
||||
logger.info(f"⏰ 运行时间: 每天 {run_time}(盘后)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 解析运行时间
|
||||
hour, minute = map(int, run_time.split(':'))
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 计算下次运行时间
|
||||
now = datetime.now()
|
||||
next_run = now.replace(
|
||||
hour=hour,
|
||||
minute=minute,
|
||||
second=0,
|
||||
microsecond=0
|
||||
)
|
||||
|
||||
# 如果今天的运行时间已过,设置为明天
|
||||
if now >= next_run:
|
||||
from datetime import timedelta
|
||||
next_run = next_run + timedelta(days=1)
|
||||
|
||||
wait_seconds = (next_run - now).total_seconds()
|
||||
|
||||
logger.info(f"⏳ 等待下次运行: {next_run.strftime('%Y-%m-%d %H:%M:%S')} "
|
||||
f"(等待 {wait_seconds/3600:.1f} 小时)")
|
||||
|
||||
# 等待到运行时间
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
# 执行选股
|
||||
await self.run_once()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"定时运行出错: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
# 等待1小时后重试
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
def stop(self):
|
||||
"""停止运行"""
|
||||
self.running = False
|
||||
logger.info("A股智能体已停止")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_astock_agent: Optional[AStockAgent] = None
|
||||
|
||||
|
||||
def get_astock_agent() -> AStockAgent:
|
||||
"""获取A股智能体单例"""
|
||||
global _astock_agent
|
||||
if _astock_agent is None:
|
||||
_astock_agent = AStockAgent()
|
||||
return _astock_agent
|
||||
|
||||
|
||||
async def main():
|
||||
"""测试入口"""
|
||||
agent = get_astock_agent()
|
||||
|
||||
# 执行一次选股
|
||||
result = await agent.run_once()
|
||||
|
||||
# 输出结果
|
||||
print("\n" + "=" * 60)
|
||||
print(agent.selector.format_output_text(result))
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,338 +0,0 @@
|
||||
"""
|
||||
钉钉通知模块
|
||||
格式化并发送板块异动通知
|
||||
"""
|
||||
import json
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
import time
|
||||
import requests
|
||||
from typing import Dict, List
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class DingTalkNotifier:
|
||||
"""钉钉通知器"""
|
||||
|
||||
def __init__(self, webhook: str, secret: str = None):
|
||||
"""
|
||||
初始化通知器
|
||||
|
||||
Args:
|
||||
webhook: 钉钉机器人 Webhook URL
|
||||
secret: 加签密钥(可选)
|
||||
"""
|
||||
self.webhook = webhook
|
||||
self.secret = secret
|
||||
|
||||
def _sign(self, timestamp: int) -> str:
|
||||
"""
|
||||
生成签名
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
签名字符串
|
||||
"""
|
||||
if not self.secret:
|
||||
return ""
|
||||
|
||||
secret_enc = self.secret.encode('utf-8')
|
||||
string_to_sign = f'{timestamp}\n{self.secret}'
|
||||
string_to_sign_enc = string_to_sign.encode('utf-8')
|
||||
hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest()
|
||||
sign = base64.b64encode(hmac_code).decode('utf-8')
|
||||
return sign
|
||||
|
||||
def _build_url(self) -> str:
|
||||
"""
|
||||
构建带签名的 Webhook URL
|
||||
|
||||
Returns:
|
||||
完整的 Webhook URL
|
||||
"""
|
||||
if not self.secret:
|
||||
return self.webhook
|
||||
|
||||
timestamp = int(time.time() * 1000)
|
||||
sign = self._sign(timestamp)
|
||||
sign_encoded = quote(sign, safe='')
|
||||
|
||||
return f"{self.webhook}×tamp={timestamp}&sign={sign_encoded}"
|
||||
|
||||
def send_sector_alert(self, sector_data: Dict, top_stocks: List[Dict], reason: str = "") -> bool:
|
||||
"""
|
||||
发送板块异动提醒
|
||||
|
||||
Args:
|
||||
sector_data: 板块数据
|
||||
top_stocks: 龙头股列表
|
||||
reason: 异动原因
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 构建消息卡片
|
||||
card = self._format_sector_card(sector_data, top_stocks, reason)
|
||||
|
||||
# 构建请求数据
|
||||
data = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": f"🔥 {sector_data['name']} 异动提醒",
|
||||
"text": card
|
||||
}
|
||||
}
|
||||
|
||||
# 构建带签名的 URL
|
||||
url = self._build_url()
|
||||
|
||||
# 发送请求
|
||||
headers = {"Content-Type": "application/json;charset=utf-8"}
|
||||
response = requests.post(
|
||||
url,
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
if result.get("errcode") == 0:
|
||||
logger.info(f"钉钉通知发送成功: {sector_data['name']}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"钉钉通知发送失败: {result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送钉钉通知异常: {e}")
|
||||
return False
|
||||
|
||||
def _format_sector_card(self, sector_data: Dict, top_stocks: List[Dict], reason: str) -> str:
|
||||
"""
|
||||
格式化板块异动卡片
|
||||
|
||||
Args:
|
||||
sector_data: 板块数据
|
||||
top_stocks: 龙头股列表
|
||||
reason: 异动原因
|
||||
|
||||
Returns:
|
||||
Markdown 格式的消息内容
|
||||
"""
|
||||
lines = []
|
||||
|
||||
# 标题
|
||||
lines.append("### 🔥 A股板块异动提醒")
|
||||
lines.append("")
|
||||
|
||||
# 基本信息
|
||||
change_pct = sector_data['change_pct']
|
||||
change_icon = "📈" if change_pct > 0 else "📉"
|
||||
lines.append(f"**异动板块**: {sector_data['name']} {change_icon} {change_pct:+.2f}%")
|
||||
lines.append(f"**异动时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
lines.append(f"**异动类型**: 涨幅突增 | {reason if reason else '资金集中流入'}")
|
||||
lines.append("")
|
||||
|
||||
# 板块概况
|
||||
lines.append("#### 📊 板块概况")
|
||||
lines.append(f"- 涨幅: {change_pct:+.2f}%")
|
||||
lines.append(f"- 涨跌额: {sector_data.get('change_amount', 0):+.2f}")
|
||||
|
||||
if sector_data.get('amount', 0) > 0:
|
||||
amount = sector_data['amount']
|
||||
if amount >= 100000:
|
||||
amount_str = f"{amount/100000:.1f}亿"
|
||||
else:
|
||||
amount_str = f"{amount/10000:.1f}万"
|
||||
lines.append(f"- 成交额: {amount_str}")
|
||||
|
||||
if sector_data.get('leading_stock'):
|
||||
lines.append(f"- 领涨股: {sector_data['leading_stock']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# 龙头股
|
||||
if top_stocks:
|
||||
lines.append("#### 🏆 龙头股 Top " + str(len(top_stocks)))
|
||||
lines.append("")
|
||||
|
||||
for idx, stock in enumerate(top_stocks, 1):
|
||||
# 价格格式化
|
||||
price = stock['price']
|
||||
change_pct = stock['change_pct']
|
||||
|
||||
# 涨跌幅图标
|
||||
if change_pct >= 9.9:
|
||||
change_icon = "🚀"
|
||||
elif change_pct >= 5:
|
||||
change_icon = "⚡"
|
||||
elif change_pct > 0:
|
||||
change_icon = "📈"
|
||||
elif change_pct > -3:
|
||||
change_icon = "➖"
|
||||
else:
|
||||
change_icon = "📉"
|
||||
|
||||
lines.append(f"**{idx}. {stock['name']} ({stock['code']})**")
|
||||
lines.append(f" 现价: ¥{price:.2f} ({change_icon} {change_pct:+.2f}%)")
|
||||
lines.append(f" 成交额: {self._format_amount(stock['amount'])}")
|
||||
lines.append(f" 换手率: {stock['turnover']:.2f}%")
|
||||
lines.append(f" 涨速: {stock['speed_level']}")
|
||||
|
||||
if stock.get('volume_ratio', 1) > 2:
|
||||
lines.append(f" 量比: {stock['volume_ratio']:.2f} 🔥")
|
||||
|
||||
lines.append("")
|
||||
|
||||
lines.append("---")
|
||||
lines.append(f"📊 综合评分: {top_stocks[0]['score']:.1f}分")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_amount(self, amount: float) -> str:
|
||||
"""
|
||||
格式化成交额
|
||||
|
||||
Args:
|
||||
amount: 成交额(元)
|
||||
|
||||
Returns:
|
||||
格式化后的字符串
|
||||
"""
|
||||
if amount >= 100000000:
|
||||
return f"{amount/100000000:.2f}亿"
|
||||
elif amount >= 10000:
|
||||
return f"{amount/10000:.2f}万"
|
||||
else:
|
||||
return f"{amount:.0f}元"
|
||||
|
||||
def send_summary(self, total_sectors: int, total_stocks: int) -> bool:
|
||||
"""
|
||||
发送监控汇总
|
||||
|
||||
Args:
|
||||
total_sectors: 异动板块总数
|
||||
total_stocks: 龙头股总数
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
lines = []
|
||||
lines.append("### 📋 A股板块监控汇总")
|
||||
lines.append("")
|
||||
lines.append(f"**监控时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
lines.append("")
|
||||
lines.append("#### 📊 今日统计")
|
||||
lines.append(f"- 异动板块: {total_sectors} 个")
|
||||
lines.append(f"- 龙头股: {total_stocks} 只")
|
||||
lines.append("")
|
||||
lines.append("---")
|
||||
lines.append(f"⏰ 下次更新: {datetime.now().strftime('%H:%M')}")
|
||||
|
||||
card = "\n".join(lines)
|
||||
|
||||
# 构建请求数据
|
||||
data = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": "📋 A股板块监控汇总",
|
||||
"text": card
|
||||
}
|
||||
}
|
||||
|
||||
# 构建带签名的 URL
|
||||
url = self._build_url()
|
||||
|
||||
# 发送请求
|
||||
headers = {"Content-Type": "application/json;charset=utf-8"}
|
||||
response = requests.post(
|
||||
url,
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
if result.get("errcode") == 0:
|
||||
logger.info(f"钉钉汇总发送成功")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"钉钉汇总发送失败: {result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送钉钉汇总异常: {e}")
|
||||
return False
|
||||
|
||||
def send_error(self, error_msg: str) -> bool:
|
||||
"""
|
||||
发送错误通知
|
||||
|
||||
Args:
|
||||
error_msg: 错误信息
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
lines = []
|
||||
lines.append("### ❌ A股板块监控异常")
|
||||
lines.append("")
|
||||
lines.append(f"**时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
lines.append("")
|
||||
lines.append(f"```\n{error_msg}\n```")
|
||||
|
||||
card = "\n".join(lines)
|
||||
|
||||
# 构建请求数据
|
||||
data = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"title": "❌ A股板块监控异常",
|
||||
"text": card
|
||||
}
|
||||
}
|
||||
|
||||
# 构建带签名的 URL
|
||||
url = self._build_url()
|
||||
|
||||
# 发送请求
|
||||
headers = {"Content-Type": "application/json;charset=utf-8"}
|
||||
response = requests.post(
|
||||
url,
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
return result.get("errcode") == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送错误通知异常: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_notifier: DingTalkNotifier = None
|
||||
|
||||
|
||||
def get_dingtalk_notifier() -> DingTalkNotifier:
|
||||
"""获取钉钉通知器单例"""
|
||||
global _notifier
|
||||
if _notifier is None:
|
||||
from app.config import get_settings
|
||||
settings = get_settings()
|
||||
# 优先使用A股专用配置,否则使用通用配置
|
||||
webhook = settings.dingtalk_astock_webhook or settings.dingtalk_webhook_url
|
||||
secret = settings.dingtalk_astock_secret or settings.dingtalk_secret
|
||||
if webhook:
|
||||
_notifier = DingTalkNotifier(webhook, secret)
|
||||
return _notifier
|
||||
@ -1,373 +0,0 @@
|
||||
"""
|
||||
A股龙回头选股器
|
||||
策略:热门板块 + MA多头排列 + 量价配合龙回头
|
||||
执行时间:每天盘前 9:00
|
||||
|
||||
【选股条件】
|
||||
1. MA 多头排列:MA5 > MA10 > MA30
|
||||
2. 从近期高点回调 2-20%
|
||||
3. 回调期间缩量(成交量 < 上涨期的 80%)
|
||||
4. 接近 MA10/30/60 支撑位(±10%以内)
|
||||
5. 近期再度放量(最近2天 > 回调期的 1.2倍)
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class PullbackStockSelector:
|
||||
"""龙回头选股器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化选股器"""
|
||||
try:
|
||||
import tushare as ts
|
||||
self.ts = ts
|
||||
# 从配置获取 token
|
||||
from app.config import get_settings
|
||||
self.settings = get_settings()
|
||||
self.pro = ts.pro_api(self.settings.tushare_token)
|
||||
logger.info("龙回头选股器初始化成功")
|
||||
except ImportError:
|
||||
logger.error("tushare 未安装")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"龙回头选股器初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def get_hot_concepts(self, limit: int = 10) -> List[tuple]:
|
||||
"""
|
||||
从 Tushare 获取热门概念板块
|
||||
|
||||
使用 ths_index 接口获取同花顺概念板块列表
|
||||
|
||||
Args:
|
||||
limit: 返回板块数量
|
||||
|
||||
Returns:
|
||||
[(概念代码, 概念名称), ...]
|
||||
"""
|
||||
try:
|
||||
# 使用 ths_index 获取同花顺概念板块
|
||||
index_df = self.pro.ths_index(
|
||||
market='A', # A股市场
|
||||
fields='ts_code,name'
|
||||
)
|
||||
|
||||
if index_df.empty:
|
||||
logger.warning("未能获取概念板块列表,使用备用列表")
|
||||
return self._get_fallback_sectors()
|
||||
|
||||
# 筛选科技相关的热门概念(关键词匹配)
|
||||
tech_keywords = ['人工智能', '芯片', '半导体', '新能源', '汽车', '云计算',
|
||||
'网络安全', '软件', '数字', '5G', '锂电', '光伏', 'AI', '科技']
|
||||
|
||||
filtered = []
|
||||
for _, row in index_df.iterrows():
|
||||
concept_name = row['name']
|
||||
concept_code = row['ts_code']
|
||||
|
||||
# 检查是否包含关键词
|
||||
for keyword in tech_keywords:
|
||||
if keyword in concept_name:
|
||||
filtered.append((concept_code, concept_name))
|
||||
break
|
||||
|
||||
if len(filtered) >= limit * 2: # 多获取一些,后续筛选
|
||||
break
|
||||
|
||||
return filtered[:limit] if filtered else self._get_fallback_sectors()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取热门概念失败: {e},使用备用列表")
|
||||
return self._get_fallback_sectors()
|
||||
|
||||
def _get_fallback_sectors(self) -> List[tuple]:
|
||||
"""备用板块列表(使用已验证的板块代码)"""
|
||||
return [
|
||||
('884031.TI', '人工智能'),
|
||||
('884065.TI', '新能源汽车'),
|
||||
('884039.TI', '云计算'),
|
||||
('884145.TI', '国产软件'),
|
||||
('884192.TI', '5G概念'),
|
||||
]
|
||||
|
||||
def select_from_sector(self, sector_code: str, sector_name: str, max_stocks: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从指定板块选出龙回头股票
|
||||
|
||||
Args:
|
||||
sector_code: 板块代码
|
||||
sector_name: 板块名称
|
||||
max_stocks: 最多返回股票数
|
||||
|
||||
Returns:
|
||||
符合条件的股票列表
|
||||
"""
|
||||
try:
|
||||
# 1. 获取板块成分股
|
||||
members_df = self.pro.ths_member(
|
||||
ts_code=sector_code,
|
||||
fields='ts_code,name,con_code,name'
|
||||
)
|
||||
if members_df.empty:
|
||||
logger.warning(f"板块 {sector_name}({sector_code}) 无成分股数据,可能是板块代码不正确或该板块已下线")
|
||||
return []
|
||||
|
||||
stock_codes = members_df['con_code'].tolist()
|
||||
logger.info(f"板块 {sector_name} 共 {len(stock_codes)} 只成分股")
|
||||
|
||||
# 2. 逐个检查股票
|
||||
selected_stocks = []
|
||||
for i, stock_code in enumerate(stock_codes[:50]): # 最多检查前50只
|
||||
result = self._check_stock(stock_code)
|
||||
if result:
|
||||
result['sector_name'] = sector_name
|
||||
selected_stocks.append(result)
|
||||
logger.info(f" ✓ 找到: {result['name']}({stock_code}) - 回踩 {result['pullback_pct']:.2f}%")
|
||||
|
||||
if len(selected_stocks) >= max_stocks:
|
||||
break
|
||||
|
||||
return selected_stocks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从板块选股失败 {sector_name}({sector_code}): {e}")
|
||||
return []
|
||||
|
||||
def _check_stock(self, stock_code: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
检查单只股票是否符合量价配合龙回头条件
|
||||
|
||||
【新策略】量价配合龙回头:
|
||||
1. MA 多头排列:MA5 > MA10 > MA30(上涨趋势)
|
||||
2. 从近期高点回调(寻找龙回头机会)
|
||||
3. 回调期间缩量(成交量明显萎缩,主力未离场)
|
||||
4. 接近 MA10/30/60 支撑位
|
||||
5. 近期有再度放量迹象(资金重新入场)
|
||||
|
||||
Returns:
|
||||
符合条件返回股票信息,否则返回 None
|
||||
"""
|
||||
try:
|
||||
# 获取最近150天的日线数据
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
start_date = (datetime.now() - timedelta(days=150)).strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.daily(
|
||||
ts_code=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if df.empty or len(df) < 70:
|
||||
return None
|
||||
|
||||
df = df.sort_values('trade_date').reset_index(drop=True)
|
||||
df = df.tail(90).reset_index(drop=True) # 取最近90天
|
||||
|
||||
close = df['close']
|
||||
volume = df['vol'] # 成交量(手)
|
||||
|
||||
# ========== 1. 计算 MA 均线 ==========
|
||||
ma5 = close.rolling(window=5).mean()
|
||||
ma10 = close.rolling(window=10).mean()
|
||||
ma30 = close.rolling(window=30).mean()
|
||||
ma60 = close.rolling(window=60).mean()
|
||||
|
||||
latest = df.iloc[-1]
|
||||
latest_close = latest['close']
|
||||
latest_ma5 = ma5.iloc[-1]
|
||||
latest_ma10 = ma10.iloc[-1]
|
||||
latest_ma30 = ma30.iloc[-1]
|
||||
latest_ma60 = ma60.iloc[-1]
|
||||
|
||||
# 条件1:MA 多头排列 MA5 > MA10 > MA30
|
||||
if not (latest_ma5 > latest_ma10 > latest_ma30):
|
||||
logger.debug(f" ✗ {stock_code}: MA非多头排列 (MA5:{latest_ma5:.2f}, MA10:{latest_ma10:.2f}, MA30:{latest_ma30:.2f})")
|
||||
return None
|
||||
|
||||
# ========== 2. 找近期高点(最近20天内最高价) ==========
|
||||
lookback_high = 20
|
||||
recent_df = df.tail(lookback_high).reset_index(drop=True)
|
||||
high_idx_in_recent = recent_df['close'].idxmax()
|
||||
high_price = recent_df.loc[high_idx_in_recent, 'close']
|
||||
high_idx_in_df = df.index[-lookback_high] + high_idx_in_recent
|
||||
|
||||
# 计算从高点的回踩幅度
|
||||
pullback_pct = (high_price - latest_close) / high_price * 100
|
||||
|
||||
# 条件2:回踩幅度 2-20%(放宽)
|
||||
if not (2 <= pullback_pct <= 20):
|
||||
logger.debug(f" ✗ {stock_code}: 回踩幅度不符合 (回踩:{pullback_pct:.2f}%, 需要2-20%)")
|
||||
return None
|
||||
|
||||
# ========== 3. 量能形态分析:放量上涨→缩量回调→再度放量 ==========
|
||||
|
||||
# 上涨期间的成交量(从低点到高点)
|
||||
rise_period = df.loc[:high_idx_in_df]
|
||||
rise_volume_avg = rise_period['vol'].mean()
|
||||
|
||||
# 回调期间的成交量(从高点到现在,最近5天)
|
||||
pullback_start_idx = min(high_idx_in_df + 1, len(df) - 1)
|
||||
pullback_period = df.loc[pullback_start_idx:][-5:] # 回调期间最近5天
|
||||
pullback_volume_avg = pullback_period['vol'].mean()
|
||||
|
||||
# 条件3:回调期间缩量(成交量 < 上涨期间的 80%,放宽)
|
||||
volume_shrink_ratio = pullback_volume_avg / rise_volume_avg if rise_volume_avg > 0 else 1
|
||||
if volume_shrink_ratio >= 0.8:
|
||||
logger.debug(f" ✗ {stock_code}: 回调未缩量 (缩量比:{volume_shrink_ratio:.2%}, 需要<80%)")
|
||||
return None
|
||||
|
||||
# 条件4:最近2天再度放量(比回调期间平均成交量增加 20%+,放宽)
|
||||
recent_2_days_volume = df.tail(2)['vol'].mean()
|
||||
if recent_2_days_volume < pullback_volume_avg * 1.2:
|
||||
logger.debug(f" ✗ {stock_code}: 未再度放量 (最近2天/回调期:{recent_2_days_volume/pullback_volume_avg:.2f}, 需要>1.2)")
|
||||
return None
|
||||
|
||||
# ========== 4. 接近 MA 支撑位 ==========
|
||||
ma10_diff = abs(latest_close - latest_ma10) / latest_ma10 * 100
|
||||
ma30_diff = abs(latest_close - latest_ma30) / latest_ma30 * 100
|
||||
ma60_diff = abs(latest_close - latest_ma60) / latest_ma60 * 100
|
||||
|
||||
# 接近任意一条 MA 线(±10%以内,放宽)
|
||||
near_ma = (ma10_diff <= 10 or ma30_diff <= 10 or ma60_diff <= 10)
|
||||
|
||||
# 未跌破 MA60(允许跌破 10%,放宽)
|
||||
above_ma60 = latest_close >= latest_ma60 * 0.90
|
||||
|
||||
if not (near_ma and above_ma60):
|
||||
logger.debug(f" ✗ {stock_code}: 未接近MA支撑或已跌破MA60 (MA10差:{ma10_diff:.2f}%, MA30:{ma30_diff:.2f}%, MA60:{ma60_diff:.2f}%)")
|
||||
return None
|
||||
|
||||
# ========== 5. 计算前期涨幅 ==========
|
||||
rise_data = df.loc[:high_idx_in_df]
|
||||
low_price = rise_data['low'].min()
|
||||
rise_pct = (high_price - low_price) / low_price * 100
|
||||
|
||||
# 涨幅需 > 10%
|
||||
if rise_pct < 10:
|
||||
logger.debug(f" ✗ {stock_code}: 涨幅不足 (涨幅:{rise_pct:.2f}%, 需要>10%)")
|
||||
return None
|
||||
|
||||
# 获取股票名称
|
||||
stock_info = self.pro.stock_basic(ts_code=stock_code, fields='ts_code,name')
|
||||
stock_name = stock_info.iloc[0]['name'] if not stock_info.empty else stock_code
|
||||
|
||||
return {
|
||||
'ts_code': stock_code,
|
||||
'name': stock_name,
|
||||
'close': latest_close,
|
||||
'high': high_price,
|
||||
'rise_pct': rise_pct,
|
||||
'pullback_pct': pullback_pct,
|
||||
'ma5': latest_ma5,
|
||||
'ma10': latest_ma10,
|
||||
'ma30': latest_ma30,
|
||||
'ma60': latest_ma60,
|
||||
'volume_shrink_ratio': volume_shrink_ratio,
|
||||
'recent_volume_ratio': recent_2_days_volume / pullback_volume_avg,
|
||||
'trade_date': latest['trade_date']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"检查股票 {stock_code} 失败: {e}")
|
||||
return None
|
||||
|
||||
def select_from_hot_sectors(self, top_n: int = 5) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
从热门板块选股
|
||||
|
||||
Args:
|
||||
top_n: 选择前N个热门板块
|
||||
|
||||
Returns:
|
||||
{板块名称: [股票列表]}
|
||||
"""
|
||||
try:
|
||||
# 使用 Tushare API 动态获取热门板块
|
||||
hot_sectors = self.get_hot_concepts(limit=top_n * 2) # 多获取一些以备筛选
|
||||
|
||||
logger.info(f"开始龙回头选股,共 {len(hot_sectors)} 个热门板块")
|
||||
|
||||
results = {}
|
||||
checked_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for sector_code, sector_name in hot_sectors:
|
||||
if checked_count >= top_n:
|
||||
break
|
||||
|
||||
logger.info(f"检查板块: {sector_name}")
|
||||
|
||||
selected = self.select_from_sector(sector_code, sector_name, max_stocks=2)
|
||||
|
||||
if selected:
|
||||
results[sector_name] = selected
|
||||
logger.info(f" ✓ 板块 {sector_name} 选出 {len(selected)} 只")
|
||||
checked_count += 1
|
||||
else:
|
||||
logger.info(f" - 板块 {sector_name} 未选出")
|
||||
skipped_count += 1
|
||||
# 跳过无数据的板块,继续检查下一个
|
||||
if skipped_count >= 5: # 如果连续5个板块都没有数据,就停止
|
||||
logger.warning("连续多个板块无数据,停止检查")
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从热门板块选股失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return {}
|
||||
|
||||
def format_result(self, results: Dict[str, List[Dict[str, Any]]]) -> str:
|
||||
"""
|
||||
格式化选股结果
|
||||
|
||||
Args:
|
||||
results: 选股结果
|
||||
|
||||
Returns:
|
||||
格式化的文本
|
||||
"""
|
||||
if not results:
|
||||
return "今日未选出符合条件的龙回头股票"
|
||||
|
||||
lines = [
|
||||
"📊 **龙回头选股结果**",
|
||||
f"",
|
||||
f"选股时间: {datetime.now().strftime('%Y-%m-%d %H:%M')}",
|
||||
f"",
|
||||
]
|
||||
|
||||
total_stocks = 0
|
||||
for sector_name, stocks in results.items():
|
||||
lines.append(f"**🏢 {sector_name}**")
|
||||
for stock in stocks:
|
||||
total_stocks += 1
|
||||
lines.append(f" • {stock['name']}({stock['ts_code']})")
|
||||
lines.append(f" 现价: ¥{stock['close']:.2f} | 回踩: {stock['pullback_pct']:.2f}% | 涨幅: {stock['rise_pct']:.2f}%")
|
||||
lines.append(f" MA5: ¥{stock['ma5']:.2f} | MA10: ¥{stock['ma10']:.2f} | MA30: ¥{stock['ma30']:.2f} | MA60: ¥{stock['ma60']:.2f}")
|
||||
lines.append(f" 缩量比: {stock['volume_shrink_ratio']:.0%} | 再放量: {stock['recent_volume_ratio']:.2f}x")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"**共选出 {total_stocks} 只股票**")
|
||||
lines.append("")
|
||||
lines.append("*⚠️ 仅供参考,不构成投资建议*")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_pullback_selector: Optional[PullbackStockSelector] = None
|
||||
|
||||
|
||||
def get_pullback_selector() -> PullbackStockSelector:
|
||||
"""获取龙回头选股器单例"""
|
||||
global _pullback_selector
|
||||
if _pullback_selector is None:
|
||||
_pullback_selector = PullbackStockSelector()
|
||||
return _pullback_selector
|
||||
@ -1,200 +0,0 @@
|
||||
"""
|
||||
板块异动分析
|
||||
检测板块涨跌幅、量能、资金流向异动
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
from app.utils.error_handler import notify_error
|
||||
from .akshare_client import get_akshare_client
|
||||
|
||||
|
||||
class SectorChangeAnalyzer:
|
||||
"""板块异动分析器"""
|
||||
|
||||
def __init__(self, change_threshold: float = 2.0):
|
||||
"""
|
||||
初始化异动分析器
|
||||
|
||||
Args:
|
||||
change_threshold: 涨跌幅阈值(%)
|
||||
"""
|
||||
self.change_threshold = change_threshold
|
||||
self.akshare = get_akshare_client()
|
||||
|
||||
def detect_sector_changes(self) -> List[Dict]:
|
||||
"""
|
||||
检测异动板块(使用概念板块)
|
||||
|
||||
Returns:
|
||||
异动板块列表
|
||||
"""
|
||||
try:
|
||||
# 获取概念板块行情
|
||||
df = self.akshare.get_concept_spot()
|
||||
if df.empty:
|
||||
logger.warning("概念板块行情数据为空")
|
||||
return []
|
||||
|
||||
# 转换数据类型(概念板块返回的列名)
|
||||
df['涨跌幅'] = pd.to_numeric(df['涨跌幅'], errors='coerce')
|
||||
df['涨跌额'] = pd.to_numeric(df['涨跌额'], errors='coerce')
|
||||
df['最新价'] = pd.to_numeric(df['最新价'], errors='coerce')
|
||||
df['成交额'] = pd.to_numeric(df['成交额'], errors='coerce')
|
||||
|
||||
# 筛选异动板块
|
||||
hot_sectors = df[df['涨跌幅'] >= self.change_threshold].copy()
|
||||
|
||||
if hot_sectors.empty:
|
||||
return []
|
||||
|
||||
# 排序:涨幅优先,然后成交额
|
||||
hot_sectors = hot_sectors.sort_values(
|
||||
by=['涨跌幅', '成交额'],
|
||||
ascending=[False, False]
|
||||
)
|
||||
|
||||
# 转换为结果列表
|
||||
results = []
|
||||
for _, row in hot_sectors.iterrows():
|
||||
results.append({
|
||||
'name': row['板块名称'],
|
||||
'change_pct': float(row['涨跌幅']),
|
||||
'change_amount': float(row.get('涨跌额', 0)),
|
||||
'volume': float(row.get('成交量', 0)) if '成交量' in row else 0.0,
|
||||
'amount': float(row.get('成交额', 0)),
|
||||
'leading_stock': row.get('领涨股', ''),
|
||||
'ups': int(row.get('上涨家数', 0)),
|
||||
'downs': int(row.get('下跌家数', 0)),
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
logger.info(f"检测到 {len(results)} 个异动概念板块")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"检测板块异动失败: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 如果是连接错误,发送通知
|
||||
if 'Connection' in str(e) or 'RemoteDisconnected' in str(e):
|
||||
notify_error(
|
||||
title="A股板块监控 - 数据源连接失败",
|
||||
message=f"akshare 概念板块 API 连接失败\n\n错误: {e}\n\n可能原因:\n- eastmoney API 服务不稳定\n- 网络连接问题\n- 建议:稍后自动重试或考虑使用 tushare",
|
||||
level="warning"
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
def analyze_sector_momentum(self, sector_name: str) -> Dict:
|
||||
"""
|
||||
分析板块动能
|
||||
|
||||
Args:
|
||||
sector_name: 板块名称
|
||||
|
||||
Returns:
|
||||
板块动能分析结果
|
||||
"""
|
||||
try:
|
||||
# 获取板块成分股
|
||||
stocks_df = self.akshare.get_concept_stocks(sector_name)
|
||||
if stocks_df.empty:
|
||||
return {}
|
||||
|
||||
# 获取实时行情
|
||||
spot_df = self.akshare.get_stock_spot()
|
||||
if spot_df.empty:
|
||||
return {}
|
||||
|
||||
# 合并数据
|
||||
merged = pd.merge(
|
||||
stocks_df,
|
||||
spot_df,
|
||||
on='代码',
|
||||
how='inner'
|
||||
)
|
||||
|
||||
if merged.empty:
|
||||
return {}
|
||||
|
||||
# 计算统计
|
||||
total_stocks = len(merged)
|
||||
up_stocks = len(merged[merged['涨跌幅'] > 0])
|
||||
down_stocks = len(merged[merged['涨跌幅'] < 0])
|
||||
avg_change = merged['涨跌幅'].mean()
|
||||
max_change = merged['涨跌幅'].max()
|
||||
|
||||
# 计算总成交额
|
||||
total_amount = merged['成交额'].sum() if '成交额' in merged.columns else 0
|
||||
|
||||
# 找出涨幅最大的股票
|
||||
if not merged.empty:
|
||||
top_stock = merged.loc[merged['涨跌幅'].idxmax()]
|
||||
else:
|
||||
top_stock = None
|
||||
|
||||
return {
|
||||
'sector_name': sector_name,
|
||||
'total_stocks': total_stocks,
|
||||
'up_stocks': up_stocks,
|
||||
'down_stocks': down_stocks,
|
||||
'up_down_ratio': f"{up_stocks}:{down_stocks}",
|
||||
'avg_change': float(avg_change) if pd.notna(avg_change) else 0,
|
||||
'max_change': float(max_change) if pd.notna(max_change) else 0,
|
||||
'total_amount': float(total_amount),
|
||||
'top_stock': {
|
||||
'code': top_stock['代码'] if top_stock is not None else '',
|
||||
'name': top_stock['名称'] if top_stock is not None else '',
|
||||
'change': float(top_stock['涨跌幅']) if top_stock is not None else 0,
|
||||
} if top_stock is not None else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析板块动能失败 {sector_name}: {e}")
|
||||
return {}
|
||||
|
||||
def get_hot_reason(self, sector_name: str, top_stocks: List[Dict]) -> str:
|
||||
"""
|
||||
推测异动原因(基于龙头股分析)
|
||||
|
||||
Args:
|
||||
sector_name: 板块名称
|
||||
top_stocks: 龙头股列表
|
||||
|
||||
Returns:
|
||||
异动原因描述
|
||||
"""
|
||||
try:
|
||||
if not top_stocks:
|
||||
return "板块整体异动"
|
||||
|
||||
# 简单的原因分析
|
||||
reasons = []
|
||||
|
||||
# 检查是否有涨停股
|
||||
limit_up_count = sum(1 for s in top_stocks if s.get('change_pct', 0) >= 9.9)
|
||||
if limit_up_count > 0:
|
||||
reasons.append(f"{limit_up_count}只个股涨停")
|
||||
|
||||
# 检查平均涨幅
|
||||
avg_change = sum(s.get('change_pct', 0) for s in top_stocks) / len(top_stocks)
|
||||
if avg_change >= 7:
|
||||
reasons.append("板块全线爆发")
|
||||
|
||||
# 检查是否集中在某个龙头
|
||||
if len(top_stocks) >= 2:
|
||||
top1_change = top_stocks[0].get('change_pct', 0)
|
||||
top2_change = top_stocks[1].get('change_pct', 0)
|
||||
if top1_change - top2_change > 3:
|
||||
reasons.append(f"{top_stocks[0].get('name', '')}龙头领涨")
|
||||
|
||||
if reasons:
|
||||
return ",".join(reasons)
|
||||
else:
|
||||
return "资金集中流入"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推测异动原因失败: {e}")
|
||||
return "板块异动"
|
||||
@ -1,250 +0,0 @@
|
||||
"""
|
||||
板块异动监控主程序
|
||||
协调各个模块,实现监控流程
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from .tushare_client import get_tushare_client
|
||||
from .tushare_sector_analyzer import TushareSectorAnalyzer
|
||||
from .tushare_stock_selector import TushareStockSelector
|
||||
from .notifier import get_dingtalk_notifier
|
||||
|
||||
|
||||
class SectorMonitor:
|
||||
"""板块异动监控器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
change_threshold: float = 2.0,
|
||||
top_n: int = 3,
|
||||
enable_notifier: bool = True
|
||||
):
|
||||
"""
|
||||
初始化监控器
|
||||
|
||||
Args:
|
||||
change_threshold: 涨跌幅阈值(%)
|
||||
top_n: 每个板块返回前N只龙头股
|
||||
enable_notifier: 是否启用钉钉通知
|
||||
"""
|
||||
self.change_threshold = change_threshold
|
||||
self.top_n = top_n
|
||||
self.enable_notifier = enable_notifier
|
||||
|
||||
# 获取 Tushare 客户端
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
if not ts_client:
|
||||
logger.warning("Tushare token 未配置,板块监控可能无法正常工作")
|
||||
|
||||
# 初始化各个模块
|
||||
self.analyzer = TushareSectorAnalyzer(ts_client, change_threshold=change_threshold)
|
||||
self.selector = TushareStockSelector(ts_client, top_n=top_n)
|
||||
self.notifier = get_dingtalk_notifier() if enable_notifier else None
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
'total_checks': 0,
|
||||
'total_hot_sectors': 0,
|
||||
'total_stocks': 0,
|
||||
'last_check_time': None,
|
||||
'last_hot_count': 0
|
||||
}
|
||||
|
||||
async def check_once(self) -> Dict:
|
||||
"""
|
||||
执行一次检查
|
||||
|
||||
Returns:
|
||||
检查结果统计
|
||||
"""
|
||||
try:
|
||||
logger.info("开始板块异动检查...")
|
||||
start_time = datetime.now()
|
||||
|
||||
# 1. 检测异动板块
|
||||
hot_sectors = self.analyzer.detect_sector_changes()
|
||||
|
||||
if not hot_sectors:
|
||||
logger.info("未检测到异动板块")
|
||||
self.stats['total_checks'] += 1
|
||||
self.stats['last_check_time'] = datetime.now()
|
||||
self.stats['last_hot_count'] = 0
|
||||
return {
|
||||
'hot_sectors': 0,
|
||||
'stocks': 0,
|
||||
'notified': 0
|
||||
}
|
||||
|
||||
logger.info(f"检测到 {len(hot_sectors)} 个异动板块")
|
||||
|
||||
# 2. 对每个异动板块进行深度分析
|
||||
results = []
|
||||
total_stocks = 0
|
||||
|
||||
for sector in hot_sectors:
|
||||
sector_name = sector['name']
|
||||
ts_code = sector['ts_code']
|
||||
|
||||
# 筛选龙头股(Tushare 版本需要 ts_code)
|
||||
top_stocks = self.selector.select_leading_stocks(ts_code, sector_name)
|
||||
|
||||
if not top_stocks:
|
||||
logger.warning(f"板块 {sector_name} 未找到龙头股")
|
||||
continue
|
||||
|
||||
# 分析异动原因
|
||||
reason = self.analyzer.get_hot_reason(sector_name, top_stocks)
|
||||
|
||||
# 发送钉钉通知
|
||||
notified = False
|
||||
if self.notifier:
|
||||
notified = self.notifier.send_sector_alert(
|
||||
sector_data=sector,
|
||||
top_stocks=top_stocks,
|
||||
reason=reason
|
||||
)
|
||||
|
||||
results.append({
|
||||
'sector': sector,
|
||||
'stocks': top_stocks,
|
||||
'reason': reason,
|
||||
'notified': notified
|
||||
})
|
||||
|
||||
total_stocks += len(top_stocks)
|
||||
logger.info(
|
||||
f"板块 {sector_name}: {len(top_stocks)} 只龙头股, "
|
||||
f"原因: {reason}, 通知: {'成功' if notified else '失败'}"
|
||||
)
|
||||
|
||||
# 更新统计
|
||||
self.stats['total_checks'] += 1
|
||||
self.stats['total_hot_sectors'] += len(hot_sectors)
|
||||
self.stats['total_stocks'] += total_stocks
|
||||
self.stats['last_check_time'] = datetime.now()
|
||||
self.stats['last_hot_count'] = len(hot_sectors)
|
||||
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
logger.info(
|
||||
f"检查完成: {len(hot_sectors)} 个异动板块, "
|
||||
f"{total_stocks} 只龙头股, 耗时 {elapsed:.2f}秒"
|
||||
)
|
||||
|
||||
return {
|
||||
'hot_sectors': len(hot_sectors),
|
||||
'stocks': total_stocks,
|
||||
'notified': sum(1 for r in results if r['notified']),
|
||||
'results': results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"板块异动检查失败: {e}")
|
||||
# 发送错误通知
|
||||
if self.notifier:
|
||||
self.notifier.send_error(str(e))
|
||||
return {
|
||||
'hot_sectors': 0,
|
||||
'stocks': 0,
|
||||
'notified': 0,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
async def run_periodic(self, interval_minutes: int = 30, max_runs: int = None):
|
||||
"""
|
||||
周期性运行监控
|
||||
|
||||
Args:
|
||||
interval_minutes: 检查间隔(分钟)
|
||||
max_runs: 最大运行次数(None表示无限运行)
|
||||
"""
|
||||
logger.info(
|
||||
f"启动周期性监控: 间隔 {interval_minutes}分钟, "
|
||||
f"阈值 {self.change_threshold}%, Top{self.top_n}"
|
||||
)
|
||||
|
||||
run_count = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 检查是否达到最大运行次数
|
||||
if max_runs and run_count >= max_runs:
|
||||
logger.info(f"已达到最大运行次数 {max_runs},停止监控")
|
||||
break
|
||||
|
||||
# 执行检查
|
||||
await self.check_once()
|
||||
run_count += 1
|
||||
|
||||
# 等待下一次检查
|
||||
if interval_minutes > 0:
|
||||
logger.info(f"等待 {interval_minutes} 分钟后进行下次检查...")
|
||||
await asyncio.sleep(interval_minutes * 60)
|
||||
else:
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("监控任务被取消")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"周期性监控异常: {e}")
|
||||
if self.notifier:
|
||||
self.notifier.send_error(f"周期性监控异常: {e}")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return {
|
||||
**self.stats,
|
||||
'avg_stocks_per_check': (
|
||||
self.stats['total_stocks'] / self.stats['total_checks']
|
||||
if self.stats['total_checks'] > 0 else 0
|
||||
)
|
||||
}
|
||||
|
||||
def send_summary_report(self) -> bool:
|
||||
"""
|
||||
发送汇总报告
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
if not self.notifier:
|
||||
return False
|
||||
|
||||
return self.notifier.send_summary(
|
||||
total_sectors=self.stats['total_hot_sectors'],
|
||||
total_stocks=self.stats['total_stocks']
|
||||
)
|
||||
|
||||
|
||||
# 快捷函数
|
||||
async def quick_check(
|
||||
change_threshold: float = 2.0,
|
||||
top_n: int = 3,
|
||||
enable_notifier: bool = True
|
||||
) -> Dict:
|
||||
"""
|
||||
快捷检查函数
|
||||
|
||||
Args:
|
||||
change_threshold: 涨跌幅阈值(%)
|
||||
top_n: 每个板块返回前N只龙头股
|
||||
enable_notifier: 是否启用钉钉通知
|
||||
|
||||
Returns:
|
||||
检查结果
|
||||
"""
|
||||
monitor = SectorMonitor(
|
||||
change_threshold=change_threshold,
|
||||
top_n=top_n,
|
||||
enable_notifier=enable_notifier
|
||||
)
|
||||
return await monitor.check_once()
|
||||
@ -1,802 +0,0 @@
|
||||
"""
|
||||
A股短期题材选股器
|
||||
策略:题材轮动 + 资金异动 + MA多头排列 + 量能配合
|
||||
执行时间:每天盘后输出
|
||||
|
||||
【选股策略】
|
||||
1. 题材筛选:资金异动板块(成交量放大、成交额增加)
|
||||
2. 个股筛选:
|
||||
- 市值 30-1000亿(流动性好,有炒作空间)
|
||||
- 换手率 1%-20%(资金活跃)
|
||||
- 排除ST、退市风险股
|
||||
- MA趋势向上(MA5 > MA20,适合震荡市场)
|
||||
- 量能配合(量比≥1.0)
|
||||
|
||||
【风险控制】(最大回撤10%)
|
||||
- 硬止损:-7%(单只股票最大损失)
|
||||
- 技术止损:跌破20日均线
|
||||
- 时间止损:持仓>30天未启动
|
||||
- 仓位管理:
|
||||
* 单票最大20%
|
||||
* 单行业最大40%
|
||||
* 总仓位最大80%
|
||||
|
||||
【数据源】
|
||||
- Tushare API(行情、基本面、资金流)
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class ShortTermThematicSelector:
|
||||
"""短期题材选股器"""
|
||||
|
||||
def __init__(self, tushare_client, strict_mode: bool = False):
|
||||
"""
|
||||
初始化选股器
|
||||
|
||||
Args:
|
||||
tushare_client: TushareClient实例
|
||||
strict_mode: 严格模式(True使用原策略,False放宽条件)
|
||||
"""
|
||||
self.ts_client = tushare_client
|
||||
self.strict_mode = strict_mode
|
||||
|
||||
# 选股参数(严格模式 vs 宽松模式)
|
||||
if strict_mode:
|
||||
# 严格模式:原策略
|
||||
self.min_market_cap = 50
|
||||
self.max_market_cap = 500
|
||||
self.min_turnover = 3.0
|
||||
self.max_turnover = 15.0
|
||||
self.sector_change_threshold = 2.0
|
||||
self.volume_ratio_threshold = 1.2
|
||||
else:
|
||||
# 宽松模式:适应当前市场
|
||||
self.min_market_cap = 30 # 降低市值下限
|
||||
self.max_market_cap = 1000 # 提高市值上限
|
||||
self.min_turnover = 1.0 # 降低换手率下限
|
||||
self.max_turnover = 20.0 # 提高换手率上限
|
||||
self.sector_change_threshold = 1.5 # 降低板块涨幅要求
|
||||
self.volume_ratio_threshold = 0.6 # 放宽量比要求(原1.0,现0.6)
|
||||
|
||||
# 风险控制参数
|
||||
self.max_drawdown = 10.0 # 最大回撤(%)
|
||||
self.hard_stop_loss = -7.0 # 硬止损(%)
|
||||
self.max_single_position = 0.20 # 单票最大仓位
|
||||
self.max_sector_position = 0.40 # 单行业最大仓位
|
||||
self.max_total_position = 0.80 # 总仓位最大值
|
||||
|
||||
def select_stocks(self, max_stocks: int = 10) -> Dict[str, Any]:
|
||||
"""
|
||||
执行选股
|
||||
|
||||
Args:
|
||||
max_stocks: 最多返回股票数
|
||||
|
||||
Returns:
|
||||
选股结果字典
|
||||
"""
|
||||
try:
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"📊 短期题材选股开始 ({'严格模式' if self.strict_mode else '宽松模式'})")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 1. 获取异动板块
|
||||
logger.info("\n【第一步】筛选异动板块...")
|
||||
hot_sectors = self._get_hot_sectors()
|
||||
|
||||
if hot_sectors.empty:
|
||||
logger.warning("未找到异动板块")
|
||||
return self._empty_result()
|
||||
|
||||
logger.info(f"找到 {len(hot_sectors)} 个异动板块")
|
||||
for _, sector in hot_sectors.head(5).iterrows():
|
||||
logger.info(f" - {sector['name']}: {sector['change_pct']:+.2f}%, 成交额: {sector['amount']/100000000:.2f}亿")
|
||||
|
||||
# 2. 从异动板块中筛选个股
|
||||
logger.info("\n【第二步】从异动板块中筛选个股...")
|
||||
all_selected = []
|
||||
|
||||
for idx, sector in hot_sectors.iterrows():
|
||||
sector_code = sector['ts_code']
|
||||
sector_name = sector['name']
|
||||
sector_change = sector['change_pct']
|
||||
|
||||
logger.info(f"\n检查板块: {sector_name} ({sector_code})")
|
||||
|
||||
# 获取该板块的成分股
|
||||
members_df = self.ts_client.get_sector_members(sector_code)
|
||||
if members_df.empty:
|
||||
logger.warning(f" 无法获取板块成分股")
|
||||
continue
|
||||
|
||||
stock_codes = members_df['con_code'].tolist()
|
||||
logger.info(f" 板块成分股: {len(stock_codes)} 只")
|
||||
|
||||
# 筛选该板块的个股
|
||||
sector_stocks = self._select_stocks_from_sector(
|
||||
stock_codes, sector_name, sector_change
|
||||
)
|
||||
|
||||
if sector_stocks:
|
||||
all_selected.extend(sector_stocks)
|
||||
logger.info(f" ✓ 选出 {len(sector_stocks)} 只")
|
||||
|
||||
if len(all_selected) >= max_stocks * 2: # 多选一些备用
|
||||
break
|
||||
|
||||
if not all_selected:
|
||||
logger.warning("未选出符合条件的股票")
|
||||
return self._empty_result()
|
||||
|
||||
# 3. 综合评分和排序
|
||||
logger.info("\n【第三步】综合评分和排序...")
|
||||
all_selected = self._rank_stocks(all_selected)
|
||||
|
||||
# 4. 应用仓位管理
|
||||
logger.info("\n【第四步】计算仓位配置...")
|
||||
final_stocks = self._allocate_positions(all_selected[:max_stocks])
|
||||
|
||||
# 5. 生成输出
|
||||
result = self._format_result(final_stocks, hot_sectors)
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info(f"✅ 选股完成,共选出 {len(final_stocks)} 只股票")
|
||||
logger.info("=" * 60)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选股失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return self._empty_result()
|
||||
|
||||
def _get_hot_sectors(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取异动板块(基于成交量和资金异动)
|
||||
|
||||
策略:
|
||||
1. 优先选择热门概念板块(AI、新能源、芯片等)
|
||||
2. 关键指标:成交量放大、成交额增加
|
||||
3. 辅助指标:涨幅(可选)
|
||||
|
||||
Returns:
|
||||
异动板块列表
|
||||
"""
|
||||
try:
|
||||
sectors_df = self.ts_client.get_concept_sectors()
|
||||
if sectors_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
today = datetime.now().strftime('%Y%m%d')
|
||||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||||
|
||||
# 热门板块关键词(优先选择这些)
|
||||
hot_keywords = [
|
||||
'人工智能', 'AI', '算力', 'CPO', 'AIGC',
|
||||
'新能源汽车', '锂电', '储能', '充电桩', '汽车',
|
||||
'半导体', '芯片', '集成电路',
|
||||
'机器人', '工业4.0',
|
||||
'5G', '6G', '通信',
|
||||
'数字经济', '云计算', '大数据', '物联网',
|
||||
'军工', '航空',
|
||||
'生物医药', '医药', '医疗',
|
||||
'消费电子',
|
||||
'光伏', '风电', '氢能',
|
||||
'智能电网', '电力',
|
||||
'元宇宙', '虚拟现实',
|
||||
]
|
||||
|
||||
hot_sectors = []
|
||||
checked_codes = set()
|
||||
|
||||
# 1. 优先检查热门概念板块
|
||||
logger.info("优先检查热门概念板块的资金异动...")
|
||||
for keyword in hot_keywords:
|
||||
# 查找包含关键词的板块
|
||||
matching_sectors = sectors_df[sectors_df['name'].str.contains(keyword, na=False)]
|
||||
|
||||
for _, row in matching_sectors.iterrows():
|
||||
ts_code = row['ts_code']
|
||||
name = row['name']
|
||||
|
||||
if ts_code in checked_codes:
|
||||
continue
|
||||
checked_codes.add(ts_code)
|
||||
|
||||
try:
|
||||
# 获取板块行情(最近10天)
|
||||
daily_df = self.ts_client.pro.ths_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if daily_df.empty or len(daily_df) < 5:
|
||||
continue
|
||||
|
||||
daily_df = daily_df.sort_values('trade_date')
|
||||
|
||||
# 获取最新2天数据
|
||||
latest = daily_df.iloc[-1]
|
||||
prev = daily_df.iloc[-2]
|
||||
|
||||
# 计算成交量和成交额
|
||||
# ths_daily API 返回: vol(手), avg_price(元/股)
|
||||
# 成交额(元) = vol * avg_price * 100
|
||||
latest_vol = float(latest.get('vol', 0))
|
||||
latest_avg_price = float(latest.get('avg_price', 0))
|
||||
latest_amount = latest_vol * latest_avg_price * 100 # 转换为元
|
||||
|
||||
prev_vol = float(prev.get('vol', 0))
|
||||
prev_avg_price = float(prev.get('avg_price', 0))
|
||||
prev_amount = prev_vol * prev_avg_price * 100
|
||||
|
||||
# 计算成交量放大倍数
|
||||
vol_ratio = latest_vol / prev_vol if prev_vol > 0 else 1
|
||||
|
||||
# 计算成交额放大倍数
|
||||
amount_ratio = latest_amount / prev_amount if prev_amount > 0 else 1
|
||||
|
||||
# 涨跌幅
|
||||
change_pct = float(latest.get('pct_change', 0))
|
||||
|
||||
# 判断资金异动:
|
||||
# 1. 成交量放大 >= 1.2倍(宽松)或 2倍(严格)
|
||||
# 2. 成交额明显增加(>= 10%)
|
||||
# 3. 有一定涨幅辅助判断(可选)
|
||||
vol_threshold = 1.2 if not self.strict_mode else 2.0
|
||||
amount_threshold = 1.1 if not self.strict_mode else 1.5
|
||||
|
||||
is_volume_surge = vol_ratio >= vol_threshold
|
||||
is_amount_surge = amount_ratio >= amount_threshold
|
||||
has_min_change = change_pct >= 0.5 # 至少有一点涨幅
|
||||
|
||||
if (is_volume_surge or is_amount_surge) and has_min_change:
|
||||
hot_sectors.append({
|
||||
'ts_code': ts_code,
|
||||
'name': name,
|
||||
'change_pct': change_pct,
|
||||
'amount': latest_amount,
|
||||
'close': float(latest.get('close', 0)),
|
||||
'vol_ratio': vol_ratio,
|
||||
'amount_ratio': amount_ratio,
|
||||
'is_hot_sector': True
|
||||
})
|
||||
logger.info(f" ✓ {name}: 涨{change_pct:+.2f}%, 量比{vol_ratio:.2f}x, 额比{amount_ratio:.2f}x")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"获取板块 {name} 行情失败: {e}")
|
||||
continue
|
||||
|
||||
# 2. 如果热门板块不够,继续检查其他板块
|
||||
if len(hot_sectors) < 5:
|
||||
logger.info("热门板块数量不足,继续检查其他板块的资金异动...")
|
||||
|
||||
max_check = 200
|
||||
|
||||
for idx, row in sectors_df.iterrows():
|
||||
ts_code = row['ts_code']
|
||||
name = row.get('name', '')
|
||||
|
||||
if ts_code in checked_codes:
|
||||
continue
|
||||
checked_codes.add(ts_code)
|
||||
|
||||
try:
|
||||
daily_df = self.ts_client.pro.ths_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if daily_df.empty or len(daily_df) < 5:
|
||||
continue
|
||||
|
||||
daily_df = daily_df.sort_values('trade_date')
|
||||
latest = daily_df.iloc[-1]
|
||||
prev = daily_df.iloc[-2]
|
||||
|
||||
latest_vol = float(latest.get('vol', 0))
|
||||
latest_avg_price = float(latest.get('avg_price', 0))
|
||||
latest_amount = latest_vol * latest_avg_price * 100
|
||||
|
||||
prev_vol = float(prev.get('vol', 0))
|
||||
prev_avg_price = float(prev.get('avg_price', 0))
|
||||
prev_amount = prev_vol * prev_avg_price * 100
|
||||
|
||||
vol_ratio = latest_vol / prev_vol if prev_vol > 0 else 1
|
||||
amount_ratio = latest_amount / prev_amount if prev_amount > 0 else 1
|
||||
change_pct = float(latest.get('pct_change', 0))
|
||||
|
||||
# 非热门板块需要更强的异动信号
|
||||
if vol_ratio >= 2.0 and amount_ratio >= 1.5 and change_pct >= 1.0:
|
||||
hot_sectors.append({
|
||||
'ts_code': ts_code,
|
||||
'name': name,
|
||||
'change_pct': change_pct,
|
||||
'amount': latest_amount,
|
||||
'close': float(latest.get('close', 0)),
|
||||
'vol_ratio': vol_ratio,
|
||||
'amount_ratio': amount_ratio,
|
||||
'is_hot_sector': False
|
||||
})
|
||||
logger.info(f" ✓ {name}: 涨{change_pct:+.2f}%, 量比{vol_ratio:.2f}x, 额比{amount_ratio:.2f}x")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"获取板块 {name} 行情失败: {e}")
|
||||
continue
|
||||
|
||||
if len(hot_sectors) >= 10:
|
||||
break
|
||||
|
||||
result_df = pd.DataFrame(hot_sectors)
|
||||
if not result_df.empty:
|
||||
# 热门板块排在前面,按成交额放大倍数排序
|
||||
result_df = result_df.sort_values(['is_hot_sector', 'amount_ratio'], ascending=[False, False])
|
||||
logger.info(f"共找到 {len(result_df)} 个资金异动板块(热门: {result_df['is_hot_sector'].sum()} 个)")
|
||||
logger.info(f"平均量比: {result_df['vol_ratio'].mean():.2f}x, 平均额比: {result_df['amount_ratio'].mean():.2f}x")
|
||||
|
||||
return result_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取异动板块失败: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return pd.DataFrame()
|
||||
|
||||
def _select_stocks_from_sector(
|
||||
self,
|
||||
stock_codes: List[str],
|
||||
sector_name: str,
|
||||
sector_change: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从板块中筛选个股
|
||||
|
||||
Args:
|
||||
stock_codes: 股票代码列表
|
||||
sector_name: 板块名称
|
||||
sector_change: 板块涨跌幅
|
||||
|
||||
Returns:
|
||||
符合条件的股票列表
|
||||
"""
|
||||
selected = []
|
||||
|
||||
# 批量获取行情数据
|
||||
realtime_df = self.ts_client.get_realtime_data(stock_codes)
|
||||
if realtime_df.empty:
|
||||
logger.warning(f"板块 {sector_name} 无法获取实时行情数据")
|
||||
return []
|
||||
|
||||
logger.info(f" 获取到 {len(realtime_df)} 只股票的行情数据(请求了 {len(stock_codes)} 只)")
|
||||
|
||||
# 获取每日指标
|
||||
from datetime import datetime
|
||||
trade_date = datetime.now().strftime('%Y%m%d')
|
||||
basic_df = self.ts_client.get_stock_daily_basic(stock_codes, trade_date)
|
||||
|
||||
# 获取历史数据(计算技术指标)
|
||||
logger.debug(f" 开始检查 {len(stock_codes)} 只成分股...")
|
||||
checked_count = 0
|
||||
passed_count = 0
|
||||
|
||||
for stock_code in stock_codes: # 检查所有成分股
|
||||
try:
|
||||
checked_count += 1
|
||||
if checked_count % 10 == 0:
|
||||
logger.debug(f" 进度: {checked_count}/{len(stock_codes)}, 已通过: {passed_count}")
|
||||
|
||||
result = self._check_single_stock(
|
||||
stock_code, sector_name, sector_change,
|
||||
realtime_df, basic_df
|
||||
)
|
||||
if result:
|
||||
selected.append(result)
|
||||
passed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"检查股票 {stock_code} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f" 检查完成: {checked_count} 只,通过筛选: {passed_count} 只")
|
||||
return selected
|
||||
|
||||
def _check_single_stock(
|
||||
self,
|
||||
stock_code: str,
|
||||
sector_name: str,
|
||||
sector_change: float,
|
||||
realtime_df: pd.DataFrame,
|
||||
basic_df: pd.DataFrame
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
检查单只股票是否符合条件
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
sector_name: 所属板块
|
||||
sector_change: 板块涨跌幅
|
||||
realtime_df: 实时行情数据
|
||||
basic_df: 每日指标数据
|
||||
|
||||
Returns:
|
||||
符合条件返回股票信息,否则返回None
|
||||
"""
|
||||
# 获取实时行情
|
||||
stock_data = realtime_df[realtime_df['ts_code'] == stock_code]
|
||||
if stock_data.empty:
|
||||
logger.debug(f" ⚠️ {stock_code}: 无实时行情数据")
|
||||
return None
|
||||
|
||||
row = stock_data.iloc[0]
|
||||
|
||||
# 基本数据
|
||||
close = float(row['close'])
|
||||
pct_chg = float(row['pct_chg'])
|
||||
amount = float(row['amount']) * 1000 # 转换为元
|
||||
vol = float(row['vol'])
|
||||
|
||||
# 获取股票名称
|
||||
name = row.get('name', '')
|
||||
|
||||
logger.debug(f" 🔍 {name}({stock_code}): 价格={close:.2f}, 涨跌幅={pct_chg:+.2f}%")
|
||||
|
||||
# 过滤ST股票
|
||||
if 'ST' in name or '退' in name:
|
||||
logger.debug(f" ✗ {name}({stock_code}): ST/退市股,跳过")
|
||||
return None
|
||||
|
||||
# 获取每日指标
|
||||
basic_data = basic_df[basic_df['ts_code'] == stock_code]
|
||||
if not basic_data.empty:
|
||||
turnover = float(basic_data.iloc[0].get('turnover_rate', 0))
|
||||
|
||||
# 换手率过滤(只有有数据时才检查)
|
||||
if turnover > 0 and not (self.min_turnover <= turnover <= self.max_turnover):
|
||||
logger.debug(f" ✗ {name}({stock_code}): 换手率不符合 ({turnover:.2f}%)")
|
||||
return None
|
||||
else:
|
||||
turnover = 0.0
|
||||
|
||||
# 获取历史数据计算技术指标
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y%m%d')
|
||||
|
||||
try:
|
||||
daily_df = self.ts_client.pro.daily(
|
||||
ts_code=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if daily_df.empty or len(daily_df) < 30:
|
||||
return None
|
||||
|
||||
daily_df = daily_df.sort_values('trade_date').reset_index(drop=True)
|
||||
close_series = daily_df['close']
|
||||
vol_series = daily_df['vol']
|
||||
|
||||
# 计算均线
|
||||
ma5 = close_series.rolling(window=5).mean().iloc[-1]
|
||||
ma10 = close_series.rolling(window=10).mean().iloc[-1]
|
||||
ma20 = close_series.rolling(window=20).mean().iloc[-1]
|
||||
ma5_vol = vol_series.rolling(window=5).mean().iloc[-1]
|
||||
|
||||
# MA趋势检查:MA5 > MA20(要求短期在长期趋势之上)
|
||||
# 在震荡修复阶段,允许MA5略低于MA10,但必须高于MA20
|
||||
if not (ma5 > ma20):
|
||||
logger.debug(f" ✗ {name}({stock_code}): MA5不在MA20之上 (MA5={ma5:.2f}, MA10={ma10:.2f}, MA20={ma20:.2f})")
|
||||
return None
|
||||
|
||||
# 量能检查
|
||||
volume_ratio = vol / ma5_vol if ma5_vol > 0 else 0
|
||||
if volume_ratio < self.volume_ratio_threshold:
|
||||
logger.debug(f" ✗ {name}({stock_code}): 量能不足 (量比: {volume_ratio:.2f})")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f" ✗ {name}({stock_code}): 计算技术指标失败: {e}")
|
||||
return None
|
||||
|
||||
# 估算市值(使用成交额和换手率)
|
||||
if turnover > 0:
|
||||
market_cap = amount / (turnover / 100) # 元
|
||||
market_cap_yi = market_cap / 100000000 # 转换为亿
|
||||
|
||||
# 市值过滤
|
||||
if not (self.min_market_cap <= market_cap_yi <= self.max_market_cap):
|
||||
logger.debug(f" ✗ {name}({stock_code}): 市值不符合 ({market_cap_yi:.2f}亿)")
|
||||
return None
|
||||
else:
|
||||
market_cap_yi = 0
|
||||
|
||||
# 通过所有筛选条件
|
||||
logger.info(f" ✓ {name}({stock_code}): 符合条件")
|
||||
|
||||
return {
|
||||
'ts_code': stock_code,
|
||||
'name': name,
|
||||
'close': close,
|
||||
'pct_chg': pct_chg,
|
||||
'amount': amount,
|
||||
'turnover': turnover,
|
||||
'volume_ratio': volume_ratio,
|
||||
'market_cap_yi': market_cap_yi,
|
||||
'sector': sector_name,
|
||||
'sector_change': sector_change,
|
||||
'ma5': ma5,
|
||||
'ma10': ma10,
|
||||
'ma20': ma20,
|
||||
}
|
||||
|
||||
def _rank_stocks(self, stocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
综合评分和排序
|
||||
|
||||
评分维度:
|
||||
- 板块强度 (40%)
|
||||
- 个股涨幅 (30%)
|
||||
- 量能表现 (30%)
|
||||
|
||||
Args:
|
||||
stocks: 股票列表
|
||||
|
||||
Returns:
|
||||
排序后的股票列表
|
||||
"""
|
||||
for stock in stocks:
|
||||
score = 0.0
|
||||
|
||||
# 1. 板块强度 (40分)
|
||||
sector_change = stock['sector_change']
|
||||
if sector_change >= 5:
|
||||
score += 40
|
||||
elif sector_change >= 3:
|
||||
score += 35
|
||||
elif sector_change >= 2:
|
||||
score += 30
|
||||
elif sector_change >= 1:
|
||||
score += 25
|
||||
elif sector_change > 0:
|
||||
score += 20
|
||||
else:
|
||||
score += 10
|
||||
|
||||
# 2. 个股涨幅 (30分)
|
||||
pct_chg = stock['pct_chg']
|
||||
if pct_chg >= 7:
|
||||
score += 30
|
||||
elif pct_chg >= 5:
|
||||
score += 26
|
||||
elif pct_chg >= 3:
|
||||
score += 22
|
||||
elif pct_chg >= 1:
|
||||
score += 18
|
||||
elif pct_chg > 0:
|
||||
score += 12
|
||||
else:
|
||||
score += 5
|
||||
|
||||
# 3. 量能表现 (30分)
|
||||
volume_ratio = stock['volume_ratio']
|
||||
if volume_ratio >= 2.5:
|
||||
score += 30
|
||||
elif volume_ratio >= 2.0:
|
||||
score += 26
|
||||
elif volume_ratio >= 1.5:
|
||||
score += 22
|
||||
elif volume_ratio >= 1.2:
|
||||
score += 18
|
||||
else:
|
||||
score += 10
|
||||
|
||||
# 4. 换手率 (10分)
|
||||
turnover = stock.get('turnover', 0)
|
||||
if 8 <= turnover <= 12: # 最理想的换手率范围
|
||||
score += 10
|
||||
elif 5 <= turnover < 8 or 12 < turnover <= 15:
|
||||
score += 8
|
||||
elif 3 <= turnover < 5:
|
||||
score += 6
|
||||
else:
|
||||
score += 4
|
||||
|
||||
stock['score'] = score
|
||||
|
||||
# 按得分排序
|
||||
return sorted(stocks, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
def _allocate_positions(self, stocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
仓位分配
|
||||
|
||||
策略:
|
||||
- 优先级高的股票获得更大仓位
|
||||
- 根据得分动态分配仓位
|
||||
- 确保风险分散
|
||||
|
||||
Args:
|
||||
stocks: 股票列表
|
||||
|
||||
Returns:
|
||||
添加了仓位信息的股票列表
|
||||
"""
|
||||
if not stocks:
|
||||
return []
|
||||
|
||||
total_score = sum(s['score'] for s in stocks)
|
||||
|
||||
for stock in stocks:
|
||||
# 根据得分比例分配仓位
|
||||
score_ratio = stock['score'] / total_score if total_score > 0 else 1.0 / len(stocks)
|
||||
|
||||
# 基础仓位(按得分比例)
|
||||
base_position = score_ratio * self.max_total_position
|
||||
|
||||
# 调整:最高得分股票仓位不超过最大单票仓位
|
||||
if base_position > self.max_single_position:
|
||||
base_position = self.max_single_position
|
||||
|
||||
# 仓位范围:5% - 20%
|
||||
position = max(0.05, min(base_position, self.max_single_position))
|
||||
|
||||
stock['position'] = position
|
||||
stock['stop_loss'] = close * (1 - 0.07) # 硬止损-7%
|
||||
stock['target_profit'] = close * (1 + 0.15) # 目标止盈+15%
|
||||
|
||||
return stocks
|
||||
|
||||
def _format_result(self, stocks: List[Dict[str, Any]], sectors: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化选股结果
|
||||
|
||||
Args:
|
||||
stocks: 股票列表
|
||||
sectors: 异动板块列表
|
||||
|
||||
Returns:
|
||||
格式化的结果字典
|
||||
"""
|
||||
return {
|
||||
'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'total_stocks': len(stocks),
|
||||
'total_sectors': len(sectors),
|
||||
'stocks': stocks,
|
||||
'sectors': sectors.head(10).to_dict('records'),
|
||||
'summary': self._generate_summary(stocks)
|
||||
}
|
||||
|
||||
def _generate_summary(self, stocks: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
生成汇总信息
|
||||
|
||||
Args:
|
||||
stocks: 股票列表
|
||||
|
||||
Returns:
|
||||
汇总信息字典
|
||||
"""
|
||||
if not stocks:
|
||||
return {}
|
||||
|
||||
total_position = sum(s['position'] for s in stocks)
|
||||
sectors = list(set(s['sector'] for s in stocks))
|
||||
|
||||
return {
|
||||
'total_position': total_position,
|
||||
'position_percent': total_position * 100,
|
||||
'sector_count': len(sectors),
|
||||
'sectors': sectors,
|
||||
'avg_score': sum(s['score'] for s in stocks) / len(stocks),
|
||||
}
|
||||
|
||||
def _empty_result(self) -> Dict[str, Any]:
|
||||
"""返回空结果"""
|
||||
return {
|
||||
'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'total_stocks': 0,
|
||||
'total_sectors': 0,
|
||||
'stocks': [],
|
||||
'sectors': [],
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
def format_output_text(self, result: Dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化输出文本
|
||||
|
||||
Args:
|
||||
result: 选股结果
|
||||
|
||||
Returns:
|
||||
格式化的文本
|
||||
"""
|
||||
if not result or result['total_stocks'] == 0:
|
||||
return "📊 **短期题材选股结果**\n\n今日未选出符合条件的股票\n\n*⚠️ 仅供参考,不构成投资建议*"
|
||||
|
||||
lines = [
|
||||
"📊 **短期题材选股结果**",
|
||||
"",
|
||||
f"选股时间: {result['date']}",
|
||||
f"选出股票: {result['total_stocks']} 只",
|
||||
f"异动板块: {result['total_sectors']} 个",
|
||||
"",
|
||||
]
|
||||
|
||||
# 汇总信息
|
||||
if result.get('summary'):
|
||||
summary = result['summary']
|
||||
lines.extend([
|
||||
"**💼 仓位配置**",
|
||||
f"总仓位: {summary['position_percent']:.1f}%",
|
||||
f"涉及板块: {summary['sector_count']} 个",
|
||||
f"平均得分: {summary['avg_score']:.1f}分",
|
||||
"",
|
||||
])
|
||||
|
||||
# 异动板块
|
||||
if result.get('sectors'):
|
||||
lines.append("**🔥 异动板块 Top5**")
|
||||
for sector in result['sectors'][:5]:
|
||||
vol_ratio = sector.get('vol_ratio', 0)
|
||||
amount_ratio = sector.get('amount_ratio', 0)
|
||||
vol_icon = "🔥" if vol_ratio >= 2.0 else "📊"
|
||||
lines.append(f"- {sector['name']}: {sector['change_pct']:+.2f}% | 量比{vol_ratio:.2f}x {vol_icon} | 额比{amount_ratio:.2f}x")
|
||||
lines.append("")
|
||||
|
||||
# 选出股票
|
||||
lines.append("**🏆 选出股票**")
|
||||
for idx, stock in enumerate(result['stocks'], 1):
|
||||
lines.extend([
|
||||
f"",
|
||||
f"**{idx}. {stock['name']} ({stock['ts_code']})**",
|
||||
f" 现价: ¥{stock['close']:.2f} ({stock['pct_chg']:+.2f}%)",
|
||||
f" 板块: {stock['sector']} ({stock['sector_change']:+.2f}%)",
|
||||
f" 换手率: {stock['turnover']:.2f}% | 量比: {stock['volume_ratio']:.2f}",
|
||||
f" 市值: {stock['market_cap_yi']:.2f}亿 | 评分: {stock['score']:.1f}分",
|
||||
f" MA5: ¥{stock['ma5']:.2f} | MA10: ¥{stock['ma10']:.2f} | MA20: ¥{stock['ma20']:.2f}",
|
||||
f" 建议仓位: {stock['position']*100:.1f}%",
|
||||
f" 止损价: ¥{stock['stop_loss']:.2f} (-7%)",
|
||||
f" 目标价: ¥{stock['target_profit']:.2f} (+15%)",
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
"",
|
||||
"---",
|
||||
"",
|
||||
"**⚠️ 风险提示**",
|
||||
f"- 硬止损: {self.hard_stop_loss}%(单只股票最大损失)",
|
||||
f"- 技术止损: 跌破20日均线",
|
||||
f"- 时间止损: 持仓>30天未启动",
|
||||
f"- 单票最大: {self.max_single_position*100}%",
|
||||
f"- 单行业最大: {self.max_sector_position*100}%",
|
||||
"",
|
||||
"*⚠️ 仅供参考,不构成投资建议*"
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_thematic_selector: Optional[ShortTermThematicSelector] = None
|
||||
|
||||
|
||||
def get_thematic_selector(tushare_client=None) -> ShortTermThematicSelector:
|
||||
"""获取短期题材选股器单例"""
|
||||
global _thematic_selector
|
||||
if _thematic_selector is None:
|
||||
if tushare_client is None:
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
from app.config import get_settings
|
||||
settings = get_settings()
|
||||
tushare_client = get_tushare_client(settings.tushare_token)
|
||||
_thematic_selector = ShortTermThematicSelector(tushare_client)
|
||||
return _thematic_selector
|
||||
@ -1,192 +0,0 @@
|
||||
"""
|
||||
龙头股筛选
|
||||
从异动板块中筛选出龙头股
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, List
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
from .akshare_client import get_akshare_client
|
||||
|
||||
|
||||
class StockSelector:
|
||||
"""龙头股筛选器"""
|
||||
|
||||
def __init__(self, top_n: int = 3):
|
||||
"""
|
||||
初始化筛选器
|
||||
|
||||
Args:
|
||||
top_n: 返回前 N 只龙头股
|
||||
"""
|
||||
self.top_n = top_n
|
||||
self.akshare = get_akshare_client()
|
||||
|
||||
def select_leading_stocks(self, sector_name: str) -> List[Dict]:
|
||||
"""
|
||||
筛选板块龙头股
|
||||
|
||||
Args:
|
||||
sector_name: 板块名称
|
||||
|
||||
Returns:
|
||||
龙头股列表(已排序)
|
||||
"""
|
||||
try:
|
||||
# 获取成分股
|
||||
stocks_df = self.akshare.get_concept_stocks(sector_name)
|
||||
if stocks_df.empty:
|
||||
logger.warning(f"获取板块 {sector_name} 成分股失败")
|
||||
return []
|
||||
|
||||
# 获取实时行情
|
||||
spot_df = self.akshare.get_stock_spot()
|
||||
if spot_df.empty:
|
||||
logger.warning("获取实时行情失败")
|
||||
return []
|
||||
|
||||
# 合并数据
|
||||
merged = pd.merge(
|
||||
stocks_df[['代码', '名称']],
|
||||
spot_df,
|
||||
on='代码',
|
||||
how='inner'
|
||||
)
|
||||
|
||||
if merged.empty:
|
||||
return []
|
||||
|
||||
# 数据类型转换
|
||||
merged['最新价'] = pd.to_numeric(merged['最新价'], errors='coerce')
|
||||
merged['涨跌幅'] = pd.to_numeric(merged['涨跌幅'], errors='coerce')
|
||||
merged['涨跌额'] = pd.to_numeric(merged['涨跌额'], errors='coerce')
|
||||
merged['成交量'] = pd.to_numeric(merged['成交量'], errors='coerce')
|
||||
merged['成交额'] = pd.to_numeric(merged['成交额'], errors='coerce')
|
||||
merged['换手率'] = pd.to_numeric(merged['换手率'], errors='coerce')
|
||||
merged['振幅'] = pd.to_numeric(merged['振幅'], errors='coerce')
|
||||
merged['量比'] = pd.to_numeric(merged['量比'], errors='coerce')
|
||||
|
||||
# 过滤:只保留有成交额的股票
|
||||
merged = merged[merged['成交额'] > 0].copy()
|
||||
|
||||
if merged.empty:
|
||||
return []
|
||||
|
||||
# 计算综合评分
|
||||
merged['score'] = merged.apply(self._calculate_score, axis=1)
|
||||
|
||||
# 排序:按综合得分
|
||||
merged = merged.sort_values('score', ascending=False)
|
||||
|
||||
# 取前 N 只
|
||||
top_stocks = merged.head(self.top_n)
|
||||
|
||||
# 转换结果
|
||||
results = []
|
||||
for _, row in top_stocks.iterrows():
|
||||
# 计算涨速等级
|
||||
change_pct = row['涨跌幅']
|
||||
if change_pct >= 5:
|
||||
speed_level = "⚡⚡⚡ 极快"
|
||||
elif change_pct >= 3:
|
||||
speed_level = "⚡⚡ 快速"
|
||||
elif change_pct >= 1:
|
||||
speed_level = "⚡ 较快"
|
||||
else:
|
||||
speed_level = "🐌 平稳"
|
||||
|
||||
results.append({
|
||||
'code': row['代码'],
|
||||
'name': row['名称'],
|
||||
'price': float(row['最新价']),
|
||||
'change_pct': float(row['涨跌幅']),
|
||||
'change_amount': float(row['涨跌额']),
|
||||
'amount': float(row['成交额']),
|
||||
'turnover': float(row['换手率']),
|
||||
'volume_ratio': float(row.get('量比', 1)),
|
||||
'amplitude': float(row.get('振幅', 0)),
|
||||
'score': float(row['score']),
|
||||
'speed_level': speed_level,
|
||||
})
|
||||
|
||||
logger.info(f"板块 {sector_name} 龙头股筛选完成,Top {len(results)}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"筛选龙头股失败 {sector_name}: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_score(self, row: pd.Series) -> float:
|
||||
"""
|
||||
计算综合得分
|
||||
|
||||
评分维度:
|
||||
- 涨跌幅 (40%)
|
||||
- 成交额 (30%)
|
||||
- 涨速 (20%)
|
||||
- 换手率 (10%)
|
||||
|
||||
Args:
|
||||
row: 股票数据行
|
||||
|
||||
Returns:
|
||||
综合得分
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 1. 涨跌幅得分 (40分) - 涨幅越高得分越高
|
||||
change_pct = row['涨跌幅']
|
||||
if change_pct >= 7:
|
||||
score += 40 # 涨停级别
|
||||
elif change_pct >= 5:
|
||||
score += 35
|
||||
elif change_pct >= 3:
|
||||
score += 30
|
||||
elif change_pct >= 2:
|
||||
score += 25
|
||||
elif change_pct >= 1:
|
||||
score += 20
|
||||
elif change_pct > 0:
|
||||
score += 15
|
||||
else:
|
||||
score += max(0, 10 + change_pct * 5) # 下跌也有基础分
|
||||
|
||||
# 2. 成交额得分 (30分) - 成交额越大得分越高
|
||||
amount = row['成交额']
|
||||
if amount >= 100000: # 10亿以上
|
||||
score += 30
|
||||
elif amount >= 50000: # 5亿以上
|
||||
score += 25
|
||||
elif amount >= 10000: # 1亿以上
|
||||
score += 20
|
||||
elif amount >= 5000: # 5000万以上
|
||||
score += 15
|
||||
elif amount >= 1000: # 1000万以上
|
||||
score += 10
|
||||
else:
|
||||
score += 5
|
||||
|
||||
# 3. 涨速得分 (20分) - 简化用涨幅代替
|
||||
if change_pct >= 5:
|
||||
score += 20
|
||||
elif change_pct >= 3:
|
||||
score += 15
|
||||
elif change_pct >= 1:
|
||||
score += 10
|
||||
else:
|
||||
score += 5
|
||||
|
||||
# 4. 换手率得分 (10分) - 适中换手率加分
|
||||
turnover = row['换手率']
|
||||
if 5 <= turnover <= 15:
|
||||
score += 10 # 适中换手率
|
||||
elif 15 < turnover <= 25:
|
||||
score += 8 # 活跃但不过热
|
||||
elif turnover > 25:
|
||||
score += 5 # 过热可能回调
|
||||
elif turnover > 0:
|
||||
score += 3 # 有成交即可
|
||||
else:
|
||||
score += 0
|
||||
|
||||
return score
|
||||
@ -1,384 +0,0 @@
|
||||
"""
|
||||
Tushare 数据封装
|
||||
提供 A 股板块、个股行情数据获取接口(使用同花顺系列接口)
|
||||
"""
|
||||
import time
|
||||
import tushare as ts
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class TushareClient:
|
||||
"""Tushare 数据客户端(同花顺系列接口)"""
|
||||
|
||||
# 缓存数据,避免频繁请求
|
||||
_cache = {}
|
||||
_cache_time = {}
|
||||
_last_request_time = 0
|
||||
|
||||
def __init__(self, token: str):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
token: Tushare token
|
||||
"""
|
||||
self.token = token
|
||||
ts.set_token(token)
|
||||
self.pro = ts.pro_api()
|
||||
self.cache_ttl = 300 # 缓存5分钟
|
||||
self.request_delay = 0.5 # 请求间隔(秒)- tushare 有频率限制
|
||||
|
||||
def _get_cached(self, key: str, fetch_func) -> pd.DataFrame:
|
||||
"""获取缓存数据,支持重试"""
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if key in self._cache:
|
||||
cache_time = self._cache_time.get(key)
|
||||
if cache_time and (now - cache_time).seconds < self.cache_ttl:
|
||||
logger.debug(f"使用缓存数据: {key}")
|
||||
return self._cache[key]
|
||||
|
||||
# 请求限流
|
||||
elapsed = now.timestamp() - self._last_request_time
|
||||
if elapsed < self.request_delay:
|
||||
time.sleep(self.request_delay - elapsed)
|
||||
|
||||
# 重试逻辑
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._last_request_time = time.time()
|
||||
df = fetch_func()
|
||||
|
||||
if df is not None and not df.empty:
|
||||
self._cache[key] = df
|
||||
self._cache_time[key] = now
|
||||
logger.debug(f"获取数据成功: {key}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 指数退避重试
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = (2 ** attempt) * 2
|
||||
logger.warning(
|
||||
f"获取数据失败 {key} (尝试 {attempt + 1}/{max_retries}): {e},"
|
||||
f"等待 {wait_time}秒后重试..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
logger.error(f"获取数据失败 {key}: {e}")
|
||||
break
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_concept_sectors(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取概念板块列表
|
||||
|
||||
使用 ths_index 接口,type="N" 代表概念板块
|
||||
|
||||
Returns:
|
||||
概念板块列表
|
||||
"""
|
||||
def fetch():
|
||||
# ths_index - 获取同花顺概念指数列表
|
||||
return self.pro.ths_index(type='N')
|
||||
|
||||
return self._get_cached('concept_sectors', fetch)
|
||||
|
||||
def get_sector_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
获取板块日线行情
|
||||
|
||||
Args:
|
||||
ts_code: 板块指数代码(如 885823.TI)
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
板块日线数据
|
||||
"""
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
def fetch():
|
||||
# ths_daily - 获取板块指数历史行情
|
||||
return self.pro.ths_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return self._get_cached(f'sector_daily_{ts_code}_{end_date}', fetch)
|
||||
|
||||
def get_sector_members(self, ts_code: str) -> pd.DataFrame:
|
||||
"""
|
||||
获取板块成分股
|
||||
|
||||
Args:
|
||||
ts_code: 板块指数代码(如 885823.TI)
|
||||
|
||||
Returns:
|
||||
成分股列表
|
||||
"""
|
||||
def fetch():
|
||||
# ths_member - 获取板块成分股
|
||||
return self.pro.ths_member(ts_code=ts_code)
|
||||
|
||||
return self._get_cached(f'sector_members_{ts_code}', fetch)
|
||||
|
||||
def get_stock_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
获取个股日线行情
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码(如 000001.SZ)
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
日线数据
|
||||
"""
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
def fetch():
|
||||
# daily - 获取日线行情
|
||||
return self.pro.daily(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return self._get_cached(f'stock_daily_{ts_code}_{end_date}', fetch)
|
||||
|
||||
def get_stock_daily_basic(self, ts_codes: List[str], trade_date: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
获取个股每日指标(包含换手率、量比等)
|
||||
|
||||
Args:
|
||||
ts_codes: 股票代码列表
|
||||
trade_date: 交易日期 (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
每日指标数据
|
||||
"""
|
||||
if not ts_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
if not trade_date:
|
||||
trade_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
def fetch():
|
||||
# daily_basic - 获取每日指标
|
||||
# 分批处理以支持超过300只股票的情况
|
||||
all_data = []
|
||||
batch_size = 300
|
||||
for i in range(0, min(len(ts_codes), 900), batch_size): # 最多处理900只
|
||||
batch_codes = ts_codes[i:i+batch_size]
|
||||
|
||||
# 尝试获取最近3天的数据(以防当天数据未更新)
|
||||
for j in range(3):
|
||||
try_date = (datetime.now() - timedelta(days=j)).strftime('%Y%m%d')
|
||||
df = self.pro.daily_basic(
|
||||
ts_code=','.join(batch_codes),
|
||||
trade_date=try_date,
|
||||
fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb'
|
||||
)
|
||||
if not df.empty:
|
||||
all_data.append(df)
|
||||
# 如果找到数据就不再尝试更早的日期
|
||||
break
|
||||
|
||||
if all_data:
|
||||
return pd.concat(all_data, ignore_index=True)
|
||||
return pd.DataFrame()
|
||||
|
||||
# 创建包含股票代码的缓存键
|
||||
codes_key = '_'.join(sorted(ts_codes[:20]))
|
||||
cache_key = f'stock_daily_basic_{trade_date}_{codes_key}'
|
||||
|
||||
return self._get_cached(cache_key, fetch)
|
||||
|
||||
def get_stock_basic(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取股票基本信息列表
|
||||
|
||||
Returns:
|
||||
股票基本信息
|
||||
"""
|
||||
def fetch():
|
||||
# stock_basic - 获取股票基本信息
|
||||
return self.pro.stock_basic(
|
||||
exchange='',
|
||||
list_status='L',
|
||||
fields='ts_code,symbol,name,area,industry,list_date'
|
||||
)
|
||||
|
||||
return self._get_cached('stock_basic', fetch)
|
||||
|
||||
def get_realtime_data(self, ts_codes: List[str]) -> pd.DataFrame:
|
||||
"""
|
||||
获取实时行情数据(使用最新的日线数据)
|
||||
|
||||
注意:tushare 不提供真正的实时数据,这里返回最新的日线数据
|
||||
注意:amount 字段单位是千元,需要 * 1000 转换为元
|
||||
|
||||
Args:
|
||||
ts_codes: 股票代码列表
|
||||
|
||||
Returns:
|
||||
实时行情数据(amount 单位为千元)
|
||||
"""
|
||||
if not ts_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 获取今天的日期
|
||||
today = datetime.now().strftime('%Y%m%d')
|
||||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||||
|
||||
# 创建包含股票代码的缓存键
|
||||
codes_key = '_'.join(sorted(ts_codes[:20])) # 使用前20只代码创建唯一键
|
||||
cache_key = f'realtime_{today}_{codes_key}'
|
||||
|
||||
def fetch():
|
||||
# 使用 daily 接口获取最近数据
|
||||
# 分批处理以支持超过100只股票的情况
|
||||
all_dfs = []
|
||||
batch_size = 100
|
||||
for i in range(0, min(len(ts_codes), 500), batch_size): # 最多处理500只
|
||||
batch_codes = ts_codes[i:i+batch_size]
|
||||
codes_str = ','.join(batch_codes)
|
||||
df = self.pro.daily(
|
||||
ts_code=codes_str,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
if not df.empty:
|
||||
all_dfs.append(df)
|
||||
|
||||
# 合并所有批次的数据
|
||||
if all_dfs:
|
||||
combined_df = pd.concat(all_dfs, ignore_index=True)
|
||||
# 只返回每个股票的最新一天数据
|
||||
combined_df = combined_df.sort_values('trade_date').groupby('ts_code').tail(1)
|
||||
|
||||
# 获取股票基本信息(包含股票名称)
|
||||
stock_basic = self.pro.stock_basic(
|
||||
exchange='',
|
||||
list_status='L',
|
||||
fields='ts_code,symbol,name,area,industry,list_date'
|
||||
)
|
||||
|
||||
# 合并股票名称
|
||||
if not stock_basic.empty:
|
||||
combined_df = combined_df.merge(
|
||||
stock_basic[['ts_code', 'name']],
|
||||
on='ts_code',
|
||||
how='left'
|
||||
)
|
||||
|
||||
return combined_df
|
||||
return pd.DataFrame()
|
||||
|
||||
return self._get_cached(cache_key, fetch)
|
||||
|
||||
def get_hot_sectors(self, threshold: float = 2.0) -> pd.DataFrame:
|
||||
"""
|
||||
获取异动板块(一次性获取所有板块的最新行情)
|
||||
|
||||
Args:
|
||||
threshold: 涨跌幅阈值(%)
|
||||
|
||||
Returns:
|
||||
异动板块数据
|
||||
"""
|
||||
try:
|
||||
# 1. 获取所有概念板块
|
||||
sectors_df = self.get_concept_sectors()
|
||||
if sectors_df.empty:
|
||||
logger.warning("获取概念板块列表失败")
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f"获取到 {len(sectors_df)} 个概念板块")
|
||||
|
||||
# 2. 获取今天的日期
|
||||
today = datetime.now().strftime('%Y%m%d')
|
||||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||||
|
||||
# 3. 批量获取板块行情(为了效率,限制数量)
|
||||
hot_sectors = []
|
||||
max_sectors = 100 # 最多检查100个板块
|
||||
|
||||
for idx, row in sectors_df.head(max_sectors).iterrows():
|
||||
ts_code = row['ts_code']
|
||||
name = row.get('name', '')
|
||||
|
||||
try:
|
||||
# 获取板块最新行情
|
||||
daily_df = self.pro.ths_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if daily_df.empty:
|
||||
continue
|
||||
|
||||
# 获取最新一天的数据
|
||||
latest = daily_df.sort_values('trade_date').iloc[-1]
|
||||
|
||||
# 检查涨跌幅 - 注意列名是 pct_change 不是 pct_chg
|
||||
change_pct = float(latest.get('pct_change', 0))
|
||||
if change_pct >= threshold:
|
||||
hot_sectors.append({
|
||||
'ts_code': ts_code,
|
||||
'name': name,
|
||||
'change_pct': change_pct,
|
||||
'change': float(latest.get('change', 0)), # 涨跌额
|
||||
'close': float(latest.get('close', 0)),
|
||||
'amount': float(latest.get('amount', 0)), # 成交额(元)
|
||||
'volume': float(latest.get('vol', 0)), # 成交量(手)
|
||||
'turnover_rate': float(latest.get('turnover_rate', 0)), # 换手率
|
||||
'trade_date': str(latest.get('trade_date', ''))
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"获取板块 {name} 行情失败: {e}")
|
||||
continue
|
||||
|
||||
result_df = pd.DataFrame(hot_sectors)
|
||||
if not result_df.empty:
|
||||
result_df = result_df.sort_values('change_pct', ascending=False)
|
||||
|
||||
return result_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取异动板块失败: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
# 全局单例
|
||||
_tushare_client: Optional[TushareClient] = None
|
||||
|
||||
|
||||
def get_tushare_client(token: str = None) -> Optional[TushareClient]:
|
||||
"""获取 Tushare 客户端单例"""
|
||||
global _tushare_client
|
||||
if _tushare_client is None:
|
||||
if not token:
|
||||
return None
|
||||
_tushare_client = TushareClient(token)
|
||||
return _tushare_client
|
||||
@ -1,189 +0,0 @@
|
||||
"""
|
||||
板块异动分析(Tushare 版本)
|
||||
检测板块涨跌幅、量能、资金流向异动
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
from app.utils.error_handler import notify_error
|
||||
|
||||
|
||||
class TushareSectorAnalyzer:
|
||||
"""板块异动分析器(使用 Tushare 同花顺接口)"""
|
||||
|
||||
def __init__(self, tushare_client, change_threshold: float = 2.0):
|
||||
"""
|
||||
初始化异动分析器
|
||||
|
||||
Args:
|
||||
tushare_client: TushareClient 实例
|
||||
change_threshold: 涨跌幅阈值(%)
|
||||
"""
|
||||
self.change_threshold = change_threshold
|
||||
self.ts_client = tushare_client
|
||||
|
||||
def detect_sector_changes(self) -> List[Dict]:
|
||||
"""
|
||||
检测异动板块
|
||||
|
||||
Returns:
|
||||
异动板块列表
|
||||
"""
|
||||
try:
|
||||
# 使用 tushare 获取异动板块(一次性获取)
|
||||
df = self.ts_client.get_hot_sectors(threshold=self.change_threshold)
|
||||
|
||||
if df.empty:
|
||||
logger.info("未检测到异动板块")
|
||||
return []
|
||||
|
||||
# 转换为结果列表
|
||||
results = []
|
||||
for _, row in df.iterrows():
|
||||
# 成交额转换为万元
|
||||
amount_wan = row['amount'] / 10000 if row['amount'] > 0 else 0
|
||||
|
||||
# 格式化成交额显示
|
||||
if amount_wan >= 100000:
|
||||
amount_str = f"{amount_wan/100000:.1f}亿"
|
||||
elif amount_wan >= 10000:
|
||||
amount_str = f"{amount_wan/10000:.1f}万"
|
||||
else:
|
||||
amount_str = f"{amount_wan:.0f}元"
|
||||
|
||||
results.append({
|
||||
'name': row['name'],
|
||||
'ts_code': row['ts_code'],
|
||||
'change_pct': float(row['change_pct']),
|
||||
'change': float(row.get('change', 0)), # 涨跌额
|
||||
'close': float(row['close']),
|
||||
'amount': float(row['amount']),
|
||||
'amount_str': amount_str,
|
||||
'volume': float(row['volume']),
|
||||
'turnover_rate': float(row.get('turnover_rate', 0)), # 换手率
|
||||
'trade_date': row['trade_date'],
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
logger.info(f"检测到 {len(results)} 个异动概念板块(Tushare)")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Tushare 检测板块异动失败: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 发送通知
|
||||
notify_error(
|
||||
title="A股板块监控 - Tushare 数据获取失败",
|
||||
message=f"错误: {e}\n\n可能原因:\n- Tushare token 未配置或无效\n- API 频率限制\n- 网络连接问题",
|
||||
level="warning"
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
def get_sector_stocks(self, ts_code: str, sector_name: str) -> List[Dict]:
|
||||
"""
|
||||
获取板块成分股
|
||||
|
||||
Args:
|
||||
ts_code: 板块指数代码
|
||||
sector_name: 板块名称
|
||||
|
||||
Returns:
|
||||
成分股列表
|
||||
"""
|
||||
try:
|
||||
# 获取成分股
|
||||
members_df = self.ts_client.get_sector_members(ts_code)
|
||||
|
||||
if members_df.empty:
|
||||
logger.warning(f"板块 {sector_name} 成分股数据为空")
|
||||
return []
|
||||
|
||||
# 获取成分股的行情数据
|
||||
stock_codes = members_df['ts_code'].tolist()
|
||||
|
||||
# 限制数量,避免请求过多
|
||||
if len(stock_codes) > 50:
|
||||
stock_codes = stock_codes[:50]
|
||||
|
||||
# 获取实时行情
|
||||
realtime_df = self.ts_client.get_realtime_data(stock_codes)
|
||||
|
||||
if realtime_df.empty:
|
||||
logger.warning(f"板块 {sector_name} 成分股行情为空")
|
||||
return []
|
||||
|
||||
# 合并数据
|
||||
merged = pd.merge(
|
||||
members_df,
|
||||
realtime_df,
|
||||
on='ts_code',
|
||||
how='inner'
|
||||
)
|
||||
|
||||
if merged.empty:
|
||||
return []
|
||||
|
||||
# 转换结果
|
||||
results = []
|
||||
for _, row in merged.iterrows():
|
||||
results.append({
|
||||
'code': row['ts_code'],
|
||||
'name': row.get('name', row.get('member_name', '')),
|
||||
'price': float(row.get('close', 0)),
|
||||
'change_pct': float(row.get('pct_chg', 0)),
|
||||
'change_amount': float(row.get('change', 0)),
|
||||
'amount': float(row.get('amount', 0)),
|
||||
'volume': float(row.get('vol', 0)),
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取板块 {sector_name} 成分股失败: {e}")
|
||||
return []
|
||||
|
||||
def get_hot_reason(self, sector_name: str, top_stocks: List[Dict]) -> str:
|
||||
"""
|
||||
推测异动原因(基于龙头股分析)
|
||||
|
||||
Args:
|
||||
sector_name: 板块名称
|
||||
top_stocks: 龙头股列表
|
||||
|
||||
Returns:
|
||||
异动原因描述
|
||||
"""
|
||||
try:
|
||||
if not top_stocks:
|
||||
return "板块整体异动"
|
||||
|
||||
reasons = []
|
||||
|
||||
# 检查是否有涨停股
|
||||
limit_up_count = sum(1 for s in top_stocks if s.get('change_pct', 0) >= 9.9)
|
||||
if limit_up_count > 0:
|
||||
reasons.append(f"{limit_up_count}只个股涨停")
|
||||
|
||||
# 检查平均涨幅
|
||||
avg_change = sum(s.get('change_pct', 0) for s in top_stocks) / len(top_stocks)
|
||||
if avg_change >= 7:
|
||||
reasons.append("板块全线爆发")
|
||||
|
||||
# 检查是否集中在某个龙头
|
||||
if len(top_stocks) >= 2:
|
||||
top1_change = top_stocks[0].get('change_pct', 0)
|
||||
top2_change = top_stocks[1].get('change_pct', 0)
|
||||
if top1_change - top2_change > 3:
|
||||
reasons.append(f"{top_stocks[0].get('name', '')}龙头领涨")
|
||||
|
||||
if reasons:
|
||||
return ",".join(reasons)
|
||||
else:
|
||||
return "资金集中流入"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推测异动原因失败: {e}")
|
||||
return "板块异动"
|
||||
@ -1,244 +0,0 @@
|
||||
"""
|
||||
龙头股筛选(Tushare 版本)
|
||||
从异动板块中筛选出龙头股
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, List
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class TushareStockSelector:
|
||||
"""龙头股筛选器(使用 Tushare)"""
|
||||
|
||||
def __init__(self, tushare_client, top_n: int = 3):
|
||||
"""
|
||||
初始化筛选器
|
||||
|
||||
Args:
|
||||
tushare_client: TushareClient 实例
|
||||
top_n: 返回前 N 只龙头股
|
||||
"""
|
||||
self.top_n = top_n
|
||||
self.ts_client = tushare_client
|
||||
|
||||
def select_leading_stocks(self, ts_code: str, sector_name: str) -> List[Dict]:
|
||||
"""
|
||||
筛选板块龙头股
|
||||
|
||||
Args:
|
||||
ts_code: 板块指数代码
|
||||
sector_name: 板块名称
|
||||
|
||||
Returns:
|
||||
龙头股列表(已排序)
|
||||
"""
|
||||
try:
|
||||
# 获取成分股
|
||||
members_df = self.ts_client.get_sector_members(ts_code)
|
||||
if members_df.empty:
|
||||
logger.warning(f"获取板块 {sector_name} 成分股失败")
|
||||
return []
|
||||
|
||||
# ths_member 返回的是 con_code(成分股代码),需要用这个来查行情
|
||||
stock_codes = members_df['con_code'].tolist()
|
||||
|
||||
# 限制数量,避免请求过多
|
||||
if len(stock_codes) > 50:
|
||||
stock_codes = stock_codes[:50]
|
||||
|
||||
# 获取实时行情
|
||||
realtime_df = self.ts_client.get_realtime_data(stock_codes)
|
||||
if realtime_df.empty:
|
||||
logger.warning(f"获取板块 {sector_name} 成分股行情失败")
|
||||
return []
|
||||
|
||||
# 获取每日指标(换手率、量比)
|
||||
from datetime import datetime
|
||||
trade_date = datetime.now().strftime('%Y%m%d')
|
||||
basic_df = self.ts_client.get_stock_daily_basic(stock_codes, trade_date)
|
||||
|
||||
# 合并数据 - 注意:ths_member 的 con_code 对应 daily 的 ts_code
|
||||
members_df = members_df.rename(columns={'con_code': 'stock_code'})
|
||||
realtime_df = realtime_df.rename(columns={'ts_code': 'stock_code'})
|
||||
|
||||
if not basic_df.empty:
|
||||
basic_df = basic_df.rename(columns={'ts_code': 'stock_code'})
|
||||
merged = pd.merge(
|
||||
members_df[['stock_code', 'con_name']],
|
||||
realtime_df,
|
||||
on='stock_code',
|
||||
how='inner'
|
||||
)
|
||||
merged = pd.merge(
|
||||
merged,
|
||||
basic_df[['stock_code', 'turnover_rate', 'volume_ratio']],
|
||||
on='stock_code',
|
||||
how='left'
|
||||
)
|
||||
else:
|
||||
merged = pd.merge(
|
||||
members_df[['stock_code', 'con_name']],
|
||||
realtime_df,
|
||||
on='stock_code',
|
||||
how='inner'
|
||||
)
|
||||
|
||||
if merged.empty:
|
||||
return []
|
||||
|
||||
# 数据类型转换 - daily 接口返回 pct_chg 不是 pct_change
|
||||
merged['close'] = pd.to_numeric(merged['close'], errors='coerce')
|
||||
merged['pct_chg'] = pd.to_numeric(merged['pct_chg'], errors='coerce')
|
||||
merged['change'] = pd.to_numeric(merged['change'], errors='coerce')
|
||||
merged['vol'] = pd.to_numeric(merged['vol'], errors='coerce')
|
||||
# 注意:daily 接口的 amount 单位是千元,需要转换为元
|
||||
merged['amount'] = pd.to_numeric(merged['amount'], errors='coerce') * 1000
|
||||
|
||||
# 换手率和量比填充默认值
|
||||
if 'turnover_rate' in merged.columns:
|
||||
merged['turnover_rate'] = pd.to_numeric(merged['turnover_rate'], errors='coerce').fillna(0)
|
||||
else:
|
||||
merged['turnover_rate'] = 0.0
|
||||
|
||||
if 'volume_ratio' in merged.columns:
|
||||
merged['volume_ratio'] = pd.to_numeric(merged['volume_ratio'], errors='coerce').fillna(1.0)
|
||||
else:
|
||||
merged['volume_ratio'] = 1.0
|
||||
|
||||
# 过滤:只保留有成交额的股票
|
||||
merged = merged[merged['amount'] > 0].copy()
|
||||
|
||||
if merged.empty:
|
||||
return []
|
||||
|
||||
# 计算综合评分
|
||||
merged['score'] = merged.apply(self._calculate_score, axis=1)
|
||||
|
||||
# 排序:按综合得分
|
||||
merged = merged.sort_values('score', ascending=False)
|
||||
|
||||
# 取前 N 只
|
||||
top_stocks = merged.head(self.top_n)
|
||||
|
||||
# 转换结果
|
||||
results = []
|
||||
for _, row in top_stocks.iterrows():
|
||||
# 计算涨速等级
|
||||
change_pct = row['pct_chg']
|
||||
if change_pct >= 5:
|
||||
speed_level = "⚡⚡⚡ 极快"
|
||||
elif change_pct >= 3:
|
||||
speed_level = "⚡⚡ 快速"
|
||||
elif change_pct >= 1:
|
||||
speed_level = "⚡ 较快"
|
||||
else:
|
||||
speed_level = "🐌 平稳"
|
||||
|
||||
# 计算振幅
|
||||
amplitude = 0.0
|
||||
if 'high' in row and 'low' in row and row['low'] > 0:
|
||||
amplitude = (row['high'] - row['low']) / row['low'] * 100
|
||||
|
||||
results.append({
|
||||
'code': row['stock_code'],
|
||||
'name': row['con_name'],
|
||||
'price': float(row['close']),
|
||||
'change_pct': float(row['pct_chg']),
|
||||
'change_amount': float(row['change']),
|
||||
'amount': float(row['amount']),
|
||||
'turnover': float(row.get('turnover_rate', 0)),
|
||||
'volume_ratio': float(row.get('volume_ratio', 1.0)),
|
||||
'amplitude': amplitude,
|
||||
'score': float(row['score']),
|
||||
'speed_level': speed_level,
|
||||
})
|
||||
|
||||
logger.info(f"板块 {sector_name} 龙头股筛选完成,Top {len(results)}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"筛选龙头股失败 {sector_name}: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_score(self, row: pd.Series) -> float:
|
||||
"""
|
||||
计算综合得分
|
||||
|
||||
评分维度:
|
||||
- 涨跌幅 (40%)
|
||||
- 成交额 (30%)
|
||||
- 涨速 (20%)
|
||||
- 换手率 (10%)
|
||||
|
||||
Args:
|
||||
row: 股票数据行
|
||||
|
||||
Returns:
|
||||
综合得分
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 1. 涨跌幅得分 (40分) - 涨幅越高得分越高
|
||||
change_pct = row['pct_chg']
|
||||
if change_pct >= 7:
|
||||
score += 40 # 涨停级别
|
||||
elif change_pct >= 5:
|
||||
score += 35
|
||||
elif change_pct >= 3:
|
||||
score += 30
|
||||
elif change_pct >= 2:
|
||||
score += 25
|
||||
elif change_pct >= 1:
|
||||
score += 20
|
||||
elif change_pct > 0:
|
||||
score += 15
|
||||
else:
|
||||
score += max(0, 10 + change_pct * 5) # 下跌也有基础分
|
||||
|
||||
# 2. 成交额得分 (30分) - 成交额越大得分越高
|
||||
# 注意:amount 已在 select_leading_stocks 中从千元转换为元
|
||||
amount = row['amount'] # 单位是元
|
||||
if amount >= 1000000000: # 10亿以上
|
||||
score += 30
|
||||
elif amount >= 500000000: # 5亿以上
|
||||
score += 25
|
||||
elif amount >= 100000000: # 1亿以上
|
||||
score += 20
|
||||
elif amount >= 50000000: # 5000万以上
|
||||
score += 15
|
||||
elif amount >= 10000000: # 1000万以上
|
||||
score += 10
|
||||
else:
|
||||
score += 5
|
||||
|
||||
# 3. 涨速得分 (20分) - 简化用涨幅代替
|
||||
if change_pct >= 5:
|
||||
score += 20
|
||||
elif change_pct >= 3:
|
||||
score += 15
|
||||
elif change_pct >= 1:
|
||||
score += 10
|
||||
else:
|
||||
score += 5
|
||||
|
||||
# 4. 换手率得分 (10分) - 使用真实换手率数据
|
||||
turnover_rate = row.get('turnover_rate', 0)
|
||||
if turnover_rate >= 15:
|
||||
score += 10 # 换手率极高,资金活跃
|
||||
elif turnover_rate >= 10:
|
||||
score += 9
|
||||
elif turnover_rate >= 7:
|
||||
score += 8
|
||||
elif turnover_rate >= 5:
|
||||
score += 7
|
||||
elif turnover_rate >= 3:
|
||||
score += 6
|
||||
elif turnover_rate >= 1:
|
||||
score += 4
|
||||
elif turnover_rate >= 0.5:
|
||||
score += 2
|
||||
else:
|
||||
score += 1 # 换手率较低
|
||||
|
||||
return score
|
||||
@ -3,10 +3,11 @@
|
||||
从环境变量加载配置
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
# 查找.env文件的位置
|
||||
@ -48,9 +49,6 @@ def find_env_file():
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置"""
|
||||
|
||||
# Tushare配置
|
||||
tushare_token: str = ""
|
||||
|
||||
# LLM配置
|
||||
zhipuai_api_key: str = ""
|
||||
deepseek_api_key: str = ""
|
||||
@ -94,9 +92,7 @@ class Settings(BaseSettings):
|
||||
binance_api_secret: str = ""
|
||||
|
||||
# 飞书机器人配置
|
||||
feishu_crypto_webhook_url: str = "https://open.feishu.cn/open-apis/bot/v2/hook/8a1dcf69-6753-41e2-a393-edc4f7822db0" # 加密货币通知
|
||||
feishu_stock_webhook_url: str = "https://open.feishu.cn/open-apis/bot/v2/hook/408ab727-0dcd-4c7a-bde7-4aad38cbf807" # 股票通知
|
||||
feishu_news_webhook_url: str = "https://open.feishu.cn/open-apis/bot/v2/hook/c7fd0db7-d295-451c-b943-130278a6cd9d" # 新闻智能体通知
|
||||
feishu_crypto_webhook_url: str = "https://open.feishu.cn/open-apis/bot/v2/hook/8a1dcf69-6753-41e2-a393-edc4f7822db0"
|
||||
feishu_paper_trading_webhook_url: str = "https://open.feishu.cn/open-apis/bot/v2/hook/3f5642e7-420b-45f7-8f88-fff92bb98c69" # 模拟交易通知(交易信号+决策+执行)
|
||||
feishu_error_webhook_url: str = "https://open.feishu.cn/open-apis/bot/v2/hook/ba6952c9-3b0c-4bc1-8a43-ceaacb27b043" # 系统异常通知
|
||||
feishu_enabled: bool = True # 是否启用飞书通知
|
||||
@ -179,43 +175,61 @@ class Settings(BaseSettings):
|
||||
account_drawdown_alert: float = 0.15 # 回撤警告阈值(15%),触发告警通知
|
||||
|
||||
# Agent 模型配置 (可选值: zhipu, deepseek)
|
||||
smart_agent_model: str = "deepseek" # SmartAgent 使用的模型
|
||||
crypto_agent_model: str = "deepseek" # CryptoAgent 使用的模型
|
||||
stock_agent_model: str = "deepseek" # StockAgent 使用的模型
|
||||
|
||||
# 股票智能体配置
|
||||
stock_symbols_us: str = "" # 美股代码,逗号分隔
|
||||
# 港股代码:科技+新能源+芯片+AI+金融+汽车+医药+消费+能源(统一格式:去掉前导零)
|
||||
# 科技:腾讯控股/阿里巴巴/美团/小米集团/京东集团/网易/百度/快手/知乎/B站
|
||||
# 新能源:比亚迪/理想汽车/小鹏汽车/赣锋锂业/龙源电力/信义能源
|
||||
# 芯片:中芯国际/华虹半导体/上海复旦
|
||||
# AI:商汤/第四范式/创新奇智/美图/联易融/百融云
|
||||
# 金融:汇控/建行/工行/农行/中行/友邦/平安/国寿/中金/中信
|
||||
# 汽车:蔚来/长城汽车/吉利汽车
|
||||
# 医药:药明康德/药明生物/百济神州/信达生物/石药集团
|
||||
# 消费:名创优品/泡泡玛特/安踏体育
|
||||
# 能源:中海油/中石油/中国神华
|
||||
stock_symbols_hk: str = "700.HK,9988.HK,3690.HK,1810.HK,9618.HK,9999.HK,9888.HK,1024.HK,2390.HK,9626.HK,1211.HK,2015.HK,9868.HK,1772.HK,916.HK,3868.HK,981.HK,1347.HK,1385.HK,20.HK,6682.HK,2121.HK,1357.HK,9959.HK,6608.HK,5.HK,939.HK,1398.HK,1288.HK,3988.HK,1299.HK,2318.HK,2628.HK,3908.HK,6030.HK,9866.HK,2333.HK,175.HK,2359.HK,2269.HK,6160.HK,1801.HK,1093.HK,9896.HK,9992.HK,2020.HK,883.HK,857.HK,1088.HK"
|
||||
# 注意:实际执行为每小时整点,此配置已废弃
|
||||
stock_analysis_interval: int = 3600 # 分析间隔(秒,整点执行)
|
||||
stock_llm_threshold: float = 0.70 # 触发 LLM 分析的置信度阈值
|
||||
|
||||
# A股智能体配置
|
||||
astock_monitor_enabled: bool = True # 是否启用A股智能体
|
||||
astock_change_threshold: float = 2.0 # 涨跌幅阈值(%),超过此值触发异动
|
||||
astock_top_n: int = 3 # 每个板块返回前N只龙头股
|
||||
astock_check_interval: int = 30 # 检查间隔(分钟)
|
||||
# 钉钉通知配置(A股专用)
|
||||
dingtalk_astock_webhook: str = "" # A股钉钉通知 Webhook
|
||||
dingtalk_astock_secret: str = "" # A股钉钉通知加签密钥
|
||||
|
||||
# A股龙回头选股配置
|
||||
pullback_selector_enabled: bool = True # 是否启用龙回头选股
|
||||
pullback_select_time: str = "09:00" # 选股时间(24小时制)
|
||||
pullback_sectors_to_check: int = 5 # 检查板块数量
|
||||
crypto_agent_model: str = "deepseek"
|
||||
|
||||
# ========== Bitget 实盘交易配置 ==========
|
||||
bitget_trading_enabled: bool = False # Bitget 实盘交易开关(默认关闭)
|
||||
bitget_accounts: str = "" # 多账号列表,例如: "main,sub1"
|
||||
|
||||
def get_bitget_account_ids(self) -> List[str]:
|
||||
"""返回已启用的 Bitget 账号列表,未配置时兼容 default 单账号。"""
|
||||
raw = str(self.bitget_accounts or "").strip()
|
||||
if raw:
|
||||
account_ids = [item.strip() for item in raw.split(',') if item.strip()]
|
||||
if account_ids:
|
||||
return list(dict.fromkeys(account_ids))
|
||||
return ['default']
|
||||
|
||||
def get_bitget_account_config(self, account_id: str = "default") -> Dict[str, Any]:
|
||||
"""获取指定 Bitget 账号配置,兼容单账号与多账号命名。"""
|
||||
normalized = (account_id or "default").strip() or "default"
|
||||
if normalized == "default":
|
||||
return {
|
||||
"account_id": "default",
|
||||
"api_key": self.bitget_api_key,
|
||||
"api_secret": self.bitget_api_secret,
|
||||
"passphrase": self.bitget_passphrase,
|
||||
"enabled": bool(self.bitget_trading_enabled and self.bitget_api_key and self.bitget_api_secret),
|
||||
"use_testnet": self.bitget_use_testnet,
|
||||
"use_unified_account": self.bitget_use_unified_account,
|
||||
}
|
||||
|
||||
prefix = f"bitget_{normalized}_"
|
||||
api_key = getattr(self, f"{prefix}api_key", "")
|
||||
api_secret = getattr(self, f"{prefix}api_secret", "")
|
||||
passphrase = getattr(self, f"{prefix}passphrase", "")
|
||||
enabled_value = getattr(self, f"{prefix}enabled", True)
|
||||
use_testnet = getattr(self, f"{prefix}use_testnet", self.bitget_use_testnet)
|
||||
use_unified = getattr(self, f"{prefix}use_unified_account", self.bitget_use_unified_account)
|
||||
|
||||
return {
|
||||
"account_id": normalized,
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"passphrase": passphrase,
|
||||
"enabled": bool(self.bitget_trading_enabled and enabled_value and api_key and api_secret),
|
||||
"use_testnet": use_testnet,
|
||||
"use_unified_account": use_unified,
|
||||
}
|
||||
|
||||
def get_enabled_bitget_accounts(self) -> List[Dict[str, Any]]:
|
||||
"""返回所有已启用且凭证完整的 Bitget 账号配置。"""
|
||||
configs: List[Dict[str, Any]] = []
|
||||
for account_id in self.get_bitget_account_ids():
|
||||
config = self.get_bitget_account_config(account_id)
|
||||
if config.get("enabled"):
|
||||
configs.append(config)
|
||||
return configs
|
||||
|
||||
class Config:
|
||||
env_file = find_env_file()
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
加密货币交易智能体模块
|
||||
"""
|
||||
from app.crypto_agent.crypto_agent import CryptoAgent
|
||||
from app.crypto_agent.execution_guardian import ExecutionGuardian
|
||||
from app.crypto_agent.strategy import TrendFollowingStrategy
|
||||
|
||||
__all__ = ['CryptoAgent', 'TrendFollowingStrategy']
|
||||
__all__ = ['CryptoAgent', 'ExecutionGuardian', 'TrendFollowingStrategy']
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
361
backend/app/crypto_agent/execution_guardian.py
Normal file
361
backend/app/crypto_agent/execution_guardian.py
Normal file
@ -0,0 +1,361 @@
|
||||
"""
|
||||
执行监管器
|
||||
|
||||
从 CryptoAgent 主循环中拆分执行后监管职责,负责:
|
||||
- 挂单超时清理
|
||||
- 持仓管理(止盈 / 超时退出 / 移动止损)
|
||||
- Bitget 挂单成交后的 TP/SL 补设
|
||||
- Bitget 持仓保护单缺失补救
|
||||
|
||||
第一版先作为确定性协调器运行,不引入新的 LLM 决策。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.crypto_agent.execution_targets import ExecutionTarget
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class ExecutionGuardian:
|
||||
"""执行监管协调器。"""
|
||||
|
||||
def __init__(self, agent: Any):
|
||||
self.agent = agent
|
||||
self._state: Dict[str, Any] = {
|
||||
"last_run_at": None,
|
||||
"last_status": "idle",
|
||||
"last_error": "",
|
||||
"last_actions": [],
|
||||
}
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"last_run_at": self._state.get("last_run_at"),
|
||||
"last_status": self._state.get("last_status", "idle"),
|
||||
"last_error": self._state.get("last_error", ""),
|
||||
"targets": [self._serialize_target(target) for target in self._iter_targets()],
|
||||
"last_actions": list(self._state.get("last_actions", []))[:20],
|
||||
}
|
||||
|
||||
def _serialize_target(self, target: ExecutionTarget) -> Dict[str, Any]:
|
||||
return {
|
||||
"target_key": target.target_key,
|
||||
"platform": target.platform,
|
||||
"account_id": target.account_id,
|
||||
"supports_pending_timeout": target.supports_pending_timeout,
|
||||
"supports_position_management": target.supports_position_management,
|
||||
"supports_tpsl_repair": target.supports_tpsl_repair,
|
||||
}
|
||||
|
||||
def _iter_targets(self) -> List[ExecutionTarget]:
|
||||
targets = self.agent.get_execution_targets()
|
||||
if not isinstance(targets, list):
|
||||
return []
|
||||
return targets
|
||||
|
||||
async def run_cycle(self):
|
||||
"""执行一轮监管扫描。"""
|
||||
self._state["last_run_at"] = datetime.now().isoformat()
|
||||
self._state["last_status"] = "running"
|
||||
self._state["last_error"] = ""
|
||||
self._state["last_actions"] = []
|
||||
|
||||
try:
|
||||
for target in self._iter_targets():
|
||||
if self.agent._is_platform_halted(target.target_key):
|
||||
continue
|
||||
|
||||
if target.supports_pending_timeout:
|
||||
await self._check_pending_order_timeouts(target)
|
||||
if target.supports_position_management:
|
||||
await self._check_position_management(target)
|
||||
if target.supports_tpsl_repair:
|
||||
await self._check_and_set_pending_tp_sl(target)
|
||||
await self._check_missing_tp_sl(target)
|
||||
|
||||
self._state["last_status"] = "completed"
|
||||
except Exception as e:
|
||||
self._state["last_status"] = "error"
|
||||
self._state["last_error"] = str(e)
|
||||
logger.error(f"ExecutionGuardian 运行异常: {e}")
|
||||
raise
|
||||
|
||||
def _record_action(self, action_type: str, platform: str, symbol: str = "", detail: str = ""):
|
||||
self._state.setdefault("last_actions", []).insert(0, {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"action_type": action_type,
|
||||
"platform": platform,
|
||||
"symbol": symbol,
|
||||
"detail": detail,
|
||||
})
|
||||
self._state["last_actions"] = self._state["last_actions"][:20]
|
||||
|
||||
async def _check_pending_order_timeouts(self, target: ExecutionTarget):
|
||||
"""检查各平台挂单超时。"""
|
||||
pending_orders = []
|
||||
if target.platform == 'PaperTrading':
|
||||
pending_orders = target.service.get_open_orders()
|
||||
elif target.platform == 'Bitget':
|
||||
pending_orders = target.service.get_open_orders() if target.service else []
|
||||
|
||||
if not pending_orders:
|
||||
return
|
||||
|
||||
timeout_orders = target.executor.check_pending_order_timeout(pending_orders)
|
||||
for order_info in timeout_orders:
|
||||
order_id = order_info.get('order_id')
|
||||
symbol = order_info.get('symbol', '')
|
||||
reason = order_info.get('reason', '')
|
||||
|
||||
logger.info(f" ⏰ [{target.target_key}] {symbol} {reason}")
|
||||
result = await target.executor.execute_cancel(order_id, symbol)
|
||||
if result.get('success'):
|
||||
self._record_action("cancel_timeout", target.target_key, symbol, reason)
|
||||
logger.info(f" ✅ 已取消超时挂单: {order_id}")
|
||||
message = (
|
||||
f"⏰ 挂单超时自动取消\n\n"
|
||||
f"平台: {target.platform}\n"
|
||||
f"账户: {target.account_id}\n"
|
||||
f"交易对: {symbol}\n"
|
||||
f"订单ID: {order_id}\n"
|
||||
f"原因: {reason}"
|
||||
)
|
||||
await self.agent._send_alert_notification(f"⏰ [{target.target_key}] 挂单超时", message)
|
||||
else:
|
||||
error = result.get('error', '未知错误')
|
||||
logger.error(f" ❌ 取消失败: {error}")
|
||||
|
||||
async def _check_position_management(self, target: ExecutionTarget):
|
||||
"""检查各平台持仓管理(止盈/止损/移动止损)。"""
|
||||
current_prices = {}
|
||||
volatility_data = {}
|
||||
for symbol in self.agent.symbols:
|
||||
try:
|
||||
data = self.agent.exchange.get_multi_timeframe_data(symbol)
|
||||
current_prices[symbol] = float(data['5m'].iloc[-1]['close'])
|
||||
if '1h' in data and 'atr' in data['1h'].columns:
|
||||
atr_value = data['1h']['atr'].iloc[-1]
|
||||
price_1h = data['1h']['close'].iloc[-1]
|
||||
if atr_value and price_1h > 0:
|
||||
volatility_data[symbol] = float(atr_value) / float(price_1h)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if target.platform == 'PaperTrading':
|
||||
positions = target.service.get_open_positions()
|
||||
elif target.platform == 'Bitget':
|
||||
positions = target.service.get_open_positions() if target.service else []
|
||||
else:
|
||||
positions = []
|
||||
|
||||
if not positions:
|
||||
return
|
||||
|
||||
actions = target.executor.check_position_management(positions, current_prices, volatility_data)
|
||||
for action_info in actions:
|
||||
symbol = action_info.get('symbol')
|
||||
action = action_info.get('action')
|
||||
reason = action_info.get('reason', '')
|
||||
|
||||
logger.info(f" 📊 [{target.target_key}] {symbol} {reason}")
|
||||
|
||||
if action in {'TAKE_PROFIT', 'TIME_EXIT'}:
|
||||
normalized_symbol = self.agent._normalize_symbol(symbol)
|
||||
close_order_ids = [
|
||||
p.get('order_id') for p in positions
|
||||
if self.agent._normalize_symbol(p.get('symbol', '')) == normalized_symbol and p.get('order_id')
|
||||
]
|
||||
decision = {
|
||||
'decision': 'CLOSE',
|
||||
'symbol': normalized_symbol,
|
||||
'orders_to_close': close_order_ids,
|
||||
'reason': reason,
|
||||
}
|
||||
result = await target.executor.execute_close(decision, current_prices.get(symbol, 0))
|
||||
if result.get('success'):
|
||||
self._record_action(action.lower(), target.target_key, normalized_symbol, reason)
|
||||
title = "💰" if action == 'TAKE_PROFIT' else "⏰"
|
||||
text = "自动止盈" if action == 'TAKE_PROFIT' else "持仓超时平仓"
|
||||
await self.agent._send_alert_notification(
|
||||
f"{title} [{target.target_key}] {text}",
|
||||
f"交易对: {symbol}\n原因: {reason}"
|
||||
)
|
||||
|
||||
elif action == 'MOVE_SL':
|
||||
new_sl = action_info.get('new_sl')
|
||||
pnl_pct = action_info.get('pnl_pct', 0)
|
||||
if new_sl:
|
||||
move_result = await target.executor.move_stop_loss(symbol=symbol, new_stop_loss=new_sl)
|
||||
if move_result.get('success'):
|
||||
self._record_action("move_sl", target.target_key, symbol, f"new_sl={new_sl}")
|
||||
await self.agent._send_alert_notification(
|
||||
f"🔒 [{target.target_key}] 移动止损",
|
||||
f"交易对: {symbol}\n新止损: ${new_sl:.2f}\n原因: {reason}"
|
||||
)
|
||||
await target.executor.send_execution_notification(
|
||||
operation='POSITION_MANAGEMENT',
|
||||
symbol=symbol,
|
||||
result={'success': True, 'action': 'MOVE_SL', 'reason': reason},
|
||||
details={
|
||||
'new_sl': new_sl,
|
||||
'pnl_percent': pnl_pct,
|
||||
'account_id': target.account_id,
|
||||
'target_key': target.target_key,
|
||||
}
|
||||
)
|
||||
|
||||
async def _check_and_set_pending_tp_sl(self, target: ExecutionTarget):
|
||||
"""检查 Bitget 挂单是否已成交,若成交则补设止盈止损。"""
|
||||
if target.platform != 'Bitget':
|
||||
return
|
||||
pending_state = self.agent._get_pending_tp_sl_state(target.pending_tpsl_state_key or target.target_key)
|
||||
if not pending_state:
|
||||
return
|
||||
|
||||
for order_id, info in list(pending_state.items()):
|
||||
symbol = self.agent._normalize_symbol(info['symbol'])
|
||||
coin = symbol.replace('USDT', '')
|
||||
open_orders = target.service.get_open_orders(symbol)
|
||||
still_open = any(str(o.get('order_id')) == order_id for o in open_orders)
|
||||
if still_open:
|
||||
continue
|
||||
|
||||
position = target.service.get_position_for_symbol(coin)
|
||||
if not position:
|
||||
logger.info(f"[{target.target_key}] 挂单追踪 {order_id} 已结束:{symbol} 无持仓,移除待补设任务")
|
||||
self._record_action("cleanup_pending_tpsl", target.target_key, symbol, f"order_id={order_id}")
|
||||
del pending_state[order_id]
|
||||
continue
|
||||
|
||||
tp_price = info.get('tp_price')
|
||||
sl_price = info.get('sl_price')
|
||||
logger.info(f"[{target.target_key}] 挂单 {order_id} ({symbol}) 已成交,补设 TP/SL...")
|
||||
tp_sl_result = target.service.set_tp_sl(
|
||||
symbol=coin,
|
||||
is_long=position.get('size', 0) > 0,
|
||||
size=abs(position.get('size', 0)),
|
||||
tp_price=tp_price,
|
||||
sl_price=sl_price,
|
||||
)
|
||||
info['retry_count'] = int(info.get('retry_count', 0)) + 1
|
||||
tp_set = tp_sl_result.get('tp_set', False)
|
||||
sl_set = tp_sl_result.get('sl_set', False)
|
||||
|
||||
if tp_set and sl_set:
|
||||
self._record_action("repair_tpsl", target.target_key, symbol, f"order_id={order_id}")
|
||||
logger.info(f"[{target.target_key}] ✅ TP/SL 补设成功: {symbol} TP={tp_price} SL={sl_price}")
|
||||
del pending_state[order_id]
|
||||
continue
|
||||
|
||||
if tp_set or sl_set:
|
||||
missing_tp = tp_price if not tp_set else None
|
||||
missing_sl = sl_price if not sl_set else None
|
||||
pending_state[order_id] = self.agent._build_pending_tp_sl_task(
|
||||
symbol=symbol,
|
||||
is_long=position.get('size', 0) > 0,
|
||||
size=abs(position.get('size', 0)),
|
||||
tp_price=missing_tp,
|
||||
sl_price=missing_sl,
|
||||
retry_count=info.get('retry_count', 0),
|
||||
first_seen_at=info.get('first_seen_at'),
|
||||
last_alert_at=info.get('last_alert_at'),
|
||||
)
|
||||
set_text = "TP" if tp_set else "SL"
|
||||
fail_text = "TP" if not tp_set else "SL"
|
||||
await self.agent._maybe_alert_tp_sl_incomplete(
|
||||
target.target_key,
|
||||
order_id,
|
||||
pending_state[order_id],
|
||||
f"{set_text}已设,{fail_text}补设失败",
|
||||
)
|
||||
continue
|
||||
|
||||
await self.agent._maybe_alert_tp_sl_incomplete(
|
||||
target.target_key,
|
||||
order_id,
|
||||
info,
|
||||
str(tp_sl_result.get('errors') or 'TP/SL补设失败'),
|
||||
)
|
||||
|
||||
async def _check_missing_tp_sl(self, target: ExecutionTarget):
|
||||
"""定时检查 Bitget 持仓是否缺少止盈止损,缺少则从信号补救。"""
|
||||
if target.platform != 'Bitget' or not target.service:
|
||||
return
|
||||
|
||||
positions = target.service.get_open_positions()
|
||||
if not positions:
|
||||
return
|
||||
|
||||
for pos in positions:
|
||||
symbol = pos.get('symbol', '')
|
||||
if not symbol:
|
||||
continue
|
||||
|
||||
coin = symbol.replace('USDT', '')
|
||||
tp_sl = target.service.get_tp_sl_prices(coin)
|
||||
has_tp = tp_sl.get('take_profit') is not None
|
||||
has_sl = tp_sl.get('stop_loss') is not None
|
||||
if has_tp and has_sl:
|
||||
continue
|
||||
|
||||
latest_signal = self.agent.signal_db.get_latest_signal('crypto', symbol)
|
||||
if not latest_signal:
|
||||
missing = ('止盈' if not has_tp else '') + ('/' if not has_tp and not has_sl else '') + ('止损' if not has_sl else '')
|
||||
logger.warning(f"[{target.target_key}] ⚠️ {symbol} 缺少{missing},且无历史信号可补救")
|
||||
continue
|
||||
|
||||
tp_price = latest_signal.get('take_profit')
|
||||
sl_price = latest_signal.get('stop_loss')
|
||||
if not tp_price and not sl_price:
|
||||
logger.warning(f"[{target.target_key}] ⚠️ {symbol} 缺少止盈止损,最近信号也无 TP/SL")
|
||||
continue
|
||||
|
||||
set_tp = tp_price if not has_tp else None
|
||||
set_sl = sl_price if not has_sl else None
|
||||
missing_parts = []
|
||||
if not has_tp:
|
||||
missing_parts.append(f"TP={set_tp}")
|
||||
if not has_sl:
|
||||
missing_parts.append(f"SL={set_sl}")
|
||||
missing_desc = ' & '.join(missing_parts)
|
||||
|
||||
logger.warning(f"[{target.target_key}] 🔧 {symbol} 缺少 {missing_desc},从信号补救...")
|
||||
size = abs(pos.get('size', 0))
|
||||
if size <= 0:
|
||||
continue
|
||||
|
||||
tp_sl_result = target.service.set_tp_sl(
|
||||
symbol=coin,
|
||||
is_long=pos.get('size', 0) > 0,
|
||||
size=size,
|
||||
tp_price=set_tp,
|
||||
sl_price=set_sl,
|
||||
)
|
||||
|
||||
tp_set = tp_sl_result.get('tp_set', False)
|
||||
sl_set = tp_sl_result.get('sl_set', False)
|
||||
if tp_set or sl_set:
|
||||
self._record_action("fallback_tpsl", target.target_key, symbol, missing_desc)
|
||||
set_parts = []
|
||||
if tp_set:
|
||||
set_parts.append(f"TP={set_tp}")
|
||||
if sl_set:
|
||||
set_parts.append(f"SL={set_sl}")
|
||||
logger.info(f"[{target.target_key}] ✅ 补救成功: {symbol} {' & '.join(set_parts)}")
|
||||
else:
|
||||
await self.agent._maybe_alert_tp_sl_incomplete(
|
||||
target.target_key,
|
||||
f"{target.target_key}:fallback:{symbol}",
|
||||
self.agent._build_pending_tp_sl_task(
|
||||
symbol=coin,
|
||||
is_long=pos.get('size', 0) > 0,
|
||||
size=size,
|
||||
tp_price=set_tp,
|
||||
sl_price=set_sl,
|
||||
retry_count=self.agent.TP_SL_RETRY_ALERT_THRESHOLD,
|
||||
),
|
||||
str(tp_sl_result.get('errors') or '兜底补设失败'),
|
||||
force=True,
|
||||
)
|
||||
60
backend/app/crypto_agent/execution_targets.py
Normal file
60
backend/app/crypto_agent/execution_targets.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""
|
||||
执行监管目标定义与默认注册工厂。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionTarget:
|
||||
"""执行监管目标:平台 + 账户。"""
|
||||
target_key: str
|
||||
platform: str
|
||||
account_id: str
|
||||
service: Any
|
||||
executor: Any
|
||||
supports_pending_timeout: bool = True
|
||||
supports_position_management: bool = True
|
||||
supports_tpsl_repair: bool = False
|
||||
pending_tpsl_state_key: Optional[str] = None
|
||||
|
||||
|
||||
def build_default_execution_targets(agent: Any) -> List[ExecutionTarget]:
|
||||
"""根据当前 agent 已启用的平台生成默认执行监管目标。"""
|
||||
targets: List[ExecutionTarget] = []
|
||||
|
||||
paper_executor = agent.executors.get('PaperTrading')
|
||||
if getattr(agent, 'paper_trading', None) and paper_executor:
|
||||
targets.append(ExecutionTarget(
|
||||
target_key="PaperTrading",
|
||||
platform="PaperTrading",
|
||||
account_id="default",
|
||||
service=agent.paper_trading,
|
||||
executor=paper_executor,
|
||||
supports_pending_timeout=True,
|
||||
supports_position_management=True,
|
||||
supports_tpsl_repair=False,
|
||||
))
|
||||
|
||||
bitget_services = getattr(agent, 'bitget_services', {}) or {}
|
||||
bitget_executors = getattr(agent, 'bitget_executors', {}) or {}
|
||||
for account_id, service in bitget_services.items():
|
||||
executor = bitget_executors.get(account_id)
|
||||
if not service or not executor:
|
||||
continue
|
||||
target_key = f"Bitget:{account_id}"
|
||||
targets.append(ExecutionTarget(
|
||||
target_key=target_key,
|
||||
platform="Bitget",
|
||||
account_id=account_id,
|
||||
service=service,
|
||||
executor=executor,
|
||||
supports_pending_timeout=True,
|
||||
supports_position_management=True,
|
||||
supports_tpsl_repair=True,
|
||||
pending_tpsl_state_key=target_key,
|
||||
))
|
||||
|
||||
return targets
|
||||
@ -20,6 +20,7 @@ class BaseExecutor(ABC):
|
||||
|
||||
def __init__(self, platform_name: str):
|
||||
self.platform_name = platform_name
|
||||
self.account_id = "default"
|
||||
|
||||
# 初始化飞书通知服务
|
||||
try:
|
||||
@ -523,6 +524,36 @@ class BaseExecutor(ABC):
|
||||
|
||||
# ==================== 飞书通知 ====================
|
||||
|
||||
def _normalize_notification_context(self,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
account_id: str = "default",
|
||||
target_key: str = "") -> Dict[str, str]:
|
||||
resolved_account_id = str(account_id or getattr(self, 'account_id', 'default') or 'default')
|
||||
resolved_target_key = target_key or (
|
||||
f"{self.platform_name}:{resolved_account_id}" if resolved_account_id and resolved_account_id != "default" else self.platform_name
|
||||
)
|
||||
return {
|
||||
"account_id": resolved_account_id,
|
||||
"target_key": resolved_target_key,
|
||||
"platform_label": self.platform_name,
|
||||
}
|
||||
|
||||
def _append_notification_detail(self, content_parts: List[str], label: str, value: Any):
|
||||
if value is None or value == "":
|
||||
return
|
||||
content_parts.append(f"**{label}**: {value}")
|
||||
|
||||
def _build_notification_header(self,
|
||||
symbol: str,
|
||||
account_id: str,
|
||||
target_key: str) -> List[str]:
|
||||
return [
|
||||
f"**执行目标**: {target_key}",
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**账号**: {account_id}",
|
||||
f"**交易对**: {symbol}",
|
||||
]
|
||||
|
||||
async def send_execution_notification(self,
|
||||
operation: str,
|
||||
symbol: str,
|
||||
@ -541,24 +572,26 @@ class BaseExecutor(ABC):
|
||||
return
|
||||
|
||||
try:
|
||||
success = result.get('success', False)
|
||||
order_id = result.get('order_id', '')
|
||||
error_msg = result.get('error', result.get('message', ''))
|
||||
details = dict(details or {})
|
||||
account_id = details.get('account_id') or getattr(self, 'account_id', 'default')
|
||||
target_key = details.get('target_key') or (
|
||||
f"{self.platform_name}:{account_id}" if account_id and account_id != "default" else self.platform_name
|
||||
)
|
||||
|
||||
# 根据操作类型选择通知方法
|
||||
if operation == 'OPEN':
|
||||
await self._send_open_notification(symbol, result, details)
|
||||
await self._send_open_notification(symbol, result, details, account_id, target_key)
|
||||
elif operation == 'CLOSE':
|
||||
await self._send_close_notification(symbol, result, details)
|
||||
await self._send_close_notification(symbol, result, details, account_id, target_key)
|
||||
elif operation == 'CANCEL':
|
||||
await self._send_cancel_notification(symbol, result, details)
|
||||
await self._send_cancel_notification(symbol, result, details, account_id, target_key)
|
||||
elif operation == 'TP_SL':
|
||||
await self._send_tp_sl_notification(symbol, result, details)
|
||||
await self._send_tp_sl_notification(symbol, result, details, account_id, target_key)
|
||||
elif operation == 'POSITION_MANAGEMENT':
|
||||
await self._send_position_management_notification(symbol, result, details)
|
||||
await self._send_position_management_notification(symbol, result, details, account_id, target_key)
|
||||
else:
|
||||
# 通用通知
|
||||
await self._send_generic_notification(operation, symbol, result, details)
|
||||
await self._send_generic_notification(operation, symbol, result, details, account_id, target_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.platform_name}] 发送执行通知失败: {e}")
|
||||
@ -566,7 +599,9 @@ class BaseExecutor(ABC):
|
||||
async def _send_open_notification(self,
|
||||
symbol: str,
|
||||
result: Dict[str, Any],
|
||||
details: Optional[Dict[str, Any]] = None):
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
account_id: str = "default",
|
||||
target_key: str = ""):
|
||||
"""发送开仓通知"""
|
||||
success = result.get('success', False)
|
||||
order_id = result.get('order_id', '')
|
||||
@ -574,45 +609,39 @@ class BaseExecutor(ABC):
|
||||
|
||||
if success:
|
||||
# 成功开仓
|
||||
title = f"✅ [{self.platform_name}] 开仓成功 - {symbol}"
|
||||
title = f"✅ [{target_key or self.platform_name}] 开仓成功 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**订单ID**: {order_id}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "订单ID", order_id)
|
||||
|
||||
# 添加详情
|
||||
if details:
|
||||
if 'size' in details:
|
||||
content_parts.append(f"**数量**: {details['size']}")
|
||||
if 'price' in details:
|
||||
content_parts.append(f"**价格**: ${details['price']:,.2f}")
|
||||
if 'margin' in details:
|
||||
content_parts.append(f"**保证金**: ${details['margin']:,.2f}")
|
||||
if 'leverage' in details:
|
||||
content_parts.append(f"**杠杆**: {details['leverage']}x")
|
||||
if 'stop_loss' in details and details['stop_loss']:
|
||||
content_parts.append(f"**止损**: ${details['stop_loss']:,.2f}")
|
||||
if 'take_profit' in details and details['take_profit']:
|
||||
content_parts.append(f"**止盈**: ${details['take_profit']:,.2f}")
|
||||
if 'order_type' in details:
|
||||
content_parts.append(f"**订单类型**: {details['order_type']}")
|
||||
self._append_notification_detail(content_parts, "数量", details.get('size'))
|
||||
if details.get('price') is not None:
|
||||
self._append_notification_detail(content_parts, "价格", f"${details['price']:,.2f}")
|
||||
if details.get('margin') is not None:
|
||||
self._append_notification_detail(content_parts, "保证金", f"${details['margin']:,.2f}")
|
||||
if details.get('notional') is not None:
|
||||
self._append_notification_detail(content_parts, "名义仓位", f"${details['notional']:,.2f}")
|
||||
if details.get('leverage') is not None:
|
||||
self._append_notification_detail(content_parts, "杠杆", f"{details['leverage']}x")
|
||||
if details.get('stop_loss') is not None:
|
||||
self._append_notification_detail(content_parts, "止损", f"${details['stop_loss']:,.2f}")
|
||||
if details.get('take_profit') is not None:
|
||||
self._append_notification_detail(content_parts, "止盈", f"${details['take_profit']:,.2f}")
|
||||
self._append_notification_detail(content_parts, "订单类型", details.get('order_type'))
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "green"
|
||||
else:
|
||||
# 开仓失败
|
||||
title = f"❌ [{self.platform_name}] 开仓失败 - {symbol}"
|
||||
title = f"❌ [{target_key or self.platform_name}] 开仓失败 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**错误**: {error_msg}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "错误", error_msg)
|
||||
|
||||
if details and 'reason' in details:
|
||||
content_parts.append(f"**原因**: {details['reason']}")
|
||||
self._append_notification_detail(content_parts, "原因", details['reason'])
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "red"
|
||||
@ -622,39 +651,35 @@ class BaseExecutor(ABC):
|
||||
async def _send_close_notification(self,
|
||||
symbol: str,
|
||||
result: Dict[str, Any],
|
||||
details: Optional[Dict[str, Any]] = None):
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
account_id: str = "default",
|
||||
target_key: str = ""):
|
||||
"""发送平仓通知"""
|
||||
success = result.get('success', False)
|
||||
error_msg = result.get('error', result.get('message', ''))
|
||||
|
||||
if success:
|
||||
title = f"✅ [{self.platform_name}] 平仓成功 - {symbol}"
|
||||
title = f"✅ [{target_key or self.platform_name}] 平仓成功 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
|
||||
if details:
|
||||
if 'pnl' in details:
|
||||
pnl = details['pnl']
|
||||
pnl_color = "盈利" if pnl >= 0 else "亏损"
|
||||
content_parts.append(f"**{pnl_color}**: ${pnl:,.2f}")
|
||||
self._append_notification_detail(content_parts, pnl_color, f"${pnl:,.2f}")
|
||||
if 'pnl_percent' in details:
|
||||
content_parts.append(f"**收益率**: {details['pnl_percent']:.2f}%")
|
||||
self._append_notification_detail(content_parts, "收益率", f"{details['pnl_percent']:.2f}%")
|
||||
if 'exit_reason' in details:
|
||||
content_parts.append(f"**平仓原因**: {details['exit_reason']}")
|
||||
self._append_notification_detail(content_parts, "平仓原因", details['exit_reason'])
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "green"
|
||||
else:
|
||||
title = f"❌ [{self.platform_name}] 平仓失败 - {symbol}"
|
||||
title = f"❌ [{target_key or self.platform_name}] 平仓失败 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**错误**: {error_msg}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "错误", error_msg)
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "red"
|
||||
@ -664,35 +689,31 @@ class BaseExecutor(ABC):
|
||||
async def _send_cancel_notification(self,
|
||||
symbol: str,
|
||||
result: Dict[str, Any],
|
||||
details: Optional[Dict[str, Any]] = None):
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
account_id: str = "default",
|
||||
target_key: str = ""):
|
||||
"""发送撤单通知"""
|
||||
success = result.get('success', False)
|
||||
order_id = result.get('order_id', '')
|
||||
error_msg = result.get('error', result.get('message', ''))
|
||||
|
||||
if success:
|
||||
title = f"✅ [{self.platform_name}] 撤单成功 - {symbol}"
|
||||
title = f"✅ [{target_key or self.platform_name}] 撤单成功 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**订单ID**: {order_id}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "订单ID", order_id)
|
||||
|
||||
if details and 'reason' in details:
|
||||
content_parts.append(f"**撤单原因**: {details['reason']}")
|
||||
self._append_notification_detail(content_parts, "撤单原因", details['reason'])
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "green"
|
||||
else:
|
||||
title = f"❌ [{self.platform_name}] 撤单失败 - {symbol}"
|
||||
title = f"❌ [{target_key or self.platform_name}] 撤单失败 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**订单ID**: {order_id}",
|
||||
f"**错误**: {error_msg}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "订单ID", order_id)
|
||||
self._append_notification_detail(content_parts, "错误", error_msg)
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "red"
|
||||
@ -702,37 +723,33 @@ class BaseExecutor(ABC):
|
||||
async def _send_tp_sl_notification(self,
|
||||
symbol: str,
|
||||
result: Dict[str, Any],
|
||||
details: Optional[Dict[str, Any]] = None):
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
account_id: str = "default",
|
||||
target_key: str = ""):
|
||||
"""发送止盈止损设置通知"""
|
||||
success = result.get('success', False)
|
||||
message = result.get('message', '')
|
||||
|
||||
if success:
|
||||
title = f"✅ [{self.platform_name}] 止盈止损设置成功 - {symbol}"
|
||||
title = f"✅ [{target_key or self.platform_name}] 止盈止损设置成功 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
|
||||
if details:
|
||||
if 'stop_loss' in details and details['stop_loss']:
|
||||
content_parts.append(f"**止损**: ${details['stop_loss']:,.2f}")
|
||||
if 'take_profit' in details and details['take_profit']:
|
||||
content_parts.append(f"**止盈**: ${details['take_profit']:,.2f}")
|
||||
if 'stop_loss' in details and details['stop_loss'] is not None:
|
||||
self._append_notification_detail(content_parts, "止损", f"${details['stop_loss']:,.2f}")
|
||||
if 'take_profit' in details and details['take_profit'] is not None:
|
||||
self._append_notification_detail(content_parts, "止盈", f"${details['take_profit']:,.2f}")
|
||||
if 'move_sl_reason' in details:
|
||||
content_parts.append(f"**移动止损**: {details['move_sl_reason']}")
|
||||
self._append_notification_detail(content_parts, "移动止损", details['move_sl_reason'])
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "green"
|
||||
else:
|
||||
title = f"⚠️ [{self.platform_name}] 止盈止损设置失败 - {symbol}"
|
||||
title = f"⚠️ [{target_key or self.platform_name}] 止盈止损设置失败 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**错误**: {message}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "错误", message)
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "orange"
|
||||
@ -742,25 +759,24 @@ class BaseExecutor(ABC):
|
||||
async def _send_position_management_notification(self,
|
||||
symbol: str,
|
||||
result: Dict[str, Any],
|
||||
details: Optional[Dict[str, Any]] = None):
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
account_id: str = "default",
|
||||
target_key: str = ""):
|
||||
"""发送持仓管理通知"""
|
||||
action = result.get('action', '')
|
||||
reason = result.get('reason', '')
|
||||
|
||||
title = f"📊 [{self.platform_name}] 持仓管理 - {symbol}"
|
||||
title = f"📊 [{target_key or self.platform_name}] 持仓管理 - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**操作**: {action}",
|
||||
f"**原因**: {reason}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "操作", action)
|
||||
self._append_notification_detail(content_parts, "原因", reason)
|
||||
|
||||
if details:
|
||||
if 'pnl_percent' in details:
|
||||
content_parts.append(f"**盈亏**: {details['pnl_percent']:.2f}%")
|
||||
self._append_notification_detail(content_parts, "盈亏", f"{details['pnl_percent']:.2f}%")
|
||||
if 'hold_hours' in details:
|
||||
content_parts.append(f"**持仓时长**: {details['hold_hours']:.1f}h")
|
||||
self._append_notification_detail(content_parts, "持仓时长", f"{details['hold_hours']:.1f}h")
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
|
||||
@ -780,26 +796,27 @@ class BaseExecutor(ABC):
|
||||
operation: str,
|
||||
symbol: str,
|
||||
result: Dict[str, Any],
|
||||
details: Optional[Dict[str, Any]] = None):
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
account_id: str = "default",
|
||||
target_key: str = ""):
|
||||
"""发送通用通知"""
|
||||
success = result.get('success', False)
|
||||
message = result.get('message', result.get('error', ''))
|
||||
|
||||
title = f"[{self.platform_name}] {operation} - {symbol}"
|
||||
title = f"[{target_key or self.platform_name}] {operation} - {symbol}"
|
||||
|
||||
content_parts = [
|
||||
f"**平台**: {self.platform_name}",
|
||||
f"**操作**: {operation}",
|
||||
f"**交易对**: {symbol}",
|
||||
f"**状态**: {'成功' if success else '失败'}",
|
||||
]
|
||||
content_parts = self._build_notification_header(symbol, account_id, target_key)
|
||||
self._append_notification_detail(content_parts, "操作", operation)
|
||||
self._append_notification_detail(content_parts, "状态", '成功' if success else '失败')
|
||||
|
||||
if message:
|
||||
content_parts.append(f"**信息**: {message}")
|
||||
self._append_notification_detail(content_parts, "信息", message)
|
||||
|
||||
if details:
|
||||
for key, value in details.items():
|
||||
content_parts.append(f"**{key}**: {value}")
|
||||
if key in {'account_id', 'target_key'}:
|
||||
continue
|
||||
self._append_notification_detail(content_parts, key, value)
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
color = "green" if success else "red"
|
||||
|
||||
@ -11,9 +11,17 @@ import re
|
||||
class BitgetExecutor(BaseExecutor):
|
||||
"""Bitget 实盘交易执行器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, service=None, account_id: str = "default"):
|
||||
super().__init__("Bitget")
|
||||
self.bitget = get_bitget_live_service()
|
||||
self.account_id = (account_id or "default").strip() or "default"
|
||||
self.bitget = service or get_bitget_live_service(self.account_id)
|
||||
|
||||
def _notification_context(self) -> Dict[str, str]:
|
||||
account_id = getattr(self, 'account_id', 'default') or 'default'
|
||||
return {
|
||||
'account_id': account_id,
|
||||
'target_key': f'Bitget:{account_id}',
|
||||
}
|
||||
|
||||
# ==================== 核心执行方法 ====================
|
||||
|
||||
@ -78,6 +86,13 @@ class BitgetExecutor(BaseExecutor):
|
||||
|
||||
order_id = result.get('order_id')
|
||||
order_status = result.get('order_status', 'filled')
|
||||
result['contracts'] = contracts
|
||||
result['margin'] = adjusted_margin
|
||||
result['leverage'] = leverage
|
||||
result['order_type'] = order_type
|
||||
result['entry_price'] = entry_price
|
||||
result['actual_position_value'] = actual_position_value
|
||||
result['effective_leverage'] = effective_leverage
|
||||
|
||||
# 设置止盈止损
|
||||
if stop_loss or take_profit:
|
||||
@ -152,7 +167,8 @@ class BitgetExecutor(BaseExecutor):
|
||||
await self.send_execution_notification(
|
||||
operation='OPEN',
|
||||
symbol=decision.get('symbol', ''),
|
||||
result=error_result
|
||||
result=error_result,
|
||||
details=self._notification_context()
|
||||
)
|
||||
|
||||
return error_result
|
||||
@ -178,7 +194,8 @@ class BitgetExecutor(BaseExecutor):
|
||||
await self.send_execution_notification(
|
||||
operation='CLOSE',
|
||||
symbol=symbol,
|
||||
result=result
|
||||
result=result,
|
||||
details=self._notification_context()
|
||||
)
|
||||
|
||||
return result
|
||||
@ -191,7 +208,8 @@ class BitgetExecutor(BaseExecutor):
|
||||
await self.send_execution_notification(
|
||||
operation='CLOSE',
|
||||
symbol=decision.get('symbol', ''),
|
||||
result=error_result
|
||||
result=error_result,
|
||||
details=self._notification_context()
|
||||
)
|
||||
|
||||
return error_result
|
||||
@ -208,7 +226,7 @@ class BitgetExecutor(BaseExecutor):
|
||||
operation='CANCEL',
|
||||
symbol=symbol,
|
||||
result=result,
|
||||
details={'order_id': order_id}
|
||||
details={'order_id': order_id, **self._notification_context()}
|
||||
)
|
||||
|
||||
return result
|
||||
@ -221,7 +239,7 @@ class BitgetExecutor(BaseExecutor):
|
||||
operation='CANCEL',
|
||||
symbol=symbol,
|
||||
result=error_result,
|
||||
details={'order_id': order_id}
|
||||
details={'order_id': order_id, **self._notification_context()}
|
||||
)
|
||||
|
||||
return error_result
|
||||
|
||||
@ -102,9 +102,12 @@ class PaperTradingExecutor(BaseExecutor):
|
||||
symbol=symbol,
|
||||
result=success_result,
|
||||
details={
|
||||
'account_id': 'default',
|
||||
'target_key': 'PaperTrading',
|
||||
'size': adjusted_margin * self.paper_trading.leverage / current_price,
|
||||
'price': entry_price if order_type == 'limit' else current_price,
|
||||
'margin': adjusted_margin,
|
||||
'notional': actual_position_value,
|
||||
'leverage': self.paper_trading.leverage,
|
||||
'stop_loss': stop_loss,
|
||||
'take_profit': take_profit,
|
||||
@ -168,6 +171,8 @@ class PaperTradingExecutor(BaseExecutor):
|
||||
symbol=symbol,
|
||||
result=result,
|
||||
details={
|
||||
'account_id': 'default',
|
||||
'target_key': 'PaperTrading',
|
||||
'pnl': total_pnl,
|
||||
'pnl_percent': (total_pnl / (success_count * decision.get('margin', 100))) * 100 if success_count > 0 else 0,
|
||||
'exit_reason': '手动平仓'
|
||||
|
||||
@ -9,7 +9,7 @@ from fastapi.responses import FileResponse
|
||||
from contextlib import asynccontextmanager
|
||||
from app.config import get_settings
|
||||
from app.utils.logger import logger
|
||||
from app.api import chat, stock, skills, llm, auth, admin, paper_trading, stocks, signals, system, news, astock, bitget_live
|
||||
from app.api import llm, auth, admin, paper_trading, signals, system, bitget_live
|
||||
from app.utils.error_handler import setup_global_exception_handler, init_error_notifier
|
||||
from app.utils.system_status import get_system_monitor
|
||||
import os
|
||||
@ -17,107 +17,7 @@ import os
|
||||
|
||||
# 后台任务
|
||||
_price_monitor_task = None
|
||||
_stock_agent_task = None
|
||||
_crypto_agent_task = None
|
||||
_news_agent_task = None
|
||||
_astock_monitor_task = None
|
||||
_astock_scheduler = None
|
||||
_astock_monitor_instance = None
|
||||
|
||||
|
||||
async def is_trading_day() -> bool:
|
||||
"""检查今天是否为A股交易日"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
from app.config import get_settings
|
||||
from app.astock_agent.tushare_client import TushareClient
|
||||
|
||||
settings = get_settings()
|
||||
token = settings.tushare_token
|
||||
if not token:
|
||||
logger.warning("Tushare token 未配置,使用简单的周末判断")
|
||||
# 简单判断:周一到周五是交易日(不包含节假日)
|
||||
return datetime.now().weekday() < 5
|
||||
|
||||
client = TushareClient(token=token)
|
||||
pro = client.pro
|
||||
|
||||
# 获取今天的日期
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
# 查询交易日历(最近3天)
|
||||
df = pro.trade_cal(
|
||||
exchange='SSE',
|
||||
start_date=(datetime.now().replace(day=datetime.now().day-2)).strftime("%Y%m%d") if datetime.now().day > 2 else today,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# 检查今天是否为交易日
|
||||
today_cal = df[df['cal_date'] == today]
|
||||
if not today_cal.empty:
|
||||
is_open = today_cal.iloc[0]['is_open']
|
||||
logger.info(f"交易日历查询: 今天 {today} {'是' if is_open == 1 else '不是'}交易日")
|
||||
return is_open == 1
|
||||
|
||||
# Fallback: 简单周末判断
|
||||
is_weekday = datetime.now().weekday() < 5
|
||||
logger.warning(f"交易日历查询失败,使用简单判断: 今天 {'是' if is_weekday else '不是'}工作日")
|
||||
return is_weekday
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查交易日失败: {e}")
|
||||
# Fallback: 简单周末判断
|
||||
return datetime.now().weekday() < 5
|
||||
|
||||
|
||||
async def run_scheduled_astock_monitor():
|
||||
"""定时运行A股板块异动监控(每天 15:30)"""
|
||||
global _astock_monitor_instance
|
||||
if not _astock_monitor_instance:
|
||||
logger.warning("A股监控实例未初始化")
|
||||
return
|
||||
|
||||
try:
|
||||
# 检查今天是否为交易日
|
||||
if not await is_trading_day():
|
||||
logger.info("📅 今天不是交易日,跳过板块异动分析")
|
||||
return
|
||||
|
||||
logger.info("🔔 开始执行定时板块异动分析...")
|
||||
result = await _astock_monitor_instance.check_once()
|
||||
|
||||
hot_sectors = result.get('hot_sectors', 0)
|
||||
stocks = result.get('stocks', 0)
|
||||
notified = result.get('notified', 0)
|
||||
|
||||
logger.info(f"✅ 定时板块分析完成: {hot_sectors}个异动板块, {stocks}只龙头股, {notified}条通知")
|
||||
except Exception as e:
|
||||
logger.error(f"定时板块分析失败: {e}")
|
||||
|
||||
|
||||
async def start_scheduler():
|
||||
"""启动定时任务调度器"""
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
global _astock_scheduler
|
||||
|
||||
# 创建调度器
|
||||
_astock_scheduler = AsyncIOScheduler(timezone='Asia/Shanghai')
|
||||
|
||||
# 添加定时任务:每天 15:30 运行板块异动分析
|
||||
_astock_scheduler.add_job(
|
||||
run_scheduled_astock_monitor,
|
||||
trigger=CronTrigger(hour=15, minute=30),
|
||||
id='daily_astock_monitor',
|
||||
name='A股板块异动分析',
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
_astock_scheduler.start()
|
||||
logger.info("📅 定时任务调度器已启动:")
|
||||
logger.info(" - 每天 15:30 (A股板块异动分析)")
|
||||
|
||||
|
||||
async def price_monitor_loop():
|
||||
@ -171,7 +71,7 @@ async def price_monitor_loop():
|
||||
f"⭐ **信号等级**: {grade}",
|
||||
f"💰 **挂单价**: ${price_fmt.format(entry_price)}",
|
||||
f"🎯 **成交价**: ${price_fmt.format(filled_price)}",
|
||||
f"💵 **仓位**: ${result.get('quantity', 0):,.0f}",
|
||||
f"💵 **仓位**: ${result.get('notional', result.get('quantity', 0)):,.0f}",
|
||||
]
|
||||
if stop_loss:
|
||||
content_parts.append(f"🛑 **止损**: ${price_fmt.format(stop_loss)}")
|
||||
@ -352,7 +252,7 @@ async def price_monitor_loop():
|
||||
f"",
|
||||
f"💰 **挂单价**: ${price_fmt.format(entry_price)}",
|
||||
f"🎯 **成交价**: ${price_fmt.format(filled_price)}",
|
||||
f"📊 **持仓价值**: ${result.get('quantity', 0):,.0f}",
|
||||
f"📊 **名义仓位**: ${result.get('notional', result.get('quantity', 0)):,.0f}",
|
||||
f"",
|
||||
f"🛑 **止损价**: ${price_fmt.format(stop_loss)}",
|
||||
f"🎯 **止盈价**: ${price_fmt.format(take_profit)}"
|
||||
@ -464,7 +364,7 @@ async def _print_system_status():
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
global _price_monitor_task, _stock_agent_task, _crypto_agent_task, _news_agent_task, _astock_monitor_task
|
||||
global _price_monitor_task, _crypto_agent_task
|
||||
|
||||
# 启动时执行
|
||||
logger.info("应用启动")
|
||||
@ -512,70 +412,6 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as e:
|
||||
logger.error(f"加密货币智能体启动失败: {e}")
|
||||
|
||||
# 启动股票智能体(美股 + 港股)
|
||||
us_symbols = getattr(settings, 'stock_symbols_us', '') or ''
|
||||
hk_symbols = getattr(settings, 'stock_symbols_hk', '') or ''
|
||||
|
||||
if (us_symbols.strip() or hk_symbols.strip()):
|
||||
try:
|
||||
from app.stock_agent.stock_agent import get_stock_agent
|
||||
stock_agent = get_stock_agent()
|
||||
_stock_agent_task = asyncio.create_task(stock_agent.start())
|
||||
# 设置智能体实例到 API 模块
|
||||
stocks.set_stock_agent(stock_agent)
|
||||
|
||||
symbols_list = []
|
||||
if us_symbols:
|
||||
symbols_list.append(f"美股({len(us_symbols.split(','))}只)")
|
||||
if hk_symbols:
|
||||
symbols_list.append(f"港股({len(hk_symbols.split(','))}只)")
|
||||
|
||||
logger.info(f"股票智能体已启动,监控: {', '.join(symbols_list)}")
|
||||
except Exception as e:
|
||||
logger.error(f"股票智能体启动失败: {e}")
|
||||
logger.error(f"提示: 请确保已安装 yfinance (pip install yfinance)")
|
||||
else:
|
||||
logger.info("股票智能体未启动(未配置股票代码)")
|
||||
|
||||
# 启动新闻智能体
|
||||
# try:
|
||||
# from app.news_agent.news_agent import get_news_agent
|
||||
# news_agent = get_news_agent()
|
||||
# _news_agent_task = asyncio.create_task(news_agent.start())
|
||||
# logger.info("新闻智能体已启动")
|
||||
# except Exception as e:
|
||||
# logger.error(f"新闻智能体启动失败: {e}")
|
||||
# logger.error(f"提示: 请确保已安装 feedparser 和 beautifulsoup4 (pip install feedparser beautifulsoup4)")
|
||||
|
||||
# 启动A股智能体
|
||||
if getattr(settings, 'astock_monitor_enabled', True):
|
||||
try:
|
||||
from app.astock_agent import SectorMonitor, AStockAgent
|
||||
# 初始化板块监控(保留原有功能)
|
||||
sector_monitor = SectorMonitor(
|
||||
change_threshold=settings.astock_change_threshold,
|
||||
top_n=settings.astock_top_n,
|
||||
enable_notifier=bool(settings.dingtalk_astock_webhook)
|
||||
)
|
||||
# 保存实例供定时任务使用
|
||||
_astock_monitor_instance = sector_monitor
|
||||
|
||||
# 初始化短期题材选股器(新功能)
|
||||
try:
|
||||
astock_agent = AStockAgent()
|
||||
# 设置智能体实例到 API 模块
|
||||
astock.set_astock_agent(astock_agent)
|
||||
logger.info(f"A股智能体已初始化(短期题材选股器)")
|
||||
except Exception as e:
|
||||
logger.warning(f"A股短期题材选股器初始化失败: {e}(可能缺少Tushare配置)")
|
||||
|
||||
logger.info(f"A股智能体已初始化")
|
||||
except Exception as e:
|
||||
logger.error(f"A股智能体初始化失败: {e}")
|
||||
|
||||
# 启动定时任务调度器
|
||||
await start_scheduler()
|
||||
|
||||
# 显示系统状态摘要
|
||||
await _print_system_status()
|
||||
|
||||
@ -599,57 +435,13 @@ async def lifespan(app: FastAPI):
|
||||
pass
|
||||
logger.info("加密货币智能体已停止")
|
||||
|
||||
# 停止美股智能体
|
||||
if _stock_agent_task:
|
||||
_stock_agent_task.cancel()
|
||||
try:
|
||||
await _stock_agent_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("美股智能体已停止")
|
||||
|
||||
# 停止新闻智能体
|
||||
if _news_agent_task:
|
||||
try:
|
||||
from app.news_agent.news_agent import get_news_agent
|
||||
news_agent = get_news_agent()
|
||||
await news_agent.stop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"停止新闻智能体失败: {e}")
|
||||
logger.info("新闻智能体已停止")
|
||||
|
||||
# 停止A股智能体
|
||||
global _astock_scheduler
|
||||
if _astock_scheduler:
|
||||
_astock_scheduler.shutdown(wait=False)
|
||||
logger.info("A股定时任务已停止")
|
||||
|
||||
if _astock_monitor_task:
|
||||
_astock_monitor_task.cancel()
|
||||
try:
|
||||
await _astock_monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("A股智能体已停止")
|
||||
|
||||
# 停止A股短期题材选股器
|
||||
try:
|
||||
from app.astock_agent import get_astock_agent
|
||||
astock_agent = get_astock_agent()
|
||||
astock_agent.stop()
|
||||
logger.info("A股短期题材选股器已停止")
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info("应用关闭")
|
||||
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="A股AI分析Agent系统",
|
||||
description="基于AI Agent的股票智能分析系统",
|
||||
title="Crypto Trading Agent",
|
||||
description="基于 AI 的加密货币交易分析与执行系统",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
@ -667,16 +459,10 @@ app.add_middleware(
|
||||
# 注册路由
|
||||
app.include_router(auth.router, tags=["认证"])
|
||||
app.include_router(admin.router, tags=["后台管理"])
|
||||
app.include_router(chat.router, prefix="/api/chat", tags=["对话"])
|
||||
app.include_router(stock.router, prefix="/api/stock", tags=["股票数据"])
|
||||
app.include_router(skills.router, prefix="/api/skills", tags=["技能管理"])
|
||||
app.include_router(llm.router, tags=["LLM模型"])
|
||||
app.include_router(paper_trading.router, tags=["交易"])
|
||||
app.include_router(bitget_live.router, tags=["Bitget"])
|
||||
app.include_router(stocks.router, prefix="/api/stocks", tags=["美股分析"])
|
||||
app.include_router(astock.router, prefix="/api/astock", tags=["A股分析"])
|
||||
app.include_router(signals.router, tags=["信号管理"])
|
||||
app.include_router(news.router, tags=["新闻管理"])
|
||||
app.include_router(system.router, prefix="/api/system", tags=["系统状态"])
|
||||
|
||||
# 挂载静态文件
|
||||
@ -715,8 +501,8 @@ async def trading_page():
|
||||
|
||||
@app.get("/bitget-trading")
|
||||
async def bitget_trading_page():
|
||||
"""Bitget 实盘交易页面"""
|
||||
page_path = os.path.join(frontend_path, "real-trading.html")
|
||||
"""Bitget 交易页面兼容入口,统一跳转到当前 trading 页面"""
|
||||
page_path = os.path.join(frontend_path, "trading.html")
|
||||
if os.path.exists(page_path):
|
||||
return FileResponse(page_path)
|
||||
return {"message": "页面不存在"}
|
||||
@ -729,14 +515,6 @@ async def signals_page():
|
||||
return FileResponse(page_path)
|
||||
return {"message": "页面不存在"}
|
||||
|
||||
@app.get("/status")
|
||||
async def status_page():
|
||||
"""系统状态监控页面"""
|
||||
page_path = os.path.join(frontend_path, "status.html")
|
||||
if os.path.exists(page_path):
|
||||
return FileResponse(page_path)
|
||||
return {"message": "页面不存在"}
|
||||
|
||||
@app.get("/console")
|
||||
async def console_page():
|
||||
"""系统总控台页面"""
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
"""
|
||||
对话相关的Pydantic模型
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""聊天消息"""
|
||||
role: str = Field(..., description="角色:user或assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="元数据")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求"""
|
||||
message: str = Field(..., description="用户消息", min_length=1)
|
||||
session_id: Optional[str] = Field(None, description="会话ID")
|
||||
user_id: Optional[str] = Field(None, description="用户ID")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""聊天响应"""
|
||||
message: str = Field(..., description="助手回复")
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="元数据")
|
||||
|
||||
|
||||
class ConversationHistory(BaseModel):
|
||||
"""对话历史"""
|
||||
session_id: str
|
||||
messages: list[ChatMessage]
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@ -1,99 +0,0 @@
|
||||
"""
|
||||
新闻文章数据库模型
|
||||
"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean, Float
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.database import Base
|
||||
|
||||
|
||||
class NewsArticle(Base):
|
||||
"""新闻文章表"""
|
||||
__tablename__ = "news_articles"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# 新闻基本信息
|
||||
title = Column(String(500), nullable=False)
|
||||
content = Column(Text, nullable=True) # 完整内容或摘要
|
||||
content_hash = Column(String(64), nullable=False, index=True) # 内容哈希,用于去重
|
||||
url = Column(String(1000), nullable=False, unique=True) # 原文链接
|
||||
source = Column(String(100), nullable=False, index=True) # 来源网站
|
||||
author = Column(String(200), nullable=True) # 作者
|
||||
|
||||
# 新闻分类
|
||||
category = Column(String(50), nullable=False, index=True) # 'crypto', 'stock', 'forex', 'commodity'
|
||||
tags = Column(JSON, nullable=True) # 标签列表
|
||||
|
||||
# 时间信息
|
||||
published_at = Column(DateTime, nullable=True, index=True) # 发布时间
|
||||
crawled_at = Column(DateTime, default=datetime.utcnow, index=True) # 爬取时间
|
||||
|
||||
# LLM 分析结果
|
||||
llm_analyzed = Column(Boolean, default=False, index=True) # 是否已分析
|
||||
market_impact = Column(String(20), nullable=True, index=True) # 'high', 'medium', 'low'
|
||||
impact_type = Column(String(50), nullable=True) # 'bullish', 'bearish', 'neutral'
|
||||
relevant_symbols = Column(JSON, nullable=True) # 相关的币种/股票代码
|
||||
|
||||
# LLM 分析详情
|
||||
sentiment = Column(String(20), nullable=True) # 'positive', 'negative', 'neutral'
|
||||
summary = Column(Text, nullable=True) # LLM 生成的摘要
|
||||
key_points = Column(JSON, nullable=True) # 关键点列表
|
||||
trading_advice = Column(Text, nullable=True) # 交易建议
|
||||
|
||||
# 优先级队列
|
||||
priority = Column(Float, default=0.0, index=True) # 优先级分数
|
||||
priority_reason = Column(Text, nullable=True) # 优先级原因
|
||||
|
||||
# 通知状态
|
||||
notified = Column(Boolean, default=False, index=True) # 是否已发送通知
|
||||
notification_sent_at = Column(DateTime, nullable=True)
|
||||
notification_channel = Column(String(50), nullable=True) # 'feishu', 'telegram', etc.
|
||||
|
||||
# 质量控制
|
||||
quality_score = Column(Float, nullable=True) # 质量分数 0-1
|
||||
duplicate_of = Column(Integer, nullable=True) # 如果是重复,指向原始文章ID
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, index=True) # 是否有效
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<NewsArticle({self.category} {self.source} {self.title[:50]}...)>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'title': self.title,
|
||||
'content': self.content,
|
||||
'url': self.url,
|
||||
'source': self.source,
|
||||
'author': self.author,
|
||||
'category': self.category,
|
||||
'tags': self.tags,
|
||||
'published_at': self.published_at.isoformat() if self.published_at else None,
|
||||
'crawled_at': self.crawled_at.isoformat() if self.crawled_at else None,
|
||||
'llm_analyzed': self.llm_analyzed,
|
||||
'market_impact': self.market_impact,
|
||||
'impact_type': self.impact_type,
|
||||
'relevant_symbols': self.relevant_symbols,
|
||||
'sentiment': self.sentiment,
|
||||
'summary': self.summary,
|
||||
'key_points': self.key_points,
|
||||
'trading_advice': self.trading_advice,
|
||||
'priority': self.priority,
|
||||
'priority_reason': self.priority_reason,
|
||||
'notified': self.notified,
|
||||
'notification_sent_at': self.notification_sent_at.isoformat() if self.notification_sent_at else None,
|
||||
'notification_channel': self.notification_channel,
|
||||
'quality_score': self.quality_score,
|
||||
'duplicate_of': self.duplicate_of,
|
||||
'is_active': self.is_active,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
@ -61,7 +61,7 @@ class PaperOrder(Base):
|
||||
exit_price = Column(Float, nullable=True) # 出场价
|
||||
|
||||
# 仓位信息
|
||||
quantity = Column(Float, default=1000) # 持仓价值 (USDT)
|
||||
quantity = Column(Float, default=1000) # 兼容旧字段:名义仓位 (USDT)
|
||||
margin = Column(Float, default=50) # 保证金 (USDT)
|
||||
leverage = Column(Integer, default=10) # 杠杆倍数
|
||||
|
||||
@ -116,7 +116,8 @@ class PaperOrder(Base):
|
||||
'take_profit': self.take_profit,
|
||||
'filled_price': self.filled_price,
|
||||
'exit_price': self.exit_price,
|
||||
'quantity': self.quantity, # 持仓价值
|
||||
'quantity': self.quantity, # 兼容旧字段
|
||||
'notional': self.quantity, # 标准字段:名义仓位
|
||||
'margin': getattr(self, 'margin', self.quantity / 10), # 保证金(回退值:10倍杠杆)
|
||||
'leverage': getattr(self, 'leverage', 10), # 杠杆倍数(回退值:10倍)
|
||||
'signal_grade': self.signal_grade.value if self.signal_grade else None,
|
||||
|
||||
@ -15,8 +15,8 @@ class TradingSignal(Base):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# 信号基本信息
|
||||
signal_type = Column(String(20), nullable=False, index=True) # 'crypto' or 'stock'
|
||||
symbol = Column(String(50), nullable=False, index=True) # 交易对或股票代码
|
||||
signal_type = Column(String(20), nullable=False, index=True) # 当前仅使用 'crypto'
|
||||
symbol = Column(String(50), nullable=False, index=True) # 交易对
|
||||
|
||||
# 信号方向和评级
|
||||
action = Column(String(10), nullable=False) # 'buy', 'sell', 'hold'
|
||||
|
||||
@ -1,48 +0,0 @@
|
||||
"""
|
||||
股票相关的Pydantic模型
|
||||
"""
|
||||
from datetime import date
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StockQuote(BaseModel):
|
||||
"""股票行情"""
|
||||
ts_code: str = Field(..., description="股票代码")
|
||||
name: Optional[str] = Field(None, description="股票名称")
|
||||
trade_date: Optional[str] = Field(None, description="交易日期")
|
||||
open: Optional[float] = Field(None, description="开盘价")
|
||||
high: Optional[float] = Field(None, description="最高价")
|
||||
low: Optional[float] = Field(None, description="最低价")
|
||||
close: Optional[float] = Field(None, description="收盘价")
|
||||
pre_close: Optional[float] = Field(None, description="昨收价")
|
||||
change: Optional[float] = Field(None, description="涨跌额")
|
||||
pct_chg: Optional[float] = Field(None, description="涨跌幅%")
|
||||
vol: Optional[float] = Field(None, description="成交量(手)")
|
||||
amount: Optional[float] = Field(None, description="成交额(千元)")
|
||||
|
||||
|
||||
class KLineData(BaseModel):
|
||||
"""K线数据"""
|
||||
ts_code: str = Field(..., description="股票代码")
|
||||
trade_date: str = Field(..., description="交易日期")
|
||||
open: float = Field(..., description="开盘价")
|
||||
high: float = Field(..., description="最高价")
|
||||
low: float = Field(..., description="最低价")
|
||||
close: float = Field(..., description="收盘价")
|
||||
vol: float = Field(..., description="成交量")
|
||||
amount: Optional[float] = Field(None, description="成交额")
|
||||
|
||||
|
||||
class TechnicalIndicators(BaseModel):
|
||||
"""技术指标"""
|
||||
ma5: Optional[List[float]] = Field(None, description="5日均线")
|
||||
ma10: Optional[List[float]] = Field(None, description="10日均线")
|
||||
ma20: Optional[List[float]] = Field(None, description="20日均线")
|
||||
macd_dif: Optional[List[float]] = Field(None, description="MACD DIF")
|
||||
macd_dea: Optional[List[float]] = Field(None, description="MACD DEA")
|
||||
macd: Optional[List[float]] = Field(None, description="MACD柱")
|
||||
rsi: Optional[List[float]] = Field(None, description="RSI")
|
||||
kdj_k: Optional[List[float]] = Field(None, description="KDJ K值")
|
||||
kdj_d: Optional[List[float]] = Field(None, description="KDJ D值")
|
||||
kdj_j: Optional[List[float]] = Field(None, description="KDJ J值")
|
||||
@ -1,38 +0,0 @@
|
||||
"""
|
||||
新闻智能体模块
|
||||
"""
|
||||
from app.news_agent.news_agent import NewsAgent, get_news_agent
|
||||
from app.news_agent.fetcher import NewsFetcher, NewsItem
|
||||
from app.news_agent.filter import NewsDeduplicator, NewsFilter
|
||||
from app.news_agent.analyzer import NewsAnalyzer, NewsAnalyzerSimple
|
||||
from app.news_agent.notifier import NewsNotifier, get_news_notifier
|
||||
from app.news_agent.news_db_service import NewsDatabaseService, get_news_db_service
|
||||
from app.news_agent.sources import (
|
||||
get_enabled_sources,
|
||||
CRYPTO_NEWS_SOURCES,
|
||||
STOCK_NEWS_SOURCES,
|
||||
CRYPTO_KEYWORDS,
|
||||
STOCK_KEYWORDS,
|
||||
SYMBOL_MAPPINGS
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'NewsAgent',
|
||||
'get_news_agent',
|
||||
'NewsFetcher',
|
||||
'NewsItem',
|
||||
'NewsDeduplicator',
|
||||
'NewsFilter',
|
||||
'NewsAnalyzer',
|
||||
'NewsAnalyzerSimple',
|
||||
'NewsNotifier',
|
||||
'get_news_notifier',
|
||||
'NewsDatabaseService',
|
||||
'get_news_db_service',
|
||||
'get_enabled_sources',
|
||||
'CRYPTO_NEWS_SOURCES',
|
||||
'STOCK_NEWS_SOURCES',
|
||||
'CRYPTO_KEYWORDS',
|
||||
'STOCK_KEYWORDS',
|
||||
'SYMBOL_MAPPINGS',
|
||||
]
|
||||
@ -1,527 +0,0 @@
|
||||
"""
|
||||
新闻 LLM 分析模块
|
||||
使用 LLM 分析新闻内容并生成交易建议
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.news_agent.fetcher import NewsItem
|
||||
from app.config import get_settings
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class NewsAnalyzer:
|
||||
"""新闻 LLM 分析器 (DeepSeek) - 异步版本"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.client = None
|
||||
|
||||
try:
|
||||
# 使用 DeepSeek API (异步客户端)
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.settings.deepseek_api_key,
|
||||
base_url="https://api.deepseek.com"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 客户端初始化失败: {e}")
|
||||
|
||||
# 批量分析配置
|
||||
self.batch_size = 10 # 每次最多分析 10 条新闻(只传标题,可以增加数量)
|
||||
self.max_retries = 2
|
||||
|
||||
# 余额错误通知冷却时间(秒)
|
||||
self._balance_error_cooldown = 3600 # 1小时内只通知一次
|
||||
self._balance_error_last_notified = None
|
||||
|
||||
def _build_analysis_prompt(self, news_item: NewsItem) -> str:
|
||||
"""构建单条新闻的分析提示词"""
|
||||
|
||||
prompt = f"""你是一名专业的金融新闻分析师。请分析以下新闻标题,并以 JSON 格式输出结果。
|
||||
|
||||
**新闻标题**: {news_item.title}
|
||||
|
||||
**新闻来源**: {news_item.source}
|
||||
|
||||
**新闻分类**: {news_item.category}
|
||||
|
||||
请按以下 JSON 格式输出(不要包含其他内容):
|
||||
|
||||
```json
|
||||
{{
|
||||
"market_impact": "high/medium/low",
|
||||
"impact_type": "bullish/bearish/neutral",
|
||||
"sentiment": "positive/negative/neutral",
|
||||
"summary": "简洁的新闻摘要(1句话,不超过50字)",
|
||||
"key_points": ["关键点1", "关键点2", "关键点3"],
|
||||
"trading_advice": "简洁的交易建议(1句话,不超过30字)",
|
||||
"relevant_symbols": ["相关的币种或股票代码"],
|
||||
"confidence": 85
|
||||
}}
|
||||
```
|
||||
|
||||
**分析要求**:
|
||||
1. market_impact: 对市场的潜在影响(high/medium/low)
|
||||
|
||||
⚠️ **high(重大影响)- 请严格判断,只有以下情况才标记为 high**:
|
||||
- 监管层面:ETF批准/拒绝、交易所封禁/解禁、央行政策重大变化
|
||||
- 企业层面:破产/退市/重大并购(>100亿美元)、财务造假
|
||||
- 技术层面:严重安全漏洞(被盗>1亿美元)、网络暂停
|
||||
- 宏观层面:重大地缘政治事件、经济数据远超预期
|
||||
|
||||
❌ **以下情况不应该标记为 high**:
|
||||
- 普通价格波动(涨跌<10%)
|
||||
- 分析师观点/评级调整
|
||||
- CEO发表常规评论
|
||||
- 一般业务合作/投资
|
||||
- 常规财报发布(非意外业绩)
|
||||
|
||||
- **medium**: 对价格有**短期影响**但不会改变长期趋势的事件
|
||||
* 财报业绩、管理层变动、一般并购、机构评级调整
|
||||
* 业务合作、技术升级、普通投资新闻
|
||||
|
||||
- **low**: 常规信息,影响有限
|
||||
* 分析师观点、一般评论、价格波动、市场常规动态
|
||||
|
||||
**判断原则**:
|
||||
1. 问自己"这条新闻会改变市场/公司的长期格局吗?"
|
||||
2. 如果会→high,如果只是短期波动→medium,如果无关紧要→low
|
||||
3. 宁可判断为 medium,也不要过度判断为 high
|
||||
4. 价格波动类新闻,除非涨跌>15%,否则不应是 high
|
||||
|
||||
2. impact_type: 对价格的影响方向(bullish=利好, bearish=利空, neutral=中性)
|
||||
3. sentiment: 新闻情绪(positive=正面, negative=负面, neutral=中性)
|
||||
4. summary: 根据标题推断并总结新闻核心内容
|
||||
5. key_points: 基于标题推断3-5个关键信息点
|
||||
6. trading_advice: 给出简明的交易建议
|
||||
7. relevant_symbols: 根据标题列出相关的交易代码(如 BTC, ETH, NVDA, TSLA 等)
|
||||
8. confidence: 分析置信度(0-100)
|
||||
|
||||
请只输出 JSON,不要包含其他解释。
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _build_batch_analysis_prompt(self, news_items: List[NewsItem]) -> str:
|
||||
"""构建批量分析提示词"""
|
||||
|
||||
news_text = ""
|
||||
for i, item in enumerate(news_items, 1):
|
||||
news_text += f"""
|
||||
--- 新闻 {i} ---
|
||||
标题: {item.title}
|
||||
来源: {item.source}
|
||||
分类: {item.category}
|
||||
---
|
||||
"""
|
||||
|
||||
prompt = f"""你是一名专业的金融新闻分析师。请分析以下 {len(news_items)} 条新闻标题,并以 JSON 数组格式输出结果。
|
||||
|
||||
{news_text}
|
||||
|
||||
请按以下 JSON 格式输出(不要包含其他内容):
|
||||
|
||||
```json
|
||||
[
|
||||
{{
|
||||
"title": "新闻标题",
|
||||
"market_impact": "high/medium/low",
|
||||
"impact_type": "bullish/bearish/neutral",
|
||||
"sentiment": "positive/negative/neutral",
|
||||
"summary": "简洁的新闻摘要(1句话,不超过50字)",
|
||||
"key_points": ["关键点1", "关键点2"],
|
||||
"trading_advice": "简洁的交易建议(1句话,不超过30字)",
|
||||
"relevant_symbols": ["相关代码"],
|
||||
"confidence": 85
|
||||
}}
|
||||
]
|
||||
```
|
||||
|
||||
**market_impact 判断标准(严格)**:
|
||||
|
||||
⚠️ **high(重大影响)- 请严格判断**:
|
||||
- 监管:ETF批准/拒绝、交易所封禁/解禁、央行政策重大变化
|
||||
- 企业:破产/退市、重大并购(>100亿美元)、财务造假
|
||||
- 技术:严重安全漏洞(被盗>1亿美元)、网络暂停
|
||||
- 宏观:重大地缘政治事件、经济数据远超预期
|
||||
|
||||
❌ **以下情况不应该标记为 high**:
|
||||
- 普通价格波动(涨跌<10%)
|
||||
- 分析师观点/评级调整
|
||||
- CEO发表常规评论
|
||||
- 一般业务合作/投资
|
||||
- 常规财报发布(非意外业绩)
|
||||
|
||||
- **medium**: 对价格有**短期影响**但不会改变长期趋势
|
||||
- **low**: 常规信息,影响有限
|
||||
|
||||
**判断原则**: 问自己"这条新闻会改变市场/公司的长期格局吗?" 如果会→high,否则→medium/low
|
||||
|
||||
请只输出 JSON 数组,不要包含其他解释。
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析 LLM 响应"""
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
response = response.strip()
|
||||
|
||||
# 移除可能的 markdown 代码块标记
|
||||
if response.startswith("```json"):
|
||||
response = response[7:]
|
||||
if response.startswith("```"):
|
||||
response = response[3:]
|
||||
if response.endswith("```"):
|
||||
response = response[:-3]
|
||||
|
||||
response = response.strip()
|
||||
|
||||
# 解析 JSON
|
||||
return json.loads(response)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# 尝试修复截断的 JSON
|
||||
logger.warning(f"JSON 解析失败,尝试修复: {e}")
|
||||
try:
|
||||
# 查找最后一个完整的对象
|
||||
response = response.strip()
|
||||
|
||||
# 如果是数组,找到最后一个完整的对象
|
||||
if response.startswith('['):
|
||||
# 找到每个完整对象的结束位置
|
||||
brace_count = 0
|
||||
last_complete = 0
|
||||
for i, char in enumerate(response):
|
||||
if char == '{':
|
||||
brace_count += 1
|
||||
elif char == '}':
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
last_complete = i + 1
|
||||
break
|
||||
|
||||
if last_complete > 0:
|
||||
# 提取完整的数组
|
||||
fixed = response[:last_complete]
|
||||
if not fixed.endswith(']'):
|
||||
fixed += ']'
|
||||
if not fixed.endswith('}'):
|
||||
fixed += '}'
|
||||
return json.loads(fixed)
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(f"JSON 解析失败: {e}, 响应: {response[:500]}")
|
||||
return None
|
||||
|
||||
def _parse_llm_array_response(self, response: str) -> Optional[List[Dict[str, Any]]]:
|
||||
"""解析 LLM 数组响应"""
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
response = response.strip()
|
||||
|
||||
# 移除可能的 markdown 代码块标记
|
||||
if response.startswith("```json"):
|
||||
response = response[7:]
|
||||
if response.startswith("```"):
|
||||
response = response[3:]
|
||||
if response.endswith("```"):
|
||||
response = response[:-3]
|
||||
|
||||
response = response.strip()
|
||||
|
||||
# 解析 JSON 数组
|
||||
result = json.loads(response)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
elif isinstance(result, dict) and 'title' in result:
|
||||
# 如果返回单个对象,包装成数组
|
||||
return [result]
|
||||
return None
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# 尝试修复截断的 JSON 数组
|
||||
logger.warning(f"JSON 数组解析失败,尝试修复: {e}")
|
||||
try:
|
||||
response = response.strip()
|
||||
|
||||
if response.startswith('['):
|
||||
# 找到每个完整对象
|
||||
objects = []
|
||||
brace_count = 0
|
||||
obj_start = -1
|
||||
|
||||
for i, char in enumerate(response):
|
||||
if char == '{':
|
||||
if obj_start == -1:
|
||||
obj_start = i
|
||||
brace_count += 1
|
||||
elif char == '}':
|
||||
brace_count -= 1
|
||||
if brace_count == 0 and obj_start >= 0:
|
||||
# 提取完整对象
|
||||
obj_str = response[obj_start:i + 1]
|
||||
try:
|
||||
obj = json.loads(obj_str)
|
||||
if isinstance(obj, dict) and 'title' in obj:
|
||||
objects.append(obj)
|
||||
except:
|
||||
pass
|
||||
obj_start = -1
|
||||
|
||||
if objects:
|
||||
return objects
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(f"JSON 数组解析失败: {e}, 响应: {response[:500]}")
|
||||
return None
|
||||
|
||||
async def analyze_single(self, news_item: NewsItem) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
分析单条新闻 (异步)
|
||||
|
||||
Args:
|
||||
news_item: 新闻项
|
||||
|
||||
Returns:
|
||||
分析结果字典或 None
|
||||
"""
|
||||
if not self.client:
|
||||
logger.warning("LLM 客户端未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
prompt = self._build_analysis_prompt(news_item)
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一名专业的金融新闻分析师,擅长分析新闻标题对市场的影响。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=1000 # 只传标题,减少输出token
|
||||
)
|
||||
|
||||
result = self._parse_llm_response(response.choices[0].message.content)
|
||||
|
||||
if result:
|
||||
logger.info(f"新闻分析成功: {news_item.title[:50]}... -> {result.get('market_impact')}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"分析失败 (尝试 {attempt + 1}/{self.max_retries}): {e}")
|
||||
|
||||
# 检查是否是余额不足错误 (402)
|
||||
error_str = str(e)
|
||||
error_code = str(e).split('Error code: ')[1].split(' -')[0] if 'Error code:' in error_str else ''
|
||||
if error_code == '402' or ('402' in error_str and 'insufficient balance' in error_str.lower()):
|
||||
await self._notify_balance_error(e)
|
||||
break # 余额不足不再重试
|
||||
|
||||
logger.error(f"新闻分析失败,已达最大重试次数: {news_item.title[:50]}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析新闻时出错: {e}")
|
||||
return None
|
||||
|
||||
async def analyze_batch(self, news_items: List[NewsItem]) -> List[Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
批量分析新闻 (异步)
|
||||
|
||||
Args:
|
||||
news_items: 新闻项列表
|
||||
|
||||
Returns:
|
||||
分析结果列表(与输入顺序一致)
|
||||
"""
|
||||
if not self.client:
|
||||
logger.warning("LLM 客户端未初始化")
|
||||
return [None] * len(news_items)
|
||||
|
||||
results = []
|
||||
|
||||
# 分批处理
|
||||
for i in range(0, len(news_items), self.batch_size):
|
||||
batch = news_items[i:i + self.batch_size]
|
||||
|
||||
try:
|
||||
prompt = self._build_batch_analysis_prompt(batch)
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一名专业的金融新闻分析师,擅长分析新闻标题对市场的影响。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2000 # 批量分析需要更多 token
|
||||
)
|
||||
|
||||
batch_results = self._parse_llm_array_response(response.choices[0].message.content)
|
||||
|
||||
if batch_results:
|
||||
# 按标题匹配结果
|
||||
title_to_result = {r.get('title'): r for r in batch_results if r and isinstance(r, dict)}
|
||||
for item in batch:
|
||||
result = title_to_result.get(item.title)
|
||||
results.append(result)
|
||||
if result:
|
||||
logger.info(f"新闻分析成功: {item.title[:50]}... -> {result.get('market_impact')}")
|
||||
else:
|
||||
results.extend([None] * len(batch))
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
error_code = str(e).split('Error code: ')[1].split(' -')[0] if 'Error code:' in error_str else ''
|
||||
|
||||
logger.error(f"批量分析失败: {e}")
|
||||
|
||||
# 检查是否是余额不足错误 (402)
|
||||
if error_code == '402' or ('402' in error_str and 'insufficient balance' in error_str.lower()):
|
||||
await self._notify_balance_error(e)
|
||||
|
||||
results.extend([None] * len(batch))
|
||||
|
||||
return results
|
||||
|
||||
async def _notify_balance_error(self, error: Exception):
|
||||
"""
|
||||
发送余额不足的飞书通知
|
||||
|
||||
Args:
|
||||
error: 异常对象
|
||||
"""
|
||||
# 检查冷却时间
|
||||
now = datetime.now()
|
||||
if self._balance_error_last_notified:
|
||||
time_since_last = (now - self._balance_error_last_notified).total_seconds()
|
||||
if time_since_last < self._balance_error_cooldown:
|
||||
logger.info(f"余额错误通知冷却中,剩余 {int(self._balance_error_cooldown - time_since_last)} 秒")
|
||||
return
|
||||
|
||||
# 发送通知
|
||||
try:
|
||||
from app.services.feishu_service import get_feishu_service
|
||||
feishu = get_feishu_service()
|
||||
|
||||
message = f"""🚨 **新闻分析 LLM API 余额不足警告**
|
||||
|
||||
**服务商**: DeepSeek
|
||||
**错误类型**: 余额不足 (Insufficient Balance)
|
||||
**错误信息**: {str(error)[:200]}
|
||||
**时间**: {now.strftime('%Y-%m-%d %H:%M:%S')}
|
||||
|
||||
⚠️ 请及时充值,否则新闻智能体将无法正常工作"""
|
||||
|
||||
await feishu.send_text(message)
|
||||
logger.warning("已发送 DeepSeek 余额不足飞书通知")
|
||||
|
||||
# 记录通知时间
|
||||
self._balance_error_last_notified = now
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送余额不足通知失败: {e}")
|
||||
|
||||
def calculate_priority(self, analysis: Dict[str, Any], quality_score: float = 0.5) -> float:
|
||||
"""
|
||||
根据分析结果计算优先级
|
||||
|
||||
Args:
|
||||
analysis: LLM 分析结果
|
||||
quality_score: 质量分数
|
||||
|
||||
Returns:
|
||||
优先级分数
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 市场影响(更严格的权重)
|
||||
impact_weights = {'high': 50, 'medium': 25, 'low': 5} # 降低 low 和 medium 的权重
|
||||
score += impact_weights.get(analysis.get('market_impact', 'low'), 5)
|
||||
|
||||
# 方向性(利空利好比中性重要)
|
||||
if analysis.get('impact_type') in ['bullish', 'bearish']:
|
||||
score += 10 # 从 15 降低到 10
|
||||
|
||||
# 置信度(降低权重)
|
||||
score += (analysis.get('confidence', 50) / 100) * 8 # 从 10 降低到 8
|
||||
|
||||
# 质量分数(保持)
|
||||
score += quality_score * 15 # 从 20 降低到 15
|
||||
|
||||
# 是否有相关代码(提高重要性)
|
||||
if analysis.get('relevant_symbols'):
|
||||
score += 12 # 从 5 提高到 12
|
||||
|
||||
return score
|
||||
|
||||
|
||||
class NewsAnalyzerSimple:
|
||||
"""简化版新闻分析器(仅关键词规则,不使用 LLM)"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def analyze_single(self, news_item: NewsItem) -> Dict[str, Any]:
|
||||
"""
|
||||
基于规则分析新闻
|
||||
|
||||
Args:
|
||||
news_item: 新闻项
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
# 使用已有的影响评分
|
||||
impact_score = getattr(news_item, 'impact_score', 0.0)
|
||||
|
||||
# 根据 impact_score 确定市场影响
|
||||
if impact_score >= 1.0:
|
||||
market_impact = 'high'
|
||||
elif impact_score >= 0.7:
|
||||
market_impact = 'medium'
|
||||
else:
|
||||
market_impact = 'low'
|
||||
|
||||
# 检查关键词确定方向
|
||||
text = f"{news_item.title} {news_item.content}".lower()
|
||||
|
||||
bullish_keywords = ['上涨', '增长', '突破', '新高', 'bullish', 'surge', 'rally', 'gain', '批准', '合作']
|
||||
bearish_keywords = ['下跌', '暴跌', '崩盘', 'ban', 'bearish', 'crash', 'plunge', 'fall', '禁令', '风险']
|
||||
|
||||
bullish_count = sum(1 for k in bullish_keywords if k in text)
|
||||
bearish_count = sum(1 for k in bearish_keywords if k in text)
|
||||
|
||||
if bullish_count > bearish_count:
|
||||
impact_type = 'bullish'
|
||||
sentiment = 'positive'
|
||||
elif bearish_count > bullish_count:
|
||||
impact_type = 'bearish'
|
||||
sentiment = 'negative'
|
||||
else:
|
||||
impact_type = 'neutral'
|
||||
sentiment = 'neutral'
|
||||
|
||||
# 获取相关代码
|
||||
relevant_symbols = list(set(getattr(news_item, 'relevant_symbols', [])))
|
||||
|
||||
return {
|
||||
'market_impact': market_impact,
|
||||
'impact_type': impact_type,
|
||||
'sentiment': sentiment,
|
||||
'summary': news_item.title,
|
||||
'key_points': [news_item.title[:100]],
|
||||
'trading_advice': getattr(news_item, 'impact_reason', '关注市场动态'),
|
||||
'relevant_symbols': relevant_symbols,
|
||||
'confidence': 70,
|
||||
'analyzed_by': 'rules'
|
||||
}
|
||||
@ -1,271 +0,0 @@
|
||||
"""
|
||||
新闻获取模块 - 从 RSS 源获取新闻
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import feedparser
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.news_agent.sources import get_enabled_sources
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsItem:
|
||||
"""新闻项数据类"""
|
||||
title: str
|
||||
content: str
|
||||
url: str
|
||||
source: str
|
||||
category: str
|
||||
published_at: Optional[datetime]
|
||||
crawled_at: datetime
|
||||
content_hash: str
|
||||
author: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'title': self.title,
|
||||
'content': self.content,
|
||||
'url': self.url,
|
||||
'source': self.source,
|
||||
'category': self.category,
|
||||
'published_at': self.published_at.isoformat() if self.published_at else None,
|
||||
'crawled_at': self.crawled_at.isoformat(),
|
||||
'content_hash': self.content_hash,
|
||||
'author': self.author,
|
||||
'tags': self.tags,
|
||||
}
|
||||
|
||||
|
||||
class NewsFetcher:
|
||||
"""新闻获取器"""
|
||||
|
||||
def __init__(self):
|
||||
self.sources = get_enabled_sources()
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
headers={
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
||||
}
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""关闭 HTTP 客户端"""
|
||||
await self.client.aclose()
|
||||
|
||||
def _generate_content_hash(self, title: str, content: str) -> str:
|
||||
"""生成内容哈希用于去重"""
|
||||
combined = f"{title}{content}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
def _clean_html(self, html: str) -> str:
|
||||
"""清理 HTML,提取纯文本"""
|
||||
if not html:
|
||||
return ""
|
||||
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# 移除脚本和样式
|
||||
for script in soup(['script', 'style']):
|
||||
script.decompose()
|
||||
|
||||
# 获取文本
|
||||
text = soup.get_text()
|
||||
|
||||
# 清理空白
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = ' '.join(chunk for chunk in chunks if chunk)
|
||||
|
||||
return text[:5000] # 限制长度
|
||||
|
||||
def _parse_rss_date(self, date_str: str) -> Optional[datetime]:
|
||||
"""解析 RSS 日期"""
|
||||
if not date_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
# feedparser 会解析日期
|
||||
parsed = feedparser.parse(date_str)
|
||||
if hasattr(parsed, 'updated_parsed'):
|
||||
return datetime(*parsed.updated_parsed[:6])
|
||||
except Exception as e:
|
||||
logger.debug(f"日期解析失败: {date_str}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def fetch_rss_feed(self, source: Dict[str, Any]) -> List[NewsItem]:
|
||||
"""
|
||||
获取单个 RSS 源的新闻
|
||||
|
||||
Args:
|
||||
source: 新闻源配置
|
||||
|
||||
Returns:
|
||||
新闻项列表
|
||||
"""
|
||||
items = []
|
||||
|
||||
try:
|
||||
logger.debug(f"正在获取 {source['name']} 的 RSS...")
|
||||
|
||||
# 使用 feedparser 解析 RSS
|
||||
feed = feedparser.parse(source['url'])
|
||||
|
||||
if feed.bozo: # RSS 解析错误
|
||||
logger.warning(f"{source['name']} RSS 解析警告: {feed.bozo_exception}")
|
||||
|
||||
# 解析每个条目
|
||||
for entry in feed.entries[:50]: # 每次最多取 50 条
|
||||
try:
|
||||
# 提取标题
|
||||
title = entry.get('title', '')
|
||||
|
||||
# 提取内容
|
||||
content = ''
|
||||
if hasattr(entry, 'content'):
|
||||
content = entry.content[0].value if entry.content else ''
|
||||
elif hasattr(entry, 'summary'):
|
||||
content = entry.summary
|
||||
elif hasattr(entry, 'description'):
|
||||
content = entry.description
|
||||
|
||||
# 清理 HTML
|
||||
content = self._clean_html(content)
|
||||
|
||||
# 提取链接
|
||||
url = entry.get('link', '')
|
||||
|
||||
# 提取作者
|
||||
author = entry.get('author', None)
|
||||
|
||||
# 提取标签
|
||||
tags = []
|
||||
if hasattr(entry, 'tags'):
|
||||
tags = [tag.term for tag in entry.tags]
|
||||
|
||||
# 解析发布时间
|
||||
published_at = None
|
||||
if hasattr(entry, 'published_parsed'):
|
||||
published_at = datetime(*entry.published_parsed[:6])
|
||||
elif hasattr(entry, 'updated_parsed'):
|
||||
published_at = datetime(*entry.updated_parsed[:6])
|
||||
|
||||
# 只处理最近 24 小时的新闻
|
||||
if published_at:
|
||||
time_diff = datetime.utcnow() - published_at
|
||||
if time_diff > timedelta(hours=24):
|
||||
continue
|
||||
|
||||
# 生成内容哈希
|
||||
content_hash = self._generate_content_hash(title, content)
|
||||
|
||||
news_item = NewsItem(
|
||||
title=title,
|
||||
content=content,
|
||||
url=url,
|
||||
source=source['name'],
|
||||
category=source['category'],
|
||||
published_at=published_at,
|
||||
crawled_at=datetime.utcnow(),
|
||||
content_hash=content_hash,
|
||||
author=author,
|
||||
tags=tags if tags else None
|
||||
)
|
||||
|
||||
items.append(news_item)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"解析新闻条目失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"从 {source['name']} 获取到 {len(items)} 条新闻")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {source['name']} 失败: {e}")
|
||||
|
||||
return items
|
||||
|
||||
async def fetch_all_news(self, category: str = None) -> List[NewsItem]:
|
||||
"""
|
||||
获取所有新闻源的新闻
|
||||
|
||||
Args:
|
||||
category: 分类过滤 ('crypto', 'stock', None 表示全部)
|
||||
|
||||
Returns:
|
||||
所有新闻项列表
|
||||
"""
|
||||
sources = get_enabled_sources(category)
|
||||
|
||||
if not sources:
|
||||
logger.warning("没有启用的新闻源")
|
||||
return []
|
||||
|
||||
logger.info(f"开始从 {len(sources)} 个新闻源获取新闻...")
|
||||
|
||||
# 并发获取所有源
|
||||
tasks = [self.fetch_rss_feed(source) for source in sources]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 合并结果
|
||||
all_items = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"获取新闻时出错: {result}")
|
||||
continue
|
||||
all_items.extend(result)
|
||||
|
||||
logger.info(f"总共获取到 {len(all_items)} 条新闻")
|
||||
|
||||
return all_items
|
||||
|
||||
async def fetch_single_url(self, url: str, source: str = "manual") -> Optional[NewsItem]:
|
||||
"""
|
||||
获取单个 URL 的新闻内容
|
||||
|
||||
Args:
|
||||
url: 新闻 URL
|
||||
source: 新闻来源名称
|
||||
|
||||
Returns:
|
||||
新闻项或 None
|
||||
"""
|
||||
try:
|
||||
response = await self.client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 使用 BeautifulSoup 解析
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
|
||||
# 尝试提取标题
|
||||
title_tag = soup.find(['h1', 'title'])
|
||||
title = title_tag.get_text().strip() if title_tag else url
|
||||
|
||||
# 提取正文(简单处理,实际需要针对不同网站调整)
|
||||
content = self._clean_html(response.text)
|
||||
|
||||
# 生成哈希
|
||||
content_hash = self._generate_content_hash(title, content)
|
||||
|
||||
return NewsItem(
|
||||
title=title,
|
||||
content=content,
|
||||
url=url,
|
||||
source=source,
|
||||
category="manual",
|
||||
published_at=datetime.utcnow(),
|
||||
crawled_at=datetime.utcnow(),
|
||||
content_hash=content_hash
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 URL {url} 失败: {e}")
|
||||
return None
|
||||
@ -1,267 +0,0 @@
|
||||
"""
|
||||
新闻去重和过滤模块
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Set, Tuple
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.news_agent.fetcher import NewsItem
|
||||
from app.news_agent.sources import CRYPTO_KEYWORDS, STOCK_KEYWORDS, SYMBOL_MAPPINGS
|
||||
|
||||
|
||||
class NewsDeduplicator:
|
||||
"""新闻去重器"""
|
||||
|
||||
def __init__(self):
|
||||
self.recent_hashes: Set[str] = set()
|
||||
self.hash_expiry: datetime = None
|
||||
self.expiry_hours = 24
|
||||
|
||||
def _clean_hash_cache(self):
|
||||
"""清理过期的哈希缓存"""
|
||||
now = datetime.utcnow()
|
||||
if self.hash_expiry is None or now > self.hash_expiry:
|
||||
self.recent_hashes.clear()
|
||||
self.hash_expiry = now + timedelta(hours=self.expiry_hours)
|
||||
logger.debug("哈希缓存已清理")
|
||||
|
||||
def check_duplicate(self, item: NewsItem) -> bool:
|
||||
"""
|
||||
检查新闻是否重复
|
||||
|
||||
Args:
|
||||
item: 新闻项
|
||||
|
||||
Returns:
|
||||
True 如果重复
|
||||
"""
|
||||
self._clean_hash_cache()
|
||||
|
||||
# 检查内容哈希
|
||||
if item.content_hash in self.recent_hashes:
|
||||
return True
|
||||
|
||||
# 添加到缓存
|
||||
self.recent_hashes.add(item.content_hash)
|
||||
return False
|
||||
|
||||
def deduplicate_list(self, items: List[NewsItem]) -> List[NewsItem]:
|
||||
"""
|
||||
对新闻列表进行去重
|
||||
|
||||
Args:
|
||||
items: 新闻项列表
|
||||
|
||||
Returns:
|
||||
去重后的新闻列表
|
||||
"""
|
||||
seen_hashes = set()
|
||||
unique_items = []
|
||||
|
||||
for item in items:
|
||||
if item.content_hash not in seen_hashes:
|
||||
seen_hashes.add(item.content_hash)
|
||||
unique_items.append(item)
|
||||
|
||||
removed = len(items) - len(unique_items)
|
||||
if removed > 0:
|
||||
logger.info(f"去重: 移除了 {removed} 条重复新闻")
|
||||
|
||||
return unique_items
|
||||
|
||||
def find_similar(self, item: NewsItem, existing_items: List[NewsItem], threshold: float = 0.85) -> List[NewsItem]:
|
||||
"""
|
||||
查找相似新闻(基于标题相似度)
|
||||
|
||||
Args:
|
||||
item: 待检查的新闻项
|
||||
existing_items: 已存在的新闻列表
|
||||
threshold: 相似度阈值
|
||||
|
||||
Returns:
|
||||
相似新闻列表
|
||||
"""
|
||||
similar = []
|
||||
|
||||
for existing in existing_items:
|
||||
# 只比较同类新闻
|
||||
if existing.category != item.category:
|
||||
continue
|
||||
|
||||
# 标题相似度
|
||||
similarity = SequenceMatcher(None, item.title.lower(), existing.title.lower()).ratio()
|
||||
|
||||
if similarity >= threshold:
|
||||
similar.append((existing, similarity))
|
||||
|
||||
# 按相似度排序
|
||||
similar.sort(key=lambda x: x[1], reverse=True)
|
||||
return [s[0] for s in similar]
|
||||
|
||||
|
||||
class NewsFilter:
|
||||
"""新闻过滤器 - 关键词和质量过滤"""
|
||||
|
||||
def __init__(self):
|
||||
self.crypto_keywords = CRYPTO_KEYWORDS
|
||||
self.stock_keywords = STOCK_KEYWORDS
|
||||
self.symbol_mappings = SYMBOL_MAPPINGS
|
||||
|
||||
def _extract_symbols(self, text: str, category: str) -> List[str]:
|
||||
"""
|
||||
从文本中提取相关的币种或股票代码
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
category: 分类 ('crypto', 'stock')
|
||||
|
||||
Returns:
|
||||
相关代码列表
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
found_symbols = []
|
||||
|
||||
mappings = self.symbol_mappings
|
||||
for symbol, keywords in mappings.items():
|
||||
# 检查是否匹配
|
||||
for keyword in keywords:
|
||||
if keyword.lower() in text_lower:
|
||||
found_symbols.append(symbol)
|
||||
break
|
||||
|
||||
return found_symbols
|
||||
|
||||
def _check_keywords(self, text: str, category: str) -> Tuple[float, str]:
|
||||
"""
|
||||
检查关键词并返回影响评分
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
category: 分类
|
||||
|
||||
Returns:
|
||||
(影响评分, 原因)
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
keywords_config = self.crypto_keywords if category == 'crypto' else self.stock_keywords
|
||||
|
||||
# 检查高影响关键词
|
||||
for keyword in keywords_config['high_impact']:
|
||||
if keyword.lower() in text_lower:
|
||||
return 1.0, f"匹配高影响关键词: {keyword}"
|
||||
|
||||
# 检查中等影响关键词
|
||||
for keyword in keywords_config['medium_impact']:
|
||||
if keyword.lower() in text_lower:
|
||||
return 0.7, f"匹配中等影响关键词: {keyword}"
|
||||
|
||||
return 0.0, "未匹配关键词"
|
||||
|
||||
def _calculate_quality_score(self, item: NewsItem) -> float:
|
||||
"""
|
||||
计算新闻质量分数
|
||||
|
||||
Args:
|
||||
item: 新闻项
|
||||
|
||||
Returns:
|
||||
质量分数 0-1
|
||||
"""
|
||||
score = 0.5 # 基础分
|
||||
|
||||
# 内容长度
|
||||
if len(item.content) > 500:
|
||||
score += 0.1
|
||||
if len(item.content) > 1000:
|
||||
score += 0.1
|
||||
|
||||
# 标题长度
|
||||
if 20 <= len(item.title) <= 150:
|
||||
score += 0.1
|
||||
|
||||
# 有作者
|
||||
if item.author:
|
||||
score += 0.1
|
||||
|
||||
# 有标签
|
||||
if item.tags and len(item.tags) > 0:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def filter_news(self, items: List[NewsItem], min_quality: float = 0.3) -> List[NewsItem]:
|
||||
"""
|
||||
过滤新闻列表
|
||||
|
||||
Args:
|
||||
items: 新闻项列表
|
||||
min_quality: 最低质量分数
|
||||
|
||||
Returns:
|
||||
过滤后的新闻列表,附带影响评分
|
||||
"""
|
||||
filtered = []
|
||||
low_quality_count = 0
|
||||
no_keywords_count = 0
|
||||
|
||||
for item in items:
|
||||
# 计算质量分数
|
||||
quality_score = self._calculate_quality_score(item)
|
||||
|
||||
# 质量过滤
|
||||
if quality_score < min_quality:
|
||||
low_quality_count += 1
|
||||
continue
|
||||
|
||||
# 关键词检查
|
||||
text_to_check = f"{item.title} {item.content[:500]}"
|
||||
impact_score, impact_reason = self._check_keywords(text_to_check, item.category)
|
||||
|
||||
# 提取相关代码
|
||||
symbols = self._extract_symbols(text_to_check, item.category)
|
||||
|
||||
# 附加属性
|
||||
item.quality_score = quality_score
|
||||
item.impact_score = impact_score
|
||||
item.impact_reason = impact_reason
|
||||
item.relevant_symbols = symbols
|
||||
|
||||
# 至少匹配关键词
|
||||
if impact_score > 0:
|
||||
filtered.append(item)
|
||||
else:
|
||||
no_keywords_count += 1
|
||||
|
||||
logger.info(f"过滤结果: {len(filtered)} 条通过, {low_quality_count} 条低质量, {no_keywords_count} 条无关键词")
|
||||
|
||||
return filtered
|
||||
|
||||
def get_priority_score(self, item: NewsItem) -> float:
|
||||
"""
|
||||
计算优先级分数
|
||||
|
||||
Args:
|
||||
item: 新闻项
|
||||
|
||||
Returns:
|
||||
优先级分数
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 影响分数
|
||||
score += getattr(item, 'impact_score', 0.0) * 50
|
||||
|
||||
# 质量分数
|
||||
score += getattr(item, 'quality_score', 0.5) * 20
|
||||
|
||||
# 是否有相关代码
|
||||
if hasattr(item, 'relevant_symbols') and item.relevant_symbols:
|
||||
score += 10
|
||||
|
||||
# 新闻新鲜度(最近发布的优先)
|
||||
if item.published_at:
|
||||
hours_ago = (datetime.utcnow() - item.published_at).total_seconds() / 3600
|
||||
score += max(0, 20 - hours_ago)
|
||||
|
||||
return score
|
||||
@ -1,338 +0,0 @@
|
||||
"""
|
||||
新闻智能体 - 主控制器
|
||||
实时抓取、分析、通知重要新闻
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from app.news_agent.sources import get_enabled_sources
|
||||
from app.news_agent.fetcher import NewsFetcher, NewsItem
|
||||
from app.news_agent.filter import NewsDeduplicator, NewsFilter
|
||||
from app.news_agent.analyzer import NewsAnalyzer, NewsAnalyzerSimple
|
||||
from app.news_agent.news_db_service import get_news_db_service
|
||||
from app.news_agent.notifier import get_news_notifier
|
||||
|
||||
|
||||
class NewsAgent:
|
||||
"""新闻智能体 - 主控制器"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""初始化新闻智能体"""
|
||||
if NewsAgent._initialized:
|
||||
return
|
||||
|
||||
NewsAgent._initialized = True
|
||||
self.settings = get_settings()
|
||||
|
||||
# 核心组件
|
||||
self.fetcher = NewsFetcher()
|
||||
self.deduplicator = NewsDeduplicator()
|
||||
self.filter = NewsFilter()
|
||||
self.analyzer = NewsAnalyzer() # LLM 分析器
|
||||
self.simple_analyzer = NewsAnalyzerSimple() # 规则分析器(备用)
|
||||
self.db_service = get_news_db_service()
|
||||
self.notifier = get_news_notifier()
|
||||
|
||||
# 配置
|
||||
self.fetch_interval = 300 # 抓取间隔(秒)= 5分钟
|
||||
self.min_priority = 40.0 # 最低通知优先级
|
||||
self.use_llm = True # 使用 LLM 批量分析
|
||||
|
||||
# 统计数据
|
||||
self.stats = {
|
||||
'total_fetched': 0,
|
||||
'total_saved': 0,
|
||||
'total_analyzed': 0,
|
||||
'total_notified': 0,
|
||||
'last_fetch_time': None,
|
||||
'last_notify_time': None
|
||||
}
|
||||
|
||||
# 运行状态
|
||||
self.running = False
|
||||
self._task = None
|
||||
|
||||
logger.info("新闻智能体初始化完成")
|
||||
|
||||
async def start(self):
|
||||
"""启动新闻智能体"""
|
||||
if self.running:
|
||||
logger.warning("新闻智能体已在运行")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
# 发送启动通知
|
||||
sources = get_enabled_sources()
|
||||
crypto_count = sum(1 for s in sources if s['category'] == 'crypto')
|
||||
stock_count = sum(1 for s in sources if s['category'] == 'stock')
|
||||
|
||||
await self.notifier.notify_startup({
|
||||
'crypto_sources': crypto_count,
|
||||
'stock_sources': stock_count,
|
||||
'fetch_interval': self.fetch_interval
|
||||
})
|
||||
|
||||
# 启动后台任务
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
|
||||
logger.info("新闻智能体已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止新闻智能体"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.fetcher.close()
|
||||
|
||||
logger.info("新闻智能体已停止")
|
||||
|
||||
async def _run_loop(self):
|
||||
"""主循环"""
|
||||
while self.running:
|
||||
try:
|
||||
await self._fetch_and_process_news()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"新闻处理循环出错: {e}")
|
||||
await self.notifier.notify_error(str(e))
|
||||
|
||||
# 等待下一次抓取
|
||||
await asyncio.sleep(self.fetch_interval)
|
||||
|
||||
async def _fetch_and_process_news(self):
|
||||
"""抓取并处理新闻"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始新闻处理周期")
|
||||
|
||||
# 1. 抓取新闻
|
||||
items = await self.fetcher.fetch_all_news()
|
||||
self.stats['total_fetched'] += len(items)
|
||||
self.stats['last_fetch_time'] = datetime.utcnow().isoformat()
|
||||
|
||||
if not items:
|
||||
logger.info("没有获取到新新闻")
|
||||
return
|
||||
|
||||
logger.info(f"获取到 {len(items)} 条新闻")
|
||||
|
||||
# 2. 去重
|
||||
items = self.deduplicator.deduplicate_list(items)
|
||||
logger.info(f"去重后剩余 {len(items)} 条")
|
||||
|
||||
# 3. 过滤
|
||||
filtered_items = self.filter.filter_news(items)
|
||||
logger.info(f"过滤后剩余 {len(filtered_items)} 条")
|
||||
|
||||
if not filtered_items:
|
||||
logger.info("没有符合条件的新闻")
|
||||
return
|
||||
|
||||
# 4. 保存到数据库
|
||||
saved_articles = []
|
||||
for item in filtered_items:
|
||||
# 检查数据库中是否已存在
|
||||
if self.db_service.check_duplicate_by_hash(item.content_hash):
|
||||
continue
|
||||
|
||||
# 保存
|
||||
article_data = {
|
||||
'title': item.title,
|
||||
'content': item.content,
|
||||
'url': item.url,
|
||||
'source': item.source,
|
||||
'author': item.author,
|
||||
'category': item.category,
|
||||
'tags': item.tags,
|
||||
'published_at': item.published_at,
|
||||
'crawled_at': item.crawled_at,
|
||||
'content_hash': item.content_hash,
|
||||
'quality_score': getattr(item, 'quality_score', 0.5),
|
||||
}
|
||||
|
||||
article = self.db_service.save_article(article_data)
|
||||
if article:
|
||||
saved_articles.append((article, item))
|
||||
|
||||
self.stats['total_saved'] += len(saved_articles)
|
||||
logger.info(f"保存了 {len(saved_articles)} 条新文章")
|
||||
|
||||
if not saved_articles:
|
||||
return
|
||||
|
||||
# 5. LLM 分析(仅批量分析)
|
||||
analyzed_count = 0
|
||||
high_priority_articles = []
|
||||
|
||||
if self.use_llm:
|
||||
# 只使用批量分析 (异步)
|
||||
items_to_analyze = [item for _, item in saved_articles]
|
||||
results = await self.analyzer.analyze_batch(items_to_analyze)
|
||||
|
||||
for (article, _), result in zip(saved_articles, results):
|
||||
if result:
|
||||
priority = self.analyzer.calculate_priority(
|
||||
result,
|
||||
getattr(article, 'quality_score', 0.5)
|
||||
)
|
||||
self.db_service.mark_as_analyzed(article.id, result, priority)
|
||||
|
||||
analyzed_count += 1
|
||||
# 只发送重大影响(high)的新闻
|
||||
if result.get('market_impact') == 'high':
|
||||
article_dict = article.to_dict()
|
||||
article_dict.update({
|
||||
'llm_analyzed': True,
|
||||
'market_impact': result.get('market_impact'),
|
||||
'impact_type': result.get('impact_type'),
|
||||
'sentiment': result.get('sentiment'),
|
||||
'summary': result.get('summary'),
|
||||
'key_points': result.get('key_points'),
|
||||
'trading_advice': result.get('trading_advice'),
|
||||
'relevant_symbols': result.get('relevant_symbols'),
|
||||
'priority': priority,
|
||||
})
|
||||
high_priority_articles.append(article_dict)
|
||||
|
||||
else:
|
||||
# 使用规则分析
|
||||
for article, item in saved_articles:
|
||||
result = self.simple_analyzer.analyze_single(item)
|
||||
priority = result.get('confidence', 50)
|
||||
|
||||
self.db_service.mark_as_analyzed(article.id, result, priority)
|
||||
analyzed_count += 1
|
||||
|
||||
# 只发送重大影响(high)的新闻
|
||||
if result.get('market_impact') == 'high':
|
||||
article_dict = article.to_dict()
|
||||
article_dict.update({
|
||||
'llm_analyzed': True,
|
||||
'market_impact': result.get('market_impact'),
|
||||
'impact_type': result.get('impact_type'),
|
||||
'sentiment': result.get('sentiment'),
|
||||
'summary': result.get('summary'),
|
||||
'key_points': result.get('key_points'),
|
||||
'trading_advice': result.get('trading_advice'),
|
||||
'relevant_symbols': result.get('relevant_symbols'),
|
||||
'priority': priority,
|
||||
})
|
||||
high_priority_articles.append(article_dict)
|
||||
|
||||
self.stats['total_analyzed'] += analyzed_count
|
||||
logger.info(f"分析了 {analyzed_count} 条文章")
|
||||
|
||||
# 6. 发送通知(仅批量发送)- 增加过滤条件
|
||||
if high_priority_articles:
|
||||
# 按优先级排序
|
||||
high_priority_articles.sort(
|
||||
key=lambda x: x.get('priority', 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 额外过滤:只推送真正重要的新闻
|
||||
truly_important_articles = []
|
||||
for article in high_priority_articles:
|
||||
impact = article.get('market_impact', 'low')
|
||||
priority = article.get('priority', 0)
|
||||
confidence = article.get('llm_analyzed', False) and article.get('relevant_symbols')
|
||||
|
||||
# 推送条件(满足其一即可):
|
||||
# 1. high 影响 + 优先级 >= 55
|
||||
# 2. high 影响 + 有明确相关代码
|
||||
# 3. 优先级 >= 60(特别重要)
|
||||
should_notify = (
|
||||
(impact == 'high' and priority >= 55) or
|
||||
(impact == 'high' and confidence) or
|
||||
(priority >= 60)
|
||||
)
|
||||
|
||||
if should_notify:
|
||||
truly_important_articles.append(article)
|
||||
|
||||
# 批量发送最多10条
|
||||
if truly_important_articles:
|
||||
await self.notifier.notify_news_batch(truly_important_articles[:10])
|
||||
for article in truly_important_articles[:10]:
|
||||
self.db_service.mark_as_notified(article['id'])
|
||||
self.stats['total_notified'] += 1
|
||||
|
||||
logger.info(f"推送了 {len(truly_important_articles)} 条真正重要的新闻(从 {len(high_priority_articles)} 条 high 中筛选)")
|
||||
else:
|
||||
logger.info(f"没有达到推送标准的新闻({len(high_priority_articles)} 条 high 但不够重要)")
|
||||
|
||||
self.stats['last_notify_time'] = datetime.utcnow().isoformat()
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计数据"""
|
||||
stats = self.stats.copy()
|
||||
stats['running'] = self.running
|
||||
stats['fetch_interval'] = self.fetch_interval
|
||||
stats['use_llm'] = self.use_llm
|
||||
|
||||
# 从数据库获取更多统计
|
||||
db_stats = self.db_service.get_stats(hours=24)
|
||||
stats['db_stats'] = db_stats
|
||||
|
||||
return stats
|
||||
|
||||
async def manual_fetch(self, category: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发新闻抓取
|
||||
|
||||
Args:
|
||||
category: 分类过滤
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
logger.info(f"手动触发新闻抓取: category={category}")
|
||||
|
||||
items = await self.fetcher.fetch_all_news(category)
|
||||
|
||||
result = {
|
||||
'fetched': len(items),
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if items:
|
||||
# 这里可以触发处理流程
|
||||
# 为简化,只返回抓取结果
|
||||
result['items'] = [item.to_dict() for item in items[:5]]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 全局实例
|
||||
_news_agent = None
|
||||
|
||||
|
||||
def get_news_agent() -> NewsAgent:
|
||||
"""获取新闻智能体单例"""
|
||||
global _news_agent
|
||||
if _news_agent is None:
|
||||
_news_agent = NewsAgent()
|
||||
return _news_agent
|
||||
@ -1,406 +0,0 @@
|
||||
"""
|
||||
新闻数据库服务
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy import create_engine, and_, or_
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.models.news import NewsArticle
|
||||
from app.models.database import Base
|
||||
from app.config import get_settings
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class NewsDatabaseService:
|
||||
"""新闻数据库服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.engine = None
|
||||
self.SessionLocal = None
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库连接"""
|
||||
try:
|
||||
# 使用 settings.database_url 或构建路径
|
||||
if hasattr(self.settings, 'database_url'):
|
||||
database_url = self.settings.database_url
|
||||
elif hasattr(self.settings, 'database_path'):
|
||||
database_url = f"sqlite:///{self.settings.database_path}"
|
||||
else:
|
||||
# 默认路径
|
||||
database_url = "sqlite:///./backend/stock_agent.db"
|
||||
|
||||
self.engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
echo=False
|
||||
)
|
||||
|
||||
self.SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=self.engine
|
||||
)
|
||||
|
||||
# 创建表(如果不存在)
|
||||
from app.models.news import NewsArticle
|
||||
NewsArticle.metadata.create_all(self.engine, checkfirst=True)
|
||||
|
||||
logger.info("新闻数据库服务初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"新闻数据库初始化失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
# 重新抛出异常,避免 SessionLocal 为 None
|
||||
raise
|
||||
|
||||
def get_session(self) -> Session:
|
||||
"""获取数据库会话"""
|
||||
return self.SessionLocal()
|
||||
|
||||
def save_article(self, article_data: Dict[str, Any]) -> Optional[NewsArticle]:
|
||||
"""
|
||||
保存单篇文章
|
||||
|
||||
Args:
|
||||
article_data: 文章数据字典
|
||||
|
||||
Returns:
|
||||
保存的文章对象或 None
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
article = NewsArticle(**article_data)
|
||||
session.add(article)
|
||||
session.commit()
|
||||
session.refresh(article)
|
||||
|
||||
logger.debug(f"文章保存成功: {article.title[:50]}...")
|
||||
return article
|
||||
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
logger.debug(f"文章已存在(URL 重复): {article_data.get('url', '')}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"保存文章失败: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def check_duplicate_by_hash(self, content_hash: str, hours: int = 24) -> bool:
|
||||
"""
|
||||
检查内容哈希是否重复
|
||||
|
||||
Args:
|
||||
content_hash: 内容哈希
|
||||
hours: 检查最近多少小时
|
||||
|
||||
Returns:
|
||||
True 如果重复
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
since = datetime.utcnow() - timedelta(hours=hours)
|
||||
|
||||
count = session.query(NewsArticle).filter(
|
||||
and_(
|
||||
NewsArticle.content_hash == content_hash,
|
||||
NewsArticle.created_at >= since
|
||||
)
|
||||
).count()
|
||||
|
||||
return count > 0
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def mark_as_analyzed(
|
||||
self,
|
||||
article_id: int,
|
||||
analysis: Dict[str, Any],
|
||||
priority: float
|
||||
) -> bool:
|
||||
"""
|
||||
标记文章已分析
|
||||
|
||||
Args:
|
||||
article_id: 文章 ID
|
||||
analysis: LLM 分析结果
|
||||
priority: 优先级分数
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
article = session.query(NewsArticle).filter(
|
||||
NewsArticle.id == article_id
|
||||
).first()
|
||||
|
||||
if not article:
|
||||
logger.warning(f"文章不存在: {article_id}")
|
||||
return False
|
||||
|
||||
article.llm_analyzed = True
|
||||
article.market_impact = analysis.get('market_impact')
|
||||
article.impact_type = analysis.get('impact_type')
|
||||
article.sentiment = analysis.get('sentiment')
|
||||
article.summary = analysis.get('summary')
|
||||
article.key_points = analysis.get('key_points')
|
||||
article.trading_advice = analysis.get('trading_advice')
|
||||
article.relevant_symbols = analysis.get('relevant_symbols')
|
||||
article.quality_score = analysis.get('confidence', 70) / 100
|
||||
article.priority = priority
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.debug(f"文章分析结果已保存: {article.title[:50]}...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"保存分析结果失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def mark_as_notified(self, article_id: int, channel: str = 'feishu') -> bool:
|
||||
"""
|
||||
标记文章已发送通知
|
||||
|
||||
Args:
|
||||
article_id: 文章 ID
|
||||
channel: 通知渠道
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
article = session.query(NewsArticle).filter(
|
||||
NewsArticle.id == article_id
|
||||
).first()
|
||||
|
||||
if not article:
|
||||
return False
|
||||
|
||||
article.notified = True
|
||||
article.notification_sent_at = datetime.utcnow()
|
||||
article.notification_channel = channel
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"标记通知状态失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_high_priority_articles(
|
||||
self,
|
||||
limit: int = 20,
|
||||
min_priority: float = 40.0,
|
||||
hours: int = 24
|
||||
) -> List[NewsArticle]:
|
||||
"""
|
||||
获取高优先级文章
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
min_priority: 最低优先级分数
|
||||
hours: 查询最近多少小时
|
||||
|
||||
Returns:
|
||||
文章列表
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
since = datetime.utcnow() - timedelta(hours=hours)
|
||||
|
||||
articles = session.query(NewsArticle).filter(
|
||||
and_(
|
||||
NewsArticle.llm_analyzed == True,
|
||||
NewsArticle.priority >= min_priority,
|
||||
NewsArticle.created_at >= since,
|
||||
NewsArticle.notified == False
|
||||
)
|
||||
).order_by(NewsArticle.priority.desc()).limit(limit).all()
|
||||
|
||||
return articles
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_latest_articles(
|
||||
self,
|
||||
category: str = None,
|
||||
limit: int = 50,
|
||||
hours: int = 24
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取最新文章
|
||||
|
||||
Args:
|
||||
category: 分类过滤
|
||||
limit: 返回数量限制
|
||||
hours: 查询最近多少小时
|
||||
|
||||
Returns:
|
||||
文章字典列表
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
since = datetime.utcnow() - timedelta(hours=hours)
|
||||
|
||||
query = session.query(NewsArticle).filter(
|
||||
NewsArticle.created_at >= since
|
||||
)
|
||||
|
||||
if category:
|
||||
query = query.filter(NewsArticle.category == category)
|
||||
|
||||
articles = query.order_by(
|
||||
NewsArticle.created_at.desc()
|
||||
).limit(limit).all()
|
||||
|
||||
return [article.to_dict() for article in articles]
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_stats(self, hours: int = 24) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计数据
|
||||
|
||||
Args:
|
||||
hours: 统计最近多少小时
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
since = datetime.utcnow() - timedelta(hours=hours)
|
||||
|
||||
total = session.query(NewsArticle).filter(
|
||||
NewsArticle.created_at >= since
|
||||
).count()
|
||||
|
||||
analyzed = session.query(NewsArticle).filter(
|
||||
and_(
|
||||
NewsArticle.created_at >= since,
|
||||
NewsArticle.llm_analyzed == True
|
||||
)
|
||||
).count()
|
||||
|
||||
high_impact = session.query(NewsArticle).filter(
|
||||
and_(
|
||||
NewsArticle.created_at >= since,
|
||||
NewsArticle.market_impact == 'high'
|
||||
)
|
||||
).count()
|
||||
|
||||
notified = session.query(NewsArticle).filter(
|
||||
and_(
|
||||
NewsArticle.created_at >= since,
|
||||
NewsArticle.notified == True
|
||||
)
|
||||
).count()
|
||||
|
||||
return {
|
||||
'total_articles': total,
|
||||
'analyzed': analyzed,
|
||||
'high_impact': high_impact,
|
||||
'notified': notified,
|
||||
'hours': hours
|
||||
}
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_unanalyzed_articles(self, limit: int = 50, hours: int = 24) -> List[NewsArticle]:
|
||||
"""
|
||||
获取未分析的文章
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
hours: 查询最近多少小时
|
||||
|
||||
Returns:
|
||||
未分析的文章列表
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
since = datetime.utcnow() - timedelta(hours=hours)
|
||||
|
||||
articles = session.query(NewsArticle).filter(
|
||||
and_(
|
||||
NewsArticle.llm_analyzed == False,
|
||||
NewsArticle.created_at >= since
|
||||
)
|
||||
).order_by(NewsArticle.created_at.desc()).limit(limit).all()
|
||||
|
||||
return articles
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def clean_old_articles(self, days: int = 7) -> int:
|
||||
"""
|
||||
清理旧文章(设置为不活跃)
|
||||
|
||||
Args:
|
||||
days: 保留多少天的文章
|
||||
|
||||
Returns:
|
||||
清理的数量
|
||||
"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
before = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
count = session.query(NewsArticle).filter(
|
||||
NewsArticle.created_at < before
|
||||
).update({
|
||||
'is_active': False
|
||||
})
|
||||
|
||||
session.commit()
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"清理了 {count} 条旧文章")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"清理旧文章失败: {e}")
|
||||
return 0
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# 全局实例
|
||||
_news_db_service = None
|
||||
|
||||
|
||||
def get_news_db_service() -> NewsDatabaseService:
|
||||
"""获取新闻数据库服务单例"""
|
||||
global _news_db_service
|
||||
if _news_db_service is None:
|
||||
_news_db_service = NewsDatabaseService()
|
||||
return _news_db_service
|
||||
@ -1,307 +0,0 @@
|
||||
"""
|
||||
新闻通知模块 - 发送飞书卡片通知
|
||||
"""
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.services.feishu_service import get_feishu_news_service
|
||||
|
||||
|
||||
class NewsNotifier:
|
||||
"""新闻通知器"""
|
||||
|
||||
def __init__(self):
|
||||
self.feishu = get_feishu_news_service()
|
||||
|
||||
def _get_emoji_for_impact(self, impact: str) -> str:
|
||||
"""根据影响级别获取表情符号"""
|
||||
emoji_map = {
|
||||
'high': '🔴',
|
||||
'medium': '🟡',
|
||||
'low': '🟢'
|
||||
}
|
||||
return emoji_map.get(impact, '📰')
|
||||
|
||||
def _get_emoji_for_impact_type(self, impact_type: str) -> str:
|
||||
"""根据影响类型获取表情符号"""
|
||||
emoji_map = {
|
||||
'bullish': '📈',
|
||||
'bearish': '📉',
|
||||
'neutral': '➡️'
|
||||
}
|
||||
return emoji_map.get(impact_type, '📊')
|
||||
|
||||
def _get_color_for_impact(self, impact: str) -> str:
|
||||
"""根据影响级别获取颜色"""
|
||||
color_map = {
|
||||
'high': 'red',
|
||||
'medium': 'orange',
|
||||
'low': 'blue'
|
||||
}
|
||||
return color_map.get(impact, 'grey')
|
||||
|
||||
async def notify_single_news(self, article: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
发送单条新闻通知
|
||||
|
||||
Args:
|
||||
article: 文章数据(包含分析结果)
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
impact = article.get('market_impact', 'low')
|
||||
impact_type = article.get('impact_type', 'neutral')
|
||||
title = article.get('title', '')
|
||||
summary = article.get('summary', '')
|
||||
source = article.get('source', '')
|
||||
category = article.get('category', '')
|
||||
url = article.get('url', '')
|
||||
trading_advice = article.get('trading_advice', '')
|
||||
relevant_symbols = article.get('relevant_symbols', [])
|
||||
key_points = article.get('key_points', [])
|
||||
|
||||
# 标题
|
||||
impact_emoji = self._get_emoji_for_impact(impact)
|
||||
type_emoji = self._get_emoji_for_impact_type(impact_type)
|
||||
category_text = '加密货币' if category == 'crypto' else '股票'
|
||||
|
||||
card_title = f"{impact_emoji} {type_emoji} 市场快讯 - {category_text}"
|
||||
|
||||
# 内容
|
||||
content_parts = [
|
||||
f"**来源**: {source}",
|
||||
f"**标题**: {title}",
|
||||
"",
|
||||
f"**摘要**: {summary}",
|
||||
]
|
||||
|
||||
# 关键点
|
||||
if key_points:
|
||||
content_parts.append("")
|
||||
content_parts.append("**关键点**:")
|
||||
for point in key_points[:3]:
|
||||
content_parts.append(f"• {point}")
|
||||
|
||||
# 交易建议
|
||||
if trading_advice:
|
||||
content_parts.append("")
|
||||
content_parts.append(f"**交易建议**: {trading_advice}")
|
||||
|
||||
# 相关代码
|
||||
if relevant_symbols:
|
||||
symbols_text = " ".join(relevant_symbols)
|
||||
content_parts.append("")
|
||||
content_parts.append(f"**相关**: {symbols_text}")
|
||||
|
||||
# 链接
|
||||
if url:
|
||||
content_parts.append("")
|
||||
content_parts.append(f"[查看原文]({url})")
|
||||
|
||||
# 影响
|
||||
impact_map = {'high': '重大影响', 'medium': '中等影响', 'low': '轻微影响'}
|
||||
content_parts.append("")
|
||||
content_parts.append(f"**影响**: {impact_map.get(impact, '未知')}")
|
||||
|
||||
# 颜色
|
||||
color = self._get_color_for_impact(impact)
|
||||
|
||||
# 发送
|
||||
content = "\n".join(content_parts)
|
||||
await self.feishu.send_card(card_title, content, color)
|
||||
|
||||
logger.info(f"新闻通知已发送: {title[:50]}...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送新闻通知失败: {e}")
|
||||
return False
|
||||
|
||||
async def notify_news_batch(self, articles: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
发送批量新闻通知(详细模式)
|
||||
|
||||
Args:
|
||||
articles: 文章列表
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
if not articles:
|
||||
return False
|
||||
|
||||
# 只显示重大影响新闻
|
||||
high_impact = [a for a in articles if a.get('market_impact') == 'high']
|
||||
|
||||
if not high_impact:
|
||||
logger.info("没有重大影响新闻,跳过通知")
|
||||
return False
|
||||
|
||||
title = f"🔴 重大市场新闻 ({len(high_impact)} 条)"
|
||||
|
||||
content_parts = []
|
||||
|
||||
# 获取时间(显示时分)
|
||||
created_time = high_impact[0].get('created_at', '')
|
||||
if created_time:
|
||||
# 格式: 2026-02-25T12:30:45 -> 02-25 12:30
|
||||
try:
|
||||
dt = created_time[:16].replace('T', ' ')
|
||||
content_parts.append(f"**时间**: {dt}")
|
||||
except:
|
||||
content_parts.append(f"**时间**: {created_time[:10]}")
|
||||
|
||||
# 只显示重大影响新闻
|
||||
for i, article in enumerate(high_impact[:5]):
|
||||
impact_type = article.get('impact_type', 'neutral')
|
||||
emoji = self._get_emoji_for_impact_type(impact_type)
|
||||
|
||||
# 每条新闻之间空一行
|
||||
if i > 0:
|
||||
content_parts.append("")
|
||||
|
||||
# 构建单条新闻的所有内容
|
||||
article_lines = []
|
||||
|
||||
# 标题
|
||||
title_text = article.get('title', '')
|
||||
article_lines.append(f"{emoji} **{title_text}**")
|
||||
|
||||
# 来源
|
||||
source = article.get('source', '')
|
||||
if source:
|
||||
article_lines.append(f"📰 来源: {source}")
|
||||
|
||||
# 新闻内容(摘要)
|
||||
summary = article.get('summary', '')
|
||||
content = article.get('content', '')
|
||||
if summary:
|
||||
article_lines.append(f"📝 {summary[:100]}")
|
||||
elif content:
|
||||
article_lines.append(f"📝 {content[:100]}")
|
||||
|
||||
# 影响和建议
|
||||
impact_desc = {
|
||||
'bullish': '📈 利好',
|
||||
'bearish': '📉 利空',
|
||||
'neutral': '➡️ 中性'
|
||||
}.get(impact_type, '➡️ 中性')
|
||||
|
||||
advice = article.get('trading_advice', '')
|
||||
if advice:
|
||||
article_lines.append(f"{impact_desc} | 💡 {advice}")
|
||||
|
||||
# 相关代码和链接
|
||||
extra_info = []
|
||||
symbols = article.get('relevant_symbols', [])
|
||||
if symbols and isinstance(symbols, list):
|
||||
extra_info.append(f"🔗 {' '.join(symbols[:4])}")
|
||||
|
||||
url = article.get('url', '')
|
||||
if url:
|
||||
extra_info.append(f"🔎 [查看原文]({url})")
|
||||
|
||||
if extra_info:
|
||||
article_lines.append(" ".join(extra_info))
|
||||
|
||||
# 将这条新闻的所有内容合并为一行
|
||||
content_parts.append(" | ".join(article_lines))
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
await self.feishu.send_card(title, content, "red")
|
||||
|
||||
logger.info(f"重大新闻通知已发送: {len(high_impact)} 条")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送批量新闻通知失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def notify_startup(self, config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
发送启动通知
|
||||
|
||||
Args:
|
||||
config: 配置信息
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
crypto_sources = config.get('crypto_sources', 0)
|
||||
stock_sources = config.get('stock_sources', 0)
|
||||
interval = config.get('fetch_interval', 30)
|
||||
|
||||
title = "📰 新闻智能体已启动"
|
||||
|
||||
content_parts = [
|
||||
f"🤖 **功能**: 实时新闻监控与分析",
|
||||
f"",
|
||||
f"📊 **监控来源**:",
|
||||
f" • 加密货币: {crypto_sources} 个",
|
||||
f" • 股票: {stock_sources} 个",
|
||||
f"",
|
||||
f"⏱️ **抓取频率**: 每 {interval} 秒",
|
||||
f"",
|
||||
f"🎯 **分析能力**:",
|
||||
f" • LLM 智能分析",
|
||||
f" • 市场影响评估",
|
||||
f" • 交易建议生成",
|
||||
f"",
|
||||
f"📢 **通知策略**: 仅推送高影响新闻"
|
||||
]
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
await self.feishu.send_card(title, content, "green")
|
||||
|
||||
logger.info("新闻智能体启动通知已发送")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送启动通知失败: {e}")
|
||||
return False
|
||||
|
||||
async def notify_error(self, error_message: str) -> bool:
|
||||
"""
|
||||
发送错误通知
|
||||
|
||||
Args:
|
||||
error_message: 错误信息
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
title = "⚠️ 新闻智能体异常"
|
||||
|
||||
content = f"""
|
||||
**错误信息**: {error_message}
|
||||
|
||||
**建议操作**:
|
||||
1. 检查网络连接
|
||||
2. 查看日志文件
|
||||
3. 必要时重启服务
|
||||
"""
|
||||
await self.feishu.send_card(title, content, "red")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送错误通知失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局实例
|
||||
_news_notifier = None
|
||||
|
||||
|
||||
def get_news_notifier() -> NewsNotifier:
|
||||
"""获取新闻通知器单例"""
|
||||
global _news_notifier
|
||||
if _news_notifier is None:
|
||||
_news_notifier = NewsNotifier()
|
||||
return _news_notifier
|
||||
@ -1,314 +0,0 @@
|
||||
"""
|
||||
新闻源配置
|
||||
定义各类新闻的 RSS 源
|
||||
"""
|
||||
|
||||
# 加密货币新闻源
|
||||
CRYPTO_NEWS_SOURCES = [
|
||||
{
|
||||
"name": "Cointelegraph",
|
||||
"url": "https://cointelegraph.com/rss",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 1.0, # 权重
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "CoinDesk",
|
||||
"url": "https://www.coindesk.com/arc/outboundfeeds/rss/",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 1.0,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "Decrypt",
|
||||
"url": "https://decrypt.co/feed",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 0.9,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "The Block",
|
||||
"url": "https://www.theblock.co/rss.xml",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 0.9,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "律动 BlockBeats",
|
||||
"url": "https://www.theblockbeats.info/feed",
|
||||
"category": "crypto",
|
||||
"language": "zh",
|
||||
"priority": 1.0,
|
||||
"enabled": False # RSS 格式问题,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "巴比特",
|
||||
"url": "https://www.8btc.com/feed",
|
||||
"category": "crypto",
|
||||
"language": "zh",
|
||||
"priority": 0.8,
|
||||
"enabled": False # 连接不稳定,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "CoinGlass",
|
||||
"url": "https://coinglass.com/news/rss",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 0.8,
|
||||
"enabled": False # 返回 HTML 而非 RSS,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "CryptoSlate",
|
||||
"url": "https://cryptoslate.com/news/feed",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 0.8,
|
||||
"enabled": False # RSS 格式问题,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "AMBCrypto",
|
||||
"url": "https://ambcrypto.com/feed",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 0.7,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "Whale Alert",
|
||||
"url": "https://whale-alert.io/rss",
|
||||
"category": "crypto",
|
||||
"language": "en",
|
||||
"priority": 0.7,
|
||||
"enabled": False # 大额转账,可选择性开启
|
||||
},
|
||||
]
|
||||
|
||||
# 股票新闻源
|
||||
STOCK_NEWS_SOURCES = [
|
||||
{
|
||||
"name": "Reuters Business",
|
||||
"url": "https://www.reuters.com/finance/rss",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 1.0,
|
||||
"enabled": False # 返回 HTML 而非 RSS,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "CNBC",
|
||||
"url": "https://www.cnbc.com/id/100003114/device/rss/rss.html",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 1.0,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "Bloomberg Markets",
|
||||
"url": "https://feeds.bloomberg.com/markets/news.rss",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 1.0,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "Yahoo Finance",
|
||||
"url": "https://finance.yahoo.com/news/rssindex",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 0.8,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "MarketWatch",
|
||||
"url": "https://www.marketwatch.com/rss/topstories",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 0.9,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "Seeking Alpha",
|
||||
"url": "https://seekingalpha.com/article/rss",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 0.9,
|
||||
"enabled": False # RSS 格式问题,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "华尔街见闻",
|
||||
"url": "https://wallstreetcn.com/rss",
|
||||
"category": "stock",
|
||||
"language": "zh",
|
||||
"priority": 0.9,
|
||||
"enabled": False # RSS 格式问题,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "雪球",
|
||||
"url": "https://xueqiu.com/statuses/hot_stock.xml",
|
||||
"category": "stock",
|
||||
"language": "zh",
|
||||
"priority": 0.8,
|
||||
"enabled": False # 需要认证,暂时禁用
|
||||
},
|
||||
{
|
||||
"name": "Investing.com",
|
||||
"url": "https://www.investing.com/rss/news.rss",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 0.8,
|
||||
"enabled": True
|
||||
},
|
||||
{
|
||||
"name": "Business Insider",
|
||||
"url": "https://markets.businessinsider.com/rss/news",
|
||||
"category": "stock",
|
||||
"language": "en",
|
||||
"priority": 0.7,
|
||||
"enabled": True
|
||||
},
|
||||
]
|
||||
|
||||
# 获取所有启用的新闻源
|
||||
def get_enabled_sources(category: str = None) -> list:
|
||||
"""
|
||||
获取启用的新闻源
|
||||
|
||||
Args:
|
||||
category: 分类过滤 ('crypto', 'stock', None 表示全部)
|
||||
|
||||
Returns:
|
||||
启用的新闻源列表
|
||||
"""
|
||||
all_sources = CRYPTO_NEWS_SOURCES + STOCK_NEWS_SOURCES
|
||||
|
||||
if category:
|
||||
return [s for s in all_sources if s['enabled'] and s['category'] == category]
|
||||
|
||||
return [s for s in all_sources if s['enabled']]
|
||||
|
||||
|
||||
# 关键词配置 - 用于第一级过滤
|
||||
CRYPTO_KEYWORDS = {
|
||||
'high_impact': [
|
||||
# 监管相关(只保留真正重大的)
|
||||
'SEC ETF', 'ETF approved', 'ETF rejected', 'regulation ban',
|
||||
'监管禁令', 'ETF批准', 'ETF拒绝',
|
||||
|
||||
# 重大事件(只保留真正重大的)
|
||||
'hack $', 'exploit $', ' $ billion hack', # 需要金额上下文,避免普通新闻
|
||||
'bankruptcy', '破产', 'shut down', '暂停交易',
|
||||
'exchange collapse', '交易所倒闭',
|
||||
|
||||
# 超级机构(重大新闻才推送)
|
||||
'BlackRock ETF', 'Grayscale ETF', 'Fidelity ETF',
|
||||
'贝莱德ETF', '灰度ETF',
|
||||
|
||||
# 重大安全事故
|
||||
'bridge exploit', 'smart contract hack', '$ million stolen',
|
||||
'跨链桥攻击', '智能合约漏洞',
|
||||
],
|
||||
'medium_impact': [
|
||||
# 一般监管
|
||||
'SEC', 'regulation', 'legal', '合规', '监管',
|
||||
'approve', 'ban', '禁令',
|
||||
|
||||
# 市场动态(常见的价格波动)
|
||||
'ATH', 'all-time high', 'crash', 'surge', 'plunge',
|
||||
'历史新高', '暴跌', '暴涨', '突破',
|
||||
|
||||
# 技术更新
|
||||
'upgrade', 'fork', 'airdrop', 'launch',
|
||||
'升级', '分叉', '空投', '上线',
|
||||
|
||||
# 并购/合作
|
||||
'partnership', 'acquisition', 'merger',
|
||||
'合作', '并购', '收购',
|
||||
|
||||
# 宏观经济
|
||||
'fed', 'inflation', 'recession', 'interest rate',
|
||||
'美联储', '通胀', '加息', '降息',
|
||||
|
||||
# 机构和钱包
|
||||
'whale', 'wallet', 'exchange',
|
||||
'巨鲸', '钱包', '交易所',
|
||||
]
|
||||
}
|
||||
|
||||
STOCK_KEYWORDS = {
|
||||
'high_impact': [
|
||||
# 只保留真正重大、罕见的事件
|
||||
# 破产/退市级别
|
||||
'bankruptcy', 'delisting', 'fraud', 'scandal',
|
||||
'破产', '退市', '欺诈', '丑闻',
|
||||
|
||||
# 重大监管事件
|
||||
'antitrust', 'DOJ ', 'SEC investigation', 'sanction',
|
||||
'反垄断', '司法部', '证监会调查', '制裁',
|
||||
|
||||
# 超级并购/分拆
|
||||
'mega merger', 'mega acquisition', 'breakup', 'spinoff',
|
||||
'巨型并购', '分拆',
|
||||
|
||||
# 重大安全事故/风险
|
||||
'data breach', 'cyber attack', 'massive layoff', 'shutdown',
|
||||
'数据泄露', '网络攻击', '大规模裁员', '停产',
|
||||
],
|
||||
'medium_impact': [
|
||||
# 财报相关(移到这里,因为太常见)
|
||||
'earnings', 'revenue', 'profit', 'loss', 'guidance',
|
||||
'财报', '营收', '利润', '业绩预告',
|
||||
'beat', 'miss', 'surprise',
|
||||
'超预期', '不及预期',
|
||||
|
||||
# 一般事件
|
||||
'FDA', 'approval', 'recall', 'lawsuit', 'IPO',
|
||||
'批准', '召回', '诉讼', '上市',
|
||||
|
||||
# 并购重组(一般规模)
|
||||
'merger', 'acquisition', 'buyout',
|
||||
'并购', '收购', '重组',
|
||||
|
||||
# 市场动态
|
||||
'surge', 'plunge', 'rally', 'crash',
|
||||
'暴涨', '暴跌', '反弹', '崩盘',
|
||||
|
||||
# 管理层变动
|
||||
'CEO', 'CFO', 'resign', 'appoint', 'executive',
|
||||
'辞职', '任命',
|
||||
|
||||
# 评级相关
|
||||
'upgrade', 'downgrade', 'rating', 'target price',
|
||||
'评级', '目标价', '上调', '下调',
|
||||
'dividend', 'buyback', 'split',
|
||||
'分红', '回购', '拆股',
|
||||
]
|
||||
}
|
||||
|
||||
# 常见的币种和股票代码映射
|
||||
SYMBOL_MAPPINGS = {
|
||||
# 加密货币
|
||||
'BTC': ['bitcoin', 'btc', '比特币'],
|
||||
'ETH': ['ethereum', 'eth', '以太坊'],
|
||||
'BNB': ['binance', 'bnb', '币安'],
|
||||
'SOL': ['solana', 'sol'],
|
||||
'XRP': ['ripple', 'xrp'],
|
||||
'ADA': ['cardano', 'ada'],
|
||||
'DOGE': ['dogecoin', 'doge', '狗狗币'],
|
||||
'AVAX': ['avalanche', 'avax'],
|
||||
'DOT': ['polkadot', 'dot'],
|
||||
'MATIC': ['polygon', 'matic'],
|
||||
|
||||
# 美股
|
||||
'AAPL': ['apple', 'aapl', '苹果'],
|
||||
'NVDA': ['nvidia', 'nvda', '英伟达'],
|
||||
'MSFT': ['microsoft', 'msft', '微软'],
|
||||
'GOOGL': ['google', 'alphabet', 'googl', '谷歌'],
|
||||
'AMZN': ['amazon', 'amzn', '亚马逊'],
|
||||
'TSLA': ['tesla', 'tsla', '特斯拉'],
|
||||
'META': ['meta', 'facebook', 'meta'],
|
||||
'BRK.B': ['berkshire', 'buffett', '伯克希尔'],
|
||||
'JPM': ['jpmorgan', 'jpm', '摩根大通'],
|
||||
}
|
||||
@ -50,15 +50,16 @@ class BitgetLiveTradingService:
|
||||
continue
|
||||
return 0.0
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, account_id: str = "default", trading_api: Any = None):
|
||||
self.account_id = (account_id or "default").strip() or "default"
|
||||
self.settings = get_settings()
|
||||
self.max_total_leverage: float = self.settings.bitget_max_total_leverage
|
||||
self.max_single_position: float = self.settings.bitget_max_single_position
|
||||
self.circuit_breaker_drawdown: float = self.settings.account_max_drawdown
|
||||
|
||||
self.trading_api = get_bitget_trading_api()
|
||||
self.trading_api = trading_api or get_bitget_trading_api(self.account_id)
|
||||
if not self.trading_api:
|
||||
raise RuntimeError("Bitget 交易 API 初始化失败,请检查 API Key 配置")
|
||||
raise RuntimeError(f"Bitget 交易 API 初始化失败,请检查账号 {self.account_id} 的 API Key 配置")
|
||||
|
||||
# 初始余额(用于回撤计算)
|
||||
self.initial_balance: Optional[float] = None
|
||||
@ -66,6 +67,7 @@ class BitgetLiveTradingService:
|
||||
|
||||
logger.info(
|
||||
f"✅ BitgetLiveTradingService 初始化完成 "
|
||||
f"(account={self.account_id}, "
|
||||
f"(最大总杠杆: {self.max_total_leverage}x, "
|
||||
f"单笔上限: ${self.max_single_position}, "
|
||||
f"熔断阈值: {self.circuit_breaker_drawdown * 100:.0f}%)"
|
||||
@ -94,10 +96,11 @@ class BitgetLiveTradingService:
|
||||
}
|
||||
"""
|
||||
balance = self.trading_api.get_balance()
|
||||
account_tag = f"[Bitget:{getattr(self, 'account_id', 'default')}]"
|
||||
if not balance:
|
||||
logger.warning("[Bitget] get_balance() 返回空,API 调用可能失败")
|
||||
logger.warning(f"{account_tag} get_balance() 返回空,API 调用可能失败")
|
||||
else:
|
||||
logger.debug(f"[Bitget] get_balance 原始返回: {balance}")
|
||||
logger.debug(f"{account_tag} get_balance 原始返回: {balance}")
|
||||
|
||||
usdt = balance.get('USDT') or balance.get('usdt') or {}
|
||||
if not usdt:
|
||||
@ -133,13 +136,13 @@ class BitgetLiveTradingService:
|
||||
inferred_available = max(account_value - frozen, 0.0)
|
||||
if inferred_available > 0:
|
||||
logger.warning(
|
||||
f"[Bitget] 可用余额字段缺失,使用 account_value - frozen 回退: "
|
||||
f"{account_tag} 可用余额字段缺失,使用 account_value - frozen 回退: "
|
||||
f"${account_value:.2f} - ${frozen:.2f} = ${inferred_available:.2f}"
|
||||
)
|
||||
available = inferred_available
|
||||
|
||||
logger.info(
|
||||
f"[Bitget] 账户状态: available=${available:.2f}, "
|
||||
f"{account_tag} 账户状态: available=${available:.2f}, "
|
||||
f"frozen=${frozen:.2f}, equity=${equity:.2f}, account_value=${account_value:.2f}"
|
||||
)
|
||||
|
||||
@ -561,10 +564,23 @@ class BitgetLiveTradingService:
|
||||
{"success": bool, "order_id": str, "error"?: str}
|
||||
"""
|
||||
try:
|
||||
open_orders = self.get_open_orders(symbol)
|
||||
normalized_order_id = str(order_id)
|
||||
matched_order = next((o for o in open_orders if str(o.get('order_id', '')) == normalized_order_id), None)
|
||||
if not matched_order:
|
||||
logger.info(f"ℹ️ Bitget 挂单已不存在,视为撤单完成: {symbol} #{order_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"order_id": normalized_order_id,
|
||||
"symbol": symbol,
|
||||
"already_closed": True,
|
||||
"message": "订单已不在挂单列表,可能已成交、已撤销或已失效",
|
||||
}
|
||||
|
||||
success = self.trading_api.cancel_order(symbol=symbol, order_id=order_id)
|
||||
if success:
|
||||
logger.info(f"✅ Bitget 单笔撤单成功: {symbol} #{order_id}")
|
||||
return {"success": True, "order_id": str(order_id), "symbol": symbol}
|
||||
return {"success": True, "order_id": normalized_order_id, "symbol": symbol}
|
||||
return {"success": False, "order_id": str(order_id), "error": "cancel_order 返回 False"}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Bitget 单笔撤单失败: {symbol} #{order_id} {e}")
|
||||
@ -750,32 +766,54 @@ class BitgetLiveTradingService:
|
||||
|
||||
# ==================== 单例工厂 ====================
|
||||
|
||||
_bitget_live_service: Optional[BitgetLiveTradingService] = None
|
||||
_bitget_live_services: Dict[str, BitgetLiveTradingService] = {}
|
||||
|
||||
|
||||
def get_bitget_live_service() -> Optional[BitgetLiveTradingService]:
|
||||
def get_bitget_live_service(account_id: str = "default") -> Optional[BitgetLiveTradingService]:
|
||||
"""
|
||||
获取 BitgetLiveTradingService 单例。
|
||||
获取 BitgetLiveTradingService 单例(按账号)。
|
||||
|
||||
bitget_trading_enabled=False 时返回 None(功能关闭)。
|
||||
"""
|
||||
global _bitget_live_service
|
||||
global _bitget_live_services
|
||||
|
||||
settings = get_settings()
|
||||
if not settings.bitget_trading_enabled:
|
||||
return None
|
||||
|
||||
if _bitget_live_service is None:
|
||||
normalized_account_id = (account_id or "default").strip() or "default"
|
||||
existing = _bitget_live_services.get(normalized_account_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
try:
|
||||
_bitget_live_service = BitgetLiveTradingService()
|
||||
service = BitgetLiveTradingService(account_id=normalized_account_id)
|
||||
_bitget_live_services[normalized_account_id] = service
|
||||
return service
|
||||
except Exception as e:
|
||||
logger.error(f"❌ BitgetLiveTradingService 初始化失败: {e}")
|
||||
logger.error(f"❌ BitgetLiveTradingService 初始化失败: account={normalized_account_id} error={e}")
|
||||
return None
|
||||
|
||||
return _bitget_live_service
|
||||
|
||||
def get_all_bitget_live_services() -> Dict[str, BitgetLiveTradingService]:
|
||||
"""获取所有已启用 Bitget 账号的服务实例。"""
|
||||
settings = get_settings()
|
||||
services: Dict[str, BitgetLiveTradingService] = {}
|
||||
for account in settings.get_enabled_bitget_accounts():
|
||||
account_id = account.get("account_id") or "default"
|
||||
service = get_bitget_live_service(account_id)
|
||||
if service:
|
||||
services[account_id] = service
|
||||
return services
|
||||
|
||||
|
||||
def reset_bitget_live_service():
|
||||
def reset_bitget_live_service(account_id: Optional[str] = None):
|
||||
"""重置单例(测试用)"""
|
||||
global _bitget_live_service
|
||||
_bitget_live_service = None
|
||||
global _bitget_live_services
|
||||
|
||||
if account_id is None:
|
||||
_bitget_live_services = {}
|
||||
return
|
||||
|
||||
normalized_account_id = (account_id or "default").strip() or "default"
|
||||
_bitget_live_services.pop(normalized_account_id, None)
|
||||
|
||||
@ -53,7 +53,12 @@ class BitgetTradingAPI:
|
||||
continue
|
||||
return default
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str, passphrase: str = "", use_testnet: bool = True):
|
||||
def __init__(self,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
passphrase: str = "",
|
||||
use_testnet: bool = True,
|
||||
account_id: str = "default"):
|
||||
"""
|
||||
初始化 Bitget 交易 API
|
||||
|
||||
@ -66,6 +71,7 @@ class BitgetTradingAPI:
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.use_testnet = use_testnet
|
||||
self.account_id = account_id or "default"
|
||||
from app.config import get_settings
|
||||
self.settings = get_settings()
|
||||
self.use_unified_account = getattr(self.settings, 'bitget_use_unified_account', True)
|
||||
@ -93,7 +99,10 @@ class BitgetTradingAPI:
|
||||
if use_testnet:
|
||||
logger.info("✅ Bitget 测试网模式(使用相同端点,由 API key 区分)")
|
||||
|
||||
logger.info(f"Bitget 交易 API 初始化完成 ({'测试网' if use_testnet else '生产网'})")
|
||||
logger.info(
|
||||
f"Bitget 交易 API 初始化完成 "
|
||||
f"(account={self.account_id}, {'测试网' if use_testnet else '生产网'})"
|
||||
)
|
||||
|
||||
# ==================== 订单操作 ====================
|
||||
|
||||
@ -1209,46 +1218,74 @@ class BitgetTradingAPI:
|
||||
logger.info("Bitget API 连接已关闭")
|
||||
|
||||
|
||||
# 全局实例(延迟初始化)
|
||||
_trading_api: Optional[BitgetTradingAPI] = None
|
||||
# 全局实例(按账号延迟初始化)
|
||||
_trading_api_instances: Dict[str, BitgetTradingAPI] = {}
|
||||
|
||||
|
||||
def get_bitget_trading_api() -> Optional[BitgetTradingAPI]:
|
||||
def get_bitget_trading_api(account_id: str = "default") -> Optional[BitgetTradingAPI]:
|
||||
"""
|
||||
获取 Bitget 交易 API 实例(单例)
|
||||
获取 Bitget 交易 API 实例(按账号缓存)。
|
||||
|
||||
Returns:
|
||||
BitgetTradingAPI 实例或 None(如果未配置)
|
||||
BitgetTradingAPI 实例或 None(如果未配置/未启用)
|
||||
"""
|
||||
global _trading_api
|
||||
|
||||
if _trading_api:
|
||||
return _trading_api
|
||||
normalized_account_id = (account_id or "default").strip() or "default"
|
||||
existing = _trading_api_instances.get(normalized_account_id)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# 检查是否配置了 API Key
|
||||
if not settings.bitget_api_key or not settings.bitget_api_secret:
|
||||
logger.warning("Bitget API Key 未配置<E9858D><E7BDAE>实盘交易功能不可用")
|
||||
account_config = settings.get_bitget_account_config(normalized_account_id)
|
||||
if not account_config.get("enabled"):
|
||||
logger.warning(f"Bitget 账号未启用或凭证不完整: account={normalized_account_id}")
|
||||
return None
|
||||
|
||||
# 创建实例
|
||||
_trading_api = BitgetTradingAPI(
|
||||
api_key=settings.bitget_api_key,
|
||||
api_secret=settings.bitget_api_secret,
|
||||
passphrase=settings.bitget_passphrase,
|
||||
use_testnet=settings.bitget_use_testnet
|
||||
instance = BitgetTradingAPI(
|
||||
api_key=account_config.get("api_key", ""),
|
||||
api_secret=account_config.get("api_secret", ""),
|
||||
passphrase=account_config.get("passphrase", ""),
|
||||
use_testnet=bool(account_config.get("use_testnet", settings.bitget_use_testnet)),
|
||||
account_id=normalized_account_id,
|
||||
)
|
||||
|
||||
return _trading_api
|
||||
_trading_api_instances[normalized_account_id] = instance
|
||||
return instance
|
||||
|
||||
|
||||
def reset_bitget_trading_api():
|
||||
"""重置全局实例(用于测试或配置更新)"""
|
||||
global _trading_api
|
||||
if _trading_api:
|
||||
_trading_api.close()
|
||||
_trading_api = None
|
||||
logger.info("Bitget API 实例已重置")
|
||||
def get_all_bitget_trading_apis() -> Dict[str, BitgetTradingAPI]:
|
||||
"""返回所有已启用账号的 API 实例。"""
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
instances: Dict[str, BitgetTradingAPI] = {}
|
||||
for account in settings.get_enabled_bitget_accounts():
|
||||
account_id = account.get("account_id") or "default"
|
||||
instance = get_bitget_trading_api(account_id)
|
||||
if instance:
|
||||
instances[account_id] = instance
|
||||
return instances
|
||||
|
||||
|
||||
def reset_bitget_trading_api(account_id: Optional[str] = None):
|
||||
"""重置全局实例(用于测试或配置更新)。"""
|
||||
global _trading_api_instances
|
||||
|
||||
if account_id is None:
|
||||
for instance in _trading_api_instances.values():
|
||||
try:
|
||||
instance.close()
|
||||
except Exception:
|
||||
pass
|
||||
_trading_api_instances = {}
|
||||
logger.info("Bitget API 实例已全部重置")
|
||||
return
|
||||
|
||||
normalized_account_id = (account_id or "default").strip() or "default"
|
||||
instance = _trading_api_instances.pop(normalized_account_id, None)
|
||||
if instance:
|
||||
try:
|
||||
instance.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Bitget API 实例已重置: account={normalized_account_id}")
|
||||
|
||||
@ -104,7 +104,7 @@ class CacheService:
|
||||
清除匹配模式的所有缓存
|
||||
|
||||
Args:
|
||||
pattern: 键模式(如 "stock:*")
|
||||
pattern: 键模式(如 "crypto:*")
|
||||
|
||||
Returns:
|
||||
删除的键数量
|
||||
|
||||
@ -258,8 +258,7 @@ class DingTalkService:
|
||||
|
||||
# 市场类型映射
|
||||
market_map = {
|
||||
'crypto': '[加密货币]',
|
||||
'stock': '[股票]'
|
||||
'crypto': '[加密货币]'
|
||||
}
|
||||
|
||||
action_text = action_map.get(action, action)
|
||||
@ -273,7 +272,7 @@ class DingTalkService:
|
||||
> **趋势**: {trend}
|
||||
> **信心度**: {confidence}%
|
||||
|
||||
*信号来源: Stock Agent*
|
||||
*信号来源: Crypto Agent*
|
||||
"""
|
||||
|
||||
return await self.send_markdown(title, content)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
"""
|
||||
飞书通知服务 - 通过 Webhook 发送交易信号通知
|
||||
支持加密货币和股票两个独立的 webhook
|
||||
飞书通知服务 - 通过 Webhook 发送加密货币交易通知
|
||||
"""
|
||||
import json
|
||||
import httpx
|
||||
@ -18,7 +17,7 @@ class FeishuService:
|
||||
|
||||
Args:
|
||||
webhook_url: 飞书机器人 Webhook URL(如果为空,则根据 service_type 从配置读取)
|
||||
service_type: 服务类型 ("crypto", "stock", "news", "paper_trading", "error")
|
||||
service_type: 服务类型 ("crypto", "paper_trading", "error")
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
@ -29,10 +28,6 @@ class FeishuService:
|
||||
# 否则根据服务类型从配置读取
|
||||
if service_type == "crypto":
|
||||
self.webhook_url = getattr(settings, 'feishu_crypto_webhook_url', '')
|
||||
elif service_type == "stock":
|
||||
self.webhook_url = getattr(settings, 'feishu_stock_webhook_url', '')
|
||||
elif service_type == "news":
|
||||
self.webhook_url = getattr(settings, 'feishu_news_webhook_url', '')
|
||||
elif service_type == "paper_trading":
|
||||
self.webhook_url = getattr(settings, 'feishu_paper_trading_webhook_url', '')
|
||||
elif service_type == "error":
|
||||
@ -289,10 +284,8 @@ class FeishuService:
|
||||
|
||||
|
||||
|
||||
# 全局实例(延迟初始化)- 分别用于加密货币、股票、新闻和模拟交易
|
||||
# 全局实例(延迟初始化)
|
||||
_feishu_crypto_service: Optional[FeishuService] = None
|
||||
_feishu_stock_service: Optional[FeishuService] = None
|
||||
_feishu_news_service: Optional[FeishuService] = None
|
||||
_feishu_paper_trading_service: Optional[FeishuService] = None
|
||||
_feishu_error_service: Optional[FeishuService] = None
|
||||
|
||||
@ -310,22 +303,6 @@ def get_feishu_crypto_service() -> FeishuService:
|
||||
return _feishu_crypto_service
|
||||
|
||||
|
||||
def get_feishu_stock_service() -> FeishuService:
|
||||
"""获取股票飞书服务实例"""
|
||||
global _feishu_stock_service
|
||||
if _feishu_stock_service is None:
|
||||
_feishu_stock_service = FeishuService(service_type="stock")
|
||||
return _feishu_stock_service
|
||||
|
||||
|
||||
def get_feishu_news_service() -> FeishuService:
|
||||
"""获取新闻智能体飞书服务实例"""
|
||||
global _feishu_news_service
|
||||
if _feishu_news_service is None:
|
||||
_feishu_news_service = FeishuService(service_type="news")
|
||||
return _feishu_news_service
|
||||
|
||||
|
||||
def get_feishu_paper_trading_service() -> FeishuService:
|
||||
"""获取模拟交易飞书服务实例"""
|
||||
global _feishu_paper_trading_service
|
||||
|
||||
@ -1,550 +0,0 @@
|
||||
"""
|
||||
基本面因子数据服务
|
||||
获取美股和港股的基本面数据,包括估值、盈利能力、成长性等指标
|
||||
"""
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import yfinance as yf
|
||||
YFINANCE_AVAILABLE = True
|
||||
except ImportError:
|
||||
YFINANCE_AVAILABLE = False
|
||||
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class FundamentalService:
|
||||
"""基本面因子数据服务"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
if not YFINANCE_AVAILABLE:
|
||||
logger.warning("yfinance 未安装,基本面数据功能将不可用")
|
||||
return
|
||||
|
||||
self._cache = {} # 数据缓存
|
||||
self._cache_time = {} # 缓存时间
|
||||
self._cache_ttl = 3600 # 缓存有效期1小时
|
||||
|
||||
logger.info("基本面数据服务初始化成功")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hk_symbol(symbol: str) -> str:
|
||||
"""
|
||||
标准化港股代码格式为 yfinance 要求的格式
|
||||
- 4位及以下:左侧补零到4位,如 700.HK → 0700.HK, 5.HK → 0005.HK
|
||||
- 5位及以上:去掉前导零,如 09618.HK → 9618.HK
|
||||
"""
|
||||
if not symbol.endswith('.HK'):
|
||||
return symbol
|
||||
|
||||
# 分离代码和后缀
|
||||
code_part = symbol[:-3] # 去掉 .HK
|
||||
suffix = '.HK'
|
||||
|
||||
# 如果是纯数字代码
|
||||
if code_part.isdigit():
|
||||
# 4位及以下:补零到4位
|
||||
if len(code_part) <= 4:
|
||||
normalized_code = code_part.zfill(4)
|
||||
# 5位及以上:去掉前导零
|
||||
else:
|
||||
normalized_code = code_part.lstrip('0') or '0'
|
||||
else:
|
||||
normalized_code = code_part
|
||||
|
||||
return normalized_code + suffix
|
||||
|
||||
def get_fundamental_data(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取股票的基本面数据
|
||||
|
||||
Args:
|
||||
symbol: 股票代码,如 'AAPL', '0700.HK'
|
||||
|
||||
Returns:
|
||||
基本面数据字典,包含估值、盈利、成长等指标
|
||||
"""
|
||||
if not YFINANCE_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 标准化港股代码格式
|
||||
normalized_symbol = self._normalize_hk_symbol(symbol)
|
||||
|
||||
ticker = yf.Ticker(normalized_symbol)
|
||||
|
||||
# 获取股票信息
|
||||
info = ticker.info
|
||||
|
||||
if not info:
|
||||
logger.warning(f"无法获取 {symbol} 的基本面数据")
|
||||
return None
|
||||
|
||||
# 提取关键指标
|
||||
fundamental_data = {
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
|
||||
# 基本信息
|
||||
'company_name': info.get('longName', info.get('shortName', 'N/A')),
|
||||
'sector': info.get('sector', 'N/A'),
|
||||
'industry': info.get('industry', 'N/A'),
|
||||
'market_cap': info.get('marketCap'),
|
||||
'shares_outstanding': info.get('sharesOutstanding'),
|
||||
|
||||
# 估值指标
|
||||
'valuation': self._extract_valuation_metrics(info),
|
||||
|
||||
# 盈利能力
|
||||
'profitability': self._extract_profitability_metrics(info),
|
||||
|
||||
# 成长性
|
||||
'growth': self._extract_growth_metrics(info),
|
||||
|
||||
# 财务健康
|
||||
'financial_health': self._extract_financial_health_metrics(info),
|
||||
|
||||
# 股票回报
|
||||
'returns': self._extract_return_metrics(info),
|
||||
|
||||
# 分析师建议
|
||||
'analyst': self._extract_analyst_metrics(info),
|
||||
}
|
||||
|
||||
# 计算综合评分
|
||||
fundamental_data['score'] = self._calculate_fundamental_score(fundamental_data)
|
||||
|
||||
# 输出基本面关键指标
|
||||
score = fundamental_data.get('score', {})
|
||||
logger.info(f"✓ {symbol} 基本面数据获取成功")
|
||||
logger.info(f" 【公司】{fundamental_data.get('company_name', 'N/A')} | {fundamental_data.get('sector', 'N/A')}")
|
||||
logger.info(f" 【评分】总分: {score.get('total', 0):.0f}/100 ({score.get('rating', 'N/A')}级) | "
|
||||
f"估值:{score.get('valuation', 0)} 盈利:{score.get('profitability', 0)} "
|
||||
f"成长:{score.get('growth', 0)} 财务:{score.get('financial_health', 0)}")
|
||||
|
||||
# 估值指标
|
||||
val = fundamental_data.get('valuation', {})
|
||||
if val.get('pe_ratio'):
|
||||
pe = val['pe_ratio']
|
||||
pb = val.get('pb_ratio')
|
||||
ps = val.get('ps_ratio')
|
||||
peg = val.get('peg_ratio')
|
||||
pb_str = f"{pb:.2f}" if pb is not None else "N/A"
|
||||
ps_str = f"{ps:.2f}" if ps is not None else "N/A"
|
||||
peg_str = f"{peg:.2f}" if peg is not None else "N/A"
|
||||
logger.info(f" 【估值】PE:{pe:.2f} | PB:{pb_str} | PS:{ps_str} | PEG:{peg_str}")
|
||||
|
||||
# 盈利能力
|
||||
prof = fundamental_data.get('profitability', {})
|
||||
if prof.get('return_on_equity'):
|
||||
roe = prof['return_on_equity']
|
||||
pm = prof.get('profit_margin')
|
||||
gm = prof.get('gross_margin')
|
||||
pm_str = f"{pm:.1f}" if pm is not None else "N/A"
|
||||
gm_str = f"{gm:.1f}" if gm is not None else "N/A"
|
||||
logger.info(f" 【盈利】ROE:{roe:.2f}% | 净利率:{pm_str}% | 毛利率:{gm_str}%")
|
||||
|
||||
# 成长性
|
||||
growth = fundamental_data.get('growth', {})
|
||||
rg = growth.get('revenue_growth')
|
||||
eg = growth.get('earnings_growth')
|
||||
if rg is not None or eg is not None:
|
||||
rg_str = f"{rg:.1f}" if rg is not None else "N/A"
|
||||
eg_str = f"{eg:.1f}" if eg is not None else "N/A"
|
||||
logger.info(f" 【成长】营收增长:{rg_str}% | 盈利增长:{eg_str}%")
|
||||
|
||||
# 财务健康
|
||||
fin = fundamental_data.get('financial_health', {})
|
||||
if fin.get('debt_to_equity'):
|
||||
de = fin['debt_to_equity']
|
||||
cr = fin.get('current_ratio')
|
||||
cr_str = f"{cr:.2f}" if cr is not None else "N/A"
|
||||
logger.info(f" 【财务】债务股本比:{de:.2f} | 流动比率:{cr_str}")
|
||||
|
||||
# 分析师建议
|
||||
analyst = fundamental_data.get('analyst', {})
|
||||
tp = analyst.get('target_price')
|
||||
if tp:
|
||||
rec = analyst.get('recommendation', 'N/A')
|
||||
logger.info(f" 【分析师】目标价:${tp:.2f} | 评级:{rec}")
|
||||
|
||||
return fundamental_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {symbol} 基本面数据失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_valuation_metrics(self, info: Dict) -> Dict[str, Any]:
|
||||
"""提取估值指标"""
|
||||
return {
|
||||
'pe_ratio': info.get('trailingPE'), # 市盈率
|
||||
'forward_pe': info.get('forwardPE'), # 远期市盈率
|
||||
'peg_ratio': info.get('pegRatio'), # PEG
|
||||
'pb_ratio': info.get('priceToBook'), # 市净率
|
||||
'ps_ratio': info.get('priceToSalesTrailing12M'), # 市销率
|
||||
'ev_to_ebitda': info.get('enterpriseToEbitda'), # EV/EBITDA
|
||||
'enterprise_value': info.get('enterpriseValue'), # 企业价值
|
||||
}
|
||||
|
||||
def _extract_profitability_metrics(self, info: Dict) -> Dict[str, Any]:
|
||||
"""提取盈利能力指标"""
|
||||
return {
|
||||
'eps': info.get('trailingEps'), # 每股收益
|
||||
'forward_eps': info.get('forwardEps'), # 预期每股收益
|
||||
'revenue': info.get('totalRevenue'), # 总收入
|
||||
'net_income': info.get('netIncomeToCommon'), # 净收入
|
||||
'profit_margin': info.get('profitMargins'), # 利润率
|
||||
'operating_margin': info.get('operatingMargins'), # 营业利润率
|
||||
'gross_margin': info.get('grossMargins'), # 毛利率
|
||||
'ebitda': info.get('ebitda'), # EBITDA
|
||||
'ebitda_margins': info.get('ebitdaMargins'), # EBITDA利润率
|
||||
}
|
||||
|
||||
def _extract_growth_metrics(self, info: Dict) -> Dict[str, Any]:
|
||||
"""提取成长性指标"""
|
||||
return {
|
||||
'revenue_growth': info.get('revenueGrowth'), # 营收增长率
|
||||
'earnings_growth': info.get('earningsGrowth'), # 盈利增长
|
||||
'earnings_quarterly_growth': info.get('earningsQuarterlyGrowth'), # 季度盈利增长
|
||||
'revenue_quarterly_growth': info.get('revenueQuarterlyGrowth'), # 季度营收增长
|
||||
}
|
||||
|
||||
def _extract_financial_health_metrics(self, info: Dict) -> Dict[str, Any]:
|
||||
"""提取财务健康指标"""
|
||||
return {
|
||||
'debt_to_equity': info.get('debtToEquity'), # 债务股本比
|
||||
'current_ratio': info.get('currentRatio'), # 流动比率
|
||||
'quick_ratio': info.get('quickRatio'), # 速动比率
|
||||
'total_cash': info.get('totalCash'), # 总现金
|
||||
'total_debt': info.get('totalDebt'), # 总债务
|
||||
'operating_cashflow': info.get('operatingCashflow'), # 经营现金流
|
||||
'free_cashflow': info.get('freeCashflow'), # 自由现金流
|
||||
}
|
||||
|
||||
def _extract_return_metrics(self, info: Dict) -> Dict[str, Any]:
|
||||
"""提取股票回报指标"""
|
||||
return {
|
||||
'dividend_rate': info.get('dividendRate'), # 股息率
|
||||
'dividend_yield': info.get('dividendYield'), # 股息收益率
|
||||
'payout_ratio': info.get('payoutRatio'), # 派息比率
|
||||
'five_year_avg_dividend_yield': info.get('fiveYearAvgDividendYield'), # 5年平均股息率
|
||||
'return_on_equity': info.get('returnOnEquity'), # ROE
|
||||
'return_on_assets': info.get('returnOnAssets'), # ROA
|
||||
}
|
||||
|
||||
def _extract_analyst_metrics(self, info: Dict) -> Dict[str, Any]:
|
||||
"""提取分析师建议"""
|
||||
return {
|
||||
'target_price': info.get('targetMeanPrice'), # 目标价
|
||||
'target_high': info.get('targetHighPrice'), # 目标价上限
|
||||
'target_low': info.get('targetLowPrice'), # 目标价下限
|
||||
'recommendation': info.get('recommendationKey'), # 分析师建议
|
||||
'number_of_analysts': info.get('numberOfAnalystOpinions'), # 分析师数量
|
||||
}
|
||||
|
||||
def _calculate_fundamental_score(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
计算基本面综合评分(0-100分)
|
||||
|
||||
评分维度:
|
||||
1. 估值合理性 (0-25分)
|
||||
2. 盈利能力 (0-25分)
|
||||
3. 成长性 (0-25分)
|
||||
4. 财务健康 (0-25分)
|
||||
"""
|
||||
scores = {
|
||||
'valuation': 0,
|
||||
'profitability': 0,
|
||||
'growth': 0,
|
||||
'financial_health': 0,
|
||||
'total': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 估值评分 (0-25分)
|
||||
valuation = data.get('valuation', {})
|
||||
if valuation.get('pe_ratio'):
|
||||
pe = valuation['pe_ratio']
|
||||
# PE < 15: 优秀,15-25: 良好,25-40: 一般,>40: 偏高
|
||||
if pe < 15:
|
||||
scores['valuation'] = 25
|
||||
elif pe < 25:
|
||||
scores['valuation'] = 20
|
||||
elif pe < 40:
|
||||
scores['valuation'] = 10
|
||||
else:
|
||||
scores['valuation'] = 5
|
||||
|
||||
# 2. 盈利能力评分 (0-25分)
|
||||
profitability = data.get('profitability', {})
|
||||
roe = profitability.get('return_on_equity')
|
||||
profit_margin = profitability.get('profit_margin')
|
||||
|
||||
# 处理 None 值
|
||||
if roe is None:
|
||||
roe = 0
|
||||
if profit_margin is None:
|
||||
profit_margin = 0
|
||||
|
||||
if roe > 0:
|
||||
# ROE > 20%: 优秀,15-20%: 良好,10-15%: 一般,< 10%: 较差
|
||||
if roe > 20:
|
||||
scores['profitability'] += 15
|
||||
elif roe > 15:
|
||||
scores['profitability'] += 12
|
||||
elif roe > 10:
|
||||
scores['profitability'] += 8
|
||||
else:
|
||||
scores['profitability'] += 4
|
||||
|
||||
if profit_margin > 0:
|
||||
# 净利率 > 20%: 优秀,10-20%: 良好,5-10%: 一般
|
||||
if profit_margin > 20:
|
||||
scores['profitability'] += 10
|
||||
elif profit_margin > 10:
|
||||
scores['profitability'] += 7
|
||||
else:
|
||||
scores['profitability'] += 4
|
||||
|
||||
# 3. 成长性评分 (0-25分)
|
||||
growth = data.get('growth', {})
|
||||
revenue_growth = growth.get('revenue_growth')
|
||||
earnings_growth = growth.get('earnings_growth')
|
||||
|
||||
# 处理 None 值
|
||||
if revenue_growth is None:
|
||||
revenue_growth = 0
|
||||
if earnings_growth is None:
|
||||
earnings_growth = 0
|
||||
|
||||
if revenue_growth > 0:
|
||||
# 营收增长 > 30%: 优秀,20-30%: 良好,10-20%: 一般,< 10%: 较差
|
||||
if revenue_growth > 30:
|
||||
scores['growth'] += 12
|
||||
elif revenue_growth > 20:
|
||||
scores['growth'] += 10
|
||||
elif revenue_growth > 10:
|
||||
scores['growth'] += 6
|
||||
else:
|
||||
scores['growth'] += 3
|
||||
|
||||
if earnings_growth > 0:
|
||||
# 盈利增长 > 30%: 优秀,20-30%: 良好,10-20%: 一般
|
||||
if earnings_growth > 30:
|
||||
scores['growth'] += 13
|
||||
elif earnings_growth > 20:
|
||||
scores['growth'] += 10
|
||||
elif earnings_growth > 10:
|
||||
scores['growth'] += 6
|
||||
else:
|
||||
scores['growth'] += 3
|
||||
|
||||
# 4. 财务健康评分 (0-25分)
|
||||
financial = data.get('financial_health', {})
|
||||
debt_to_equity = financial.get('debt_to_equity')
|
||||
current_ratio = financial.get('current_ratio')
|
||||
|
||||
# 处理 None 值
|
||||
if debt_to_equity is None:
|
||||
debt_to_equity = 0
|
||||
if current_ratio is None:
|
||||
current_ratio = 0
|
||||
|
||||
# 债务股本比 < 1: 优秀,1-2: 良好,2-3: 一般,> 3: 风险高
|
||||
if debt_to_equity < 1:
|
||||
scores['financial_health'] += 12
|
||||
elif debt_to_equity < 2:
|
||||
scores['financial_health'] += 10
|
||||
elif debt_to_equity < 3:
|
||||
scores['financial_health'] += 5
|
||||
else:
|
||||
scores['financial_health'] += 2
|
||||
|
||||
# 流动比率 > 2: 优秀,1.5-2: 良好,1-1.5: 一般,< 1: 风险
|
||||
if current_ratio > 2:
|
||||
scores['financial_health'] += 13
|
||||
elif current_ratio > 1.5:
|
||||
scores['financial_health'] += 10
|
||||
elif current_ratio > 1:
|
||||
scores['financial_health'] += 5
|
||||
else:
|
||||
scores['financial_health'] += 0
|
||||
|
||||
# 现金流评分
|
||||
fc = financial.get('free_cashflow')
|
||||
if fc is not None and fc > 0:
|
||||
scores['financial_health'] += 0 # 已在盈利能力中考虑
|
||||
|
||||
# 计算总分
|
||||
scores['total'] = sum([scores['valuation'], scores['profitability'],
|
||||
scores['growth'], scores['financial_health']])
|
||||
|
||||
# 添加评级
|
||||
if scores['total'] >= 80:
|
||||
scores['rating'] = 'A'
|
||||
elif scores['total'] >= 60:
|
||||
scores['rating'] = 'B'
|
||||
elif scores['total'] >= 40:
|
||||
scores['rating'] = 'C'
|
||||
else:
|
||||
scores['rating'] = 'D'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算基本面评分失败: {e}")
|
||||
|
||||
return scores
|
||||
|
||||
def get_fundamental_summary(self, symbol: str, data: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
生成基本面数据摘要文本,用于 LLM 分析
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
data: 可选,已获取的基本面数据。如果为None,则自动获取
|
||||
|
||||
Returns:
|
||||
基本面摘要文本
|
||||
"""
|
||||
if data is None:
|
||||
data = self.get_fundamental_data(symbol)
|
||||
if not data:
|
||||
return f"{symbol}: 暂无基本面数据"
|
||||
|
||||
summary_parts = []
|
||||
|
||||
# 基本信息
|
||||
summary_parts.append(f"【公司信息】{data.get('company_name', 'N/A')} | "
|
||||
f"行业: {data.get('sector', 'N/A')}")
|
||||
|
||||
# 估值情况
|
||||
val = data.get('valuation', {})
|
||||
if val.get('pe_ratio'):
|
||||
summary_parts.append(f"【估值】PE: {val['pe_ratio']:.2f} | "
|
||||
f"PB: {val.get('pb_ratio', 'N/A')} | "
|
||||
f"PS: {val.get('ps_ratio', 'N/A')}")
|
||||
|
||||
# 盈利能力
|
||||
prof = data.get('profitability', {})
|
||||
if prof.get('return_on_equity'):
|
||||
pm = prof.get('profit_margin')
|
||||
gm = prof.get('gross_margin')
|
||||
pm_str = f"{pm:.1f}" if pm is not None else "N/A"
|
||||
gm_str = f"{gm:.1f}" if gm is not None else "N/A"
|
||||
summary_parts.append(f"【盈利】ROE: {prof['return_on_equity']:.2f}% | "
|
||||
f"净利率: {pm_str}% | "
|
||||
f"毛利率: {gm_str}%")
|
||||
|
||||
# 成长性
|
||||
growth = data.get('growth', {})
|
||||
rg = growth.get('revenue_growth')
|
||||
eg = growth.get('earnings_growth')
|
||||
if rg is not None or eg is not None:
|
||||
rg_str = f"{rg:.1f}" if rg is not None else "N/A"
|
||||
eg_str = f"{eg:.1f}" if eg is not None else "N/A"
|
||||
summary_parts.append(f"【成长】营收增长: {rg_str}% | "
|
||||
f"盈利增长: {eg_str}%")
|
||||
|
||||
# 财务健康
|
||||
fin = data.get('financial_health', {})
|
||||
if fin.get('debt_to_equity'):
|
||||
cr = fin.get('current_ratio')
|
||||
cr_str = f"{cr:.2f}" if cr is not None else "N/A"
|
||||
summary_parts.append(f"【财务】债务股本比: {fin['debt_to_equity']:.2f} | "
|
||||
f"流动比率: {cr_str}")
|
||||
|
||||
# 分析师建议
|
||||
analyst = data.get('analyst', {})
|
||||
if analyst.get('target_price'):
|
||||
summary_parts.append(f"【分析师建议】目标价: ${analyst['target_price']:.2f} | "
|
||||
f"评级: {analyst.get('recommendation', 'N/A')}")
|
||||
|
||||
# 基本面评分
|
||||
score = data.get('score', {})
|
||||
summary_parts.append(f"【基本面评分】{score.get('total', 0):.0f}/100 ({score.get('rating', 'N/A')}级)")
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
def batch_get_fundamentals(self, symbols: List[str]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
批量获取多只股票的基本面数据
|
||||
|
||||
Args:
|
||||
symbols: 股票代码列表
|
||||
|
||||
Returns:
|
||||
股票代码到基本面数据的映射
|
||||
"""
|
||||
results = {}
|
||||
for symbol in symbols:
|
||||
data = self.get_fundamental_data(symbol)
|
||||
if data:
|
||||
results[symbol] = data
|
||||
|
||||
logger.info(f"批量获取基本面数据完成: {len(results)}/{len(symbols)} 只股票")
|
||||
return results
|
||||
|
||||
def compare_stocks(self, symbols: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
比较多只股票的基本面指标
|
||||
|
||||
Args:
|
||||
symbols: 股票代码列表
|
||||
|
||||
Returns:
|
||||
比较结果
|
||||
"""
|
||||
fundamentals = self.batch_get_fundamentals(symbols)
|
||||
|
||||
comparison = {
|
||||
'symbols': symbols,
|
||||
'metrics': {}
|
||||
}
|
||||
|
||||
# 提取可比较的指标
|
||||
metrics_to_compare = [
|
||||
('valuation', ['pe_ratio', 'pb_ratio']),
|
||||
('profitability', ['return_on_equity', 'profit_margin']),
|
||||
('growth', ['revenue_growth', 'earnings_growth']),
|
||||
('financial_health', ['debt_to_equity', 'current_ratio']),
|
||||
]
|
||||
|
||||
for category, metric_names in metrics_to_compare:
|
||||
comparison['metrics'][category] = {}
|
||||
for metric in metric_names:
|
||||
values = {}
|
||||
for symbol in symbols:
|
||||
if symbol in fundamentals:
|
||||
category_data = fundamentals[symbol].get(category, {})
|
||||
value = category_data.get(metric)
|
||||
if value is not None:
|
||||
values[symbol] = value
|
||||
|
||||
if values:
|
||||
comparison['metrics'][category][metric] = values
|
||||
|
||||
# 计算排名
|
||||
comparison['rankings'] = {}
|
||||
if 'valuation' in comparison['metrics']:
|
||||
pe_ratios = {s: v.get('valuation', {}).get('pe_ratio')
|
||||
for s, v in fundamentals.items() if v.get('valuation', {}).get('pe_ratio')}
|
||||
if pe_ratios:
|
||||
# PE 越低越好
|
||||
sorted_pe = sorted(pe_ratios.items(), key=lambda x: x[1])
|
||||
comparison['rankings']['pe_low_to_high'] = [s[0] for s in sorted_pe]
|
||||
|
||||
return comparison
|
||||
|
||||
|
||||
# 全局单例
|
||||
_fundamental_service: Optional[FundamentalService] = None
|
||||
|
||||
|
||||
def get_fundamental_service() -> FundamentalService:
|
||||
"""获取基本面数据服务单例"""
|
||||
global _fundamental_service
|
||||
if _fundamental_service is None:
|
||||
_fundamental_service = FundamentalService()
|
||||
return _fundamental_service
|
||||
@ -76,10 +76,6 @@ class LLMService:
|
||||
)
|
||||
)
|
||||
|
||||
def analyze_intent(self, user_message: str) -> Dict[str, Any]:
|
||||
"""使用LLM分析用户意图"""
|
||||
return self.multi_service.analyze_intent(user_message)
|
||||
|
||||
def chat_stream(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
||||
@ -75,7 +75,7 @@ class MultiLLMService:
|
||||
logger.error(f"DeepSeek初始化失败: {e}")
|
||||
|
||||
# 设置默认模型(优先使用配置文件中的设置)
|
||||
preferred_model = getattr(settings, 'smart_agent_model', None)
|
||||
preferred_model = getattr(settings, 'crypto_agent_model', None)
|
||||
if preferred_model and preferred_model in self.clients:
|
||||
self.current_model = preferred_model
|
||||
logger.info(f"使用配置的模型: {preferred_model}")
|
||||
@ -426,41 +426,5 @@ class MultiLLMService:
|
||||
|
||||
return
|
||||
|
||||
def analyze_intent(self, user_message: str) -> Dict[str, Any]:
|
||||
"""使用LLM分析用户意图"""
|
||||
if not self.current_model:
|
||||
return {"type": "unknown", "confidence": 0}
|
||||
|
||||
prompt = f"""你是一个股票分析助手的意图识别模块。请分析用户的查询意图。
|
||||
|
||||
用户消息:{user_message}
|
||||
|
||||
请识别以下意图类型之一:
|
||||
1. market_data - 查询实时行情、价格
|
||||
2. technical_analysis - 技术分析、技术指标
|
||||
3. fundamental - 基本面信息、公司信息
|
||||
4. visualization - K线图、图表
|
||||
5. unknown - 无法识别
|
||||
|
||||
请以JSON格式返回:
|
||||
{{
|
||||
"type": "意图类型",
|
||||
"confidence": 0.0-1.0,
|
||||
"stock_name": "提取的股票名称(如果有)"
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = self.chat([{"role": "user", "content": prompt}], temperature=0.3)
|
||||
if response:
|
||||
import json
|
||||
result = json.loads(response)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"意图分析失败: {e}")
|
||||
|
||||
return {"type": "unknown", "confidence": 0}
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
multi_llm_service = MultiLLMService()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
新闻舆情服务 - 获取加密货币和股票相关新闻
|
||||
新闻舆情服务 - 获取加密货币相关新闻
|
||||
"""
|
||||
import re
|
||||
import html
|
||||
@ -29,7 +29,7 @@ class NewsService:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化新闻服务"""
|
||||
self._cache: Dict[str, List[Dict[str, Any]]] = {'crypto': [], 'stock': {}}
|
||||
self._cache: List[Dict[str, Any]] = []
|
||||
self._cache_time: Optional[datetime] = None
|
||||
self._cache_duration = timedelta(minutes=5) # 缓存5分钟
|
||||
self.settings = get_settings()
|
||||
@ -53,7 +53,7 @@ class NewsService:
|
||||
# 检查缓存
|
||||
if self._cache and self._cache_time:
|
||||
if datetime.now() - self._cache_time < self._cache_duration:
|
||||
return self._cache['crypto'][:limit] if isinstance(self._cache, dict) else self._cache[:limit]
|
||||
return self._cache[:limit]
|
||||
|
||||
try:
|
||||
# 并发获取所有源的新闻
|
||||
@ -77,7 +77,7 @@ class NewsService:
|
||||
all_news.sort(key=lambda x: x.get('time') or datetime.min, reverse=True)
|
||||
|
||||
# 更新缓存
|
||||
self._cache = {'crypto': all_news, 'stock': self._cache.get('stock', {}) if isinstance(self._cache, dict) else {}}
|
||||
self._cache = all_news
|
||||
self._cache_time = datetime.now()
|
||||
|
||||
logger.info(f"获取到 {len(all_news)} 条加密货币新闻(律动+Cointelegraph+CoinDesk)")
|
||||
@ -86,8 +86,6 @@ class NewsService:
|
||||
except Exception as e:
|
||||
logger.error(f"获取新闻失败: {e}")
|
||||
# 返回缓存
|
||||
if isinstance(self._cache, dict):
|
||||
return self._cache.get('crypto', [])[:limit]
|
||||
return self._cache[:limit] if self._cache else []
|
||||
|
||||
async def _fetch_blockbeats_news(self) -> List[Dict[str, Any]]:
|
||||
@ -407,114 +405,6 @@ class NewsService:
|
||||
|
||||
return filtered
|
||||
|
||||
async def search_stock_news(self, symbol: str, stock_name: str = '',
|
||||
max_results: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用 Brave Search API 搜索股票相关新闻
|
||||
|
||||
Args:
|
||||
symbol: 股票代码(如 AAPL, 0700.HK)
|
||||
stock_name: 股票中文名称(可选)
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
新闻列表
|
||||
"""
|
||||
api_key = self.settings.brave_api_key
|
||||
if not api_key:
|
||||
logger.warning("未配置 Brave API Key,跳过新闻搜索")
|
||||
return []
|
||||
|
||||
# 检查缓存
|
||||
cache_key = f"{symbol}_{stock_name}"
|
||||
if self._cache_time and cache_key in self._cache.get('stock', {}):
|
||||
if datetime.now() - self._cache_time < self._cache_duration:
|
||||
return self._cache['stock'][cache_key][:max_results]
|
||||
|
||||
# 构建搜索查询
|
||||
# 根据股票类型构建不同的搜索词
|
||||
if symbol.endswith('.HK'):
|
||||
# 港股
|
||||
if stock_name:
|
||||
query = f"{stock_name} 港股 新闻 最新"
|
||||
else:
|
||||
query = f"{symbol.replace('.HK', '')} 港股 新闻 最新"
|
||||
else:
|
||||
# 美股
|
||||
if stock_name:
|
||||
query = f"{stock_name} 股票 {symbol} news latest"
|
||||
else:
|
||||
query = f"{symbol} stock news latest"
|
||||
|
||||
try:
|
||||
headers = {
|
||||
'Accept': 'application/json',
|
||||
'Accept-Encoding': 'gzip',
|
||||
'X-Subscription-Token': api_key
|
||||
}
|
||||
|
||||
params = {
|
||||
'q': query,
|
||||
'count': max_results,
|
||||
'text_decorations': 'false', # 改为字符串
|
||||
'search_lang': 'zh-hans', # Brave Search 使用 zh-hans 而非 zh-CN
|
||||
# 'result_filter': 'news', # 免费计划不支持,移除此参数
|
||||
'freshness': 'pd' # 过去24小时
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
self.BRAVE_SEARCH_API,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=10
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"Brave Search API 请求失败: HTTP {response.status}")
|
||||
return []
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# 解析搜索结果
|
||||
news_list = []
|
||||
web_results = data.get('web', {}).get('results', [])
|
||||
|
||||
for item in web_results:
|
||||
title = item.get('title', '')
|
||||
url = item.get('url', '')
|
||||
description = item.get('description', '')
|
||||
|
||||
# 清理描述
|
||||
description = self._clean_html(description)
|
||||
|
||||
news_list.append({
|
||||
'title': title,
|
||||
'description': description[:500],
|
||||
'time': datetime.now(), # Brave Search 不返回精确时间
|
||||
'time_str': datetime.now().strftime('%m-%d %H:%M'),
|
||||
'link': url,
|
||||
'source': 'Brave Search'
|
||||
})
|
||||
|
||||
logger.info(f"Brave Search 搜索 {symbol} 获取到 {len(news_list)} 条新闻")
|
||||
|
||||
# 更新缓存
|
||||
if 'stock' not in self._cache:
|
||||
self._cache['stock'] = {}
|
||||
self._cache['stock'][cache_key] = news_list
|
||||
self._cache_time = datetime.now()
|
||||
|
||||
return news_list[:max_results]
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Brave Search API 请求失败: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"搜索股票新闻失败: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return []
|
||||
|
||||
async def get_crypto_news(self, symbol: str, limit: int = 10) -> Dict[str, Any]:
|
||||
"""
|
||||
获取加密货币相关新闻
|
||||
|
||||
@ -249,14 +249,12 @@ class PaperTradingService:
|
||||
return result
|
||||
|
||||
# === 动态仓位计算 ===
|
||||
# 优先使用信号中的 quantity(LLM 决策的保证金金额)
|
||||
quantity_from_signal = signal.get('quantity')
|
||||
if quantity_from_signal is not None and quantity_from_signal > 0:
|
||||
# LLM 决策的 quantity 是保证金金额
|
||||
margin = float(quantity_from_signal)
|
||||
# 计算持仓价值(保证金 × 杠杆),保留2位小数
|
||||
# 优先使用信号中的 margin;旧路径 quantity 仍兼容为“保证金金额”
|
||||
margin_from_signal = signal.get('margin', signal.get('quantity'))
|
||||
if margin_from_signal is not None and margin_from_signal > 0:
|
||||
margin = float(margin_from_signal)
|
||||
position_value = round(margin * self.leverage, 2)
|
||||
logger.debug(f"使用 LLM 决策保证金: ${margin:.2f}, 持仓价值: ${position_value:.2f}")
|
||||
logger.debug(f"使用 LLM 决策保证金: ${margin:.2f}, 名义仓位: ${position_value:.2f}")
|
||||
else:
|
||||
# 回退到动态仓位计算
|
||||
position_size = signal.get('position_size', 'light')
|
||||
@ -276,15 +274,15 @@ class PaperTradingService:
|
||||
return result
|
||||
|
||||
# === 检查总杠杆是否超限 ===
|
||||
# 计算当前实际的已用保证金和持仓价值(使用订单实际值)
|
||||
# 计算当前实际的已用保证金和名义仓位(使用订单实际值)
|
||||
current_used_margin = sum(order.margin for order in self.active_orders.values())
|
||||
current_total_position_value = sum(order.quantity for order in self.active_orders.values())
|
||||
|
||||
# 新订单增加的保证金和持仓价值
|
||||
# 新订单增加的保证金和名义仓位
|
||||
new_margin = margin
|
||||
new_position_value = position_value
|
||||
|
||||
# 新增后的总持仓价值和总杠杆
|
||||
# 新增后的总名义仓位和总杠杆
|
||||
new_total_position_value = current_total_position_value + new_position_value
|
||||
|
||||
# 获取当前账户余额
|
||||
@ -310,14 +308,14 @@ class PaperTradingService:
|
||||
|
||||
if new_total_leverage > self.max_total_leverage:
|
||||
msg = f"总杠杆超限!当前 {new_total_leverage:.1f}x,上限 {self.max_total_leverage}x"
|
||||
logger.info(f"{msg}: {symbol} | 当前持仓价值: ${current_total_position_value:,.0f} | "
|
||||
f"新订单持仓价值: ${new_position_value:,.0f} | 总持仓价值: ${new_total_position_value:,.0f} | "
|
||||
logger.info(f"{msg}: {symbol} | 当前名义仓位: ${current_total_position_value:,.0f} | "
|
||||
f"新订单名义仓位: ${new_position_value:,.0f} | 总名义仓位: ${new_total_position_value:,.0f} | "
|
||||
f"账户余额: ${current_balance:,.0f}")
|
||||
result['message'] = msg
|
||||
return result
|
||||
|
||||
logger.debug(f"总杠杆检查通过: {new_total_leverage:.1f}x / {self.max_total_leverage}x")
|
||||
quantity = round(position_value, 2) # 确保持仓价值保留2位小数
|
||||
quantity = round(position_value, 2) # 兼容旧字段:名义仓位
|
||||
|
||||
# 确定入场类型
|
||||
entry_type_str = signal.get('entry_type', 'market')
|
||||
@ -348,7 +346,7 @@ class PaperTradingService:
|
||||
stop_loss=signal.get('stop_loss', 0),
|
||||
take_profit=signal.get('take_profit', 0),
|
||||
filled_price=filled_price,
|
||||
quantity=quantity, # 持仓价值
|
||||
quantity=quantity, # 兼容旧字段:名义仓位
|
||||
margin=margin, # 保证金
|
||||
leverage=self.leverage, # 杠杆倍数
|
||||
signal_grade=SignalGrade(grade),
|
||||
@ -381,7 +379,7 @@ class PaperTradingService:
|
||||
entry_type_text = "现价" if entry_type == EntryType.MARKET else "挂单"
|
||||
status_text = "已开仓" if status == OrderStatus.OPEN else "等待触发"
|
||||
logger.info(f"✅ 创建订单成功: {order_id} | {symbol} {side.value} [{entry_type_text}] @ ${entry_price:,.2f} | {status_text}")
|
||||
logger.info(f" 保证金: ${margin:,.0f} | 杠杆: {self.leverage}x | 持仓价值: ${position_value:,.0f} | 当前订单数: {len(self.active_orders)}/{self.max_orders}")
|
||||
logger.info(f" 保证金: ${margin:,.0f} | 杠杆: {self.leverage}x | 名义仓位: ${position_value:,.0f} | 当前订单数: {len(self.active_orders)}/{self.max_orders}")
|
||||
result['order'] = order
|
||||
return result
|
||||
|
||||
@ -447,7 +445,7 @@ class PaperTradingService:
|
||||
|
||||
logger.info(
|
||||
f"动态仓位计算: {symbol} | {sizing_reason} | "
|
||||
f"保证金 ${margin:.2f} | 持仓价值 ${position_value:.2f} | {budget_reason}"
|
||||
f"保证金 ${margin:.2f} | 名义仓位 ${position_value:.2f} | {budget_reason}"
|
||||
)
|
||||
return margin, position_value
|
||||
|
||||
@ -475,6 +473,7 @@ class PaperTradingService:
|
||||
'status': order.get('status'),
|
||||
'entry_price': order.get('filled_price') or order.get('entry_price'),
|
||||
'quantity': order.get('quantity'),
|
||||
'notional': order.get('notional', order.get('quantity')),
|
||||
'pnl_percent': order.get('pnl_percent', 0)
|
||||
})
|
||||
|
||||
@ -598,6 +597,7 @@ class PaperTradingService:
|
||||
'entry_price': order.entry_price, # 挂单价
|
||||
'filled_price': filled_price,
|
||||
'quantity': order.quantity,
|
||||
'notional': order.quantity,
|
||||
'signal_grade': order.signal_grade.value if order.signal_grade else None,
|
||||
'stop_loss': order.stop_loss,
|
||||
'take_profit': order.take_profit
|
||||
@ -732,6 +732,7 @@ class PaperTradingService:
|
||||
'entry_price': db_order.filled_price,
|
||||
'exit_price': exit_price,
|
||||
'quantity': db_order.quantity,
|
||||
'notional': db_order.quantity,
|
||||
'pnl_amount': db_order.pnl_amount,
|
||||
'pnl_percent': db_order.pnl_percent,
|
||||
'is_win': pnl_amount > 0,
|
||||
@ -1614,6 +1615,7 @@ class PaperTradingService:
|
||||
'entry_price': order.get('filled_price') or order.get('entry_price') or 0,
|
||||
'filled_price': order.get('filled_price'),
|
||||
'quantity': order.get('quantity', 0),
|
||||
'notional': order.get('notional', order.get('quantity', 0)),
|
||||
'margin': order.get('margin', 0),
|
||||
'stop_loss': order.get('stop_loss'),
|
||||
'take_profit': order.get('take_profit'),
|
||||
@ -1813,7 +1815,7 @@ class PaperTradingService:
|
||||
|
||||
for order in self.active_orders.values():
|
||||
used_margin += order.margin # 订单实际保证金
|
||||
total_position_value += order.quantity # 订单实际持仓价值
|
||||
total_position_value += order.quantity # 订单实际名义仓位
|
||||
|
||||
# 计算已实现盈亏(从历史订单)
|
||||
db = db_service.get_session()
|
||||
@ -1839,7 +1841,7 @@ class PaperTradingService:
|
||||
# 计算当前余额
|
||||
current_balance = self.initial_balance + realized_pnl
|
||||
|
||||
# 计算当前总杠杆(持仓价值 / 账户余额)
|
||||
# 计算当前总杠杆(名义仓位 / 账户余额)
|
||||
current_total_leverage = total_position_value / current_balance if current_balance > 0 else 0
|
||||
|
||||
# 计算可用保证金
|
||||
@ -2128,6 +2130,8 @@ class PaperTradingService:
|
||||
'close_percent': close_percent,
|
||||
'close_quantity': close_quantity,
|
||||
'remaining_quantity': remaining_quantity,
|
||||
'close_notional': close_quantity,
|
||||
'remaining_notional': remaining_quantity,
|
||||
'pnl': pnl,
|
||||
'order': order.to_dict()
|
||||
}
|
||||
|
||||
@ -120,66 +120,6 @@ class SignalDatabaseService:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_stock_signals(
|
||||
self,
|
||||
limit: int = 50,
|
||||
symbol: Optional[str] = None,
|
||||
days: int = 7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取美股信号"""
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
query = db.query(TradingSignal).filter(
|
||||
TradingSignal.signal_type == 'stock',
|
||||
TradingSignal.created_at >= cutoff_time
|
||||
)
|
||||
|
||||
if symbol:
|
||||
query = query.filter(TradingSignal.symbol == symbol.upper())
|
||||
|
||||
signals = query.order_by(desc(TradingSignal.created_at)).limit(limit).all()
|
||||
|
||||
return [signal.to_dict() for signal in signals]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股信号失败: {e}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_all_signals(self, limit: int = 100, days: int = 7) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""获取所有信号"""
|
||||
db = self.db_service.get_session()
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
signals = db.query(TradingSignal).filter(
|
||||
TradingSignal.created_at >= cutoff_time
|
||||
).order_by(desc(TradingSignal.created_at)).limit(limit).all()
|
||||
|
||||
crypto_signals = []
|
||||
stock_signals = []
|
||||
|
||||
for signal in signals:
|
||||
signal_dict = signal.to_dict()
|
||||
if signal.signal_type == 'crypto':
|
||||
crypto_signals.append(signal_dict)
|
||||
else:
|
||||
stock_signals.append(signal_dict)
|
||||
|
||||
return {
|
||||
'crypto': crypto_signals,
|
||||
'stock': stock_signals
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有信号失败: {e}")
|
||||
return {'crypto': [], 'stock': []}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_latest_signals(self, limit: int = 20, days: int = 7) -> List[Dict[str, Any]]:
|
||||
"""获取最新信号(混合)"""
|
||||
db = self.db_service.get_session()
|
||||
@ -214,11 +154,6 @@ class SignalDatabaseService:
|
||||
crypto_buy = sum(1 for s in crypto_signals if s.action == 'buy')
|
||||
crypto_sell = sum(1 for s in crypto_signals if s.action == 'sell')
|
||||
|
||||
# 统计美股信号
|
||||
stock_signals = [s for s in all_signals if s.signal_type == 'stock']
|
||||
stock_buy = sum(1 for s in stock_signals if s.action == 'buy')
|
||||
stock_sell = sum(1 for s in stock_signals if s.action == 'sell')
|
||||
|
||||
# 按等级统计
|
||||
grade_stats = {}
|
||||
for signal in all_signals:
|
||||
@ -227,7 +162,6 @@ class SignalDatabaseService:
|
||||
# 最近24小时信号
|
||||
recent_cutoff = datetime.utcnow() - timedelta(hours=24)
|
||||
recent_crypto = sum(1 for s in crypto_signals if s.created_at >= recent_cutoff)
|
||||
recent_stock = sum(1 for s in stock_signals if s.created_at >= recent_cutoff)
|
||||
|
||||
return {
|
||||
'crypto': {
|
||||
@ -236,14 +170,8 @@ class SignalDatabaseService:
|
||||
'sell': crypto_sell,
|
||||
'recent_24h': recent_crypto
|
||||
},
|
||||
'stock': {
|
||||
'total': len(stock_signals),
|
||||
'buy': stock_buy,
|
||||
'sell': stock_sell,
|
||||
'recent_24h': recent_stock
|
||||
},
|
||||
'grades': grade_stats,
|
||||
'total': len(all_signals)
|
||||
'total': len(crypto_signals)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -301,4 +229,3 @@ def get_signal_db_service() -> SignalDatabaseService:
|
||||
if _signal_db_service is None:
|
||||
_signal_db_service = SignalDatabaseService()
|
||||
return _signal_db_service
|
||||
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
"""
|
||||
信号存储服务 - 保存加密货币和美股的交易信号
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class SignalStorageService:
|
||||
"""信号存储服务"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.storage_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'data', 'signals')
|
||||
os.makedirs(self.storage_dir, exist_ok=True)
|
||||
|
||||
# 信号文件
|
||||
self.crypto_file = os.path.join(self.storage_dir, 'crypto_signals.json')
|
||||
self.stock_file = os.path.join(self.storage_dir, 'stock_signals.json')
|
||||
|
||||
# 加载现有信号
|
||||
self._crypto_signals = self._load_signals(self.crypto_file)
|
||||
self._stock_signals = self._load_signals(self.stock_file)
|
||||
|
||||
logger.info(f"信号存储服务初始化完成,加密货币信号: {len(self._crypto_signals)},美股信号: {len(self._stock_signals)}")
|
||||
|
||||
def _load_signals(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
"""从文件加载信号"""
|
||||
if not os.path.exists(file_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载信号失败 {file_path}: {e}")
|
||||
return []
|
||||
|
||||
def _save_signals(self, file_path: str, signals: List[Dict[str, Any]]):
|
||||
"""保存信号到文件"""
|
||||
try:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(signals, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"保存信号失败 {file_path}: {e}")
|
||||
|
||||
def add_crypto_signal(self, signal: Dict[str, Any]):
|
||||
"""添加加密货币信号"""
|
||||
# 添加时间戳和类型
|
||||
signal['timestamp'] = datetime.now().isoformat()
|
||||
signal['signal_type'] = 'crypto'
|
||||
|
||||
# 保存到内存
|
||||
self._crypto_signals.insert(0, signal)
|
||||
|
||||
# 只保留最近 100 条
|
||||
if len(self._crypto_signals) > 100:
|
||||
self._crypto_signals = self._crypto_signals[:100]
|
||||
|
||||
# 持久化
|
||||
self._save_signals(self.crypto_file, self._crypto_signals)
|
||||
|
||||
logger.info(f"添加加密货币信号: {signal.get('symbol', 'N/A')} - {signal.get('action', 'N/A')}")
|
||||
|
||||
def add_stock_signal(self, signal: Dict[str, Any]):
|
||||
"""添加美股信号"""
|
||||
# 添加时间戳和类型
|
||||
signal['timestamp'] = datetime.now().isoformat()
|
||||
signal['signal_type'] = 'stock'
|
||||
|
||||
# 保存到内存
|
||||
self._stock_signals.insert(0, signal)
|
||||
|
||||
# 只保留最近 100 条
|
||||
if len(self._stock_signals) > 100:
|
||||
self._stock_signals = self._stock_signals[:100]
|
||||
|
||||
# 持久化
|
||||
self._save_signals(self.stock_file, self._stock_signals)
|
||||
|
||||
logger.info(f"添加美股信号: {signal.get('symbol', 'N/A')} - {signal.get('action', 'N/A')}")
|
||||
|
||||
def get_crypto_signals(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取加密货币信号列表"""
|
||||
return self._crypto_signals[:limit]
|
||||
|
||||
def get_stock_signals(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取美股信号列表"""
|
||||
return self._stock_signals[:limit]
|
||||
|
||||
def get_all_signals(self, limit: int = 100) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""获取所有信号"""
|
||||
return {
|
||||
'crypto': self._crypto_signals[:limit],
|
||||
'stock': self._stock_signals[:limit]
|
||||
}
|
||||
|
||||
def get_latest_signal(self, signal_type: str, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取指定交易对的最新信号"""
|
||||
if signal_type == 'crypto':
|
||||
signals = self._crypto_signals
|
||||
elif signal_type == 'stock':
|
||||
signals = self._stock_signals
|
||||
else:
|
||||
return None
|
||||
|
||||
for signal in signals:
|
||||
if signal.get('symbol') == symbol:
|
||||
return signal
|
||||
|
||||
return None
|
||||
|
||||
def clear_old_signals(self, days: int = 7):
|
||||
"""清理旧信号"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff_time = (datetime.now() - timedelta(days=days)).isoformat()
|
||||
|
||||
# 清理加密货币信号
|
||||
self._crypto_signals = [
|
||||
s for s in self._crypto_signals
|
||||
if s.get('timestamp', '') >= cutoff_time
|
||||
]
|
||||
self._save_signals(self.crypto_file, self._crypto_signals)
|
||||
|
||||
# 清理美股信号
|
||||
self._stock_signals = [
|
||||
s for s in self._stock_signals
|
||||
if s.get('timestamp', '') >= cutoff_time
|
||||
]
|
||||
self._save_signals(self.stock_file, self._stock_signals)
|
||||
|
||||
logger.info(f"清理旧信号完成,保留 {days} 天内的信号")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_signal_storage: Optional[SignalStorageService] = None
|
||||
|
||||
|
||||
def get_signal_storage() -> SignalStorageService:
|
||||
"""获取信号存储服务单例"""
|
||||
global _signal_storage
|
||||
if _signal_storage is None:
|
||||
_signal_storage = SignalStorageService()
|
||||
return _signal_storage
|
||||
@ -17,7 +17,7 @@ class TelegramService:
|
||||
Args:
|
||||
bot_token: Telegram Bot Token (从 @BotFather 获取)
|
||||
channel_id: 频道 ID (如 @your_channel 或 -1001234567890)
|
||||
service_type: 服务类型 (crypto/stock/news/default)
|
||||
service_type: 服务类型 (crypto/default)
|
||||
"""
|
||||
settings = get_settings()
|
||||
self.bot_token = bot_token or getattr(settings, 'telegram_bot_token', '')
|
||||
@ -30,8 +30,6 @@ class TelegramService:
|
||||
# 根据service_type选择对应的频道
|
||||
if service_type == "crypto":
|
||||
self.channel_id = getattr(settings, 'telegram_crypto_channel_id', '') or getattr(settings, 'telegram_channel_id', '')
|
||||
elif service_type == "stock":
|
||||
self.channel_id = getattr(settings, 'telegram_stock_channel_id', '') or getattr(settings, 'telegram_channel_id', '')
|
||||
else:
|
||||
self.channel_id = getattr(settings, 'telegram_channel_id', '')
|
||||
|
||||
@ -260,7 +258,6 @@ class TelegramService:
|
||||
# 全局实例(延迟初始化)
|
||||
_telegram_service: Optional[TelegramService] = None
|
||||
_telegram_crypto_service: Optional[TelegramService] = None
|
||||
_telegram_stock_service: Optional[TelegramService] = None
|
||||
|
||||
|
||||
def get_telegram_service() -> TelegramService:
|
||||
@ -277,11 +274,3 @@ def get_telegram_crypto_service() -> TelegramService:
|
||||
if _telegram_crypto_service is None:
|
||||
_telegram_crypto_service = TelegramService(service_type="crypto")
|
||||
return _telegram_crypto_service
|
||||
|
||||
|
||||
def get_telegram_stock_service() -> TelegramService:
|
||||
"""获取股票 Telegram 服务实例"""
|
||||
global _telegram_stock_service
|
||||
if _telegram_stock_service is None:
|
||||
_telegram_stock_service = TelegramService(service_type="stock")
|
||||
return _telegram_stock_service
|
||||
|
||||
@ -1,556 +0,0 @@
|
||||
"""
|
||||
Tushare高级数据服务
|
||||
充分利用5000+积分,获取财务数据、资金流向、新闻公告等
|
||||
"""
|
||||
import tushare as ts
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from app.config import get_settings
|
||||
from app.utils.logger import logger
|
||||
from app.utils.validators import normalize_stock_code
|
||||
|
||||
|
||||
class TushareAdvancedService:
|
||||
"""Tushare高级数据服务类(需要5000+积分)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Tushare服务"""
|
||||
settings = get_settings()
|
||||
if not settings.tushare_token:
|
||||
logger.warning("Tushare token未配置")
|
||||
self.pro = None
|
||||
else:
|
||||
ts.set_token(settings.tushare_token)
|
||||
self.pro = ts.pro_api()
|
||||
logger.info("Tushare高级服务初始化成功")
|
||||
|
||||
# ==================== 财务数据 ====================
|
||||
|
||||
def get_income_statement(
|
||||
self,
|
||||
stock_code: str,
|
||||
period: str = None,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取利润表数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
period: 报告期(YYYYMMDD),如20231231
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
利润表数据
|
||||
"""
|
||||
if not self.pro:
|
||||
logger.error("Tushare服务未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
# 默认获取最近4个季度的数据
|
||||
if not period and not start_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
start_date = (datetime.now() - timedelta(days=400)).strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.income(
|
||||
ts_code=ts_code,
|
||||
period=period,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fields='ts_code,ann_date,f_ann_date,end_date,report_type,comp_type,'
|
||||
'total_revenue,revenue,operating_profit,total_profit,n_income,'
|
||||
'n_income_attr_p,basic_eps,diluted_eps'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"未找到利润表数据: {ts_code}")
|
||||
return None
|
||||
|
||||
# 转换为字典列表,按日期降序
|
||||
df = df.sort_values('end_date', ascending=False)
|
||||
return {
|
||||
'ts_code': ts_code,
|
||||
'data': df.to_dict('records')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取利润表失败: {e}")
|
||||
return None
|
||||
|
||||
def get_balance_sheet(
|
||||
self,
|
||||
stock_code: str,
|
||||
period: str = None,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取资产负债表数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
period: 报告期
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
资产负债表数据
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
if not period and not start_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
start_date = (datetime.now() - timedelta(days=400)).strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.balancesheet(
|
||||
ts_code=ts_code,
|
||||
period=period,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fields='ts_code,ann_date,f_ann_date,end_date,report_type,'
|
||||
'total_assets,total_liab,total_hldr_eqy_exc_min_int,'
|
||||
'total_cur_assets,total_cur_liab,money_cap'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
df = df.sort_values('end_date', ascending=False)
|
||||
return {
|
||||
'ts_code': ts_code,
|
||||
'data': df.to_dict('records')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取资产负债表失败: {e}")
|
||||
return None
|
||||
|
||||
def get_financial_indicators(
|
||||
self,
|
||||
stock_code: str,
|
||||
period: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取财务指标数据(ROE、ROA、毛利率等)
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
period: 报告期
|
||||
|
||||
Returns:
|
||||
财务指标数据
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
# 获取最近一期的财务指标
|
||||
df = self.pro.fina_indicator(
|
||||
ts_code=ts_code,
|
||||
period=period,
|
||||
fields='ts_code,end_date,eps,dt_eps,total_revenue_ps,revenue_ps,'
|
||||
'capital_rese_ps,undist_profit_ps,extra_item,profit_dedt,'
|
||||
'gross_margin,current_ratio,quick_ratio,roe,roe_waa,'
|
||||
'roe_dt,roa,npta,roic,debt_to_assets,assets_to_eqt'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# 取最新一期
|
||||
latest = df.sort_values('end_date', ascending=False).iloc[0]
|
||||
return {
|
||||
'ts_code': ts_code,
|
||||
'end_date': latest['end_date'],
|
||||
'indicators': latest.to_dict()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取财务指标失败: {e}")
|
||||
return None
|
||||
|
||||
# ==================== 估值数据 ====================
|
||||
|
||||
def get_daily_basic(
|
||||
self,
|
||||
stock_code: str,
|
||||
trade_date: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取每日指标(PE、PB、PS、市值、换手率等)
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
trade_date: 交易日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
每日指标数据
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
if not trade_date:
|
||||
trade_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.daily_basic(
|
||||
ts_code=ts_code,
|
||||
trade_date=trade_date,
|
||||
fields='ts_code,trade_date,close,turnover_rate,turnover_rate_f,'
|
||||
'volume_ratio,pe,pe_ttm,pb,ps,ps_ttm,'
|
||||
'dv_ratio,dv_ttm,total_share,float_share,free_share,'
|
||||
'total_mv,circ_mv'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
return {
|
||||
'ts_code': ts_code,
|
||||
'data': df.iloc[0].to_dict()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取每日指标失败: {e}")
|
||||
return None
|
||||
|
||||
# ==================== 资金流向 ====================
|
||||
|
||||
def get_money_flow(
|
||||
self,
|
||||
stock_code: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取资金流向数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
资金流向数据
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.moneyflow(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fields='ts_code,trade_date,buy_sm_vol,buy_sm_amount,'
|
||||
'sell_sm_vol,sell_sm_amount,buy_md_vol,buy_md_amount,'
|
||||
'sell_md_vol,sell_md_amount,buy_lg_vol,buy_lg_amount,'
|
||||
'sell_lg_vol,sell_lg_amount,buy_elg_vol,buy_elg_amount,'
|
||||
'sell_elg_vol,sell_elg_amount,net_mf_vol,net_mf_amount'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
df = df.sort_values('trade_date', ascending=False)
|
||||
return {
|
||||
'ts_code': ts_code,
|
||||
'data': df.to_dict('records')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取资金流向失败: {e}")
|
||||
return None
|
||||
|
||||
# ==================== 新闻公告 ====================
|
||||
|
||||
def get_news(
|
||||
self,
|
||||
stock_code: str = None,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
src: str = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取新闻资讯
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码(可选)
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
src: 新闻来源
|
||||
|
||||
Returns:
|
||||
新闻列表
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = None
|
||||
if stock_code:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=7)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
# 使用news接口(需要5000积分)
|
||||
try:
|
||||
df = self.pro.query('news',
|
||||
src=src,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fields='datetime,content,title,channels,score'
|
||||
)
|
||||
except Exception as api_error:
|
||||
# 如果接口不可用(积分不足或接口名称问题),返回None
|
||||
logger.warning(f"新闻接口不可用(可能需要更高积分权限): {api_error}")
|
||||
return None
|
||||
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
|
||||
# 如果指定了股票代码,过滤相关新闻
|
||||
if ts_code:
|
||||
try:
|
||||
# 简单的关键词过滤
|
||||
stock_info = self.pro.stock_basic(ts_code=ts_code, fields='name,symbol')
|
||||
if not stock_info.empty:
|
||||
name = stock_info.iloc[0]['name']
|
||||
symbol = stock_info.iloc[0]['symbol']
|
||||
df = df[
|
||||
df['title'].str.contains(name, na=False) |
|
||||
df['content'].str.contains(name, na=False) |
|
||||
df['title'].str.contains(symbol, na=False)
|
||||
]
|
||||
except Exception as filter_error:
|
||||
logger.warning(f"新闻过滤失败: {filter_error}")
|
||||
# 继续返回未过滤的新闻
|
||||
|
||||
df = df.sort_values('datetime', ascending=False)
|
||||
return df.head(10).to_dict('records')
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取新闻失败: {e}")
|
||||
return None
|
||||
|
||||
# ==================== 市场特色数据 ====================
|
||||
|
||||
def get_margin_detail(
|
||||
self,
|
||||
stock_code: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取融资融券详情
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
融资融券数据
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.margin_detail(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fields='ts_code,trade_date,rzye,rqye,rzmre,rqyl,'
|
||||
'rzche,rqchl,rqmcl,rzrqye'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
df = df.sort_values('trade_date', ascending=False)
|
||||
return {
|
||||
'ts_code': ts_code,
|
||||
'data': df.to_dict('records')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取融资融券失败: {e}")
|
||||
return None
|
||||
|
||||
def get_block_trade(
|
||||
self,
|
||||
stock_code: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取大宗交易数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
大宗交易列表
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=90)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.block_trade(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fields='ts_code,trade_date,price,vol,amount,buyer,seller'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
df = df.sort_values('trade_date', ascending=False)
|
||||
return df.to_dict('records')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取大宗交易失败: {e}")
|
||||
return None
|
||||
|
||||
def get_top_list(
|
||||
self,
|
||||
trade_date: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取龙虎榜数据
|
||||
|
||||
Args:
|
||||
trade_date: 交易日期
|
||||
|
||||
Returns:
|
||||
龙虎榜数据
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
if not trade_date:
|
||||
trade_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.top_list(
|
||||
trade_date=trade_date,
|
||||
fields='trade_date,ts_code,name,close,pct_change,turnover_rate,'
|
||||
'amount,l_sell,l_buy,l_amount,net_amount,net_rate,'
|
||||
'amount_rate,float_values,reason'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
return {
|
||||
'trade_date': trade_date,
|
||||
'data': df.to_dict('records')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取龙虎榜失败: {e}")
|
||||
return None
|
||||
|
||||
# ==================== 指数数据 ====================
|
||||
|
||||
def get_index_daily(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取指数日线行情
|
||||
|
||||
Args:
|
||||
ts_code: 指数代码(如000001.SH=上证指数)
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
指数行情数据
|
||||
"""
|
||||
if not self.pro:
|
||||
return None
|
||||
|
||||
try:
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=180)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
df = self.pro.index_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
fields='ts_code,trade_date,close,open,high,low,pre_close,'
|
||||
'change,pct_chg,vol,amount'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
df = df.sort_values('trade_date')
|
||||
return df.to_dict('records')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取指数数据失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
tushare_advanced_service = TushareAdvancedService()
|
||||
@ -1,266 +0,0 @@
|
||||
"""
|
||||
Tushare数据服务
|
||||
封装Tushare API调用
|
||||
"""
|
||||
import tushare as ts
|
||||
import pandas as pd
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from app.config import get_settings
|
||||
from app.utils.logger import logger
|
||||
from app.utils.validators import normalize_stock_code
|
||||
|
||||
|
||||
class TushareService:
|
||||
"""Tushare数据服务类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Tushare服务"""
|
||||
settings = get_settings()
|
||||
if not settings.tushare_token:
|
||||
logger.warning("Tushare token未配置")
|
||||
self.pro = None
|
||||
else:
|
||||
ts.set_token(settings.tushare_token)
|
||||
self.pro = ts.pro_api()
|
||||
logger.info("Tushare服务初始化成功")
|
||||
|
||||
def get_realtime_quote(self, stock_code: str) -> Optional[dict]:
|
||||
"""
|
||||
获取实时行情
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
行情数据字典
|
||||
"""
|
||||
if not self.pro:
|
||||
logger.error("Tushare服务未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 标准化股票代码
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
logger.error(f"无效的股票代码: {stock_code}")
|
||||
return None
|
||||
|
||||
# 获取最新交易日数据
|
||||
df = self.pro.daily(ts_code=ts_code, start_date='', end_date='')
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"未找到股票数据: {ts_code}")
|
||||
return None
|
||||
|
||||
# 取最新一条
|
||||
latest = df.iloc[0]
|
||||
|
||||
# 获取股票名称
|
||||
stock_info = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,name')
|
||||
name = stock_info.iloc[0]['name'] if not stock_info.empty else None
|
||||
|
||||
return {
|
||||
'ts_code': ts_code,
|
||||
'name': name,
|
||||
'trade_date': latest['trade_date'],
|
||||
'open': float(latest['open']),
|
||||
'high': float(latest['high']),
|
||||
'low': float(latest['low']),
|
||||
'close': float(latest['close']),
|
||||
'pre_close': float(latest['pre_close']),
|
||||
'change': float(latest['change']),
|
||||
'pct_chg': float(latest['pct_chg']),
|
||||
'vol': float(latest['vol']),
|
||||
'amount': float(latest['amount'])
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取实时行情失败: {e}")
|
||||
return None
|
||||
|
||||
def get_kline_data(
|
||||
self,
|
||||
stock_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
period: str = 'D'
|
||||
) -> Optional[List[dict]]:
|
||||
"""
|
||||
获取K线数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
start_date: 开始日期(YYYYMMDD)
|
||||
end_date: 结束日期(YYYYMMDD)
|
||||
period: 周期(D=日,W=周,M=月)
|
||||
|
||||
Returns:
|
||||
K线数据列表
|
||||
"""
|
||||
if not self.pro:
|
||||
logger.error("Tushare服务未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 标准化股票代码
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
logger.error(f"无效的股票代码: {stock_code}")
|
||||
return None
|
||||
|
||||
# 默认获取最近180个交易日(约6个月),确保技术指标计算准确
|
||||
# MA60需要至少60个交易日,加上缓冲期,180天可以覆盖约120个交易日
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=180)).strftime('%Y%m%d')
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
logger.info(f"获取K线数据: {stock_code}, 时间范围: {start_date} - {end_date}")
|
||||
|
||||
# 获取日线数据
|
||||
if period == 'D':
|
||||
df = self.pro.daily(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
elif period == 'W':
|
||||
df = self.pro.weekly(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
elif period == 'M':
|
||||
df = self.pro.monthly(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
else:
|
||||
logger.error(f"不支持的周期: {period}")
|
||||
return None
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"未找到K线数据: {ts_code}")
|
||||
return None
|
||||
|
||||
# 按日期升序排列
|
||||
df = df.sort_values('trade_date')
|
||||
|
||||
# 转换为字典列表
|
||||
kline_data = []
|
||||
for _, row in df.iterrows():
|
||||
kline_data.append({
|
||||
'ts_code': ts_code,
|
||||
'trade_date': row['trade_date'],
|
||||
'open': float(row['open']),
|
||||
'high': float(row['high']),
|
||||
'low': float(row['low']),
|
||||
'close': float(row['close']),
|
||||
'vol': float(row['vol']),
|
||||
'amount': float(row['amount']) if pd.notna(row['amount']) else None
|
||||
})
|
||||
|
||||
return kline_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取K线数据失败: {e}")
|
||||
return None
|
||||
|
||||
def get_stock_basic(self, stock_code: str) -> Optional[dict]:
|
||||
"""
|
||||
获取股票基本信息
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
基本信息字典
|
||||
"""
|
||||
if not self.pro:
|
||||
logger.error("Tushare服务未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
ts_code = normalize_stock_code(stock_code)
|
||||
if not ts_code:
|
||||
return None
|
||||
|
||||
df = self.pro.stock_basic(
|
||||
ts_code=ts_code,
|
||||
fields='ts_code,symbol,name,area,industry,market,list_date'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
info = df.iloc[0]
|
||||
return {
|
||||
'ts_code': info['ts_code'],
|
||||
'symbol': info['symbol'],
|
||||
'name': info['name'],
|
||||
'area': info['area'],
|
||||
'industry': info['industry'],
|
||||
'market': info['market'],
|
||||
'list_date': info['list_date']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票基本信息失败: {e}")
|
||||
return None
|
||||
|
||||
def search_stock(self, keyword: str) -> Optional[List[dict]]:
|
||||
"""
|
||||
搜索股票(通过名称或代码)
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词(股票名称或代码)
|
||||
|
||||
Returns:
|
||||
匹配的股票列表
|
||||
"""
|
||||
if not self.pro:
|
||||
logger.error("Tushare服务未初始化")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 获取所有股票列表
|
||||
df = self.pro.stock_basic(
|
||||
fields='ts_code,symbol,name,area,industry,market,list_date'
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# 搜索匹配的股票
|
||||
# 1. 精确匹配代码
|
||||
exact_match = df[df['symbol'] == keyword]
|
||||
if not exact_match.empty:
|
||||
return [exact_match.iloc[0].to_dict()]
|
||||
|
||||
# 2. 模糊匹配名称
|
||||
name_match = df[df['name'].str.contains(keyword, na=False)]
|
||||
if not name_match.empty:
|
||||
results = []
|
||||
for _, row in name_match.iterrows():
|
||||
results.append(row.to_dict())
|
||||
return results[:5] # 最多返回5个结果
|
||||
|
||||
# 3. 模糊匹配代码
|
||||
code_match = df[df['symbol'].str.contains(keyword, na=False)]
|
||||
if not code_match.empty:
|
||||
results = []
|
||||
for _, row in code_match.iterrows():
|
||||
results.append(row.to_dict())
|
||||
return results[:5]
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索股票失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
tushare_service = TushareService()
|
||||
@ -1,321 +0,0 @@
|
||||
"""
|
||||
美股数据服务 - 使用 yfinance 获取美股数据
|
||||
"""
|
||||
from typing import Optional, Dict, Any, List
|
||||
import yfinance as yf
|
||||
from datetime import datetime, timedelta
|
||||
import pandas as pd
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class USStockService:
|
||||
"""美股数据服务类(支持美股和港股)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化美股数据服务"""
|
||||
self.cache = {} # 简单的内存缓存
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hk_symbol(symbol: str) -> str:
|
||||
"""
|
||||
标准化港股代码格式为 yfinance 要求的格式
|
||||
- 4位及以下:左侧补零到4位,如 700.HK → 0700.HK, 5.HK → 0005.HK
|
||||
- 5位及以上:去掉前导零,如 09618.HK → 9618.HK
|
||||
"""
|
||||
if not symbol.endswith('.HK'):
|
||||
return symbol
|
||||
|
||||
# 分离代码和后缀
|
||||
code_part = symbol[:-3] # 去掉 .HK
|
||||
suffix = '.HK'
|
||||
|
||||
# 如果是纯数字代码
|
||||
if code_part.isdigit():
|
||||
# 4位及以下:补零到4位
|
||||
if len(code_part) <= 4:
|
||||
normalized_code = code_part.zfill(4)
|
||||
# 5位及以上:去掉前导零
|
||||
else:
|
||||
normalized_code = code_part.lstrip('0') or '0'
|
||||
else:
|
||||
normalized_code = code_part
|
||||
|
||||
return normalized_code + suffix
|
||||
|
||||
def get_stock_info(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取美股基本信息
|
||||
|
||||
Args:
|
||||
symbol: 股票代码(如 AAPL, TSLA 或 0700.HK)
|
||||
|
||||
Returns:
|
||||
股票基本信息字典
|
||||
"""
|
||||
try:
|
||||
# 标准化港股代码格式
|
||||
normalized_symbol = self._normalize_hk_symbol(symbol)
|
||||
|
||||
stock = yf.Ticker(normalized_symbol)
|
||||
info = stock.info
|
||||
|
||||
if not info or 'symbol' not in info:
|
||||
logger.warning(f"未找到股票: {symbol}")
|
||||
return None
|
||||
|
||||
# 提取关键信息
|
||||
result = {
|
||||
"symbol": symbol,
|
||||
"name": info.get("longName", info.get("shortName", symbol)),
|
||||
"sector": info.get("sector", "未知"),
|
||||
"industry": info.get("industry", "未知"),
|
||||
"market_cap": info.get("marketCap", 0),
|
||||
"current_price": info.get("currentPrice", info.get("regularMarketPrice", 0)),
|
||||
"previous_close": info.get("previousClose", 0),
|
||||
"open": info.get("open", 0),
|
||||
"day_high": info.get("dayHigh", 0),
|
||||
"day_low": info.get("dayLow", 0),
|
||||
"volume": info.get("volume", 0),
|
||||
"avg_volume": info.get("averageVolume", 0),
|
||||
"pe_ratio": info.get("trailingPE", 0),
|
||||
"forward_pe": info.get("forwardPE", 0),
|
||||
"pb_ratio": info.get("priceToBook", 0),
|
||||
"dividend_yield": info.get("dividendYield", 0),
|
||||
"52_week_high": info.get("fiftyTwoWeekHigh", 0),
|
||||
"52_week_low": info.get("fiftyTwoWeekLow", 0),
|
||||
"50_day_avg": info.get("fiftyDayAverage", 0),
|
||||
"200_day_avg": info.get("twoHundredDayAverage", 0),
|
||||
"beta": info.get("beta", 0),
|
||||
"eps": info.get("trailingEps", 0),
|
||||
"description": info.get("longBusinessSummary", ""),
|
||||
}
|
||||
|
||||
logger.info(f"获取美股信息成功: {symbol}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股信息失败 {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_historical_data(
|
||||
self,
|
||||
symbol: str,
|
||||
period: str = "1mo",
|
||||
interval: str = "1d"
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取美股历史K线数据
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
period: 时间周期 (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max)
|
||||
interval: K线间隔 (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo)
|
||||
|
||||
Returns:
|
||||
包含OHLCV数据的DataFrame
|
||||
"""
|
||||
try:
|
||||
# 标准化港股代码格式
|
||||
normalized_symbol = self._normalize_hk_symbol(symbol)
|
||||
|
||||
stock = yf.Ticker(normalized_symbol)
|
||||
hist = stock.history(period=period, interval=interval)
|
||||
|
||||
if hist.empty:
|
||||
logger.warning(f"未找到历史数据: {symbol}")
|
||||
return None
|
||||
|
||||
logger.info(f"获取美股历史数据成功: {symbol}, 周期: {period}")
|
||||
return hist
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股历史数据失败 {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_financial_data(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取美股财务数据
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
|
||||
Returns:
|
||||
财务数据字典
|
||||
"""
|
||||
try:
|
||||
# 标准化港股代码格式
|
||||
normalized_symbol = self._normalize_hk_symbol(symbol)
|
||||
|
||||
stock = yf.Ticker(normalized_symbol)
|
||||
|
||||
# 获取财务报表
|
||||
financials = stock.financials
|
||||
balance_sheet = stock.balance_sheet
|
||||
cashflow = stock.cashflow
|
||||
|
||||
result = {
|
||||
"symbol": symbol,
|
||||
"income_statement": financials.to_dict() if not financials.empty else {},
|
||||
"balance_sheet": balance_sheet.to_dict() if not balance_sheet.empty else {},
|
||||
"cash_flow": cashflow.to_dict() if not cashflow.empty else {},
|
||||
}
|
||||
|
||||
# 获取关键财务指标
|
||||
info = stock.info
|
||||
result["key_metrics"] = {
|
||||
"revenue": info.get("totalRevenue", 0),
|
||||
"gross_profit": info.get("grossProfits", 0),
|
||||
"ebitda": info.get("ebitda", 0),
|
||||
"net_income": info.get("netIncomeToCommon", 0),
|
||||
"total_assets": info.get("totalAssets", 0),
|
||||
"total_debt": info.get("totalDebt", 0),
|
||||
"total_cash": info.get("totalCash", 0),
|
||||
"operating_cash_flow": info.get("operatingCashflow", 0),
|
||||
"free_cash_flow": info.get("freeCashflow", 0),
|
||||
"roe": info.get("returnOnEquity", 0),
|
||||
"roa": info.get("returnOnAssets", 0),
|
||||
"profit_margin": info.get("profitMargins", 0),
|
||||
"operating_margin": info.get("operatingMargins", 0),
|
||||
}
|
||||
|
||||
logger.info(f"获取美股财务数据成功: {symbol}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股财务数据失败 {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def calculate_technical_indicators(self, hist: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
计算技术指标
|
||||
|
||||
Args:
|
||||
hist: 历史数据DataFrame
|
||||
|
||||
Returns:
|
||||
技术指标字典
|
||||
"""
|
||||
try:
|
||||
if hist.empty or len(hist) < 20:
|
||||
return {}
|
||||
|
||||
close = hist['Close']
|
||||
|
||||
# 计算移动平均线
|
||||
ma5 = close.rolling(window=5).mean().iloc[-1] if len(close) >= 5 else None
|
||||
ma10 = close.rolling(window=10).mean().iloc[-1] if len(close) >= 10 else None
|
||||
ma20 = close.rolling(window=20).mean().iloc[-1] if len(close) >= 20 else None
|
||||
ma60 = close.rolling(window=60).mean().iloc[-1] if len(close) >= 60 else None
|
||||
|
||||
# 计算RSI(使用 Wilder's Smoothing 方法)
|
||||
delta = close.diff()
|
||||
gain = delta.where(delta > 0, 0)
|
||||
loss = -delta.where(delta < 0, 0)
|
||||
# 使用 EMA (Wilder's Smoothing) 而不是简单平均
|
||||
avg_gain = gain.ewm(alpha=1/14, adjust=False).mean()
|
||||
avg_loss = loss.ewm(alpha=1/14, adjust=False).mean()
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
rsi_value = rsi.iloc[-1] if len(rsi) >= 14 else None
|
||||
|
||||
# 计算MACD
|
||||
exp1 = close.ewm(span=12, adjust=False).mean()
|
||||
exp2 = close.ewm(span=26, adjust=False).mean()
|
||||
macd = exp1 - exp2
|
||||
signal = macd.ewm(span=9, adjust=False).mean()
|
||||
macd_value = macd.iloc[-1] if len(macd) >= 26 else None
|
||||
signal_value = signal.iloc[-1] if len(signal) >= 26 else None
|
||||
|
||||
# 计算布林带
|
||||
bb_middle = close.rolling(window=20).mean()
|
||||
bb_std = close.rolling(window=20).std()
|
||||
bb_upper = bb_middle + (bb_std * 2)
|
||||
bb_lower = bb_middle - (bb_std * 2)
|
||||
|
||||
result = {
|
||||
"ma5": float(ma5) if ma5 and not pd.isna(ma5) else None,
|
||||
"ma10": float(ma10) if ma10 and not pd.isna(ma10) else None,
|
||||
"ma20": float(ma20) if ma20 and not pd.isna(ma20) else None,
|
||||
"ma60": float(ma60) if ma60 and not pd.isna(ma60) else None,
|
||||
"rsi": float(rsi_value) if rsi_value and not pd.isna(rsi_value) else None,
|
||||
"macd": float(macd_value) if macd_value and not pd.isna(macd_value) else None,
|
||||
"macd_signal": float(signal_value) if signal_value and not pd.isna(signal_value) else None,
|
||||
"bb_upper": float(bb_upper.iloc[-1]) if len(bb_upper) >= 20 and not pd.isna(bb_upper.iloc[-1]) else None,
|
||||
"bb_middle": float(bb_middle.iloc[-1]) if len(bb_middle) >= 20 and not pd.isna(bb_middle.iloc[-1]) else None,
|
||||
"bb_lower": float(bb_lower.iloc[-1]) if len(bb_lower) >= 20 and not pd.isna(bb_lower.iloc[-1]) else None,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算技术指标失败: {e}")
|
||||
return {}
|
||||
|
||||
def get_comprehensive_analysis(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取美股综合分析数据
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
|
||||
Returns:
|
||||
综合分析数据字典
|
||||
"""
|
||||
try:
|
||||
# 获取基本信息
|
||||
info = self.get_stock_info(symbol)
|
||||
if not info:
|
||||
return None
|
||||
|
||||
# 获取历史数据
|
||||
hist = self.get_historical_data(symbol, period="6mo", interval="1d")
|
||||
if hist is None or hist.empty:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "无法获取历史数据"
|
||||
}
|
||||
|
||||
# 计算技术指标
|
||||
technical = self.calculate_technical_indicators(hist)
|
||||
|
||||
# 获取最近的价格数据
|
||||
latest = hist.iloc[-1]
|
||||
prev = hist.iloc[-2] if len(hist) > 1 else latest
|
||||
|
||||
# 计算涨跌幅
|
||||
change = latest['Close'] - prev['Close']
|
||||
change_pct = (change / prev['Close'] * 100) if prev['Close'] != 0 else 0
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"symbol": symbol,
|
||||
"name": info["name"],
|
||||
"sector": info["sector"],
|
||||
"industry": info["industry"],
|
||||
"current_price": float(latest['Close']),
|
||||
"change": float(change),
|
||||
"change_percent": float(change_pct),
|
||||
"volume": int(latest['Volume']),
|
||||
"market_cap": info["market_cap"],
|
||||
"pe_ratio": info["pe_ratio"],
|
||||
"pb_ratio": info["pb_ratio"],
|
||||
"dividend_yield": info["dividend_yield"],
|
||||
"52_week_high": info["52_week_high"],
|
||||
"52_week_low": info["52_week_low"],
|
||||
"technical_indicators": technical,
|
||||
"description": info["description"][:500] if info["description"] else "",
|
||||
}
|
||||
|
||||
logger.info(f"获取美股综合分析成功: {symbol}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股综合分析失败 {symbol}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
us_stock_service = USStockService()
|
||||
@ -1,438 +0,0 @@
|
||||
"""
|
||||
YFinance 服务 - 美股港股数据获取
|
||||
支持获取美股的实时行情和历史 K 线数据
|
||||
备用数据源:Stooq
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from app.utils.logger import logger
|
||||
import time
|
||||
|
||||
|
||||
class YFinanceService:
|
||||
"""YFinance 服务类(支持 Stooq 备用)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
# 初始化 YFinance
|
||||
try:
|
||||
import yfinance as yf
|
||||
self.yf = yf
|
||||
self._yf_available = True
|
||||
logger.info("YFinance 服务初始化成功")
|
||||
except ImportError:
|
||||
logger.warning("yfinance 未安装")
|
||||
self._yf_available = False
|
||||
|
||||
# 初始化 Stooq(备用)
|
||||
try:
|
||||
import pandas_datareader.data as web
|
||||
self.web = web
|
||||
self._stooq_available = True
|
||||
logger.info("Stooq 备用数据源初始化成功")
|
||||
except ImportError:
|
||||
logger.warning("pandas_datareader 未安装,Stooq 备用不可用")
|
||||
self._stooq_available = False
|
||||
|
||||
if not self._yf_available and not self._stooq_available:
|
||||
raise Exception("没有可用的数据源,请安装 yfinance 或 pandas_datareader")
|
||||
|
||||
self._cache = {} # 数据缓存
|
||||
self._cache_time = {} # 缓存时间
|
||||
self._cache_ttl = 300 # 缓存有效期(秒)
|
||||
|
||||
def _normalize_hk_symbol(self, symbol: str) -> str:
|
||||
"""
|
||||
标准化港股代码格式为 yfinance 要求的格式
|
||||
- 4位及以下:左侧补零到4位,如 700.HK → 0700.HK, 5.HK → 0005.HK
|
||||
- 5位及以上:去掉前导零,如 09618.HK → 9618.HK
|
||||
"""
|
||||
if not symbol.endswith('.HK'):
|
||||
return symbol
|
||||
|
||||
# 分离代码和后缀
|
||||
code_part = symbol[:-3] # 去掉 .HK
|
||||
suffix = '.HK'
|
||||
|
||||
# 如果是纯数字代码
|
||||
if code_part.isdigit():
|
||||
# 4位及以下:补零到4位
|
||||
if len(code_part) <= 4:
|
||||
normalized_code = code_part.zfill(4)
|
||||
# 5位及以上:去掉前导零
|
||||
else:
|
||||
normalized_code = code_part.lstrip('0') or '0'
|
||||
else:
|
||||
normalized_code = code_part
|
||||
|
||||
return normalized_code + suffix
|
||||
|
||||
def get_ticker(self, symbol: str) -> Optional[Dict]:
|
||||
"""
|
||||
获取股票实时行情(优先使用 YFinance,失败则使用 Stooq)
|
||||
|
||||
Args:
|
||||
symbol: 股票代码,如 'AAPL' 或 '0700.HK'
|
||||
|
||||
Returns:
|
||||
行情数据字典
|
||||
"""
|
||||
# 优先使用 YFinance
|
||||
if self._yf_available:
|
||||
result = self._get_yf_ticker(symbol)
|
||||
if result:
|
||||
return result
|
||||
logger.info(f"YFinance 获取失败,尝试使用 Stooq 备用数据源 ({symbol})")
|
||||
|
||||
# 备用使用 Stooq
|
||||
if self._stooq_available:
|
||||
result = self._get_stooq_ticker(symbol)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def _get_yf_ticker(self, symbol: str) -> Optional[Dict]:
|
||||
"""使用 YFinance 获取行情"""
|
||||
try:
|
||||
normalized_symbol = self._normalize_hk_symbol(symbol)
|
||||
ticker = self.yf.Ticker(normalized_symbol)
|
||||
hist = ticker.history(period="2d", interval="1h")
|
||||
|
||||
if hist.empty:
|
||||
logger.warning(f"YFinance 无法获取 {symbol} 的数据")
|
||||
return None
|
||||
|
||||
latest = hist.iloc[-1]
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'lastPrice': float(latest['Close']),
|
||||
'priceChange': float(latest['Close'] - latest['Open']),
|
||||
'priceChangePercent': float((latest['Close'] - latest['Open']) / latest['Open'] * 100) if latest['Open'] > 0 else 0,
|
||||
'volume': int(latest['Volume']),
|
||||
'high': float(latest['High']),
|
||||
'low': float(latest['Low']),
|
||||
'open': float(latest['Open']),
|
||||
'prevClose': float(latest['Close']),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "429" in error_msg or "Too Many Requests" in error_msg:
|
||||
logger.warning(f"YFinance API 限流 ({symbol})")
|
||||
else:
|
||||
logger.debug(f"YFinance 获取失败 ({symbol}): {error_msg}")
|
||||
return None
|
||||
|
||||
def _get_stooq_ticker(self, symbol: str) -> Optional[Dict]:
|
||||
"""使用 Stooq 获取行情(备用)"""
|
||||
try:
|
||||
# Stooq 使用的港股格式
|
||||
stooq_symbol = self._convert_to_stooq_symbol(symbol)
|
||||
|
||||
# 获取最近几天的数据
|
||||
start_date = (datetime.now() - timedelta(days=5)).strftime('%Y-%m-%d')
|
||||
df = self.web.DataReader(stooq_symbol, 'stooq', start=start_date)
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"Stooq 无法获取 {symbol} 的数据")
|
||||
return None
|
||||
|
||||
# Stooq 返回的数据是倒序的,取第一行(最新)
|
||||
latest = df.iloc[-1]
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'lastPrice': float(latest['Close']),
|
||||
'priceChange': float(latest['Close'] - latest['Open']),
|
||||
'priceChangePercent': float((latest['Close'] - latest['Open']) / latest['Open'] * 100) if latest['Open'] > 0 else 0,
|
||||
'volume': int(latest['Volume']),
|
||||
'high': float(latest['High']),
|
||||
'low': float(latest['Low']),
|
||||
'open': float(latest['Open']),
|
||||
'prevClose': float(latest['Close']),
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'source': 'stooq' # 标记数据来源
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Stooq 获取 {symbol} 行情失败: {e}")
|
||||
return None
|
||||
|
||||
def _convert_to_stooq_symbol(self, symbol: str) -> str:
|
||||
"""
|
||||
转换股票代码为 Stooq 格式
|
||||
|
||||
美股:AAPL -> AAPL.US
|
||||
港股:0700.HK -> 0700.HK
|
||||
"""
|
||||
if symbol.endswith('.HK'):
|
||||
return symbol
|
||||
elif '.' in symbol:
|
||||
# 其他格式保持不变
|
||||
return symbol
|
||||
else:
|
||||
# 美股添加 .US 后缀
|
||||
return f"{symbol}.US"
|
||||
|
||||
def get_multi_timeframe_data(
|
||||
self,
|
||||
symbol: str,
|
||||
timeframes: Optional[Dict[str, tuple]] = None
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取多时间周期的 K 线数据
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
timeframes: 时间周期配置 {'1d': ('1d', '3mo'), ...}
|
||||
|
||||
Returns:
|
||||
多时间周期数据字典 {'1d': df, '1h': df, ...}
|
||||
"""
|
||||
if timeframes is None:
|
||||
# 技术面分析时间周期:1h、1d、1w
|
||||
timeframes = {
|
||||
'1w': ('1wk', '2y'), # 周级别,2年 - 长期趋势
|
||||
'1d': ('1d', '6mo'), # 日级别,6个月 - 中期趋势
|
||||
'1h': ('1h', '1mo'), # 小时级别,1个月 - 短期趋势
|
||||
}
|
||||
|
||||
result = {}
|
||||
|
||||
for tf_name, (interval, period) in timeframes.items():
|
||||
try:
|
||||
df = self._get_cached_data(symbol, interval, period)
|
||||
if df is not None and not df.empty:
|
||||
result[tf_name] = df
|
||||
logger.debug(f"获取 {symbol} {tf_name} 数据成功: {len(df)} 条")
|
||||
else:
|
||||
logger.warning(f"获取 {symbol} {tf_name} 数据失败或为空")
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {symbol} {tf_name} 数据出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def _get_cached_data(
|
||||
self,
|
||||
symbol: str,
|
||||
interval: str,
|
||||
period: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""获取带缓存的数据(优先 YFinance,失败则使用 Stooq)"""
|
||||
# 标准化港股代码格式
|
||||
normalized_symbol = self._normalize_hk_symbol(symbol)
|
||||
cache_key = f"{normalized_symbol}_{interval}_{period}"
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self._cache:
|
||||
cache_time = self._cache_time.get(cache_key)
|
||||
if cache_time and (now - cache_time).total_seconds() < self._cache_ttl:
|
||||
logger.debug(f"使用缓存数据: {cache_key}")
|
||||
return self._cache[cache_key]
|
||||
|
||||
# 优先使用 YFinance
|
||||
if self._yf_available:
|
||||
df = self._get_yf_data(symbol, interval, period, cache_key, now)
|
||||
if df is not None:
|
||||
return df
|
||||
logger.info(f"YFinance 获取历史数据失败,尝试 Stooq ({symbol})")
|
||||
|
||||
# 备用使用 Stooq
|
||||
if self._stooq_available:
|
||||
df = self._get_stooq_data(symbol, interval, period, cache_key, now)
|
||||
if df is not None:
|
||||
logger.info(f"✓ 使用 Stooq 数据源 ({symbol})")
|
||||
return df
|
||||
|
||||
return None
|
||||
|
||||
def _get_yf_data(
|
||||
self,
|
||||
symbol: str,
|
||||
interval: str,
|
||||
period: str,
|
||||
cache_key: str,
|
||||
now: datetime
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""使用 YFinance 获取历史数据"""
|
||||
try:
|
||||
normalized_symbol = self._normalize_hk_symbol(symbol)
|
||||
ticker = self.yf.Ticker(normalized_symbol)
|
||||
df = ticker.history(period=period, interval=interval)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# 转换数据格式
|
||||
df = self._format_dataframe(df)
|
||||
|
||||
# 更新缓存
|
||||
self._cache[cache_key] = df
|
||||
self._cache_time[cache_key] = now
|
||||
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.debug(f"YFinance 获取历史数据失败: {e}")
|
||||
return None
|
||||
|
||||
def _get_stooq_data(
|
||||
self,
|
||||
symbol: str,
|
||||
interval: str,
|
||||
period: str,
|
||||
cache_key: str,
|
||||
now: datetime
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""使用 Stooq 获取历史数据(备用)"""
|
||||
try:
|
||||
# 转换为 Stooq 格式
|
||||
stooq_symbol = self._convert_to_stooq_symbol(symbol)
|
||||
|
||||
# 将 period 转换为天数
|
||||
period_days = self._period_to_days(period)
|
||||
start_date = (datetime.now() - timedelta(days=period_days)).strftime('%Y-%m-%d')
|
||||
|
||||
# 获取数据
|
||||
df = self.web.DataReader(stooq_symbol, 'stooq', start=start_date)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# Stooq 数据是倒序的,需要反转
|
||||
df = df.iloc[::-1]
|
||||
|
||||
# 转换数据格式
|
||||
df = self._format_dataframe(df)
|
||||
|
||||
# 更新缓存
|
||||
self._cache[cache_key] = df
|
||||
self._cache_time[cache_key] = now
|
||||
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.debug(f"Stooq 获取历史数据失败: {e}")
|
||||
return None
|
||||
|
||||
def _period_to_days(self, period: str) -> int:
|
||||
"""将 YFinance period 格式转换为天数"""
|
||||
period_map = {
|
||||
'1mo': 30,
|
||||
'3mo': 90,
|
||||
'6mo': 180,
|
||||
'1y': 365,
|
||||
'2y': 730,
|
||||
}
|
||||
return period_map.get(period, 180) # 默认6个月
|
||||
|
||||
def _format_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
格式化 DataFrame 以兼容现有代码
|
||||
|
||||
yfinance 原始格式:
|
||||
- 列名大写: Open, High, Low, Close, Volume
|
||||
- 索引是 Datetime
|
||||
|
||||
转换后格式:
|
||||
- 列名小写: open, high, low, close, volume
|
||||
- 重置索引,time 作为一列
|
||||
- 添加技术指标
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# 列名转为小写
|
||||
df.columns = [col.lower() for col in df.columns]
|
||||
|
||||
# 重置索引
|
||||
df = df.reset_index()
|
||||
|
||||
# 重命名日期列
|
||||
if 'date' in df.columns:
|
||||
df = df.rename(columns={'date': 'time'})
|
||||
elif 'datetime' in df.columns:
|
||||
df = df.rename(columns={'datetime': 'time'})
|
||||
|
||||
# 删除不需要的列
|
||||
cols_to_keep = ['time', 'open', 'high', 'low', 'close', 'volume']
|
||||
df = df[[col for col in cols_to_keep if col in df.columns]]
|
||||
|
||||
# 添加技术指标(与 binance_service 一致)
|
||||
df = self._add_indicators(df)
|
||||
|
||||
return df
|
||||
|
||||
def _add_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
添加技术指标到 DataFrame
|
||||
|
||||
Args:
|
||||
df: 原始数据
|
||||
|
||||
Returns:
|
||||
添加了技术指标的 DataFrame
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# 移动平均线(简单移动平均 MA)
|
||||
df['ma5'] = df['close'].rolling(window=5).mean()
|
||||
df['ma10'] = df['close'].rolling(window=10).mean()
|
||||
df['ma20'] = df['close'].rolling(window=20).mean()
|
||||
df['ma50'] = df['close'].rolling(window=50).mean()
|
||||
|
||||
# 指数移动平均线(EMA)- 用于趋势判断
|
||||
df['ema20'] = df['close'].ewm(span=20, adjust=False).mean()
|
||||
df['ema50'] = df['close'].ewm(span=50, adjust=False).mean()
|
||||
df['ema200'] = df['close'].ewm(span=200, adjust=False).mean()
|
||||
|
||||
# RSI(使用 Wilder's Smoothing 方法)
|
||||
delta = df['close'].diff()
|
||||
gain = delta.where(delta > 0, 0)
|
||||
loss = -delta.where(delta < 0, 0)
|
||||
# 使用 EMA (Wilder's Smoothing) 而不是简单平均
|
||||
avg_gain = gain.ewm(alpha=1/14, adjust=False).mean()
|
||||
avg_loss = loss.ewm(alpha=1/14, adjust=False).mean()
|
||||
rs = avg_gain / avg_loss
|
||||
df['rsi'] = 100 - (100 / (1 + rs))
|
||||
|
||||
# MACD (使用与 binance_service 相同的计算方法)
|
||||
ema_fast = df['close'].ewm(span=12, adjust=False).mean()
|
||||
ema_slow = df['close'].ewm(span=26, adjust=False).mean()
|
||||
df['macd'] = ema_fast - ema_slow
|
||||
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
|
||||
df['macd_hist'] = df['macd'] - df['macd_signal']
|
||||
|
||||
# ATR
|
||||
high_low = df['high'] - df['low']
|
||||
high_close = abs(df['high'] - df['close'].shift())
|
||||
low_close = abs(df['low'] - df['close'].shift())
|
||||
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
||||
df['atr'] = true_range.rolling(window=14).mean()
|
||||
|
||||
# KDJ 指标
|
||||
low_min = df['low'].rolling(window=9).min()
|
||||
high_max = df['high'].rolling(window=9).max()
|
||||
rsv = (df['close'] - low_min) / (high_max - low_min) * 100
|
||||
df['k'] = rsv.ewm(com=2, adjust=False).mean()
|
||||
df['d'] = df['k'].ewm(com=2, adjust=False).mean()
|
||||
df['j'] = 3 * df['k'] - 2 * df['d']
|
||||
|
||||
return df
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
self._cache_time.clear()
|
||||
logger.info("YFinance 缓存已清空")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_yfinance_service: Optional[YFinanceService] = None
|
||||
|
||||
|
||||
def get_yfinance_service() -> YFinanceService:
|
||||
"""获取 YFinance 服务单例"""
|
||||
global _yfinance_service
|
||||
if _yfinance_service is None:
|
||||
_yfinance_service = YFinanceService()
|
||||
return _yfinance_service
|
||||
@ -1,120 +0,0 @@
|
||||
"""
|
||||
高级数据技能
|
||||
封装Tushare Pro高级数据接口(需要5000+积分)
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
from app.services.tushare_advanced_service import tushare_advanced_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class AdvancedDataSkill(BaseSkill):
|
||||
"""高级数据技能"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "advanced_data"
|
||||
self.description = "获取高级财务数据、估值数据、资金流向等(Tushare Pro 5000+积分)"
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="stock_code",
|
||||
type="string",
|
||||
description="股票代码",
|
||||
required=True
|
||||
),
|
||||
SkillParameter(
|
||||
name="data_type",
|
||||
type="string",
|
||||
description="数据类型:financial(财务)、valuation(估值)、money_flow(资金流向)、all(全部)",
|
||||
required=False,
|
||||
default="all"
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
执行高级数据查询
|
||||
|
||||
支持的数据类型:
|
||||
- financial: 财务数据(利润表、资产负债表、财务指标)
|
||||
- valuation: 估值数据(PE、PB、PS、市值等)
|
||||
- money_flow: 资金流向
|
||||
- margin: 融资融券
|
||||
- block_trade: 大宗交易
|
||||
- all: 全部数据
|
||||
"""
|
||||
stock_code = kwargs.get('stock_code')
|
||||
data_type = kwargs.get('data_type', 'all') # 默认获取所有数据
|
||||
|
||||
if not stock_code:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "缺少股票代码"
|
||||
}
|
||||
|
||||
try:
|
||||
result = {
|
||||
"success": True,
|
||||
"data": {}
|
||||
}
|
||||
|
||||
# 财务数据(利润表、资产负债表、财务指标)
|
||||
if data_type in ['financial', 'all']:
|
||||
financial_data = {}
|
||||
|
||||
# 获取财务指标(最重要)
|
||||
indicators = tushare_advanced_service.get_financial_indicators(stock_code)
|
||||
if indicators:
|
||||
financial_data['indicators'] = indicators
|
||||
|
||||
# 获取利润表(最近一期)
|
||||
income = tushare_advanced_service.get_income_statement(stock_code)
|
||||
if income and income.get('data'):
|
||||
financial_data['income'] = income['data'][0] if income['data'] else None
|
||||
|
||||
# 获取资产负债表(最近一期)
|
||||
balance = tushare_advanced_service.get_balance_sheet(stock_code)
|
||||
if balance and balance.get('data'):
|
||||
financial_data['balance'] = balance['data'][0] if balance['data'] else None
|
||||
|
||||
if financial_data:
|
||||
result['data']['financial'] = financial_data
|
||||
|
||||
# 估值数据
|
||||
if data_type in ['valuation', 'all']:
|
||||
valuation = tushare_advanced_service.get_daily_basic(stock_code)
|
||||
if valuation:
|
||||
result['data']['valuation'] = valuation.get('data')
|
||||
|
||||
# 资金流向
|
||||
if data_type in ['money_flow', 'all']:
|
||||
money_flow = tushare_advanced_service.get_money_flow(stock_code)
|
||||
if money_flow:
|
||||
# 只取最近5天的数据
|
||||
result['data']['money_flow'] = money_flow.get('data', [])[:5]
|
||||
|
||||
# 融资融券
|
||||
if data_type in ['margin', 'all']:
|
||||
margin = tushare_advanced_service.get_margin_detail(stock_code)
|
||||
if margin:
|
||||
# 只取最近5天的数据
|
||||
result['data']['margin'] = margin.get('data', [])[:5]
|
||||
|
||||
# 大宗交易
|
||||
if data_type in ['block_trade', 'all']:
|
||||
block_trade = tushare_advanced_service.get_block_trade(stock_code)
|
||||
if block_trade:
|
||||
# 只取最近10条
|
||||
result['data']['block_trade'] = block_trade[:10]
|
||||
|
||||
# 注意:重大公告功能已移除(需要特殊权限)
|
||||
|
||||
logger.info(f"获取高级数据成功: {stock_code}, 类型: {data_type}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取高级数据失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
@ -1,78 +0,0 @@
|
||||
"""
|
||||
技能基类
|
||||
所有技能插件的基类
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SkillParameter(BaseModel):
|
||||
"""技能参数定义"""
|
||||
name: str = Field(..., description="参数名称")
|
||||
type: str = Field(..., description="参数类型")
|
||||
description: str = Field(..., description="参数描述")
|
||||
required: bool = Field(True, description="是否必需")
|
||||
default: Optional[Any] = Field(None, description="默认值")
|
||||
|
||||
|
||||
class BaseSkill(ABC):
|
||||
"""技能基类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化技能"""
|
||||
self.name: str = ""
|
||||
self.description: str = ""
|
||||
self.parameters: list[SkillParameter] = []
|
||||
self.enabled: bool = True
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
执行技能
|
||||
|
||||
Args:
|
||||
**kwargs: 技能参数
|
||||
|
||||
Returns:
|
||||
执行结果字典
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_params(self, **kwargs) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证参数
|
||||
|
||||
Args:
|
||||
**kwargs: 参数字典
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
for param in self.parameters:
|
||||
if param.required and param.name not in kwargs:
|
||||
return False, f"缺少必需参数: {param.name}"
|
||||
|
||||
return True, None
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取技能信息
|
||||
|
||||
Returns:
|
||||
技能信息字典
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": [p.dict() for p in self.parameters],
|
||||
"enabled": self.enabled
|
||||
}
|
||||
|
||||
def enable(self):
|
||||
"""启用技能"""
|
||||
self.enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""禁用技能"""
|
||||
self.enabled = False
|
||||
@ -1,180 +0,0 @@
|
||||
"""
|
||||
Brave搜索技能
|
||||
提供网页搜索、新闻搜索等能力
|
||||
"""
|
||||
import aiohttp
|
||||
from typing import Dict, Any, List, Optional
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class BraveSearchSkill(BaseSkill):
|
||||
"""Brave搜索技能"""
|
||||
|
||||
def __init__(self, api_key: str = "BSAcaROCUmCAI0XsQWzxooWT74LFFX_"):
|
||||
super().__init__()
|
||||
self.name = "brave_search"
|
||||
self.description = "使用Brave搜索引擎搜索网页、新闻、公司公告等实时信息"
|
||||
self.api_key = api_key
|
||||
self.base_url = "https://api.search.brave.com/res/v1"
|
||||
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="query",
|
||||
type="string",
|
||||
description="搜索关键词",
|
||||
required=True
|
||||
),
|
||||
SkillParameter(
|
||||
name="search_type",
|
||||
type="string",
|
||||
description="搜索类型:web(网页)、news(新闻)",
|
||||
required=False,
|
||||
default="web"
|
||||
),
|
||||
SkillParameter(
|
||||
name="count",
|
||||
type="integer",
|
||||
description="返回结果数量(1-20)",
|
||||
required=False,
|
||||
default=5
|
||||
),
|
||||
SkillParameter(
|
||||
name="freshness",
|
||||
type="string",
|
||||
description="时效性:pd(过去一天)、pw(过去一周)、pm(过去一月)、py(过去一年)",
|
||||
required=False,
|
||||
default=None
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
执行Brave搜索
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
search_type: 搜索类型(web/news)
|
||||
count: 结果数量
|
||||
freshness: 时效性过滤
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
query = kwargs.get("query")
|
||||
search_type = kwargs.get("search_type", "web")
|
||||
count = kwargs.get("count", 5)
|
||||
freshness = kwargs.get("freshness")
|
||||
|
||||
logger.info(f"Brave搜索: {query}, 类型: {search_type}")
|
||||
|
||||
try:
|
||||
if search_type == "news":
|
||||
results = await self._search_news(query, count, freshness)
|
||||
else:
|
||||
results = await self._search_web(query, count, freshness)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"search_type": search_type,
|
||||
"results": results,
|
||||
"count": len(results)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Brave搜索失败: {e}")
|
||||
return {
|
||||
"error": f"搜索失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _search_web(
|
||||
self,
|
||||
query: str,
|
||||
count: int = 5,
|
||||
freshness: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""网页搜索"""
|
||||
url = f"{self.base_url}/web/search"
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"count": min(count, 20),
|
||||
"text_decorations": False,
|
||||
"search_lang": "zh-hans"
|
||||
}
|
||||
|
||||
if freshness:
|
||||
params["freshness"] = freshness
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"X-Subscription-Token": self.api_key
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"API请求失败: {response.status}, {error_text}")
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# 解析结果
|
||||
results = []
|
||||
web_results = data.get("web", {}).get("results", [])
|
||||
|
||||
for item in web_results[:count]:
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"description": item.get("description", ""),
|
||||
"published": item.get("age", "")
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def _search_news(
|
||||
self,
|
||||
query: str,
|
||||
count: int = 5,
|
||||
freshness: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""新闻搜索"""
|
||||
url = f"{self.base_url}/news/search"
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"count": min(count, 20),
|
||||
"search_lang": "zh-hans"
|
||||
}
|
||||
|
||||
if freshness:
|
||||
params["freshness"] = freshness
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"X-Subscription-Token": self.api_key
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"API请求失败: {response.status}, {error_text}")
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# 解析结果
|
||||
results = []
|
||||
news_results = data.get("results", [])
|
||||
|
||||
for item in news_results[:count]:
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"description": item.get("description", ""),
|
||||
"published": item.get("age", ""),
|
||||
"source": item.get("meta_url", {}).get("hostname", "")
|
||||
})
|
||||
|
||||
return results
|
||||
@ -1,61 +0,0 @@
|
||||
"""
|
||||
基本面分析技能
|
||||
提供股票基本信息查询
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
from app.services.tushare_service import tushare_service
|
||||
from app.services.cache_service import cache_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class FundamentalSkill(BaseSkill):
|
||||
"""基本面分析技能"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "fundamental"
|
||||
self.description = "查询股票基本面信息(公司概况、行业、上市日期等)"
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="stock_code",
|
||||
type="string",
|
||||
description="股票代码",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
执行基本面查询
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
基本面信息
|
||||
"""
|
||||
stock_code = kwargs.get("stock_code")
|
||||
|
||||
logger.info(f"查询基本面信息: {stock_code}")
|
||||
|
||||
# 尝试从缓存获取
|
||||
cache_key = f"fundamental:{stock_code}"
|
||||
cached_data = cache_service.get(cache_key)
|
||||
|
||||
if cached_data:
|
||||
logger.info(f"从缓存获取基本面信息: {stock_code}")
|
||||
return cached_data
|
||||
|
||||
# 从Tushare获取
|
||||
basic_info = tushare_service.get_stock_basic(stock_code)
|
||||
|
||||
if not basic_info:
|
||||
return {
|
||||
"error": f"未找到股票基本信息: {stock_code}"
|
||||
}
|
||||
|
||||
# 缓存1天
|
||||
cache_service.set(cache_key, basic_info, ttl=86400)
|
||||
|
||||
return basic_info
|
||||
@ -1,140 +0,0 @@
|
||||
"""
|
||||
行情查询技能
|
||||
提供股票实时行情和K线数据查询
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
from app.services.tushare_service import tushare_service
|
||||
from app.services.cache_service import cache_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class MarketDataSkill(BaseSkill):
|
||||
"""行情查询技能"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "market_data"
|
||||
self.description = "查询股票实时行情和历史K线数据"
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="stock_code",
|
||||
type="string",
|
||||
description="股票代码(如600000、000001)",
|
||||
required=True
|
||||
),
|
||||
SkillParameter(
|
||||
name="data_type",
|
||||
type="string",
|
||||
description="数据类型:quote(实时行情)或kline(K线数据)",
|
||||
required=False,
|
||||
default="quote"
|
||||
),
|
||||
SkillParameter(
|
||||
name="start_date",
|
||||
type="string",
|
||||
description="开始日期(YYYYMMDD格式,仅K线数据需要)",
|
||||
required=False
|
||||
),
|
||||
SkillParameter(
|
||||
name="end_date",
|
||||
type="string",
|
||||
description="结束日期(YYYYMMDD格式,仅K线数据需要)",
|
||||
required=False
|
||||
),
|
||||
SkillParameter(
|
||||
name="period",
|
||||
type="string",
|
||||
description="K线周期:D(日线)、W(周线)、M(月线)",
|
||||
required=False,
|
||||
default="D"
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
执行行情查询
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
data_type: 数据类型(quote/kline)
|
||||
start_date: 开始日期(可选)
|
||||
end_date: 结束日期(可选)
|
||||
period: K线周期(可选)
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
stock_code = kwargs.get("stock_code")
|
||||
data_type = kwargs.get("data_type", "quote")
|
||||
|
||||
logger.info(f"查询行情数据: {stock_code}, 类型: {data_type}")
|
||||
|
||||
if data_type == "quote":
|
||||
return await self._get_quote(stock_code)
|
||||
elif data_type == "kline":
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
period = kwargs.get("period", "D")
|
||||
return await self._get_kline(stock_code, start_date, end_date, period)
|
||||
else:
|
||||
return {
|
||||
"error": f"不支持的数据类型: {data_type}"
|
||||
}
|
||||
|
||||
async def _get_quote(self, stock_code: str) -> Dict[str, Any]:
|
||||
"""获取实时行情"""
|
||||
# 尝试从缓存获取
|
||||
cache_key = f"quote:{stock_code}"
|
||||
cached_data = cache_service.get(cache_key)
|
||||
|
||||
if cached_data:
|
||||
logger.info(f"从缓存获取行情: {stock_code}")
|
||||
return cached_data
|
||||
|
||||
# 从Tushare获取
|
||||
quote_data = tushare_service.get_realtime_quote(stock_code)
|
||||
|
||||
if not quote_data:
|
||||
return {
|
||||
"error": f"未找到股票数据: {stock_code}"
|
||||
}
|
||||
|
||||
# 缓存30秒
|
||||
cache_service.set(cache_key, quote_data, ttl=30)
|
||||
|
||||
return quote_data
|
||||
|
||||
async def _get_kline(
|
||||
self,
|
||||
stock_code: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
period: str = "D"
|
||||
) -> Dict[str, Any]:
|
||||
"""获取K线数据"""
|
||||
# 尝试从缓存获取
|
||||
cache_key = f"kline:{stock_code}:{start_date}:{end_date}:{period}"
|
||||
cached_data = cache_service.get(cache_key)
|
||||
|
||||
if cached_data:
|
||||
logger.info(f"从缓存获取K线: {stock_code}")
|
||||
return {"kline_data": cached_data}
|
||||
|
||||
# 从Tushare获取
|
||||
kline_data = tushare_service.get_kline_data(
|
||||
stock_code,
|
||||
start_date,
|
||||
end_date,
|
||||
period
|
||||
)
|
||||
|
||||
if not kline_data:
|
||||
return {
|
||||
"error": f"未找到K线数据: {stock_code}"
|
||||
}
|
||||
|
||||
# 缓存1小时
|
||||
cache_service.set(cache_key, kline_data, ttl=3600)
|
||||
|
||||
return {"kline_data": kline_data}
|
||||
@ -1,202 +0,0 @@
|
||||
"""
|
||||
技术分析技能
|
||||
提供技术指标计算和分析
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Dict, Any
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
from app.services.tushare_service import tushare_service
|
||||
from app.utils.indicators import (
|
||||
calculate_ma, calculate_macd, calculate_rsi,
|
||||
calculate_kdj, calculate_boll
|
||||
)
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class TechnicalAnalysisSkill(BaseSkill):
|
||||
"""技术分析技能"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "technical_analysis"
|
||||
self.description = "计算股票技术指标(MA、MACD、RSI、KDJ、BOLL等)"
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="stock_code",
|
||||
type="string",
|
||||
description="股票代码",
|
||||
required=True
|
||||
),
|
||||
SkillParameter(
|
||||
name="indicators",
|
||||
type="array",
|
||||
description="要计算的指标列表(ma、macd、rsi、kdj、boll)",
|
||||
required=False,
|
||||
default=["ma", "macd"]
|
||||
),
|
||||
SkillParameter(
|
||||
name="period",
|
||||
type="integer",
|
||||
description="数据周期(天数)",
|
||||
required=False,
|
||||
default=60
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
执行技术分析
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
indicators: 指标列表
|
||||
period: 数据周期
|
||||
|
||||
Returns:
|
||||
技术指标结果
|
||||
"""
|
||||
stock_code = kwargs.get("stock_code")
|
||||
indicators = kwargs.get("indicators", ["ma", "macd"])
|
||||
period = kwargs.get("period", 60)
|
||||
|
||||
logger.info(f"技术分析: {stock_code}, 指标: {indicators}")
|
||||
|
||||
# 获取K线数据
|
||||
kline_data = tushare_service.get_kline_data(stock_code)
|
||||
|
||||
if not kline_data:
|
||||
return {
|
||||
"error": f"未找到K线数据: {stock_code}"
|
||||
}
|
||||
|
||||
# 转换为DataFrame
|
||||
df = pd.DataFrame(kline_data)
|
||||
|
||||
# 计算指标
|
||||
result = {
|
||||
"stock_code": stock_code,
|
||||
"indicators": {}
|
||||
}
|
||||
|
||||
try:
|
||||
if "ma" in indicators:
|
||||
result["indicators"]["ma"] = self._calculate_ma(df)
|
||||
|
||||
if "macd" in indicators:
|
||||
result["indicators"]["macd"] = self._calculate_macd(df)
|
||||
|
||||
if "rsi" in indicators:
|
||||
result["indicators"]["rsi"] = self._calculate_rsi(df)
|
||||
|
||||
if "kdj" in indicators:
|
||||
result["indicators"]["kdj"] = self._calculate_kdj(df)
|
||||
|
||||
if "boll" in indicators:
|
||||
result["indicators"]["boll"] = self._calculate_boll(df)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"技术指标计算失败: {e}")
|
||||
return {
|
||||
"error": f"技术指标计算失败: {str(e)}"
|
||||
}
|
||||
|
||||
def _calculate_ma(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""计算均线"""
|
||||
close = df['close']
|
||||
|
||||
ma5 = calculate_ma(close, 5)
|
||||
ma10 = calculate_ma(close, 10)
|
||||
ma20 = calculate_ma(close, 20)
|
||||
ma60 = calculate_ma(close, 60)
|
||||
|
||||
# 获取最新值
|
||||
latest_ma5 = ma5.iloc[-1] if not ma5.empty else None
|
||||
latest_ma10 = ma10.iloc[-1] if not ma10.empty else None
|
||||
latest_ma20 = ma20.iloc[-1] if not ma20.empty else None
|
||||
latest_ma60 = ma60.iloc[-1] if not ma60.empty else None
|
||||
|
||||
return {
|
||||
"ma5": round(latest_ma5, 2) if latest_ma5 else None,
|
||||
"ma10": round(latest_ma10, 2) if latest_ma10 else None,
|
||||
"ma20": round(latest_ma20, 2) if latest_ma20 else None,
|
||||
"ma60": round(latest_ma60, 2) if latest_ma60 else None,
|
||||
"description": "移动平均线"
|
||||
}
|
||||
|
||||
def _calculate_macd(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""计算MACD"""
|
||||
close = df['close']
|
||||
|
||||
dif, dea, macd = calculate_macd(close)
|
||||
|
||||
# 获取最新值
|
||||
latest_dif = dif.iloc[-1] if not dif.empty else None
|
||||
latest_dea = dea.iloc[-1] if not dea.empty else None
|
||||
latest_macd = macd.iloc[-1] if not macd.empty else None
|
||||
|
||||
return {
|
||||
"dif": round(latest_dif, 2) if latest_dif else None,
|
||||
"dea": round(latest_dea, 2) if latest_dea else None,
|
||||
"macd": round(latest_macd, 2) if latest_macd else None,
|
||||
"description": "MACD指标"
|
||||
}
|
||||
|
||||
def _calculate_rsi(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""计算RSI"""
|
||||
close = df['close']
|
||||
|
||||
rsi6 = calculate_rsi(close, 6)
|
||||
rsi12 = calculate_rsi(close, 12)
|
||||
rsi24 = calculate_rsi(close, 24)
|
||||
|
||||
# 获取最新值
|
||||
latest_rsi6 = rsi6.iloc[-1] if not rsi6.empty else None
|
||||
latest_rsi12 = rsi12.iloc[-1] if not rsi12.empty else None
|
||||
latest_rsi24 = rsi24.iloc[-1] if not rsi24.empty else None
|
||||
|
||||
return {
|
||||
"rsi6": round(latest_rsi6, 2) if latest_rsi6 else None,
|
||||
"rsi12": round(latest_rsi12, 2) if latest_rsi12 else None,
|
||||
"rsi24": round(latest_rsi24, 2) if latest_rsi24 else None,
|
||||
"description": "相对强弱指标"
|
||||
}
|
||||
|
||||
def _calculate_kdj(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""计算KDJ"""
|
||||
high = df['high']
|
||||
low = df['low']
|
||||
close = df['close']
|
||||
|
||||
k, d, j = calculate_kdj(high, low, close)
|
||||
|
||||
# 获取最新值
|
||||
latest_k = k.iloc[-1] if not k.empty else None
|
||||
latest_d = d.iloc[-1] if not d.empty else None
|
||||
latest_j = j.iloc[-1] if not j.empty else None
|
||||
|
||||
return {
|
||||
"k": round(latest_k, 2) if latest_k else None,
|
||||
"d": round(latest_d, 2) if latest_d else None,
|
||||
"j": round(latest_j, 2) if latest_j else None,
|
||||
"description": "KDJ指标"
|
||||
}
|
||||
|
||||
def _calculate_boll(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""计算布林带"""
|
||||
close = df['close']
|
||||
|
||||
upper, middle, lower = calculate_boll(close)
|
||||
|
||||
# 获取最新值
|
||||
latest_upper = upper.iloc[-1] if not upper.empty else None
|
||||
latest_middle = middle.iloc[-1] if not middle.empty else None
|
||||
latest_lower = lower.iloc[-1] if not lower.empty else None
|
||||
|
||||
return {
|
||||
"upper": round(latest_upper, 2) if latest_upper else None,
|
||||
"middle": round(latest_middle, 2) if latest_middle else None,
|
||||
"lower": round(latest_lower, 2) if latest_lower else None,
|
||||
"description": "布林带"
|
||||
}
|
||||
@ -1,118 +0,0 @@
|
||||
"""
|
||||
美股/港股分析技能
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
from app.services.us_stock_service import us_stock_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class USStockSkill(BaseSkill):
|
||||
"""美股/港股分析技能(使用 yfinance)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "us_stock_analysis"
|
||||
self.description = "分析美股(如 AAPL, TSLA)和港股(如 0700.HK, 9988.HK),获取实时行情、技术指标、基本面数据"
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="symbol",
|
||||
type="string",
|
||||
description="股票代码(美股如 AAPL, TSLA;港股如 0700.HK, 9988.HK)",
|
||||
required=True
|
||||
),
|
||||
SkillParameter(
|
||||
name="analysis_type",
|
||||
type="string",
|
||||
description="分析类型:basic(基本信息)、technical(技术分析)、fundamental(基本面)、comprehensive(综合分析)",
|
||||
required=False,
|
||||
default="comprehensive"
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
执行美股分析
|
||||
|
||||
Args:
|
||||
symbol: 美股代码
|
||||
analysis_type: 分析类型
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
try:
|
||||
symbol = kwargs.get("symbol", "").upper()
|
||||
analysis_type = kwargs.get("analysis_type", "comprehensive")
|
||||
|
||||
if not symbol:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "请提供美股代码"
|
||||
}
|
||||
|
||||
logger.info(f"开始分析股票: {symbol}, 类型: {analysis_type}")
|
||||
|
||||
if analysis_type == "basic":
|
||||
# 基本信息
|
||||
info = us_stock_service.get_stock_info(symbol)
|
||||
if not info:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到股票 {symbol}"
|
||||
}
|
||||
return {
|
||||
"success": True,
|
||||
"data": info
|
||||
}
|
||||
|
||||
elif analysis_type == "technical":
|
||||
# 技术分析
|
||||
hist = us_stock_service.get_historical_data(symbol, period="6mo")
|
||||
if hist is None or hist.empty:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "无法获取历史数据"
|
||||
}
|
||||
|
||||
technical = us_stock_service.calculate_technical_indicators(hist)
|
||||
latest = hist.iloc[-1]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"symbol": symbol,
|
||||
"current_price": float(latest['Close']),
|
||||
"volume": int(latest['Volume']),
|
||||
"technical_indicators": technical
|
||||
}
|
||||
}
|
||||
|
||||
elif analysis_type == "fundamental":
|
||||
# 基本面分析
|
||||
financial = us_stock_service.get_financial_data(symbol)
|
||||
if not financial:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "无法获取财务数据"
|
||||
}
|
||||
return {
|
||||
"success": True,
|
||||
"data": financial
|
||||
}
|
||||
|
||||
else:
|
||||
# 综合分析(默认)
|
||||
result = us_stock_service.get_comprehensive_analysis(symbol)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"美股分析失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
us_stock_skill = USStockSkill()
|
||||
@ -1,118 +0,0 @@
|
||||
"""
|
||||
数据可视化技能
|
||||
生成图表配置数据
|
||||
"""
|
||||
from typing import Dict, Any, List
|
||||
from app.skills.base import BaseSkill, SkillParameter
|
||||
from app.services.tushare_service import tushare_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class VisualizationSkill(BaseSkill):
|
||||
"""数据可视化技能"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "visualization"
|
||||
self.description = "生成K线图和技术指标图表配置"
|
||||
self.parameters = [
|
||||
SkillParameter(
|
||||
name="stock_code",
|
||||
type="string",
|
||||
description="股票代码",
|
||||
required=True
|
||||
),
|
||||
SkillParameter(
|
||||
name="chart_type",
|
||||
type="string",
|
||||
description="图表类型:candlestick(K线图)",
|
||||
required=False,
|
||||
default="candlestick"
|
||||
),
|
||||
SkillParameter(
|
||||
name="period",
|
||||
type="integer",
|
||||
description="数据周期(天数)",
|
||||
required=False,
|
||||
default=60
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
生成图表配置
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
chart_type: 图表类型
|
||||
period: 数据周期
|
||||
|
||||
Returns:
|
||||
图表配置数据
|
||||
"""
|
||||
stock_code = kwargs.get("stock_code")
|
||||
chart_type = kwargs.get("chart_type", "candlestick")
|
||||
period = kwargs.get("period", 60)
|
||||
|
||||
logger.info(f"生成图表配置: {stock_code}, 类型: {chart_type}")
|
||||
|
||||
# 获取K线数据
|
||||
kline_data = tushare_service.get_kline_data(stock_code)
|
||||
|
||||
if not kline_data:
|
||||
return {
|
||||
"error": f"未找到K线数据: {stock_code}"
|
||||
}
|
||||
|
||||
# 限制数据量
|
||||
if len(kline_data) > period:
|
||||
kline_data = kline_data[-period:]
|
||||
|
||||
if chart_type == "candlestick":
|
||||
return self._generate_candlestick_config(kline_data)
|
||||
else:
|
||||
return {
|
||||
"error": f"不支持的图表类型: {chart_type}"
|
||||
}
|
||||
|
||||
def _generate_candlestick_config(self, kline_data: List[dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
生成K线图配置(Lightweight Charts格式)
|
||||
|
||||
Args:
|
||||
kline_data: K线数据列表
|
||||
|
||||
Returns:
|
||||
图表配置
|
||||
"""
|
||||
# 转换为Lightweight Charts格式
|
||||
candlestick_data = []
|
||||
volume_data = []
|
||||
|
||||
for item in kline_data:
|
||||
# 转换日期格式 YYYYMMDD -> YYYY-MM-DD
|
||||
date_str = item['trade_date']
|
||||
formatted_date = f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
|
||||
|
||||
# K线数据
|
||||
candlestick_data.append({
|
||||
"time": formatted_date,
|
||||
"open": item['open'],
|
||||
"high": item['high'],
|
||||
"low": item['low'],
|
||||
"close": item['close']
|
||||
})
|
||||
|
||||
# 成交量数据
|
||||
volume_data.append({
|
||||
"time": formatted_date,
|
||||
"value": item['vol'],
|
||||
"color": "rgba(0, 150, 136, 0.8)" if item['close'] >= item['open'] else "rgba(255, 82, 82, 0.8)"
|
||||
})
|
||||
|
||||
return {
|
||||
"chart_type": "candlestick",
|
||||
"candlestick_data": candlestick_data,
|
||||
"volume_data": volume_data,
|
||||
"stock_code": kline_data[0]['ts_code'] if kline_data else None
|
||||
}
|
||||
@ -1,6 +0,0 @@
|
||||
"""
|
||||
美股交易智能体包
|
||||
"""
|
||||
from app.stock_agent.stock_agent import StockAgent, get_stock_agent
|
||||
|
||||
__all__ = ['StockAgent', 'get_stock_agent']
|
||||
@ -1,569 +0,0 @@
|
||||
"""
|
||||
股票分析工具模块
|
||||
|
||||
提供各种辅助计算功能:
|
||||
- ATR计算
|
||||
- 多周期共振分析
|
||||
- 基本面评分
|
||||
- 支撑阻力位识别
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class StockAnalysisTools:
|
||||
"""股票分析工具类"""
|
||||
|
||||
@staticmethod
|
||||
def calculate_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""
|
||||
计算ATR (真实波动幅度)
|
||||
|
||||
Args:
|
||||
df: 包含 high, low, close 的数据
|
||||
period: ATR周期
|
||||
|
||||
Returns:
|
||||
ATR值
|
||||
"""
|
||||
if df is None or len(df) < period + 1:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
high = df['high'].values
|
||||
low = df['low'].values
|
||||
close = df['close'].values
|
||||
|
||||
tr_list = []
|
||||
for i in range(1, len(df)):
|
||||
tr1 = high[i] - low[i]
|
||||
tr2 = abs(high[i] - close[i-1])
|
||||
tr3 = abs(low[i] - close[i-1])
|
||||
tr = max(tr1, tr2, tr3)
|
||||
tr_list.append(tr)
|
||||
|
||||
if len(tr_list) < period:
|
||||
return 0.0
|
||||
|
||||
atr = pd.Series(tr_list).rolling(window=period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(f"ATR计算失败: {e}")
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def calculate_volume_ratio(df: pd.DataFrame, period: int = 20) -> float:
|
||||
"""
|
||||
计算量比(当前成交量 / 过去N周期平均成交量)
|
||||
|
||||
Args:
|
||||
df: 包含 volume 的数据
|
||||
period: 均量周期
|
||||
|
||||
Returns:
|
||||
量比值
|
||||
"""
|
||||
if df is None or len(df) < period + 1:
|
||||
return 1.0
|
||||
|
||||
try:
|
||||
current_vol = df['volume'].iloc[-1]
|
||||
avg_vol = df['volume'].iloc[-period-1:-1].mean()
|
||||
|
||||
if avg_vol > 0:
|
||||
return float(current_vol / avg_vol)
|
||||
return 1.0
|
||||
except Exception as e:
|
||||
logger.warning(f"量比计算失败: {e}")
|
||||
return 1.0
|
||||
|
||||
@staticmethod
|
||||
def detect_trend(df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
检测趋势方向和强度
|
||||
|
||||
使用EMA系统判断趋势
|
||||
|
||||
Returns:
|
||||
{
|
||||
'direction': 'uptrend'/'downtrend'/'neutral',
|
||||
'strength': 'strong'/'medium'/'weak',
|
||||
'ema_alignment': 'bullish'/'bearish'/'mixed',
|
||||
'price_vs_ema20': float, # 百分比
|
||||
'signals': List[str]
|
||||
}
|
||||
"""
|
||||
if df is None or len(df) < 200:
|
||||
return {
|
||||
'direction': 'neutral',
|
||||
'strength': 'weak',
|
||||
'ema_alignment': 'mixed',
|
||||
'price_vs_ema20': 0.0,
|
||||
'signals': []
|
||||
}
|
||||
|
||||
try:
|
||||
# 计算EMA
|
||||
close = df['close']
|
||||
ema20 = close.ewm(span=20, adjust=False).mean().iloc[-1]
|
||||
ema50 = close.ewm(span=50, adjust=False).mean().iloc[-1]
|
||||
ema200 = close.ewm(span=200, adjust=False).mean().iloc[-1]
|
||||
|
||||
current_price = close.iloc[-1]
|
||||
|
||||
# EMA排列判断
|
||||
if ema20 > ema50 > ema200:
|
||||
ema_alignment = 'bullish'
|
||||
direction = 'uptrend'
|
||||
elif ema20 < ema50 < ema200:
|
||||
ema_alignment = 'bearish'
|
||||
direction = 'downtrend'
|
||||
else:
|
||||
ema_alignment = 'mixed'
|
||||
direction = 'neutral'
|
||||
|
||||
# 价格与EMA20的关系
|
||||
price_vs_ema20 = ((current_price - ema20) / ema20 * 100) if ema20 > 0 else 0
|
||||
|
||||
# 趋势强度判断
|
||||
strength = 'weak'
|
||||
signals = []
|
||||
|
||||
if ema_alignment == 'bullish':
|
||||
if price_vs_ema20 > 1:
|
||||
strength = 'strong'
|
||||
signals.append("强势上涨:价格站稳 EMA20 之上")
|
||||
elif price_vs_ema20 > 0:
|
||||
strength = 'medium'
|
||||
signals.append("上涨趋势:价格在 EMA20 附近")
|
||||
else:
|
||||
strength = 'weak'
|
||||
signals.append("上涨趋势减弱:价格跌破 EMA20")
|
||||
|
||||
elif ema_alignment == 'bearish':
|
||||
if price_vs_ema20 < -1:
|
||||
strength = 'strong'
|
||||
signals.append("强势下跌:价格跌破 EMA20 之下")
|
||||
elif price_vs_ema20 < 0:
|
||||
strength = 'medium'
|
||||
signals.append("下跌趋势:价格在 EMA20 附近")
|
||||
else:
|
||||
strength = 'weak'
|
||||
signals.append("下跌趋势减弱:价格站上 EMA20")
|
||||
else:
|
||||
signals.append("震荡市:EMA 交织")
|
||||
|
||||
return {
|
||||
'direction': direction,
|
||||
'strength': strength,
|
||||
'ema_alignment': ema_alignment,
|
||||
'price_vs_ema20': round(price_vs_ema20, 2),
|
||||
'signals': signals
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"趋势检测失败: {e}")
|
||||
return {
|
||||
'direction': 'neutral',
|
||||
'strength': 'weak',
|
||||
'ema_alignment': 'mixed',
|
||||
'price_vs_ema20': 0.0,
|
||||
'signals': []
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_multi_timeframe_resonance(data: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
|
||||
"""
|
||||
计算多周期共振强度
|
||||
|
||||
Args:
|
||||
data: 包含多个周期的数据 {'1w': df, '1d': df, '1h': df}
|
||||
|
||||
Returns:
|
||||
{
|
||||
'score': 0-100,
|
||||
'level': 'strong'/'medium'/'weak',
|
||||
'weekly_trend': str,
|
||||
'daily_trend': str,
|
||||
'hourly_trend': str,
|
||||
'resonance_type': str,
|
||||
'analysis': str
|
||||
}
|
||||
"""
|
||||
trends = {}
|
||||
|
||||
# 获取各周期趋势
|
||||
for tf_name in ['1w', '1d', '1h']:
|
||||
df = data.get(tf_name)
|
||||
if df is not None and len(df) > 0:
|
||||
trend_info = StockAnalysisTools.detect_trend(df)
|
||||
trends[tf_name] = trend_info['direction']
|
||||
else:
|
||||
trends[tf_name] = 'neutral'
|
||||
|
||||
weekly_trend = trends.get('1w', 'neutral')
|
||||
daily_trend = trends.get('1d', 'neutral')
|
||||
hourly_trend = trends.get('1h', 'neutral')
|
||||
|
||||
# 计算共振得分
|
||||
score = 0
|
||||
analysis_parts = []
|
||||
|
||||
# 大周期共振(周线+日线)
|
||||
if weekly_trend == daily_trend and weekly_trend != 'neutral':
|
||||
score += 40
|
||||
analysis_parts.append(f"✅ 大周期共振({weekly_trend})")
|
||||
elif weekly_trend != 'neutral' and daily_trend != 'neutral':
|
||||
analysis_parts.append(f"⚠️ 大周期分歧(周线{weekly_trend} vs 日线{daily_trend})")
|
||||
else:
|
||||
analysis_parts.append(f"➖ 大周期不明确")
|
||||
|
||||
# 主周期共振(日线+1h)
|
||||
if daily_trend == hourly_trend and daily_trend != 'neutral':
|
||||
score += 35
|
||||
analysis_parts.append(f"✅ 主周期共振({daily_trend})")
|
||||
elif daily_trend != 'neutral' and hourly_trend != 'neutral':
|
||||
analysis_parts.append(f"⚠️ 主周期分歧(日线{daily_trend} vs 1h{hourly_trend})")
|
||||
else:
|
||||
analysis_parts.append(f"➖ 主周期不明确")
|
||||
|
||||
# 全周期共振
|
||||
if weekly_trend == daily_trend == hourly_trend and weekly_trend != 'neutral':
|
||||
score += 25
|
||||
analysis_parts.append(f"🔥 全周期共振({weekly_trend})")
|
||||
|
||||
# 确定共振等级
|
||||
if score >= 70:
|
||||
level = 'strong'
|
||||
resonance_type = 'all_timeframe_aligned' if score >= 90 else 'strong'
|
||||
elif score >= 40:
|
||||
level = 'medium'
|
||||
resonance_type = 'large_timeframe_aligned'
|
||||
elif score >= 20:
|
||||
level = 'weak'
|
||||
resonance_type = 'partial_alignment'
|
||||
else:
|
||||
level = 'weak'
|
||||
resonance_type = 'no_resonance'
|
||||
|
||||
return {
|
||||
'score': score,
|
||||
'level': level,
|
||||
'weekly_trend': weekly_trend,
|
||||
'daily_trend': daily_trend,
|
||||
'hourly_trend': hourly_trend,
|
||||
'resonance_type': resonance_type,
|
||||
'analysis': ' | '.join(analysis_parts)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_fundamental_score(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
计算基本面评分(0-100)
|
||||
|
||||
Args:
|
||||
data: 基本面数据
|
||||
|
||||
Returns:
|
||||
{
|
||||
'score': 0-100,
|
||||
'grade': 'A'/'B'/'C'/'D',
|
||||
'valuation': {'score': 0-25, 'grade': 'A'/'B'/'C'/'D'},
|
||||
'profitability': {'score': 0-35, 'grade': 'A'/'B'/'C'/'D'},
|
||||
'growth': {'score': 0-25, 'grade': 'A'/'B'/'C'/'D'},
|
||||
'financial_health': {'score': 0-15, 'grade': 'A'/'B'/'C'/'D'},
|
||||
'summary': str
|
||||
}
|
||||
"""
|
||||
if not data:
|
||||
return {
|
||||
'score': 0,
|
||||
'grade': 'D',
|
||||
'summary': '无基本面数据'
|
||||
}
|
||||
|
||||
score = 0
|
||||
breakdown = {}
|
||||
|
||||
# 估值评分 (25分)
|
||||
val_score = 0
|
||||
pe = data.get('pe_ratio', 0)
|
||||
pb = data.get('pb_ratio', 0)
|
||||
peg = data.get('peg_ratio', 0)
|
||||
|
||||
if 0 < pe < 15:
|
||||
val_score += 10
|
||||
elif 15 <= pe <= 25:
|
||||
val_score += 7
|
||||
elif pe > 40:
|
||||
val_score += 0
|
||||
else:
|
||||
val_score += 5
|
||||
|
||||
if 0 < pb < 1:
|
||||
val_score += 10
|
||||
elif 1 <= pb <= 3:
|
||||
val_score += 7
|
||||
elif pb > 5:
|
||||
val_score += 0
|
||||
else:
|
||||
val_score += 5
|
||||
|
||||
if peg and 0 < peg < 1:
|
||||
val_score += 5
|
||||
elif peg and 1 <= peg <= 2:
|
||||
val_score += 3
|
||||
elif peg and peg > 2:
|
||||
val_score += 0
|
||||
else:
|
||||
val_score += 2
|
||||
|
||||
breakdown['valuation'] = {
|
||||
'score': val_score,
|
||||
'grade': 'A' if val_score >= 20 else 'B' if val_score >= 15 else 'C' if val_score >= 10 else 'D'
|
||||
}
|
||||
score += val_score
|
||||
|
||||
# 盈利能力评分 (35分)
|
||||
prof_score = 0
|
||||
roe = data.get('roe', 0)
|
||||
net_margin = data.get('profit_margin', 0)
|
||||
|
||||
if roe > 20:
|
||||
prof_score += 20
|
||||
elif roe > 15:
|
||||
prof_score += 15
|
||||
elif roe > 10:
|
||||
prof_score += 10
|
||||
elif roe > 0:
|
||||
prof_score += 5
|
||||
|
||||
if net_margin > 20:
|
||||
prof_score += 15
|
||||
elif net_margin > 10:
|
||||
prof_score += 10
|
||||
elif net_margin > 5:
|
||||
prof_score += 5
|
||||
elif net_margin > 0:
|
||||
prof_score += 2
|
||||
|
||||
breakdown['profitability'] = {
|
||||
'score': prof_score,
|
||||
'grade': 'A' if prof_score >= 30 else 'B' if prof_score >= 20 else 'C' if prof_score >= 10 else 'D'
|
||||
}
|
||||
score += prof_score
|
||||
|
||||
# 成长性评分 (25分)
|
||||
growth_score = 0
|
||||
revenue_growth = data.get('revenue_growth', 0)
|
||||
earnings_growth = data.get('earnings_growth', 0)
|
||||
|
||||
if revenue_growth > 30:
|
||||
growth_score += 13
|
||||
elif revenue_growth > 20:
|
||||
growth_score += 10
|
||||
elif revenue_growth > 10:
|
||||
growth_score += 5
|
||||
elif revenue_growth > 0:
|
||||
growth_score += 2
|
||||
|
||||
if earnings_growth > 30:
|
||||
growth_score += 12
|
||||
elif earnings_growth > 20:
|
||||
growth_score += 8
|
||||
elif earnings_growth > 10:
|
||||
growth_score += 5
|
||||
elif earnings_growth > 0:
|
||||
growth_score += 2
|
||||
elif earnings_growth < 0:
|
||||
growth_score -= 5 # 负增长扣分
|
||||
|
||||
breakdown['growth'] = {
|
||||
'score': max(0, growth_score),
|
||||
'grade': 'A' if growth_score >= 20 else 'B' if growth_score >= 15 else 'C' if growth_score >= 10 else 'D'
|
||||
}
|
||||
score += max(0, growth_score)
|
||||
|
||||
# 财务健康评分 (15分)
|
||||
health_score = 0
|
||||
debt_ratio = data.get('debt_to_equity', 0)
|
||||
current_ratio = data.get('current_ratio', 0)
|
||||
|
||||
if debt_ratio < 1:
|
||||
health_score += 8
|
||||
elif debt_ratio < 2:
|
||||
health_score += 5
|
||||
elif debt_ratio < 3:
|
||||
health_score += 2
|
||||
else:
|
||||
health_score += 0
|
||||
|
||||
if current_ratio > 2:
|
||||
health_score += 7
|
||||
elif current_ratio > 1.5:
|
||||
health_score += 5
|
||||
elif current_ratio > 1:
|
||||
health_score += 2
|
||||
else:
|
||||
health_score += 0
|
||||
|
||||
breakdown['financial_health'] = {
|
||||
'score': health_score,
|
||||
'grade': 'A' if health_score >= 12 else 'B' if health_score >= 8 else 'C' if health_score >= 5 else 'D'
|
||||
}
|
||||
score += health_score
|
||||
|
||||
# 确定总等级
|
||||
grade = 'A' if score >= 80 else 'B' if score >= 60 else 'C' if score >= 40 else 'D'
|
||||
|
||||
# 生成摘要
|
||||
summary_parts = []
|
||||
if breakdown['valuation']['grade'] == 'A':
|
||||
summary_parts.append("估值低估")
|
||||
elif breakdown['valuation']['grade'] == 'D':
|
||||
summary_parts.append("估值高估")
|
||||
|
||||
if breakdown['profitability']['grade'] == 'A':
|
||||
summary_parts.append("盈利优秀")
|
||||
elif breakdown['profitability']['grade'] == 'D':
|
||||
summary_parts.append("盈利较差")
|
||||
|
||||
if breakdown['growth']['grade'] == 'A':
|
||||
summary_parts.append("高成长")
|
||||
elif breakdown['growth']['grade'] == 'D':
|
||||
summary_parts.append("低成长")
|
||||
|
||||
if breakdown['financial_health']['grade'] == 'A':
|
||||
summary_parts.append("财务健康")
|
||||
elif breakdown['financial_health']['grade'] == 'D':
|
||||
summary_parts.append("财务风险")
|
||||
|
||||
return {
|
||||
'score': score,
|
||||
'grade': grade,
|
||||
'breakdown': breakdown,
|
||||
'summary': ' | '.join(summary_parts) if summary_parts else '一般'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def identify_key_levels(df: pd.DataFrame, lookback: int = 60) -> Dict[str, List[float]]:
|
||||
"""
|
||||
识别关键支撑位和阻力位
|
||||
|
||||
Args:
|
||||
df: K线数据
|
||||
lookback: 回看周期
|
||||
|
||||
Returns:
|
||||
{
|
||||
'support': [支撑位1, 支撑位2, ...],
|
||||
'resistance': [阻力位1, 阻力位2, ...]
|
||||
}
|
||||
"""
|
||||
if df is None or len(df) < lookback:
|
||||
return {'support': [], 'resistance': []}
|
||||
|
||||
try:
|
||||
recent = df.tail(lookback)
|
||||
highs = recent['high'].values
|
||||
lows = recent['low'].values
|
||||
|
||||
# 找局部高点和低点
|
||||
from scipy.signal import argrelextrema
|
||||
from numpy import array
|
||||
|
||||
high_indices = argrelextrema(array(highs), np.greater, order=5)
|
||||
low_indices = argrelextrema(array(lows), np.less, order=5)
|
||||
|
||||
resistance_levels = sorted([highs[i] for i in high_indices], reverse=True)[:3]
|
||||
support_levels = sorted([lows[i] for i in low_indices])[:3]
|
||||
|
||||
return {
|
||||
'resistance': resistance_levels,
|
||||
'support': support_levels
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"关键位识别失败: {e}")
|
||||
return {'support': [], 'resistance': []}
|
||||
|
||||
@staticmethod
|
||||
def calculate_stop_loss_take_profit(
|
||||
entry_price: float,
|
||||
atr: float,
|
||||
direction: str,
|
||||
key_levels: Dict[str, List[float]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算止损止盈价格
|
||||
|
||||
Args:
|
||||
entry_price: 入场价格
|
||||
atr: ATR值
|
||||
direction: 'long'/'short'
|
||||
key_levels: 支撑阻力位
|
||||
|
||||
Returns:
|
||||
{
|
||||
'stop_loss': float,
|
||||
'take_profit': float,
|
||||
'method': str,
|
||||
'risk_reward_ratio': float
|
||||
}
|
||||
"""
|
||||
if direction == 'long':
|
||||
# 做多
|
||||
# 止损:入场价 - 1.5×ATR,或设在前支撑位下方
|
||||
sl_atr = entry_price - 1.5 * atr
|
||||
sl_support = None
|
||||
|
||||
if key_levels and key_levels.get('support'):
|
||||
closest_support = [s for s in key_levels['support'] if s < entry_price]
|
||||
if closest_support:
|
||||
sl_support = min(closest_support) * 0.995 # 支撑位下方0.5%
|
||||
|
||||
# 选择更保守的止损(更远的)
|
||||
if sl_support and sl_support < sl_atr:
|
||||
stop_loss = sl_atr
|
||||
else:
|
||||
stop_loss = sl_atr
|
||||
|
||||
# 止盈:入场价 + 3×ATR(风险收益比1:2)
|
||||
take_profit = entry_price + 3 * atr
|
||||
|
||||
else:
|
||||
# 做空
|
||||
# 止损:入场价 + 1.5×ATR,或设在前阻力位上方
|
||||
sl_atr = entry_price + 1.5 * atr
|
||||
sl_resistance = None
|
||||
|
||||
if key_levels and key_levels.get('resistance'):
|
||||
closest_resistance = [r for r in key_levels['resistance'] if r > entry_price]
|
||||
if closest_resistance:
|
||||
sl_resistance = min(closest_resistance) * 1.005 # 阻力位上方0.5%
|
||||
|
||||
# 选择更保守的止损
|
||||
if sl_resistance and sl_resistance > sl_atr:
|
||||
stop_loss = sl_atr
|
||||
else:
|
||||
stop_loss = sl_atr
|
||||
|
||||
# 止盈:入场价 - 3×ATR
|
||||
take_profit = entry_price - 3 * atr
|
||||
|
||||
# 计算风险收益比
|
||||
if direction == 'long':
|
||||
risk = entry_price - stop_loss
|
||||
reward = take_profit - entry_price
|
||||
else:
|
||||
risk = stop_loss - entry_price
|
||||
reward = entry_price - take_profit
|
||||
|
||||
risk_reward_ratio = reward / risk if risk > 0 else 0
|
||||
|
||||
return {
|
||||
'stop_loss': round(stop_loss, 2),
|
||||
'take_profit': round(take_profit, 2),
|
||||
'method': 'ATR-based (1:2 risk-reward)',
|
||||
'risk_reward_ratio': round(risk_reward_ratio, 2)
|
||||
}
|
||||
@ -1,871 +0,0 @@
|
||||
"""
|
||||
股票市场信号分析器 - 纯市场分析,不包含任何仓位信息
|
||||
|
||||
职责:
|
||||
1. 分析K线、量价、技术指标
|
||||
2. 分析新闻舆情
|
||||
3. 输出纯市场信号(buy/sell/hold + confidence + reasoning)
|
||||
|
||||
不负责:
|
||||
- 仓位管理
|
||||
- 风险控制
|
||||
- 具体下单决策
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import pandas as pd
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
from app.services.llm_service import llm_service
|
||||
|
||||
|
||||
class StockMarketSignalAnalyzer:
|
||||
"""股票市场信号分析器 - 只关注市场,输出客观信号"""
|
||||
|
||||
# 股票市场分析系统提示词
|
||||
MARKET_ANALYSIS_PROMPT = """你是一位专业的股票交易员和技术分析师。你的任务是综合分析**趋势方向、技术面(K线、量价、技术指标)、基本面(估值、盈利、成长)、新闻舆情**,给出交易信号。
|
||||
|
||||
## 核心理念
|
||||
**趋势是你的朋友,顺势交易是稳定盈利的关键。**
|
||||
|
||||
### 🚨 铁律(必须遵守)
|
||||
1. **先判断趋势,再寻找信号** - 趋势方向错误,信号再强也不做
|
||||
2. **顺势交易为主** - 上涨趋势只做多或观望,下跌趋势只做空或观望
|
||||
3. **逆势交易极其谨慎** - 必须有多重反转信号才能考虑逆势
|
||||
4. **单边行情不逆势** - 强趋势中(日线连续3根以上同向K线)严禁逆势开仓
|
||||
|
||||
### 交易目标
|
||||
- **稳健为主**:宁可错过,不做错
|
||||
- **顺势而为**:在大趋势方向上寻找入场点
|
||||
- **严控风险**:每次交易风险不超过本金的2%
|
||||
|
||||
## 零、趋势方向判断(第一步,最重要!)
|
||||
**在分析任何信号之前,先判断当前趋势方向和强度。**
|
||||
|
||||
### 趋势判断标准(使用 EMA 和均线系统)
|
||||
**上升趋势(多头市场)**:
|
||||
- EMA20 > EMA50 > EMA200(短中长期均线多头排列)
|
||||
- 价格站稳在 EMA20 之上
|
||||
- MA5 > MA10 > MA20 > MA50
|
||||
- 最近高点逐步抬高,低点也逐步抬高
|
||||
|
||||
**下降趋势(空头市场)**:
|
||||
- EMA20 < EMA50 < EMA200(短中长期均线空头排列)
|
||||
- 价格持续在 EMA20 之下
|
||||
- MA5 < MA10 < MA20 < MA50
|
||||
- 最近高点逐步降低,低点也逐步降低
|
||||
|
||||
**震荡市(无明确趋势)**:
|
||||
- 均线纠缠,无明显排列
|
||||
- 价格在 EMA20 上下波动
|
||||
- 高点低点无规律
|
||||
- 此时可双向交易,但降低仓位
|
||||
|
||||
### 趋势强度判断
|
||||
- **强趋势**:均线完美排列 + 价格远离均线 + 成交量配合
|
||||
- **中等趋势**:均线有排列 + 价格偶尔回踩均线
|
||||
- **弱趋势/震荡**:均线纠缠 + 价格在均线上下反复
|
||||
|
||||
### 顺势交易规则(必须执行)
|
||||
| 当前趋势 | 允许操作 | 条件 |
|
||||
|---------|---------|------|
|
||||
| **强上升趋势** | ✅ 只做多 | 回调到支撑位、RSI超卖区、金叉 |
|
||||
| **强上升趋势** | ❌ 严禁做空 | 除非出现明确的顶背离+放量反转信号 |
|
||||
| **强下降趋势** | ✅ 只做空 | 反弹到阻力位、RSI超买区、死叉 |
|
||||
| **强下降趋势** | ❌ 严禁做多 | 除非出现明确的底背离+放量反转信号 |
|
||||
| **震荡市** | ✅ 双向交易 | 但降低仓位(轻仓),提高止损要求 |
|
||||
| **趋势不明确** | ⚠️ 观望为主 | 等待趋势明确后再入场 |
|
||||
|
||||
### 逆势交易的条件(极其严格)
|
||||
**只有在满足以下全部条件时,才允许考虑逆势交易:**
|
||||
1. **多重反转信号**:
|
||||
- 明确的背离(顶背离或底背离)
|
||||
- 关键形态反转(头肩顶/底、双顶/底、吞没形态)
|
||||
- 放量突破关键位
|
||||
2. **多周期确认**:周线、日线、1h 三个周期同时出现反转信号
|
||||
3. **风险收益比合理**:潜在盈利至少是风险的3倍以上
|
||||
4. **基本面支持**:重大利好/利空改变趋势
|
||||
5. **降低仓位**:逆势交易必须轻仓(不超过顺势仓位的50%)
|
||||
|
||||
**如果不符合上述条件,即使有买入/卖出信号,也必须选择 hold(观望)。**
|
||||
|
||||
## 数据说明
|
||||
你将获得三个维度的数据:
|
||||
1. **技术面数据**:K线、量价、技术指标(RSI、MACD、布林带、均线)
|
||||
2. **基本面数据**:估值指标(PE、PB)、盈利能力(ROE、净利率)、成长性(营收增长、盈利增长)、财务健康度
|
||||
3. **新闻舆情**:最新相关新闻
|
||||
|
||||
## 分析框架(重要!)
|
||||
### 优先级排序:
|
||||
1. **技术面** = 40%:K线、量价、技术指标决定入场时机
|
||||
2. **基本面** = 35%:估值和盈利能力决定信号的长期有效性
|
||||
3. **新闻** = 25%:重大新闻可能改变短期趋势
|
||||
|
||||
### 综合判断规则:
|
||||
- **技术面强 + 基本面好 + 无负面新闻** → A级信号,高置信度
|
||||
- **技术面强 + 基本面一般** → B级信号,中等置信度
|
||||
- **技术面一般 + 基本面好** → C级信号,低置信度,观望为主
|
||||
- **技术面强 + 基本面差 + 有负面新闻** → D级信号,不推荐交易
|
||||
- **技术面弱** → 无论基本面如何,不推荐交易(观望)
|
||||
|
||||
## 一、量价分析(最重要)
|
||||
量价关系是判断趋势真假的核心:
|
||||
|
||||
### 1. 健康上涨信号
|
||||
- **放量上涨**:价格上涨 + 成交量放大(量比>1.5)= 上涨有效,可追多
|
||||
- **缩量回调**:上涨后回调 + 成交量萎缩(量比<0.7)= 回调健康,可低吸
|
||||
- **温和放量**:量比在1.2-1.5之间,价格稳步上涨 = 最健康的上涨
|
||||
|
||||
### 2. 健康下跌信号
|
||||
- **放量下跌**:价格下跌 + 成交量放大 = 下跌有效,暂不抄底
|
||||
- **缩量阴跌**:下跌 + 成交量萎缩 = 抛压逐渐枯竭,关注反弹
|
||||
- **地量企稳**:极端缩量后价格横盘 = 可能见底
|
||||
|
||||
### 3. 量价背离(重要反转信号)
|
||||
- **顶背离**:价格创新高,但成交量未创新高 → 上涨动能衰竭
|
||||
- **底背离**:价格创新低,但成交量未创新低 → 下跌动能衰竭
|
||||
- **天量见顶**:单日成交量突然放大2-3倍后价格滞涨 → 主力出货
|
||||
- **地量见底**:成交量创阶段新低后价格企稳 → 抛压枯竭
|
||||
|
||||
### 4. 突破确认
|
||||
- **有效突破**:突破关键位 + 放量确认(量比>1.3)= 真突破
|
||||
- **假突破**:突破关键位 + 缩量 = 假突破,可能回落
|
||||
|
||||
## 二、K线形态分析
|
||||
### 反转形态
|
||||
- **锤子线/倒锤子**:下跌趋势中出现,下影线长 = 底部信号
|
||||
- **吞没形态**:大阳吞没前一根阴线 = 看涨;大阴吞没前一根阳线 = 看跌
|
||||
- **十字星**:在高位/低位出现 = 变盘信号
|
||||
- **早晨之星/黄昏之星**:三根K线组合的反转信号
|
||||
|
||||
### 持续形态
|
||||
- **三连阳/三连阴**:趋势延续信号
|
||||
- **旗形整理**:趋势中的健康回调
|
||||
|
||||
## 三、技术指标分析
|
||||
### RSI(相对强弱指标)
|
||||
**RSI 是最重要的超买超卖指标:**
|
||||
- **RSI < 30**:超卖区,关注反弹机会
|
||||
- RSI 从 30 以下回升,交叉上穿 30:买入信号
|
||||
- RSI 底背离(价格新低但 RSI 未创新低):强买入信号
|
||||
- **RSI > 70**:超买区,关注回落风险
|
||||
- RSI 从 70 以上回落,交叉下穿 70:卖出信号
|
||||
- RSI 顶背离(价格新高但 RSI 未创新高):强卖出信号
|
||||
- **RSI 40-60**:震荡区,观望为主
|
||||
|
||||
### MACD
|
||||
- 金叉(DIF 上穿 DEA):做多信号
|
||||
- 死叉(DIF 下穿 DEA):做空信号
|
||||
- 零轴上方金叉:强势做多
|
||||
- 零轴下方死叉:强势做空
|
||||
- MACD 柱状图背离:重要反转信号
|
||||
|
||||
### 布林带
|
||||
- 触及下轨 + 企稳:反弹做多
|
||||
- 触及上轨 + 受阻:回落做空
|
||||
- 布林带收口:即将变盘
|
||||
- 布林带开口:趋势启动
|
||||
|
||||
### 均线系统(重要)
|
||||
**均线系统是趋势判断的核心:**
|
||||
- **多头排列**(MA5 > MA10 > MA20 > MA50):强势上涨趋势,回调做多
|
||||
- **空头排列**(MA5 < MA10 < MA20 < MA50):强势下跌趋势,反弹做空
|
||||
- **EMA 趋势判断**(比 MA 更平滑,更适合判断长期趋势):
|
||||
- **多头排列**(EMA20 > EMA50 > EMA200):长期上涨趋势确立
|
||||
- **空头排列**(EMA20 < EMA50 < EMA200):长期下跌趋势确立
|
||||
- 价格站稳 EMA20 上方:中期上涨趋势
|
||||
- 价格跌破 EMA20:中期转为下跌趋势
|
||||
- EMA50 是长期趋势的生命线
|
||||
- **价格与 MA/EMA 的关系**:
|
||||
- 价格站稳 MA5/MA10 上方:短线上涨
|
||||
- 价格突破 MA20/EMA20:中线转多
|
||||
- 价格跌破 MA20/EMA20:中线转空
|
||||
- MA50/EMA50 是中期趋势的分水岭
|
||||
- **均线金叉死叉**:
|
||||
- MA5 上穿 MA10:短线买入信号
|
||||
- MA5 下穿 MA10:短线卖出信号
|
||||
- EMA20 上穿 EMA50:中线买入信号(重要)
|
||||
- EMA20 下穿 EMA50:中线卖出信号(重要)
|
||||
|
||||
## 四、多周期共振(关键分析框架)
|
||||
**多周期共振是提高信号质量的核心方法:**
|
||||
|
||||
### 周期层级关系
|
||||
- **周线(趋势层)**:决定长期大方向
|
||||
- **日线(主周期)**:主要交易周期
|
||||
- **1h(入场层)**:寻找入场时机
|
||||
|
||||
### 共振判断标准
|
||||
**强共振(A级信号)**:
|
||||
- 所有周期趋势同向(如周线多 + 日线多 + 1h多)
|
||||
- 多周期 RSI 同时超买/超卖后出现背离
|
||||
- 多周期 MA 同时金叉/死叉
|
||||
|
||||
**中等共振(B级信号)**:
|
||||
- 大周期(周线+日线)同向
|
||||
- 主周期(日线)技术指标明确
|
||||
|
||||
**弱共振(C级信号)**:
|
||||
- 只有单一周期信号
|
||||
- 多周期方向不一致
|
||||
|
||||
### 实战策略
|
||||
- **顺势交易**:周线和日线同向时,在 1h 寻找入场点
|
||||
- **逆势谨慎**:只有日线信号但周线反向时,降低置信度
|
||||
- **突破交易**:多周期同时突破关键位,信号最强
|
||||
|
||||
## 五、基本面分析(重要)
|
||||
**基本面是判断信号长期有效性的关键:**
|
||||
|
||||
### 估值指标
|
||||
- **PE(市盈率)**:
|
||||
- PE < 15:低估,安全边际高
|
||||
- PE 15-25:合理估值
|
||||
- PE > 40:高估,风险较大
|
||||
- **PB(市净率)**:
|
||||
- PB < 1:低于净资产,价值投资机会
|
||||
- PB 1-3:合理区间
|
||||
- PB > 5:高估
|
||||
- **PEG(市盈率增长率)**:
|
||||
- PEG < 1:低估,成长性好
|
||||
- PEG 1-2:合理
|
||||
- PEG > 2:高估
|
||||
|
||||
### 盈利能力
|
||||
- **ROE(净资产收益率)**:
|
||||
- ROE > 20%:优秀
|
||||
- ROE 15-20%:良好
|
||||
- ROE < 10%:较差
|
||||
- **净利率**:
|
||||
- 净利率 > 20%:优秀(通常是科技、消费品牌)
|
||||
- 净利率 10-20%:良好
|
||||
- 净利率 < 5%:较低(通常是零售、制造业)
|
||||
|
||||
### 成长性
|
||||
- **营收增长**:
|
||||
- > 30%:高成长
|
||||
- 20-30%:稳健成长
|
||||
- < 10%:低成长
|
||||
- **盈利增长**:
|
||||
- > 30%:高成长
|
||||
- 10-30%:稳健成长
|
||||
- < 0%:负增长,警惕
|
||||
|
||||
### 财务健康
|
||||
- **债务股本比**:
|
||||
- < 1:健康
|
||||
- 1-2:可控
|
||||
- > 3:高风险
|
||||
- **流动比率**:
|
||||
- > 2:健康
|
||||
- 1.5-2:良好
|
||||
- < 1:流动性风险
|
||||
|
||||
### 基本面综合判断
|
||||
- **基本面优秀**(ROE>15%, 营收增长>20%, 财务健康)+ 技术面信号 = 提高置信度
|
||||
- **基本面一般**(ROE 10-15%, 营收增长 10-20%)+ 技术面信号 = 正常置信度
|
||||
- **基本面较差**(ROE<10%, 营收增长<10% 或负增长, 高负债)+ 技术面信号 = 降低置信度
|
||||
- **基本面差**(连续亏损, 高负债, 负增长)= 不建议交易,无论技术面如何
|
||||
|
||||
## 六、新闻舆情分析
|
||||
**新闻会改变短期趋势,需要重点关注:**
|
||||
|
||||
### 正面新闻(提高做多置信度)
|
||||
- 财报超预期
|
||||
- 重大产品发布
|
||||
- 业务扩张/并购
|
||||
- 分析师上调评级
|
||||
- 行业利好政策
|
||||
|
||||
### 负面新闻(提高做空置信度或降低做多置信度)
|
||||
- 财报不及预期
|
||||
- 监管调查/处罚
|
||||
- 管理层变动
|
||||
- 分析师下调评级
|
||||
- 行业监管收紧
|
||||
- 重大安全事故/质量问题
|
||||
|
||||
### 新闻综合判断
|
||||
- **重大正面新闻** + 技术面做多信号 = 提高置信度 10-20%
|
||||
- **重大负面新闻** + 技术面做多信号 = 降低置信度或转为观望
|
||||
- **无重大新闻** = 技术面 + 基本面分析为主
|
||||
|
||||
## 七、入场方式
|
||||
根据市场分析综合判断入场方式:
|
||||
- **market**:现价立即入场
|
||||
- 信号强烈且明确(A级或高置信度B级)
|
||||
- 放量突破关键位,趋势明确
|
||||
- 多周期共振,等待可能错过机会
|
||||
- 市场波动大,等待可能价格变化太快
|
||||
- **limit**:挂单等待入场
|
||||
- 信号强度中等(B级或C级)
|
||||
- 当前价格距离理想入场位有一定距离
|
||||
- 判断市场可能回调到更好位置
|
||||
- 希望获得更优成交价格,愿意承担可能无法成交的风险
|
||||
|
||||
**重要**:
|
||||
- 必须同时输出 `entry_price`(建议入场价)和 `entry_type`(入场方式)
|
||||
- 入场方式由你的市场分析判断,不是简单的价格距离计算
|
||||
|
||||
## 止损止盈计算规则
|
||||
使用ATR(真实波动幅度)动态计算止损止盈:
|
||||
|
||||
**做多**:
|
||||
- 止损 = 入场价 - 1.5 × ATR(14日)
|
||||
- 止盈 = 入场价 + 3 × ATR(14日) (风险收益比 1:2)
|
||||
- 如果附近有明显支撑位,可将止损设在支撑位下方
|
||||
|
||||
**做空**:
|
||||
- 止损 = 入场价 + 1.5 × ATR(14日)
|
||||
- 止盈 = 入场价 - 3 × ATR(14日) (风险收益比 1:2)
|
||||
- 如果附近有明显阻力位,可将止损设在阻力位上方
|
||||
|
||||
## 信号输出条件(严格遵守)
|
||||
**满足以下条件才输出信号(否则 signals 返回空数组)**:
|
||||
|
||||
### 做多信号条件:
|
||||
1. ✅ 趋势:EMA20 > EMA50 > EMA200(或处于回调中的上升趋势)
|
||||
2. ✅ 价格:站稳 EMA20 之上或回调到 EMA20 附近
|
||||
3. ✅ 量价:放量上涨或缩量回调后重新放量
|
||||
4. ✅ 共振:多周期共振得分 >= 40(至少大周期一致)
|
||||
5. ✅ 基本面:评分 >= C级(40分以上)
|
||||
6. ✅ 新闻:无重大负面消息(财报暴雷、监管处罚等)
|
||||
|
||||
### 做空信号条件:
|
||||
1. ✅ 趋势:EMA20 < EMA50 < EMA200(或处于反弹中的下降趋势)
|
||||
2. ✅ 价格:跌破 EMA20 或反弹到 EMA20 附近
|
||||
3. ✅ 量价:放量下跌或缩量反弹后继续下跌
|
||||
4. ✅ 共振:多周期共振得分 >= 40(至少大周期一致)
|
||||
5. ✅ 基本面:评分 >= C级(40分以上)
|
||||
6. ✅ 新闻:无重大正面消息
|
||||
|
||||
### 禁止输出信号的情况:
|
||||
- ❌ 趋势不明确(EMA 交织,震荡市)
|
||||
- ❌ 多周期方向完全相反
|
||||
- ❌ 基本面差(D级,<40分)
|
||||
- ❌ 有重大负面新闻(做多时)或重大正面新闻(做空时)
|
||||
- ❌ 技术指标矛盾(如RSI超买但MACD死叉)
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下 JSON 格式输出:
|
||||
|
||||
```json
|
||||
{
|
||||
"trend_direction": "uptrend/downtrend/neutral",
|
||||
"trend_strength": "strong/medium/weak",
|
||||
"analysis_summary": "简要描述当前市场状态(50字以内)",
|
||||
"volume_analysis": "量价分析结论(30字以内)",
|
||||
"news_sentiment": "positive/negative/neutral",
|
||||
"news_impact": "新闻对市场的影响分析(30字以内)",
|
||||
"signals": [
|
||||
{
|
||||
"type": "short_term/medium_term/long_term",
|
||||
"action": "buy/sell",
|
||||
"entry_type": "market/limit",
|
||||
"confidence": 0-100,
|
||||
"grade": "A/B/C/D",
|
||||
"entry_price": 150.50,
|
||||
"stop_loss": 148.00,
|
||||
"take_profit": 155.00,
|
||||
"reasoning": "详细的入场理由(必须包含趋势判断和量价分析)",
|
||||
"key_factors": ["关键因素1", "关键因素2"]
|
||||
}
|
||||
],
|
||||
"key_levels": {
|
||||
"support": [148, 145],
|
||||
"resistance": [152, 155]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 重要说明
|
||||
- **所有价格必须是纯数字**,不要加 $ 符号、逗号或其他格式
|
||||
- `entry_price`、`stop_loss`、`take_profit` 必须是数字类型,不要是字符串
|
||||
- `key_levels` 中的支撑位和阻力位也必须是数字数组
|
||||
|
||||
## 信号等级与置信度(综合技术面 + 基本面 + 新闻)
|
||||
- **A级**(80-100):量价配合 + 多指标共振 + 多周期确认 + 基本面优秀 + 无负面新闻
|
||||
- **B级**(60-79):量价配合 + 主要指标确认 + 基本面良好/一般
|
||||
- **C级**(40-59):技术面有机会但基本面一般,或基本面好但技术面不够明确
|
||||
- **D级**(<40):量价背离或信号矛盾,或基本面差,或有重大负面新闻
|
||||
|
||||
## 注意事项
|
||||
1. **只在有明确的做多或做空机会时才输出信号**(action 为 buy 或 sell)
|
||||
2. 如果市场不明朗,没有明确交易机会,**不要输出任何信号**(signals 为空数组 [])
|
||||
3. 信号强度(confidence)要合理,不要随意给高分
|
||||
4. 60-70分:一般信号,可轻仓试探
|
||||
5. 75-85分:较强信号,可正常仓位
|
||||
6. 90+分:强信号,但也要控制风险
|
||||
7. 不要输出 action 为 "wait" 的信号,如果没有交易机会就不输出
|
||||
8. **必须综合考虑技术面、基本面、新闻三个维度**,不能只看技术面
|
||||
|
||||
记住:你只负责分析市场,输出客观的交易信号,不需要考虑仓位管理和风险控制!
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def analyze(self, symbol: str, data: Dict[str, Any],
|
||||
symbols: List[str] = None,
|
||||
fundamental_data: Dict[str, Any] = None,
|
||||
news_data: List[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
分析市场并生成信号
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
data: 多周期K线数据
|
||||
symbols: 所有监控的股票(用于市场对比)
|
||||
fundamental_data: 基本面数据
|
||||
news_data: 新闻数据列表
|
||||
|
||||
Returns:
|
||||
市场信号字典
|
||||
"""
|
||||
try:
|
||||
# 1. 准备市场数据(技术面 + 基本面 + 新闻)
|
||||
market_context = self._prepare_market_context(
|
||||
symbol, data, symbols,
|
||||
fundamental_data, news_data
|
||||
)
|
||||
|
||||
# 2. 构建 LLM 提示词
|
||||
prompt = self._build_analysis_prompt(symbol, market_context)
|
||||
|
||||
# 3. 调用 LLM 分析
|
||||
messages = [
|
||||
{"role": "system", "content": self.MARKET_ANALYSIS_PROMPT},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
response = await llm_service.achat(messages)
|
||||
|
||||
# 4. 解析结果
|
||||
result = self._parse_llm_response(response, symbol)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"市场信号分析失败: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return self._get_empty_signal(symbol)
|
||||
|
||||
def _prepare_market_context(self, symbol: str, data: Dict,
|
||||
symbols: List[str] = None,
|
||||
fundamental_data: Dict[str, Any] = None,
|
||||
news_data: List[Dict[str, Any]] = None) -> str:
|
||||
"""准备市场上下文信息(结构化版本,减少token消耗)"""
|
||||
from app.stock_agent.analysis_tools import StockAnalysisTools
|
||||
|
||||
context_parts = []
|
||||
|
||||
# 当前价格和24h变化
|
||||
df_1d = data.get('1d')
|
||||
if df_1d is None or len(df_1d) == 0:
|
||||
return ""
|
||||
|
||||
current_price = float(df_1d.iloc[-1]['close'])
|
||||
price_change_24h = self._calculate_price_change_24h(df_1d)
|
||||
context_parts.append(f"**当前价格**: ${current_price:,.2f} ({price_change_24h})")
|
||||
|
||||
# ===== 趋势分析(预计算) =====
|
||||
context_parts.append(f"\n**## 趋势分析**")
|
||||
trend_info = StockAnalysisTools.detect_trend(df_1d)
|
||||
context_parts.append(f"- 方向: {trend_info['direction']} ({trend_info['strength']})")
|
||||
context_parts.append(f"- EMA排列: {trend_info['ema_alignment']}")
|
||||
context_parts.append(f"- 价格相对EMA20: {trend_info['price_vs_ema20']:+.2f}%")
|
||||
|
||||
# 多周期共振(预计算)
|
||||
context_parts.append(f"\n**## 多周期共振**")
|
||||
resonance = StockAnalysisTools.calculate_multi_timeframe_resonance(data)
|
||||
context_parts.append(f"- 共振等级: {resonance['level'].upper()} (得分: {resonance['score']}/100)")
|
||||
context_parts.append(f"- 共振分析: {resonance['analysis']}")
|
||||
context_parts.append(f"- 周线: {resonance['weekly_trend']} | 日线: {resonance['daily_trend']} | 1h: {resonance['hourly_trend']}")
|
||||
|
||||
# ===== 技术指标(只显示日线关键指标) =====
|
||||
context_parts.append(f"\n**## 技术指标(日线)**")
|
||||
latest = df_1d.iloc[-1]
|
||||
|
||||
# RSI
|
||||
if 'rsi' in df_1d.columns:
|
||||
rsi = df_1d['rsi'].iloc[-1]
|
||||
context_parts.append(f"- RSI(14): {rsi:.1f} {'(超买⚠️)' if rsi > 70 else '(超卖💡)' if rsi < 30 else '(正常)'}")
|
||||
|
||||
# MACD
|
||||
if 'macd' in df_1d.columns:
|
||||
macd = df_1d['macd'].iloc[-1]
|
||||
signal = df_1d['macd_signal'].iloc[-1]
|
||||
context_parts.append(f"- MACD: {macd:.4f} (信号: {signal:.4f})")
|
||||
|
||||
# 布林带
|
||||
if 'bb_upper' in df_1d.columns:
|
||||
bb_upper = df_1d['bb_upper'].iloc[-1]
|
||||
bb_lower = df_1d['bb_lower'].iloc[-1]
|
||||
bb_position = (current_price - bb_lower) / (bb_upper - bb_lower) * 100 if bb_upper != bb_lower else 50
|
||||
context_parts.append(f"- 布林带: [{bb_lower:.2f}, {bb_upper:.2f}] (位置: {bb_position:.0f}%)")
|
||||
|
||||
# 量价分析
|
||||
volume_ratio = StockAnalysisTools.calculate_volume_ratio(df_1d, 20)
|
||||
context_parts.append(f"- 量比: {volume_ratio:.2f}x {'放量📊' if volume_ratio > 1.5 else '缩量📉' if volume_ratio < 0.7 else '平量'}")
|
||||
|
||||
# ATR(用于止损止盈)
|
||||
atr = StockAnalysisTools.calculate_atr(df_1d, 14)
|
||||
atr_pct = (atr / current_price * 100) if current_price > 0 else 0
|
||||
context_parts.append(f"- ATR(14): ${atr:.2f} ({atr_pct:.2f}%)")
|
||||
|
||||
# ===== 关键支撑阻力位 =====
|
||||
context_parts.append(f"\n**## 关键位**")
|
||||
key_levels = StockAnalysisTools.identify_key_levels(df_1d, lookback=60)
|
||||
if key_levels['resistance']:
|
||||
context_parts.append(f"- 阻力位: ${', '.join(f'{r:.2f}' for r in key_levels['resistance'][:3])}")
|
||||
else:
|
||||
context_parts.append(f"- 阻力位: 未明确")
|
||||
if key_levels['support']:
|
||||
context_parts.append(f"- 支撑位: ${', '.join(f'{s:.2f}' for s in key_levels['support'][:3])}")
|
||||
else:
|
||||
context_parts.append(f"- 支撑位: 未明确")
|
||||
|
||||
# ===== 基本面分析(评分) =====
|
||||
if fundamental_data:
|
||||
context_parts.append(f"\n**## 基本面分析**")
|
||||
fund_score = StockAnalysisTools.calculate_fundamental_score(fundamental_data)
|
||||
context_parts.append(f"- 综合评分: {fund_score['grade']}级 ({fund_score['score']}/100)")
|
||||
context_parts.append(f"- 估值: {fund_score['breakdown']['valuation']['grade']}级 | 盈利: {fund_score['breakdown']['profitability']['grade']}级")
|
||||
context_parts.append(f"- 成长: {fund_score['breakdown']['growth']['grade']}级 | 财务: {fund_score['breakdown']['financial_health']['grade']}级")
|
||||
context_parts.append(f"- 摘要: {fund_score['summary']}")
|
||||
|
||||
# ===== 新闻舆情 =====
|
||||
if news_data:
|
||||
context_parts.append(f"\n**## 最新新闻**")
|
||||
context_parts.append(self._format_news_data(news_data))
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def _build_analysis_prompt(self, symbol: str, market_context: str) -> str:
|
||||
"""构建分析提示词"""
|
||||
return f"""请分析 {symbol} 的市场情况:
|
||||
|
||||
{market_context}
|
||||
|
||||
请根据以上数据,给出你的市场判断和交易信号。
|
||||
"""
|
||||
|
||||
def _parse_llm_response(self, response: str, symbol: str) -> Dict[str, Any]:
|
||||
"""解析 LLM 响应"""
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
json_match = re.search(r'```json\s*([\s\S]*?)\s*```', response)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
json_str = json_match.group(0)
|
||||
else:
|
||||
raise ValueError("无法找到 JSON 响应")
|
||||
|
||||
# 清理 JSON 字符串
|
||||
json_str = self._clean_json_string(json_str)
|
||||
|
||||
result = json.loads(json_str)
|
||||
|
||||
# 向后兼容:确保新字段存在
|
||||
if 'trend_direction' not in result:
|
||||
result['trend_direction'] = 'neutral'
|
||||
if 'trend_strength' not in result:
|
||||
result['trend_strength'] = 'weak'
|
||||
if 'news_sentiment' not in result:
|
||||
result['news_sentiment'] = 'neutral'
|
||||
if 'news_impact' not in result:
|
||||
result['news_impact'] = ''
|
||||
|
||||
# 清理价格字段 - 转换为 float
|
||||
result = self._clean_price_fields(result)
|
||||
|
||||
# 添加元数据
|
||||
result['symbol'] = symbol
|
||||
result['timestamp'] = datetime.now().isoformat()
|
||||
result['raw_response'] = response
|
||||
|
||||
# 兼容处理:确保 signals 中的字段与旧格式一致
|
||||
if 'signals' in result:
|
||||
for sig in result['signals']:
|
||||
if 'type' in sig:
|
||||
if sig['type'] in ['short_term', 'medium_term', 'long_term']:
|
||||
sig['timeframe'] = sig.pop('type')
|
||||
elif sig['type'] in ['buy', 'sell', 'wait']:
|
||||
sig['action'] = sig.pop('type')
|
||||
|
||||
if 'action' not in sig and 'timeframe' in sig:
|
||||
sig['action'] = 'wait'
|
||||
|
||||
if 'grade' not in sig:
|
||||
confidence = sig.get('confidence', 0)
|
||||
if confidence >= 80:
|
||||
sig['grade'] = 'A'
|
||||
elif confidence >= 60:
|
||||
sig['grade'] = 'B'
|
||||
elif confidence >= 40:
|
||||
sig['grade'] = 'C'
|
||||
else:
|
||||
sig['grade'] = 'D'
|
||||
|
||||
logger.info(f"✅ 市场信号分析完成: {symbol}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析 LLM 响应失败: {e}")
|
||||
logger.warning(f"原始响应: {response[:1000]}...")
|
||||
return self._get_empty_signal(symbol)
|
||||
|
||||
def _clean_json_string(self, json_str: str) -> str:
|
||||
"""清理 JSON 字符串,移除可能导致解析错误的内容"""
|
||||
import re
|
||||
json_str = re.sub(r'//.*?(?=\n|$)', '', json_str)
|
||||
json_str = re.sub(r'/\*[\s\S]*?\*/', '', json_str)
|
||||
json_str = re.sub(r',\s*([}\]])', r'\1', json_str)
|
||||
return json_str
|
||||
|
||||
def _clean_price_fields(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""清理价格字段,转换为 float"""
|
||||
def clean_price(price_value):
|
||||
if price_value is None:
|
||||
return None
|
||||
if isinstance(price_value, (int, float)):
|
||||
return float(price_value)
|
||||
if isinstance(price_value, str):
|
||||
cleaned = price_value.replace('$', '').replace(',', '').strip()
|
||||
if cleaned:
|
||||
try:
|
||||
return float(cleaned)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
if 'key_levels' in data and data['key_levels']:
|
||||
key_levels = data['key_levels']
|
||||
if 'support' in key_levels:
|
||||
data['key_levels']['support'] = [clean_price(s) for s in key_levels['support']]
|
||||
if 'resistance' in key_levels:
|
||||
data['key_levels']['resistance'] = [clean_price(r) for r in key_levels['resistance']]
|
||||
|
||||
if 'signals' in data:
|
||||
for sig in data['signals']:
|
||||
price_fields = ['entry_price', 'stop_loss', 'take_profit']
|
||||
for field in price_fields:
|
||||
if field in sig:
|
||||
sig[field] = clean_price(sig[field])
|
||||
|
||||
# 验证止损止盈价格的合理性
|
||||
entry_price = sig.get('entry_price')
|
||||
stop_loss = sig.get('stop_loss')
|
||||
take_profit = sig.get('take_profit')
|
||||
action = sig.get('action', '')
|
||||
|
||||
if entry_price and entry_price > 0:
|
||||
MAX_REASONABLE_DEVIATION = 0.50 # 50%
|
||||
has_invalid_price = False
|
||||
|
||||
# 检查止损
|
||||
if stop_loss is not None:
|
||||
deviation = abs(stop_loss - entry_price) / entry_price
|
||||
if deviation > MAX_REASONABLE_DEVIATION:
|
||||
logger.warning(f"⚠️ [{data.get('symbol', '')}] 信号止损价格不合理: entry={entry_price}, stop_loss={stop_loss}, 偏离={deviation*100:.1f}%")
|
||||
has_invalid_price = True
|
||||
elif action == 'buy' and stop_loss >= entry_price:
|
||||
logger.warning(f"⚠️ [{data.get('symbol', '')}] 做多止损错误: entry={entry_price}, stop_loss={stop_loss} 应该 < entry")
|
||||
has_invalid_price = True
|
||||
elif action == 'sell' and stop_loss <= entry_price:
|
||||
logger.warning(f"⚠️ [{data.get('symbol', '')}] 做空止损错误: entry={entry_price}, stop_loss={stop_loss} 应该 > entry")
|
||||
has_invalid_price = True
|
||||
|
||||
# 检查止盈
|
||||
if take_profit is not None:
|
||||
deviation = abs(take_profit - entry_price) / entry_price
|
||||
if deviation > MAX_REASONABLE_DEVIATION:
|
||||
logger.warning(f"⚠️ [{data.get('symbol', '')}] 信号止盈价格不合理: entry={entry_price}, take_profit={take_profit}, 偏离={deviation*100:.1f}%")
|
||||
has_invalid_price = True
|
||||
elif action == 'buy' and take_profit <= entry_price:
|
||||
logger.warning(f"⚠️ [{data.get('symbol', '')}] 做多止盈错误: entry={entry_price}, take_profit={take_profit} 应该 > entry")
|
||||
has_invalid_price = True
|
||||
elif action == 'sell' and take_profit >= entry_price:
|
||||
logger.warning(f"⚠️ [{data.get('symbol', '')}] 做空止盈错误: entry={entry_price}, take_profit={take_profit} 应该 < entry")
|
||||
has_invalid_price = True
|
||||
|
||||
# 如果价格不合理,降低等级为 D
|
||||
if has_invalid_price:
|
||||
original_grade = sig.get('grade', 'C')
|
||||
sig['grade'] = 'D'
|
||||
sig['confidence'] = 0
|
||||
# 添加错误说明
|
||||
if 'reasoning' in sig:
|
||||
sig['reasoning'] = f"[价格异常] {sig['reasoning']}"
|
||||
logger.error(f"❌ [{data.get('symbol', '')}] 信号价格异常,等级从 {original_grade} 降为 D,止损止盈已清空")
|
||||
|
||||
# 清空不合理的价格
|
||||
sig['stop_loss'] = None
|
||||
sig['take_profit'] = None
|
||||
|
||||
return data
|
||||
|
||||
def _calculate_price_change_24h(self, df) -> str:
|
||||
"""计算24小时涨跌幅"""
|
||||
try:
|
||||
if df is None or len(df) < 24:
|
||||
return "N/A"
|
||||
|
||||
current_price = float(df['close'].iloc[-1])
|
||||
price_24h_ago = float(df['close'].iloc[-24])
|
||||
change = ((current_price - price_24h_ago) / price_24h_ago) * 100
|
||||
|
||||
sign = "+" if change >= 0 else ""
|
||||
return f"{sign}{change:.2f}%"
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"计算24h涨跌失败: {e}")
|
||||
return "N/A"
|
||||
|
||||
def _analyze_volatility(self, data: Dict[str, pd.DataFrame]) -> str:
|
||||
"""分析波动率变化(使用日线数据)"""
|
||||
df = data.get('1d')
|
||||
if df is None or len(df) < 24 or 'atr' not in df.columns:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
|
||||
recent_atr = df['atr'].iloc[-6:].mean()
|
||||
older_atr = df['atr'].iloc[-12:-6].mean()
|
||||
|
||||
if pd.isna(recent_atr) or pd.isna(older_atr) or older_atr == 0:
|
||||
return ""
|
||||
|
||||
atr_change = (recent_atr - older_atr) / older_atr * 100
|
||||
|
||||
current_atr = float(df['atr'].iloc[-1])
|
||||
current_price = float(df['close'].iloc[-1])
|
||||
atr_percent = current_atr / current_price * 100
|
||||
|
||||
lines.append(f"当前 ATR: ${current_atr:.2f} ({atr_percent:.2f}%)")
|
||||
|
||||
if atr_change > 20:
|
||||
lines.append(f"**波动率扩张**: ATR 上升 {atr_change:.0f}%,趋势可能启动")
|
||||
elif atr_change < -20:
|
||||
lines.append(f"**波动率收缩**: ATR 下降 {abs(atr_change):.0f}%,可能即将突破")
|
||||
else:
|
||||
lines.append(f"波动率稳定: ATR 变化 {atr_change:+.0f}%")
|
||||
|
||||
if 'bb_upper' in df.columns and 'bb_lower' in df.columns:
|
||||
bb_width = (float(df['bb_upper'].iloc[-1]) - float(df['bb_lower'].iloc[-1])) / current_price * 100
|
||||
bb_width_prev = (float(df['bb_upper'].iloc[-6]) - float(df['bb_lower'].iloc[-6])) / float(df['close'].iloc[-6]) * 100
|
||||
|
||||
if bb_width < bb_width_prev * 0.8:
|
||||
lines.append(f"**布林带收口**: 宽度 {bb_width:.1f}%,变盘信号")
|
||||
elif bb_width > bb_width_prev * 1.2:
|
||||
lines.append(f"**布林带开口**: 宽度 {bb_width:.1f}%,趋势延续")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_fundamental_data(self, data: Dict[str, Any]) -> str:
|
||||
"""格式化基本面数据"""
|
||||
if not data:
|
||||
return "暂无基本面数据"
|
||||
|
||||
lines = []
|
||||
|
||||
# 基本信息
|
||||
company_name = data.get('company_name', 'N/A')
|
||||
sector = data.get('sector', 'N/A')
|
||||
lines.append(f"公司: {company_name}")
|
||||
lines.append(f"行业: {sector}")
|
||||
|
||||
# 估值指标
|
||||
val = data.get('valuation', {})
|
||||
if val.get('pe_ratio'):
|
||||
pe = val['pe_ratio']
|
||||
pb = val.get('pb_ratio')
|
||||
ps = val.get('ps_ratio')
|
||||
peg = val.get('peg_ratio')
|
||||
pb_str = f"{pb:.2f}" if pb is not None else "N/A"
|
||||
ps_str = f"{ps:.2f}" if ps is not None else "N/A"
|
||||
peg_str = f"{peg:.2f}" if peg is not None else "N/A"
|
||||
lines.append(f"估值: PE={pe:.2f} | PB={pb_str} | PS={ps_str} | PEG={peg_str}")
|
||||
|
||||
# 盈利能力
|
||||
prof = data.get('profitability', {})
|
||||
if prof.get('return_on_equity'):
|
||||
roe = prof['return_on_equity']
|
||||
pm = prof.get('profit_margin')
|
||||
gm = prof.get('gross_margin')
|
||||
pm_str = f"{pm:.1f}" if pm is not None else "N/A"
|
||||
gm_str = f"{gm:.1f}" if gm is not None else "N/A"
|
||||
lines.append(f"盈利: ROE={roe:.2f}% | 净利率={pm_str}% | 毛利率={gm_str}%")
|
||||
|
||||
# 成长性
|
||||
growth = data.get('growth', {})
|
||||
rg = growth.get('revenue_growth')
|
||||
eg = growth.get('earnings_growth')
|
||||
if rg is not None or eg is not None:
|
||||
rg_str = f"{rg:.1f}" if rg is not None else "N/A"
|
||||
eg_str = f"{eg:.1f}" if eg is not None else "N/A"
|
||||
lines.append(f"成长: 营收增长={rg_str}% | 盈利增长={eg_str}%")
|
||||
|
||||
# 财务健康
|
||||
fin = data.get('financial_health', {})
|
||||
if fin.get('debt_to_equity'):
|
||||
de = fin['debt_to_equity']
|
||||
cr = fin.get('current_ratio')
|
||||
cr_str = f"{cr:.2f}" if cr is not None else "N/A"
|
||||
lines.append(f"财务: 债务股本比={de:.2f} | 流动比率={cr_str}")
|
||||
|
||||
# 分析师建议
|
||||
analyst = data.get('analyst', {})
|
||||
if analyst.get('target_price'):
|
||||
tp = analyst['target_price']
|
||||
rec = analyst.get('recommendation', 'N/A')
|
||||
lines.append(f"分析师: 目标价=${tp:.2f} | 评级={rec}")
|
||||
|
||||
# 基本面评分
|
||||
score = data.get('score', {})
|
||||
if score.get('total'):
|
||||
lines.append(f"基本面评分: {score['total']:.0f}/100 ({score.get('rating', 'N/A')}级)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_news_data(self, news_list: List[Dict[str, Any]]) -> str:
|
||||
"""格式化新闻数据"""
|
||||
if not news_list:
|
||||
return "暂无相关新闻"
|
||||
|
||||
lines = []
|
||||
for i, news in enumerate(news_list[:5], 1): # 最多5条
|
||||
title = news.get('title', '')
|
||||
desc = news.get('description', '')[:150] # 限制描述长度
|
||||
source = news.get('source', '')
|
||||
time_str = news.get('time_str', '')
|
||||
|
||||
lines.append(f"{i}. [{time_str}] {title}")
|
||||
if desc:
|
||||
lines.append(f" {desc}")
|
||||
if source:
|
||||
lines.append(f" 来源: {source}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_empty_signal(self, symbol: str) -> Dict[str, Any]:
|
||||
"""返回空信号"""
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'trend_direction': 'neutral',
|
||||
'trend_strength': 'weak',
|
||||
'analysis_summary': '分析失败',
|
||||
'volume_analysis': '',
|
||||
'news_sentiment': 'neutral',
|
||||
'news_impact': '',
|
||||
'signals': [],
|
||||
'key_levels': {},
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'error': '信号分析失败'
|
||||
}
|
||||
@ -1,834 +0,0 @@
|
||||
"""
|
||||
美股交易智能体 - 主控制器(新架构版)
|
||||
只进行市场分析和通知,不执行模拟交易
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import pandas as pd
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from app.services.yfinance_service import get_yfinance_service
|
||||
from app.services.feishu_service import get_feishu_stock_service
|
||||
from app.services.telegram_service import get_telegram_service
|
||||
from app.services.signal_database_service import get_signal_db_service
|
||||
from app.services.fundamental_service import get_fundamental_service
|
||||
from app.services.news_service import get_news_service
|
||||
from app.stock_agent.market_signal_analyzer import StockMarketSignalAnalyzer
|
||||
from app.utils.system_status import get_system_monitor, AgentStatus
|
||||
|
||||
|
||||
class StockAgent:
|
||||
"""美股交易信号智能体(LLM 驱动,仅分析通知)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化智能体"""
|
||||
self.settings = get_settings()
|
||||
self.yfinance = get_yfinance_service()
|
||||
self.feishu = get_feishu_stock_service()
|
||||
self.telegram = get_telegram_service()
|
||||
self.market_analyzer = StockMarketSignalAnalyzer() # 使用新的市场信号分析器
|
||||
self.signal_db = get_signal_db_service() # 信号数据库服务
|
||||
self.fundamental = get_fundamental_service() # 基本面数据服务
|
||||
self.news = get_news_service() # 新闻服务
|
||||
|
||||
# 状态管理
|
||||
self.last_signals: Dict[str, Dict[str, Any]] = {}
|
||||
self.signal_cooldown: Dict[str, datetime] = {}
|
||||
|
||||
# 配置 - 分别读取美股和港股
|
||||
us_symbols = self.settings.stock_symbols_us.split(',') if self.settings.stock_symbols_us else []
|
||||
hk_symbols = self.settings.stock_symbols_hk.split(',') if self.settings.stock_symbols_hk else []
|
||||
self.symbols = us_symbols + hk_symbols
|
||||
|
||||
# 运行状态
|
||||
self.running = False
|
||||
self._event_loop = None
|
||||
self._task = None
|
||||
|
||||
# 注册到系统监控
|
||||
monitor = get_system_monitor()
|
||||
self._monitor_info = monitor.register_agent(
|
||||
agent_id="stock_agent",
|
||||
name="股票智能体",
|
||||
agent_type="stock"
|
||||
)
|
||||
|
||||
# 分类美股和港股数量
|
||||
us_count = len([s for s in self.symbols if not s.endswith('.HK')])
|
||||
hk_count = len([s for s in self.symbols if s.endswith('.HK')])
|
||||
|
||||
monitor.update_config("stock_agent", {
|
||||
"us_symbols": us_symbols,
|
||||
"hk_symbols": hk_symbols,
|
||||
"total_symbols": len(self.symbols),
|
||||
"us_count": us_count,
|
||||
"hk_count": hk_count,
|
||||
"analysis_interval": f"{self.settings.stock_analysis_interval}秒"
|
||||
})
|
||||
|
||||
logger.info(f"股票智能体初始化完成 - 美股: {us_count}只, 港股: {hk_count}只, 总计: {len(self.symbols)}只")
|
||||
|
||||
async def start(self):
|
||||
"""启动智能体"""
|
||||
if self.running:
|
||||
logger.warning("美股智能体已在运行中")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._event_loop = asyncio.get_event_loop()
|
||||
|
||||
# 更新状态为启动中
|
||||
monitor = get_system_monitor()
|
||||
monitor.update_status("stock_agent", AgentStatus.STARTING)
|
||||
|
||||
logger.info("美股智能体已启动")
|
||||
|
||||
# 启动分析任务
|
||||
self._task = asyncio.create_task(self._analysis_loop())
|
||||
|
||||
# 更新状态为运行中
|
||||
monitor.update_status("stock_agent", AgentStatus.RUNNING)
|
||||
|
||||
# 发送启动通知(卡片格式)
|
||||
us_stocks = [s for s in self.symbols if not s.endswith('.HK')]
|
||||
hk_stocks = [s for s in self.symbols if s.endswith('.HK')]
|
||||
|
||||
title = "📈 股票智能体已启动"
|
||||
|
||||
content_parts = [
|
||||
f"🤖 **驱动引擎**: LLM 三维分析",
|
||||
f"📊 **监控股票**: {len(self.symbols)} 只",
|
||||
]
|
||||
|
||||
if us_stocks:
|
||||
content_parts.append(f" 🇺🇸 美股 ({len(us_stocks)}): {', '.join(us_stocks[:3])}{'...' if len(us_stocks) > 3 else ''}")
|
||||
if hk_stocks:
|
||||
content_parts.append(f" 🇭🇰 港股 ({len(hk_stocks)}): {', '.join(hk_stocks[:3])}{'...' if len(hk_stocks) > 3 else ''}")
|
||||
|
||||
content_parts.extend([
|
||||
f"⏰ **运行频率**: 每小时整点",
|
||||
f"🎯 **分析维度**: 技术面(40%) + 基本面(35%) + 新闻(25%)",
|
||||
f"📢 **当前模式**: 仅市场分析",
|
||||
])
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
await self.feishu.send_card(title, content, "green")
|
||||
|
||||
async def stop(self):
|
||||
"""停止智能体"""
|
||||
self.running = False
|
||||
|
||||
# 更新状态为已停止
|
||||
monitor = get_system_monitor()
|
||||
monitor.update_status("stock_agent", AgentStatus.STOPPED)
|
||||
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("美股智能体已停止")
|
||||
|
||||
async def _analysis_loop(self):
|
||||
"""分析循环 - 根据交易时间分析对应市场的股票"""
|
||||
while self.running:
|
||||
try:
|
||||
# 计算距离下一个整点的时间
|
||||
now = datetime.now()
|
||||
next_hour = now.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
wait_seconds = (next_hour - now).total_seconds()
|
||||
|
||||
logger.info(f"等待到下一个整点: {next_hour.strftime('%H:%M')} (等待 {int(wait_seconds)} 秒)")
|
||||
|
||||
# 等待到整点
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
# 分类股票:美股和港股
|
||||
us_stocks = [s for s in self.symbols if not s.endswith('.HK')]
|
||||
hk_stocks = [s for s in self.symbols if s.endswith('.HK')]
|
||||
|
||||
# 检查各市场交易时间
|
||||
us_market_open = self._is_market_hours('US')
|
||||
hk_market_open = self._is_market_hours('0700.HK')
|
||||
|
||||
# 检查是否是盘后分析时间
|
||||
us_after_hours = self._is_after_hours('US')
|
||||
hk_after_hours = self._is_after_hours('0700.HK')
|
||||
|
||||
# 确定要分析的股票列表
|
||||
stocks_to_analyze = []
|
||||
analysis_type = "盘中" # 默认为盘中分析
|
||||
|
||||
# 盘后分析:优先级更高,用于日线级别分析
|
||||
if us_after_hours or hk_after_hours:
|
||||
analysis_type = "盘后"
|
||||
if us_after_hours:
|
||||
stocks_to_analyze.extend(us_stocks)
|
||||
logger.info(f"美股盘后分析,分析 {len(us_stocks)} 只美股(日线级别)")
|
||||
if hk_after_hours:
|
||||
stocks_to_analyze.extend(hk_stocks)
|
||||
logger.info(f"港股盘后分析,分析 {len(hk_stocks)} 只港股(日线级别)")
|
||||
else:
|
||||
# 盘中分析
|
||||
if us_market_open:
|
||||
stocks_to_analyze.extend(us_stocks)
|
||||
logger.info(f"美股交易时间,分析 {len(us_stocks)} 只美股")
|
||||
if hk_market_open:
|
||||
stocks_to_analyze.extend(hk_stocks)
|
||||
logger.info(f"港股交易时间,分析 {len(hk_stocks)} 只港股")
|
||||
|
||||
# 如果没有需要分析的股票
|
||||
if not stocks_to_analyze:
|
||||
logger.debug("没有需要分析的股票")
|
||||
continue
|
||||
|
||||
# 分析股票并收集结果
|
||||
logger.info(f"开始{analysis_type}分析 {len(stocks_to_analyze)} 只股票")
|
||||
analysis_results = []
|
||||
|
||||
for symbol in stocks_to_analyze:
|
||||
if not self.running:
|
||||
break
|
||||
result = await self.analyze_symbol(symbol, is_after_hours=(analysis_type == "盘后"))
|
||||
if result:
|
||||
analysis_results.append(result)
|
||||
|
||||
# 生成并发送汇总报告
|
||||
await self._send_summary_report(analysis_results, analysis_type)
|
||||
|
||||
logger.info(f"本次{analysis_type}分析完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析循环出错: {e}")
|
||||
await asyncio.sleep(60) # 出错后等待 1 分钟再重试
|
||||
|
||||
def _is_market_hours(self, symbol: str = None) -> bool:
|
||||
"""
|
||||
判断当前是否在交易时间
|
||||
|
||||
美股交易时间: 周一至周五 9:30-16:00 (EST)
|
||||
北京时间:
|
||||
- 冬令时 (11月-3月): 22:30-05:00 (次日)
|
||||
- 夏令时 (3月-11月): 21:30-04:00 (次日)
|
||||
|
||||
港股交易时间: 周一至周五
|
||||
北京时间:
|
||||
- 上午: 09:30-12:00
|
||||
- 下午: 13:00-16:00
|
||||
|
||||
Args:
|
||||
symbol: 股票代码(用于判断是美股还是港股)
|
||||
|
||||
Returns:
|
||||
是否在交易时间
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# 获取当前时间
|
||||
now = datetime.now()
|
||||
|
||||
# 检查是否为周末
|
||||
if now.weekday() >= 5: # 5=周六, 6=周日
|
||||
return False
|
||||
|
||||
# 判断是港股还是美股
|
||||
is_hk_stock = symbol and symbol.endswith('.HK') if symbol else False
|
||||
|
||||
# 获取当前小时和分钟
|
||||
hour = now.hour
|
||||
minute = now.minute
|
||||
current_time = hour * 100 + minute # 转换为数字,如 2130 表示 21:30
|
||||
|
||||
if is_hk_stock:
|
||||
# 港股交易时间: 09:30-12:00 或 13:00-16:00
|
||||
return (930 <= current_time < 1200) or (1300 <= current_time < 1600)
|
||||
else:
|
||||
# 美股交易时间
|
||||
# 判断夏令时/冬令时(简单判断:3-11月为夏令时)
|
||||
is_summer = 3 <= now.month <= 11
|
||||
|
||||
if is_summer:
|
||||
# 夏令时: 21:30-04:00 (次日)
|
||||
# 即 2130-2359 或 0000-0400
|
||||
if current_time >= 2130 or current_time < 400:
|
||||
return True
|
||||
else:
|
||||
# 冬令时: 22:30-05:00 (次日)
|
||||
# 即 2230-2359 或 0000-0500
|
||||
if current_time >= 2230 or current_time < 500:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_any_market_hours(self) -> bool:
|
||||
"""判断当前是否在任一市场的交易时间(美股或港股)"""
|
||||
return self._is_market_hours('US') or self._is_market_hours('0700.HK')
|
||||
|
||||
def _is_after_hours(self, symbol: str) -> bool:
|
||||
"""
|
||||
判断当前是否是盘后分析时间(收盘后2小时内)
|
||||
|
||||
美股收盘时间:
|
||||
- 夏令时: 北京时间 04:00 收盘
|
||||
- 冬令时: 北京时间 05:00 收盘
|
||||
|
||||
港股收盘时间: 北京时间 16:00 收盘
|
||||
|
||||
盘后分析时间: 收盘后 2 小时内
|
||||
|
||||
Args:
|
||||
symbol: 股票代码(用于判断是美股还是港股)
|
||||
|
||||
Returns:
|
||||
是否是盘后分析时间
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# 获取当前时间
|
||||
now = datetime.now()
|
||||
|
||||
# 检查是否为周末
|
||||
if now.weekday() >= 5: # 5=周六, 6=周日
|
||||
return False
|
||||
|
||||
# 判断是港股还是美股
|
||||
is_hk_stock = symbol and symbol.endswith('.HK') if symbol else False
|
||||
|
||||
# 获取当前小时和分钟
|
||||
hour = now.hour
|
||||
minute = now.minute
|
||||
current_time = hour * 100 + minute # 转换为数字,如 1630 表示 16:30
|
||||
|
||||
if is_hk_stock:
|
||||
# 港股盘后: 16:00-18:00 (收盘后2小时)
|
||||
return 1600 <= current_time < 1800
|
||||
else:
|
||||
# 美股盘后
|
||||
# 判断夏令时/冬令时(简单判断:3-11月为夏令时)
|
||||
is_summer = 3 <= now.month <= 11
|
||||
|
||||
if is_summer:
|
||||
# 夏令时: 04:00-06:00 (收盘后2小时)
|
||||
return 400 <= current_time < 600
|
||||
else:
|
||||
# 冬令时: 05:00-07:00 (收盘后2小时)
|
||||
return 500 <= current_time < 700
|
||||
|
||||
return False
|
||||
|
||||
async def analyze_symbol(self, symbol: str, is_after_hours: bool = False) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
分析单个股票
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
is_after_hours: 是否是盘后分析(盘后会更关注日线级别机会)
|
||||
|
||||
Returns:
|
||||
分析结果字典,包含股票信息和信号
|
||||
"""
|
||||
# 更新活动时间
|
||||
monitor = get_system_monitor()
|
||||
monitor.update_activity("stock_agent")
|
||||
|
||||
result = {
|
||||
'symbol': symbol,
|
||||
'stock_name': '', # 从基本面数据获取的公司名称
|
||||
'current_price': 0,
|
||||
'signals': [],
|
||||
'analysis_summary': '',
|
||||
'notified': False
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 获取多时间周期数据
|
||||
data = self.yfinance.get_multi_timeframe_data(symbol)
|
||||
|
||||
# 2. 验证数据完整性
|
||||
if not self._validate_data(data):
|
||||
logger.warning(f"{symbol} 数据不完整,跳过本次分析")
|
||||
return result
|
||||
|
||||
# 3. 获取当前价格
|
||||
ticker = self.yfinance.get_ticker(symbol)
|
||||
if not ticker:
|
||||
logger.warning(f"无法获取 {symbol} 当前价格")
|
||||
return result
|
||||
current_price = ticker['lastPrice']
|
||||
result['current_price'] = current_price
|
||||
|
||||
# 4. 获取基本面数据(包含公司名称)
|
||||
logger.info(f"\n📈 【基本面分析】")
|
||||
fundamental_data = None
|
||||
fundamental_summary = ""
|
||||
stock_name = "" # 从基本面数据获取公司名称
|
||||
try:
|
||||
fundamental_data = self.fundamental.get_fundamental_data(symbol)
|
||||
if fundamental_data:
|
||||
# 传递已获取的数据,避免重复调用
|
||||
fundamental_summary = self.fundamental.get_fundamental_summary(symbol, fundamental_data)
|
||||
# 从基本面数据获取公司名称
|
||||
stock_name = fundamental_data.get('company_name', '')
|
||||
result['stock_name'] = stock_name # 保存到结果中
|
||||
# 基本面评分已经在 fundamental_service 中输出
|
||||
else:
|
||||
logger.warning(f" ⚠️ 无法获取基本面数据")
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠️ 获取基本面数据失败: {e}")
|
||||
|
||||
symbol_display = f"{stock_name}({symbol})" if stock_name else symbol
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"📊 分析 {symbol_display} @ ${current_price:,.2f}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
# 5. 获取新闻数据
|
||||
logger.info(f"\n📰 【新闻分析】")
|
||||
news_data = None
|
||||
try:
|
||||
news_data = await self.news.search_stock_news(symbol, stock_name, max_results=5)
|
||||
if news_data:
|
||||
logger.info(f" 获取到 {len(news_data)} 条相关新闻")
|
||||
else:
|
||||
logger.info(f" 暂无相关新闻")
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠️ 获取新闻数据失败: {e}")
|
||||
|
||||
# 6. 市场信号分析(使用新架构 - 技术面 + 基本面 + 新闻)
|
||||
logger.info(f"\n🤖 【市场信号分析中...】")
|
||||
market_signal = await self.market_analyzer.analyze(
|
||||
symbol, data,
|
||||
symbols=self.symbols,
|
||||
fundamental_data=fundamental_data,
|
||||
news_data=news_data
|
||||
)
|
||||
|
||||
# 输出分析摘要
|
||||
summary = market_signal.get('analysis_summary', '无')
|
||||
result['analysis_summary'] = summary
|
||||
logger.info(f" 市场状态: {summary}")
|
||||
|
||||
# 输出新闻情绪(如果有)
|
||||
# 注:新的分析器不包含新闻分析,可以跳过或从其他地方获取
|
||||
|
||||
# 输出关键价位
|
||||
levels = market_signal.get('key_levels', {})
|
||||
if levels.get('support') or levels.get('resistance'):
|
||||
support_str = ', '.join([f"${s:,.2f}" for s in levels.get('support', [])[:2]])
|
||||
resistance_str = ', '.join([f"${r:,.2f}" for r in levels.get('resistance', [])[:2]])
|
||||
logger.info(f" 支撑位: {support_str or '-'}")
|
||||
logger.info(f" 阻力位: {resistance_str or '-'}")
|
||||
|
||||
# 5. 处理信号
|
||||
signals = market_signal.get('signals', [])
|
||||
result['signals'] = signals
|
||||
|
||||
if not signals:
|
||||
logger.info(f"\n⏸️ 结论: 无交易信号,继续观望")
|
||||
return result
|
||||
|
||||
# 输出所有信号
|
||||
logger.info(f"\n🎯 【发现 {len(signals)} 个信号】")
|
||||
|
||||
for sig in signals:
|
||||
signal_type = sig.get('type', 'unknown')
|
||||
type_map = {'short_term': '短线', 'medium_term': '中线', 'long_term': '长线'}
|
||||
type_text = type_map.get(signal_type, signal_type)
|
||||
|
||||
action = sig.get('action', 'wait')
|
||||
action_map = {'buy': '🟢 做多', 'sell': '🔴 做空'}
|
||||
action_text = action_map.get(action, action)
|
||||
|
||||
grade = sig.get('grade', 'D')
|
||||
confidence = sig.get('confidence', 0)
|
||||
grade_icon = {'A': '⭐⭐⭐', 'B': '⭐⭐', 'C': '⭐', 'D': ''}.get(grade, '')
|
||||
|
||||
logger.info(f"\n {type_text} {action_text} [{grade}级{grade_icon}] {confidence}%")
|
||||
|
||||
# 6. 过滤并通知最佳信号
|
||||
best_signal = self._get_best_signal(signals)
|
||||
|
||||
if not best_signal:
|
||||
logger.info(f"\n⏸️ 信号质量不高,不发送通知")
|
||||
return result
|
||||
|
||||
logger.info(f"\n📢 【最佳信号】{best_signal.get('action')} {best_signal.get('grade')}级 {best_signal.get('confidence')}%")
|
||||
|
||||
# 检查置信度阈值
|
||||
threshold = self.settings.stock_llm_threshold * 100
|
||||
if best_signal.get('confidence', 0) < threshold:
|
||||
logger.info(f"\n⏸️ 置信度不足 ({best_signal.get('confidence', 0)}% < {threshold}%)")
|
||||
return result
|
||||
|
||||
# 检查冷却时间
|
||||
if not self._should_send_signal(symbol, best_signal):
|
||||
logger.info(f"\n⏸️ 信号冷却中,不发送通知")
|
||||
return result
|
||||
|
||||
logger.info(f"\n✅ 满足所有条件,准备发送通知...")
|
||||
|
||||
# 发送通知
|
||||
try:
|
||||
await self._send_signal_notification(symbol, best_signal, current_price)
|
||||
result['notified'] = True
|
||||
result['best_signal'] = best_signal
|
||||
|
||||
# 更新状态
|
||||
self.last_signals[symbol] = best_signal
|
||||
self.signal_cooldown[symbol] = datetime.now()
|
||||
except Exception as notify_error:
|
||||
logger.error(f"❌ 发送 {symbol} 通知失败: {notify_error}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
result['notified'] = False
|
||||
result['notify_error'] = str(notify_error)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 分析 {symbol} 出错: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return result
|
||||
|
||||
def _get_best_signal(self, signals: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""获取最佳信号"""
|
||||
# 过滤掉 D 级信号
|
||||
valid_signals = [s for s in signals if s.get('grade', 'D') != 'D']
|
||||
|
||||
if not valid_signals:
|
||||
return None
|
||||
|
||||
# 按等级和置信度排序
|
||||
grade_order = {'A': 0, 'B': 1, 'C': 2}
|
||||
valid_signals.sort(key=lambda x: (
|
||||
grade_order.get(x.get('grade', 'C'), 3),
|
||||
-x.get('confidence', 0)
|
||||
))
|
||||
|
||||
return valid_signals[0]
|
||||
|
||||
def _should_send_signal(self, symbol: str, signal: Dict[str, Any]) -> bool:
|
||||
"""判断是否应该发送信号"""
|
||||
action = signal.get('action', 'wait')
|
||||
if action == 'wait':
|
||||
return False
|
||||
|
||||
# 检查冷却时间(60分钟内不重复发送相同方向的信号)
|
||||
if symbol in self.signal_cooldown:
|
||||
cooldown_end = self.signal_cooldown[symbol] + timedelta(minutes=60)
|
||||
if datetime.now() < cooldown_end:
|
||||
if symbol in self.last_signals:
|
||||
if self.last_signals[symbol].get('action') == action:
|
||||
logger.debug(f"{symbol} 信号冷却中,跳过")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _send_signal_notification(
|
||||
self,
|
||||
symbol: str,
|
||||
signal: Dict[str, Any],
|
||||
current_price: float
|
||||
):
|
||||
"""发送信号通知"""
|
||||
try:
|
||||
logger.info(f"📤 准备发送 {symbol} 信号通知...")
|
||||
|
||||
from app.utils.signal_formatter import get_signal_formatter
|
||||
formatter = get_signal_formatter()
|
||||
|
||||
# 获取股票名称
|
||||
stock_name = signal.get('stock_name', '')
|
||||
|
||||
# 使用格式化工具格式化信号
|
||||
card = formatter.format_feishu_card(signal, symbol, agent_type='stock', stock_name=stock_name)
|
||||
title = card['title']
|
||||
content = card['content']
|
||||
|
||||
logger.info(f" 标题: {title}")
|
||||
logger.info(f" 内容长度: {len(content)} 字符")
|
||||
|
||||
# 根据信号方向选择颜色
|
||||
color = "green" if signal.get('action') == 'buy' else "red"
|
||||
logger.info(f" 颜色: {color}")
|
||||
|
||||
# 检查飞书服务
|
||||
logger.info(f" 飞书服务: {type(self.feishu).__name__}")
|
||||
logger.info(f" Webhook URL: {self.feishu.webhook_url[:50]}...")
|
||||
|
||||
# 发送到飞书
|
||||
feishu_success = await self.feishu.send_card(title, content, color)
|
||||
if feishu_success:
|
||||
logger.info(f" ✅ 飞书通知发送成功")
|
||||
else:
|
||||
logger.warning(f" ⚠️ 飞书通知发送失败(但Telegram会发送)")
|
||||
|
||||
# 发送到 Telegram(也传递 stock_name)
|
||||
await self.telegram.send_message(formatter.format_signal_message(signal, symbol, agent_type='stock', stock_name=stock_name))
|
||||
|
||||
logger.info(f"✅ 信号通知已发送: {title}")
|
||||
|
||||
# 保存信号到数据库
|
||||
signal_to_save = signal.copy()
|
||||
signal_to_save['signal_type'] = 'stock'
|
||||
signal_to_save['symbol'] = symbol
|
||||
signal_to_save['current_price'] = current_price
|
||||
self.signal_db.add_signal(signal_to_save)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 发送通知失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
# 重新抛出异常,让上层能够捕获
|
||||
raise
|
||||
|
||||
def _validate_data(self, data: Dict[str, pd.DataFrame]) -> bool:
|
||||
"""验证数据完整性"""
|
||||
required_intervals = ['1d', '1h']
|
||||
for interval in required_intervals:
|
||||
if interval not in data or data[interval].empty:
|
||||
return False
|
||||
if len(data[interval]) < 20:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def analyze_once(self, symbol: str) -> Dict[str, Any]:
|
||||
"""单次分析(用于测试或手动触发)"""
|
||||
data = self.yfinance.get_multi_timeframe_data(symbol)
|
||||
|
||||
if not self._validate_data(data):
|
||||
return {'error': '数据不完整'}
|
||||
|
||||
# 获取基本面数据
|
||||
fundamental_data = None
|
||||
fundamental_summary = ""
|
||||
try:
|
||||
fundamental_data = self.fundamental.get_fundamental_data(symbol)
|
||||
if fundamental_data:
|
||||
# 传递已获取的数据,避免重复调用
|
||||
fundamental_summary = self.fundamental.get_fundamental_summary(symbol, fundamental_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取基本面数据失败: {e}")
|
||||
|
||||
result = await self.market_analyzer.analyze(
|
||||
symbol, data,
|
||||
symbols=self.symbols
|
||||
)
|
||||
return result
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取智能体状态"""
|
||||
return {
|
||||
'running': self.running,
|
||||
'symbols': self.symbols,
|
||||
'mode': 'LLM 驱动(仅分析通知)',
|
||||
'last_signals': {
|
||||
symbol: {
|
||||
'type': sig.get('type'),
|
||||
'action': sig.get('action'),
|
||||
'confidence': sig.get('confidence'),
|
||||
'grade': sig.get('grade')
|
||||
}
|
||||
for symbol, sig in self.last_signals.items()
|
||||
}
|
||||
}
|
||||
|
||||
async def _send_summary_report(self, results: List[Dict[str, Any]], analysis_type: str = "盘中"):
|
||||
"""
|
||||
生成并发送分析汇总报告
|
||||
|
||||
Args:
|
||||
results: 所有股票的分析结果列表
|
||||
analysis_type: 分析类型 ("盘中" 或 "盘后")
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
total = len(results)
|
||||
with_signals = [r for r in results if r.get('signals')]
|
||||
notified = [r for r in results if r.get('notified')]
|
||||
|
||||
# 区分美股和港股
|
||||
us_results = [r for r in results if not r['symbol'].endswith('.HK')]
|
||||
hk_results = [r for r in results if r['symbol'].endswith('.HK')]
|
||||
us_with_signals = [r for r in us_results if r.get('signals')]
|
||||
hk_with_signals = [r for r in hk_results if r.get('signals')]
|
||||
|
||||
# 统计信号
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
high_quality_signals = [] # A/B级信号
|
||||
|
||||
for r in with_signals:
|
||||
for sig in r.get('signals', []):
|
||||
sig['symbol'] = r['symbol']
|
||||
sig['current_price'] = r.get('current_price', 0)
|
||||
sig['is_hk'] = r['symbol'].endswith('.HK')
|
||||
sig['stock_name'] = r.get('stock_name', '')
|
||||
|
||||
if sig.get('action') == 'buy':
|
||||
buy_signals.append(sig)
|
||||
elif sig.get('action') == 'sell':
|
||||
sell_signals.append(sig)
|
||||
|
||||
if sig.get('grade') in ['A', 'B']:
|
||||
high_quality_signals.append(sig)
|
||||
|
||||
# 按置信度排序
|
||||
high_quality_signals.sort(key=lambda x: x.get('confidence', 0), reverse=True)
|
||||
|
||||
# 构建汇总报告
|
||||
analysis_tag = f"【{analysis_type}分析】"
|
||||
logger.info(f"\n{'='*80}")
|
||||
logger.info(f"📊 股票分析汇总报告 {analysis_tag}")
|
||||
logger.info(f"{'='*80}")
|
||||
logger.info(f"时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f"分析总数: {total} 只 (美股: {len(us_results)}, 港股: {len(hk_results)})")
|
||||
logger.info(f"有信号: {len(with_signals)} 只 (美股: {len(us_with_signals)}, 港股: {len(hk_with_signals)})")
|
||||
logger.info(f"已通知: {len(notified)} 只")
|
||||
logger.info(f"")
|
||||
|
||||
# 显示高等级信号
|
||||
if high_quality_signals:
|
||||
logger.info(f"⭐ 高等级信号 (A/B级): {len(high_quality_signals)} 个")
|
||||
for sig in high_quality_signals[:10]: # 最多显示10个
|
||||
symbol = sig['symbol']
|
||||
stock_name = sig.get('stock_name', '')
|
||||
market_tag = '[港股]' if sig.get('is_hk') else '[美股]'
|
||||
action = '🟢 做多' if sig.get('action') == 'buy' else '🔴 做空'
|
||||
grade = sig.get('grade', 'D')
|
||||
confidence = sig.get('confidence', 0)
|
||||
price = sig.get('current_price', 0)
|
||||
entry = sig.get('entry_price', 0)
|
||||
|
||||
# 构建带名称的股票显示
|
||||
symbol_display = f"{stock_name}({symbol})" if stock_name else symbol
|
||||
|
||||
logger.info(f" {market_tag} {symbol_display} {action} [{grade}级] {confidence}% @ ${price:,.2f}")
|
||||
if entry > 0:
|
||||
logger.info(f" 入场: ${entry:,.2f}")
|
||||
logger.info(f"")
|
||||
|
||||
# 统计汇总
|
||||
logger.info(f"📈 做多信号: {len(buy_signals)} 个")
|
||||
logger.info(f"📉 做空信号: {len(sell_signals)} 个")
|
||||
logger.info(f"{'='*80}\n")
|
||||
|
||||
# 发送飞书汇总
|
||||
await self._send_feishu_summary(
|
||||
now, total, with_signals, notified,
|
||||
buy_signals, sell_signals, high_quality_signals,
|
||||
len(us_results), len(hk_results),
|
||||
analysis_type
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成汇总报告失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _send_feishu_summary(
|
||||
self,
|
||||
now: datetime,
|
||||
total: int,
|
||||
with_signals: List,
|
||||
notified: List,
|
||||
buy_signals: List,
|
||||
sell_signals: List,
|
||||
high_quality_signals: List,
|
||||
us_count: int = 0,
|
||||
hk_count: int = 0,
|
||||
analysis_type: str = "盘中"
|
||||
):
|
||||
"""发送飞书汇总报告"""
|
||||
try:
|
||||
# 构建内容
|
||||
analysis_tag = f"【{analysis_type}分析】"
|
||||
content_parts = [
|
||||
f"**📊 股票分析汇总报告 {analysis_tag}**",
|
||||
f"",
|
||||
f"⏰ 时间: {now.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"",
|
||||
f"📊 **分析概况**",
|
||||
f"• 美股: {us_count} 只 | 港股: {hk_count} 只",
|
||||
f"• 发现信号: {len(with_signals)} 只",
|
||||
f"• 已发通知: {len(notified)} 只",
|
||||
f"",
|
||||
]
|
||||
|
||||
# 所有信号(按等级分组)
|
||||
all_signals = buy_signals + sell_signals
|
||||
|
||||
# 高等级信号 (A/B级)
|
||||
if high_quality_signals:
|
||||
content_parts.append(f"⭐ **高等级信号 (A/B级)**")
|
||||
for sig in high_quality_signals[:5]:
|
||||
symbol = sig['symbol']
|
||||
stock_name = sig.get('stock_name', '')
|
||||
market_tag = '[港股]' if sig.get('is_hk') else '[美股]'
|
||||
action = '🟢 做多' if sig.get('action') == 'buy' else '🔴 做空'
|
||||
grade = sig.get('grade', 'D')
|
||||
confidence = sig.get('confidence', 0)
|
||||
|
||||
# 构建带名称的股票显示
|
||||
symbol_display = f"{stock_name}({symbol})" if stock_name else symbol
|
||||
|
||||
content_parts.append(f"• {market_tag} {symbol_display} {action} {grade}级 {confidence}%")
|
||||
content_parts.append(f"")
|
||||
|
||||
# 其他等级信号 (C/D级)
|
||||
other_signals = [s for s in all_signals if s.get('grade', 'D') not in ['A', 'B']]
|
||||
if other_signals:
|
||||
content_parts.append(f"📋 **其他信号 (C/D级)**")
|
||||
for sig in other_signals[:10]: # 最多显示10个
|
||||
symbol = sig['symbol']
|
||||
stock_name = sig.get('stock_name', '')
|
||||
market_tag = '[港股]' if sig.get('is_hk') else '[美股]'
|
||||
action = '🟢 做多' if sig.get('action') == 'buy' else '🔴 做空'
|
||||
grade = sig.get('grade', 'D')
|
||||
confidence = sig.get('confidence', 0)
|
||||
|
||||
# 构建带名称的股票显示
|
||||
symbol_display = f"{stock_name}({symbol})" if stock_name else symbol
|
||||
|
||||
content_parts.append(f"• {market_tag} {symbol_display} {action} {grade}级 {confidence}%")
|
||||
|
||||
if len(other_signals) > 10:
|
||||
content_parts.append(f" *...还有 {len(other_signals) - 10} 个信号*")
|
||||
content_parts.append(f"")
|
||||
|
||||
# 信号统计
|
||||
content_parts.extend([
|
||||
f"📈 做多信号: {len(buy_signals)} 个",
|
||||
f"📉 做空信号: {len(sell_signals)} 个",
|
||||
f"",
|
||||
f"*⚠️ 仅供参考,不构成投资建议*"
|
||||
])
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
|
||||
# 发送飞书 - 标题包含分析类型
|
||||
type_tag = "盘后" if analysis_type == "盘后" else "分析"
|
||||
title = f"📊 股票{type_tag}汇总 ({now.strftime('%H:%M')})"
|
||||
color = "blue"
|
||||
|
||||
await self.feishu.send_card(title, content, color)
|
||||
logger.info("✅ 汇总报告已发送到飞书")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送飞书汇总失败: {e}")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_stock_agent: Optional[StockAgent] = None
|
||||
|
||||
|
||||
def get_stock_agent() -> StockAgent:
|
||||
"""获取美股智能体单例"""
|
||||
global _stock_agent
|
||||
if _stock_agent is None:
|
||||
_stock_agent = StockAgent()
|
||||
return _stock_agent
|
||||
@ -9,7 +9,7 @@ from typing import Optional
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = "stock_agent",
|
||||
name: str = "crypto_agent",
|
||||
level: int = logging.INFO,
|
||||
log_file: Optional[str] = None
|
||||
) -> logging.Logger:
|
||||
|
||||
@ -1,221 +0,0 @@
|
||||
"""
|
||||
信号格式化工具
|
||||
|
||||
用于格式化交易信号通知,支持:
|
||||
- 飞书卡片格式
|
||||
- Telegram 文本格式
|
||||
- 支持加密货币、美股、港股
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class SignalFormatter:
|
||||
"""信号格式化工具"""
|
||||
|
||||
@staticmethod
|
||||
def format_signal_message(signal: Dict[str, Any], symbol: str, agent_type: str = 'crypto', stock_name: str = '') -> str:
|
||||
"""
|
||||
格式化信号消息(用于 Telegram 通知)
|
||||
|
||||
Args:
|
||||
signal: 信号数据
|
||||
symbol: 交易对
|
||||
agent_type: 智能体类型 (crypto/stock)
|
||||
stock_name: 股票名称(可选,从基本面数据获取)
|
||||
|
||||
Returns:
|
||||
格式化的消息文本
|
||||
"""
|
||||
type_map = {
|
||||
'short_term': '短线',
|
||||
'medium_term': '中线',
|
||||
'long_term': '长线'
|
||||
}
|
||||
action_map = {
|
||||
'buy': '做多',
|
||||
'sell': '做空'
|
||||
}
|
||||
|
||||
# 兼容 timeframe 和 type 字段
|
||||
signal_type_key = 'timeframe' if 'timeframe' in signal else 'type'
|
||||
signal_type = type_map.get(signal.get(signal_type_key), signal.get(signal_type_key))
|
||||
action = action_map.get(signal['action'], signal['action'])
|
||||
grade = signal.get('grade', 'C')
|
||||
confidence = signal.get('confidence', 0)
|
||||
entry_type = signal.get('entry_type', 'market')
|
||||
|
||||
# 等级图标
|
||||
grade_icon = {'A': '⭐⭐⭐', 'B': '⭐⭐', 'C': '⭐', 'D': ''}.get(grade, '')
|
||||
|
||||
# 方向图标
|
||||
action_icon = '🟢' if signal['action'] == 'buy' else '🔴'
|
||||
|
||||
# 入场类型
|
||||
entry_type_text = '现价入场' if entry_type == 'market' else '挂单等待'
|
||||
entry_type_icon = '⚡' if entry_type == 'market' else '⏳'
|
||||
|
||||
# 仓位大小
|
||||
position_size = signal.get('position_size', 'light')
|
||||
position_map = {'heavy': '重仓', 'medium': '中仓', 'light': '轻仓'}
|
||||
position_icon = {'heavy': '🔥', 'medium': '📊', 'light': '🌱'}.get(position_size, '🌱')
|
||||
position_text = position_map.get(position_size, '轻仓')
|
||||
|
||||
# 计算风险收益比
|
||||
entry = signal.get('entry_price', 0)
|
||||
sl = signal.get('stop_loss', 0)
|
||||
tp = signal.get('take_profit', 0)
|
||||
sl_percent = ((sl - entry) / entry * 100) if entry else 0
|
||||
tp_percent = ((tp - entry) / entry * 100) if entry else 0
|
||||
|
||||
# 识别市场类型
|
||||
if agent_type == 'crypto':
|
||||
market_tag = '[加密货币] '
|
||||
elif symbol.endswith('.HK'):
|
||||
market_tag = '[港股] '
|
||||
else:
|
||||
market_tag = '[美股] '
|
||||
|
||||
# 构建标题(带股票名称和市场类型)
|
||||
symbol_display = f"{stock_name}({symbol})" if stock_name else symbol
|
||||
|
||||
message = f"""📊 {market_tag}{symbol_display} {signal_type}信号
|
||||
|
||||
{action_icon} **方向**: {action}
|
||||
{entry_type_icon} **入场**: {entry_type_text}
|
||||
{position_icon} **仓位**: {position_text}
|
||||
⭐ **等级**: {grade} {grade_icon}
|
||||
📈 **置信度**: {confidence}%
|
||||
|
||||
💰 **入场价**: ${entry:,.2f}
|
||||
🛑 **止损价**: ${sl:,.2f} ({sl_percent:+.1f}%)
|
||||
🎯 **止盈价**: ${tp:,.2f} ({tp_percent:+.1f}%)
|
||||
|
||||
📝 **分析理由**:
|
||||
{signal.get('reasoning') or signal.get('reason', '无')}
|
||||
|
||||
⚠️ **风险提示**:
|
||||
{signal.get('risk_warning', '请注意风险控制')}"""
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def format_feishu_card(signal: Dict[str, Any], symbol: str, agent_type: str = 'crypto', stock_name: str = '') -> Dict[str, Any]:
|
||||
"""
|
||||
格式化飞书卡片消息
|
||||
|
||||
Args:
|
||||
signal: 信号数据
|
||||
symbol: 交易对
|
||||
agent_type: 智能体类型 (crypto/stock)
|
||||
stock_name: 股票名称(可选,从基本面数据获取)
|
||||
|
||||
Returns:
|
||||
包含 title, content, color 的字典
|
||||
"""
|
||||
type_map = {
|
||||
'short_term': '短线',
|
||||
'medium_term': '中线',
|
||||
'long_term': '长线'
|
||||
}
|
||||
action_map = {
|
||||
'buy': '做多',
|
||||
'sell': '做空'
|
||||
}
|
||||
|
||||
# 兼容 timeframe 和 type 字段
|
||||
signal_type_key = 'timeframe' if 'timeframe' in signal else 'type'
|
||||
signal_type = type_map.get(signal.get(signal_type_key), signal.get(signal_type_key))
|
||||
action = action_map.get(signal['action'], signal['action'])
|
||||
action_icon = '🟢' if signal['action'] == 'buy' else '🔴'
|
||||
grade = signal.get('grade', 'C')
|
||||
confidence = signal.get('confidence', 0)
|
||||
entry_type = signal.get('entry_type', 'market')
|
||||
|
||||
# 等级图标
|
||||
grade_icon = {'A': '⭐⭐⭐', 'B': '⭐⭐', 'C': '⭐', 'D': ''}.get(grade, '')
|
||||
|
||||
# 入场类型
|
||||
entry_type_text = '现价入场' if entry_type == 'market' else '挂单等待'
|
||||
entry_type_icon = '⚡' if entry_type == 'market' else '⏳'
|
||||
|
||||
# 仓位大小
|
||||
position_size = signal.get('position_size', 'light')
|
||||
position_map = {'heavy': '重仓', 'medium': '中仓', 'light': '轻仓'}
|
||||
position_icon = {'heavy': '🔥', 'medium': '📊', 'light': '🌱'}.get(position_size, '🌱')
|
||||
position_text = position_map.get(position_size, '轻仓')
|
||||
|
||||
# 标题和颜色 - 区分加密货币/美股/港股
|
||||
is_market_order = entry_type == 'market'
|
||||
market_badge = '【现价】' if is_market_order else ''
|
||||
|
||||
# 识别市场类型
|
||||
if agent_type == 'crypto':
|
||||
market_tag = '[加密货币] '
|
||||
elif symbol.endswith('.HK'):
|
||||
market_tag = '[港股] '
|
||||
else:
|
||||
market_tag = '[美股] '
|
||||
|
||||
# 构建带名称的股票显示
|
||||
symbol_display = f"{stock_name}({symbol})" if stock_name else symbol
|
||||
|
||||
if signal['action'] == 'buy':
|
||||
title = f"🟢 {market_tag}{symbol_display} {signal_type}做多信号 {market_badge}"
|
||||
color = "green"
|
||||
else:
|
||||
title = f"🔴 {market_tag}{symbol_display} {signal_type}做空信号 {market_badge}"
|
||||
color = "red"
|
||||
|
||||
# 计算风险收益比
|
||||
entry = signal.get('entry_price', 0)
|
||||
sl = signal.get('stop_loss', 0)
|
||||
tp = signal.get('take_profit', 0)
|
||||
sl_percent = ((sl - entry) / entry * 100) if entry else 0
|
||||
tp_percent = ((tp - entry) / entry * 100) if entry else 0
|
||||
|
||||
# 构建内容
|
||||
content_lines = [
|
||||
f"{action_icon} **操作**: {action}",
|
||||
f"{entry_type_icon} **入场方式**: {entry_type_text}",
|
||||
f"{position_icon} **仓位**: {position_text} | 📈 信心度: **{confidence}%**",
|
||||
f"⭐ **等级**: {grade} {grade_icon}",
|
||||
f"",
|
||||
f"💰 **入场价**: ${entry:,.2f}",
|
||||
f"🛑 **止损价**: ${sl:,.2f} ({sl_percent:+.1f}%)",
|
||||
f"🎯 **止盈价**: ${tp:,.2f} ({tp_percent:+.1f}%)",
|
||||
f"",
|
||||
f"📝 **分析理由**:",
|
||||
f"{signal.get('reason', '无')}",
|
||||
]
|
||||
|
||||
# 添加关键因素(如果有)
|
||||
key_factors = signal.get('key_factors')
|
||||
if key_factors and isinstance(key_factors, list):
|
||||
content_lines.append("")
|
||||
content_lines.append("**关键因素**:")
|
||||
for factor in key_factors[:5]:
|
||||
content_lines.append(f"- {factor}")
|
||||
|
||||
# 添加风险提示(如果有)
|
||||
risk_warning = signal.get('risk_warning')
|
||||
if risk_warning:
|
||||
content_lines.append("")
|
||||
content_lines.append(f"⚠️ **风险提示**:")
|
||||
content_lines.append(risk_warning)
|
||||
|
||||
content = "\n".join(content_lines)
|
||||
|
||||
return {
|
||||
'title': title,
|
||||
'content': content,
|
||||
'color': color
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_signal_formatter = SignalFormatter()
|
||||
|
||||
|
||||
def get_signal_formatter() -> SignalFormatter:
|
||||
"""获取信号格式化工具单例"""
|
||||
return _signal_formatter
|
||||
@ -23,7 +23,7 @@ class AgentInfo:
|
||||
|
||||
def __init__(self, name: str, agent_type: str):
|
||||
self.name = name # Agent 名称
|
||||
self.agent_type = agent_type # Agent 类型 (crypto/stock/smart)
|
||||
self.agent_type = agent_type # Agent 类型
|
||||
self.status = AgentStatus.NOT_INITIALIZED
|
||||
self.start_time: Optional[datetime] = None
|
||||
self.last_activity: Optional[datetime] = None
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
"""
|
||||
验证工具模块
|
||||
提供各种数据验证功能
|
||||
"""
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def validate_stock_code(code: str) -> bool:
|
||||
"""
|
||||
验证股票代码格式
|
||||
|
||||
A股代码格式:
|
||||
- 上海:6开头,6位数字
|
||||
- 深圳:0/3开头,6位数字
|
||||
- 创业板:3开头,6位数字
|
||||
- 科创板:688开头,6位数字
|
||||
|
||||
Args:
|
||||
code: 股票代码
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
if not code:
|
||||
return False
|
||||
|
||||
# 移除可能的后缀(如.SH, .SZ)
|
||||
code = code.split('.')[0]
|
||||
|
||||
# 检查是否为6位数字
|
||||
if not re.match(r'^\d{6}$', code):
|
||||
return False
|
||||
|
||||
# 检查首位数字
|
||||
first_digit = code[0]
|
||||
if first_digit in ['0', '3', '6']:
|
||||
return True
|
||||
|
||||
# 检查科创板
|
||||
if code.startswith('688'):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def normalize_stock_code(code: str) -> Optional[str]:
|
||||
"""
|
||||
标准化股票代码,添加市场后缀
|
||||
|
||||
Args:
|
||||
code: 股票代码
|
||||
|
||||
Returns:
|
||||
标准化后的代码(如600000.SH)或None
|
||||
"""
|
||||
if not validate_stock_code(code):
|
||||
return None
|
||||
|
||||
# 移除已有后缀
|
||||
code = code.split('.')[0]
|
||||
|
||||
# 添加市场后缀
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH" # 上海
|
||||
elif code.startswith(('0', '3')):
|
||||
return f"{code}.SZ" # 深圳
|
||||
elif code.startswith('688'):
|
||||
return f"{code}.SH" # 科创板
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_date_format(date_str: str) -> bool:
|
||||
"""
|
||||
验证日期格式(YYYYMMDD)
|
||||
|
||||
Args:
|
||||
date_str: 日期字符串
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
if not date_str:
|
||||
return False
|
||||
|
||||
# 检查格式
|
||||
if not re.match(r'^\d{8}$', date_str):
|
||||
return False
|
||||
|
||||
# 简单的日期范围检查
|
||||
year = int(date_str[:4])
|
||||
month = int(date_str[4:6])
|
||||
day = int(date_str[6:8])
|
||||
|
||||
if year < 1990 or year > 2100:
|
||||
return False
|
||||
if month < 1 or month > 12:
|
||||
return False
|
||||
if day < 1 or day > 31:
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
检查 Tushare ths_daily API 返回的数据字段
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.config import get_settings
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
|
||||
|
||||
async def check_api_fields():
|
||||
"""检查API字段"""
|
||||
print("\n" + "=" * 80)
|
||||
print("🔍 检查 ths_daily API 返回字段")
|
||||
print("=" * 80)
|
||||
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
|
||||
# 获取智能电网板块
|
||||
sectors_df = ts_client.get_concept_sectors()
|
||||
smart_grid = sectors_df[sectors_df['name'] == '智能电网']
|
||||
|
||||
if smart_grid.empty:
|
||||
print("未找到智能电网板块")
|
||||
return
|
||||
|
||||
ts_code = smart_grid.iloc[0]['ts_code']
|
||||
print(f"\n板块代码: {ts_code}")
|
||||
|
||||
today = datetime.now().strftime('%Y%m%d')
|
||||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||||
|
||||
daily_df = ts_client.pro.ths_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if daily_df.empty:
|
||||
print("未获取到数据")
|
||||
return
|
||||
|
||||
print(f"\n获取到 {len(daily_df)} 条数据")
|
||||
print("\n数据列:")
|
||||
print(daily_df.columns.tolist())
|
||||
|
||||
print("\n最近3天的数据:")
|
||||
print(daily_df.tail(3).to_string())
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("\n字段分析:")
|
||||
for col in daily_df.columns:
|
||||
print(f" {col}: {daily_df[col].dtype}")
|
||||
if col in ['volume', 'amount', 'vol', 'amt']:
|
||||
print(f" 最新值: {daily_df[col].iloc[-1]}")
|
||||
print(f" 前一日: {daily_df[col].iloc[-2] if len(daily_df) > 1 else 'N/A'}")
|
||||
|
||||
|
||||
async def main():
|
||||
await check_api_fields()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,225 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A股短期题材选股 - 调试版本
|
||||
用于诊断为什么没有选出股票
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
|
||||
|
||||
async def debug_tushare_connection():
|
||||
"""测试Tushare连接"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 测试1: Tushare连接测试")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
settings = get_settings()
|
||||
print(f"Token配置: {'已配置' if settings.tushare_token else '未配置'}")
|
||||
|
||||
if not settings.tushare_token:
|
||||
print("❌ Tushare Token未配置,请在.env文件中设置TUSHARE_TOKEN")
|
||||
return False
|
||||
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
print(f"✅ Tushare客户端初始化成功")
|
||||
|
||||
# 测试基本API调用
|
||||
print("\n测试API调用...")
|
||||
|
||||
# 测试获取概念板块
|
||||
sectors_df = ts_client.get_concept_sectors()
|
||||
print(f"概念板块数量: {len(sectors_df)}")
|
||||
if not sectors_df.empty:
|
||||
print(f"示例板块: {sectors_df.head(3)['name'].tolist()}")
|
||||
else:
|
||||
print("❌ 无法获取概念板块列表")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Tushare连接失败: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
async def debug_hot_sectors():
|
||||
"""测试异动板块获取"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 测试2: 异动板块检测")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
|
||||
sectors_df = ts_client.get_concept_sectors()
|
||||
print(f"总概念板块数: {len(sectors_df)}")
|
||||
|
||||
# 检查异动板块
|
||||
today = datetime.now().strftime('%Y%m%d')
|
||||
print(f"当前日期: {today}")
|
||||
|
||||
# 检查是否是交易日
|
||||
weekday = datetime.now().weekday()
|
||||
if weekday >= 5:
|
||||
print(f"⚠️ 当前是周末(周{weekday}),可能没有最新数据")
|
||||
else:
|
||||
print(f"✅ 当前是工作日(周{weekday})")
|
||||
|
||||
# 手动检查几个热门板块
|
||||
print("\n检查前10个板块的行情...")
|
||||
check_count = min(10, len(sectors_df))
|
||||
|
||||
hot_count = 0
|
||||
for idx, row in sectors_df.head(check_count).iterrows():
|
||||
ts_code = row['ts_code']
|
||||
name = row['name']
|
||||
|
||||
try:
|
||||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||||
daily_df = ts_client.pro.ths_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if not daily_df.empty:
|
||||
latest = daily_df.sort_values('trade_date').iloc[-1]
|
||||
change_pct = float(latest.get('pct_change', 0))
|
||||
trade_date = str(latest.get('trade_date', ''))
|
||||
|
||||
status = "🔥" if change_pct >= 2.0 else "📊"
|
||||
print(f" {status} {name}: {change_pct:+.2f}% (日期: {trade_date})")
|
||||
|
||||
if change_pct >= 2.0:
|
||||
hot_count += 1
|
||||
else:
|
||||
print(f" ⚠️ {name}: 无数据")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ {name}: 查询失败 ({e})")
|
||||
|
||||
print(f"\n找到 {hot_count} 个涨幅≥2%的板块")
|
||||
|
||||
if hot_count == 0:
|
||||
print("\n⚠️ 没有找到符合条件的异动板块,可能原因:")
|
||||
print(" 1. 当前不是交易日(周末或节假日)")
|
||||
print(" 2. 盘中时段数据未更新")
|
||||
print(" 3. 市场整体表现平淡")
|
||||
|
||||
return hot_count > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 异动板块检测失败: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
async def debug_stock_screening():
|
||||
"""测试个股筛选"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 测试3: 个股筛选条件分析")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n筛选条件回顾:")
|
||||
print(" 1. 市值: 50-500亿")
|
||||
print(" 2. 换手率: 3%-15%")
|
||||
print(" 3. 涨跌幅: -5% 到 +8%")
|
||||
print(" 4. MA多头排列: MA5 > MA10 > MA20")
|
||||
print(" 5. 量能配合: 量比 > 1.2")
|
||||
print(" 6. 20日动量 > 0")
|
||||
print(" 7. 距离高点回撤 < 15%")
|
||||
|
||||
print("\n⚠️ 如果没有选出股票,可能是因为:")
|
||||
print(" 1. 市场整体不符合技术形态(没有MA多头排列的股票)")
|
||||
print(" 2. 筛选条件较严格(可以尝试放宽参数)")
|
||||
print(" 3. 数据时间窗口问题(需要30天以上历史数据)")
|
||||
|
||||
# 建议放宽的条件
|
||||
print("\n建议放宽的参数(在当前市场环境下):")
|
||||
print(" - 换手率: 1%-15% (降低下限)")
|
||||
print(" - 涨跌幅: -7% 到 +10% (扩大范围)")
|
||||
print(" - 市值: 30-500亿 (降低下限)")
|
||||
|
||||
|
||||
async def debug_selector_run():
|
||||
"""尝试运行选股器并显示详细信息"""
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 测试4: 运行选股器(详细日志)")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from app.astock_agent.short_term_thematic_selector import get_thematic_selector
|
||||
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
selector = get_thematic_selector(ts_client)
|
||||
|
||||
# 运行选股,启用详细日志
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
result = selector.select_stocks(max_stocks=10)
|
||||
|
||||
print(f"\n选股结果: {result['total_stocks']} 只")
|
||||
|
||||
if result['total_stocks'] == 0:
|
||||
print("\n❌ 未选出股票")
|
||||
print("\n请查看上方详细日志,分析哪个环节过滤掉了股票")
|
||||
else:
|
||||
print("\n✅ 选股成功!")
|
||||
print(selector.format_output_text(result))
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 选股器运行失败: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🔍 A股选股器诊断工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行所有测试
|
||||
step1_ok = await debug_tushare_connection()
|
||||
if not step1_ok:
|
||||
print("\n❌ Tushare连接失败,请检查配置")
|
||||
return 1
|
||||
|
||||
step2_ok = await debug_hot_sectors()
|
||||
if not step2_ok:
|
||||
print("\n⚠️ 没有找到异动板块,这是正常的(取决于市场情况)")
|
||||
|
||||
await debug_stock_screening()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("📋 诊断完成")
|
||||
print("=" * 60)
|
||||
print("\n如果所有测试通过但未选出股票,说明当前市场条件不符合策略要求。")
|
||||
print("这是正常的,策略不会在市场条件不符合时强行选股。")
|
||||
print("\n建议:")
|
||||
print(" 1. 在交易日15:00后运行(确保有完整数据)")
|
||||
print(" 2. 或者放宽筛选条件以适应当前市场环境")
|
||||
print("")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
62
backend/diagnose.sh
Executable file → Normal file
62
backend/diagnose.sh
Executable file → Normal file
@ -1,100 +1,70 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 诊断脚本 - 检查系统配置
|
||||
set -e
|
||||
|
||||
echo "================================"
|
||||
echo "系统诊断"
|
||||
echo "Crypto Agent 诊断"
|
||||
echo "================================"
|
||||
echo ""
|
||||
|
||||
cd /Users/aaron/source_code/Stock_Agent/backend
|
||||
|
||||
# 1. 检查虚拟环境
|
||||
echo "1. 检查虚拟环境..."
|
||||
if [ -d "venv" ]; then
|
||||
echo " ✓ 虚拟环境存在"
|
||||
source venv/bin/activate
|
||||
python_version=$(python --version 2>&1)
|
||||
echo " ✓ $python_version"
|
||||
echo " ✓ 虚拟环境存在"
|
||||
echo " ✓ $(python --version 2>&1)"
|
||||
else
|
||||
echo " ❌ 虚拟环境不存在"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 2. 检查.env文件
|
||||
echo ""
|
||||
echo "2. 检查配置文件..."
|
||||
if [ -f "../.env" ]; then
|
||||
echo " ✓ .env文件存在(项目根目录)"
|
||||
elif [ -f ".env" ]; then
|
||||
echo " ✓ .env文件存在(backend目录)"
|
||||
if [ -f "../.env" ] || [ -f ".env" ]; then
|
||||
echo " ✓ .env 文件存在"
|
||||
else
|
||||
echo " ❌ .env文件不存在"
|
||||
echo " ❌ .env 文件不存在"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 3. 检查依赖包
|
||||
echo ""
|
||||
echo "3. 检查依赖包..."
|
||||
packages=("fastapi" "uvicorn" "tushare" "pandas" "numpy" "sqlalchemy" "pydantic")
|
||||
all_installed=true
|
||||
|
||||
echo "3. 检查关键依赖..."
|
||||
packages=("fastapi" "uvicorn" "pandas" "numpy" "sqlalchemy" "pydantic" "ccxt" "httpx" "aiohttp")
|
||||
for pkg in "${packages[@]}"; do
|
||||
if python -c "import $pkg" 2>/dev/null; then
|
||||
version=$(python -c "import $pkg; print($pkg.__version__)" 2>/dev/null || echo "unknown")
|
||||
echo " ✓ $pkg ($version)"
|
||||
echo " ✓ $pkg"
|
||||
else
|
||||
echo " ❌ $pkg 未安装"
|
||||
all_installed=false
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$all_installed" = false ]; then
|
||||
echo ""
|
||||
echo "请运行: pip install -r requirements.txt"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 4. 测试配置加载
|
||||
echo ""
|
||||
echo "4. 测试配置加载..."
|
||||
python -c "
|
||||
from app.config import get_settings
|
||||
settings = get_settings()
|
||||
print(f' ✓ 配置加载成功')
|
||||
print(f' - Tushare Token: {'已配置 (' + settings.tushare_token[:10] + '...)' if settings.tushare_token else '❌ 未配置'}')
|
||||
print(f' - 智谱AI Key: {'已配置 (' + settings.zhipuai_api_key[:10] + '...)' if settings.zhipuai_api_key else '❌ 未配置'}')
|
||||
" 2>&1
|
||||
print(' ✓ 配置加载成功')
|
||||
print(f' - DeepSeek Key: {'已配置' if settings.deepseek_api_key else '未配置'}')
|
||||
print(f' - 智谱AI Key: {'已配置' if settings.zhipuai_api_key else '未配置'}')
|
||||
print(f' - Bitget 实盘: {'开启' if settings.bitget_trading_enabled else '关闭'}')
|
||||
"
|
||||
|
||||
# 5. 测试模块导入
|
||||
echo ""
|
||||
echo "5. 测试模块导入..."
|
||||
modules=("app.models.database" "app.services.cache_service" "app.services.tushare_service" "app.agent.core")
|
||||
|
||||
modules=(\"app.models.database\" \"app.services.cache_service\" \"app.services.bitget_trading_api_sdk\" \"app.crypto_agent.crypto_agent\")
|
||||
for module in "${modules[@]}"; do
|
||||
if python -c "import $module" 2>/dev/null; then
|
||||
echo " ✓ $module"
|
||||
else
|
||||
echo " ❌ $module 导入失败"
|
||||
python -c "import $module" 2>&1 | head -5
|
||||
fi
|
||||
done
|
||||
|
||||
# 6. 检查端口占用
|
||||
echo ""
|
||||
echo "6. 检查端口占用..."
|
||||
if lsof -i :8000 >/dev/null 2>&1; then
|
||||
echo " ⚠ 端口8000已被占用"
|
||||
lsof -i :8000
|
||||
else
|
||||
echo " ✓ 端口8000可用"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "================================"
|
||||
echo "诊断完成"
|
||||
echo "================================"
|
||||
echo ""
|
||||
echo "如果所有检查都通过,可以运行:"
|
||||
echo " ./start.sh"
|
||||
echo ""
|
||||
|
||||
@ -1,203 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A股短期题材选股 - 详细诊断版本
|
||||
显示每个股票的筛选过程
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
|
||||
|
||||
async def diagnose_sector_stocks():
|
||||
"""诊断板块成分股的筛选过程"""
|
||||
print("\n" + "=" * 60)
|
||||
print("🔍 A股选股详细诊断")
|
||||
print("=" * 60)
|
||||
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
|
||||
# 1. 测试获取一个板块的成分股
|
||||
print("\n【测试】获取智能电网板块成分股...")
|
||||
try:
|
||||
# 使用智能电网板块代码(与选股器一致)
|
||||
sector_code = "885311.TI"
|
||||
members_df = ts_client.get_sector_members(sector_code)
|
||||
|
||||
if members_df.empty:
|
||||
print("❌ 无法获取板块成分股")
|
||||
return
|
||||
|
||||
stock_codes = members_df['con_code'].tolist()[:10] # 只测试前10只
|
||||
print(f"✓ 获取到 {len(stock_codes)} 只成分股(测试前10只)")
|
||||
print(f"股票代码: {stock_codes}")
|
||||
|
||||
# 2. 获取这些股票的实时行情
|
||||
print("\n【测试】获取实时行情...")
|
||||
today = datetime.now().strftime('%Y%m%d')
|
||||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||||
|
||||
all_stocks_data = []
|
||||
for stock_code in stock_codes:
|
||||
try:
|
||||
daily_df = ts_client.pro.daily(
|
||||
ts_code=stock_code,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if daily_df.empty:
|
||||
print(f" ⚠️ {stock_code}: 无历史数据")
|
||||
continue
|
||||
|
||||
daily_df = daily_df.sort_values('trade_date')
|
||||
latest = daily_df.iloc[-1]
|
||||
|
||||
stock_info = {
|
||||
'ts_code': stock_code,
|
||||
'name': latest.get('name', ''),
|
||||
'close': float(latest['close']),
|
||||
'pct_chg': float(latest['pct_chg']),
|
||||
'vol': float(latest['vol']),
|
||||
'amount': float(latest['amount']) * 1000,
|
||||
'trade_date': str(latest['trade_date'])
|
||||
}
|
||||
all_stocks_data.append(stock_info)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ {stock_code}: 获取失败 - {e}")
|
||||
continue
|
||||
|
||||
print(f"\n✓ 成功获取 {len(all_stocks_data)} 只股票的行情")
|
||||
|
||||
# 3. 获取每日指标
|
||||
print("\n【测试】获取每日指标(换手率等)...")
|
||||
basic_df = ts_client.pro.daily_basic(
|
||||
ts_code=','.join(stock_codes),
|
||||
trade_date=all_stocks_data[0]['trade_date'],
|
||||
fields='ts_code,trade_date,turnover_rate,pe,pb'
|
||||
)
|
||||
|
||||
if basic_df.empty:
|
||||
print("⚠️ 无法获取每日指标数据")
|
||||
else:
|
||||
print(f"✓ 获取到 {len(basic_df)} 只股票的每日指标")
|
||||
|
||||
# 4. 逐个检查筛选条件
|
||||
print("\n【测试】逐个检查筛选条件...")
|
||||
print("=" * 80)
|
||||
|
||||
for stock_info in all_stocks_data:
|
||||
stock_code = stock_info['ts_code']
|
||||
name = stock_info['name']
|
||||
close = stock_info['close']
|
||||
pct_chg = stock_info['pct_chg']
|
||||
vol = stock_info['vol']
|
||||
amount = stock_info['amount']
|
||||
|
||||
print(f"\n🔍 {name}({stock_code}):")
|
||||
print(f" 日期: {stock_info['trade_date']}")
|
||||
print(f" 现价: ¥{close:.2f}, 涨跌幅: {pct_chg:+.2f}%")
|
||||
|
||||
# 检查1: ST股票
|
||||
if 'ST' in name or '退' in name:
|
||||
print(f" ❌ ST/退市股,被过滤")
|
||||
continue
|
||||
print(f" ✓ 不是ST/退市股")
|
||||
|
||||
# 检查2: 换手率
|
||||
basic_row = basic_df[basic_df['ts_code'] == stock_code]
|
||||
if not basic_row.empty:
|
||||
turnover = float(basic_row.iloc[0].get('turnover_rate', 0))
|
||||
print(f" 换手率: {turnover:.2f}%")
|
||||
if 1.0 <= turnover <= 20.0:
|
||||
print(f" ✓ 换手率符合")
|
||||
else:
|
||||
print(f" ❌ 换手率不符合(需要1%-20%)")
|
||||
continue
|
||||
else:
|
||||
print(f" ⚠️ 无换手率数据")
|
||||
turnover = 0
|
||||
|
||||
# 检查3: MA多头排列
|
||||
try:
|
||||
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y%m%d')
|
||||
daily_df = ts_client.pro.daily(
|
||||
ts_code=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if daily_df.empty or len(daily_df) < 30:
|
||||
print(f" ❌ 历史数据不足(需要30天以上),无法计算MA")
|
||||
continue
|
||||
|
||||
daily_df = daily_df.sort_values('trade_date').reset_index(drop=True)
|
||||
close_series = daily_df['close']
|
||||
vol_series = daily_df['vol']
|
||||
|
||||
ma5 = close_series.rolling(window=5).mean().iloc[-1]
|
||||
ma10 = close_series.rolling(window=10).mean().iloc[-1]
|
||||
ma20 = close_series.rolling(window=20).mean().iloc[-1]
|
||||
|
||||
print(f" MA5: ¥{ma5:.2f}, MA10: ¥{ma10:.2f}, MA20: ¥{ma20:.2f}")
|
||||
|
||||
if ma5 > ma20:
|
||||
print(f" ✓ MA趋势符合(MA5 > MA20)")
|
||||
else:
|
||||
print(f" ❌ MA趋势不符合(需要 MA5 > MA20)")
|
||||
continue
|
||||
|
||||
# 检查4: 量能
|
||||
ma5_vol = vol_series.rolling(window=5).mean().iloc[-1]
|
||||
volume_ratio = vol / ma5_vol if ma5_vol > 0 else 1
|
||||
print(f" 量比: {volume_ratio:.2f}")
|
||||
|
||||
if volume_ratio >= 0.7:
|
||||
print(f" ✓ 量能符合(≥0.7)")
|
||||
else:
|
||||
print(f" ❌ 量能不足(量比需要≥0.7)")
|
||||
continue
|
||||
|
||||
# 检查5: 市值
|
||||
if turnover > 0:
|
||||
market_cap = amount / (turnover / 100)
|
||||
market_cap_yi = market_cap / 100000000
|
||||
print(f" 市值: {market_cap_yi:.2f}亿")
|
||||
|
||||
if 30 <= market_cap_yi <= 1000:
|
||||
print(f" ✓ 市值符合")
|
||||
else:
|
||||
print(f" ❌ 市值不符合(需要30-1000亿)")
|
||||
continue
|
||||
|
||||
print(f" ✅✅✅ {name}({stock_code}) 通过所有筛选条件!")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 计算技术指标失败: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 诊断失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
await diagnose_sector_stocks()
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,150 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
诊断板块资金异动检测
|
||||
检查为什么没有找到异动板块
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
|
||||
|
||||
async def diagnose_sector_detection():
|
||||
"""诊断板块检测"""
|
||||
print("\n" + "=" * 80)
|
||||
print("🔍 板块资金异动诊断")
|
||||
print("=" * 80)
|
||||
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
|
||||
# 获取热门概念板块
|
||||
print("\n【第一步】获取热门概念板块...")
|
||||
hot_concept_sectors = [
|
||||
'人工智能', '新能源汽车', '芯片', '半导体', '5G',
|
||||
'智能电网', '物联网', '云计算', '大数据', '区块链'
|
||||
]
|
||||
|
||||
sectors_df = ts_client.get_concept_sectors()
|
||||
|
||||
# 找到热门板块
|
||||
hot_sectors_codes = []
|
||||
for hot_name in hot_concept_sectors:
|
||||
matches = sectors_df[sectors_df['name'].str.contains(hot_name, na=False)]
|
||||
if not matches.empty:
|
||||
for _, row in matches.iterrows():
|
||||
hot_sectors_codes.append({
|
||||
'ts_code': row['ts_code'],
|
||||
'name': row['name']
|
||||
})
|
||||
|
||||
print(f"✓ 找到 {len(hot_sectors_codes)} 个热门板块")
|
||||
|
||||
# 检查这些板块的资金异动
|
||||
print("\n【第二步】检查板块资金异动(量比、额比、涨幅)...")
|
||||
print("=" * 80)
|
||||
|
||||
today = datetime.now().strftime('%Y%m%d')
|
||||
yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d')
|
||||
|
||||
# 宽松模式的阈值
|
||||
vol_threshold = 1.5
|
||||
amount_threshold = 1.3
|
||||
min_change = 0.5
|
||||
|
||||
qualified_sectors = []
|
||||
|
||||
for sector_info in hot_sectors_codes[:15]: # 只检查前15个
|
||||
ts_code = sector_info['ts_code']
|
||||
name = sector_info['name']
|
||||
|
||||
try:
|
||||
daily_df = ts_client.pro.ths_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
if daily_df.empty or len(daily_df) < 2:
|
||||
print(f" ⚠️ {name}: 数据不足")
|
||||
continue
|
||||
|
||||
daily_df = daily_df.sort_values('trade_date')
|
||||
latest = daily_df.iloc[-1]
|
||||
prev = daily_df.iloc[-2]
|
||||
|
||||
latest_vol = float(latest.get('vol', 0))
|
||||
latest_avg_price = float(latest.get('avg_price', 0))
|
||||
latest_amount = latest_vol * latest_avg_price * 100 # 估算成交额
|
||||
|
||||
prev_vol = float(prev.get('vol', 0))
|
||||
prev_avg_price = float(prev.get('avg_price', 0))
|
||||
prev_amount = prev_vol * prev_avg_price * 100
|
||||
change_pct = float(latest.get('pct_change', 0))
|
||||
|
||||
# 计算量比和额比
|
||||
vol_ratio = latest_vol / prev_vol if prev_vol > 0 else 0
|
||||
amount_ratio = latest_amount / prev_amount if prev_amount > 0 else 0
|
||||
|
||||
# 判断是否符合条件
|
||||
is_volume_surge = vol_ratio >= vol_threshold
|
||||
is_amount_surge = amount_ratio >= amount_threshold
|
||||
has_min_change = change_pct >= min_change
|
||||
is_qualified = (is_volume_surge or is_amount_surge) and has_min_change
|
||||
|
||||
# 显示结果
|
||||
status = "✅" if is_qualified else "❌"
|
||||
vol_status = "🔥" if is_volume_surge else "📊"
|
||||
amount_status = "🔥" if is_amount_surge else "📊"
|
||||
change_status = "✓" if has_min_change else "✗"
|
||||
|
||||
print(f" {status} {name}")
|
||||
print(f" 涨跌幅: {change_pct:+.2f}% {change_status}")
|
||||
print(f" 量比: {vol_ratio:.2f}x {vol_status} (需要≥{vol_threshold})")
|
||||
print(f" 额比: {amount_ratio:.2f}x {amount_status} (需要≥{amount_threshold})")
|
||||
|
||||
if is_qualified:
|
||||
qualified_sectors.append({
|
||||
'name': name,
|
||||
'change_pct': change_pct,
|
||||
'vol_ratio': vol_ratio,
|
||||
'amount_ratio': amount_ratio
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ {name}: 查询失败 ({e})")
|
||||
continue
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print(f"【结果】找到 {len(qualified_sectors)} 个符合资金异动条件的板块")
|
||||
|
||||
if len(qualified_sectors) == 0:
|
||||
print("\n⚠️ 没有板块符合条件,可能原因:")
|
||||
print(" 1. 市场整体资金流入不足(量比、额比都未达标)")
|
||||
print(" 2. 板块涨幅不够(需要≥0.5%)")
|
||||
print(" 3. 阈值设置过高(当前:量比≥1.5,额比≥1.3)")
|
||||
print("\n建议放宽阈值:")
|
||||
print(" - 量比阈值: 1.5 → 1.2")
|
||||
print(" - 额比阈值: 1.3 → 1.1")
|
||||
print(" - 最小涨幅: 0.5% → 0.3%")
|
||||
else:
|
||||
print("\n✅ 符合条件的板块:")
|
||||
for idx, sector in enumerate(qualified_sectors, 1):
|
||||
print(f" {idx}. {sector['name']}: {sector['change_pct']:+.2f}%, "
|
||||
f"量比{sector['vol_ratio']:.2f}x, 额比{sector['amount_ratio']:.2f}x")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
async def main():
|
||||
await diagnose_sector_detection()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,6 +1,5 @@
|
||||
"""
|
||||
数据库迁移脚本 - 添加移动止损字段
|
||||
用于为已有的 paper_trading 表添加新字段
|
||||
数据库迁移脚本 - 添加模拟盘移动止损字段
|
||||
"""
|
||||
import sqlite3
|
||||
import os
|
||||
@ -23,7 +22,7 @@ def migrate_database():
|
||||
break
|
||||
|
||||
if not db_path:
|
||||
print("❌ 未找到数据库文件 stock_agent.db")
|
||||
print("❌ 未找到数据库文件")
|
||||
print("请确保在项目根目录或 backend 目录下运行此脚本")
|
||||
return False
|
||||
|
||||
@ -134,7 +133,7 @@ def verify_migration():
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("🔄 Stock Agent 数据库迁移工具")
|
||||
print("🔄 Crypto Agent 数据库迁移工具")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
|
||||
@ -1,21 +1,15 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
langchain==0.1.0
|
||||
langchain-community==0.0.20
|
||||
zhipuai==2.0.1
|
||||
openai>=1.0.0
|
||||
tushare>=1.4.0
|
||||
sqlalchemy==2.0.25
|
||||
pydantic==2.5.3
|
||||
pydantic-settings==2.1.0
|
||||
python-dotenv==1.0.0
|
||||
slowapi==0.1.9
|
||||
pandas>=2.2.0
|
||||
numpy>=1.26.0
|
||||
python-multipart==0.0.6
|
||||
aiohttp==3.9.1
|
||||
yfinance>=0.2.36
|
||||
pandas-datareader>=0.10.0 # Stooq 数据源支持(美股港股备用)
|
||||
PyJWT==2.8.0
|
||||
tencentcloud-sdk-python==3.0.1100
|
||||
python-jose[cryptography]==3.3.0
|
||||
@ -27,12 +21,4 @@ ccxt>=4.5.45 # 统一交易所API接口,Bitget UTA V3 需要 4.5.45+
|
||||
websockets>=12.0 # WebSocket 支持,用于实时价格更新
|
||||
|
||||
# 新闻智能体依赖
|
||||
feedparser>=6.0.10
|
||||
beautifulsoup4>=4.12.0
|
||||
lxml>=4.9.0
|
||||
|
||||
# A股板块监控依赖
|
||||
akshare>=1.12.0
|
||||
apscheduler>=3.10.0 # 定时任务
|
||||
|
||||
eth-account>=0.10.0
|
||||
|
||||
76
backend/run.sh
Executable file → Normal file
76
backend/run.sh
Executable file → Normal file
@ -1,89 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 最终启动检查和启动脚本
|
||||
set -e
|
||||
|
||||
echo "================================"
|
||||
echo "A股AI分析Agent - 最终检查"
|
||||
echo "Crypto Agent - 运行前检查"
|
||||
echo "================================"
|
||||
echo ""
|
||||
|
||||
cd /Users/aaron/source_code/Stock_Agent/backend
|
||||
|
||||
# 激活虚拟环境
|
||||
if [ ! -d "venv" ]; then
|
||||
echo "❌ 虚拟环境不存在,请先运行 ../install.sh"
|
||||
echo "❌ 虚拟环境不存在"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source venv/bin/activate
|
||||
|
||||
# 快速导入测试
|
||||
echo "1. 测试模块导入..."
|
||||
echo "1. 测试核心模块导入..."
|
||||
python3 << 'EOF'
|
||||
try:
|
||||
# 测试基础模块
|
||||
from app.config import get_settings
|
||||
print(" ✓ 配置模块")
|
||||
|
||||
from app.models.database import Base, Message
|
||||
print(" ✓ 数据库模型")
|
||||
|
||||
from app.services.cache_service import cache_service
|
||||
print(" ✓ 缓存服务")
|
||||
|
||||
from app.services.tushare_service import tushare_service
|
||||
print(" ✓ Tushare服务")
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
print(" ✓ LLM服务")
|
||||
|
||||
from app.agent.smart_agent import smart_agent
|
||||
print(" ✓ 智能Agent")
|
||||
|
||||
print("\n所有模块导入成功!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 导入失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
from app.config import get_settings
|
||||
from app.models.database import Base, Message
|
||||
from app.services.cache_service import cache_service
|
||||
from app.services.llm_service import llm_service
|
||||
from app.crypto_agent.crypto_agent import get_crypto_agent
|
||||
print(" ✓ 配置模块")
|
||||
print(" ✓ 数据库模型")
|
||||
print(" ✓ 缓存服务")
|
||||
print(" ✓ LLM 服务")
|
||||
print(" ✓ Crypto Agent")
|
||||
EOF
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo ""
|
||||
echo "模块导入失败,请检查错误信息"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查配置
|
||||
echo ""
|
||||
echo "2. 检查配置..."
|
||||
python3 << 'EOF'
|
||||
from app.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
print(f" Tushare Token: {'✓ 已配置' if settings.tushare_token else '❌ 未配置'}")
|
||||
print(f" DeepSeek Key: {'✓ 已配置' if settings.deepseek_api_key else '❌ 未配置'}")
|
||||
print(f" 智谱AI Key: {'✓ 已配置' if settings.zhipuai_api_key else '❌ 未配置'}")
|
||||
print(f" 数据库: {settings.database_url}")
|
||||
print(f" 监听: {settings.api_host}:{settings.api_port}")
|
||||
|
||||
if not settings.tushare_token:
|
||||
print("\n⚠️ 警告: Tushare Token未配置,数据查询功能将不可用")
|
||||
if not settings.zhipuai_api_key:
|
||||
print("⚠️ 警告: 智谱AI Key未配置,将使用规则模式(无AI分析)")
|
||||
print(f" Bitget 实盘: {'开启' if settings.bitget_trading_enabled else '关闭'}")
|
||||
EOF
|
||||
|
||||
echo ""
|
||||
echo "================================"
|
||||
echo "检查完成!准备启动..."
|
||||
echo "================================"
|
||||
echo ""
|
||||
echo "访问地址:"
|
||||
echo " 前端: http://localhost:8000"
|
||||
echo " API: http://localhost:8000/docs"
|
||||
echo ""
|
||||
echo "按 Ctrl+C 停止服务"
|
||||
echo ""
|
||||
|
||||
# 启动应用
|
||||
echo "启动应用..."
|
||||
python3 -m app.main
|
||||
|
||||
@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
A股短期题材选股 - 手动执行脚本
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.astock_agent.astock_agent import get_astock_agent
|
||||
|
||||
|
||||
async def main():
|
||||
"""手动执行选股"""
|
||||
# 解析命令行参数
|
||||
strict_mode = '--strict' in sys.argv or '-s' in sys.argv
|
||||
|
||||
try:
|
||||
print("\n" + "=" * 60)
|
||||
mode_text = "严格模式" if strict_mode else "宽松模式(适应当前市场)"
|
||||
print(f"📊 A股短期题材选股 - 手动执行 [{mode_text}]")
|
||||
print("=" * 60)
|
||||
|
||||
if not strict_mode:
|
||||
print("\n💡 使用宽松模式:")
|
||||
print(" - 市值: 30-1000亿(原50-500亿)")
|
||||
print(" - 换手率: 1%-20%(原3%-15%)")
|
||||
print(" - 板块涨幅: ≥1.5%(原2%)")
|
||||
print(" - 量比: ≥1.0(原1.2)")
|
||||
print("\n使用 --strict 或 -s 参数启用严格模式")
|
||||
|
||||
# 获取智能体实例
|
||||
agent = get_astock_agent()
|
||||
|
||||
# 设置模式
|
||||
agent.selector.strict_mode = strict_mode
|
||||
|
||||
# 执行选股
|
||||
result = await agent.run_once()
|
||||
|
||||
# 输出结果
|
||||
print("\n" + "=" * 60)
|
||||
print(agent.selector.format_output_text(result))
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选股执行失败: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
41
backend/start.sh
Executable file → Normal file
41
backend/start.sh
Executable file → Normal file
@ -1,64 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
# A股AI分析Agent系统 - 启动脚本(改进版)
|
||||
set -e
|
||||
|
||||
echo "================================"
|
||||
echo "A股AI分析Agent系统"
|
||||
echo "Crypto Agent Backend"
|
||||
echo "================================"
|
||||
echo ""
|
||||
|
||||
# 检查.env文件
|
||||
if [ ! -f "../.env" ] && [ ! -f ".env" ]; then
|
||||
echo "❌ 错误: 未找到.env配置文件"
|
||||
echo ""
|
||||
echo "请先配置环境变量:"
|
||||
echo " cd .."
|
||||
echo " cp .env.example .env"
|
||||
echo " # 编辑.env文件,填写API密钥"
|
||||
echo "❌ 未找到 .env 配置文件"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查虚拟环境
|
||||
if [ ! -d "venv" ]; then
|
||||
echo "❌ 错误: 虚拟环境不存在"
|
||||
echo ""
|
||||
echo "请先运行安装脚本:"
|
||||
echo " cd .."
|
||||
echo " ./install.sh"
|
||||
echo "❌ 未找到 backend/venv"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 激活虚拟环境
|
||||
echo "激活虚拟环境..."
|
||||
source venv/bin/activate
|
||||
|
||||
# 检查Python版本
|
||||
python_version=$(python --version 2>&1 | awk '{print $2}')
|
||||
echo "Python版本: $python_version"
|
||||
|
||||
# 显示配置信息
|
||||
echo ""
|
||||
echo "配置信息:"
|
||||
python -c "
|
||||
from app.config import get_settings
|
||||
settings = get_settings()
|
||||
print(f' Tushare Token: {'已配置' if settings.tushare_token else '未配置'}')
|
||||
print(f' DeepSeek Key: {'已配置' if settings.deepseek_api_key else '未配置'}')
|
||||
print(f' 智谱AI Key: {'已配置' if settings.zhipuai_api_key else '未配置'}')
|
||||
print(f' 数据库: {settings.database_url}')
|
||||
print(f' 监听地址: {settings.api_host}:{settings.api_port}')
|
||||
"
|
||||
|
||||
echo ""
|
||||
echo "================================"
|
||||
echo "启动服务..."
|
||||
echo "================================"
|
||||
echo ""
|
||||
echo "访问地址:"
|
||||
echo " 前端界面: http://localhost:8000"
|
||||
echo " API文档: http://localhost:8000/docs"
|
||||
echo ""
|
||||
echo "按 Ctrl+C 停止服务"
|
||||
echo ""
|
||||
|
||||
# 启动应用
|
||||
python -m app.main
|
||||
|
||||
12
backend/test_import.sh
Executable file → Normal file
12
backend/test_import.sh
Executable file → Normal file
@ -1,12 +1,10 @@
|
||||
#!/bin/bash
|
||||
# 测试应用启动
|
||||
|
||||
set -e
|
||||
|
||||
cd /Users/aaron/source_code/Stock_Agent/backend
|
||||
|
||||
# 激活虚拟环境
|
||||
source venv/bin/activate
|
||||
|
||||
# 测试导入
|
||||
echo "测试数据库模型..."
|
||||
python3 -c "from app.models.database import Base, Message; print('✓ 数据库模型导入成功')"
|
||||
|
||||
@ -19,8 +17,8 @@ echo "测试服务..."
|
||||
python3 -c "from app.services.cache_service import cache_service; print('✓ 缓存服务初始化成功')"
|
||||
|
||||
echo ""
|
||||
echo "测试Agent..."
|
||||
python3 -c "from app.agent.smart_agent import smart_agent; print('✓ Agent初始化成功')"
|
||||
echo "测试 Crypto Agent..."
|
||||
python3 -c "from app.crypto_agent.crypto_agent import get_crypto_agent; print('✓ Crypto Agent 导入成功')"
|
||||
|
||||
echo ""
|
||||
echo "所有测试通过!可以启动应用了。"
|
||||
echo "所有测试通过。"
|
||||
|
||||
@ -1,130 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试单个股票的筛选逻辑
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.config import get_settings
|
||||
from app.astock_agent.tushare_client import get_tushare_client
|
||||
from app.astock_agent.short_term_thematic_selector import get_thematic_selector
|
||||
|
||||
|
||||
async def test_single_stock():
|
||||
"""测试单个股票"""
|
||||
print("\n" + "=" * 80)
|
||||
print("🔍 测试单个股票筛选")
|
||||
print("=" * 80)
|
||||
|
||||
settings = get_settings()
|
||||
ts_client = get_tushare_client(settings.tushare_token)
|
||||
selector = get_thematic_selector(ts_client)
|
||||
|
||||
# 测试股票代码(从诊断脚本中找到的通过股票)
|
||||
test_stock = "000682.SZ"
|
||||
|
||||
# 获取该股票所属的板块
|
||||
sectors_df = ts_client.get_concept_sectors()
|
||||
smart_grid = sectors_df[sectors_df['name'] == '智能电网']
|
||||
|
||||
if smart_grid.empty:
|
||||
print("未找到智能电网板块")
|
||||
return
|
||||
|
||||
sector_code = smart_grid.iloc[0]['ts_code']
|
||||
sector_name = smart_grid.iloc[0]['name']
|
||||
|
||||
# 获取板块成分股
|
||||
members_df = ts_client.get_sector_members(sector_code)
|
||||
stock_codes = members_df['con_code'].tolist()
|
||||
|
||||
if test_stock not in stock_codes:
|
||||
print(f"{test_stock} 不在智能电网板块中")
|
||||
return
|
||||
|
||||
print(f"\n测试股票: {test_stock}")
|
||||
print(f"所属板块: {sector_name} ({sector_code})")
|
||||
print(f"板块成分股数量: {len(stock_codes)}")
|
||||
print(f"测试股票在板块中的位置: {stock_codes.index(test_stock) + 1}/{len(stock_codes)}")
|
||||
|
||||
# 获取实时行情 - 检查更多股票
|
||||
check_count = min(200, len(stock_codes))
|
||||
print(f"\n获取前 {check_count} 只股票的实时行情...")
|
||||
|
||||
realtime_df = ts_client.get_realtime_data(stock_codes[:check_count])
|
||||
|
||||
if realtime_df.empty:
|
||||
print("无法获取实时行情")
|
||||
return
|
||||
|
||||
print(f"实时行情数据获取成功,共 {len(realtime_df)} 只股票")
|
||||
|
||||
# 检查目标股票是否在行情数据中
|
||||
if test_stock not in realtime_df['ts_code'].values:
|
||||
print(f"❌ {test_stock} 不在行情数据中")
|
||||
print(f"行情数据中的股票: {realtime_df['ts_code'].tolist()[:10]}")
|
||||
return
|
||||
|
||||
stock_row = realtime_df[realtime_df['ts_code'] == test_stock].iloc[0]
|
||||
print(f"\n✓ {test_stock} 行情数据:")
|
||||
print(f" 现价: {stock_row['close']}")
|
||||
print(f" 涨跌幅: {stock_row['pct_chg']}%")
|
||||
print(f" 成交量: {stock_row['vol']}")
|
||||
print(f" 成交额: {stock_row['amount']}千元")
|
||||
|
||||
# 获取每日指标
|
||||
trade_date = realtime_df.iloc[0]['trade_date']
|
||||
basic_df = ts_client.get_stock_daily_basic([test_stock], str(trade_date))
|
||||
|
||||
print(f"\n每日指标数据: {'有' if not basic_df.empty else '无'}")
|
||||
if not basic_df.empty:
|
||||
basic_row = basic_df[basic_df['ts_code'] == test_stock]
|
||||
if not basic_row.empty:
|
||||
print(f" 换手率: {basic_row.iloc[0]['turnover_rate']}%")
|
||||
|
||||
# 调用选股器的内部检查函数
|
||||
print("\n开始筛选检查...")
|
||||
print("=" * 80)
|
||||
|
||||
# 检查所有股票
|
||||
passed_stocks = []
|
||||
for idx, stock_code in enumerate(stock_codes[:check_count], 1):
|
||||
try:
|
||||
result = selector._check_single_stock(
|
||||
stock_code=stock_code,
|
||||
sector_name=sector_name,
|
||||
sector_change=2.77,
|
||||
realtime_df=realtime_df,
|
||||
basic_df=basic_df
|
||||
)
|
||||
|
||||
if result:
|
||||
passed_stocks.append((stock_code, result.get('name', '')))
|
||||
print(f" ✓ [{idx}] {stock_code}: {result.get('name', '')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ [{idx}] {stock_code}: 检查失败 - {e}")
|
||||
|
||||
print("=" * 80)
|
||||
print(f"\n检查了 {check_count} 只股票,通过筛选: {len(passed_stocks)} 只")
|
||||
|
||||
if passed_stocks:
|
||||
print("\n✅ 通过的股票:")
|
||||
for stock_code, name in passed_stocks[:20]: # 只显示前20只
|
||||
print(f" - {stock_code}: {name}")
|
||||
if len(passed_stocks) > 20:
|
||||
print(f" ... 还有 {len(passed_stocks) - 20} 只")
|
||||
else:
|
||||
print("\n❌ 没有股票通过筛选")
|
||||
|
||||
|
||||
async def main():
|
||||
await test_single_stock()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -14,6 +14,17 @@ def _mock_settings():
|
||||
s.bitget_max_single_position = 1000.0
|
||||
s.account_max_drawdown = 0.25
|
||||
s.bitget_trading_enabled = False
|
||||
s.bitget_default_leverage = 10
|
||||
s.get_bitget_account_config = MagicMock(return_value={
|
||||
"account_id": "default",
|
||||
"api_key": "",
|
||||
"api_secret": "",
|
||||
"passphrase": "",
|
||||
"enabled": False,
|
||||
"use_testnet": True,
|
||||
"use_unified_account": True,
|
||||
})
|
||||
s.get_enabled_bitget_accounts = MagicMock(return_value=[])
|
||||
return s
|
||||
|
||||
|
||||
@ -31,4 +42,5 @@ sys.modules['app.utils.logger'] = mock_logger_module
|
||||
# ---- mock app.services.bitget_trading_api_sdk (避免 ccxt import) ----
|
||||
mock_sdk_module = MagicMock()
|
||||
mock_sdk_module.get_bitget_trading_api = MagicMock(return_value=MagicMock())
|
||||
mock_sdk_module.get_all_bitget_trading_apis = MagicMock(return_value={})
|
||||
sys.modules['app.services.bitget_trading_api_sdk'] = mock_sdk_module
|
||||
|
||||
@ -53,6 +53,7 @@ def make_service(settings_overrides=None):
|
||||
|
||||
# 用 __new__ 跳过 __init__(避免真实 API/数据库调用),手动设置所有属性
|
||||
service = BitgetLiveTradingService.__new__(BitgetLiveTradingService)
|
||||
service.account_id = "default"
|
||||
service.settings = mock_settings
|
||||
service.max_total_leverage = mock_settings.bitget_max_total_leverage
|
||||
service.max_single_position = mock_settings.bitget_max_single_position
|
||||
@ -673,6 +674,10 @@ class TestGetBitgetLiveServiceFactory:
|
||||
mock_settings.bitget_max_total_leverage = 10.0
|
||||
mock_settings.bitget_max_single_position = 1000.0
|
||||
mock_settings.account_max_drawdown = 0.25
|
||||
mock_settings.get_enabled_bitget_accounts.return_value = [{
|
||||
"account_id": "default",
|
||||
"enabled": True,
|
||||
}]
|
||||
|
||||
mock_api = MagicMock()
|
||||
mock_api._standardize_symbol = lambda s: f"{s}/USDT:USDT"
|
||||
@ -694,6 +699,10 @@ class TestGetBitgetLiveServiceFactory:
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.bitget_trading_enabled = True
|
||||
mock_settings.get_enabled_bitget_accounts.return_value = [{
|
||||
"account_id": "default",
|
||||
"enabled": True,
|
||||
}]
|
||||
|
||||
with patch('app.services.bitget_live_trading_service.get_settings', return_value=mock_settings), \
|
||||
patch('app.services.bitget_live_trading_service.get_bitget_trading_api', return_value=None):
|
||||
@ -1045,6 +1054,7 @@ class TestCancelOrder:
|
||||
|
||||
def test_cancel_success(self):
|
||||
service, mock_api = make_service()
|
||||
service.get_open_orders = MagicMock(return_value=[{'order_id': 'ord123'}])
|
||||
mock_api.cancel_order.return_value = True
|
||||
result = service.cancel_order('BTC', 'ord123')
|
||||
assert result['success'] is True
|
||||
@ -1052,12 +1062,14 @@ class TestCancelOrder:
|
||||
|
||||
def test_cancel_failure(self):
|
||||
service, mock_api = make_service()
|
||||
service.get_open_orders = MagicMock(return_value=[{'order_id': 'ord456'}])
|
||||
mock_api.cancel_order.return_value = False
|
||||
result = service.cancel_order('BTC', 'ord456')
|
||||
assert result['success'] is False
|
||||
|
||||
def test_cancel_exception(self):
|
||||
service, mock_api = make_service()
|
||||
service.get_open_orders = MagicMock(return_value=[{'order_id': 'ord789'}])
|
||||
mock_api.cancel_order.side_effect = Exception("order not found")
|
||||
result = service.cancel_order('BTC', 'ord789')
|
||||
assert result['success'] is False
|
||||
|
||||
@ -108,9 +108,13 @@ def make_agent():
|
||||
crypto_event_analysis_window_minutes=5,
|
||||
crypto_event_analysis_price_change_percent=1.0,
|
||||
crypto_event_analysis_cooldown_minutes=10,
|
||||
paper_trading_enabled=True,
|
||||
bitget_trading_enabled=True,
|
||||
)
|
||||
agent.paper_trading = None
|
||||
agent.bitget = None
|
||||
agent.bitget_services = {}
|
||||
agent.bitget_executors = {}
|
||||
agent.symbols = ['BTCUSDT']
|
||||
agent.executors = {}
|
||||
agent._platform_halts = {}
|
||||
@ -121,9 +125,13 @@ def make_agent():
|
||||
agent._analysis_notification_state = {}
|
||||
agent._lane_analysis_state = {}
|
||||
agent._event_analysis_state = {}
|
||||
agent.execution_guardian = MagicMock()
|
||||
agent.execution_guardian.get_status.return_value = {"last_status": "idle", "targets": [], "last_actions": []}
|
||||
agent._initial_balances = {}
|
||||
agent._target_execution_controls = {}
|
||||
agent._save_platform_halts = MagicMock()
|
||||
agent._save_initial_balances = MagicMock()
|
||||
agent._save_target_execution_controls = MagicMock()
|
||||
agent._send_alert_notification = AsyncMock()
|
||||
agent._emergency_close_all_positions = AsyncMock()
|
||||
return agent
|
||||
@ -137,14 +145,15 @@ def test_account_stop_loss_halts_only_triggered_platform():
|
||||
'current_balance': 700.0,
|
||||
}
|
||||
agent.bitget = bitget
|
||||
agent._get_risk_platforms = MagicMock(return_value=[('Bitget', bitget)])
|
||||
agent._get_risk_platforms = MagicMock(return_value=[('Bitget:default', bitget)])
|
||||
agent._get_initial_balance = MagicMock(return_value=1000.0)
|
||||
agent._get_bitget_target_key = MagicMock(return_value='Bitget:default')
|
||||
|
||||
should_stop, reason = asyncio.run(agent._check_account_level_stop_loss())
|
||||
|
||||
assert should_stop is True
|
||||
assert 'Bitget' in reason
|
||||
assert agent._platform_halts['Bitget']['halted'] is True
|
||||
assert 'Bitget:default' in reason
|
||||
assert agent._platform_halts['Bitget:default']['halted'] is True
|
||||
agent._emergency_close_all_positions.assert_awaited_once()
|
||||
|
||||
|
||||
@ -156,8 +165,10 @@ def test_resume_platform_resets_initial_balance_and_clears_halt():
|
||||
'current_balance': 888.0,
|
||||
}
|
||||
agent.bitget = bitget
|
||||
agent.bitget_services = {'default': bitget}
|
||||
agent._get_bitget_target_key = MagicMock(return_value='Bitget:default')
|
||||
agent._platform_halts = {
|
||||
'Bitget': {
|
||||
'Bitget:default': {
|
||||
'halted': True,
|
||||
'reason': 'drawdown',
|
||||
'drawdown_pct': 25.1,
|
||||
@ -167,7 +178,7 @@ def test_resume_platform_resets_initial_balance_and_clears_halt():
|
||||
result = agent.resume_platform('Bitget')
|
||||
|
||||
assert result['halted'] is False
|
||||
assert agent._initial_balances['Bitget'] == 888.0
|
||||
assert agent._initial_balances['Bitget:default'] == 888.0
|
||||
assert result['initial_balance'] == 888.0
|
||||
assert result['current_balance'] == 888.0
|
||||
|
||||
@ -207,3 +218,32 @@ def test_get_status_contains_last_execution_preview():
|
||||
|
||||
assert status['last_execution_preview']['BTCUSDT']['paper']['decision'] == 'OPEN'
|
||||
assert status['last_execution_preview']['BTCUSDT']['bitget']['reason'] == '替换旧挂单'
|
||||
|
||||
|
||||
def test_target_execution_status_uses_settings_defaults_until_overridden():
|
||||
agent = make_agent()
|
||||
agent.bitget_services = {'default': MagicMock()}
|
||||
agent._get_bitget_target_key = MagicMock(return_value='Bitget:default')
|
||||
agent._iter_bitget_accounts = MagicMock(return_value=['default'])
|
||||
|
||||
status = agent.get_target_execution_status()
|
||||
|
||||
assert status['PaperTrading']['enabled'] is True
|
||||
assert status['PaperTrading']['source'] == 'default'
|
||||
assert status['Bitget:default']['enabled'] is True
|
||||
assert status['Bitget:default']['source'] == 'default'
|
||||
|
||||
|
||||
def test_set_target_execution_enabled_persists_manual_override():
|
||||
agent = make_agent()
|
||||
agent.bitget_services = {'default': MagicMock()}
|
||||
agent._iter_bitget_accounts = MagicMock(return_value=['default'])
|
||||
agent._get_bitget_target_key = MagicMock(return_value='Bitget:default')
|
||||
|
||||
result = agent.set_target_execution_enabled('Bitget', False, 'manual off')
|
||||
|
||||
assert result['enabled'] is False
|
||||
assert result['source'] == 'manual'
|
||||
assert result['reason'] == 'manual off'
|
||||
assert agent._target_execution_controls['Bitget:default']['enabled'] is False
|
||||
agent._save_target_execution_controls.assert_called_once()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user