This commit is contained in:
aaron 2025-05-13 23:50:54 +08:00
parent 12ab5f0ecc
commit 95d9548bb9
5 changed files with 568 additions and 52 deletions

View File

@ -129,30 +129,42 @@ if __name__ == "__main__":
print("开始获取A股数据")
api = AStockAPI()
stock_codes = api.get_all_stock_codes()
list = stock_codes.to_json(orient="records")
print(list)
# 保存到数据库
import cryptoai.utils.db_manager as db_manager
db_manager.get_db_manager().create_stocks(list)
# 获取所有股票代码
stock_codes = ["688552", "688648", "688165","688552"]
print(f"获取到 {len(stock_codes)} 个股票代码")
# 获取指定股票的市场数据
if stock_codes:
sample_codes = stock_codes[:3] # 取前3个股票作为示例
print(f"获取到 {len(sample_codes)} 个股票代码")
today = datetime.now().strftime("%Y-%m-%d")
week_ago = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# # 获取所有股票代码
# stock_codes = ["688552", "688648", "688165","688552"]
# print(f"获取到 {len(stock_codes)} 个股票代码")
# 获取日线数据
market_data = api.get_market_data(stock_code=sample_codes[0],
start_date=week_ago,
end_date=today)
print("\n日线数据示例:")
print(market_data.head())
# # 获取指定股票的市场数据
# if stock_codes:
# sample_codes = stock_codes[:3] # 取前3个股票作为示例
# print(f"获取到 {len(sample_codes)} 个股票代码")
# today = datetime.now().strftime("%Y-%m-%d")
# week_ago = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 获取分钟线数据
min_data = api.get_market_min_data(stock_code=sample_codes[0])
print("\n分钟线数据示例:")
print(min_data.head())
# # 获取日线数据
# market_data = api.get_market_data(stock_code=sample_codes[0],
# start_date=week_ago,
# end_date=today)
# print("\n日线数据示例:")
# print(market_data.head())
# 获取资金流向数据
flow_data = api.get_capital_flow_min(stock_code=sample_codes[0])
print("\n资金流向数据示例:")
print(flow_data.head())
# # 获取分钟线数据
# min_data = api.get_market_min_data(stock_code=sample_codes[0])
# print("\n分钟线数据示例:")
# print(min_data.head())
# # 获取资金流向数据
# flow_data = api.get_capital_flow_min(stock_code=sample_codes[0])
# print("\n资金流向数据示例:")
# print(flow_data.head())

View File

@ -1,26 +1,84 @@
import json
import logging
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
from cryptoai.api.adata_api import AStockAPI
from cryptoai.utils.db_manager import get_db_manager
# 创建路由
router = APIRouter()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@router.get("/stock/data")
async def get_stock_data(stock_code: str):
@router.get("/stock/search")
async def search_stock(key: str, limit: int = 10):
manager = get_db_manager()
result = manager.search_stock(key, limit)
return result
@router.get("/stock/base", summary="获取股票基础信息")
async def get_stock_base(stock_code: str):
api = AStockAPI()
result = {}
try:
# 获取股本信息
stock_shares = api.get_stock_shares(stock_code)
result["stock_shares"] = json.loads(stock_shares.to_json(orient="records"))
# 获取概念板块
concept_east = api.get_concept_east(stock_code)
result["concept_east"] = json.loads(concept_east.to_json(orient="records"))
# 获取板块
plate_east = api.get_plate_east(stock_code)
result["plate_east"] = json.loads(plate_east.to_json(orient="records"))
except Exception as e:
logger.error(f"获取股票基础信息失败: {e}")
return {}
return result
@router.get("/stock/data", summary="获取股票数据")
async def get_stock_data(stock_code: str):
api = AStockAPI()
result = {}
try:
# 获取市场数据
market_data = api.get_market_data(stock_code)
result["market_data"] = json.loads(market_data.to_json(orient="records"))
# 获取资金流向数据
flow_data = api.get_capital_flow(stock_code)
result["flow_data"] = json.loads(flow_data.to_json(orient="records"))
except Exception as e:
logger.error(f"获取股票数据失败: {e}")
return {}
@router.get("/stock/data/all", summary="获取所有股票数据")
async def get_stock_data_all(stock_code: str):
result = {}
try:
api = AStockAPI()
#获取股本信息
# stock_shares = api.get_stock_shares(stock_code)
# result["stock_shares"] = json.loads(stock_shares.to_json(orient="records"))
stock_shares = api.get_stock_shares(stock_code)
result["stock_shares"] = json.loads(stock_shares.to_json(orient="records"))
# # 获取概念板块
# concept_east = api.get_concept_east(stock_code)
# result["concept_east"] = json.loads(concept_east.to_json(orient="records"))
# 获取概念板块
concept_east = api.get_concept_east(stock_code)
result["concept_east"] = json.loads(concept_east.to_json(orient="records"))
# 获取板块
plate_east = api.get_plate_east(stock_code)
@ -30,12 +88,16 @@ async def get_stock_data(stock_code: str):
market_data = api.get_market_data(stock_code)
result["market_data"] = json.loads(market_data.to_json(orient="records"))
# # 获取分钟线数据
# min_data = api.get_market_min_data(stock_code)
# result["min_data"] = json.loads(min_data.to_json(orient="records"))
# 获取分钟线数据
min_data = api.get_market_min_data(stock_code)
result["min_data"] = json.loads(min_data.to_json(orient="records"))
# 获取资金流向数据
flow_data = api.get_capital_flow(stock_code)
result["flow_data"] = json.loads(flow_data.to_json(orient="records"))
except Exception as e:
logger.error(f"获取股票数据失败: {e}")
return {}
return result

View File

@ -131,6 +131,50 @@ class Agent(Base):
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义 A 股数据模型
class AStock(Base):
"""A股股票基本信息表模型"""
__tablename__ = 'astock'
stock_code = Column(String(10), primary_key=True, comment='股票代码')
short_name = Column(String(50), nullable=False, comment='股票简称')
exchange = Column(String(20), nullable=True, comment='交易所')
list_date = Column(DateTime, nullable=True, comment='上市日期')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引
__table_args__ = (
Index('idx_stock_code', 'stock_code', unique=True),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
def _convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]:
"""
将时间戳转换为datetime对象
Args:
timestamp: Unix时间戳毫秒级
Returns:
datetime对象或None
"""
if timestamp is None:
return None
try:
# 转换为整数
if isinstance(timestamp, str):
timestamp = int(timestamp)
# 如果是毫秒级时间戳,转换为秒级
if timestamp > 1e11: # 判断是否为毫秒级时间戳
timestamp = timestamp / 1000
return datetime.fromtimestamp(timestamp)
except (ValueError, TypeError) as e:
logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}")
return None
class DBManager:
"""数据库管理工具用于连接MySQL数据库并保存智能体分析结果"""
@ -1216,6 +1260,395 @@ class DBManager:
logger.error(f"创建数据库会话失败: {e}")
return False
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
"""
批量创建股票信息
Args:
stocks: 股票信息列表每个元素为包含以下键的字典
- stock_code: 股票代码
- short_name: 股票简称
- exchange: 交易所可选
- list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
for stock_data in stocks:
# 检查股票代码是否已存在
existing_stock = session.query(AStock).filter(
AStock.stock_code == stock_data['stock_code']
).first()
if existing_stock:
logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过")
continue
# 创建新股票记录
new_stock = AStock(
stock_code=stock_data['stock_code'],
short_name=stock_data['short_name'],
exchange=stock_data.get('exchange')
)
# 处理上市日期
if 'list_date' in stock_data and stock_data['list_date']:
list_date = _convert_timestamp_to_datetime(stock_data['list_date'])
if list_date:
new_stock.list_date = list_date
session.add(new_stock)
# 批量提交
session.commit()
logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录")
return True
except Exception as e:
session.rollback()
logger.error(f"批量创建股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def create_stock(self, stock_code: str, short_name: str,
exchange: Optional[str] = None,
list_date: Optional[Union[int, float, str]] = None) -> bool:
"""
创建新的股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所可选
list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 检查股票代码是否已存在
existing_stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if existing_stock:
logger.warning(f"股票代码 {stock_code} 已存在")
return False
# 创建新股票记录
new_stock = AStock(
stock_code=stock_code,
short_name=short_name,
exchange=exchange
)
# 处理上市日期
if list_date:
converted_date = _convert_timestamp_to_datetime(list_date)
if converted_date:
new_stock.list_date = converted_date
session.add(new_stock)
session.commit()
logger.info(f"成功创建股票信息: {stock_code} - {short_name}")
return True
except Exception as e:
session.rollback()
logger.error(f"创建股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def update_stock(self, stock_code: str, short_name: Optional[str] = None,
exchange: Optional[str] = None, list_date: Optional[datetime] = None) -> bool:
"""
更新股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所
list_date: 上市日期
Returns:
更新是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 查询股票
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 更新字段
if short_name is not None:
stock.short_name = short_name
if exchange is not None:
stock.exchange = exchange
if list_date is not None:
stock.list_date = list_date
stock.updated_at = datetime.now()
session.commit()
logger.info(f"成功更新股票信息: {stock_code}")
return True
except Exception as e:
session.rollback()
logger.error(f"更新股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def delete_stock(self, stock_code: str) -> bool:
"""
删除股票信息
Args:
stock_code: 股票代码
Returns:
删除是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 查询股票
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 删除股票
session.delete(stock)
session.commit()
logger.info(f"成功删除股票信息: {stock_code}")
return True
except Exception as e:
session.rollback()
logger.error(f"删除股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]:
"""
通过股票代码获取股票信息
Args:
stock_code: 股票代码
Returns:
股票信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
session = self.Session()
try:
# 查询股票
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if stock:
return {
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at,
'updated_at': stock.updated_at
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return None
def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
搜索股票
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
session = self.Session()
try:
# 查询股票
stocks = session.query(AStock).filter(AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
finally:
session.close()
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]:
"""
通过股票简称获取股票信息可能有多个结果
Args:
short_name: 股票简称
Returns:
股票信息列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
session = self.Session()
try:
# 查询股票
stocks = session.query(AStock).filter(AStock.short_name == short_name).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
finally:
session.close()
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取股票列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
股票信息列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
session = self.Session()
try:
# 查询股票列表
stocks = session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
finally:
session.close()
except Exception as e:
logger.error(f"获取股票列表失败: {e}")
return []
# 单例模式
_db_instance = None

View File

@ -29,7 +29,7 @@ services:
cryptoai-api:
build: .
container_name: cryptoai-api
image: cryptoai-api:0.0.18
image: cryptoai-api:0.0.19
restart: always
ports:
- "8000:8000"

23
test.py
View File

@ -1,8 +1,17 @@
from cryptoai.api.binance_api import BinanceAPI
from cryptoai.utils.config_loader import ConfigLoader
from cryptoai.utils.db_manager import get_db_manager
from cryptoai.api.adata_api import AStockAPI
import json
from time import sleep
if __name__ == "__main__":
print("开始获取A股数据")
api = AStockAPI()
stock_codes = api.get_all_stock_codes()
list = json.loads(stock_codes.to_json(orient="records"))
# print(list[0])
config = ConfigLoader().get_binance_config()
binance_api = BinanceAPI(config['api_key'], config['api_secret'])
print(binance_api.get_top_longshort_position_ratio("BTCUSDT", "1h"))
# 保存到数据库
for stock in list:
print(f"创建股票: {stock['stock_code']} - {stock['short_name']}")
get_db_manager().create_stock(stock["stock_code"], stock["short_name"], stock["exchange"], stock["list_date"])
print(f"创建股票: {stock['stock_code']} - {stock['short_name']} 完成")
sleep(1)