315 lines
11 KiB
Python
315 lines
11 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||
from sqlalchemy.orm import Session
|
||
from cryptoai.models.base import Base, logger
|
||
from cryptoai.utils.db_utils import convert_timestamp_to_datetime
|
||
from cryptoai.utils.db_manager import get_db_context
|
||
|
||
# 定义 A 股数据模型
|
||
class AStock(Base):
|
||
"""A股股票基本信息表模型"""
|
||
__tablename__ = 'astock'
|
||
|
||
stock_code = Column(String(10), primary_key=True, comment='股票代码')
|
||
short_name = Column(String(50), nullable=False, comment='股票简称')
|
||
exchange = Column(String(20), nullable=True, comment='交易所')
|
||
list_date = Column(DateTime, nullable=True, comment='上市日期')
|
||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 索引
|
||
__table_args__ = (
|
||
Index('idx_stock_code', 'stock_code', unique=True),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
class AStockManager:
|
||
"""A股管理类"""
|
||
|
||
def __init__(self, session: Session = None):
|
||
self.session = session
|
||
|
||
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
|
||
"""
|
||
批量创建股票信息
|
||
|
||
Args:
|
||
stocks: 股票信息列表,每个元素为包含以下键的字典:
|
||
- stock_code: 股票代码
|
||
- short_name: 股票简称
|
||
- exchange: 交易所(可选)
|
||
- list_date: 上市日期(可选,Unix时间戳,毫秒级)
|
||
|
||
Returns:
|
||
创建是否成功
|
||
"""
|
||
try:
|
||
for stock_data in stocks:
|
||
# 检查股票代码是否已存在
|
||
existing_stock = self.session.query(AStock).filter(
|
||
AStock.stock_code == stock_data['stock_code']
|
||
).first()
|
||
|
||
if existing_stock:
|
||
logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过")
|
||
continue
|
||
|
||
# 创建新股票记录
|
||
new_stock = AStock(
|
||
stock_code=stock_data['stock_code'],
|
||
short_name=stock_data['short_name'],
|
||
exchange=stock_data.get('exchange')
|
||
)
|
||
|
||
# 处理上市日期
|
||
if 'list_date' in stock_data and stock_data['list_date']:
|
||
list_date = convert_timestamp_to_datetime(stock_data['list_date'])
|
||
if list_date:
|
||
new_stock.list_date = list_date
|
||
|
||
self.session.add(new_stock)
|
||
|
||
# 批量提交
|
||
self.session.commit()
|
||
logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"批量创建股票信息失败: {e}")
|
||
return False
|
||
|
||
def create_stock(self, stock_code: str, short_name: str,
|
||
exchange: Optional[str] = None,
|
||
list_date: Optional[Union[int, float, str]] = None) -> bool:
|
||
"""
|
||
创建新的股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
short_name: 股票简称
|
||
exchange: 交易所(可选)
|
||
list_date: 上市日期(可选,Unix时间戳,毫秒级)
|
||
|
||
Returns:
|
||
创建是否成功
|
||
"""
|
||
try:
|
||
# 检查股票代码是否已存在
|
||
existing_stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
if existing_stock:
|
||
logger.warning(f"股票代码 {stock_code} 已存在")
|
||
return False
|
||
|
||
# 创建新股票记录
|
||
new_stock = AStock(
|
||
stock_code=stock_code,
|
||
short_name=short_name,
|
||
exchange=exchange
|
||
)
|
||
|
||
# 处理上市日期
|
||
if list_date:
|
||
converted_date = convert_timestamp_to_datetime(list_date)
|
||
if converted_date:
|
||
new_stock.list_date = converted_date
|
||
|
||
self.session.add(new_stock)
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功创建股票信息: {stock_code} - {short_name}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"创建股票信息失败: {e}")
|
||
return False
|
||
|
||
def update_stock(self, stock_code: str, short_name: Optional[str] = None,
|
||
exchange: Optional[str] = None,
|
||
list_date: Optional[Union[int, float, str, datetime]] = None) -> bool:
|
||
"""
|
||
更新股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
short_name: 股票简称
|
||
exchange: 交易所
|
||
list_date: 上市日期(可以是datetime对象或时间戳)
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
try:
|
||
# 查询股票
|
||
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
if not stock:
|
||
logger.warning(f"股票代码 {stock_code} 不存在")
|
||
return False
|
||
|
||
# 更新字段
|
||
if short_name is not None:
|
||
stock.short_name = short_name
|
||
if exchange is not None:
|
||
stock.exchange = exchange
|
||
if list_date is not None:
|
||
if isinstance(list_date, datetime):
|
||
stock.list_date = list_date
|
||
else:
|
||
converted_date = convert_timestamp_to_datetime(list_date)
|
||
if converted_date:
|
||
stock.list_date = converted_date
|
||
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功更新股票信息: {stock_code}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"更新股票信息失败: {e}")
|
||
return False
|
||
|
||
def delete_stock(self, stock_code: str) -> bool:
|
||
"""
|
||
删除股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
删除是否成功
|
||
"""
|
||
try:
|
||
# 查询股票
|
||
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
if not stock:
|
||
logger.warning(f"股票代码 {stock_code} 不存在")
|
||
return False
|
||
|
||
# 删除股票
|
||
self.session.delete(stock)
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功删除股票信息: {stock_code}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"删除股票信息失败: {e}")
|
||
return False
|
||
|
||
def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过股票代码获取股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
股票信息,如果不存在则返回None
|
||
"""
|
||
try:
|
||
# 查询股票
|
||
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
|
||
if stock:
|
||
return {
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票信息失败: {e}")
|
||
return None
|
||
|
||
def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||
"""
|
||
搜索股票
|
||
|
||
Args:
|
||
key: 搜索关键词
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
股票信息列表
|
||
"""
|
||
try:
|
||
# 查询股票
|
||
stocks = self.session.query(AStock).filter(
|
||
AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")
|
||
).limit(limit).all()
|
||
|
||
return [{
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at
|
||
} for stock in stocks]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票信息失败: {e}")
|
||
return []
|
||
|
||
def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
通过股票简称获取股票信息(可能有多个结果)
|
||
|
||
Args:
|
||
short_name: 股票简称
|
||
|
||
Returns:
|
||
股票信息列表
|
||
"""
|
||
try:
|
||
# 查询股票
|
||
stocks = self.session.query(AStock).filter(AStock.short_name == short_name).all()
|
||
|
||
return [{
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at
|
||
} for stock in stocks]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票信息失败: {e}")
|
||
return []
|
||
|
||
def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取股票列表
|
||
|
||
Args:
|
||
limit: 返回的最大数量
|
||
skip: 跳过的数量
|
||
|
||
Returns:
|
||
股票信息列表
|
||
"""
|
||
try:
|
||
# 查询股票列表
|
||
stocks = self.session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all()
|
||
|
||
return [{
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at
|
||
} for stock in stocks]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票列表失败: {e}")
|
||
return [] |