stock-ai-agent/backend/app/astock_agent/akshare_client.py
2026-02-27 09:54:17 +08:00

235 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Akshare 数据封装
提供 A 股板块、个股行情数据获取接口
支持概念板块和行业板块
"""
import os
import time
import akshare as ak
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from app.utils.logger import logger
# 禁用全局代理设置
os.environ.pop('HTTP_PROXY', None)
os.environ.pop('HTTPS_PROXY', None)
os.environ.pop('http_proxy', None)
os.environ.pop('https_proxy', None)
# Monkey patch requests 以禁用代理
import requests
_original_session_init = requests.Session.__init__
def _patched_session_init(self, *args, **kwargs):
_original_session_init(self, *args, **kwargs)
self.trust_env = False
self.proxies = {'http': None, 'https': None}
requests.Session.__init__ = _patched_session_init
class AkshareClient:
"""Akshare 数据客户端"""
# 缓存数据,避免频繁请求
_cache = {}
_cache_time = {}
_last_request_time = 0
def __init__(self):
"""初始化客户端"""
self.cache_ttl = 60 # 缓存60秒
self.request_delay = 1.0 # 请求间隔(秒)
self.max_retries = 3 # 最大重试次数
def _get_cached(self, key: str, fetch_func) -> pd.DataFrame:
"""获取缓存数据,支持重试"""
now = datetime.now()
# 检查缓存
if key in self._cache:
cache_time = self._cache_time.get(key)
if cache_time and (now - cache_time).seconds < self.cache_ttl:
logger.debug(f"使用缓存数据: {key}")
return self._cache[key]
# 请求限流
elapsed = now.timestamp() - self._last_request_time
if elapsed < self.request_delay:
time.sleep(self.request_delay - elapsed)
# 重试逻辑
last_error = None
for attempt in range(self.max_retries):
try:
self._last_request_time = time.time()
df = fetch_func()
if df is not None and not df.empty:
self._cache[key] = df
self._cache_time[key] = now
logger.debug(f"获取数据成功: {key}")
return df
except Exception as e:
last_error = e
error_msg = str(e)
# 判断错误类型
if 'Connection' in error_msg or 'RemoteDisconnected' in error_msg:
# 连接错误,指数退避重试
if attempt < self.max_retries - 1:
wait_time = (2 ** attempt) * 2 # 2, 4, 8秒
logger.warning(
f"获取数据失败 {key} (尝试 {attempt + 1}/{self.max_retries}): {e}"
f"等待 {wait_time}秒后重试..."
)
time.sleep(wait_time)
continue
# 其他错误或重试次数用尽
logger.error(f"获取数据失败 {key}: {e}")
break
return pd.DataFrame()
def get_concept_spot(self) -> pd.DataFrame:
"""
获取概念板块行情(实时)
Returns:
概念板块行情数据
"""
def fetch():
# stock_board_concept_name_em - 东方财富概念板块行情
return ak.stock_board_concept_name_em()
return self._get_cached('concept_spot', fetch)
def get_industry_spot(self) -> pd.DataFrame:
"""
获取行业板块行情(实时)
Returns:
行业板块行情数据
"""
def fetch():
# stock_board_industry_name_em - 东方财富行业板块行情
return ak.stock_board_industry_name_em()
return self._get_cached('industry_spot', fetch)
def get_concept_stocks(self, sector_name: str) -> pd.DataFrame:
"""
获取概念板块成分股
Args:
sector_name: 板块名称
Returns:
成分股数据
"""
def fetch():
# stock_board_concept_cons_em - 概念板块成分股
df = ak.stock_board_concept_cons_em(symbol=sector_name)
return df if df is not None else pd.DataFrame()
return self._get_cached(f'concept_stocks_{sector_name}', fetch)
def get_industry_stocks(self, sector_name: str) -> pd.DataFrame:
"""
获取行业板块成分股
Args:
sector_name: 板块名称
Returns:
成分股数据
"""
def fetch():
# stock_board_industry_cons_em - 行业板块成分股
df = ak.stock_board_industry_cons_em(symbol=sector_name)
return df if df is not None else pd.DataFrame()
return self._get_cached(f'industry_stocks_{sector_name}', fetch)
def get_stock_spot(self) -> pd.DataFrame:
"""
获取 A 股实时行情
Returns:
A 股实时行情数据
"""
def fetch():
return ak.stock_zh_a_spot_em()
return self._get_cached('stock_spot', fetch)
def get_stock_fund_flow(self, symbol: str) -> pd.DataFrame:
"""
获取个股资金流向
Args:
symbol: 股票代码
Returns:
资金流向数据
"""
def fetch():
return ak.stock_individual_fund_flow(
stock=symbol,
market="sh" if symbol.startswith('6') else "sz"
)
return self._get_cached(f'fund_flow_{symbol}', fetch)
def get_stock_info(self, symbol: str) -> Dict:
"""
获取个股基本信息
Args:
symbol: 股票代码
Returns:
股票信息字典
"""
try:
info = ak.stock_individual_info_em(symbol=symbol)
return {
'name': info.get('股票简称', ''),
'industry': info.get('行业', ''),
'market_cap': info.get('总市值', ''),
'float_cap': info.get('流通市值', ''),
}
except Exception as e:
logger.error(f"获取股票信息失败 {symbol}: {e}")
return {}
def get_limit_list_stocks(self) -> pd.DataFrame:
"""
获取涨停板股票
Returns:
涨停板股票列表
"""
def fetch():
return ak.stock_zt_pool_em(date=datetime.now().strftime('%Y%m%d'))
return self._get_cached('limit_list', fetch)
# 全局单例
_akshare_client: Optional[AkshareClient] = None
def get_akshare_client() -> AkshareClient:
"""获取 Akshare 客户端单例"""
global _akshare_client
if _akshare_client is None:
_akshare_client = AkshareClient()
return _akshare_client