#!/usr/bin/env python # -*- coding: utf-8 -*- from typing import Dict, Any, List, Optional, Union from datetime import datetime from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy.orm import Session from cryptoai.models.base import Base, logger from cryptoai.utils.db_utils import convert_timestamp_to_datetime from cryptoai.utils.db_manager import get_db_context # 定义 A 股数据模型 class AStock(Base): """A股股票基本信息表模型""" __tablename__ = 'astock' stock_code = Column(String(10), primary_key=True, comment='股票代码') short_name = Column(String(50), nullable=False, comment='股票简称') exchange = Column(String(20), nullable=True, comment='交易所') list_date = Column(DateTime, nullable=True, comment='上市日期') created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 索引 __table_args__ = ( Index('idx_stock_code', 'stock_code', unique=True), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) class AStockManager: """A股管理类""" def __init__(self, session: Session = None): self.session = session def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: """ 批量创建股票信息 Args: stocks: 股票信息列表,每个元素为包含以下键的字典: - stock_code: 股票代码 - short_name: 股票简称 - exchange: 交易所(可选) - list_date: 上市日期(可选,Unix时间戳,毫秒级) Returns: 创建是否成功 """ try: for stock_data in stocks: # 检查股票代码是否已存在 existing_stock = self.session.query(AStock).filter( AStock.stock_code == stock_data['stock_code'] ).first() if existing_stock: logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过") continue # 创建新股票记录 new_stock = AStock( stock_code=stock_data['stock_code'], short_name=stock_data['short_name'], exchange=stock_data.get('exchange') ) # 处理上市日期 if 'list_date' in stock_data and stock_data['list_date']: list_date = convert_timestamp_to_datetime(stock_data['list_date']) if list_date: new_stock.list_date = list_date self.session.add(new_stock) # 批量提交 self.session.commit() logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录") return True except Exception as e: self.session.rollback() logger.error(f"批量创建股票信息失败: {e}") return False def create_stock(self, stock_code: str, short_name: str, exchange: Optional[str] = None, list_date: Optional[Union[int, float, str]] = None) -> bool: """ 创建新的股票信息 Args: stock_code: 股票代码 short_name: 股票简称 exchange: 交易所(可选) list_date: 上市日期(可选,Unix时间戳,毫秒级) Returns: 创建是否成功 """ try: # 检查股票代码是否已存在 existing_stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if existing_stock: logger.warning(f"股票代码 {stock_code} 已存在") return False # 创建新股票记录 new_stock = AStock( stock_code=stock_code, short_name=short_name, exchange=exchange ) # 处理上市日期 if list_date: converted_date = convert_timestamp_to_datetime(list_date) if converted_date: new_stock.list_date = converted_date self.session.add(new_stock) self.session.commit() logger.info(f"成功创建股票信息: {stock_code} - {short_name}") return True except Exception as e: self.session.rollback() logger.error(f"创建股票信息失败: {e}") return False def update_stock(self, stock_code: str, short_name: Optional[str] = None, exchange: Optional[str] = None, list_date: Optional[Union[int, float, str, datetime]] = None) -> bool: """ 更新股票信息 Args: stock_code: 股票代码 short_name: 股票简称 exchange: 交易所 list_date: 上市日期(可以是datetime对象或时间戳) Returns: 更新是否成功 """ try: # 查询股票 stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if not stock: logger.warning(f"股票代码 {stock_code} 不存在") return False # 更新字段 if short_name is not None: stock.short_name = short_name if exchange is not None: stock.exchange = exchange if list_date is not None: if isinstance(list_date, datetime): stock.list_date = list_date else: converted_date = convert_timestamp_to_datetime(list_date) if converted_date: stock.list_date = converted_date self.session.commit() logger.info(f"成功更新股票信息: {stock_code}") return True except Exception as e: self.session.rollback() logger.error(f"更新股票信息失败: {e}") return False def delete_stock(self, stock_code: str) -> bool: """ 删除股票信息 Args: stock_code: 股票代码 Returns: 删除是否成功 """ try: # 查询股票 stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if not stock: logger.warning(f"股票代码 {stock_code} 不存在") return False # 删除股票 self.session.delete(stock) self.session.commit() logger.info(f"成功删除股票信息: {stock_code}") return True except Exception as e: self.session.rollback() logger.error(f"删除股票信息失败: {e}") return False def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]: """ 通过股票代码获取股票信息 Args: stock_code: 股票代码 Returns: 股票信息,如果不存在则返回None """ try: # 查询股票 stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() if stock: return { 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } else: return None except Exception as e: logger.error(f"获取股票信息失败: {e}") return None def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: """ 搜索股票 Args: key: 搜索关键词 limit: 最大返回数量 Returns: 股票信息列表 """ try: # 查询股票 stocks = self.session.query(AStock).filter( AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%") ).limit(limit).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] except Exception as e: logger.error(f"获取股票信息失败: {e}") return [] def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]: """ 通过股票简称获取股票信息(可能有多个结果) Args: short_name: 股票简称 Returns: 股票信息列表 """ try: # 查询股票 stocks = self.session.query(AStock).filter(AStock.short_name == short_name).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] except Exception as e: logger.error(f"获取股票信息失败: {e}") return [] def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: """ 获取股票列表 Args: limit: 返回的最大数量 skip: 跳过的数量 Returns: 股票信息列表 """ try: # 查询股票列表 stocks = self.session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] except Exception as e: logger.error(f"获取股票列表失败: {e}") return []