235 lines
6.7 KiB
Python
235 lines
6.7 KiB
Python
"""
|
||
Akshare 数据封装
|
||
提供 A 股板块、个股行情数据获取接口
|
||
支持概念板块和行业板块
|
||
"""
|
||
import os
|
||
import time
|
||
import akshare as ak
|
||
import pandas as pd
|
||
from typing import Dict, List, Optional
|
||
from datetime import datetime, timedelta
|
||
from app.utils.logger import logger
|
||
|
||
|
||
# 禁用全局代理设置
|
||
os.environ.pop('HTTP_PROXY', None)
|
||
os.environ.pop('HTTPS_PROXY', None)
|
||
os.environ.pop('http_proxy', None)
|
||
os.environ.pop('https_proxy', None)
|
||
|
||
# Monkey patch requests 以禁用代理
|
||
import requests
|
||
_original_session_init = requests.Session.__init__
|
||
|
||
|
||
def _patched_session_init(self, *args, **kwargs):
|
||
_original_session_init(self, *args, **kwargs)
|
||
self.trust_env = False
|
||
self.proxies = {'http': None, 'https': None}
|
||
|
||
|
||
requests.Session.__init__ = _patched_session_init
|
||
|
||
|
||
class AkshareClient:
|
||
"""Akshare 数据客户端"""
|
||
|
||
# 缓存数据,避免频繁请求
|
||
_cache = {}
|
||
_cache_time = {}
|
||
_last_request_time = 0
|
||
|
||
def __init__(self):
|
||
"""初始化客户端"""
|
||
self.cache_ttl = 60 # 缓存60秒
|
||
self.request_delay = 1.0 # 请求间隔(秒)
|
||
self.max_retries = 3 # 最大重试次数
|
||
|
||
def _get_cached(self, key: str, fetch_func) -> pd.DataFrame:
|
||
"""获取缓存数据,支持重试"""
|
||
now = datetime.now()
|
||
|
||
# 检查缓存
|
||
if key in self._cache:
|
||
cache_time = self._cache_time.get(key)
|
||
if cache_time and (now - cache_time).seconds < self.cache_ttl:
|
||
logger.debug(f"使用缓存数据: {key}")
|
||
return self._cache[key]
|
||
|
||
# 请求限流
|
||
elapsed = now.timestamp() - self._last_request_time
|
||
if elapsed < self.request_delay:
|
||
time.sleep(self.request_delay - elapsed)
|
||
|
||
# 重试逻辑
|
||
last_error = None
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
self._last_request_time = time.time()
|
||
df = fetch_func()
|
||
|
||
if df is not None and not df.empty:
|
||
self._cache[key] = df
|
||
self._cache_time[key] = now
|
||
logger.debug(f"获取数据成功: {key}")
|
||
return df
|
||
|
||
except Exception as e:
|
||
last_error = e
|
||
error_msg = str(e)
|
||
|
||
# 判断错误类型
|
||
if 'Connection' in error_msg or 'RemoteDisconnected' in error_msg:
|
||
# 连接错误,指数退避重试
|
||
if attempt < self.max_retries - 1:
|
||
wait_time = (2 ** attempt) * 2 # 2, 4, 8秒
|
||
logger.warning(
|
||
f"获取数据失败 {key} (尝试 {attempt + 1}/{self.max_retries}): {e},"
|
||
f"等待 {wait_time}秒后重试..."
|
||
)
|
||
time.sleep(wait_time)
|
||
continue
|
||
|
||
# 其他错误或重试次数用尽
|
||
logger.error(f"获取数据失败 {key}: {e}")
|
||
break
|
||
|
||
return pd.DataFrame()
|
||
|
||
def get_concept_spot(self) -> pd.DataFrame:
|
||
"""
|
||
获取概念板块行情(实时)
|
||
|
||
Returns:
|
||
概念板块行情数据
|
||
"""
|
||
def fetch():
|
||
# stock_board_concept_name_em - 东方财富概念板块行情
|
||
return ak.stock_board_concept_name_em()
|
||
|
||
return self._get_cached('concept_spot', fetch)
|
||
|
||
def get_industry_spot(self) -> pd.DataFrame:
|
||
"""
|
||
获取行业板块行情(实时)
|
||
|
||
Returns:
|
||
行业板块行情数据
|
||
"""
|
||
def fetch():
|
||
# stock_board_industry_name_em - 东方财富行业板块行情
|
||
return ak.stock_board_industry_name_em()
|
||
|
||
return self._get_cached('industry_spot', fetch)
|
||
|
||
def get_concept_stocks(self, sector_name: str) -> pd.DataFrame:
|
||
"""
|
||
获取概念板块成分股
|
||
|
||
Args:
|
||
sector_name: 板块名称
|
||
|
||
Returns:
|
||
成分股数据
|
||
"""
|
||
def fetch():
|
||
# stock_board_concept_cons_em - 概念板块成分股
|
||
df = ak.stock_board_concept_cons_em(symbol=sector_name)
|
||
return df if df is not None else pd.DataFrame()
|
||
|
||
return self._get_cached(f'concept_stocks_{sector_name}', fetch)
|
||
|
||
def get_industry_stocks(self, sector_name: str) -> pd.DataFrame:
|
||
"""
|
||
获取行业板块成分股
|
||
|
||
Args:
|
||
sector_name: 板块名称
|
||
|
||
Returns:
|
||
成分股数据
|
||
"""
|
||
def fetch():
|
||
# stock_board_industry_cons_em - 行业板块成分股
|
||
df = ak.stock_board_industry_cons_em(symbol=sector_name)
|
||
return df if df is not None else pd.DataFrame()
|
||
|
||
return self._get_cached(f'industry_stocks_{sector_name}', fetch)
|
||
|
||
def get_stock_spot(self) -> pd.DataFrame:
|
||
"""
|
||
获取 A 股实时行情
|
||
|
||
Returns:
|
||
A 股实时行情数据
|
||
"""
|
||
def fetch():
|
||
return ak.stock_zh_a_spot_em()
|
||
|
||
return self._get_cached('stock_spot', fetch)
|
||
|
||
def get_stock_fund_flow(self, symbol: str) -> pd.DataFrame:
|
||
"""
|
||
获取个股资金流向
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
|
||
Returns:
|
||
资金流向数据
|
||
"""
|
||
def fetch():
|
||
return ak.stock_individual_fund_flow(
|
||
stock=symbol,
|
||
market="sh" if symbol.startswith('6') else "sz"
|
||
)
|
||
|
||
return self._get_cached(f'fund_flow_{symbol}', fetch)
|
||
|
||
def get_stock_info(self, symbol: str) -> Dict:
|
||
"""
|
||
获取个股基本信息
|
||
|
||
Args:
|
||
symbol: 股票代码
|
||
|
||
Returns:
|
||
股票信息字典
|
||
"""
|
||
try:
|
||
info = ak.stock_individual_info_em(symbol=symbol)
|
||
return {
|
||
'name': info.get('股票简称', ''),
|
||
'industry': info.get('行业', ''),
|
||
'market_cap': info.get('总市值', ''),
|
||
'float_cap': info.get('流通市值', ''),
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取股票信息失败 {symbol}: {e}")
|
||
return {}
|
||
|
||
def get_limit_list_stocks(self) -> pd.DataFrame:
|
||
"""
|
||
获取涨停板股票
|
||
|
||
Returns:
|
||
涨停板股票列表
|
||
"""
|
||
def fetch():
|
||
return ak.stock_zt_pool_em(date=datetime.now().strftime('%Y%m%d'))
|
||
|
||
return self._get_cached('limit_list', fetch)
|
||
|
||
|
||
# 全局单例
|
||
_akshare_client: Optional[AkshareClient] = None
|
||
|
||
|
||
def get_akshare_client() -> AkshareClient:
|
||
"""获取 Akshare 客户端单例"""
|
||
global _akshare_client
|
||
if _akshare_client is None:
|
||
_akshare_client = AkshareClient()
|
||
return _akshare_client
|