333 lines
13 KiB
Python
333 lines
13 KiB
Python
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from typing import List, Dict, Optional
|
||
import uvicorn
|
||
import json
|
||
import numpy as np
|
||
|
||
from src.data import TushareClient, Database, RedisCache
|
||
from src.analysis import FinancialIndicators, TechnicalIndicators
|
||
from src.strategies import ValueStrategy, GrowthStrategy, TechnicalStrategy
|
||
from src.utils import setup_logger, validate_ts_code
|
||
from config.config import Config
|
||
|
||
def format_stock_code(code: str) -> str:
|
||
"""格式化股票代码,自动添加交易所后缀"""
|
||
if not code:
|
||
return code
|
||
|
||
# 移除所有空格并转为大写
|
||
clean_code = code.strip().replace(' ', '').upper()
|
||
|
||
# 如果已经包含交易所后缀,直接返回
|
||
if '.SZ' in clean_code or '.SH' in clean_code:
|
||
return clean_code
|
||
|
||
# 只保留数字部分
|
||
numeric_code = ''.join(filter(str.isdigit, clean_code))
|
||
|
||
if len(numeric_code) == 6:
|
||
# 深交所:000xxx(主板), 002xxx(中小板), 300xxx(创业板)
|
||
if numeric_code.startswith('000') or numeric_code.startswith('002') or numeric_code.startswith('300'):
|
||
return numeric_code + '.SZ'
|
||
# 上交所:600xxx, 601xxx, 603xxx(主板), 688xxx(科创板)
|
||
elif numeric_code.startswith('60') or numeric_code.startswith('688'):
|
||
return numeric_code + '.SH'
|
||
|
||
# 默认返回原始代码加.SZ
|
||
return numeric_code + '.SZ' if numeric_code else code
|
||
|
||
def clean_data_for_json(data):
|
||
"""清理数据,确保JSON兼容性"""
|
||
if isinstance(data, dict):
|
||
return {k: clean_data_for_json(v) for k, v in data.items()}
|
||
elif isinstance(data, list):
|
||
return [clean_data_for_json(item) for item in data]
|
||
elif isinstance(data, (int, str, bool)) or data is None:
|
||
return data
|
||
elif isinstance(data, float):
|
||
if np.isnan(data) or np.isinf(data):
|
||
return 0
|
||
return data
|
||
else:
|
||
return str(data)
|
||
|
||
app = FastAPI(title="Stock Agent API", description="优质股票分析筛选 AI Agent", version="1.0.0")
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
logger = setup_logger(__name__)
|
||
tushare_client = TushareClient()
|
||
database = Database()
|
||
cache = RedisCache()
|
||
|
||
class AnalysisRequest(BaseModel):
|
||
ts_code: str
|
||
strategy: str = "value"
|
||
|
||
class ScreenRequest(BaseModel):
|
||
strategy: str = "value"
|
||
min_score: float = 60
|
||
limit: int = 50
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
return {"message": "Stock Agent API is running"}
|
||
|
||
@app.get("/api/stocks")
|
||
async def get_stock_list():
|
||
try:
|
||
cache_key = "stock_list"
|
||
cached_data = cache.get(cache_key)
|
||
|
||
if cached_data:
|
||
return cached_data
|
||
|
||
stocks = tushare_client.get_stock_list()
|
||
if stocks.empty:
|
||
raise HTTPException(status_code=404, detail="No stock data found")
|
||
|
||
stock_list = stocks.to_dict('records')
|
||
cache.set(cache_key, stock_list, 3600)
|
||
|
||
return {"data": stock_list, "count": len(stock_list)}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting stock list: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.post("/api/analyze")
|
||
async def analyze_stock(request: AnalysisRequest):
|
||
try:
|
||
# 格式化股票代码
|
||
formatted_ts_code = format_stock_code(request.ts_code)
|
||
|
||
if not validate_ts_code(formatted_ts_code):
|
||
raise HTTPException(status_code=400, detail="Invalid stock code")
|
||
|
||
cache_key = f"analysis_{formatted_ts_code}_{request.strategy}"
|
||
cached_data = cache.get(cache_key)
|
||
|
||
if cached_data:
|
||
return cached_data
|
||
|
||
# 获取股票基本信息
|
||
stock_basic_info = tushare_client.get_stock_basic(formatted_ts_code)
|
||
if not stock_basic_info:
|
||
raise HTTPException(status_code=404, detail="Stock not found")
|
||
|
||
daily_data = tushare_client.get_stock_daily(formatted_ts_code)
|
||
if daily_data.empty:
|
||
raise HTTPException(status_code=404, detail="No daily data found")
|
||
|
||
financial_data = tushare_client.get_financial_data(formatted_ts_code)
|
||
|
||
financial_ratios = FinancialIndicators.calculate_all_ratios(financial_data)
|
||
|
||
daily_with_indicators = TechnicalIndicators.calculate_all_indicators(daily_data)
|
||
|
||
stock_data = {
|
||
'ts_code': formatted_ts_code,
|
||
'financial_ratios': financial_ratios,
|
||
'technical_indicators': daily_with_indicators.iloc[-1].to_dict() if not daily_with_indicators.empty else {}
|
||
}
|
||
|
||
if request.strategy == "value":
|
||
strategy = ValueStrategy()
|
||
elif request.strategy == "growth":
|
||
strategy = GrowthStrategy()
|
||
elif request.strategy == "technical":
|
||
strategy = TechnicalStrategy()
|
||
else:
|
||
raise HTTPException(status_code=400, detail="Invalid strategy")
|
||
|
||
score = strategy.calculate_score(stock_data)
|
||
|
||
# 如果是技术策略,添加交易信号分析
|
||
trading_signals = None
|
||
if request.strategy == "technical":
|
||
trading_signals = strategy.get_trading_signals(stock_data)
|
||
|
||
result = {
|
||
'ts_code': request.ts_code,
|
||
'strategy': request.strategy,
|
||
'score': round(score, 2),
|
||
'financial_ratios': clean_data_for_json(financial_ratios),
|
||
'recommendation': 'BUY' if score >= 65 else 'HOLD' if score >= 45 else 'SELL',
|
||
# 添加股票基本信息
|
||
'name': stock_basic_info['name'],
|
||
'industry': stock_basic_info['industry'],
|
||
'current_price': clean_data_for_json(stock_basic_info['current_price']),
|
||
'market_cap': clean_data_for_json(stock_basic_info['market_cap']),
|
||
'list_date': stock_basic_info['list_date']
|
||
}
|
||
|
||
# 添加交易信号(如果是技术分析)
|
||
if trading_signals:
|
||
result['trading_signals'] = clean_data_for_json(trading_signals)
|
||
|
||
# 确保整个结果都是JSON兼容的
|
||
result = clean_data_for_json(result)
|
||
|
||
cache.set(cache_key, result, 1800)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error analyzing stock {request.ts_code}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.post("/api/screen")
|
||
async def screen_stocks(request: ScreenRequest):
|
||
import uuid
|
||
request_id = str(uuid.uuid4())[:8]
|
||
|
||
try:
|
||
logger.info(f"[{request_id}] Starting screening request with strategy={request.strategy}")
|
||
|
||
cache_key = f"screen_hs300_{request.strategy}_{request.min_score}_{request.limit}"
|
||
cached_data = cache.get(cache_key)
|
||
|
||
if cached_data:
|
||
logger.info(f"[{request_id}] Returning cached data")
|
||
return cached_data
|
||
|
||
# 获取沪深300成分股
|
||
hs300_stocks = tushare_client.get_hs300_stocks()
|
||
if hs300_stocks.empty:
|
||
raise HTTPException(status_code=404, detail="No HS300 stock data found")
|
||
|
||
# 限制处理数量,避免超时(先处理前300只进行测试)
|
||
max_process = min(300, len(hs300_stocks))
|
||
hs300_stocks = hs300_stocks.head(max_process)
|
||
logger.info(f"Limited processing to {max_process} stocks to avoid timeout")
|
||
|
||
# 初始化三种策略
|
||
value_strategy = ValueStrategy()
|
||
growth_strategy = GrowthStrategy()
|
||
technical_strategy = TechnicalStrategy()
|
||
|
||
results = []
|
||
processed_count = 0
|
||
|
||
if request.strategy == "comprehensive":
|
||
logger.info(f"[{request_id}] Starting comprehensive screening of {len(hs300_stocks)} HS300 stocks")
|
||
else:
|
||
# 保持向后兼容,支持单一策略
|
||
if request.strategy == "value":
|
||
strategy = value_strategy
|
||
elif request.strategy == "growth":
|
||
strategy = growth_strategy
|
||
elif request.strategy == "technical":
|
||
strategy = technical_strategy
|
||
else:
|
||
raise HTTPException(status_code=400, detail="Invalid strategy")
|
||
logger.info(f"[{request_id}] Starting to screen {len(hs300_stocks)} HS300 stocks with {request.strategy} strategy")
|
||
|
||
# 处理所有沪深300成分股
|
||
for _, stock in hs300_stocks.iterrows():
|
||
try:
|
||
ts_code = stock['ts_code']
|
||
logger.debug(f"Processing stock: {ts_code}") # 添加调试日志
|
||
|
||
financial_data = tushare_client.get_financial_data(ts_code)
|
||
financial_ratios = FinancialIndicators.calculate_all_ratios(financial_data)
|
||
|
||
daily_data = tushare_client.get_stock_daily(ts_code)
|
||
if not daily_data.empty:
|
||
daily_with_indicators = TechnicalIndicators.calculate_all_indicators(daily_data)
|
||
technical_indicators = daily_with_indicators.iloc[-1].to_dict()
|
||
else:
|
||
technical_indicators = {}
|
||
|
||
stock_data = {
|
||
'ts_code': ts_code,
|
||
'name': stock['name'],
|
||
'industry': stock.get('industry', ''),
|
||
'financial_ratios': financial_ratios,
|
||
'technical_indicators': technical_indicators
|
||
}
|
||
|
||
if request.strategy == "comprehensive":
|
||
# 综合评估:计算三种策略的平均分
|
||
value_score = value_strategy.calculate_score(stock_data)
|
||
growth_score = growth_strategy.calculate_score(stock_data)
|
||
technical_score = technical_strategy.calculate_score(stock_data)
|
||
|
||
# 综合评分(可以调整权重)
|
||
comprehensive_score = (value_score * 0.4 + growth_score * 0.3 + technical_score * 0.3)
|
||
|
||
if comprehensive_score >= request.min_score:
|
||
stock_data['score'] = round(comprehensive_score, 2)
|
||
stock_data['value_score'] = round(value_score, 2)
|
||
stock_data['growth_score'] = round(growth_score, 2)
|
||
stock_data['technical_score'] = round(technical_score, 2)
|
||
|
||
# 综合推荐逻辑
|
||
high_scores = sum([score >= 70 for score in [value_score, growth_score, technical_score]])
|
||
if comprehensive_score >= 75 or high_scores >= 2:
|
||
stock_data['recommendation'] = 'BUY'
|
||
elif comprehensive_score >= 65:
|
||
stock_data['recommendation'] = 'HOLD'
|
||
else:
|
||
stock_data['recommendation'] = 'WATCH'
|
||
|
||
logger.debug(f"Adding stock {ts_code} with comprehensive score {comprehensive_score}")
|
||
results.append(stock_data)
|
||
else:
|
||
# 单一策略评估(保持原有逻辑)
|
||
score = strategy.calculate_score(stock_data)
|
||
|
||
if score >= request.min_score:
|
||
stock_data['score'] = round(score, 2)
|
||
stock_data['recommendation'] = 'BUY' if score >= 65 else 'HOLD'
|
||
logger.debug(f"Adding stock {ts_code} with {request.strategy} score {score}")
|
||
results.append(stock_data)
|
||
|
||
processed_count += 1
|
||
|
||
if len(results) >= request.limit:
|
||
logger.info(f"[{request_id}] Reached limit of {request.limit} results, stopping processing")
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing stock {stock.get('ts_code', 'unknown')}: {e}")
|
||
continue
|
||
|
||
# 对结果进行排序
|
||
results = sorted(results, key=lambda x: x['score'], reverse=True)
|
||
|
||
logger.info(f"[{request_id}] Screening completed. Found {len(results)} qualifying stocks from {processed_count} processed")
|
||
|
||
response = {
|
||
'strategy': request.strategy,
|
||
'min_score': request.min_score,
|
||
'results': clean_data_for_json(results),
|
||
'count': len(results),
|
||
'processed_count': processed_count,
|
||
'request_id': request_id # 添加请求ID到响应中
|
||
}
|
||
|
||
cache.set(cache_key, response, 1800)
|
||
|
||
logger.info(f"[{request_id}] Returning {len(results)} results")
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{request_id}] Error screening stocks: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
"main:app",
|
||
host=Config.API_HOST,
|
||
port=Config.API_PORT,
|
||
reload=True
|
||
) |