From c73341b95048984deceebc83fb1611cb04c06e32 Mon Sep 17 00:00:00 2001 From: aaron <> Date: Tue, 3 Feb 2026 10:08:15 +0800 Subject: [PATCH] first commit --- .env.example | 20 + .gitignore | 62 ++ PROJECT_SUMMARY.md | 348 ++++++++ README.md | 328 ++++++++ backend/app/__init__.py | 0 backend/app/agent/__init__.py | 0 backend/app/agent/context.py | 93 +++ backend/app/agent/core.py | 378 +++++++++ backend/app/agent/enhanced_agent.py | 377 +++++++++ backend/app/agent/skill_manager.py | 179 +++++ backend/app/agent/smart_agent.py | 966 +++++++++++++++++++++++ backend/app/api/__init__.py | 0 backend/app/api/chat.py | 67 ++ backend/app/api/skills.py | 99 +++ backend/app/api/stock.py | 80 ++ backend/app/config.py | 69 ++ backend/app/main.py | 70 ++ backend/app/models/__init__.py | 0 backend/app/models/chat.py | 37 + backend/app/models/database.py | 47 ++ backend/app/models/stock.py | 48 ++ backend/app/services/__init__.py | 0 backend/app/services/cache_service.py | 144 ++++ backend/app/services/db_service.py | 210 +++++ backend/app/services/llm_service.py | 175 ++++ backend/app/services/tushare_service.py | 263 ++++++ backend/app/skills/__init__.py | 0 backend/app/skills/base.py | 78 ++ backend/app/skills/brave_search.py | 180 +++++ backend/app/skills/fundamental.py | 61 ++ backend/app/skills/market_data.py | 140 ++++ backend/app/skills/technical_analysis.py | 202 +++++ backend/app/skills/visualization.py | 118 +++ backend/app/utils/__init__.py | 0 backend/app/utils/indicators.py | 158 ++++ backend/app/utils/logger.py | 60 ++ backend/app/utils/stock_names.py | 254 ++++++ backend/app/utils/validators.py | 103 +++ backend/diagnose.sh | 100 +++ backend/requirements.txt | 16 + backend/run.sh | 110 +++ backend/start.sh | 64 ++ backend/test_import.sh | 26 + backend/tests/__init__.py | 0 docs/DEPLOYMENT.md | 491 ++++++++++++ docs/INSTALL_GUIDE.md | 262 ++++++ docs/USER_GUIDE.md | 343 ++++++++ frontend/css/style.css | 551 +++++++++++++ frontend/css/style.css.backup | 196 +++++ frontend/index.html | 121 +++ frontend/index.html.backup | 133 ++++ frontend/js/app.js | 269 +++++++ frontend/js/app.js.backup | 219 +++++ install.sh | 149 ++++ start.sh | 78 ++ 55 files changed, 8542 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 PROJECT_SUMMARY.md create mode 100644 README.md create mode 100644 backend/app/__init__.py create mode 100644 backend/app/agent/__init__.py create mode 100644 backend/app/agent/context.py create mode 100644 backend/app/agent/core.py create mode 100644 backend/app/agent/enhanced_agent.py create mode 100644 backend/app/agent/skill_manager.py create mode 100644 backend/app/agent/smart_agent.py create mode 100644 backend/app/api/__init__.py create mode 100644 backend/app/api/chat.py create mode 100644 backend/app/api/skills.py create mode 100644 backend/app/api/stock.py create mode 100644 backend/app/config.py create mode 100644 backend/app/main.py create mode 100644 backend/app/models/__init__.py create mode 100644 backend/app/models/chat.py create mode 100644 backend/app/models/database.py create mode 100644 backend/app/models/stock.py create mode 100644 backend/app/services/__init__.py create mode 100644 backend/app/services/cache_service.py create mode 100644 backend/app/services/db_service.py create mode 100644 backend/app/services/llm_service.py create mode 100644 backend/app/services/tushare_service.py create mode 100644 backend/app/skills/__init__.py create mode 100644 backend/app/skills/base.py create mode 100644 backend/app/skills/brave_search.py create mode 100644 backend/app/skills/fundamental.py create mode 100644 backend/app/skills/market_data.py create mode 100644 backend/app/skills/technical_analysis.py create mode 100644 backend/app/skills/visualization.py create mode 100644 backend/app/utils/__init__.py create mode 100644 backend/app/utils/indicators.py create mode 100644 backend/app/utils/logger.py create mode 100644 backend/app/utils/stock_names.py create mode 100644 backend/app/utils/validators.py create mode 100755 backend/diagnose.sh create mode 100644 backend/requirements.txt create mode 100755 backend/run.sh create mode 100755 backend/start.sh create mode 100755 backend/test_import.sh create mode 100644 backend/tests/__init__.py create mode 100644 docs/DEPLOYMENT.md create mode 100644 docs/INSTALL_GUIDE.md create mode 100644 docs/USER_GUIDE.md create mode 100644 frontend/css/style.css create mode 100644 frontend/css/style.css.backup create mode 100644 frontend/index.html create mode 100644 frontend/index.html.backup create mode 100644 frontend/js/app.js create mode 100644 frontend/js/app.js.backup create mode 100755 install.sh create mode 100755 start.sh diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e1d9ebc --- /dev/null +++ b/.env.example @@ -0,0 +1,20 @@ +# Tushare API +TUSHARE_TOKEN=your_tushare_token_here + +# 智谱AI GLM-4 API +ZHIPUAI_API_KEY=your_zhipuai_key_here + +# Database (使用SQLite,无需额外配置) +DATABASE_URL=sqlite:///./stock_agent.db + +# API Settings +API_HOST=0.0.0.0 +API_PORT=8000 +DEBUG=True + +# Security +SECRET_KEY=your_secret_key_here_change_in_production +RATE_LIMIT=100/minute + +# CORS +CORS_ORIGINS=http://localhost:8000,http://127.0.0.1:8000 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b394b5c --- /dev/null +++ b/.gitignore @@ -0,0 +1,62 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual Environment +venv/ +ENV/ +env/ +.venv + +# Environment variables +.env + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Database +*.db +*.sqlite +*.sqlite3 + +# Logs +*.log +logs/ + +# OS +.DS_Store +Thumbs.db + +# Redis +dump.rdb + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# Claude Code +.claude/ diff --git a/PROJECT_SUMMARY.md b/PROJECT_SUMMARY.md new file mode 100644 index 0000000..b46c089 --- /dev/null +++ b/PROJECT_SUMMARY.md @@ -0,0 +1,348 @@ +# 项目完成总结 + +## 🎉 A股AI分析Agent系统 - 开发完成! + +### 项目概述 + +成功开发了一个功能完整的A股智能分析系统,集成了AI大模型、实时数据查询、技术分析等功能。 + +--- + +## ✅ 已完成功能 + +### 1. 核心功能 + +#### 🤖 AI Agent系统 +- ✅ 增强版Agent(集成智谱AI GLM-4) +- ✅ 规则模式(无需LLM也可运行) +- ✅ 智能意图识别 +- ✅ 上下文管理 +- ✅ 对话历史保存 + +#### 📊 数据查询 +- ✅ 实时行情查询(Tushare) +- ✅ 历史K线数据 +- ✅ 技术指标计算(MA、MACD、RSI、KDJ、BOLL) +- ✅ 基本面信息查询 +- ✅ 内存缓存(无需Redis) + +#### 🎨 数据可视化 +- ✅ 专业K线图(Lightweight Charts) +- ✅ 成交量柱状图 +- ✅ 技术指标图表 +- ✅ 交互式图表操作 + +#### 🔌 技能插件系统 +- ✅ 插件化架构 +- ✅ 动态启用/禁用 +- ✅ 4个核心技能: + - market_data(行情查询) + - technical_analysis(技术分析) + - fundamental(基本面) + - visualization(可视化) + +#### 🧠 智能识别 +- ✅ 200+股票名称数据库 +- ✅ 支持中文名称识别 +- ✅ 支持简称识别 +- ✅ 模糊匹配 + +### 2. 技术实现 + +#### 后端(Python) +- ✅ FastAPI框架 +- ✅ SQLAlchemy ORM(SQLite) +- ✅ 智谱AI GLM-4集成 +- ✅ Tushare数据接口 +- ✅ 内存缓存系统 +- ✅ 异步处理 + +#### 前端(轻量级) +- ✅ Vue 3(CDN版本) +- ✅ Bootstrap 5 +- ✅ Lightweight Charts +- ✅ 响应式设计 +- ✅ 实时对话界面 + +#### 数据库 +- ✅ SQLite(轻量级) +- ✅ 对话历史存储 +- ✅ 用户偏好管理 + +--- + +## 📁 项目结构 + +``` +Stock_Agent/ +├── backend/ # 后端(35个文件) +│ ├── app/ +│ │ ├── agent/ # AI Agent核心 +│ │ │ ├── core.py # 原始Agent +│ │ │ ├── enhanced_agent.py # 增强版Agent(LLM) +│ │ │ ├── context.py # 上下文管理 +│ │ │ └── skill_manager.py # 技能管理 +│ │ ├── api/ # API路由 +│ │ │ ├── chat.py # 对话接口 +│ │ │ ├── stock.py # 股票数据 +│ │ │ └── skills.py # 技能管理 +│ │ ├── models/ # 数据模型 +│ │ │ ├── database.py # SQLAlchemy模型 +│ │ │ ├── chat.py # Pydantic模型 +│ │ │ └── stock.py # 股票模型 +│ │ ├── services/ # 服务层 +│ │ │ ├── tushare_service.py # Tushare数据 +│ │ │ ├── cache_service.py # 内存缓存 +│ │ │ ├── db_service.py # 数据库 +│ │ │ └── llm_service.py # LLM服务 +│ │ ├── skills/ # 技能插件 +│ │ │ ├── base.py # 基类 +│ │ │ ├── market_data.py # 行情查询 +│ │ │ ├── technical_analysis.py # 技术分析 +│ │ │ ├── fundamental.py # 基本面 +│ │ │ └── visualization.py # 可视化 +│ │ ├── utils/ # 工具函数 +│ │ │ ├── logger.py # 日志 +│ │ │ ├── validators.py # 验证 +│ │ │ ├── indicators.py # 技术指标 +│ │ │ └── stock_names.py # 股票名称库 +│ │ ├── config.py # 配置管理 +│ │ └── main.py # 应用入口 +│ ├── requirements.txt # 依赖 +│ ├── start.sh # 启动脚本 +│ ├── run.sh # 检查并启动 +│ └── diagnose.sh # 诊断脚本 +├── frontend/ # 前端(3个文件) +│ ├── index.html # 主页面 +│ ├── css/style.css # 样式 +│ └── js/app.js # Vue应用 +├── docs/ # 文档(4个文件) +│ ├── API.md +│ ├── DEPLOYMENT.md +│ ├── USER_GUIDE.md +│ └── INSTALL_GUIDE.md +├── .env.example # 配置模板 +├── .gitignore +├── README.md +└── install.sh # 安装脚本 +``` + +--- + +## 🚀 快速启动 + +### 方法1:一键启动(推荐) + +```bash +cd /Users/aaron/source_code/Stock_Agent/backend +./run.sh +``` + +### 方法2:手动启动 + +```bash +cd /Users/aaron/source_code/Stock_Agent/backend +source venv/bin/activate +python -m app.main +``` + +### 访问系统 + +- 🌐 前端界面: http://localhost:8000 +- 📚 API文档: http://localhost:8000/docs + +--- + +## 💡 使用示例 + +### 支持的查询方式 + +``` +✅ "中国卫通的技术分析" +✅ "贵州茅台的实时行情" +✅ "比亚迪的K线图" +✅ "宁德时代的基本信息" +✅ "查询600519" +✅ "分析000001的技术指标" +``` + +### AI分析示例 + +``` +用户:对中国卫通进行技术分析 + +系统: +【601698】技术指标: +均线:MA5=15.23, MA10=15.10, MA20=14.95 +MACD:DIF=0.12, DEA=0.08, MACD=0.08 +RSI:RSI6=58.3, RSI12=55.2, RSI24=52.1 + +【AI分析】 +中国卫通(601698)当前技术面表现中性偏多... +(智能分析总结) +``` + +--- + +## 🔧 配置说明 + +### 必需配置 + +在`.env`文件中配置: + +```env +# Tushare数据源(必需) +TUSHARE_TOKEN=your_token_here + +# 智谱AI(可选,不配置则使用规则模式) +ZHIPUAI_API_KEY=your_key_here +``` + +### 运行模式 + +1. **完整模式**(推荐) + - 配置Tushare + 智谱AI + - 支持所有功能 + AI分析 + +2. **规则模式** + - 仅配置Tushare + - 支持数据查询,无AI分析 + +3. **演示模式** + - 不配置任何API + - 仅展示界面和架构 + +--- + +## 📊 技术亮点 + +### 1. 双模式Agent +- LLM模式:智能意图识别 + AI分析 +- 规则模式:快速响应 + 稳定可靠 +- 自动切换:LLM失败时回退 + +### 2. 智能股票识别 +- 200+股票名称数据库 +- 支持全称、简称、模糊匹配 +- 自动提取股票代码 + +### 3. 轻量级架构 +- 无需Redis(内存缓存) +- 无需PostgreSQL(SQLite) +- 无需构建工具(CDN) +- 一键启动 + +### 4. 专业图表 +- TradingView开源图表库 +- 金融级K线渲染 +- 交互式操作 + +--- + +## 📝 文档清单 + +1. **README.md** - 项目说明和快速开始 +2. **docs/INSTALL_GUIDE.md** - 详细安装指南 +3. **docs/USER_GUIDE.md** - 用户使用手册 +4. **docs/DEPLOYMENT.md** - 部署文档 +5. **本文档** - 项目完成总结 + +--- + +## 🎯 已解决的问题 + +### 问题1:Python 3.13兼容性 +- ✅ 更新依赖版本 +- ✅ 创建安装指南 +- ✅ 提供多种解决方案 + +### 问题2:SQLAlchemy保留字冲突 +- ✅ 修改字段名(metadata → msg_metadata) +- ✅ 更新所有引用 + +### 问题3:配置文件加载 +- ✅ 智能查找.env文件 +- ✅ 支持多目录启动 + +### 问题4:股票名称识别 +- ✅ 创建200+股票名称库 +- ✅ 支持中文名称和简称 +- ✅ 模糊匹配算法 + +### 问题5:缺少LLM分析 +- ✅ 集成智谱AI GLM-4 +- ✅ 智能意图识别 +- ✅ AI分析总结 +- ✅ 自动回退机制 + +--- + +## 🎊 项目特色 + +1. **开箱即用** + - 一键安装脚本 + - 自动检查脚本 + - 详细错误提示 + +2. **智能分析** + - LLM驱动的意图识别 + - 专业的技术分析 + - 自然语言总结 + +3. **易于扩展** + - 插件化技能系统 + - 清晰的代码结构 + - 完善的文档 + +4. **生产就绪** + - 错误处理 + - 日志系统 + - 缓存优化 + - 数据验证 + +--- + +## 📈 下一步建议 + +### 短期优化 +1. 添加更多股票名称 +2. 优化LLM提示词 +3. 添加更多技术指标 +4. 改进图表交互 + +### 中期扩展 +1. 支持港股、美股 +2. 添加实时预警 +3. 用户认证系统 +4. 自选股管理 + +### 长期规划 +1. 移动端适配 +2. 多语言支持 +3. 社区功能 +4. 量化策略 + +--- + +## 🙏 致谢 + +- **Tushare** - 金融数据接口 +- **智谱AI** - GLM-4大模型 +- **FastAPI** - 高性能Web框架 +- **LangChain** - AI应用框架 +- **Lightweight Charts** - 专业图表库 + +--- + +## 📞 支持 + +如有问题,请查看: +1. [安装指南](docs/INSTALL_GUIDE.md) +2. [用户手册](docs/USER_GUIDE.md) +3. [部署文档](docs/DEPLOYMENT.md) + +--- + +**项目状态:✅ 完成并可用** + +**最后更新:2026-02-03** diff --git a/README.md b/README.md new file mode 100644 index 0000000..a6fbfbb --- /dev/null +++ b/README.md @@ -0,0 +1,328 @@ +# A股AI分析Agent系统 + +基于AI Agent的股票智能分析系统,提供自然语言对话界面,支持实时行情查询、技术分析、基本面分析等功能。 + +## 功能特性 + +- **自然语言对话**:通过对话方式查询股票信息 +- **实时行情查询**:获取股票实时价格、涨跌幅等数据 +- **技术指标分析**:计算MA、MACD、RSI、KDJ、BOLL等技术指标 +- **基本面信息**:查询公司概况、行业、上市日期等 +- **数据可视化**:生成专业的K线图和技术指标图表 +- **技能插件系统**:可扩展的技能架构,支持动态启用/禁用 +- **对话历史**:保存和查看历史分析记录 + +## 技术栈 + +### 后端 +- **框架**:FastAPI +- **AI Agent**:LangChain + 智谱AI GLM-4 +- **数据源**:Tushare +- **缓存**:内存缓存(无需Redis) +- **数据库**:SQLite +- **语言**:Python 3.11+ (推荐 3.11 或 3.12) + +### 前端 +- **框架**:Vue 3 (CDN版本) +- **UI**:Bootstrap 5 +- **图表**:Lightweight Charts +- **通信**:Fetch API + +## 项目结构 + +``` +Stock_Agent/ +├── backend/ # 后端代码 +│ ├── app/ +│ │ ├── agent/ # AI Agent核心 +│ │ ├── api/ # API路由 +│ │ ├── models/ # 数据模型 +│ │ ├── services/ # 数据服务 +│ │ ├── skills/ # 技能插件 +│ │ ├── utils/ # 工具函数 +│ │ ├── config.py # 配置管理 +│ │ └── main.py # 应用入口 +│ └── requirements.txt # Python依赖 +├── frontend/ # 前端代码 +│ ├── css/ # 样式文件 +│ ├── js/ # JavaScript文件 +│ └── index.html # 主页面 +├── .env.example # 环境变量示例 +├── .gitignore +└── README.md +``` + +## 快速开始 + +### ⚠️ 重要提示:Python 版本 + +**推荐使用 Python 3.11 或 3.12**。如果您使用 Python 3.13,可能会遇到依赖安装问题。 + +详细的安装问题解决方案请查看:[安装指南](docs/INSTALL_GUIDE.md) + +### 1. 环境准备 + +**系统要求**: +- Python 3.11 或 3.12(推荐) +- 无需 Redis(使用内存缓存) + +**获取API密钥**: +- [Tushare](https://tushare.pro/):注册并获取Token +- [智谱AI](https://open.bigmodel.cn/):注册并获取API Key + +### 2. 安装依赖 + +```bash +# 进入后端目录 +cd backend + +# 创建虚拟环境(使用 Python 3.11) +python3.11 -m venv venv + +# 激活虚拟环境 +# Windows: +venv\Scripts\activate +# macOS/Linux: +source venv/bin/activate + +# 安装依赖 +pip install --upgrade pip +pip install -r requirements.txt +``` + +**如果遇到安装错误**,请查看 [安装指南](docs/INSTALL_GUIDE.md) 获取详细解决方案。 + +### 3. 配置环境变量 + +复制 `.env.example` 为 `.env` 并填写配置: + +```bash +cp .env.example .env +``` + +编辑 `.env` 文件: + +```env +# Tushare API +TUSHARE_TOKEN=your_tushare_token_here + +# 智谱AI GLM-4 API +ZHIPUAI_API_KEY=your_zhipuai_key_here + +# 其他配置保持默认即可 +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_DB=0 + +# 其他配置保持默认即可 +``` + +### 4. 启动Redis(可选) + +如果要使用缓存功能,请先启动Redis: + +```bash +# macOS (使用Homebrew) +brew services start redis + +# Linux +sudo systemctl start redis + +# Windows +# 下载并运行Redis for Windows +``` + +### 5. 启动后端服务 + +```bash +# 在backend目录下 +cd backend + +# 启动服务 +python -m app.main + +# 或使用uvicorn +uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 +``` + +服务启动后,访问: +- 前端界面:http://localhost:8000 +- API文档:http://localhost:8000/docs + +## 使用指南 + +### 基本查询 + +1. **查询实时行情** + ``` + 查询600519的实时行情 + 贵州茅台的价格 + 000001现在多少钱 + ``` + +2. **查看K线图** + ``` + 600519的K线图 + 贵州茅台的走势 + ``` + +3. **技术指标分析** + ``` + 600519的技术指标 + 分析贵州茅台的MACD + ``` + +4. **基本面信息** + ``` + 600519的基本信息 + 贵州茅台是什么行业 + ``` + +### 技能管理 + +点击右上角"技能管理"按钮,可以: +- 查看所有可用技能 +- 启用/禁用特定技能 +- 查看技能描述 + +## API文档 + +启动服务后,访问 http://localhost:8000/docs 查看完整的API文档。 + +### 主要接口 + +#### 1. 发送消息 +```http +POST /api/chat/message +Content-Type: application/json + +{ + "message": "查询600519的实时行情", + "session_id": "optional_session_id" +} +``` + +#### 2. 获取对话历史 +```http +GET /api/chat/history/{session_id} +``` + +#### 3. 获取股票行情 +```http +GET /api/stock/quote/{stock_code} +``` + +#### 4. 获取K线数据 +```http +GET /api/stock/kline/{stock_code}?start_date=20240101&end_date=20240201 +``` + +#### 5. 获取技能列表 +```http +GET /api/skills/ +``` + +## 开发指南 + +### 添加新技能 + +1. 在 `backend/app/skills/` 目录下创建新的技能文件 +2. 继承 `BaseSkill` 类并实现 `execute` 方法 +3. 在 `backend/app/agent/core.py` 中注册新技能 + +示例: + +```python +from app.skills.base import BaseSkill, SkillParameter + +class MyNewSkill(BaseSkill): + def __init__(self): + super().__init__() + self.name = "my_skill" + self.description = "我的新技能" + self.parameters = [ + SkillParameter( + name="param1", + type="string", + description="参数1", + required=True + ) + ] + + async def execute(self, **kwargs) -> Dict[str, Any]: + # 实现技能逻辑 + return {"result": "success"} +``` + +### 扩展数据源 + +在 `backend/app/services/` 目录下添加新的数据服务类,参考 `tushare_service.py` 的实现。 + +## 常见问题 + +### 1. Redis连接失败 + +如果Redis未安装或未启动,系统会自动降级,不影响核心功能,但会失去缓存能力。 + +### 2. Tushare API限制 + +免费版Tushare有调用频率限制(120次/分钟)。如果遇到限制,可以: +- 等待一段时间后重试 +- 考虑升级到付费版 +- 使用Redis缓存减少API调用 + +### 3. 股票代码格式 + +支持的股票代码格式: +- 6位数字:600000、000001 +- 带后缀:600000.SH、000001.SZ +- 股票名称:贵州茅台、中国平安 + +### 4. 端口被占用 + +如果8000端口被占用,可以修改 `.env` 文件中的 `API_PORT` 配置。 + +## 性能优化 + +1. **启用Redis缓存**:显著减少API调用和响应时间 +2. **调整缓存TTL**:在 `cache_service.py` 中根据需求调整缓存时间 +3. **限制历史消息数**:在 `context.py` 中调整 `max_history` 参数 + +## 安全建议 + +1. **生产环境**: + - 修改 `.env` 中的 `SECRET_KEY` + - 设置 `DEBUG=False` + - 配置严格的CORS策略 + - 使用HTTPS + +2. **API密钥**: + - 不要将 `.env` 文件提交到版本控制 + - 定期更换API密钥 + - 使用环境变量或密钥管理服务 + +## 贡献指南 + +欢迎贡献代码!请遵循以下步骤: + +1. Fork本项目 +2. 创建特性分支 (`git checkout -b feature/AmazingFeature`) +3. 提交更改 (`git commit -m 'Add some AmazingFeature'`) +4. 推送到分支 (`git push origin feature/AmazingFeature`) +5. 开启Pull Request + +## 许可证 + +本项目采用 MIT 许可证。 + +## 联系方式 + +如有问题或建议,请提交Issue。 + +## 致谢 + +- [Tushare](https://tushare.pro/) - 金融数据接口 +- [智谱AI](https://open.bigmodel.cn/) - AI模型服务 +- [FastAPI](https://fastapi.tiangolo.com/) - Web框架 +- [LangChain](https://www.langchain.com/) - AI Agent框架 +- [Lightweight Charts](https://tradingview.github.io/lightweight-charts/) - 图表库 diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agent/__init__.py b/backend/app/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/agent/context.py b/backend/app/agent/context.py new file mode 100644 index 0000000..616e842 --- /dev/null +++ b/backend/app/agent/context.py @@ -0,0 +1,93 @@ +""" +上下文管理器 +管理对话历史和上下文 +""" +from typing import List, Dict, Optional +from app.services.db_service import db_service +from app.utils.logger import logger + + +class ContextManager: + """上下文管理器""" + + def __init__(self, max_history: int = 10): + """ + 初始化上下文管理器 + + Args: + max_history: 最大历史消息数 + """ + self.max_history = max_history + + def get_context(self, session_id: str) -> List[Dict[str, str]]: + """ + 获取对话上下文 + + Args: + session_id: 会话ID + + Returns: + 消息列表 + """ + messages = db_service.get_conversation_history(session_id, limit=self.max_history) + + context = [] + for msg in messages: + context.append({ + "role": msg.role, + "content": msg.content + }) + + return context + + def add_message( + self, + session_id: str, + role: str, + content: str, + metadata: Optional[dict] = None + ): + """ + 添加消息到上下文 + + Args: + session_id: 会话ID + role: 角色(user/assistant) + content: 消息内容 + metadata: 元数据 + """ + db_service.add_message(session_id, role, content, metadata) + logger.info(f"添加消息到上下文: {session_id}, {role}") + + def clear_context(self, session_id: str): + """ + 清除上下文(暂不实现删除,保留历史) + + Args: + session_id: 会话ID + """ + logger.info(f"清除上下文请求: {session_id}") + # 实际不删除,只是标记 + pass + + def format_context_for_llm(self, session_id: str) -> str: + """ + 格式化上下文供LLM使用 + + Args: + session_id: 会话ID + + Returns: + 格式化的上下文字符串 + """ + context = self.get_context(session_id) + + if not context: + return "" + + formatted = [] + for msg in context: + role = "用户" if msg["role"] == "user" else "助手" + formatted.append(f"{role}: {msg['content']}") + + return "\n".join(formatted) diff --git a/backend/app/agent/core.py b/backend/app/agent/core.py new file mode 100644 index 0000000..d301468 --- /dev/null +++ b/backend/app/agent/core.py @@ -0,0 +1,378 @@ +""" +AI Agent核心 +基于LangChain的股票分析Agent +""" +import re +import json +from typing import Dict, Any, Optional +from app.config import get_settings +from app.agent.context import ContextManager +from app.agent.skill_manager import skill_manager +from app.skills.market_data import MarketDataSkill +from app.skills.technical_analysis import TechnicalAnalysisSkill +from app.skills.fundamental import FundamentalSkill +from app.skills.visualization import VisualizationSkill +from app.utils.logger import logger + + +class StockAnalysisAgent: + """股票分析Agent""" + + def __init__(self): + """初始化Agent""" + self.context_manager = ContextManager() + self.settings = get_settings() + + # 注册技能 + self._register_skills() + + # 初始化LLM(简化版,使用规则匹配) + # 在实际部署时,这里应该集成智谱AI GLM-4 + self.use_llm = bool(self.settings.zhipuai_api_key) + + logger.info("Stock Analysis Agent初始化完成") + + def _register_skills(self): + """注册所有技能""" + skill_manager.register(MarketDataSkill()) + skill_manager.register(TechnicalAnalysisSkill()) + skill_manager.register(FundamentalSkill()) + skill_manager.register(VisualizationSkill()) + logger.info("技能注册完成") + + async def process_message( + self, + message: str, + session_id: str, + user_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + 处理用户消息 + + Args: + message: 用户消息 + session_id: 会话ID + user_id: 用户ID + + Returns: + 响应结果 + """ + logger.info(f"处理消息: {message[:50]}...") + + # 保存用户消息 + self.context_manager.add_message(session_id, "user", message) + + # 意图识别和技能调用 + intent = self._recognize_intent(message) + logger.info(f"识别意图: {intent}") + + # 执行技能 + result = await self._execute_intent(intent, message) + + # 生成响应 + response = self._generate_response(intent, result) + + # 保存助手响应 + self.context_manager.add_message( + session_id, + "assistant", + response["message"], + metadata=response.get("metadata") + ) + + return response + + def _recognize_intent(self, message: str) -> Dict[str, Any]: + """ + 识别用户意图(简化版规则匹配) + + Args: + message: 用户消息 + + Returns: + 意图字典 + """ + message_lower = message.lower() + + # 提取股票代码 + stock_code = self._extract_stock_code(message) + + # 行情查询 + if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]): + return { + "type": "market_data", + "skill": "market_data", + "params": { + "stock_code": stock_code, + "data_type": "quote" + } + } + + # K线查询 + if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]): + return { + "type": "visualization", + "skill": "visualization", + "params": { + "stock_code": stock_code, + "chart_type": "candlestick" + } + } + + # 技术分析 + if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]): + return { + "type": "technical_analysis", + "skill": "technical_analysis", + "params": { + "stock_code": stock_code, + "indicators": ["ma", "macd", "rsi"] + } + } + + # 基本面 + if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]): + return { + "type": "fundamental", + "skill": "fundamental", + "params": { + "stock_code": stock_code + } + } + + # 默认:行情查询 + if stock_code: + return { + "type": "market_data", + "skill": "market_data", + "params": { + "stock_code": stock_code, + "data_type": "quote" + } + } + + # 无法识别 + return { + "type": "unknown", + "skill": None, + "params": {} + } + + def _extract_stock_code(self, message: str) -> Optional[str]: + """ + 从消息中提取股票代码 + + Args: + message: 用户消息 + + Returns: + 股票代码或None + """ + from app.utils.stock_names import search_stock_by_name + + # 匹配6位数字 + pattern = r'\b\d{6}\b' + matches = re.findall(pattern, message) + + if matches: + return matches[0] + + # 使用股票名称数据库搜索 + # 提取可能的股票名称(2-6个汉字) + chinese_pattern = r'[\u4e00-\u9fa5]{2,6}' + chinese_words = re.findall(chinese_pattern, message) + + for word in chinese_words: + code = search_stock_by_name(word) + if code: + logger.info(f"识别股票名称: {word} -> {code}") + return code + + return None + + async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]: + """ + 执行意图对应的技能 + + Args: + intent: 意图字典 + message: 原始消息 + + Returns: + 执行结果 + """ + if intent["type"] == "unknown": + return { + "success": False, + "error": "无法理解您的问题,请提供股票代码或明确的查询意图" + } + + skill_name = intent["skill"] + params = intent["params"] + + if not params.get("stock_code"): + return { + "success": False, + "error": "请提供股票代码(6位数字)" + } + + # 执行技能 + result = await skill_manager.execute_skill(skill_name, **params) + + return result + + def _generate_response(self, intent: Dict[str, Any], result: Dict[str, Any]) -> Dict[str, Any]: + """ + 生成响应消息 + + Args: + intent: 意图 + result: 执行结果 + + Returns: + 响应字典 + """ + if not result.get("success", True): + return { + "message": f"抱歉,{result.get('error', '处理失败')}", + "metadata": { + "type": "error" + } + } + + data = result.get("data", result) + + # 根据意图类型生成不同响应 + if intent["type"] == "market_data": + return self._format_market_data_response(data) + elif intent["type"] == "technical_analysis": + return self._format_technical_response(data) + elif intent["type"] == "fundamental": + return self._format_fundamental_response(data) + elif intent["type"] == "visualization": + return self._format_visualization_response(data) + else: + return { + "message": "查询完成", + "metadata": { + "type": "data", + "data": data + } + } + + def _format_market_data_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化行情数据响应""" + if "error" in data: + return { + "message": f"查询失败:{data['error']}", + "metadata": {"type": "error"} + } + + if "kline_data" in data: + kline_data = data["kline_data"] + message = f"已获取K线数据,共{len(kline_data)}条记录" + return { + "message": message, + "metadata": { + "type": "kline", + "data": kline_data + } + } + + # 实时行情 + message = f""" +【{data.get('name', '股票')}】({data.get('ts_code', '')}) +交易日期:{data.get('trade_date', '')} +最新价:{data.get('close', 0):.2f} +涨跌额:{data.get('change', 0):.2f} +涨跌幅:{data.get('pct_chg', 0):.2f}% +开盘价:{data.get('open', 0):.2f} +最高价:{data.get('high', 0):.2f} +最低价:{data.get('low', 0):.2f} +成交量:{data.get('vol', 0):.0f}手 +成交额:{data.get('amount', 0):.0f}千元 + """.strip() + + return { + "message": message, + "metadata": { + "type": "quote", + "data": data + } + } + + def _format_technical_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化技术分析响应""" + if "error" in data: + return { + "message": f"分析失败:{data['error']}", + "metadata": {"type": "error"} + } + + indicators = data.get("indicators", {}) + message_parts = [f"【{data.get('stock_code', '')}】技术指标:\n"] + + if "ma" in indicators: + ma = indicators["ma"] + message_parts.append(f"均线:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}") + + if "macd" in indicators: + macd = indicators["macd"] + message_parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}") + + if "rsi" in indicators: + rsi = indicators["rsi"] + message_parts.append(f"RSI:RSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}") + + return { + "message": "\n".join(message_parts), + "metadata": { + "type": "technical", + "data": data + } + } + + def _format_fundamental_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化基本面响应""" + if "error" in data: + return { + "message": f"查询失败:{data['error']}", + "metadata": {"type": "error"} + } + + message = f""" +【{data.get('name', '股票')}】基本信息 +股票代码:{data.get('ts_code', '')} +所属地域:{data.get('area', '')} +所属行业:{data.get('industry', '')} +上市市场:{data.get('market', '')} +上市日期:{data.get('list_date', '')} + """.strip() + + return { + "message": message, + "metadata": { + "type": "fundamental", + "data": data + } + } + + def _format_visualization_response(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化可视化响应""" + if "error" in data: + return { + "message": f"生成图表失败:{data['error']}", + "metadata": {"type": "error"} + } + + return { + "message": f"已生成{data.get('stock_code', '')}的K线图", + "metadata": { + "type": "chart", + "data": data + } + } + + +# 创建全局Agent实例 +stock_agent = StockAnalysisAgent() diff --git a/backend/app/agent/enhanced_agent.py b/backend/app/agent/enhanced_agent.py new file mode 100644 index 0000000..bbd90f0 --- /dev/null +++ b/backend/app/agent/enhanced_agent.py @@ -0,0 +1,377 @@ +""" +增强版Agent - 集成LLM智能分析 +""" +import re +import json +from typing import Dict, Any, Optional +from app.config import get_settings +from app.agent.context import ContextManager +from app.agent.skill_manager import skill_manager +from app.skills.market_data import MarketDataSkill +from app.skills.technical_analysis import TechnicalAnalysisSkill +from app.skills.fundamental import FundamentalSkill +from app.skills.visualization import VisualizationSkill +from app.services.llm_service import llm_service +from app.utils.logger import logger +from app.utils.stock_names import search_stock_by_name, get_stock_name + + +class EnhancedStockAgent: + """增强版股票分析Agent(集成LLM)""" + + def __init__(self): + """初始化Agent""" + self.context_manager = ContextManager() + self.settings = get_settings() + + # 注册技能 + self._register_skills() + + # 检查LLM是否可用 + self.use_llm = bool(self.settings.zhipuai_api_key) and llm_service.client is not None + + if self.use_llm: + logger.info("Enhanced Agent初始化完成(LLM模式)") + else: + logger.info("Enhanced Agent初始化完成(规则模式)") + + def _register_skills(self): + """注册所有技能""" + skill_manager.register(MarketDataSkill()) + skill_manager.register(TechnicalAnalysisSkill()) + skill_manager.register(FundamentalSkill()) + skill_manager.register(VisualizationSkill()) + logger.info("技能注册完成") + + async def process_message( + self, + message: str, + session_id: str, + user_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + 处理用户消息(增强版) + + Args: + message: 用户消息 + session_id: 会话ID + user_id: 用户ID + + Returns: + 响应结果 + """ + logger.info(f"处理消息: {message[:50]}...") + + # 保存用户消息 + self.context_manager.add_message(session_id, "user", message) + + # 提取股票代码 + stock_code = self._extract_stock_code(message) + + # 使用LLM或规则识别意图 + if self.use_llm: + intent = await self._recognize_intent_with_llm(message, stock_code) + else: + intent = self._recognize_intent_with_rules(message, stock_code) + + logger.info(f"识别意图: {intent}") + + # 执行技能 + result = await self._execute_intent(intent, message) + + # 生成响应(使用LLM增强) + response = await self._generate_response(intent, result, stock_code) + + # 保存助手响应 + self.context_manager.add_message( + session_id, + "assistant", + response["message"], + metadata=response.get("metadata") + ) + + return response + + async def _recognize_intent_with_llm( + self, + message: str, + stock_code: Optional[str] + ) -> Dict[str, Any]: + """使用LLM识别意图""" + try: + llm_result = llm_service.analyze_intent(message) + + intent_type = llm_result.get("type", "unknown") + confidence = llm_result.get("confidence", 0) + + # 如果置信度太低,回退到规则模式 + if confidence < 0.5: + logger.info("LLM置信度低,回退到规则模式") + return self._recognize_intent_with_rules(message, stock_code) + + # 构建意图 + intent = { + "type": intent_type, + "confidence": confidence, + "skill": self._map_intent_to_skill(intent_type), + "params": {"stock_code": stock_code} if stock_code else {} + } + + return intent + + except Exception as e: + logger.error(f"LLM意图识别失败: {e}") + return self._recognize_intent_with_rules(message, stock_code) + + def _recognize_intent_with_rules( + self, + message: str, + stock_code: Optional[str] + ) -> Dict[str, Any]: + """使用规则识别意图(原有逻辑)""" + message_lower = message.lower() + + # 行情查询 + if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]): + return { + "type": "market_data", + "skill": "market_data", + "params": { + "stock_code": stock_code, + "data_type": "quote" + } + } + + # K线查询 + if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]): + return { + "type": "visualization", + "skill": "visualization", + "params": { + "stock_code": stock_code, + "chart_type": "candlestick" + } + } + + # 技术分析 + if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]): + return { + "type": "technical_analysis", + "skill": "technical_analysis", + "params": { + "stock_code": stock_code, + "indicators": ["ma", "macd", "rsi"] + } + } + + # 基本面 + if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]): + return { + "type": "fundamental", + "skill": "fundamental", + "params": { + "stock_code": stock_code + } + } + + # 默认:行情查询 + if stock_code: + return { + "type": "market_data", + "skill": "market_data", + "params": { + "stock_code": stock_code, + "data_type": "quote" + } + } + + # 无法识别 + return { + "type": "unknown", + "skill": None, + "params": {} + } + + def _map_intent_to_skill(self, intent_type: str) -> Optional[str]: + """将意图类型映射到技能名称""" + mapping = { + "market_data": "market_data", + "technical_analysis": "technical_analysis", + "fundamental": "fundamental", + "visualization": "visualization" + } + return mapping.get(intent_type) + + def _extract_stock_code(self, message: str) -> Optional[str]: + """从消息中提取股票代码""" + # 匹配6位数字 + pattern = r'\b\d{6}\b' + matches = re.findall(pattern, message) + + if matches: + return matches[0] + + # 使用股票名称数据库搜索 + chinese_pattern = r'[\u4e00-\u9fa5]{2,6}' + chinese_words = re.findall(chinese_pattern, message) + + for word in chinese_words: + code = search_stock_by_name(word) + if code: + logger.info(f"识别股票名称: {word} -> {code}") + return code + + return None + + async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]: + """执行意图对应的技能""" + if intent["type"] == "unknown": + return { + "success": False, + "error": "无法理解您的问题,请提供股票代码或明确的查询意图" + } + + skill_name = intent["skill"] + params = intent["params"] + + if not params.get("stock_code"): + return { + "success": False, + "error": "请提供股票代码或股票名称" + } + + # 执行技能 + result = await skill_manager.execute_skill(skill_name, **params) + return result + + async def _generate_response( + self, + intent: Dict[str, Any], + result: Dict[str, Any], + stock_code: Optional[str] + ) -> Dict[str, Any]: + """生成响应消息(使用LLM增强)""" + if not result.get("success", True): + return { + "message": f"抱歉,{result.get('error', '处理失败')}", + "metadata": {"type": "error"} + } + + data = result.get("data", result) + + # 基础格式化 + base_response = self._format_response_basic(intent, data) + + # 如果启用LLM,添加智能分析 + if self.use_llm and stock_code and intent["type"] == "technical_analysis": + try: + stock_name = get_stock_name(stock_code) or stock_code + llm_summary = llm_service.generate_analysis_summary( + stock_code, stock_name, data + ) + base_response["message"] += f"\n\n【AI分析】\n{llm_summary}" + except Exception as e: + logger.error(f"LLM分析生成失败: {e}") + + return base_response + + def _format_response_basic(self, intent: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]: + """基础响应格式化(原有逻辑)""" + if "error" in data: + return { + "message": f"查询失败:{data['error']}", + "metadata": {"type": "error"} + } + + intent_type = intent["type"] + + if intent_type == "market_data": + return self._format_market_data(data) + elif intent_type == "technical_analysis": + return self._format_technical(data) + elif intent_type == "fundamental": + return self._format_fundamental(data) + elif intent_type == "visualization": + return self._format_visualization(data) + else: + return { + "message": "查询完成", + "metadata": {"type": "data", "data": data} + } + + def _format_market_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化行情数据""" + if "kline_data" in data: + kline_data = data["kline_data"] + message = f"已获取K线数据,共{len(kline_data)}条记录" + return { + "message": message, + "metadata": {"type": "kline", "data": kline_data} + } + + message = f""" +【{data.get('name', '股票')}】({data.get('ts_code', '')}) +交易日期:{data.get('trade_date', '')} +最新价:{data.get('close', 0):.2f} +涨跌额:{data.get('change', 0):.2f} +涨跌幅:{data.get('pct_chg', 0):.2f}% +开盘价:{data.get('open', 0):.2f} +最高价:{data.get('high', 0):.2f} +最低价:{data.get('low', 0):.2f} +成交量:{data.get('vol', 0):.0f}手 +成交额:{data.get('amount', 0):.0f}千元 + """.strip() + + return { + "message": message, + "metadata": {"type": "quote", "data": data} + } + + def _format_technical(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化技术分析""" + indicators = data.get("indicators", {}) + message_parts = [f"【{data.get('stock_code', '')}】技术指标:\n"] + + if "ma" in indicators: + ma = indicators["ma"] + message_parts.append(f"均线:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}") + + if "macd" in indicators: + macd = indicators["macd"] + message_parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}") + + if "rsi" in indicators: + rsi = indicators["rsi"] + message_parts.append(f"RSI:RSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}") + + return { + "message": "\n".join(message_parts), + "metadata": {"type": "technical", "data": data} + } + + def _format_fundamental(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化基本面""" + message = f""" +【{data.get('name', '股票')}】基本信息 +股票代码:{data.get('ts_code', '')} +所属地域:{data.get('area', '')} +所属行业:{data.get('industry', '')} +上市市场:{data.get('market', '')} +上市日期:{data.get('list_date', '')} + """.strip() + + return { + "message": message, + "metadata": {"type": "fundamental", "data": data} + } + + def _format_visualization(self, data: Dict[str, Any]) -> Dict[str, Any]: + """格式化可视化""" + return { + "message": f"已生成{data.get('stock_code', '')}的K线图", + "metadata": {"type": "chart", "data": data} + } + + +# 创建全局Agent实例 +enhanced_agent = EnhancedStockAgent() diff --git a/backend/app/agent/skill_manager.py b/backend/app/agent/skill_manager.py new file mode 100644 index 0000000..759b512 --- /dev/null +++ b/backend/app/agent/skill_manager.py @@ -0,0 +1,179 @@ +""" +技能管理器 +管理所有技能的注册、发现和调用 +""" +from typing import Dict, Optional, List, Type +from app.skills.base import BaseSkill +from app.utils.logger import logger + + +class SkillManager: + """技能管理器""" + + def __init__(self): + """初始化技能管理器""" + self._skills: Dict[str, BaseSkill] = {} + logger.info("技能管理器初始化") + + def register(self, skill: BaseSkill) -> bool: + """ + 注册技能 + + Args: + skill: 技能实例 + + Returns: + 是否成功 + """ + if not skill.name: + logger.error("技能名称不能为空") + return False + + if skill.name in self._skills: + logger.warning(f"技能已存在,将被覆盖: {skill.name}") + + self._skills[skill.name] = skill + logger.info(f"技能注册成功: {skill.name}") + return True + + def unregister(self, skill_name: str) -> bool: + """ + 注销技能 + + Args: + skill_name: 技能名称 + + Returns: + 是否成功 + """ + if skill_name in self._skills: + del self._skills[skill_name] + logger.info(f"技能注销成功: {skill_name}") + return True + + logger.warning(f"技能不存在: {skill_name}") + return False + + def get_skill(self, skill_name: str) -> Optional[BaseSkill]: + """ + 获取技能 + + Args: + skill_name: 技能名称 + + Returns: + 技能实例或None + """ + return self._skills.get(skill_name) + + def get_all_skills(self) -> List[BaseSkill]: + """ + 获取所有技能 + + Returns: + 技能列表 + """ + return list(self._skills.values()) + + def get_enabled_skills(self) -> List[BaseSkill]: + """ + 获取所有启用的技能 + + Returns: + 启用的技能列表 + """ + return [skill for skill in self._skills.values() if skill.enabled] + + async def execute_skill(self, skill_name: str, **kwargs) -> Dict: + """ + 执行技能 + + Args: + skill_name: 技能名称 + **kwargs: 技能参数 + + Returns: + 执行结果 + """ + skill = self.get_skill(skill_name) + + if not skill: + return { + "success": False, + "error": f"技能不存在: {skill_name}" + } + + if not skill.enabled: + return { + "success": False, + "error": f"技能已禁用: {skill_name}" + } + + # 验证参数 + valid, error = skill.validate_params(**kwargs) + if not valid: + return { + "success": False, + "error": error + } + + # 执行技能 + try: + result = await skill.execute(**kwargs) + return { + "success": True, + "data": result + } + except Exception as e: + logger.error(f"技能执行失败 {skill_name}: {e}") + return { + "success": False, + "error": str(e) + } + + def enable_skill(self, skill_name: str) -> bool: + """ + 启用技能 + + Args: + skill_name: 技能名称 + + Returns: + 是否成功 + """ + skill = self.get_skill(skill_name) + if skill: + skill.enable() + logger.info(f"技能已启用: {skill_name}") + return True + return False + + def disable_skill(self, skill_name: str) -> bool: + """ + 禁用技能 + + Args: + skill_name: 技能名称 + + Returns: + 是否成功 + """ + skill = self.get_skill(skill_name) + if skill: + skill.disable() + logger.info(f"技能已禁用: {skill_name}") + return True + return False + + def get_skills_info(self) -> List[Dict]: + """ + 获取所有技能信息 + + Returns: + 技能信息列表 + """ + return [skill.get_info() for skill in self._skills.values()] + + +# 创建全局技能管理器实例 +skill_manager = SkillManager() diff --git a/backend/app/agent/smart_agent.py b/backend/app/agent/smart_agent.py new file mode 100644 index 0000000..2a47f9c --- /dev/null +++ b/backend/app/agent/smart_agent.py @@ -0,0 +1,966 @@ +""" +智能Agent - 真正使用LLM进行全面分析 +""" +import re +import json +from typing import Dict, Any, Optional, List +from app.config import get_settings +from app.agent.context import ContextManager +from app.agent.skill_manager import skill_manager +from app.skills.market_data import MarketDataSkill +from app.skills.technical_analysis import TechnicalAnalysisSkill +from app.skills.fundamental import FundamentalSkill +from app.skills.visualization import VisualizationSkill +from app.skills.brave_search import BraveSearchSkill +from app.services.llm_service import llm_service +from app.services.tushare_service import tushare_service +from app.utils.logger import logger + + +class SmartStockAgent: + """智能股票分析Agent - 深度集成LLM""" + + def __init__(self): + """初始化Agent""" + self.context_manager = ContextManager() + self.settings = get_settings() + + # 注册技能 + self._register_skills() + + # 检查LLM是否可用 + self.use_llm = bool(self.settings.zhipuai_api_key) and llm_service.client is not None + + if self.use_llm: + logger.info("Smart Agent初始化完成(LLM深度集成模式 + Brave搜索)") + else: + logger.warning("Smart Agent初始化完成(规则模式,建议配置LLM)") + + def _register_skills(self): + """注册所有技能""" + skill_manager.register(MarketDataSkill()) + skill_manager.register(TechnicalAnalysisSkill()) + skill_manager.register(FundamentalSkill()) + skill_manager.register(VisualizationSkill()) + skill_manager.register(BraveSearchSkill()) + logger.info("技能注册完成(包含Brave搜索)") + + async def process_message( + self, + message: str, + session_id: str, + user_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + 处理用户消息(智能版) + + Args: + message: 用户消息 + session_id: 会话ID + user_id: 用户ID + + Returns: + 响应结果 + """ + logger.info(f"处理消息: {message[:50]}...") + + # 保存用户消息 + self.context_manager.add_message(session_id, "user", message) + + # 第一步:使用LLM理解问题意图 + intent_analysis = await self._analyze_question_intent(message) + + if not intent_analysis: + response = { + "message": "抱歉,我无法理解您的问题。请重新描述您的需求。", + "metadata": {"type": "error"} + } + self.context_manager.add_message(session_id, "assistant", response["message"]) + return response + + # 第二步:根据意图类型处理 + question_type = intent_analysis['type'] + + if question_type == 'stock_specific': + # 针对特定股票的问题 + response = await self._handle_stock_question(intent_analysis, message) + elif question_type == 'macro_finance': + # 宏观金融问题 + response = await self._handle_macro_question(intent_analysis, message) + elif question_type == 'knowledge': + # 金融知识问答 + response = await self._handle_knowledge_question(intent_analysis, message) + else: + response = { + "message": "抱歉,我暂时无法处理这类问题。", + "metadata": {"type": "error"} + } + + # 保存助手响应 + self.context_manager.add_message( + session_id, + "assistant", + response["message"], + metadata=response.get("metadata") + ) + + return response + + def _is_comprehensive_analysis(self, message: str) -> bool: + """ + 判断是否需要全面分析 + + 默认情况下,如果用户只是简单提到股票名称或代码,就进行全面分析 + 只有明确要求特定信息时(如"技术指标"、"K线图"等),才做单一查询 + """ + # 明确要求单一查询的关键词 + single_query_keywords = [ + "k线", "图表", "走势图", "kline", + "技术指标", "macd", "rsi", "均线", "kdj", + "基本面", "公司信息", "行业", + "实时行情", "价格", "涨跌" + ] + + # 如果明确要求单一查询,返回False + if any(keyword in message.lower() for keyword in single_query_keywords): + return False + + # 默认进行全面分析 + return True + + async def _comprehensive_analysis( + self, + stock_code: str, + stock_name: Optional[str], + message: str + ) -> Dict[str, Any]: + """ + 全面分析:整合多个数据源 + LLM深度分析 + + Args: + stock_code: 股票代码 + stock_name: 股票名称 + message: 用户消息 + + Returns: + 综合分析结果 + """ + logger.info(f"执行全面分析: {stock_code}") + + display_name = stock_name or stock_code + + # 1. 并行获取所有数据 + try: + # 获取实时行情 + quote_result = await skill_manager.execute_skill( + "market_data", + stock_code=stock_code, + data_type="quote" + ) + + # 获取技术指标 + technical_result = await skill_manager.execute_skill( + "technical_analysis", + stock_code=stock_code, + indicators=["ma", "macd", "rsi", "kdj"] + ) + + # 获取基本面 + fundamental_result = await skill_manager.execute_skill( + "fundamental", + stock_code=stock_code + ) + + # 获取最新新闻(Brave搜索) + search_query = f"{display_name} {stock_code} 股票 最新消息" + news_result = await skill_manager.execute_skill( + "brave_search", + query=search_query, + search_type="news", + count=5, + freshness="pw" # 过去一周 + ) + + # 整合数据 + all_data = { + "stock_code": stock_code, + "stock_name": display_name, + "quote": quote_result.get("data") if quote_result.get("success") else None, + "technical": technical_result.get("data") if technical_result.get("success") else None, + "fundamental": fundamental_result.get("data") if fundamental_result.get("success") else None, + "news": news_result.get("results") if news_result and not news_result.get("error") else None + } + + # 2. 使用LLM进行深度分析 + if self.use_llm: + analysis = await self._llm_comprehensive_analysis(all_data, message) + else: + analysis = self._rule_based_analysis(all_data) + + return { + "message": analysis, + "metadata": { + "type": "comprehensive", + "data": all_data + } + } + + except Exception as e: + logger.error(f"全面分析失败: {e}") + return { + "message": f"分析{display_name}时出错:{str(e)}", + "metadata": {"type": "error"} + } + + async def _llm_comprehensive_analysis( + self, + data: Dict[str, Any], + user_message: str + ) -> str: + """使用LLM进行深度综合分析""" + + # 获取当前时间 + from datetime import datetime + current_time = datetime.now().strftime("%Y-%m-%d %H:%M") + + # 获取行情数据的交易日期 + quote_date = "未知" + if data.get('quote') and data['quote'].get('trade_date'): + quote_date = data['quote']['trade_date'] + + # 构建新闻摘要 + news_summary = "" + news_source_info = "" + if data.get('news'): + news_summary = "\n【消息面分析】\n" + news_summary += f"数据来源:Brave Search API\n" + news_summary += f"搜索时间:{current_time}\n" + news_summary += f"新闻范围:过去一周内相关新闻\n\n" + for idx, news_item in enumerate(data['news'][:5], 1): + news_summary += f"{idx}. {news_item.get('title', '无标题')}\n" + news_summary += f" 来源: {news_item.get('source', '未知')}\n" + news_summary += f" 摘要: {news_item.get('description', '无描述')}\n" + news_summary += f" 发布时间: {news_item.get('published', '未知')}\n\n" + news_source_info = "(消息来源:Brave搜索引擎,数据可能存在延迟)" + else: + news_summary = "\n【消息面分析】\n暂无最新新闻数据\n" + + # 构建详细的分析提示 + prompt = f"""你是一位专业的股票分析师。请对{data['stock_name']}({data['stock_code']})进行全面分析,用简洁专业但易懂的语言回答。 + +用户问题:{user_message} + +【实时行情数据】 +数据来源:Tushare Pro API +交易日期:{quote_date} +{json.dumps(data.get('quote'), ensure_ascii=False, indent=2) if data.get('quote') else '数据获取失败'} + +【技术指标数据】 +数据来源:Tushare Pro API(基于历史K线数据计算) +计算截止日期:{quote_date} +{json.dumps(data.get('technical'), ensure_ascii=False, indent=2) if data.get('technical') else '数据获取失败'} + +【基本面数据】 +数据来源:Tushare Pro API +{json.dumps(data.get('fundamental'), ensure_ascii=False, indent=2) if data.get('fundamental') else '数据获取失败'} +{news_summary} + +请按以下结构进行分析,并在每个部分明确标注数据来源和时效性: + +## 一、基本面分析 +分段说明公司情况,每个要点独立成段: +- 第一段:公司主营业务和行业地位 +- 第二段:所属行业发展前景 +- 第三段:如果有新闻,简要分析对公司的影响{news_source_info} + +## 二、技术面分析(数据截止:{quote_date}) +使用清晰的分段结构,每个技术指标独立成段: + +**价格走势** +当前价格走势特征(上涨/下跌/震荡),结合成交量分析。 + +**均线系统** +短期均线(MA5、MA10)与长期均线(MA20、MA60)的位置关系,判断当前趋势(多头/空头/震荡)。 + +**MACD指标** +DIF和DEA的位置关系,MACD柱状图变化,判断动能强弱和买卖信号。 + +**RSI指标** +当前RSI值的位置,是否超买(>70)或超卖(<30),短期走势预判。 + +**支撑与压力** +关键支撑位和压力位的具体价格区间。 + +## 三、市场情绪分析 +分段分析市场情绪: +- 第一段:当前市场情绪(乐观/谨慎/悲观)及原因 +- 第二段:如果有新闻,分析是利好还是利空 +- 第三段:短期可能的催化因素 + +## 四、投资建议 +清晰分段,每个时间维度独立: + +**短期(1-2周)** +明确的操作建议(买入/持有/观望/减仓)及理由。 + +**中期(1-3个月)** +趋势判断和策略建议。 + +**长期(半年以上)** +投资价值评估。 + +**风险提示** +主要风险点和注意事项。 + +## 五、总结 +用一句话概括核心观点。 + +--- +**数据说明** +- 行情数据来源:Tushare Pro(截止{quote_date}) +- 技术指标:基于历史K线数据计算(截止{quote_date}) +- 新闻数据:Brave搜索(搜索时间{current_time},范围:过去一周) + +写作要求: +1. 语言简洁专业,避免过度修饰和比喻 +2. 专业术语后用括号简单解释,例如"RSI超买(指标>70,股价可能回调)" +3. **重要:每个分析点必须独立成段,段落之间用空行分隔** +4. **技术面分析部分,每个指标必须使用加粗标题(**标题**)并独立成段** +5. 分析要客观理性,基于数据而非情绪 +3. 分析要客观理性,基于数据而非情绪 +4. 结论要明确,不要模棱两可 +5. 控制在500-600字 +6. 最后必须声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。" +""" + + try: + analysis = llm_service.chat( + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=2000 + ) + + if analysis: + return f"【{data['stock_name']}({data['stock_code']}) - AI深度分析】\n\n{analysis}" + else: + return self._rule_based_analysis(data) + + except Exception as e: + logger.error(f"LLM分析失败: {e}") + return self._rule_based_analysis(data) + + async def _llm_single_analysis( + self, + intent: Dict[str, Any], + result: Dict[str, Any], + stock_code: str, + stock_name: Optional[str], + user_message: str + ) -> Dict[str, Any]: + """使用LLM对单一查询进行分析""" + data = result.get("data", result) + display_name = stock_name or stock_code + + # 根据查询类型构建不同的prompt + if intent["type"] == "technical": + prompt = f"""你是一位专业的股票分析师。用户询问了{display_name}({stock_code})的技术指标。 + +用户问题:{user_message} + +【技术指标数据】 +{json.dumps(data, ensure_ascii=False, indent=2)} + +请进行专业的技术分析: + +## 技术指标解读 +1. 均线系统分析: + - 短期均线(MA5、MA10)与长期均线(MA20、MA60)的位置关系 + - 判断当前趋势(多头/空头/震荡) + +2. MACD指标分析: + - DIF和DEA的位置关系 + - MACD柱状图的变化趋势 + - 判断动能强弱 + +3. RSI指标分析: + - 当前RSI值的位置(超买/超卖/中性) + - 短期可能的走势 + +4. KDJ指标分析(如有): + - K、D、J值的位置关系 + - 金叉/死叉信号 + +## 综合判断 +- 短期走势预判(1-2周) +- 关键支撑位和压力位 +- 操作建议(买入/持有/观望/减仓) + +## 风险提示 +- 主要技术风险点 + +写作要求: +1. 语言简洁专业,直接给出分析结论 +2. 基于数据进行分析,不要编造 +3. 控制在300-400字 +4. 最后声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。" +""" + + elif intent["type"] == "quote": + prompt = f"""你是一位专业的股票分析师。用户询问了{display_name}({stock_code})的实时行情。 + +用户问题:{user_message} + +【实时行情数据】 +{json.dumps(data, ensure_ascii=False, indent=2)} + +请进行专业的行情分析: + +## 行情解读 +1. 当日表现: + - 涨跌幅分析 + - 成交量分析 + - 振幅分析 + +2. 价格位置: + - 当前价格相对开盘价、最高价、最低价的位置 + - 判断多空力量对比 + +3. 短期判断: + - 当日走势特征 + - 短期可能的走势 + - 操作建议 + +写作要求: +1. 语言简洁专业,直接给出分析结论 +2. 基于数据进行分析,不要编造 +3. 控制在200-300字 +4. 最后声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。" +""" + + elif intent["type"] == "fundamental": + prompt = f"""你是一位专业的股票分析师。用户询问了{display_name}({stock_code})的基本面信息。 + +用户问题:{user_message} + +【基本面数据】 +{json.dumps(data, ensure_ascii=False, indent=2)} + +请进行专业的基本面分析: + +## 公司概况 +- 公司主营业务 +- 所属行业和地域 +- 上市时间和市场 + +## 行业分析 +- 所属行业的发展前景 +- 行业地位和竞争优势 + +## 投资价值 +- 基本面评估 +- 长期投资价值 +- 关注要点 + +写作要求: +1. 语言简洁专业,直接给出分析结论 +2. 基于数据进行分析,不要编造 +3. 控制在200-300字 +4. 最后声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。" +""" + + else: + # 其他类型,使用通用分析 + prompt = f"""你是一位专业的股票分析师。用户询问了{display_name}({stock_code})的相关信息。 + +用户问题:{user_message} + +【数据】 +{json.dumps(data, ensure_ascii=False, indent=2)} + +请基于提供的数据进行专业分析,给出有价值的见解和建议。 + +写作要求: +1. 语言简洁专业,直接给出分析结论 +2. 基于数据进行分析,不要编造 +3. 控制在200-300字 +4. 最后声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。" +""" + + try: + analysis = llm_service.chat( + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=1500 + ) + + if analysis: + return { + "message": f"【{display_name}({stock_code}) - AI分析】\n\n{analysis}", + "metadata": {"type": intent["type"], "data": data} + } + else: + # LLM失败,使用原始格式化 + return self._format_response(intent, result, stock_code, stock_name) + + except Exception as e: + logger.error(f"LLM单一分析失败: {e}") + return self._format_response(intent, result, stock_code, stock_name) + + def _rule_based_analysis(self, data: Dict[str, Any]) -> str: + """基于规则的分析(LLM不可用时的备选方案)""" + parts = [f"【{data['stock_name']}({data['stock_code']}) - 综合分析】\n"] + + # 行情信息 + if data.get('quote'): + quote = data['quote'] + parts.append("## 一、实时行情") + parts.append(f"最新价:{quote.get('close', 0):.2f}元") + parts.append(f"涨跌幅:{quote.get('pct_chg', 0):.2f}%") + parts.append(f"成交量:{quote.get('vol', 0):.0f}手") + parts.append("") + + # 技术分析 + if data.get('technical'): + tech = data['technical'].get('indicators', {}) + parts.append("## 二、技术指标") + + if 'ma' in tech: + ma = tech['ma'] + parts.append(f"均线系统:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}") + + if 'macd' in tech: + macd = tech['macd'] + parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}") + + if 'rsi' in tech: + rsi = tech['rsi'] + rsi6 = rsi.get('rsi6', 50) + if rsi6 > 70: + parts.append(f"RSI:{rsi6:.1f}(超买区域,注意回调风险)") + elif rsi6 < 30: + parts.append(f"RSI:{rsi6:.1f}(超卖区域,可能存在反弹机会)") + else: + parts.append(f"RSI:{rsi6:.1f}(中性区域)") + + parts.append("") + + # 基本面 + if data.get('fundamental'): + fund = data['fundamental'] + parts.append("## 三、基本信息") + parts.append(f"所属行业:{fund.get('industry', '未知')}") + parts.append(f"上市日期:{fund.get('list_date', '未知')}") + parts.append("") + + # 简单建议 + parts.append("## 四、参考建议") + parts.append("建议结合更多信息进行综合判断。") + parts.append("") + parts.append("⚠️ 以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。") + + return "\n".join(parts) + + async def _single_query( + self, + stock_code: str, + stock_name: Optional[str], + message: str + ) -> Dict[str, Any]: + """单一查询处理 - 使用LLM进行分析""" + # 识别意图 + intent = self._recognize_intent(message, stock_code) + + # 执行技能 + result = await skill_manager.execute_skill( + intent["skill"], + **intent["params"] + ) + + # 格式化响应 + if not result.get("success", True): + return { + "message": f"查询失败:{result.get('error', '未知错误')}", + "metadata": {"type": "error"} + } + + # 所有查询都使用LLM进行分析(除了可视化) + if intent["type"] != "visualization" and self.use_llm: + return await self._llm_single_analysis(intent, result, stock_code, stock_name, message) + else: + return self._format_response(intent, result, stock_code, stock_name) + + def _recognize_intent(self, message: str, stock_code: str) -> Dict[str, Any]: + """识别查询意图""" + message_lower = message.lower() + + # K线图 + if any(kw in message_lower for kw in ["k线", "图表", "走势图", "kline"]): + return { + "type": "visualization", + "skill": "visualization", + "params": {"stock_code": stock_code} + } + + # 技术分析 + if any(kw in message_lower for kw in ["技术", "指标", "macd", "rsi", "均线"]): + return { + "type": "technical", + "skill": "technical_analysis", + "params": {"stock_code": stock_code, "indicators": ["ma", "macd", "rsi"]} + } + + # 基本面 + if any(kw in message_lower for kw in ["基本面", "公司", "行业", "信息"]): + return { + "type": "fundamental", + "skill": "fundamental", + "params": {"stock_code": stock_code} + } + + # 默认:实时行情 + return { + "type": "quote", + "skill": "market_data", + "params": {"stock_code": stock_code, "data_type": "quote"} + } + + def _format_response( + self, + intent: Dict[str, Any], + result: Dict[str, Any], + stock_code: str, + stock_name: Optional[str] + ) -> Dict[str, Any]: + """格式化响应""" + data = result.get("data", result) + display_name = stock_name or stock_code + + if intent["type"] == "quote": + message = f"""【{display_name}】实时行情 + +交易日期:{data.get('trade_date', '')} +最新价:{data.get('close', 0):.2f}元 +涨跌幅:{data.get('pct_chg', 0):+.2f}% +涨跌额:{data.get('change', 0):+.2f}元 +开盘价:{data.get('open', 0):.2f}元 +最高价:{data.get('high', 0):.2f}元 +最低价:{data.get('low', 0):.2f}元 +成交量:{data.get('vol', 0):.0f}手 +成交额:{data.get('amount', 0):.0f}千元""" + + return { + "message": message, + "metadata": {"type": "quote", "data": data} + } + + elif intent["type"] == "technical": + indicators = data.get("indicators", {}) + parts = [f"【{display_name}】技术指标\n"] + + if "ma" in indicators: + ma = indicators["ma"] + parts.append(f"均线:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}") + + if "macd" in indicators: + macd = indicators["macd"] + parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}") + + if "rsi" in indicators: + rsi = indicators["rsi"] + parts.append(f"RSI:RSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}") + + return { + "message": "\n".join(parts), + "metadata": {"type": "technical", "data": data} + } + + elif intent["type"] == "visualization": + return { + "message": f"已生成【{display_name}】的K线图", + "metadata": {"type": "chart", "data": data} + } + + elif intent["type"] == "fundamental": + message = f"""【{display_name}】基本信息 + +股票代码:{data.get('ts_code', '')} +所属地域:{data.get('area', '')} +所属行业:{data.get('industry', '')} +上市市场:{data.get('market', '')} +上市日期:{data.get('list_date', '')}""" + + return { + "message": message, + "metadata": {"type": "fundamental", "data": data} + } + + return { + "message": "查询完成", + "metadata": {"type": "data", "data": data} + } + + async def _analyze_question_intent(self, message: str) -> Optional[Dict[str, Any]]: + """ + 使用LLM分析问题意图 + + Args: + message: 用户消息 + + Returns: + 意图分析结果: { + 'type': 'stock_specific' | 'macro_finance' | 'knowledge', + 'description': '问题描述', + 'keywords': ['关键词列表'], + 'stock_names': ['股票名称'] (如果是stock_specific类型) + } + """ + if not self.use_llm: + logger.warning("LLM未配置,无法分析意图") + return None + + prompt = f"""分析用户的金融问题,判断问题类型和关键信息。 + +用户问题:{message} + +请分析这个问题属于以下哪一类: + +1. **stock_specific** - 针对特定股票的问题 + 例如:"贵州茅台怎么样"、"分析一下比亚迪"、"600519的技术指标" + +2. **macro_finance** - 宏观金融问题(不针对特定股票) + 例如:"现在A股市场怎么样"、"最近有什么投资机会"、"如何看待当前经济形势" + +3. **knowledge** - 金融知识问答 + 例如:"什么是MACD"、"如何看K线图"、"价值投资是什么" + +请以JSON格式返回分析结果: +{{ + "type": "问题类型", + "description": "问题的简要描述", + "keywords": ["关键词1", "关键词2"], + "stock_names": ["股票名称"] (仅当type为stock_specific时) +}} + +只返回JSON,不要有任何其他内容。""" + + try: + result = llm_service.chat( + messages=[{"role": "user", "content": prompt}], + temperature=0.3, + max_tokens=300 + ) + + if not result: + logger.warning("LLM返回空结果") + return None + + # 清理结果,移除可能的markdown代码块标记 + result = result.strip() + if result.startswith("```json"): + result = result[7:] + if result.startswith("```"): + result = result[3:] + if result.endswith("```"): + result = result[:-3] + result = result.strip() + + # 检查是否为空 + if not result: + logger.warning("LLM返回内容为空") + return None + + # 解析JSON + intent = json.loads(result) + logger.info(f"意图分析结果: {intent}") + return intent + + except json.JSONDecodeError as e: + logger.error(f"意图分析JSON解析失败: {e}, 原始响应: {result[:200] if result else 'None'}") + return None + except Exception as e: + logger.error(f"意图分析失败: {e}") + return None + + async def _handle_stock_question( + self, + intent_analysis: Dict[str, Any], + message: str + ) -> Dict[str, Any]: + """处理针对特定股票的问题""" + stock_names = intent_analysis.get('stock_names', []) + + if not stock_names: + return { + "message": "抱歉,我没有识别到您提到的股票。请提供更明确的股票代码或名称。", + "metadata": {"type": "error"} + } + + # 提取第一个股票(暂时只处理单只股票) + stock_keyword = stock_names[0] + + # 使用Tushare搜索股票 + search_results = tushare_service.search_stock(stock_keyword) + + if not search_results: + return { + "message": f"抱歉,未找到股票\"{stock_keyword}\"。请确认股票名称或代码是否正确。", + "metadata": {"type": "error"} + } + + stock = search_results[0] + stock_code = stock['symbol'] + stock_name = stock['name'] + + logger.info(f"处理股票问题: {stock_name}({stock_code})") + + # 判断是否需要全面分析 + is_comprehensive = self._is_comprehensive_analysis(message) + + if is_comprehensive: + return await self._comprehensive_analysis(stock_code, stock_name, message) + else: + return await self._single_query(stock_code, stock_name, message) + + async def _handle_macro_question( + self, + intent_analysis: Dict[str, Any], + message: str + ) -> Dict[str, Any]: + """处理宏观金融问题""" + keywords = intent_analysis.get('keywords', []) + description = intent_analysis.get('description', '') + + logger.info(f"处理宏观问题: {description}") + + # 使用Brave搜索获取最新信息 + search_query = f"A股市场 {' '.join(keywords)} 最新分析" + + try: + news_result = await skill_manager.execute_skill( + "brave_search", + query=search_query, + search_type="news", + count=5, + freshness="pw" + ) + + # 构建新闻摘要 + news_summary = "" + if news_result and not news_result.get("error"): + results = news_result.get("results", []) + if results: + news_summary = "\n【最新市场动态】\n" + for idx, news_item in enumerate(results[:5], 1): + news_summary += f"{idx}. {news_item.get('title', '无标题')}\n" + news_summary += f" 来源: {news_item.get('source', '未知')}\n" + news_summary += f" 时间: {news_item.get('published', '未知')}\n\n" + + # 使用LLM进行分析 + prompt = f"""你是一位专业的金融分析师。用户询问了宏观金融问题。 + +用户问题:{message} + +问题分析:{description} +关键词:{', '.join(keywords)} +{news_summary} + +请基于当前市场情况和最新动态,给出专业的分析和建议: + +## 市场现状分析 +- 当前市场整体情况 +- 主要影响因素 + +## 趋势判断 +- 短期趋势 +- 中长期展望 + +## 投资建议 +- 投资策略建议 +- 风险提示 + +写作要求: +1. 语言简洁专业,避免过度修饰 +2. 分析要客观理性,基于事实 +3. 控制在400-500字 +4. 最后声明:"以上分析仅供参考,不构成投资建议。股市有风险,投资需谨慎。" +""" + + analysis = llm_service.chat( + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=1500 + ) + + if analysis: + return { + "message": f"【宏观市场分析】\n\n{analysis}", + "metadata": {"type": "macro_analysis"} + } + + except Exception as e: + logger.error(f"宏观问题处理失败: {e}") + + return { + "message": "抱歉,暂时无法获取相关信息。请稍后再试。", + "metadata": {"type": "error"} + } + + async def _handle_knowledge_question( + self, + intent_analysis: Dict[str, Any], + message: str + ) -> Dict[str, Any]: + """处理金融知识问答""" + description = intent_analysis.get('description', '') + keywords = intent_analysis.get('keywords', []) + + logger.info(f"处理知识问答: {description}") + + # 直接使用LLM回答 + prompt = f"""你是一位专业的金融教育专家。用户询问了金融知识问题。 + +用户问题:{message} + +请用通俗易懂的语言解释这个概念或回答这个问题: + +## 核心概念 +- 清晰定义和解释 + +## 实际应用 +- 如何在投资中应用 +- 注意事项 + +## 举例说明 +- 用简单的例子帮助理解 + +写作要求: +1. 语言通俗易懂,避免过多专业术语 +2. 如果使用专业术语,要简单解释 +3. 控制在300-400字 +4. 重点是帮助用户理解,而不是炫耀知识 +""" + + try: + answer = llm_service.chat( + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=1200 + ) + + if answer: + return { + "message": f"【金融知识解答】\n\n{answer}", + "metadata": {"type": "knowledge"} + } + + except Exception as e: + logger.error(f"知识问答处理失败: {e}") + + return { + "message": "抱歉,暂时无法回答您的问题。请稍后再试。", + "metadata": {"type": "error"} + } + + +# 创建全局实例 +smart_agent = SmartStockAgent() diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/api/chat.py b/backend/app/api/chat.py new file mode 100644 index 0000000..3ab1c9e --- /dev/null +++ b/backend/app/api/chat.py @@ -0,0 +1,67 @@ +""" +对话API路由 +""" +from fastapi import APIRouter, HTTPException +from typing import Optional +import uuid +from app.models.chat import ChatRequest, ChatResponse +from app.agent.smart_agent import smart_agent # 使用智能Agent +from app.utils.logger import logger + +router = APIRouter() + + +@router.post("/message", response_model=ChatResponse) +async def send_message(request: ChatRequest): + """ + 发送消息给Agent + + Args: + request: 聊天请求 + + Returns: + Agent响应 + """ + try: + # 生成或使用现有session_id + session_id = request.session_id or str(uuid.uuid4()) + + # 处理消息(使用智能Agent) + response = await smart_agent.process_message( + message=request.message, + session_id=session_id, + user_id=request.user_id + ) + + return ChatResponse( + message=response["message"], + session_id=session_id, + metadata=response.get("metadata") + ) + + except Exception as e: + logger.error(f"处理消息失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/history/{session_id}") +async def get_history(session_id: str, limit: int = 50): + """ + 获取对话历史 + + Args: + session_id: 会话ID + limit: 最大消息数 + + Returns: + 对话历史 + """ + try: + context = smart_agent.context_manager.get_context(session_id) + return { + "session_id": session_id, + "messages": context + } + except Exception as e: + logger.error(f"获取历史失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/app/api/skills.py b/backend/app/api/skills.py new file mode 100644 index 0000000..2f89890 --- /dev/null +++ b/backend/app/api/skills.py @@ -0,0 +1,99 @@ +""" +技能管理API路由 +""" +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from app.agent.skill_manager import skill_manager +from app.utils.logger import logger + +router = APIRouter() + + +class ToggleRequest(BaseModel): + enabled: bool + + +@router.get("/") +async def list_skills(): + """ + 获取所有技能列表 + + Returns: + 技能信息列表 + """ + try: + skills_info = skill_manager.get_skills_info() + return skills_info # 直接返回数组 + except Exception as e: + logger.error(f"获取技能列表失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{skill_name}/toggle") +async def toggle_skill(skill_name: str, request: ToggleRequest): + """ + 切换技能状态 + + Args: + skill_name: 技能名称 + request: 包含enabled字段的请求体 + + Returns: + 操作结果 + """ + try: + if request.enabled: + success = skill_manager.enable_skill(skill_name) + message = f"技能 {skill_name} 已启用" + else: + success = skill_manager.disable_skill(skill_name) + message = f"技能 {skill_name} 已禁用" + + if not success: + raise HTTPException(status_code=404, detail="技能不存在") + return {"message": message, "success": True} + except Exception as e: + logger.error(f"切换技能失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{skill_name}/enable") +async def enable_skill(skill_name: str): + """ + 启用技能 + + Args: + skill_name: 技能名称 + + Returns: + 操作结果 + """ + try: + success = skill_manager.enable_skill(skill_name) + if not success: + raise HTTPException(status_code=404, detail="技能不存在") + return {"message": f"技能 {skill_name} 已启用"} + except Exception as e: + logger.error(f"启用技能失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{skill_name}/disable") +async def disable_skill(skill_name: str): + """ + 禁用技能 + + Args: + skill_name: 技能名称 + + Returns: + 操作结果 + """ + try: + success = skill_manager.disable_skill(skill_name) + if not success: + raise HTTPException(status_code=404, detail="技能不存在") + return {"message": f"技能 {skill_name} 已禁用"} + except Exception as e: + logger.error(f"禁用技能失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/app/api/stock.py b/backend/app/api/stock.py new file mode 100644 index 0000000..4533117 --- /dev/null +++ b/backend/app/api/stock.py @@ -0,0 +1,80 @@ +""" +股票数据API路由 +""" +from fastapi import APIRouter, HTTPException, Query +from typing import Optional +from app.services.tushare_service import tushare_service +from app.utils.logger import logger + +router = APIRouter() + + +@router.get("/quote/{stock_code}") +async def get_quote(stock_code: str): + """ + 获取股票实时行情 + + Args: + stock_code: 股票代码 + + Returns: + 行情数据 + """ + try: + quote = tushare_service.get_realtime_quote(stock_code) + if not quote: + raise HTTPException(status_code=404, detail="未找到股票数据") + return quote + except Exception as e: + logger.error(f"获取行情失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/kline/{stock_code}") +async def get_kline( + stock_code: str, + start_date: Optional[str] = Query(None, description="开始日期YYYYMMDD"), + end_date: Optional[str] = Query(None, description="结束日期YYYYMMDD"), + period: str = Query("D", description="周期D/W/M") +): + """ + 获取K线数据 + + Args: + stock_code: 股票代码 + start_date: 开始日期 + end_date: 结束日期 + period: 周期 + + Returns: + K线数据 + """ + try: + kline = tushare_service.get_kline_data(stock_code, start_date, end_date, period) + if not kline: + raise HTTPException(status_code=404, detail="未找到K线数据") + return {"kline_data": kline} + except Exception as e: + logger.error(f"获取K线失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/basic/{stock_code}") +async def get_basic(stock_code: str): + """ + 获取股票基本信息 + + Args: + stock_code: 股票代码 + + Returns: + 基本信息 + """ + try: + basic = tushare_service.get_stock_basic(stock_code) + if not basic: + raise HTTPException(status_code=404, detail="未找到股票信息") + return basic + except Exception as e: + logger.error(f"获取基本信息失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/app/config.py b/backend/app/config.py new file mode 100644 index 0000000..22baedc --- /dev/null +++ b/backend/app/config.py @@ -0,0 +1,69 @@ +""" +配置管理模块 +从环境变量加载配置 +""" +import os +from pathlib import Path +from typing import Optional +from pydantic_settings import BaseSettings +from functools import lru_cache + + +# 查找.env文件的位置 +def find_env_file(): + """查找.env文件,支持从backend目录或项目根目录启动""" + current_dir = Path.cwd() + + # 尝试当前目录 + env_path = current_dir / ".env" + if env_path.exists(): + return str(env_path) + + # 尝试父目录(项目根目录) + env_path = current_dir.parent / ".env" + if env_path.exists(): + return str(env_path) + + # 尝试backend的父目录 + if current_dir.name == "backend": + env_path = current_dir.parent / ".env" + if env_path.exists(): + return str(env_path) + + # 默认返回当前目录的.env + return ".env" + + +class Settings(BaseSettings): + """应用配置""" + + # Tushare配置 + tushare_token: str = "" + + # 智谱AI配置 + zhipuai_api_key: str = "" + + # 数据库配置 + database_url: str = "sqlite:///./stock_agent.db" + + # API配置 + api_host: str = "0.0.0.0" + api_port: int = 8000 + debug: bool = True + + # 安全配置 + secret_key: str = "change-this-secret-key-in-production" + rate_limit: str = "100/minute" + + # CORS配置 + cors_origins: str = "http://localhost:8000,http://127.0.0.1:8000" + + class Config: + env_file = find_env_file() + case_sensitive = False + + +@lru_cache() +def get_settings() -> Settings: + """获取配置单例""" + return Settings() diff --git a/backend/app/main.py b/backend/app/main.py new file mode 100644 index 0000000..4de4af7 --- /dev/null +++ b/backend/app/main.py @@ -0,0 +1,70 @@ +""" +FastAPI主应用 +""" +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse +from app.config import get_settings +from app.utils.logger import logger +from app.api import chat, stock, skills +import os + +# 创建FastAPI应用 +app = FastAPI( + title="A股AI分析Agent系统", + description="基于AI Agent的股票智能分析系统", + version="1.0.0" +) + +# 配置CORS +settings = get_settings() +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins.split(","), + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 注册路由 +app.include_router(chat.router, prefix="/api/chat", tags=["对话"]) +app.include_router(stock.router, prefix="/api/stock", tags=["股票数据"]) +app.include_router(skills.router, prefix="/api/skills", tags=["技能管理"]) + +# 挂载静态文件 +frontend_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "frontend") +if os.path.exists(frontend_path): + app.mount("/static", StaticFiles(directory=frontend_path), name="static") + +@app.get("/") +async def root(): + """根路径,返回前端页面""" + index_path = os.path.join(frontend_path, "index.html") + if os.path.exists(index_path): + return FileResponse(index_path) + return {"message": "A股AI分析Agent系统API"} + +@app.get("/health") +async def health_check(): + """健康检查""" + return {"status": "healthy"} + +@app.on_event("startup") +async def startup_event(): + """启动事件""" + logger.info("应用启动") + +@app.on_event("shutdown") +async def shutdown_event(): + """关闭事件""" + logger.info("应用关闭") + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app.main:app", + host=settings.api_host, + port=settings.api_port, + reload=settings.debug + ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/chat.py b/backend/app/models/chat.py new file mode 100644 index 0000000..923dd22 --- /dev/null +++ b/backend/app/models/chat.py @@ -0,0 +1,37 @@ +""" +对话相关的Pydantic模型 +""" +from datetime import datetime +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field + + +class ChatMessage(BaseModel): + """聊天消息""" + role: str = Field(..., description="角色:user或assistant") + content: str = Field(..., description="消息内容") + metadata: Optional[Dict[str, Any]] = Field(None, description="元数据") + + +class ChatRequest(BaseModel): + """聊天请求""" + message: str = Field(..., description="用户消息", min_length=1) + session_id: Optional[str] = Field(None, description="会话ID") + user_id: Optional[str] = Field(None, description="用户ID") + + +class ChatResponse(BaseModel): + """聊天响应""" + message: str = Field(..., description="助手回复") + session_id: str = Field(..., description="会话ID") + metadata: Optional[Dict[str, Any]] = Field(None, description="元数据") + + +class ConversationHistory(BaseModel): + """对话历史""" + session_id: str + messages: list[ChatMessage] + created_at: datetime + + class Config: + from_attributes = True diff --git a/backend/app/models/database.py b/backend/app/models/database.py new file mode 100644 index 0000000..a201259 --- /dev/null +++ b/backend/app/models/database.py @@ -0,0 +1,47 @@ +""" +数据库模型定义 +""" +from datetime import datetime +from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class Conversation(Base): + """对话记录表""" + __tablename__ = "conversations" + + id = Column(Integer, primary_key=True, index=True) + session_id = Column(String(64), nullable=False, index=True) + user_id = Column(String(64), nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + # 关联消息 + messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") + + +class Message(Base): + """消息记录表""" + __tablename__ = "messages" + + id = Column(Integer, primary_key=True, index=True) + conversation_id = Column(Integer, ForeignKey("conversations.id"), nullable=False) + role = Column(String(20), nullable=False) # 'user' or 'assistant' + content = Column(Text, nullable=False) + msg_metadata = Column(JSON, nullable=True) # 改名避免与SQLAlchemy保留字冲突 + created_at = Column(DateTime, default=datetime.utcnow) + + # 关联对话 + conversation = relationship("Conversation", back_populates="messages") + + +class UserPreference(Base): + """用户偏好表""" + __tablename__ = "user_preferences" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(String(64), unique=True, nullable=False, index=True) + preferences = Column(JSON, nullable=True) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) diff --git a/backend/app/models/stock.py b/backend/app/models/stock.py new file mode 100644 index 0000000..bf99c6e --- /dev/null +++ b/backend/app/models/stock.py @@ -0,0 +1,48 @@ +""" +股票相关的Pydantic模型 +""" +from datetime import date +from typing import Optional, List +from pydantic import BaseModel, Field + + +class StockQuote(BaseModel): + """股票行情""" + ts_code: str = Field(..., description="股票代码") + name: Optional[str] = Field(None, description="股票名称") + trade_date: Optional[str] = Field(None, description="交易日期") + open: Optional[float] = Field(None, description="开盘价") + high: Optional[float] = Field(None, description="最高价") + low: Optional[float] = Field(None, description="最低价") + close: Optional[float] = Field(None, description="收盘价") + pre_close: Optional[float] = Field(None, description="昨收价") + change: Optional[float] = Field(None, description="涨跌额") + pct_chg: Optional[float] = Field(None, description="涨跌幅%") + vol: Optional[float] = Field(None, description="成交量(手)") + amount: Optional[float] = Field(None, description="成交额(千元)") + + +class KLineData(BaseModel): + """K线数据""" + ts_code: str = Field(..., description="股票代码") + trade_date: str = Field(..., description="交易日期") + open: float = Field(..., description="开盘价") + high: float = Field(..., description="最高价") + low: float = Field(..., description="最低价") + close: float = Field(..., description="收盘价") + vol: float = Field(..., description="成交量") + amount: Optional[float] = Field(None, description="成交额") + + +class TechnicalIndicators(BaseModel): + """技术指标""" + ma5: Optional[List[float]] = Field(None, description="5日均线") + ma10: Optional[List[float]] = Field(None, description="10日均线") + ma20: Optional[List[float]] = Field(None, description="20日均线") + macd_dif: Optional[List[float]] = Field(None, description="MACD DIF") + macd_dea: Optional[List[float]] = Field(None, description="MACD DEA") + macd: Optional[List[float]] = Field(None, description="MACD柱") + rsi: Optional[List[float]] = Field(None, description="RSI") + kdj_k: Optional[List[float]] = Field(None, description="KDJ K值") + kdj_d: Optional[List[float]] = Field(None, description="KDJ D值") + kdj_j: Optional[List[float]] = Field(None, description="KDJ J值") diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/cache_service.py b/backend/app/services/cache_service.py new file mode 100644 index 0000000..482908b --- /dev/null +++ b/backend/app/services/cache_service.py @@ -0,0 +1,144 @@ +""" +缓存服务 +提供数据缓存功能(使用内存缓存) +""" +import time +from typing import Optional, Any, Dict +from app.utils.logger import logger + + +class CacheService: + """内存缓存服务类""" + + def __init__(self): + """初始化内存缓存""" + self._cache: Dict[str, tuple[Any, float]] = {} # key: (value, expire_time) + logger.info("内存缓存初始化成功") + + def get(self, key: str) -> Optional[Any]: + """ + 获取缓存数据 + + Args: + key: 缓存键 + + Returns: + 缓存的数据,不存在或过期返回None + """ + try: + if key in self._cache: + value, expire_time = self._cache[key] + # 检查是否过期 + if time.time() < expire_time: + return value + else: + # 删除过期数据 + del self._cache[key] + return None + except Exception as e: + logger.error(f"获取缓存失败: {e}") + return None + + def set(self, key: str, value: Any, ttl: int = 3600) -> bool: + """ + 设置缓存数据 + + Args: + key: 缓存键 + value: 要缓存的数据 + ttl: 过期时间(秒) + + Returns: + 是否成功 + """ + try: + expire_time = time.time() + ttl + self._cache[key] = (value, expire_time) + return True + except Exception as e: + logger.error(f"设置缓存失败: {e}") + return False + + def delete(self, key: str) -> bool: + """ + 删除缓存 + + Args: + key: 缓存键 + + Returns: + 是否成功 + """ + try: + if key in self._cache: + del self._cache[key] + return True + except Exception as e: + logger.error(f"删除缓存失败: {e}") + return False + + def exists(self, key: str) -> bool: + """ + 检查缓存是否存在 + + Args: + key: 缓存键 + + Returns: + 是否存在 + """ + try: + if key in self._cache: + _, expire_time = self._cache[key] + if time.time() < expire_time: + return True + else: + del self._cache[key] + return False + except Exception as e: + logger.error(f"检查缓存失败: {e}") + return False + + def clear_pattern(self, pattern: str) -> int: + """ + 清除匹配模式的所有缓存 + + Args: + pattern: 键模式(如 "stock:*") + + Returns: + 删除的键数量 + """ + try: + # 简单的模式匹配(支持*通配符) + pattern = pattern.replace('*', '') + keys_to_delete = [k for k in self._cache.keys() if pattern in k] + + for key in keys_to_delete: + del self._cache[key] + + return len(keys_to_delete) + except Exception as e: + logger.error(f"清除缓存失败: {e}") + return 0 + + def clear_expired(self): + """清除所有过期的缓存""" + try: + current_time = time.time() + expired_keys = [ + key for key, (_, expire_time) in self._cache.items() + if current_time >= expire_time + ] + + for key in expired_keys: + del self._cache[key] + + if expired_keys: + logger.info(f"清除了{len(expired_keys)}个过期缓存") + except Exception as e: + logger.error(f"清除过期缓存失败: {e}") + + +# 创建全局实例 +cache_service = CacheService() diff --git a/backend/app/services/db_service.py b/backend/app/services/db_service.py new file mode 100644 index 0000000..f26a0ca --- /dev/null +++ b/backend/app/services/db_service.py @@ -0,0 +1,210 @@ +""" +数据库服务 +提供数据库操作功能 +""" +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from typing import Optional, List +from datetime import datetime +import uuid + +from app.config import get_settings +from app.models.database import Base, Conversation, Message, UserPreference +from app.utils.logger import logger + + +class DatabaseService: + """数据库服务类""" + + def __init__(self): + """初始化数据库连接""" + settings = get_settings() + self.engine = create_engine( + settings.database_url, + connect_args={"check_same_thread": False} if "sqlite" in settings.database_url else {} + ) + self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) + + # 创建表 + Base.metadata.create_all(bind=self.engine) + logger.info("数据库初始化成功") + + def get_session(self) -> Session: + """获取数据库会话""" + return self.SessionLocal() + + def create_conversation(self, session_id: Optional[str] = None, user_id: Optional[str] = None) -> Conversation: + """ + 创建新对话 + + Args: + session_id: 会话ID(可选,自动生成) + user_id: 用户ID + + Returns: + 对话对象 + """ + db = self.get_session() + try: + if not session_id: + session_id = str(uuid.uuid4()) + + conversation = Conversation( + session_id=session_id, + user_id=user_id + ) + db.add(conversation) + db.commit() + db.refresh(conversation) + return conversation + finally: + db.close() + + def get_conversation(self, session_id: str) -> Optional[Conversation]: + """ + 获取对话 + + Args: + session_id: 会话ID + + Returns: + 对话对象或None + """ + db = self.get_session() + try: + return db.query(Conversation).filter(Conversation.session_id == session_id).first() + finally: + db.close() + + def add_message( + self, + session_id: str, + role: str, + content: str, + metadata: Optional[dict] = None + ) -> Message: + """ + 添加消息 + + Args: + session_id: 会话ID + role: 角色(user/assistant) + content: 消息内容 + metadata: 元数据 + + Returns: + 消息对象 + """ + db = self.get_session() + try: + # 获取或创建对话 + conversation = db.query(Conversation).filter( + Conversation.session_id == session_id + ).first() + + if not conversation: + conversation = Conversation(session_id=session_id) + db.add(conversation) + db.commit() + db.refresh(conversation) + + # 创建消息 + message = Message( + conversation_id=conversation.id, + role=role, + content=content, + msg_metadata=metadata + ) + db.add(message) + db.commit() + db.refresh(message) + return message + finally: + db.close() + + def get_conversation_history(self, session_id: str, limit: int = 50) -> List[Message]: + """ + 获取对话历史 + + Args: + session_id: 会话ID + limit: 最大消息数 + + Returns: + 消息列表 + """ + db = self.get_session() + try: + conversation = db.query(Conversation).filter( + Conversation.session_id == session_id + ).first() + + if not conversation: + return [] + + messages = db.query(Message).filter( + Message.conversation_id == conversation.id + ).order_by(Message.created_at.desc()).limit(limit).all() + + return list(reversed(messages)) + finally: + db.close() + + def get_user_preference(self, user_id: str) -> Optional[dict]: + """ + 获取用户偏好 + + Args: + user_id: 用户ID + + Returns: + 偏好字典或None + """ + db = self.get_session() + try: + pref = db.query(UserPreference).filter( + UserPreference.user_id == user_id + ).first() + return pref.preferences if pref else None + finally: + db.close() + + def set_user_preference(self, user_id: str, preferences: dict) -> bool: + """ + 设置用户偏好 + + Args: + user_id: 用户ID + preferences: 偏好字典 + + Returns: + 是否成功 + """ + db = self.get_session() + try: + pref = db.query(UserPreference).filter( + UserPreference.user_id == user_id + ).first() + + if pref: + pref.preferences = preferences + pref.updated_at = datetime.utcnow() + else: + pref = UserPreference( + user_id=user_id, + preferences=preferences + ) + db.add(pref) + + db.commit() + return True + except Exception as e: + logger.error(f"设置用户偏好失败: {e}") + db.rollback() + return False + finally: + db.close() + + +# 创建全局实例 +db_service = DatabaseService() diff --git a/backend/app/services/llm_service.py b/backend/app/services/llm_service.py new file mode 100644 index 0000000..03faaf2 --- /dev/null +++ b/backend/app/services/llm_service.py @@ -0,0 +1,175 @@ +""" +LLM服务 - 智谱AI GLM-4集成 +""" +from typing import Optional, List, Dict, Any +from app.config import get_settings +from app.utils.logger import logger + +try: + from zhipuai import ZhipuAI + ZHIPUAI_AVAILABLE = True +except ImportError: + ZHIPUAI_AVAILABLE = False + logger.warning("zhipuai包未安装,LLM功能将不可用") + + +class LLMService: + """LLM服务类""" + + def __init__(self): + """初始化LLM服务""" + settings = get_settings() + + if not ZHIPUAI_AVAILABLE: + logger.warning("智谱AI SDK未安装") + self.client = None + return + + if not settings.zhipuai_api_key: + logger.warning("智谱AI API Key未配置") + self.client = None + return + + try: + self.client = ZhipuAI(api_key=settings.zhipuai_api_key) + logger.info("智谱AI LLM服务初始化成功") + except Exception as e: + logger.error(f"智谱AI初始化失败: {e}") + self.client = None + + def chat( + self, + messages: List[Dict[str, str]], + model: str = "glm-4", + temperature: float = 0.7, + max_tokens: int = 2000 + ) -> Optional[str]: + """ + 调用LLM进行对话 + + Args: + messages: 消息列表 [{"role": "user", "content": "..."}] + model: 模型名称 + temperature: 温度参数 + max_tokens: 最大token数 + + Returns: + LLM响应文本 + """ + if not self.client: + logger.error("LLM客户端未初始化") + return None + + try: + logger.info(f"调用LLM: model={model}, messages={len(messages)}条") + response = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens + ) + + if response.choices: + content = response.choices[0].message.content + logger.info(f"LLM响应成功,长度: {len(content) if content else 0}") + return content + else: + logger.warning("LLM响应中没有choices") + return None + + except Exception as e: + logger.error(f"LLM调用失败: {type(e).__name__}: {e}") + import traceback + logger.error(f"详细错误: {traceback.format_exc()}") + return None + + def analyze_intent(self, user_message: str) -> Dict[str, Any]: + """ + 使用LLM分析用户意图 + + Args: + user_message: 用户消息 + + Returns: + 意图分析结果 + """ + if not self.client: + return {"type": "unknown", "confidence": 0} + + prompt = f"""你是一个股票分析助手的意图识别模块。请分析用户的查询意图。 + +用户消息:{user_message} + +请识别以下意图类型之一: +1. market_data - 查询实时行情、价格 +2. technical_analysis - 技术分析、技术指标 +3. fundamental - 基本面信息、公司信息 +4. visualization - K线图、图表 +5. unknown - 无法识别 + +请以JSON格式返回: +{{ + "type": "意图类型", + "confidence": 0.0-1.0, + "stock_name": "提取的股票名称(如果有)" +}} +""" + + try: + response = self.chat([{"role": "user", "content": prompt}], temperature=0.3) + if response: + import json + # 尝试解析JSON + result = json.loads(response) + return result + except Exception as e: + logger.error(f"意图分析失败: {e}") + + return {"type": "unknown", "confidence": 0} + + def generate_analysis_summary( + self, + stock_code: str, + stock_name: str, + data: Dict[str, Any] + ) -> str: + """ + 使用LLM生成分析总结 + + Args: + stock_code: 股票代码 + stock_name: 股票名称 + data: 分析数据 + + Returns: + 分析总结文本 + """ + if not self.client: + return "LLM服务不可用,无法生成智能分析" + + prompt = f"""你是一个专业的股票分析师。请根据以下数据对{stock_name}({stock_code})进行分析总结。 + +数据: +{data} + +请提供: +1. 当前状态评估 +2. 技术指标解读 +3. 投资建议(仅供参考) + +注意: +- 使用专业但易懂的语言 +- 控制在200字以内 +- 必须声明"仅供参考,不构成投资建议" +""" + + try: + response = self.chat([{"role": "user", "content": prompt}], temperature=0.7) + return response or "分析生成失败" + except Exception as e: + logger.error(f"分析总结生成失败: {e}") + return "分析生成失败" + + +# 创建全局实例 +llm_service = LLMService() diff --git a/backend/app/services/tushare_service.py b/backend/app/services/tushare_service.py new file mode 100644 index 0000000..8df573e --- /dev/null +++ b/backend/app/services/tushare_service.py @@ -0,0 +1,263 @@ +""" +Tushare数据服务 +封装Tushare API调用 +""" +import tushare as ts +import pandas as pd +from typing import Optional, List +from datetime import datetime, timedelta +from app.config import get_settings +from app.utils.logger import logger +from app.utils.validators import normalize_stock_code + + +class TushareService: + """Tushare数据服务类""" + + def __init__(self): + """初始化Tushare服务""" + settings = get_settings() + if not settings.tushare_token: + logger.warning("Tushare token未配置") + self.pro = None + else: + ts.set_token(settings.tushare_token) + self.pro = ts.pro_api() + logger.info("Tushare服务初始化成功") + + def get_realtime_quote(self, stock_code: str) -> Optional[dict]: + """ + 获取实时行情 + + Args: + stock_code: 股票代码 + + Returns: + 行情数据字典 + """ + if not self.pro: + logger.error("Tushare服务未初始化") + return None + + try: + # 标准化股票代码 + ts_code = normalize_stock_code(stock_code) + if not ts_code: + logger.error(f"无效的股票代码: {stock_code}") + return None + + # 获取最新交易日数据 + df = self.pro.daily(ts_code=ts_code, start_date='', end_date='') + + if df.empty: + logger.warning(f"未找到股票数据: {ts_code}") + return None + + # 取最新一条 + latest = df.iloc[0] + + # 获取股票名称 + stock_info = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,name') + name = stock_info.iloc[0]['name'] if not stock_info.empty else None + + return { + 'ts_code': ts_code, + 'name': name, + 'trade_date': latest['trade_date'], + 'open': float(latest['open']), + 'high': float(latest['high']), + 'low': float(latest['low']), + 'close': float(latest['close']), + 'pre_close': float(latest['pre_close']), + 'change': float(latest['change']), + 'pct_chg': float(latest['pct_chg']), + 'vol': float(latest['vol']), + 'amount': float(latest['amount']) + } + + except Exception as e: + logger.error(f"获取实时行情失败: {e}") + return None + + def get_kline_data( + self, + stock_code: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + period: str = 'D' + ) -> Optional[List[dict]]: + """ + 获取K线数据 + + Args: + stock_code: 股票代码 + start_date: 开始日期(YYYYMMDD) + end_date: 结束日期(YYYYMMDD) + period: 周期(D=日,W=周,M=月) + + Returns: + K线数据列表 + """ + if not self.pro: + logger.error("Tushare服务未初始化") + return None + + try: + # 标准化股票代码 + ts_code = normalize_stock_code(stock_code) + if not ts_code: + logger.error(f"无效的股票代码: {stock_code}") + return None + + # 默认获取最近60个交易日 + if not start_date: + start_date = (datetime.now() - timedelta(days=90)).strftime('%Y%m%d') + if not end_date: + end_date = datetime.now().strftime('%Y%m%d') + + # 获取日线数据 + if period == 'D': + df = self.pro.daily( + ts_code=ts_code, + start_date=start_date, + end_date=end_date + ) + elif period == 'W': + df = self.pro.weekly( + ts_code=ts_code, + start_date=start_date, + end_date=end_date + ) + elif period == 'M': + df = self.pro.monthly( + ts_code=ts_code, + start_date=start_date, + end_date=end_date + ) + else: + logger.error(f"不支持的周期: {period}") + return None + + if df.empty: + logger.warning(f"未找到K线数据: {ts_code}") + return None + + # 按日期升序排列 + df = df.sort_values('trade_date') + + # 转换为字典列表 + kline_data = [] + for _, row in df.iterrows(): + kline_data.append({ + 'ts_code': ts_code, + 'trade_date': row['trade_date'], + 'open': float(row['open']), + 'high': float(row['high']), + 'low': float(row['low']), + 'close': float(row['close']), + 'vol': float(row['vol']), + 'amount': float(row['amount']) if pd.notna(row['amount']) else None + }) + + return kline_data + + except Exception as e: + logger.error(f"获取K线数据失败: {e}") + return None + + def get_stock_basic(self, stock_code: str) -> Optional[dict]: + """ + 获取股票基本信息 + + Args: + stock_code: 股票代码 + + Returns: + 基本信息字典 + """ + if not self.pro: + logger.error("Tushare服务未初始化") + return None + + try: + ts_code = normalize_stock_code(stock_code) + if not ts_code: + return None + + df = self.pro.stock_basic( + ts_code=ts_code, + fields='ts_code,symbol,name,area,industry,market,list_date' + ) + + if df.empty: + return None + + info = df.iloc[0] + return { + 'ts_code': info['ts_code'], + 'symbol': info['symbol'], + 'name': info['name'], + 'area': info['area'], + 'industry': info['industry'], + 'market': info['market'], + 'list_date': info['list_date'] + } + + except Exception as e: + logger.error(f"获取股票基本信息失败: {e}") + return None + + def search_stock(self, keyword: str) -> Optional[List[dict]]: + """ + 搜索股票(通过名称或代码) + + Args: + keyword: 搜索关键词(股票名称或代码) + + Returns: + 匹配的股票列表 + """ + if not self.pro: + logger.error("Tushare服务未初始化") + return None + + try: + # 获取所有股票列表 + df = self.pro.stock_basic( + fields='ts_code,symbol,name,area,industry,market,list_date' + ) + + if df.empty: + return None + + # 搜索匹配的股票 + # 1. 精确匹配代码 + exact_match = df[df['symbol'] == keyword] + if not exact_match.empty: + return [exact_match.iloc[0].to_dict()] + + # 2. 模糊匹配名称 + name_match = df[df['name'].str.contains(keyword, na=False)] + if not name_match.empty: + results = [] + for _, row in name_match.iterrows(): + results.append(row.to_dict()) + return results[:5] # 最多返回5个结果 + + # 3. 模糊匹配代码 + code_match = df[df['symbol'].str.contains(keyword, na=False)] + if not code_match.empty: + results = [] + for _, row in code_match.iterrows(): + results.append(row.to_dict()) + return results[:5] + + return None + + except Exception as e: + logger.error(f"搜索股票失败: {e}") + return None + + +# 创建全局实例 +tushare_service = TushareService() diff --git a/backend/app/skills/__init__.py b/backend/app/skills/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/skills/base.py b/backend/app/skills/base.py new file mode 100644 index 0000000..90ea223 --- /dev/null +++ b/backend/app/skills/base.py @@ -0,0 +1,78 @@ +""" +技能基类 +所有技能插件的基类 +""" +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional +from pydantic import BaseModel, Field + + +class SkillParameter(BaseModel): + """技能参数定义""" + name: str = Field(..., description="参数名称") + type: str = Field(..., description="参数类型") + description: str = Field(..., description="参数描述") + required: bool = Field(True, description="是否必需") + default: Optional[Any] = Field(None, description="默认值") + + +class BaseSkill(ABC): + """技能基类""" + + def __init__(self): + """初始化技能""" + self.name: str = "" + self.description: str = "" + self.parameters: list[SkillParameter] = [] + self.enabled: bool = True + + @abstractmethod + async def execute(self, **kwargs) -> Dict[str, Any]: + """ + 执行技能 + + Args: + **kwargs: 技能参数 + + Returns: + 执行结果字典 + """ + pass + + def validate_params(self, **kwargs) -> tuple[bool, Optional[str]]: + """ + 验证参数 + + Args: + **kwargs: 参数字典 + + Returns: + (是否有效, 错误信息) + """ + for param in self.parameters: + if param.required and param.name not in kwargs: + return False, f"缺少必需参数: {param.name}" + + return True, None + + def get_info(self) -> Dict[str, Any]: + """ + 获取技能信息 + + Returns: + 技能信息字典 + """ + return { + "name": self.name, + "description": self.description, + "parameters": [p.dict() for p in self.parameters], + "enabled": self.enabled + } + + def enable(self): + """启用技能""" + self.enabled = True + + def disable(self): + """禁用技能""" + self.enabled = False diff --git a/backend/app/skills/brave_search.py b/backend/app/skills/brave_search.py new file mode 100644 index 0000000..4255bc2 --- /dev/null +++ b/backend/app/skills/brave_search.py @@ -0,0 +1,180 @@ +""" +Brave搜索技能 +提供网页搜索、新闻搜索等能力 +""" +import aiohttp +from typing import Dict, Any, List, Optional +from app.skills.base import BaseSkill, SkillParameter +from app.utils.logger import logger + + +class BraveSearchSkill(BaseSkill): + """Brave搜索技能""" + + def __init__(self, api_key: str = "BSAcaROCUmCAI0XsQWzxooWT74LFFX_"): + super().__init__() + self.name = "brave_search" + self.description = "使用Brave搜索引擎搜索网页、新闻、公司公告等实时信息" + self.api_key = api_key + self.base_url = "https://api.search.brave.com/res/v1" + + self.parameters = [ + SkillParameter( + name="query", + type="string", + description="搜索关键词", + required=True + ), + SkillParameter( + name="search_type", + type="string", + description="搜索类型:web(网页)、news(新闻)", + required=False, + default="web" + ), + SkillParameter( + name="count", + type="integer", + description="返回结果数量(1-20)", + required=False, + default=5 + ), + SkillParameter( + name="freshness", + type="string", + description="时效性:pd(过去一天)、pw(过去一周)、pm(过去一月)、py(过去一年)", + required=False, + default=None + ) + ] + + async def execute(self, **kwargs) -> Dict[str, Any]: + """ + 执行Brave搜索 + + Args: + query: 搜索关键词 + search_type: 搜索类型(web/news) + count: 结果数量 + freshness: 时效性过滤 + + Returns: + 搜索结果 + """ + query = kwargs.get("query") + search_type = kwargs.get("search_type", "web") + count = kwargs.get("count", 5) + freshness = kwargs.get("freshness") + + logger.info(f"Brave搜索: {query}, 类型: {search_type}") + + try: + if search_type == "news": + results = await self._search_news(query, count, freshness) + else: + results = await self._search_web(query, count, freshness) + + return { + "query": query, + "search_type": search_type, + "results": results, + "count": len(results) + } + + except Exception as e: + logger.error(f"Brave搜索失败: {e}") + return { + "error": f"搜索失败: {str(e)}" + } + + async def _search_web( + self, + query: str, + count: int = 5, + freshness: Optional[str] = None + ) -> List[Dict[str, Any]]: + """网页搜索""" + url = f"{self.base_url}/web/search" + + params = { + "q": query, + "count": min(count, 20), + "text_decorations": False, + "search_lang": "zh-hans" + } + + if freshness: + params["freshness"] = freshness + + headers = { + "Accept": "application/json", + "X-Subscription-Token": self.api_key + } + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + raise Exception(f"API请求失败: {response.status}, {error_text}") + + data = await response.json() + + # 解析结果 + results = [] + web_results = data.get("web", {}).get("results", []) + + for item in web_results[:count]: + results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "description": item.get("description", ""), + "published": item.get("age", "") + }) + + return results + + async def _search_news( + self, + query: str, + count: int = 5, + freshness: Optional[str] = None + ) -> List[Dict[str, Any]]: + """新闻搜索""" + url = f"{self.base_url}/news/search" + + params = { + "q": query, + "count": min(count, 20), + "search_lang": "zh-hans" + } + + if freshness: + params["freshness"] = freshness + + headers = { + "Accept": "application/json", + "X-Subscription-Token": self.api_key + } + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + raise Exception(f"API请求失败: {response.status}, {error_text}") + + data = await response.json() + + # 解析结果 + results = [] + news_results = data.get("results", []) + + for item in news_results[:count]: + results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "description": item.get("description", ""), + "published": item.get("age", ""), + "source": item.get("meta_url", {}).get("hostname", "") + }) + + return results diff --git a/backend/app/skills/fundamental.py b/backend/app/skills/fundamental.py new file mode 100644 index 0000000..715b58d --- /dev/null +++ b/backend/app/skills/fundamental.py @@ -0,0 +1,61 @@ +""" +基本面分析技能 +提供股票基本信息查询 +""" +from typing import Dict, Any +from app.skills.base import BaseSkill, SkillParameter +from app.services.tushare_service import tushare_service +from app.services.cache_service import cache_service +from app.utils.logger import logger + + +class FundamentalSkill(BaseSkill): + """基本面分析技能""" + + def __init__(self): + super().__init__() + self.name = "fundamental" + self.description = "查询股票基本面信息(公司概况、行业、上市日期等)" + self.parameters = [ + SkillParameter( + name="stock_code", + type="string", + description="股票代码", + required=True + ) + ] + + async def execute(self, **kwargs) -> Dict[str, Any]: + """ + 执行基本面查询 + + Args: + stock_code: 股票代码 + + Returns: + 基本面信息 + """ + stock_code = kwargs.get("stock_code") + + logger.info(f"查询基本面信息: {stock_code}") + + # 尝试从缓存获取 + cache_key = f"fundamental:{stock_code}" + cached_data = cache_service.get(cache_key) + + if cached_data: + logger.info(f"从缓存获取基本面信息: {stock_code}") + return cached_data + + # 从Tushare获取 + basic_info = tushare_service.get_stock_basic(stock_code) + + if not basic_info: + return { + "error": f"未找到股票基本信息: {stock_code}" + } + + # 缓存1天 + cache_service.set(cache_key, basic_info, ttl=86400) + + return basic_info diff --git a/backend/app/skills/market_data.py b/backend/app/skills/market_data.py new file mode 100644 index 0000000..4263018 --- /dev/null +++ b/backend/app/skills/market_data.py @@ -0,0 +1,140 @@ +""" +行情查询技能 +提供股票实时行情和K线数据查询 +""" +from typing import Dict, Any +from app.skills.base import BaseSkill, SkillParameter +from app.services.tushare_service import tushare_service +from app.services.cache_service import cache_service +from app.utils.logger import logger + + +class MarketDataSkill(BaseSkill): + """行情查询技能""" + + def __init__(self): + super().__init__() + self.name = "market_data" + self.description = "查询股票实时行情和历史K线数据" + self.parameters = [ + SkillParameter( + name="stock_code", + type="string", + description="股票代码(如600000、000001)", + required=True + ), + SkillParameter( + name="data_type", + type="string", + description="数据类型:quote(实时行情)或kline(K线数据)", + required=False, + default="quote" + ), + SkillParameter( + name="start_date", + type="string", + description="开始日期(YYYYMMDD格式,仅K线数据需要)", + required=False + ), + SkillParameter( + name="end_date", + type="string", + description="结束日期(YYYYMMDD格式,仅K线数据需要)", + required=False + ), + SkillParameter( + name="period", + type="string", + description="K线周期:D(日线)、W(周线)、M(月线)", + required=False, + default="D" + ) + ] + + async def execute(self, **kwargs) -> Dict[str, Any]: + """ + 执行行情查询 + + Args: + stock_code: 股票代码 + data_type: 数据类型(quote/kline) + start_date: 开始日期(可选) + end_date: 结束日期(可选) + period: K线周期(可选) + + Returns: + 查询结果 + """ + stock_code = kwargs.get("stock_code") + data_type = kwargs.get("data_type", "quote") + + logger.info(f"查询行情数据: {stock_code}, 类型: {data_type}") + + if data_type == "quote": + return await self._get_quote(stock_code) + elif data_type == "kline": + start_date = kwargs.get("start_date") + end_date = kwargs.get("end_date") + period = kwargs.get("period", "D") + return await self._get_kline(stock_code, start_date, end_date, period) + else: + return { + "error": f"不支持的数据类型: {data_type}" + } + + async def _get_quote(self, stock_code: str) -> Dict[str, Any]: + """获取实时行情""" + # 尝试从缓存获取 + cache_key = f"quote:{stock_code}" + cached_data = cache_service.get(cache_key) + + if cached_data: + logger.info(f"从缓存获取行情: {stock_code}") + return cached_data + + # 从Tushare获取 + quote_data = tushare_service.get_realtime_quote(stock_code) + + if not quote_data: + return { + "error": f"未找到股票数据: {stock_code}" + } + + # 缓存30秒 + cache_service.set(cache_key, quote_data, ttl=30) + + return quote_data + + async def _get_kline( + self, + stock_code: str, + start_date: str = None, + end_date: str = None, + period: str = "D" + ) -> Dict[str, Any]: + """获取K线数据""" + # 尝试从缓存获取 + cache_key = f"kline:{stock_code}:{start_date}:{end_date}:{period}" + cached_data = cache_service.get(cache_key) + + if cached_data: + logger.info(f"从缓存获取K线: {stock_code}") + return {"kline_data": cached_data} + + # 从Tushare获取 + kline_data = tushare_service.get_kline_data( + stock_code, + start_date, + end_date, + period + ) + + if not kline_data: + return { + "error": f"未找到K线数据: {stock_code}" + } + + # 缓存1小时 + cache_service.set(cache_key, kline_data, ttl=3600) + + return {"kline_data": kline_data} diff --git a/backend/app/skills/technical_analysis.py b/backend/app/skills/technical_analysis.py new file mode 100644 index 0000000..438b530 --- /dev/null +++ b/backend/app/skills/technical_analysis.py @@ -0,0 +1,202 @@ +""" +技术分析技能 +提供技术指标计算和分析 +""" +import pandas as pd +from typing import Dict, Any +from app.skills.base import BaseSkill, SkillParameter +from app.services.tushare_service import tushare_service +from app.utils.indicators import ( + calculate_ma, calculate_macd, calculate_rsi, + calculate_kdj, calculate_boll +) +from app.utils.logger import logger + + +class TechnicalAnalysisSkill(BaseSkill): + """技术分析技能""" + + def __init__(self): + super().__init__() + self.name = "technical_analysis" + self.description = "计算股票技术指标(MA、MACD、RSI、KDJ、BOLL等)" + self.parameters = [ + SkillParameter( + name="stock_code", + type="string", + description="股票代码", + required=True + ), + SkillParameter( + name="indicators", + type="array", + description="要计算的指标列表(ma、macd、rsi、kdj、boll)", + required=False, + default=["ma", "macd"] + ), + SkillParameter( + name="period", + type="integer", + description="数据周期(天数)", + required=False, + default=60 + ) + ] + + async def execute(self, **kwargs) -> Dict[str, Any]: + """ + 执行技术分析 + + Args: + stock_code: 股票代码 + indicators: 指标列表 + period: 数据周期 + + Returns: + 技术指标结果 + """ + stock_code = kwargs.get("stock_code") + indicators = kwargs.get("indicators", ["ma", "macd"]) + period = kwargs.get("period", 60) + + logger.info(f"技术分析: {stock_code}, 指标: {indicators}") + + # 获取K线数据 + kline_data = tushare_service.get_kline_data(stock_code) + + if not kline_data: + return { + "error": f"未找到K线数据: {stock_code}" + } + + # 转换为DataFrame + df = pd.DataFrame(kline_data) + + # 计算指标 + result = { + "stock_code": stock_code, + "indicators": {} + } + + try: + if "ma" in indicators: + result["indicators"]["ma"] = self._calculate_ma(df) + + if "macd" in indicators: + result["indicators"]["macd"] = self._calculate_macd(df) + + if "rsi" in indicators: + result["indicators"]["rsi"] = self._calculate_rsi(df) + + if "kdj" in indicators: + result["indicators"]["kdj"] = self._calculate_kdj(df) + + if "boll" in indicators: + result["indicators"]["boll"] = self._calculate_boll(df) + + return result + + except Exception as e: + logger.error(f"技术指标计算失败: {e}") + return { + "error": f"技术指标计算失败: {str(e)}" + } + + def _calculate_ma(self, df: pd.DataFrame) -> Dict[str, Any]: + """计算均线""" + close = df['close'] + + ma5 = calculate_ma(close, 5) + ma10 = calculate_ma(close, 10) + ma20 = calculate_ma(close, 20) + ma60 = calculate_ma(close, 60) + + # 获取最新值 + latest_ma5 = ma5.iloc[-1] if not ma5.empty else None + latest_ma10 = ma10.iloc[-1] if not ma10.empty else None + latest_ma20 = ma20.iloc[-1] if not ma20.empty else None + latest_ma60 = ma60.iloc[-1] if not ma60.empty else None + + return { + "ma5": round(latest_ma5, 2) if latest_ma5 else None, + "ma10": round(latest_ma10, 2) if latest_ma10 else None, + "ma20": round(latest_ma20, 2) if latest_ma20 else None, + "ma60": round(latest_ma60, 2) if latest_ma60 else None, + "description": "移动平均线" + } + + def _calculate_macd(self, df: pd.DataFrame) -> Dict[str, Any]: + """计算MACD""" + close = df['close'] + + dif, dea, macd = calculate_macd(close) + + # 获取最新值 + latest_dif = dif.iloc[-1] if not dif.empty else None + latest_dea = dea.iloc[-1] if not dea.empty else None + latest_macd = macd.iloc[-1] if not macd.empty else None + + return { + "dif": round(latest_dif, 2) if latest_dif else None, + "dea": round(latest_dea, 2) if latest_dea else None, + "macd": round(latest_macd, 2) if latest_macd else None, + "description": "MACD指标" + } + + def _calculate_rsi(self, df: pd.DataFrame) -> Dict[str, Any]: + """计算RSI""" + close = df['close'] + + rsi6 = calculate_rsi(close, 6) + rsi12 = calculate_rsi(close, 12) + rsi24 = calculate_rsi(close, 24) + + # 获取最新值 + latest_rsi6 = rsi6.iloc[-1] if not rsi6.empty else None + latest_rsi12 = rsi12.iloc[-1] if not rsi12.empty else None + latest_rsi24 = rsi24.iloc[-1] if not rsi24.empty else None + + return { + "rsi6": round(latest_rsi6, 2) if latest_rsi6 else None, + "rsi12": round(latest_rsi12, 2) if latest_rsi12 else None, + "rsi24": round(latest_rsi24, 2) if latest_rsi24 else None, + "description": "相对强弱指标" + } + + def _calculate_kdj(self, df: pd.DataFrame) -> Dict[str, Any]: + """计算KDJ""" + high = df['high'] + low = df['low'] + close = df['close'] + + k, d, j = calculate_kdj(high, low, close) + + # 获取最新值 + latest_k = k.iloc[-1] if not k.empty else None + latest_d = d.iloc[-1] if not d.empty else None + latest_j = j.iloc[-1] if not j.empty else None + + return { + "k": round(latest_k, 2) if latest_k else None, + "d": round(latest_d, 2) if latest_d else None, + "j": round(latest_j, 2) if latest_j else None, + "description": "KDJ指标" + } + + def _calculate_boll(self, df: pd.DataFrame) -> Dict[str, Any]: + """计算布林带""" + close = df['close'] + + upper, middle, lower = calculate_boll(close) + + # 获取最新值 + latest_upper = upper.iloc[-1] if not upper.empty else None + latest_middle = middle.iloc[-1] if not middle.empty else None + latest_lower = lower.iloc[-1] if not lower.empty else None + + return { + "upper": round(latest_upper, 2) if latest_upper else None, + "middle": round(latest_middle, 2) if latest_middle else None, + "lower": round(latest_lower, 2) if latest_lower else None, + "description": "布林带" + } diff --git a/backend/app/skills/visualization.py b/backend/app/skills/visualization.py new file mode 100644 index 0000000..13bb88a --- /dev/null +++ b/backend/app/skills/visualization.py @@ -0,0 +1,118 @@ +""" +数据可视化技能 +生成图表配置数据 +""" +from typing import Dict, Any, List +from app.skills.base import BaseSkill, SkillParameter +from app.services.tushare_service import tushare_service +from app.utils.logger import logger + + +class VisualizationSkill(BaseSkill): + """数据可视化技能""" + + def __init__(self): + super().__init__() + self.name = "visualization" + self.description = "生成K线图和技术指标图表配置" + self.parameters = [ + SkillParameter( + name="stock_code", + type="string", + description="股票代码", + required=True + ), + SkillParameter( + name="chart_type", + type="string", + description="图表类型:candlestick(K线图)", + required=False, + default="candlestick" + ), + SkillParameter( + name="period", + type="integer", + description="数据周期(天数)", + required=False, + default=60 + ) + ] + + async def execute(self, **kwargs) -> Dict[str, Any]: + """ + 生成图表配置 + + Args: + stock_code: 股票代码 + chart_type: 图表类型 + period: 数据周期 + + Returns: + 图表配置数据 + """ + stock_code = kwargs.get("stock_code") + chart_type = kwargs.get("chart_type", "candlestick") + period = kwargs.get("period", 60) + + logger.info(f"生成图表配置: {stock_code}, 类型: {chart_type}") + + # 获取K线数据 + kline_data = tushare_service.get_kline_data(stock_code) + + if not kline_data: + return { + "error": f"未找到K线数据: {stock_code}" + } + + # 限制数据量 + if len(kline_data) > period: + kline_data = kline_data[-period:] + + if chart_type == "candlestick": + return self._generate_candlestick_config(kline_data) + else: + return { + "error": f"不支持的图表类型: {chart_type}" + } + + def _generate_candlestick_config(self, kline_data: List[dict]) -> Dict[str, Any]: + """ + 生成K线图配置(Lightweight Charts格式) + + Args: + kline_data: K线数据列表 + + Returns: + 图表配置 + """ + # 转换为Lightweight Charts格式 + candlestick_data = [] + volume_data = [] + + for item in kline_data: + # 转换日期格式 YYYYMMDD -> YYYY-MM-DD + date_str = item['trade_date'] + formatted_date = f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}" + + # K线数据 + candlestick_data.append({ + "time": formatted_date, + "open": item['open'], + "high": item['high'], + "low": item['low'], + "close": item['close'] + }) + + # 成交量数据 + volume_data.append({ + "time": formatted_date, + "value": item['vol'], + "color": "rgba(0, 150, 136, 0.8)" if item['close'] >= item['open'] else "rgba(255, 82, 82, 0.8)" + }) + + return { + "chart_type": "candlestick", + "candlestick_data": candlestick_data, + "volume_data": volume_data, + "stock_code": kline_data[0]['ts_code'] if kline_data else None + } diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/utils/indicators.py b/backend/app/utils/indicators.py new file mode 100644 index 0000000..06c0346 --- /dev/null +++ b/backend/app/utils/indicators.py @@ -0,0 +1,158 @@ +""" +技术指标计算模块 +提供常用技术指标的计算功能 +""" +import pandas as pd +import numpy as np +from typing import Tuple + + +def calculate_ma(data: pd.Series, period: int = 5) -> pd.Series: + """ + 计算移动平均线(MA) + + Args: + data: 价格数据 + period: 周期 + + Returns: + MA值 + """ + return data.rolling(window=period).mean() + + +def calculate_ema(data: pd.Series, period: int = 12) -> pd.Series: + """ + 计算指数移动平均线(EMA) + + Args: + data: 价格数据 + period: 周期 + + Returns: + EMA值 + """ + return data.ewm(span=period, adjust=False).mean() + + +def calculate_macd( + data: pd.Series, + fast_period: int = 12, + slow_period: int = 26, + signal_period: int = 9 +) -> Tuple[pd.Series, pd.Series, pd.Series]: + """ + 计算MACD指标 + + Args: + data: 价格数据 + fast_period: 快线周期 + slow_period: 慢线周期 + signal_period: 信号线周期 + + Returns: + (DIF, DEA, MACD柱) + """ + ema_fast = calculate_ema(data, fast_period) + ema_slow = calculate_ema(data, slow_period) + + dif = ema_fast - ema_slow + dea = dif.ewm(span=signal_period, adjust=False).mean() + macd = (dif - dea) * 2 + + return dif, dea, macd + + +def calculate_rsi(data: pd.Series, period: int = 14) -> pd.Series: + """ + 计算相对强弱指标(RSI) + + Args: + data: 价格数据 + period: 周期 + + Returns: + RSI值 + """ + delta = data.diff() + + gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() + loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() + + rs = gain / loss + rsi = 100 - (100 / (1 + rs)) + + return rsi + + +def calculate_kdj( + high: pd.Series, + low: pd.Series, + close: pd.Series, + period: int = 9, + m1: int = 3, + m2: int = 3 +) -> Tuple[pd.Series, pd.Series, pd.Series]: + """ + 计算KDJ指标 + + Args: + high: 最高价 + low: 最低价 + close: 收盘价 + period: 周期 + m1: K值平滑参数 + m2: D值平滑参数 + + Returns: + (K, D, J) + """ + low_min = low.rolling(window=period).min() + high_max = high.rolling(window=period).max() + + rsv = (close - low_min) / (high_max - low_min) * 100 + + k = rsv.ewm(com=m1 - 1, adjust=False).mean() + d = k.ewm(com=m2 - 1, adjust=False).mean() + j = 3 * k - 2 * d + + return k, d, j + + +def calculate_boll( + data: pd.Series, + period: int = 20, + std_dev: float = 2.0 +) -> Tuple[pd.Series, pd.Series, pd.Series]: + """ + 计算布林带(BOLL) + + Args: + data: 价格数据 + period: 周期 + std_dev: 标准差倍数 + + Returns: + (上轨, 中轨, 下轨) + """ + middle = data.rolling(window=period).mean() + std = data.rolling(window=period).std() + + upper = middle + (std * std_dev) + lower = middle - (std * std_dev) + + return upper, middle, lower + + +def calculate_volume_ma(volume: pd.Series, period: int = 5) -> pd.Series: + """ + 计算成交量移动平均 + + Args: + volume: 成交量数据 + period: 周期 + + Returns: + 成交量MA + """ + return volume.rolling(window=period).mean() diff --git a/backend/app/utils/logger.py b/backend/app/utils/logger.py new file mode 100644 index 0000000..4a6b1d9 --- /dev/null +++ b/backend/app/utils/logger.py @@ -0,0 +1,60 @@ +""" +日志工具模块 +提供统一的日志配置和记录功能 +""" +import logging +import sys +from pathlib import Path +from typing import Optional + + +def setup_logger( + name: str = "stock_agent", + level: int = logging.INFO, + log_file: Optional[str] = None +) -> logging.Logger: + """ + 配置并返回logger实例 + + Args: + name: logger名称 + level: 日志级别 + log_file: 日志文件路径(可选) + + Returns: + 配置好的logger实例 + """ + logger = logging.getLogger(name) + logger.setLevel(level) + + # 避免重复添加handler + if logger.handlers: + return logger + + # 日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 控制台handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # 文件handler(如果指定) + if log_file: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setLevel(level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +# 创建默认logger +logger = setup_logger() diff --git a/backend/app/utils/stock_names.py b/backend/app/utils/stock_names.py new file mode 100644 index 0000000..81a2963 --- /dev/null +++ b/backend/app/utils/stock_names.py @@ -0,0 +1,254 @@ +""" +股票名称映射数据库 +包含常见A股股票的名称到代码的映射 +""" +from typing import Optional + +# 常见A股股票名称映射(按行业分类) +STOCK_NAME_MAP = { + # 白酒 + "贵州茅台": "600519", + "茅台": "600519", + "五粮液": "000858", + "泸州老窖": "000568", + "山西汾酒": "600809", + "洋河股份": "002304", + + # 银行 + "工商银行": "601398", + "工行": "601398", + "建设银行": "601939", + "建行": "601939", + "农业银行": "601288", + "农行": "601288", + "中国银行": "601988", + "中行": "601988", + "交通银行": "601328", + "交行": "601328", + "招商银行": "600036", + "招行": "600036", + "兴业银行": "601166", + "浦发银行": "600000", + "民生银行": "600016", + "光大银行": "601818", + "平安银行": "000001", + "宁波银行": "002142", + + # 保险 + "中国平安": "601318", + "平安": "601318", + "中国人寿": "601628", + "中国太保": "601601", + "新华保险": "601336", + + # 证券 + "中信证券": "600030", + "中信": "600030", + "海通证券": "600837", + "国泰君安": "601211", + "华泰证券": "601688", + "广发证券": "000776", + "招商证券": "600999", + "东方证券": "600958", + + # 科技 + "中兴通讯": "000063", + "中兴": "000063", + "立讯精密": "002475", + "京东方A": "000725", + "京东方": "000725", + "TCL科技": "000100", + "海康威视": "002415", + "大华股份": "002236", + "科大讯飞": "002230", + "讯飞": "002230", + "紫光国微": "002049", + "中芯国际": "688981", + "韦尔股份": "603501", + + # 新能源汽车 + "比亚迪": "002594", + "宁德时代": "300750", + "宁德": "300750", + "长城汽车": "601633", + "长城": "601633", + "上汽集团": "600104", + "上汽": "600104", + "广汽集团": "601238", + "广汽": "601238", + "吉利汽车": "00175", # 港股 + "理想汽车": "02015", # 港股 + "小鹏汽车": "09868", # 港股 + "蔚来": "09866", # 港股 + + # 医药 + "恒瑞医药": "600276", + "恒瑞": "600276", + "药明康德": "603259", + "迈瑞医疗": "300760", + "迈瑞": "300760", + "片仔癀": "600436", + "云南白药": "000538", + "白药": "000538", + "爱尔眼科": "300015", + "智飞生物": "300122", + + # 消费 + "伊利股份": "600887", + "伊利": "600887", + "海天味业": "603288", + "海天": "603288", + "格力电器": "000651", + "格力": "000651", + "美的集团": "000333", + "美的": "000333", + "海尔智家": "600690", + "海尔": "600690", + "老板电器": "002508", + + # 地产 + "万科A": "000002", + "万科": "000002", + "保利发展": "600048", + "保利": "600048", + "招商蛇口": "001979", + "金地集团": "600383", + "金地": "600383", + + # 能源 + "中国石油": "601857", + "中石油": "601857", + "中国石化": "600028", + "中石化": "600028", + "中国神华": "601088", + "神华": "601088", + "陕西煤业": "601225", + "长江电力": "600900", + "三峡能源": "600905", + + # 通信 + "中国移动": "600941", + "移动": "600941", + "中国电信": "601728", + "电信": "601728", + "中国联通": "600050", + "联通": "600050", + "中国卫通": "601698", + "卫通": "601698", + + # 航空航天 + "中国国航": "601111", + "国航": "601111", + "南方航空": "600029", + "南航": "600029", + "东方航空": "600115", + "东航": "600115", + "中国卫星": "600118", + "航天科技": "000901", + + # 钢铁 + "宝钢股份": "600019", + "宝钢": "600019", + "河钢股份": "000709", + "河钢": "000709", + "鞍钢股份": "000898", + "鞍钢": "000898", + + # 有色金属 + "紫金矿业": "601899", + "紫金": "601899", + "中国铝业": "601600", + "中铝": "601600", + "江西铜业": "600362", + "江铜": "600362", + "洛阳钼业": "603993", + + # 化工 + "万华化学": "600309", + "万华": "600309", + "华鲁恒升": "600426", + "恒力石化": "600346", + "荣盛石化": "002493", + + # 电力设备 + "隆基绿能": "601012", + "隆基": "601012", + "阳光电源": "300274", + "通威股份": "600438", + "通威": "600438", + "特变电工": "600089", + + # 军工 + "中航沈飞": "600760", + "沈飞": "600760", + "中航西飞": "000768", + "西飞": "000768", + "中国船舶": "600150", + "中船": "600150", + "航发动力": "600893", + "航天发展": "000547", + + # 互联网 + "腾讯控股": "00700", # 港股 + "腾讯": "00700", + "阿里巴巴": "09988", # 港股 + "阿里": "09988", + "美团": "03690", # 港股 + "京东": "09618", # 港股 + "拼多多": "PDD", # 美股 + "百度": "09888", # 港股 + "网易": "09999", # 港股 + "小米集团": "01810", # 港股 + "小米": "01810", + + # 指数 + "上证指数": "000001", + "上证": "000001", + "沪指": "000001", + "深证成指": "399001", + "深成指": "399001", + "创业板指": "399006", + "创业板": "399006", + "科创50": "000688", + "沪深300": "000300", + "中证500": "000905", + "中证1000": "000852", +} + + +def search_stock_by_name(name: str) -> Optional[str]: + """ + 根据股票名称搜索代码 + + Args: + name: 股票名称或简称 + + Returns: + 股票代码或None + """ + # 精确匹配 + if name in STOCK_NAME_MAP: + return STOCK_NAME_MAP[name] + + # 模糊匹配(包含关系) + for stock_name, code in STOCK_NAME_MAP.items(): + if name in stock_name or stock_name in name: + return code + + return None + + +def get_stock_name(code: str) -> Optional[str]: + """ + 根据代码获取股票名称 + + Args: + code: 股票代码 + + Returns: + 股票名称或None + """ + for name, stock_code in STOCK_NAME_MAP.items(): + if stock_code == code: + return name + return None diff --git a/backend/app/utils/validators.py b/backend/app/utils/validators.py new file mode 100644 index 0000000..ce22d1b --- /dev/null +++ b/backend/app/utils/validators.py @@ -0,0 +1,103 @@ +""" +验证工具模块 +提供各种数据验证功能 +""" +import re +from typing import Optional + + +def validate_stock_code(code: str) -> bool: + """ + 验证股票代码格式 + + A股代码格式: + - 上海:6开头,6位数字 + - 深圳:0/3开头,6位数字 + - 创业板:3开头,6位数字 + - 科创板:688开头,6位数字 + + Args: + code: 股票代码 + + Returns: + 是否有效 + """ + if not code: + return False + + # 移除可能的后缀(如.SH, .SZ) + code = code.split('.')[0] + + # 检查是否为6位数字 + if not re.match(r'^\d{6}$', code): + return False + + # 检查首位数字 + first_digit = code[0] + if first_digit in ['0', '3', '6']: + return True + + # 检查科创板 + if code.startswith('688'): + return True + + return False + + +def normalize_stock_code(code: str) -> Optional[str]: + """ + 标准化股票代码,添加市场后缀 + + Args: + code: 股票代码 + + Returns: + 标准化后的代码(如600000.SH)或None + """ + if not validate_stock_code(code): + return None + + # 移除已有后缀 + code = code.split('.')[0] + + # 添加市场后缀 + if code.startswith('6'): + return f"{code}.SH" # 上海 + elif code.startswith(('0', '3')): + return f"{code}.SZ" # 深圳 + elif code.startswith('688'): + return f"{code}.SH" # 科创板 + + return None + + +def validate_date_format(date_str: str) -> bool: + """ + 验证日期格式(YYYYMMDD) + + Args: + date_str: 日期字符串 + + Returns: + 是否有效 + """ + if not date_str: + return False + + # 检查格式 + if not re.match(r'^\d{8}$', date_str): + return False + + # 简单的日期范围检查 + year = int(date_str[:4]) + month = int(date_str[4:6]) + day = int(date_str[6:8]) + + if year < 1990 or year > 2100: + return False + if month < 1 or month > 12: + return False + if day < 1 or day > 31: + return False + + return True diff --git a/backend/diagnose.sh b/backend/diagnose.sh new file mode 100755 index 0000000..b5418eb --- /dev/null +++ b/backend/diagnose.sh @@ -0,0 +1,100 @@ +#!/bin/bash + +# 诊断脚本 - 检查系统配置 + +echo "================================" +echo "系统诊断" +echo "================================" +echo "" + +cd /Users/aaron/source_code/Stock_Agent/backend + +# 1. 检查虚拟环境 +echo "1. 检查虚拟环境..." +if [ -d "venv" ]; then + echo " ✓ 虚拟环境存在" + source venv/bin/activate + python_version=$(python --version 2>&1) + echo " ✓ $python_version" +else + echo " ❌ 虚拟环境不存在" + exit 1 +fi + +# 2. 检查.env文件 +echo "" +echo "2. 检查配置文件..." +if [ -f "../.env" ]; then + echo " ✓ .env文件存在(项目根目录)" +elif [ -f ".env" ]; then + echo " ✓ .env文件存在(backend目录)" +else + echo " ❌ .env文件不存在" + exit 1 +fi + +# 3. 检查依赖包 +echo "" +echo "3. 检查依赖包..." +packages=("fastapi" "uvicorn" "tushare" "pandas" "numpy" "sqlalchemy" "pydantic") +all_installed=true + +for pkg in "${packages[@]}"; do + if python -c "import $pkg" 2>/dev/null; then + version=$(python -c "import $pkg; print($pkg.__version__)" 2>/dev/null || echo "unknown") + echo " ✓ $pkg ($version)" + else + echo " ❌ $pkg 未安装" + all_installed=false + fi +done + +if [ "$all_installed" = false ]; then + echo "" + echo "请运行: pip install -r requirements.txt" + exit 1 +fi + +# 4. 测试配置加载 +echo "" +echo "4. 测试配置加载..." +python -c " +from app.config import get_settings +settings = get_settings() +print(f' ✓ 配置加载成功') +print(f' - Tushare Token: {'已配置 (' + settings.tushare_token[:10] + '...)' if settings.tushare_token else '❌ 未配置'}') +print(f' - 智谱AI Key: {'已配置 (' + settings.zhipuai_api_key[:10] + '...)' if settings.zhipuai_api_key else '❌ 未配置'}') +" 2>&1 + +# 5. 测试模块导入 +echo "" +echo "5. 测试模块导入..." +modules=("app.models.database" "app.services.cache_service" "app.services.tushare_service" "app.agent.core") + +for module in "${modules[@]}"; do + if python -c "import $module" 2>/dev/null; then + echo " ✓ $module" + else + echo " ❌ $module 导入失败" + python -c "import $module" 2>&1 | head -5 + fi +done + +# 6. 检查端口占用 +echo "" +echo "6. 检查端口占用..." +if lsof -i :8000 >/dev/null 2>&1; then + echo " ⚠ 端口8000已被占用" + lsof -i :8000 +else + echo " ✓ 端口8000可用" +fi + +echo "" +echo "================================" +echo "诊断完成" +echo "================================" +echo "" +echo "如果所有检查都通过,可以运行:" +echo " ./start.sh" +echo "" diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..28e0024 --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,16 @@ +fastapi==0.109.0 +uvicorn[standard]==0.27.0 +langchain==0.1.0 +langchain-community==0.0.20 +zhipuai==2.0.1 +tushare==1.3.8 +sqlalchemy==2.0.25 +pydantic==2.5.3 +pydantic-settings==2.1.0 +python-dotenv==1.0.0 +slowapi==0.1.9 +websockets==12.0 +pandas>=2.2.0 +numpy>=1.26.0 +python-multipart==0.0.6 +aiohttp==3.9.1 diff --git a/backend/run.sh b/backend/run.sh new file mode 100755 index 0000000..7ed4eb0 --- /dev/null +++ b/backend/run.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# 最终启动检查和启动脚本 + +echo "================================" +echo "A股AI分析Agent - 最终检查" +echo "================================" +echo "" + +cd /Users/aaron/source_code/Stock_Agent/backend + +# 激活虚拟环境 +if [ ! -d "venv" ]; then + echo "❌ 虚拟环境不存在,请先运行 ../install.sh" + exit 1 +fi + +source venv/bin/activate + +# 快速导入测试 +echo "1. 测试模块导入..." +python3 << 'EOF' +try: + # 测试基础模块 + from app.config import get_settings + print(" ✓ 配置模块") + + from app.models.database import Base, Message + print(" ✓ 数据库模型") + + from app.services.cache_service import cache_service + print(" ✓ 缓存服务") + + from app.services.tushare_service import tushare_service + print(" ✓ Tushare服务") + + from app.utils.stock_names import search_stock_by_name + print(" ✓ 股票名称库") + + from app.services.llm_service import llm_service + print(" ✓ LLM服务") + + from app.agent.enhanced_agent import enhanced_agent + print(" ✓ 增强版Agent") + + print("\n所有模块导入成功!") + +except Exception as e: + print(f"\n❌ 导入失败: {e}") + import traceback + traceback.print_exc() + exit(1) +EOF + +if [ $? -ne 0 ]; then + echo "" + echo "模块导入失败,请检查错误信息" + exit 1 +fi + +# 检查配置 +echo "" +echo "2. 检查配置..." +python3 << 'EOF' +from app.config import get_settings +settings = get_settings() + +print(f" Tushare Token: {'✓ 已配置' if settings.tushare_token else '❌ 未配置'}") +print(f" 智谱AI Key: {'✓ 已配置' if settings.zhipuai_api_key else '❌ 未配置'}") +print(f" 数据库: {settings.database_url}") +print(f" 监听: {settings.api_host}:{settings.api_port}") + +if not settings.tushare_token: + print("\n⚠️ 警告: Tushare Token未配置,数据查询功能将不可用") +if not settings.zhipuai_api_key: + print("⚠️ 警告: 智谱AI Key未配置,将使用规则模式(无AI分析)") +EOF + +echo "" +echo "3. 测试股票名称识别..." +python3 << 'EOF' +from app.utils.stock_names import search_stock_by_name + +test_cases = [ + ("中国卫通", "601698"), + ("贵州茅台", "600519"), + ("比亚迪", "002594"), + ("宁德时代", "300750") +] + +for name, expected in test_cases: + result = search_stock_by_name(name) + status = "✓" if result == expected else "❌" + print(f" {status} {name} -> {result}") +EOF + +echo "" +echo "================================" +echo "检查完成!准备启动..." +echo "================================" +echo "" +echo "访问地址:" +echo " 前端: http://localhost:8000" +echo " API: http://localhost:8000/docs" +echo "" +echo "按 Ctrl+C 停止服务" +echo "" + +# 启动应用 +python3 -m app.main diff --git a/backend/start.sh b/backend/start.sh new file mode 100755 index 0000000..71ba758 --- /dev/null +++ b/backend/start.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# A股AI分析Agent系统 - 启动脚本(改进版) + +echo "================================" +echo "A股AI分析Agent系统" +echo "================================" +echo "" + +# 检查.env文件 +if [ ! -f "../.env" ] && [ ! -f ".env" ]; then + echo "❌ 错误: 未找到.env配置文件" + echo "" + echo "请先配置环境变量:" + echo " cd .." + echo " cp .env.example .env" + echo " # 编辑.env文件,填写API密钥" + exit 1 +fi + +# 检查虚拟环境 +if [ ! -d "venv" ]; then + echo "❌ 错误: 虚拟环境不存在" + echo "" + echo "请先运行安装脚本:" + echo " cd .." + echo " ./install.sh" + exit 1 +fi + +# 激活虚拟环境 +echo "激活虚拟环境..." +source venv/bin/activate + +# 检查Python版本 +python_version=$(python --version 2>&1 | awk '{print $2}') +echo "Python版本: $python_version" + +# 显示配置信息 +echo "" +echo "配置信息:" +python -c " +from app.config import get_settings +settings = get_settings() +print(f' Tushare Token: {'已配置' if settings.tushare_token else '未配置'}') +print(f' 智谱AI Key: {'已配置' if settings.zhipuai_api_key else '未配置'}') +print(f' 数据库: {settings.database_url}') +print(f' 监听地址: {settings.api_host}:{settings.api_port}') +" + +echo "" +echo "================================" +echo "启动服务..." +echo "================================" +echo "" +echo "访问地址:" +echo " 前端界面: http://localhost:8000" +echo " API文档: http://localhost:8000/docs" +echo "" +echo "按 Ctrl+C 停止服务" +echo "" + +# 启动应用 +python -m app.main diff --git a/backend/test_import.sh b/backend/test_import.sh new file mode 100755 index 0000000..f3cbaf5 --- /dev/null +++ b/backend/test_import.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# 测试应用启动 + +cd /Users/aaron/source_code/Stock_Agent/backend + +# 激活虚拟环境 +source venv/bin/activate + +# 测试导入 +echo "测试数据库模型..." +python3 -c "from app.models.database import Base, Message; print('✓ 数据库模型导入成功')" + +echo "" +echo "测试配置..." +python3 -c "from app.config import get_settings; print('✓ 配置加载成功')" + +echo "" +echo "测试服务..." +python3 -c "from app.services.cache_service import cache_service; print('✓ 缓存服务初始化成功')" + +echo "" +echo "测试Agent..." +python3 -c "from app.agent.core import stock_agent; print('✓ Agent初始化成功')" + +echo "" +echo "所有测试通过!可以启动应用了。" diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docs/DEPLOYMENT.md b/docs/DEPLOYMENT.md new file mode 100644 index 0000000..374e41c --- /dev/null +++ b/docs/DEPLOYMENT.md @@ -0,0 +1,491 @@ +# 部署文档 + +本文档介绍如何部署A股AI分析Agent系统到生产环境。 + +## 部署方式 + +### 方式一:本地部署 + +#### 1. 系统要求 + +- 操作系统:Linux/macOS/Windows +- Python 3.9+ +- Redis 6.0+(可选) +- 内存:至少2GB +- 磁盘:至少1GB + +#### 2. 安装步骤 + +```bash +# 1. 克隆代码 +git clone +cd Stock_Agent + +# 2. 创建虚拟环境 +python -m venv venv +source venv/bin/activate # Windows: venv\Scripts\activate + +# 3. 安装依赖 +cd backend +pip install -r requirements.txt + +# 4. 配置环境变量 +cp ../.env.example ../.env +# 编辑.env文件,填写必要的配置 + +# 5. 启动Redis(可选) +redis-server + +# 6. 启动应用 +python -m app.main +``` + +#### 3. 使用进程管理器 + +**使用Supervisor(推荐)** + +创建配置文件 `/etc/supervisor/conf.d/stock_agent.conf`: + +```ini +[program:stock_agent] +directory=/path/to/Stock_Agent/backend +command=/path/to/venv/bin/python -m app.main +user=your_user +autostart=true +autorestart=true +redirect_stderr=true +stdout_logfile=/var/log/stock_agent.log +``` + +启动服务: + +```bash +sudo supervisorctl reread +sudo supervisorctl update +sudo supervisorctl start stock_agent +``` + +**使用systemd** + +创建服务文件 `/etc/systemd/system/stock_agent.service`: + +```ini +[Unit] +Description=Stock Agent Service +After=network.target + +[Service] +Type=simple +User=your_user +WorkingDirectory=/path/to/Stock_Agent/backend +Environment="PATH=/path/to/venv/bin" +ExecStart=/path/to/venv/bin/python -m app.main +Restart=always + +[Install] +WantedBy=multi-user.target +``` + +启动服务: + +```bash +sudo systemctl daemon-reload +sudo systemctl enable stock_agent +sudo systemctl start stock_agent +``` + +### 方式二:Docker部署 + +#### 1. 创建Dockerfile + +在项目根目录创建 `Dockerfile`: + +```dockerfile +FROM python:3.9-slim + +WORKDIR /app + +# 安装依赖 +COPY backend/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制代码 +COPY backend/ ./backend/ +COPY frontend/ ./frontend/ +COPY .env .env + +# 暴露端口 +EXPOSE 8000 + +# 启动应用 +CMD ["python", "-m", "backend.app.main"] +``` + +#### 2. 创建docker-compose.yml + +```yaml +version: '3.8' + +services: + redis: + image: redis:6-alpine + ports: + - "6379:6379" + volumes: + - redis_data:/data + + stock_agent: + build: . + ports: + - "8000:8000" + environment: + - REDIS_HOST=redis + - REDIS_PORT=6379 + env_file: + - .env + depends_on: + - redis + volumes: + - ./backend:/app/backend + - ./frontend:/app/frontend + - ./stock_agent.db:/app/stock_agent.db + +volumes: + redis_data: +``` + +#### 3. 启动服务 + +```bash +# 构建镜像 +docker-compose build + +# 启动服务 +docker-compose up -d + +# 查看日志 +docker-compose logs -f + +# 停止服务 +docker-compose down +``` + +### 方式三:云服务器部署 + +#### 阿里云/腾讯云部署 + +1. **购买云服务器** + - 配置:2核4GB内存 + - 系统:Ubuntu 20.04 LTS + +2. **安全组配置** + - 开放8000端口(HTTP) + - 开放22端口(SSH) + +3. **安装环境** + +```bash +# 更新系统 +sudo apt update && sudo apt upgrade -y + +# 安装Python +sudo apt install python3.9 python3.9-venv python3-pip -y + +# 安装Redis +sudo apt install redis-server -y +sudo systemctl enable redis-server +sudo systemctl start redis-server + +# 安装Nginx(可选,用于反向代理) +sudo apt install nginx -y +``` + +4. **部署应用** + +按照"本地部署"步骤进行。 + +5. **配置Nginx反向代理** + +创建配置文件 `/etc/nginx/sites-available/stock_agent`: + +```nginx +server { + listen 80; + server_name your_domain.com; + + location / { + proxy_pass http://127.0.0.1:8000; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } + + location /static { + alias /path/to/Stock_Agent/frontend; + } +} +``` + +启用配置: + +```bash +sudo ln -s /etc/nginx/sites-available/stock_agent /etc/nginx/sites-enabled/ +sudo nginx -t +sudo systemctl reload nginx +``` + +6. **配置HTTPS(推荐)** + +使用Let's Encrypt免费证书: + +```bash +sudo apt install certbot python3-certbot-nginx -y +sudo certbot --nginx -d your_domain.com +``` + +## 生产环境配置 + +### 1. 环境变量配置 + +生产环境的 `.env` 配置: + +```env +# API密钥 +TUSHARE_TOKEN=your_production_token +ZHIPUAI_API_KEY=your_production_key + +# Redis +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_PASSWORD=your_redis_password + +# 数据库(生产环境建议使用PostgreSQL) +DATABASE_URL=postgresql://user:password@localhost/stock_agent + +# API设置 +API_HOST=0.0.0.0 +API_PORT=8000 +DEBUG=False + +# 安全 +SECRET_KEY=your_very_long_random_secret_key_here +RATE_LIMIT=100/minute + +# CORS(根据实际域名配置) +CORS_ORIGINS=https://your_domain.com +``` + +### 2. 数据库迁移到PostgreSQL + +安装PostgreSQL: + +```bash +sudo apt install postgresql postgresql-contrib -y +``` + +创建数据库: + +```sql +sudo -u postgres psql +CREATE DATABASE stock_agent; +CREATE USER stock_user WITH PASSWORD 'your_password'; +GRANT ALL PRIVILEGES ON DATABASE stock_agent TO stock_user; +\q +``` + +更新 `.env` 中的 `DATABASE_URL`。 + +### 3. 性能优化 + +**Redis配置优化** + +编辑 `/etc/redis/redis.conf`: + +```conf +maxmemory 256mb +maxmemory-policy allkeys-lru +save 900 1 +save 300 10 +save 60 10000 +``` + +**应用配置优化** + +在 `config.py` 中调整: + +```python +# 增加工作进程数 +workers = multiprocessing.cpu_count() * 2 + 1 + +# 调整超时时间 +timeout = 120 +``` + +### 4. 监控和日志 + +**日志配置** + +在 `utils/logger.py` 中配置日志文件: + +```python +logger = setup_logger( + name="stock_agent", + level=logging.INFO, + log_file="/var/log/stock_agent/app.log" +) +``` + +**日志轮转** + +创建 `/etc/logrotate.d/stock_agent`: + +``` +/var/log/stock_agent/*.log { + daily + rotate 7 + compress + delaycompress + notifempty + create 0640 your_user your_user + sharedscripts +} +``` + +**监控工具** + +推荐使用: +- Prometheus + Grafana(系统监控) +- Sentry(错误追踪) +- ELK Stack(日志分析) + +## 备份和恢复 + +### 数据库备份 + +**SQLite备份** + +```bash +# 备份 +cp stock_agent.db stock_agent_backup_$(date +%Y%m%d).db + +# 恢复 +cp stock_agent_backup_20240101.db stock_agent.db +``` + +**PostgreSQL备份** + +```bash +# 备份 +pg_dump -U stock_user stock_agent > backup_$(date +%Y%m%d).sql + +# 恢复 +psql -U stock_user stock_agent < backup_20240101.sql +``` + +### Redis备份 + +```bash +# 备份 +redis-cli SAVE +cp /var/lib/redis/dump.rdb /backup/redis_$(date +%Y%m%d).rdb + +# 恢复 +sudo systemctl stop redis +cp /backup/redis_20240101.rdb /var/lib/redis/dump.rdb +sudo systemctl start redis +``` + +## 安全加固 + +### 1. 防火墙配置 + +```bash +# 使用ufw +sudo ufw allow 22/tcp +sudo ufw allow 80/tcp +sudo ufw allow 443/tcp +sudo ufw enable +``` + +### 2. 限制API访问 + +在Nginx中配置限流: + +```nginx +limit_req_zone $binary_remote_addr zone=api_limit:10m rate=10r/s; + +location /api/ { + limit_req zone=api_limit burst=20 nodelay; + proxy_pass http://127.0.0.1:8000; +} +``` + +### 3. 定期更新 + +```bash +# 更新系统 +sudo apt update && sudo apt upgrade -y + +# 更新Python依赖 +pip install --upgrade -r requirements.txt +``` + +## 故障排查 + +### 常见问题 + +1. **服务无法启动** + - 检查端口是否被占用:`lsof -i :8000` + - 查看日志:`tail -f /var/log/stock_agent.log` + +2. **Redis连接失败** + - 检查Redis状态:`sudo systemctl status redis` + - 测试连接:`redis-cli ping` + +3. **数据库错误** + - 检查数据库连接:`psql -U stock_user -d stock_agent` + - 查看数据库日志:`sudo tail -f /var/log/postgresql/postgresql-*.log` + +4. **API响应慢** + - 检查Redis缓存是否正常 + - 查看系统资源:`htop` + - 分析慢查询日志 + +## 性能测试 + +使用Apache Bench进行压力测试: + +```bash +# 安装ab +sudo apt install apache2-utils -y + +# 测试API +ab -n 1000 -c 50 http://localhost:8000/api/stock/quote/600519 +``` + +## 更新部署 + +### 零停机更新 + +使用蓝绿部署或滚动更新: + +```bash +# 1. 拉取最新代码 +git pull origin main + +# 2. 安装新依赖 +pip install -r requirements.txt + +# 3. 运行数据库迁移(如有) +# python manage.py migrate + +# 4. 重启服务 +sudo supervisorctl restart stock_agent +# 或 +sudo systemctl restart stock_agent +``` + +## 联系支持 + +如遇到部署问题,请提交Issue或联系技术支持。 diff --git a/docs/INSTALL_GUIDE.md b/docs/INSTALL_GUIDE.md new file mode 100644 index 0000000..4e9477b --- /dev/null +++ b/docs/INSTALL_GUIDE.md @@ -0,0 +1,262 @@ +# 安装问题解决指南 + +## 问题:Python 3.13 兼容性问题 + +如果您在安装依赖时遇到 numpy/pandas 编译错误,这是因为 Python 3.13 是最新版本,部分科学计算库还未完全适配。 + +### 错误信息示例 +``` +fatal error: 'type_traits' file not found +ERROR: Failed to build 'pandas' when installing build dependencies +``` + +## 解决方案 + +### 方案1:使用 Python 3.11 或 3.12(强烈推荐) + +这是最简单可靠的方法。 + +#### macOS (使用 Homebrew) + +```bash +# 1. 安装 Python 3.11 +brew install python@3.11 + +# 2. 进入项目目录 +cd /Users/aaron/source_code/Stock_Agent/backend + +# 3. 删除旧的虚拟环境(如果存在) +rm -rf venv + +# 4. 使用 Python 3.11 创建新的虚拟环境 +python3.11 -m venv venv + +# 5. 激活虚拟环境 +source venv/bin/activate + +# 6. 升级 pip +pip install --upgrade pip + +# 7. 安装依赖 +pip install -r requirements.txt +``` + +#### Linux (Ubuntu/Debian) + +```bash +# 1. 安装 Python 3.11 +sudo apt update +sudo apt install python3.11 python3.11-venv python3.11-dev + +# 2. 创建虚拟环境 +python3.11 -m venv venv +source venv/bin/activate + +# 3. 安装依赖 +pip install --upgrade pip +pip install -r requirements.txt +``` + +#### Windows + +```powershell +# 1. 从 python.org 下载并安装 Python 3.11 +# https://www.python.org/downloads/ + +# 2. 创建虚拟环境 +py -3.11 -m venv venv + +# 3. 激活虚拟环境 +venv\Scripts\activate + +# 4. 安装依赖 +pip install --upgrade pip +pip install -r requirements.txt +``` + +### 方案2:使用预编译的 wheel 包(Python 3.13) + +如果必须使用 Python 3.13,可以尝试安装预编译的包: + +```bash +# 激活虚拟环境 +source venv/bin/activate # macOS/Linux +# 或 +venv\Scripts\activate # Windows + +# 先单独安装 numpy 和 pandas +pip install --upgrade pip +pip install numpy --only-binary :all: +pip install pandas --only-binary :all: + +# 然后安装其他依赖 +pip install -r requirements.txt +``` + +### 方案3:使用 Conda(推荐用于数据科学项目) + +Conda 提供预编译的包,避免编译问题: + +```bash +# 1. 安装 Miniconda 或 Anaconda +# https://docs.conda.io/en/latest/miniconda.html + +# 2. 创建环境 +conda create -n stock_agent python=3.11 + +# 3. 激活环境 +conda activate stock_agent + +# 4. 安装依赖 +pip install -r requirements.txt +``` + +## 验证安装 + +安装完成后,验证是否成功: + +```bash +# 检查 Python 版本 +python --version +# 应该显示 Python 3.11.x 或 3.12.x + +# 检查关键包 +python -c "import numpy; print('numpy:', numpy.__version__)" +python -c "import pandas; print('pandas:', pandas.__version__)" +python -c "import fastapi; print('fastapi:', fastapi.__version__)" +python -c "import tushare; print('tushare:', tushare.__version__)" +``` + +## 启动应用 + +安装成功后,按以下步骤启动: + +```bash +# 1. 确保在虚拟环境中 +source venv/bin/activate # macOS/Linux +# 或 +venv\Scripts\activate # Windows + +# 2. 配置环境变量 +cd /Users/aaron/source_code/Stock_Agent +cp .env.example .env +# 编辑 .env 文件,填写 API 密钥 + +# 3. 启动应用 +cd backend +python -m app.main +``` + +## 常见问题 + +### Q1: 如何检查当前 Python 版本? + +```bash +python --version +python3 --version +python3.11 --version +``` + +### Q2: 如何切换 Python 版本? + +macOS/Linux: +```bash +# 使用特定版本创建虚拟环境 +python3.11 -m venv venv +``` + +Windows: +```powershell +py -3.11 -m venv venv +``` + +### Q3: 虚拟环境激活失败? + +确保在正确的目录: +```bash +cd /Users/aaron/source_code/Stock_Agent/backend +ls venv # 应该能看到 bin 或 Scripts 目录 +``` + +### Q4: pip 安装很慢? + +使用国内镜像源: +```bash +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +``` + +### Q5: 权限错误? + +不要使用 sudo,确保在虚拟环境中: +```bash +which python # 应该显示虚拟环境路径 +``` + +## 推荐的完整安装流程 + +```bash +# 1. 安装 Python 3.11 +brew install python@3.11 # macOS + +# 2. 进入项目目录 +cd /Users/aaron/source_code/Stock_Agent + +# 3. 创建虚拟环境 +python3.11 -m venv backend/venv + +# 4. 激活虚拟环境 +source backend/venv/bin/activate + +# 5. 升级 pip +pip install --upgrade pip setuptools wheel + +# 6. 安装依赖 +cd backend +pip install -r requirements.txt + +# 7. 配置环境变量 +cd .. +cp .env.example .env +nano .env # 或使用其他编辑器 + +# 8. 启动应用 +cd backend +python -m app.main +``` + +## 获取帮助 + +如果仍然遇到问题: + +1. 检查 Python 版本:`python --version` +2. 检查虚拟环境:`which python` +3. 查看完整错误信息 +4. 提交 Issue 到项目仓库 + +## 最小依赖版本 + +如果遇到版本冲突,可以尝试最小版本: + +```txt +fastapi>=0.100.0 +uvicorn>=0.23.0 +langchain>=0.1.0 +tushare>=1.3.0 +sqlalchemy>=2.0.0 +pydantic>=2.0.0 +pandas>=2.0.0 +numpy>=1.24.0 +``` + +## 成功标志 + +当您看到以下输出时,说明安装成功: + +``` +INFO: Started server process [xxxxx] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:8000 +``` + +然后访问 http://localhost:8000 即可使用系统! diff --git a/docs/USER_GUIDE.md b/docs/USER_GUIDE.md new file mode 100644 index 0000000..16d84db --- /dev/null +++ b/docs/USER_GUIDE.md @@ -0,0 +1,343 @@ +# 用户使用手册 + +欢迎使用A股AI分析Agent系统!本手册将帮助您快速上手。 + +## 目录 + +1. [系统介绍](#系统介绍) +2. [快速开始](#快速开始) +3. [功能使用](#功能使用) +4. [常见问题](#常见问题) +5. [技巧和建议](#技巧和建议) + +## 系统介绍 + +A股AI分析Agent系统是一个智能股票分析助手,通过自然语言对话方式,帮助您: + +- 查询股票实时行情 +- 分析技术指标 +- 查看K线图表 +- 了解公司基本信息 + +### 核心特性 + +- **自然对话**:像聊天一样查询股票信息 +- **智能理解**:自动识别股票代码和查询意图 +- **实时数据**:获取最新的市场数据 +- **专业图表**:生成专业的K线图和技术指标图 +- **历史记录**:保存对话历史,方便回顾 + +## 快速开始 + +### 1. 访问系统 + +打开浏览器,访问:http://localhost:8000 + +### 2. 界面介绍 + +系统界面分为三个主要区域: + +``` +┌─────────────────────────────────────────────┐ +│ 📈 A股AI分析Agent [技能管理] │ +├─────────────────────────────────────────────┤ +│ │ +│ 消息显示区域 │ +│ - 显示对话历史 │ +│ - 展示图表和数据 │ +│ │ +├─────────────────────────────────────────────┤ +│ [输入框] [发送] │ +└─────────────────────────────────────────────┘ +``` + +### 3. 第一次查询 + +在输入框中输入: + +``` +查询600519的实时行情 +``` + +按回车或点击"发送"按钮,系统会返回贵州茅台的实时行情数据。 + +## 功能使用 + +### 1. 查询实时行情 + +**支持的查询方式**: + +``` +查询600519的实时行情 +贵州茅台的价格 +000001现在多少钱 +中国平安的行情 +``` + +**返回信息**: +- 股票名称和代码 +- 最新价格 +- 涨跌额和涨跌幅 +- 开盘价、最高价、最低价 +- 成交量和成交额 + +**示例**: + +``` +输入:查询600519的实时行情 + +输出: +【贵州茅台】(600519.SH) +交易日期:20240201 +最新价:1650.00 +涨跌额:15.50 +涨跌幅:0.95% +开盘价:1640.00 +最高价:1655.00 +最低价:1638.00 +成交量:125000手 +成交额:206250千元 +``` + +### 2. 查看K线图 + +**支持的查询方式**: + +``` +600519的K线图 +贵州茅台的走势 +000001的图表 +``` + +**功能特点**: +- 显示最近60个交易日的K线 +- 包含成交量柱状图 +- 支持缩放和拖动 +- 红色表示上涨,绿色表示下跌 + +**操作技巧**: +- 鼠标滚轮:缩放图表 +- 鼠标拖动:移动时间轴 +- 双击:重置视图 + +### 3. 技术指标分析 + +**支持的查询方式**: + +``` +600519的技术指标 +分析贵州茅台的MACD +000001的RSI +``` + +**支持的指标**: + +1. **均线(MA)** + - MA5:5日均线 + - MA10:10日均线 + - MA20:20日均线 + - MA60:60日均线 + +2. **MACD** + - DIF:快线 + - DEA:慢线 + - MACD:柱状图 + +3. **RSI(相对强弱指标)** + - RSI6:6日RSI + - RSI12:12日RSI + - RSI24:24日RSI + +4. **KDJ** + - K值 + - D值 + - J值 + +5. **布林带(BOLL)** + - 上轨 + - 中轨 + - 下轨 + +**示例**: + +``` +输入:600519的技术指标 + +输出: +【600519.SH】技术指标: +均线:MA5=1645.20, MA10=1638.50, MA20=1625.30 +MACD:DIF=12.50, DEA=10.20, MACD=4.60 +RSI:RSI6=65.20, RSI12=58.30, RSI24=52.10 +``` + +### 4. 基本面信息 + +**支持的查询方式**: + +``` +600519的基本信息 +贵州茅台是什么行业 +000001的公司信息 +``` + +**返回信息**: +- 股票代码和名称 +- 所属地域 +- 所属行业 +- 上市市场 +- 上市日期 + +**示例**: + +``` +输入:600519的基本信息 + +输出: +【贵州茅台】基本信息 +股票代码:600519.SH +所属地域:贵州 +所属行业:白酒 +上市市场:主板 +上市日期:20010827 +``` + +### 5. 技能管理 + +点击右上角"技能管理"按钮,可以: + +- **查看所有技能**:显示系统支持的所有分析技能 +- **启用/禁用技能**:通过开关控制技能的启用状态 +- **查看技能说明**:了解每个技能的功能 + +**可用技能**: + +1. **market_data**:行情查询技能 +2. **technical_analysis**:技术分析技能 +3. **fundamental**:基本面分析技能 +4. **visualization**:数据可视化技能 + +## 常见问题 + +### Q1: 如何输入股票代码? + +**A**: 支持多种格式: + +- 6位数字:`600519`、`000001` +- 带后缀:`600519.SH`、`000001.SZ` +- 股票名称:`贵州茅台`、`中国平安` + +### Q2: 为什么查询失败? + +**可能原因**: + +1. **股票代码错误** + - 检查代码是否正确 + - 确认是A股代码 + +2. **数据源问题** + - Tushare API可能暂时不可用 + - 检查网络连接 + +3. **非交易时间** + - 某些数据仅在交易时间更新 + +### Q3: 数据更新频率? + +- **实时行情**:缓存30秒 +- **K线数据**:缓存1小时 +- **基本面信息**:缓存1天 + +### Q4: 如何清除对话历史? + +目前系统会自动保存对话历史。如需清除,可以: + +1. 刷新页面(会生成新的会话ID) +2. 联系管理员清理数据库 + +### Q5: 支持哪些股票市场? + +当前版本支持: +- 上海证券交易所(沪市) +- 深圳证券交易所(深市) +- 科创板 + +未来将支持: +- 港股 +- 美股 + +## 技巧和建议 + +### 查询技巧 + +1. **明确查询意图** + ``` + 好:查询600519的实时行情 + 差:600519 + ``` + +2. **使用完整股票代码** + ``` + 好:600519 + 差:6005(不完整) + ``` + +3. **一次查询一个股票** + ``` + 好:查询600519的行情 + 差:查询600519和000001的行情(暂不支持) + ``` + +### 分析建议 + +1. **结合多个指标** + - 先查看K线图,了解整体趋势 + - 再查看技术指标,确认信号 + - 最后查看基本面,评估价值 + +2. **关注关键指标** + - 均线:判断趋势方向 + - MACD:捕捉买卖信号 + - RSI:识别超买超卖 + - 成交量:确认趋势强度 + +3. **定期跟踪** + - 建立自选股列表 + - 定期查询关注的股票 + - 记录重要的分析结果 + +### 使用限制 + +1. **API调用限制** + - Tushare免费版:120次/分钟 + - 建议合理安排查询频率 + +2. **数据准确性** + - 数据来源于Tushare + - 仅供参考,不构成投资建议 + +3. **系统性能** + - 首次查询可能较慢(需要获取数据) + - 后续查询会使用缓存,速度更快 + +## 投资风险提示 + +⚠️ **重要提示**: + +1. 本系统提供的数据和分析仅供参考 +2. 不构成任何投资建议 +3. 股市有风险,投资需谨慎 +4. 请根据自身情况做出投资决策 +5. 建议咨询专业投资顾问 + +## 反馈和支持 + +如有问题或建议,请: + +1. 查看[README.md](../README.md) +2. 查看[部署文档](DEPLOYMENT.md) +3. 提交Issue到项目仓库 +4. 联系技术支持 + +--- + +感谢使用A股AI分析Agent系统!祝您投资顺利! diff --git a/frontend/css/style.css b/frontend/css/style.css new file mode 100644 index 0000000..264a81c --- /dev/null +++ b/frontend/css/style.css @@ -0,0 +1,551 @@ +/* Tesla-inspired Cyberpunk Minimal Design */ + +:root { + --bg-primary: #000000; + --bg-secondary: #0a0a0a; + --bg-tertiary: #141414; + --text-primary: #ffffff; + --text-secondary: #a0a0a0; + --text-tertiary: #666666; + --accent: #00ff41; + --accent-dim: #00ff4120; + --border: #1a1a1a; + --border-bright: #333333; +} + +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +html, body { + height: 100%; + font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; + background: var(--bg-primary); + color: var(--text-primary); + font-size: 15px; + line-height: 1.6; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; +} + +#app { + height: 100vh; + display: flex; + align-items: center; + justify-content: center; + padding: 20px; +} + +/* Container */ +.container { + width: 100%; + max-width: 900px; + height: 100%; + max-height: 900px; + display: flex; + flex-direction: column; + background: var(--bg-secondary); + border: 1px solid var(--border); + border-radius: 0; + overflow: hidden; +} + +/* Header */ +.header { + display: flex; + align-items: center; + justify-content: space-between; + padding: 20px 24px; + border-bottom: 1px solid var(--border); + background: var(--bg-primary); +} + +.logo { + display: flex; + align-items: center; + gap: 12px; + font-size: 15px; + font-weight: 500; + letter-spacing: 0.5px; + color: var(--text-primary); +} + +.logo svg { + color: var(--accent); +} + +.status { + display: flex; + align-items: center; + gap: 8px; + font-size: 13px; + color: var(--text-secondary); + text-transform: uppercase; + letter-spacing: 1px; +} + +.status-dot { + width: 6px; + height: 6px; + border-radius: 50%; + background: var(--accent); + box-shadow: 0 0 8px var(--accent); + animation: pulse 2s ease-in-out infinite; +} + +@keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.4; } +} + +/* Chat Container */ +.chat-container { + flex: 1; + overflow-y: auto; + overflow-x: hidden; +} + +/* Welcome Screen */ +.welcome { + height: 100%; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + text-align: center; + padding: 40px; +} + +.welcome-icon { + margin-bottom: 32px; + opacity: 0.3; +} + +.welcome-icon svg { + stroke: var(--accent); +} + +.welcome h1 { + font-size: 28px; + font-weight: 300; + letter-spacing: 1px; + margin-bottom: 12px; + color: var(--text-primary); +} + +.welcome p { + font-size: 14px; + color: var(--text-secondary); + letter-spacing: 0.5px; +} + +/* Messages */ +.messages { + padding: 32px 24px; + display: flex; + flex-direction: column; + gap: 24px; +} + +.message { + display: flex; + animation: fadeIn 0.3s ease; +} + +@keyframes fadeIn { + from { + opacity: 0; + transform: translateY(10px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +.message.user { + justify-content: flex-end; +} + +.message.assistant { + justify-content: flex-start; +} + +.message-content { + max-width: 85%; +} + +.message.user .message-content { + background: var(--bg-tertiary); + border: 1px solid var(--border-bright); + border-radius: 2px; + padding: 16px 20px; +} + +.message.assistant .message-content { + background: transparent; + border-left: 2px solid var(--accent); + padding: 0 0 0 20px; +} + +.text { + font-size: 14px; + line-height: 1.7; + color: var(--text-primary); + word-wrap: break-word; +} + +.message.user .text { + color: var(--text-primary); +} + +/* Markdown Styles */ +.markdown h1, +.markdown h2, +.markdown h3, +.markdown h4 { + color: var(--text-primary); + font-weight: 500; + margin: 20px 0 12px 0; + letter-spacing: 0.5px; +} + +.markdown h1 { + font-size: 20px; + border-bottom: 1px solid var(--border); + padding-bottom: 8px; +} + +.markdown h2 { + font-size: 18px; + color: var(--accent); +} + +.markdown h3 { + font-size: 16px; +} + +.markdown p { + margin: 12px 0; + color: var(--text-primary); +} + +.markdown ul, +.markdown ol { + margin: 12px 0; + padding-left: 24px; +} + +.markdown li { + margin: 6px 0; + color: var(--text-primary); +} + +.markdown code { + background: var(--bg-tertiary); + border: 1px solid var(--border); + padding: 2px 6px; + border-radius: 2px; + font-family: 'Monaco', 'Courier New', monospace; + font-size: 13px; + color: var(--accent); +} + +.markdown pre { + background: var(--bg-tertiary); + border: 1px solid var(--border); + border-radius: 2px; + padding: 16px; + overflow-x: auto; + margin: 12px 0; +} + +.markdown pre code { + background: none; + border: none; + padding: 0; + color: var(--text-primary); +} + +.markdown blockquote { + border-left: 3px solid var(--accent); + padding-left: 16px; + margin: 12px 0; + color: var(--text-secondary); + font-style: italic; +} + +.markdown a { + color: var(--accent); + text-decoration: none; + border-bottom: 1px solid transparent; + transition: border-color 0.2s; +} + +.markdown a:hover { + border-bottom-color: var(--accent); +} + +.markdown strong { + color: var(--text-primary); + font-weight: 600; +} + +.markdown table { + width: 100%; + border-collapse: collapse; + margin: 12px 0; +} + +.markdown th, +.markdown td { + border: 1px solid var(--border); + padding: 8px 12px; + text-align: left; +} + +.markdown th { + background: var(--bg-tertiary); + font-weight: 500; +} + +/* Chart */ +.chart-box { + margin-top: 16px; + border: 1px solid var(--border); + border-radius: 2px; + overflow: hidden; + background: var(--bg-primary); +} + +.chart { + width: 100%; + height: 400px; +} + +/* Typing Indicator */ +.typing { + display: flex; + gap: 6px; + padding: 8px 0; +} + +.typing span { + width: 6px; + height: 6px; + border-radius: 50%; + background: var(--accent); + animation: bounce 1.4s infinite ease-in-out; +} + +.typing span:nth-child(1) { + animation-delay: 0s; +} + +.typing span:nth-child(2) { + animation-delay: 0.2s; +} + +.typing span:nth-child(3) { + animation-delay: 0.4s; +} + +@keyframes bounce { + 0%, 60%, 100% { + transform: translateY(0); + opacity: 0.4; + } + 30% { + transform: translateY(-8px); + opacity: 1; + } +} + +/* Input Container */ +.input-container { + padding: 20px 24px; + border-top: 1px solid var(--border); + background: var(--bg-primary); +} + +.input-wrapper { + display: flex; + align-items: center; + gap: 12px; + background: var(--bg-secondary); + border: 1px solid var(--border); + border-radius: 2px; + padding: 12px; + transition: border-color 0.2s; +} + +.input-wrapper:focus-within { + border-color: var(--accent); + box-shadow: 0 0 0 1px var(--accent); +} + +/* Author Info */ +.author-info { + display: flex; + align-items: center; + justify-content: center; + gap: 8px; + margin-top: 12px; + font-size: 12px; + color: var(--text-tertiary); + letter-spacing: 0.5px; +} + +.author-label { + color: var(--text-secondary); + text-transform: uppercase; +} + +.author-divider { + color: var(--border-bright); +} + +.author-contact { + color: var(--accent); + font-family: 'Monaco', 'Courier New', monospace; + cursor: pointer; + transition: all 0.2s; + padding: 4px 8px; + border-radius: 2px; +} + +.author-contact:hover { + background: var(--accent-dim); + box-shadow: 0 0 8px var(--accent-dim); +} + +.input-wrapper textarea { + flex: 1; + background: transparent; + border: none; + outline: none; + color: var(--text-primary); + font-size: 14px; + font-family: inherit; + resize: none; + max-height: 120px; + line-height: 1.5; + padding: 0; + vertical-align: middle; +} + +.input-wrapper textarea::placeholder { + color: var(--text-tertiary); +} + +.send-btn { + width: 36px; + height: 36px; + background: transparent; + border: 1px solid var(--border-bright); + border-radius: 2px; + color: var(--accent); + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; + transition: all 0.2s; + flex-shrink: 0; +} + +.send-btn:hover:not(:disabled) { + background: var(--accent-dim); + border-color: var(--accent); +} + +.send-btn:disabled { + opacity: 0.3; + cursor: not-allowed; +} + +/* Spinner */ +.spinner { + width: 16px; + height: 16px; + border: 2px solid var(--border-bright); + border-top-color: var(--accent); + border-radius: 50%; + animation: spin 0.8s linear infinite; +} + +@keyframes spin { + to { transform: rotate(360deg); } +} + +/* Scrollbar */ +::-webkit-scrollbar { + width: 6px; + height: 6px; +} + +::-webkit-scrollbar-track { + background: var(--bg-primary); +} + +::-webkit-scrollbar-thumb { + background: var(--border-bright); + border-radius: 3px; +} + +::-webkit-scrollbar-thumb:hover { + background: var(--text-tertiary); +} + +/* Responsive */ +@media (max-width: 768px) { + #app { + padding: 0; + } + + .container { + max-width: 100%; + max-height: 100%; + border-radius: 0; + border: none; + } + + .header { + padding: 16px 20px; + } + + .messages { + padding: 24px 20px; + } + + .message-content { + max-width: 90%; + } + + .input-container { + padding: 16px 20px; + } +} + +/* Selection */ +::selection { + background: var(--accent-dim); + color: var(--text-primary); +} + +/* Copy Notification Animation */ +@keyframes fadeInOut { + 0% { + opacity: 0; + transform: translateX(-50%) translateY(10px); + } + 10%, 90% { + opacity: 1; + transform: translateX(-50%) translateY(0); + } + 100% { + opacity: 0; + transform: translateX(-50%) translateY(-10px); + } +} diff --git a/frontend/css/style.css.backup b/frontend/css/style.css.backup new file mode 100644 index 0000000..00da11d --- /dev/null +++ b/frontend/css/style.css.backup @@ -0,0 +1,196 @@ +/* 全局样式 */ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +html, body { + height: 100%; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; +} + +#app { + height: 100vh; + display: flex; + flex-direction: column; +} + +.container-fluid { + flex: 1; + overflow: hidden; +} + +.row { + height: calc(100vh - 56px); +} + +/* 聊天容器 */ +.chat-container { + display: flex; + flex-direction: column; + height: 100%; + padding: 0; +} + +/* 消息列表 */ +.messages-container { + flex: 1; + overflow-y: auto; + padding: 20px; + background-color: #f8f9fa; +} + +/* 消息样式 */ +.message { + margin-bottom: 20px; + animation: fadeIn 0.3s ease-in; +} + +@keyframes fadeIn { + from { + opacity: 0; + transform: translateY(10px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +.message-content { + max-width: 80%; + padding: 12px 16px; + border-radius: 8px; + box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1); +} + +.user-message .message-content { + background-color: #007bff; + color: white; + margin-left: auto; +} + +.assistant-message .message-content { + background-color: white; + color: #333; +} + +.message-header { + display: flex; + justify-content: space-between; + margin-bottom: 8px; + font-size: 0.9em; +} + +.user-message .message-header { + color: rgba(255, 255, 255, 0.9); +} + +.assistant-message .message-header { + color: #666; +} + +.message-time { + font-size: 0.85em; + opacity: 0.7; +} + +.message-body pre { + white-space: pre-wrap; + word-wrap: break-word; + font-family: inherit; + font-size: 0.95em; + line-height: 1.6; +} + +/* 图表容器 */ +.chart-container { + width: 100%; + height: 400px; + background-color: #fff; + border-radius: 4px; + border: 1px solid #dee2e6; +} + +/* 输入框容器 */ +.input-container { + padding: 20px; + background-color: white; + border-top: 1px solid #dee2e6; +} + +.input-group input { + border-radius: 20px 0 0 20px; + padding: 12px 20px; +} + +.input-group button { + border-radius: 0 20px 20px 0; + padding: 12px 30px; +} + +/* 技能面板 */ +.skill-panel { + background-color: #f8f9fa; + border-left: 1px solid #dee2e6; + padding: 20px; + overflow-y: auto; +} + +.skill-item { + padding: 12px; + background-color: white; + border-radius: 8px; + border: 1px solid #dee2e6; +} + +.skill-item:hover { + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); +} + +/* 滚动条样式 */ +.messages-container::-webkit-scrollbar, +.skill-panel::-webkit-scrollbar { + width: 8px; +} + +.messages-container::-webkit-scrollbar-track, +.skill-panel::-webkit-scrollbar-track { + background: #f1f1f1; +} + +.messages-container::-webkit-scrollbar-thumb, +.skill-panel::-webkit-scrollbar-thumb { + background: #888; + border-radius: 4px; +} + +.messages-container::-webkit-scrollbar-thumb:hover, +.skill-panel::-webkit-scrollbar-thumb:hover { + background: #555; +} + +/* 响应式设计 */ +@media (max-width: 768px) { + .message-content { + max-width: 90%; + } + + .skill-panel { + position: fixed; + top: 56px; + right: 0; + width: 300px; + height: calc(100vh - 56px); + z-index: 1000; + box-shadow: -2px 0 8px rgba(0, 0, 0, 0.1); + } +} + +/* 加载动画 */ +.spinner-border-sm { + width: 1rem; + height: 1rem; + border-width: 0.15em; +} diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..ebc36a7 --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,121 @@ + + + + + + 龙哥的 AI 金融智能体 + + + + + + + + + + + + + +
+ +
+ +
+ +
+
+ 在线 +
+
+ + +
+ +
+
+ + + +
+

开始对话

+

输入股票代码或名称,获取实时分析

+
+ + +
+
+
+
{{ msg.content }}
+
+ + +
+
+
+
+
+ + +
+
+
+ + + +
+
+
+
+
+ + +
+
+ + +
+ + +
+ 点击联系作者 + | + 微信:aaronlzhou +
+
+
+
+ + + + + + + + + + + diff --git a/frontend/index.html.backup b/frontend/index.html.backup new file mode 100644 index 0000000..e3255e5 --- /dev/null +++ b/frontend/index.html.backup @@ -0,0 +1,133 @@ + + + + + + A股AI分析Agent系统 + + + + + + + + +
+ + + +
+
+ +
+ +
+
+

欢迎使用A股AI分析Agent

+

请输入股票代码或问题,例如:

+
    +
  • 查询600519的实时行情
  • +
  • 贵州茅台的技术指标
  • +
  • 000001的K线图
  • +
+
+ +
+
+
+ {{ msg.role === 'user' ? '您' : 'AI助手' }} + {{ formatTime(msg.timestamp) }} +
+
+
{{ msg.content }}
+

{{ msg.content }}

+ + +
+
+
+
+
+
+ +
+
+
+ 加载中... +
+ AI正在思考... +
+
+
+ + +
+
+ + +
+
+
+ + +
+
+
+
技能列表
+
+
+
+ 加载中... +
+
+
+ {{ skill.name }} +
+ +
+
+ {{ skill.description }} +
+
+
+
+
+
+
+ + + + + + + + + + + diff --git a/frontend/js/app.js b/frontend/js/app.js new file mode 100644 index 0000000..aada4b1 --- /dev/null +++ b/frontend/js/app.js @@ -0,0 +1,269 @@ +// Vue 3 Application +const { createApp } = Vue; + +createApp({ + data() { + return { + messages: [], + userInput: '', + loading: false, + sessionId: null, + charts: {} + }; + }, + mounted() { + this.sessionId = this.generateSessionId(); + this.autoResizeTextarea(); + }, + methods: { + async sendMessage() { + if (!this.userInput.trim() || this.loading) return; + + const message = this.userInput.trim(); + this.userInput = ''; + + // Add user message + this.messages.push({ + role: 'user', + content: message, + timestamp: new Date() + }); + + this.$nextTick(() => { + this.scrollToBottom(); + this.autoResizeTextarea(); + }); + + this.loading = true; + + try { + const response = await fetch('/api/chat/message', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + message: message, + session_id: this.sessionId + }) + }); + + if (!response.ok) { + throw new Error('请求失败'); + } + + const data = await response.json(); + + // Add assistant message + const assistantMessage = { + role: 'assistant', + content: data.message, + timestamp: new Date(), + metadata: data.metadata + }; + + this.messages.push(assistantMessage); + + // Render chart if needed + if (data.metadata && data.metadata.type === 'chart') { + this.$nextTick(() => { + const index = this.messages.length - 1; + this.renderChart(index, data.metadata.data); + }); + } + + this.$nextTick(() => { + this.scrollToBottom(); + }); + + } catch (error) { + console.error('发送消息失败:', error); + this.messages.push({ + role: 'assistant', + content: '抱歉,发送消息失败,请稍后重试。', + timestamp: new Date() + }); + } finally { + this.loading = false; + } + }, + + renderMarkdown(content) { + if (!content) return ''; + + // Configure marked options + marked.setOptions({ + breaks: true, + gfm: true, + headerIds: false, + mangle: false + }); + + return marked.parse(content); + }, + + renderChart(index, data) { + const chartId = `chart-${index}`; + const container = document.getElementById(chartId); + + if (!container || !data.kline_data) return; + + const chart = LightweightCharts.createChart(container, { + width: container.clientWidth, + height: 400, + layout: { + background: { color: '#000000' }, + textColor: '#a0a0a0' + }, + grid: { + vertLines: { color: '#1a1a1a' }, + horzLines: { color: '#1a1a1a' } + }, + timeScale: { + borderColor: '#333333', + timeVisible: true + }, + rightPriceScale: { + borderColor: '#333333' + } + }); + + const candlestickSeries = chart.addCandlestickSeries({ + upColor: '#00ff41', + downColor: '#ff0040', + borderVisible: false, + wickUpColor: '#00ff41', + wickDownColor: '#ff0040' + }); + + const klineData = data.kline_data.map(item => ({ + time: item.trade_date, + open: item.open, + high: item.high, + low: item.low, + close: item.close + })); + + candlestickSeries.setData(klineData); + + if (data.volume_data) { + const volumeSeries = chart.addHistogramSeries({ + color: '#00ff4140', + priceFormat: { + type: 'volume' + }, + priceScaleId: '' + }); + + const volumeData = data.volume_data.map(item => ({ + time: item.trade_date, + value: item.vol, + color: item.close >= item.open ? '#00ff4140' : '#ff004040' + })); + + volumeSeries.setData(volumeData); + } + + chart.timeScale().fitContent(); + + this.charts[chartId] = chart; + + // Handle resize + window.addEventListener('resize', () => { + if (this.charts[chartId]) { + chart.applyOptions({ width: container.clientWidth }); + } + }); + }, + + scrollToBottom() { + const container = this.$refs.chatContainer; + if (container) { + setTimeout(() => { + container.scrollTop = container.scrollHeight; + }, 100); + } + }, + + autoResizeTextarea() { + const textarea = this.$refs.textarea; + if (textarea) { + textarea.style.height = 'auto'; + textarea.style.height = Math.min(textarea.scrollHeight, 120) + 'px'; + } + }, + + generateSessionId() { + return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9); + }, + + copyWechat() { + const wechatId = 'aaronlzhou'; + + // 使用现代的 Clipboard API + if (navigator.clipboard && navigator.clipboard.writeText) { + navigator.clipboard.writeText(wechatId).then(() => { + this.showCopyNotification(); + }).catch(err => { + console.error('复制失败:', err); + this.fallbackCopy(wechatId); + }); + } else { + this.fallbackCopy(wechatId); + } + }, + + fallbackCopy(text) { + // 降级方案:使用传统方法 + const textarea = document.createElement('textarea'); + textarea.value = text; + textarea.style.position = 'fixed'; + textarea.style.opacity = '0'; + document.body.appendChild(textarea); + textarea.select(); + + try { + document.execCommand('copy'); + this.showCopyNotification(); + } catch (err) { + console.error('复制失败:', err); + } + + document.body.removeChild(textarea); + }, + + showCopyNotification() { + // 创建临时提示 + const notification = document.createElement('div'); + notification.textContent = '已复制微信号'; + notification.style.cssText = ` + position: fixed; + bottom: 80px; + left: 50%; + transform: translateX(-50%); + background: #00ff41; + color: #000000; + padding: 8px 16px; + border-radius: 2px; + font-size: 13px; + font-weight: 500; + z-index: 10000; + animation: fadeInOut 2s ease; + `; + + document.body.appendChild(notification); + + setTimeout(() => { + document.body.removeChild(notification); + }, 2000); + } + }, + + watch: { + userInput() { + this.$nextTick(() => { + this.autoResizeTextarea(); + }); + } + } +}).mount('#app'); diff --git a/frontend/js/app.js.backup b/frontend/js/app.js.backup new file mode 100644 index 0000000..6c122a6 --- /dev/null +++ b/frontend/js/app.js.backup @@ -0,0 +1,219 @@ +// Vue 3 应用 +const { createApp } = Vue; + +createApp({ + data() { + return { + messages: [], + userInput: '', + loading: false, + sessionId: null, + showSkillPanel: false, + skills: [], + charts: {} + }; + }, + mounted() { + this.loadSkills(); + // 生成会话ID + this.sessionId = this.generateSessionId(); + }, + methods: { + async sendMessage() { + if (!this.userInput.trim() || this.loading) return; + + const message = this.userInput.trim(); + this.userInput = ''; + + // 添加用户消息 + this.messages.push({ + role: 'user', + content: message, + timestamp: new Date() + }); + + // 滚动到底部 + this.$nextTick(() => { + this.scrollToBottom(); + }); + + // 发送请求 + this.loading = true; + + try { + const response = await fetch('/api/chat/message', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + message: message, + session_id: this.sessionId + }) + }); + + if (!response.ok) { + throw new Error('请求失败'); + } + + const data = await response.json(); + + // 添加助手消息 + const assistantMessage = { + role: 'assistant', + content: data.message, + timestamp: new Date(), + metadata: data.metadata + }; + + this.messages.push(assistantMessage); + + // 如果有图表数据,渲染图表 + if (data.metadata && data.metadata.type === 'chart') { + this.$nextTick(() => { + const index = this.messages.length - 1; + this.renderChart(index, data.metadata.data); + }); + } + + // 滚动到底部 + this.$nextTick(() => { + this.scrollToBottom(); + }); + + } catch (error) { + console.error('发送消息失败:', error); + this.messages.push({ + role: 'assistant', + content: '抱歉,发送消息失败,请稍后重试。', + timestamp: new Date() + }); + } finally { + this.loading = false; + } + }, + + async loadSkills() { + try { + const response = await fetch('/api/skills/'); + if (!response.ok) { + throw new Error('加载技能失败'); + } + const data = await response.json(); + this.skills = data.skills; + } catch (error) { + console.error('加载技能失败:', error); + } + }, + + async toggleSkill(skillName, enabled) { + try { + const endpoint = enabled ? 'enable' : 'disable'; + const response = await fetch(`/api/skills/${skillName}/${endpoint}`, { + method: 'POST' + }); + + if (!response.ok) { + throw new Error('切换技能失败'); + } + + // 重新加载技能列表 + await this.loadSkills(); + } catch (error) { + console.error('切换技能失败:', error); + // 恢复原状态 + await this.loadSkills(); + } + }, + + renderChart(index, chartData) { + const containerId = `chart-${index}`; + const container = document.getElementById(containerId); + + if (!container || !chartData) return; + + try { + // 创建图表 + const chart = LightweightCharts.createChart(container, { + width: container.clientWidth, + height: 400, + layout: { + background: { color: '#ffffff' }, + textColor: '#333', + }, + grid: { + vertLines: { color: '#f0f0f0' }, + horzLines: { color: '#f0f0f0' }, + }, + timeScale: { + borderColor: '#cccccc', + }, + }); + + // 添加K线图 + if (chartData.candlestick_data) { + const candlestickSeries = chart.addCandlestickSeries({ + upColor: '#26a69a', + downColor: '#ef5350', + borderVisible: false, + wickUpColor: '#26a69a', + wickDownColor: '#ef5350', + }); + candlestickSeries.setData(chartData.candlestick_data); + } + + // 添加成交量 + if (chartData.volume_data) { + const volumeSeries = chart.addHistogramSeries({ + color: '#26a69a', + priceFormat: { + type: 'volume', + }, + priceScaleId: '', + scaleMargins: { + top: 0.8, + bottom: 0, + }, + }); + volumeSeries.setData(chartData.volume_data); + } + + // 自适应大小 + chart.timeScale().fitContent(); + + // 保存图表实例 + this.charts[containerId] = chart; + + // 窗口大小改变时调整图表 + window.addEventListener('resize', () => { + if (this.charts[containerId]) { + this.charts[containerId].applyOptions({ + width: container.clientWidth + }); + } + }); + + } catch (error) { + console.error('渲染图表失败:', error); + } + }, + + scrollToBottom() { + const container = this.$refs.messagesContainer; + if (container) { + container.scrollTop = container.scrollHeight; + } + }, + + formatTime(timestamp) { + const date = new Date(timestamp); + const hours = date.getHours().toString().padStart(2, '0'); + const minutes = date.getMinutes().toString().padStart(2, '0'); + return `${hours}:${minutes}`; + }, + + generateSessionId() { + return 'session_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9); + } + } +}).mount('#app'); diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..2d82c44 --- /dev/null +++ b/install.sh @@ -0,0 +1,149 @@ +#!/bin/bash + +# A股AI分析Agent系统 - 快速安装脚本 + +echo "================================" +echo "A股AI分析Agent系统 - 安装脚本" +echo "================================" +echo "" + +# 检查Python版本 +echo "检查Python版本..." + +# 尝试找到合适的Python版本 +PYTHON_CMD="" + +for cmd in python3.11 python3.12 python3.10 python3 python; do + if command -v $cmd &> /dev/null; then + version=$($cmd --version 2>&1 | awk '{print $2}') + major=$(echo $version | cut -d. -f1) + minor=$(echo $version | cut -d. -f2) + + if [ "$major" = "3" ] && [ "$minor" -ge "10" ] && [ "$minor" -le "12" ]; then + PYTHON_CMD=$cmd + echo "✓ 找到合适的Python版本: $version ($cmd)" + break + fi + fi +done + +if [ -z "$PYTHON_CMD" ]; then + echo "❌ 错误: 未找到合适的Python版本" + echo "" + echo "请安装 Python 3.11 或 3.12:" + echo " macOS: brew install python@3.11" + echo " Ubuntu: sudo apt install python3.11" + echo " Windows: 从 python.org 下载安装" + echo "" + exit 1 +fi + +# 检查是否在项目根目录 +if [ ! -f "README.md" ] || [ ! -d "backend" ]; then + echo "❌ 错误: 请在项目根目录运行此脚本" + echo "当前目录: $(pwd)" + exit 1 +fi + +# 创建虚拟环境 +echo "" +echo "创建虚拟环境..." +cd backend + +if [ -d "venv" ]; then + echo "⚠ 虚拟环境已存在,将删除并重新创建" + rm -rf venv +fi + +$PYTHON_CMD -m venv venv + +if [ $? -ne 0 ]; then + echo "❌ 创建虚拟环境失败" + exit 1 +fi + +echo "✓ 虚拟环境创建成功" + +# 激活虚拟环境 +echo "" +echo "激活虚拟环境..." +source venv/bin/activate + +# 升级pip +echo "" +echo "升级pip..." +pip install --upgrade pip setuptools wheel + +# 安装依赖 +echo "" +echo "安装依赖包..." +echo "这可能需要几分钟时间..." +echo "" + +pip install -r requirements.txt + +if [ $? -ne 0 ]; then + echo "" + echo "❌ 依赖安装失败" + echo "" + echo "可能的原因:" + echo "1. Python版本不兼容(推荐使用3.11或3.12)" + echo "2. 网络问题" + echo "3. 缺少编译工具" + echo "" + echo "解决方案请查看: docs/INSTALL_GUIDE.md" + exit 1 +fi + +echo "" +echo "✓ 依赖安装成功" + +# 检查配置文件 +echo "" +echo "检查配置文件..." +cd .. + +if [ ! -f ".env" ]; then + echo "⚠ 未找到.env文件,从模板创建..." + cp .env.example .env + echo "✓ 已创建.env文件" + echo "" + echo "⚠ 重要: 请编辑.env文件,填写以下配置:" + echo " - TUSHARE_TOKEN (从 https://tushare.pro/ 获取)" + echo " - ZHIPUAI_API_KEY (从 https://open.bigmodel.cn/ 获取)" + echo "" +else + echo "✓ .env文件已存在" +fi + +# 验证安装 +echo "" +echo "验证安装..." +cd backend +source venv/bin/activate + +python -c "import fastapi; print('✓ FastAPI:', fastapi.__version__)" 2>/dev/null || echo "❌ FastAPI 安装失败" +python -c "import pandas; print('✓ Pandas:', pandas.__version__)" 2>/dev/null || echo "❌ Pandas 安装失败" +python -c "import numpy; print('✓ NumPy:', numpy.__version__)" 2>/dev/null || echo "❌ NumPy 安装失败" +python -c "import tushare; print('✓ Tushare:', tushare.__version__)" 2>/dev/null || echo "❌ Tushare 安装失败" + +echo "" +echo "================================" +echo "安装完成!" +echo "================================" +echo "" +echo "下一步:" +echo "1. 编辑 .env 文件,填写API密钥" +echo "2. 启动应用:" +echo " cd backend" +echo " source venv/bin/activate" +echo " python -m app.main" +echo "" +echo "3. 访问系统:" +echo " 前端界面: http://localhost:8000" +echo " API文档: http://localhost:8000/docs" +echo "" +echo "如有问题,请查看:" +echo " - 安装指南: docs/INSTALL_GUIDE.md" +echo " - 用户手册: docs/USER_GUIDE.md" +echo "" diff --git a/start.sh b/start.sh new file mode 100755 index 0000000..f68cdee --- /dev/null +++ b/start.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# A股AI分析Agent系统 - 快速启动脚本 + +echo "================================" +echo "A股AI分析Agent系统 - 启动脚本" +echo "================================" +echo "" + +# 检查Python版本 +echo "检查Python版本..." +python_version=$(python3 --version 2>&1 | awk '{print $2}') +echo "Python版本: $python_version" + +# 检查是否在虚拟环境中 +if [[ "$VIRTUAL_ENV" == "" ]]; then + echo "" + echo "警告: 未检测到虚拟环境" + echo "建议创建虚拟环境:" + echo " python3 -m venv venv" + echo " source venv/bin/activate # macOS/Linux" + echo " venv\\Scripts\\activate # Windows" + echo "" + read -p "是否继续?(y/n) " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + exit 1 + fi +fi + +# 检查.env文件 +if [ ! -f ".env" ]; then + echo "" + echo "错误: 未找到.env文件" + echo "请复制.env.example为.env并配置:" + echo " cp .env.example .env" + echo " 然后编辑.env文件,填写必要的配置" + exit 1 +fi + +# 检查依赖 +echo "" +echo "检查依赖..." +cd backend + +if [ ! -d "venv" ] && [[ "$VIRTUAL_ENV" == "" ]]; then + echo "安装依赖..." + pip3 install -r requirements.txt +fi + +# 检查Redis(可选) +echo "" +echo "检查Redis..." +if command -v redis-cli &> /dev/null; then + if redis-cli ping &> /dev/null; then + echo "✓ Redis运行正常" + else + echo "⚠ Redis未运行(可选,系统会自动降级)" + echo " 启动Redis: redis-server" + fi +else + echo "⚠ Redis未安装(可选,系统会自动降级)" +fi + +# 启动应用 +echo "" +echo "================================" +echo "启动应用..." +echo "================================" +echo "" +echo "访问地址:" +echo " 前端界面: http://localhost:8000" +echo " API文档: http://localhost:8000/docs" +echo "" +echo "按 Ctrl+C 停止服务" +echo "" + +python3 -m app.main