From 95d9548bb9bbbf85390910eef982284621a76fce Mon Sep 17 00:00:00 2001 From: aaron <> Date: Tue, 13 May 2025 23:50:54 +0800 Subject: [PATCH] update --- cryptoai/api/adata_api.py | 58 +++-- cryptoai/routes/adata.py | 104 +++++++-- cryptoai/utils/db_manager.py | 433 +++++++++++++++++++++++++++++++++++ docker-compose.yml | 2 +- test.py | 23 +- 5 files changed, 568 insertions(+), 52 deletions(-) diff --git a/cryptoai/api/adata_api.py b/cryptoai/api/adata_api.py index 53ea5e9..9a5e20d 100644 --- a/cryptoai/api/adata_api.py +++ b/cryptoai/api/adata_api.py @@ -128,31 +128,43 @@ if __name__ == "__main__": print("开始获取A股数据") api = AStockAPI() - + + stock_codes = api.get_all_stock_codes() + list = stock_codes.to_json(orient="records") + print(list) + + # 保存到数据库 + import cryptoai.utils.db_manager as db_manager + db_manager.get_db_manager().create_stocks(list) + + + # 获取所有股票代码 - stock_codes = ["688552", "688648", "688165","688552"] - print(f"获取到 {len(stock_codes)} 个股票代码") - # 获取指定股票的市场数据 - if stock_codes: - sample_codes = stock_codes[:3] # 取前3个股票作为示例 - print(f"获取到 {len(sample_codes)} 个股票代码") - today = datetime.now().strftime("%Y-%m-%d") - week_ago = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") + # # 获取所有股票代码 + # stock_codes = ["688552", "688648", "688165","688552"] + # print(f"获取到 {len(stock_codes)} 个股票代码") + + # # 获取指定股票的市场数据 + # if stock_codes: + # sample_codes = stock_codes[:3] # 取前3个股票作为示例 + # print(f"获取到 {len(sample_codes)} 个股票代码") + # today = datetime.now().strftime("%Y-%m-%d") + # week_ago = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") - # 获取日线数据 - market_data = api.get_market_data(stock_code=sample_codes[0], - start_date=week_ago, - end_date=today) - print("\n日线数据示例:") - print(market_data.head()) + # # 获取日线数据 + # market_data = api.get_market_data(stock_code=sample_codes[0], + # start_date=week_ago, + # end_date=today) + # print("\n日线数据示例:") + # print(market_data.head()) - # 获取分钟线数据 - min_data = api.get_market_min_data(stock_code=sample_codes[0]) - print("\n分钟线数据示例:") - print(min_data.head()) + # # 获取分钟线数据 + # min_data = api.get_market_min_data(stock_code=sample_codes[0]) + # print("\n分钟线数据示例:") + # print(min_data.head()) - # 获取资金流向数据 - flow_data = api.get_capital_flow_min(stock_code=sample_codes[0]) - print("\n资金流向数据示例:") - print(flow_data.head()) \ No newline at end of file + # # 获取资金流向数据 + # flow_data = api.get_capital_flow_min(stock_code=sample_codes[0]) + # print("\n资金流向数据示例:") + # print(flow_data.head()) \ No newline at end of file diff --git a/cryptoai/routes/adata.py b/cryptoai/routes/adata.py index e975c8f..50d9a55 100644 --- a/cryptoai/routes/adata.py +++ b/cryptoai/routes/adata.py @@ -1,41 +1,103 @@ import json - +import logging from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path from cryptoai.api.adata_api import AStockAPI +from cryptoai.utils.db_manager import get_db_manager # 创建路由 router = APIRouter() +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) -@router.get("/stock/data") -async def get_stock_data(stock_code: str): +@router.get("/stock/search") +async def search_stock(key: str, limit: int = 10): + manager = get_db_manager() + result = manager.search_stock(key, limit) + + return result + + +@router.get("/stock/base", summary="获取股票基础信息") +async def get_stock_base(stock_code: str): + api = AStockAPI() result = {} + try: + # 获取股本信息 + stock_shares = api.get_stock_shares(stock_code) + result["stock_shares"] = json.loads(stock_shares.to_json(orient="records")) + + # 获取概念板块 + concept_east = api.get_concept_east(stock_code) + result["concept_east"] = json.loads(concept_east.to_json(orient="records")) + + # 获取板块 + plate_east = api.get_plate_east(stock_code) + result["plate_east"] = json.loads(plate_east.to_json(orient="records")) + + except Exception as e: + logger.error(f"获取股票基础信息失败: {e}") + return {} + + return result + +@router.get("/stock/data", summary="获取股票数据") +async def get_stock_data(stock_code: str): api = AStockAPI() - # 获取股本信息 - # stock_shares = api.get_stock_shares(stock_code) - # result["stock_shares"] = json.loads(stock_shares.to_json(orient="records")) + result = {} - # # 获取概念板块 - # concept_east = api.get_concept_east(stock_code) - # result["concept_east"] = json.loads(concept_east.to_json(orient="records")) + try: + # 获取市场数据 + market_data = api.get_market_data(stock_code) + result["market_data"] = json.loads(market_data.to_json(orient="records")) - # 获取板块 - plate_east = api.get_plate_east(stock_code) - result["plate_east"] = json.loads(plate_east.to_json(orient="records")) + # 获取资金流向数据 + flow_data = api.get_capital_flow(stock_code) + result["flow_data"] = json.loads(flow_data.to_json(orient="records")) - # 获取市场数据 - market_data = api.get_market_data(stock_code) - result["market_data"] = json.loads(market_data.to_json(orient="records")) + except Exception as e: + logger.error(f"获取股票数据失败: {e}") + return {} - # # 获取分钟线数据 - # min_data = api.get_market_min_data(stock_code) - # result["min_data"] = json.loads(min_data.to_json(orient="records")) - # 获取资金流向数据 - flow_data = api.get_capital_flow(stock_code) - result["flow_data"] = json.loads(flow_data.to_json(orient="records")) +@router.get("/stock/data/all", summary="获取所有股票数据") +async def get_stock_data_all(stock_code: str): + result = {} + + + try: + api = AStockAPI() + + #获取股本信息 + stock_shares = api.get_stock_shares(stock_code) + result["stock_shares"] = json.loads(stock_shares.to_json(orient="records")) + + # 获取概念板块 + concept_east = api.get_concept_east(stock_code) + result["concept_east"] = json.loads(concept_east.to_json(orient="records")) + + # 获取板块 + plate_east = api.get_plate_east(stock_code) + result["plate_east"] = json.loads(plate_east.to_json(orient="records")) + + # 获取市场数据 + market_data = api.get_market_data(stock_code) + result["market_data"] = json.loads(market_data.to_json(orient="records")) + + # 获取分钟线数据 + min_data = api.get_market_min_data(stock_code) + result["min_data"] = json.loads(min_data.to_json(orient="records")) + + # 获取资金流向数据 + flow_data = api.get_capital_flow(stock_code) + result["flow_data"] = json.loads(flow_data.to_json(orient="records")) + + except Exception as e: + logger.error(f"获取股票数据失败: {e}") + return {} + return result diff --git a/cryptoai/utils/db_manager.py b/cryptoai/utils/db_manager.py index 29864ba..f3db435 100644 --- a/cryptoai/utils/db_manager.py +++ b/cryptoai/utils/db_manager.py @@ -131,6 +131,50 @@ class Agent(Base): {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) +# 定义 A 股数据模型 +class AStock(Base): + """A股股票基本信息表模型""" + __tablename__ = 'astock' + + stock_code = Column(String(10), primary_key=True, comment='股票代码') + short_name = Column(String(50), nullable=False, comment='股票简称') + exchange = Column(String(20), nullable=True, comment='交易所') + list_date = Column(DateTime, nullable=True, comment='上市日期') + created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 索引 + __table_args__ = ( + Index('idx_stock_code', 'stock_code', unique=True), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +def _convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]: + """ + 将时间戳转换为datetime对象 + + Args: + timestamp: Unix时间戳(毫秒级) + + Returns: + datetime对象或None + """ + if timestamp is None: + return None + + try: + # 转换为整数 + if isinstance(timestamp, str): + timestamp = int(timestamp) + + # 如果是毫秒级时间戳,转换为秒级 + if timestamp > 1e11: # 判断是否为毫秒级时间戳 + timestamp = timestamp / 1000 + + return datetime.fromtimestamp(timestamp) + except (ValueError, TypeError) as e: + logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}") + return None + class DBManager: """数据库管理工具,用于连接MySQL数据库并保存智能体分析结果""" @@ -1215,6 +1259,395 @@ class DBManager: except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False + + def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: + """ + 批量创建股票信息 + + Args: + stocks: 股票信息列表,每个元素为包含以下键的字典: + - stock_code: 股票代码 + - short_name: 股票简称 + - exchange: 交易所(可选) + - list_date: 上市日期(可选,Unix时间戳,毫秒级) + + Returns: + 创建是否成功 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + session = self.Session() + + try: + for stock_data in stocks: + # 检查股票代码是否已存在 + existing_stock = session.query(AStock).filter( + AStock.stock_code == stock_data['stock_code'] + ).first() + + if existing_stock: + logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过") + continue + + # 创建新股票记录 + new_stock = AStock( + stock_code=stock_data['stock_code'], + short_name=stock_data['short_name'], + exchange=stock_data.get('exchange') + ) + + # 处理上市日期 + if 'list_date' in stock_data and stock_data['list_date']: + list_date = _convert_timestamp_to_datetime(stock_data['list_date']) + if list_date: + new_stock.list_date = list_date + + session.add(new_stock) + + # 批量提交 + session.commit() + logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录") + return True + + except Exception as e: + session.rollback() + logger.error(f"批量创建股票信息失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False + + def create_stock(self, stock_code: str, short_name: str, + exchange: Optional[str] = None, + list_date: Optional[Union[int, float, str]] = None) -> bool: + """ + 创建新的股票信息 + + Args: + stock_code: 股票代码 + short_name: 股票简称 + exchange: 交易所(可选) + list_date: 上市日期(可选,Unix时间戳,毫秒级) + + Returns: + 创建是否成功 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + session = self.Session() + + try: + # 检查股票代码是否已存在 + existing_stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() + if existing_stock: + logger.warning(f"股票代码 {stock_code} 已存在") + return False + + # 创建新股票记录 + new_stock = AStock( + stock_code=stock_code, + short_name=short_name, + exchange=exchange + ) + + # 处理上市日期 + if list_date: + converted_date = _convert_timestamp_to_datetime(list_date) + if converted_date: + new_stock.list_date = converted_date + + session.add(new_stock) + session.commit() + + logger.info(f"成功创建股票信息: {stock_code} - {short_name}") + return True + + except Exception as e: + session.rollback() + logger.error(f"创建股票信息失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False + + def update_stock(self, stock_code: str, short_name: Optional[str] = None, + exchange: Optional[str] = None, list_date: Optional[datetime] = None) -> bool: + """ + 更新股票信息 + + Args: + stock_code: 股票代码 + short_name: 股票简称 + exchange: 交易所 + list_date: 上市日期 + + Returns: + 更新是否成功 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + session = self.Session() + + try: + # 查询股票 + stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() + if not stock: + logger.warning(f"股票代码 {stock_code} 不存在") + return False + + # 更新字段 + if short_name is not None: + stock.short_name = short_name + if exchange is not None: + stock.exchange = exchange + if list_date is not None: + stock.list_date = list_date + + stock.updated_at = datetime.now() + + session.commit() + + logger.info(f"成功更新股票信息: {stock_code}") + return True + + except Exception as e: + session.rollback() + logger.error(f"更新股票信息失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False + + def delete_stock(self, stock_code: str) -> bool: + """ + 删除股票信息 + + Args: + stock_code: 股票代码 + + Returns: + 删除是否成功 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + session = self.Session() + + try: + # 查询股票 + stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() + if not stock: + logger.warning(f"股票代码 {stock_code} 不存在") + return False + + # 删除股票 + session.delete(stock) + session.commit() + + logger.info(f"成功删除股票信息: {stock_code}") + return True + + except Exception as e: + session.rollback() + logger.error(f"删除股票信息失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False + + def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]: + """ + 通过股票代码获取股票信息 + + Args: + stock_code: 股票代码 + + Returns: + 股票信息,如果不存在则返回None + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return None + + try: + session = self.Session() + + try: + # 查询股票 + stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() + + if stock: + return { + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at, + 'updated_at': stock.updated_at + } + else: + return None + + finally: + session.close() + + except Exception as e: + logger.error(f"获取股票信息失败: {e}") + return None + + def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: + """ + 搜索股票 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return [] + + try: + session = self.Session() + + try: + # 查询股票 + stocks = session.query(AStock).filter(AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")).limit(limit).all() + + return [{ + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at + } for stock in stocks] + + finally: + session.close() + + except Exception as e: + logger.error(f"获取股票信息失败: {e}") + return [] + + def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]: + """ + 通过股票简称获取股票信息(可能有多个结果) + + Args: + short_name: 股票简称 + + Returns: + 股票信息列表 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return [] + + try: + session = self.Session() + + try: + # 查询股票 + stocks = session.query(AStock).filter(AStock.short_name == short_name).all() + + return [{ + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at + } for stock in stocks] + + finally: + session.close() + + except Exception as e: + logger.error(f"获取股票信息失败: {e}") + return [] + + def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: + """ + 获取股票列表 + + Args: + limit: 返回的最大数量 + skip: 跳过的数量 + + Returns: + 股票信息列表 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return [] + + try: + session = self.Session() + + try: + # 查询股票列表 + stocks = session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all() + + return [{ + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at + } for stock in stocks] + + finally: + session.close() + + except Exception as e: + logger.error(f"获取股票列表失败: {e}") + return [] # 单例模式 diff --git a/docker-compose.yml b/docker-compose.yml index 20a1dbf..defc159 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.0.18 + image: cryptoai-api:0.0.19 restart: always ports: - "8000:8000" diff --git a/test.py b/test.py index 80c57dc..0f907e0 100644 --- a/test.py +++ b/test.py @@ -1,8 +1,17 @@ -from cryptoai.api.binance_api import BinanceAPI -from cryptoai.utils.config_loader import ConfigLoader +from cryptoai.utils.db_manager import get_db_manager +from cryptoai.api.adata_api import AStockAPI +import json +from time import sleep +if __name__ == "__main__": + print("开始获取A股数据") + api = AStockAPI() + stock_codes = api.get_all_stock_codes() + list = json.loads(stock_codes.to_json(orient="records")) + # print(list[0]) -config = ConfigLoader().get_binance_config() - -binance_api = BinanceAPI(config['api_key'], config['api_secret']) - -print(binance_api.get_top_longshort_position_ratio("BTCUSDT", "1h")) \ No newline at end of file + # 保存到数据库 + for stock in list: + print(f"创建股票: {stock['stock_code']} - {stock['short_name']}") + get_db_manager().create_stock(stock["stock_code"], stock["short_name"], stock["exchange"], stock["list_date"]) + print(f"创建股票: {stock['stock_code']} - {stock['short_name']} 完成") + sleep(1) \ No newline at end of file