diff --git a/cryptoai/api/adata_api.py b/cryptoai/api/adata_api.py new file mode 100644 index 0000000..39ff75e --- /dev/null +++ b/cryptoai/api/adata_api.py @@ -0,0 +1,147 @@ +""" +封装adata库的A股数据获取方法 +提供简单易用的接口来获取股票市场数据 +""" + +import adata +from typing import Dict, List, Optional +import pandas as pd +from datetime import datetime, timedelta + +class AStockAPI: + @staticmethod + def get_all_stock_codes() -> List[str]: + """ + 获取所有A股代码 + Returns: + List[str]: 股票代码列表 + """ + try: + return adata.stock.info.all_code() + except Exception as e: + print(f"获取股票代码失败: {str(e)}") + return [] + + + @staticmethod + def get_concept_east(stock_code: str) -> pd.DataFrame: + """ + 获取股票所属的概念板块 + """ + try: + return adata.stock.info.get_concept_east(stock_code) + except Exception as e: + print(f"获取概念板块失败: {str(e)}") + return pd.DataFrame() + + + @staticmethod + def get_plate_east(stock_code: str) -> pd.DataFrame: + """ + 获取股票所属的板块 + """ + try: + return adata.stock.info.get_plate_east(stock_code) + except Exception as e: + print(f"获取板块失败: {str(e)}") + return pd.DataFrame() + + @staticmethod + def get_stock_shares(stock_code: str) -> pd.DataFrame: + """ + 获取股本信息 + """ + try: + return adata.stock.info.get_stock_shares(stock_code) + except Exception as e: + print(f"获取股本信息失败: {str(e)}") + return pd.DataFrame() + + @staticmethod + def get_market_data(stock_code: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + k_type: int = 1, + adjust_type: int = 1) -> pd.DataFrame: + """ + 获取股票市场日线数据 + Args: + codes: 股票代码列表,默认为None表示获取所有股票 + start_date: 开始日期,格式:YYYY-MM-DD + end_date: 结束日期,格式:YYYY-MM-DD + Returns: + pd.DataFrame: 市场数据 + """ + try: + return adata.stock.market.get_market(stock_code=stock_code, + start_date=start_date, + end_date=end_date, + k_type=k_type, + adjust_type=adjust_type) + except Exception as e: + print(f"获取市场数据失败: {str(e)}") + return pd.DataFrame() + + @staticmethod + def get_market_min_data(stock_code: str) -> pd.DataFrame: + """ + 获取股票市场分钟线数据 + Args: + stock_code: 股票代码 + Returns: + pd.DataFrame: 分钟级市场数据 + """ + try: + return adata.stock.market.get_market_min(stock_code=stock_code) + except Exception as e: + print(f"获取分钟线数据失败: {str(e)}") + return pd.DataFrame() + + @staticmethod + def get_capital_flow_min(stock_code: str) -> pd.DataFrame: + """ + 获取股票资金流向分钟数据 + Args: + stock_code: 股票代码 + Returns: + pd.DataFrame: 资金流向数据 + """ + try: + return adata.stock.market.get_capital_flow_min(stock_code=stock_code) + except Exception as e: + print(f"获取资金流向数据失败: {str(e)}") + return pd.DataFrame() + +# 使用示例 +if __name__ == "__main__": + + print("开始获取A股数据") + api = AStockAPI() + + # 获取所有股票代码 + stock_codes = ["688552", "688648", "688165","688552"] + print(f"获取到 {len(stock_codes)} 个股票代码") + + # 获取指定股票的市场数据 + if stock_codes: + sample_codes = stock_codes[:3] # 取前3个股票作为示例 + print(f"获取到 {len(sample_codes)} 个股票代码") + today = datetime.now().strftime("%Y-%m-%d") + week_ago = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") + + # 获取日线数据 + market_data = api.get_market_data(stock_code=sample_codes[0], + start_date=week_ago, + end_date=today) + print("\n日线数据示例:") + print(market_data.head()) + + # 获取分钟线数据 + min_data = api.get_market_min_data(stock_code=sample_codes[0]) + print("\n分钟线数据示例:") + print(min_data.head()) + + # 获取资金流向数据 + flow_data = api.get_capital_flow_min(stock_code=sample_codes[0]) + print("\n资金流向数据示例:") + print(flow_data.head()) \ No newline at end of file diff --git a/cryptoai/routes/adata.py b/cryptoai/routes/adata.py new file mode 100644 index 0000000..d31a0a2 --- /dev/null +++ b/cryptoai/routes/adata.py @@ -0,0 +1,41 @@ +import json + +from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path +from cryptoai.api.adata_api import AStockAPI + +# 创建路由 +router = APIRouter() + + +@router.get("/stock/data") +async def get_stock_data(stock_code: str): + + result = {} + + api = AStockAPI() + + # 获取股本信息 + stock_shares = api.get_stock_shares(stock_code) + result["stock_shares"] = json.loads(stock_shares.to_json(orient="records")) + + # 获取概念板块 + concept_east = api.get_concept_east(stock_code) + result["concept_east"] = json.loads(concept_east.to_json(orient="records")) + + # 获取板块 + plate_east = api.get_plate_east(stock_code) + result["plate_east"] = json.loads(plate_east.to_json(orient="records")) + + # 获取市场数据 + market_data = api.get_market_data(stock_code) + result["market_data"] = json.loads(market_data.to_json(orient="records")) + + # 获取分钟线数据 + min_data = api.get_market_min_data(stock_code) + result["min_data"] = json.loads(min_data.to_json(orient="records")) + + # 获取资金流向数据 + flow_data = api.get_capital_flow_min(stock_code) + result["flow_data"] = json.loads(flow_data.to_json(orient="records")) + + return result diff --git a/cryptoai/routes/fastapi_app.py b/cryptoai/routes/fastapi_app.py index 6d5a461..6f7ffc0 100644 --- a/cryptoai/routes/fastapi_app.py +++ b/cryptoai/routes/fastapi_app.py @@ -18,6 +18,7 @@ from typing import Dict, Any from cryptoai.routes.agent import router as agent_router from cryptoai.routes.feed import router as feed_router from cryptoai.routes.user import router as user_router +from cryptoai.routes.adata import router as adata_router from cryptoai.routes.question import router as question_router # 配置日志 @@ -52,6 +53,7 @@ app.include_router(agent_router, prefix="/agent") app.include_router(feed_router, prefix="/feed", tags=["AI Agent信息流"]) app.include_router(user_router, prefix="/user", tags=["用户管理"]) app.include_router(question_router, prefix="/question", tags=["用户提问"]) +app.include_router(adata_router, prefix="/adata", tags=["A股数据"]) # 请求计时中间件 @app.middleware("http") diff --git a/docker-compose.yml b/docker-compose.yml index b14baeb..409dc6b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.0.14 + image: cryptoai-api:0.0.15 restart: always ports: - "8000:8000" diff --git a/requirements.txt b/requirements.txt index 4d1bc9d..4a88e74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ python-binance==1.0.16 -pandas==2.0.3 -numpy==1.24.3 +pandas>=1.3.0 +numpy>=1.20.0 sqlalchemy==2.0.19 pymysql==1.1.0 requests==2.31.0 @@ -14,6 +14,7 @@ pyjwt==2.8.0 python-multipart==0.0.9 email-validator==2.1.0 tencentcloud-sdk-python==3.0.1030 +adata # # 日志相关 # logging==0.4.9.6 # # 数据处理相关