#!/usr/bin/env python # -*- coding: utf-8 -*- from typing import Dict, Any, List, Optional, Union from datetime import datetime from sqlalchemy import Column, Integer, String, DateTime, Index from cryptoai.models.base import Base, logger from cryptoai.utils.db_utils import convert_timestamp_to_datetime # 定义 A 股数据模型 class AStock(Base): """A股股票基本信息表模型""" __tablename__ = 'astock' stock_code = Column(String(10), primary_key=True, comment='股票代码') short_name = Column(String(50), nullable=False, comment='股票简称') exchange = Column(String(20), nullable=True, comment='交易所') list_date = Column(DateTime, nullable=True, comment='上市日期') created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 索引 __table_args__ = ( Index('idx_stock_code', 'stock_code', unique=True), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) class AStockManager: """A股管理类""" def __init__(self, db_session): self.session = db_session def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: """ 批量创建股票信息 Args: stocks: 股票信息列表,每个元素为包含以下键的字典: - stock_code: 股票代码 - short_name: 股票简称 - exchange: 交易所(可选) - list_date: 上市日期(可选,Unix时间戳,毫秒级) Returns: 创建是否成功 """ try: for stock_data in stocks: # 检查股票代码是否已存在 existing_stock = self.session.query(AStock).filter( AStock.stock_code == stock_data['stock_code'] ).first() if existing_stock: logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过") continue # 创建新股票记录 new_stock = AStock( stock_code=stock_data['stock_code'], short_name=stock_data['short_name'], exchange=stock_data.get('exchange') ) # 处理上市日期 if 'list_date' in stock_data and stock_data['list_date']: list_date = convert_timestamp_to_datetime(stock_data['list_date']) if list_date: new_stock.list_date = list_date self.session.add(new_stock) # 批量提交 self.session.commit() logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录") return True except Exception as e: self.session.rollback() logger.error(f"批量创建股票信息失败: {e}") return False def create_stock(self, stock_code: str, short_name: str, exchange: Optional[str] = None, list_date: Optional[Union[int, float, str]] = None) -> bool: """ 创建新的股票信息 Args: stock_code: 股票代码 short_name: 股票简称 exchange: 交易所(可选) list_date: 上市日期(可选,Unix时间戳,毫秒级) Returns: 创建是否成功 """ try: # 检查股票代码是否已存在 existing_stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if existing_stock: logger.warning(f"股票代码 {stock_code} 已存在") return False # 创建新股票记录 new_stock = AStock( stock_code=stock_code, short_name=short_name, exchange=exchange ) # 处理上市日期 if list_date: converted_date = convert_timestamp_to_datetime(list_date) if converted_date: new_stock.list_date = converted_date self.session.add(new_stock) self.session.commit() logger.info(f"成功创建股票信息: {stock_code} - {short_name}") return True except Exception as e: self.session.rollback() logger.error(f"创建股票信息失败: {e}") return False def update_stock(self, stock_code: str, short_name: Optional[str] = None, exchange: Optional[str] = None, list_date: Optional[Union[int, float, str, datetime]] = None) -> bool: """ 更新股票信息 Args: stock_code: 股票代码 short_name: 股票简称 exchange: 交易所 list_date: 上市日期(可以是datetime对象或时间戳) Returns: 更新是否成功 """ try: # 查询股票 stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if not stock: logger.warning(f"股票代码 {stock_code} 不存在") return False # 更新字段 if short_name is not None: stock.short_name = short_name if exchange is not None: stock.exchange = exchange if list_date is not None: if isinstance(list_date, datetime): stock.list_date = list_date else: converted_date = convert_timestamp_to_datetime(list_date) if converted_date: stock.list_date = converted_date self.session.commit() logger.info(f"成功更新股票信息: {stock_code}") return True except Exception as e: self.session.rollback() logger.error(f"更新股票信息失败: {e}") return False def delete_stock(self, stock_code: str) -> bool: """ 删除股票信息 Args: stock_code: 股票代码 Returns: 删除是否成功 """ try: # 查询股票 stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if not stock: logger.warning(f"股票代码 {stock_code} 不存在") return False # 删除股票 self.session.delete(stock) self.session.commit() logger.info(f"成功删除股票信息: {stock_code}") return True except Exception as e: self.session.rollback() logger.error(f"删除股票信息失败: {e}") return False def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]: """ 通过股票代码获取股票信息 Args: stock_code: 股票代码 Returns: 股票信息,如果不存在则返回None """ try: # 查询股票 stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if stock: return { 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } else: return None except Exception as e: logger.error(f"获取股票信息失败: {e}") return None def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: """ 搜索股票 Args: key: 搜索关键词 limit: 最大返回数量 Returns: 股票信息列表 """ try: # 查询股票 stocks = self.session.query(AStock).filter( AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%") ).limit(limit).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] except Exception as e: logger.error(f"获取股票信息失败: {e}") return [] def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]: """ 通过股票简称获取股票信息(可能有多个结果) Args: short_name: 股票简称 Returns: 股票信息列表 """ try: # 查询股票 stocks = self.session.query(AStock).filter(AStock.short_name == short_name).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] except Exception as e: logger.error(f"获取股票信息失败: {e}") return [] def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: """ 获取股票列表 Args: limit: 返回的最大数量 skip: 跳过的数量 Returns: 股票信息列表 """ try: # 查询股票列表 stocks = self.session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] except Exception as e: logger.error(f"获取股票列表失败: {e}") return []