crypto.ai/cryptoai/models/astock.py
2025-05-30 22:09:45 +08:00

315 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Index
from sqlalchemy.orm import Session
from cryptoai.models.base import Base, logger
from cryptoai.utils.db_utils import convert_timestamp_to_datetime
from cryptoai.utils.db_manager import get_db_context
# 定义 A 股数据模型
class AStock(Base):
"""A股股票基本信息表模型"""
__tablename__ = 'astock'
stock_code = Column(String(10), primary_key=True, comment='股票代码')
short_name = Column(String(50), nullable=False, comment='股票简称')
exchange = Column(String(20), nullable=True, comment='交易所')
list_date = Column(DateTime, nullable=True, comment='上市日期')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引
__table_args__ = (
Index('idx_stock_code', 'stock_code', unique=True),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class AStockManager:
"""A股管理类"""
def __init__(self, session: Session = None):
self.session = session
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
"""
批量创建股票信息
Args:
stocks: 股票信息列表,每个元素为包含以下键的字典:
- stock_code: 股票代码
- short_name: 股票简称
- exchange: 交易所(可选)
- list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
try:
for stock_data in stocks:
# 检查股票代码是否已存在
existing_stock = self.session.query(AStock).filter(
AStock.stock_code == stock_data['stock_code']
).first()
if existing_stock:
logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过")
continue
# 创建新股票记录
new_stock = AStock(
stock_code=stock_data['stock_code'],
short_name=stock_data['short_name'],
exchange=stock_data.get('exchange')
)
# 处理上市日期
if 'list_date' in stock_data and stock_data['list_date']:
list_date = convert_timestamp_to_datetime(stock_data['list_date'])
if list_date:
new_stock.list_date = list_date
self.session.add(new_stock)
# 批量提交
self.session.commit()
logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录")
return True
except Exception as e:
self.session.rollback()
logger.error(f"批量创建股票信息失败: {e}")
return False
def create_stock(self, stock_code: str, short_name: str,
exchange: Optional[str] = None,
list_date: Optional[Union[int, float, str]] = None) -> bool:
"""
创建新的股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所(可选)
list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
try:
# 检查股票代码是否已存在
existing_stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if existing_stock:
logger.warning(f"股票代码 {stock_code} 已存在")
return False
# 创建新股票记录
new_stock = AStock(
stock_code=stock_code,
short_name=short_name,
exchange=exchange
)
# 处理上市日期
if list_date:
converted_date = convert_timestamp_to_datetime(list_date)
if converted_date:
new_stock.list_date = converted_date
self.session.add(new_stock)
self.session.commit()
logger.info(f"成功创建股票信息: {stock_code} - {short_name}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"创建股票信息失败: {e}")
return False
def update_stock(self, stock_code: str, short_name: Optional[str] = None,
exchange: Optional[str] = None,
list_date: Optional[Union[int, float, str, datetime]] = None) -> bool:
"""
更新股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所
list_date: 上市日期可以是datetime对象或时间戳
Returns:
更新是否成功
"""
try:
# 查询股票
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 更新字段
if short_name is not None:
stock.short_name = short_name
if exchange is not None:
stock.exchange = exchange
if list_date is not None:
if isinstance(list_date, datetime):
stock.list_date = list_date
else:
converted_date = convert_timestamp_to_datetime(list_date)
if converted_date:
stock.list_date = converted_date
self.session.commit()
logger.info(f"成功更新股票信息: {stock_code}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"更新股票信息失败: {e}")
return False
def delete_stock(self, stock_code: str) -> bool:
"""
删除股票信息
Args:
stock_code: 股票代码
Returns:
删除是否成功
"""
try:
# 查询股票
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 删除股票
self.session.delete(stock)
self.session.commit()
logger.info(f"成功删除股票信息: {stock_code}")
return True
except Exception as e:
self.session.rollback()
logger.error(f"删除股票信息失败: {e}")
return False
def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]:
"""
通过股票代码获取股票信息
Args:
stock_code: 股票代码
Returns:
股票信息如果不存在则返回None
"""
try:
# 查询股票
stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first()
if stock:
return {
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
}
else:
return None
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return None
def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
搜索股票
Args:
key: 搜索关键词
limit: 最大返回数量
Returns:
股票信息列表
"""
try:
# 查询股票
stocks = self.session.query(AStock).filter(
AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")
).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]:
"""
通过股票简称获取股票信息(可能有多个结果)
Args:
short_name: 股票简称
Returns:
股票信息列表
"""
try:
# 查询股票
stocks = self.session.query(AStock).filter(AStock.short_name == short_name).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取股票列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
股票信息列表
"""
try:
# 查询股票列表
stocks = self.session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
except Exception as e:
logger.error(f"获取股票列表失败: {e}")
return []