stock-agent/main.py
2025-12-28 10:12:30 +08:00

333 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional
import uvicorn
import json
import numpy as np
from src.data import TushareClient, Database, RedisCache
from src.analysis import FinancialIndicators, TechnicalIndicators
from src.strategies import ValueStrategy, GrowthStrategy, TechnicalStrategy
from src.utils import setup_logger, validate_ts_code
from config.config import Config
def format_stock_code(code: str) -> str:
"""格式化股票代码,自动添加交易所后缀"""
if not code:
return code
# 移除所有空格并转为大写
clean_code = code.strip().replace(' ', '').upper()
# 如果已经包含交易所后缀,直接返回
if '.SZ' in clean_code or '.SH' in clean_code:
return clean_code
# 只保留数字部分
numeric_code = ''.join(filter(str.isdigit, clean_code))
if len(numeric_code) == 6:
# 深交所000xxx(主板), 002xxx(中小板), 300xxx(创业板)
if numeric_code.startswith('000') or numeric_code.startswith('002') or numeric_code.startswith('300'):
return numeric_code + '.SZ'
# 上交所600xxx, 601xxx, 603xxx(主板), 688xxx(科创板)
elif numeric_code.startswith('60') or numeric_code.startswith('688'):
return numeric_code + '.SH'
# 默认返回原始代码加.SZ
return numeric_code + '.SZ' if numeric_code else code
def clean_data_for_json(data):
"""清理数据确保JSON兼容性"""
if isinstance(data, dict):
return {k: clean_data_for_json(v) for k, v in data.items()}
elif isinstance(data, list):
return [clean_data_for_json(item) for item in data]
elif isinstance(data, (int, str, bool)) or data is None:
return data
elif isinstance(data, float):
if np.isnan(data) or np.isinf(data):
return 0
return data
else:
return str(data)
app = FastAPI(title="Stock Agent API", description="优质股票分析筛选 AI Agent", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger = setup_logger(__name__)
tushare_client = TushareClient()
database = Database()
cache = RedisCache()
class AnalysisRequest(BaseModel):
ts_code: str
strategy: str = "value"
class ScreenRequest(BaseModel):
strategy: str = "value"
min_score: float = 60
limit: int = 50
@app.get("/")
async def root():
return {"message": "Stock Agent API is running"}
@app.get("/api/stocks")
async def get_stock_list():
try:
cache_key = "stock_list"
cached_data = cache.get(cache_key)
if cached_data:
return cached_data
stocks = tushare_client.get_stock_list()
if stocks.empty:
raise HTTPException(status_code=404, detail="No stock data found")
stock_list = stocks.to_dict('records')
cache.set(cache_key, stock_list, 3600)
return {"data": stock_list, "count": len(stock_list)}
except Exception as e:
logger.error(f"Error getting stock list: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/analyze")
async def analyze_stock(request: AnalysisRequest):
try:
# 格式化股票代码
formatted_ts_code = format_stock_code(request.ts_code)
if not validate_ts_code(formatted_ts_code):
raise HTTPException(status_code=400, detail="Invalid stock code")
cache_key = f"analysis_{formatted_ts_code}_{request.strategy}"
cached_data = cache.get(cache_key)
if cached_data:
return cached_data
# 获取股票基本信息
stock_basic_info = tushare_client.get_stock_basic(formatted_ts_code)
if not stock_basic_info:
raise HTTPException(status_code=404, detail="Stock not found")
daily_data = tushare_client.get_stock_daily(formatted_ts_code)
if daily_data.empty:
raise HTTPException(status_code=404, detail="No daily data found")
financial_data = tushare_client.get_financial_data(formatted_ts_code)
financial_ratios = FinancialIndicators.calculate_all_ratios(financial_data)
daily_with_indicators = TechnicalIndicators.calculate_all_indicators(daily_data)
stock_data = {
'ts_code': formatted_ts_code,
'financial_ratios': financial_ratios,
'technical_indicators': daily_with_indicators.iloc[-1].to_dict() if not daily_with_indicators.empty else {}
}
if request.strategy == "value":
strategy = ValueStrategy()
elif request.strategy == "growth":
strategy = GrowthStrategy()
elif request.strategy == "technical":
strategy = TechnicalStrategy()
else:
raise HTTPException(status_code=400, detail="Invalid strategy")
score = strategy.calculate_score(stock_data)
# 如果是技术策略,添加交易信号分析
trading_signals = None
if request.strategy == "technical":
trading_signals = strategy.get_trading_signals(stock_data)
result = {
'ts_code': request.ts_code,
'strategy': request.strategy,
'score': round(score, 2),
'financial_ratios': clean_data_for_json(financial_ratios),
'recommendation': 'BUY' if score >= 65 else 'HOLD' if score >= 45 else 'SELL',
# 添加股票基本信息
'name': stock_basic_info['name'],
'industry': stock_basic_info['industry'],
'current_price': clean_data_for_json(stock_basic_info['current_price']),
'market_cap': clean_data_for_json(stock_basic_info['market_cap']),
'list_date': stock_basic_info['list_date']
}
# 添加交易信号(如果是技术分析)
if trading_signals:
result['trading_signals'] = clean_data_for_json(trading_signals)
# 确保整个结果都是JSON兼容的
result = clean_data_for_json(result)
cache.set(cache_key, result, 1800)
return result
except Exception as e:
logger.error(f"Error analyzing stock {request.ts_code}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/screen")
async def screen_stocks(request: ScreenRequest):
import uuid
request_id = str(uuid.uuid4())[:8]
try:
logger.info(f"[{request_id}] Starting screening request with strategy={request.strategy}")
cache_key = f"screen_hs300_{request.strategy}_{request.min_score}_{request.limit}"
cached_data = cache.get(cache_key)
if cached_data:
logger.info(f"[{request_id}] Returning cached data")
return cached_data
# 获取沪深300成分股
hs300_stocks = tushare_client.get_hs300_stocks()
if hs300_stocks.empty:
raise HTTPException(status_code=404, detail="No HS300 stock data found")
# 限制处理数量避免超时先处理前300只进行测试
max_process = min(300, len(hs300_stocks))
hs300_stocks = hs300_stocks.head(max_process)
logger.info(f"Limited processing to {max_process} stocks to avoid timeout")
# 初始化三种策略
value_strategy = ValueStrategy()
growth_strategy = GrowthStrategy()
technical_strategy = TechnicalStrategy()
results = []
processed_count = 0
if request.strategy == "comprehensive":
logger.info(f"[{request_id}] Starting comprehensive screening of {len(hs300_stocks)} HS300 stocks")
else:
# 保持向后兼容,支持单一策略
if request.strategy == "value":
strategy = value_strategy
elif request.strategy == "growth":
strategy = growth_strategy
elif request.strategy == "technical":
strategy = technical_strategy
else:
raise HTTPException(status_code=400, detail="Invalid strategy")
logger.info(f"[{request_id}] Starting to screen {len(hs300_stocks)} HS300 stocks with {request.strategy} strategy")
# 处理所有沪深300成分股
for _, stock in hs300_stocks.iterrows():
try:
ts_code = stock['ts_code']
logger.debug(f"Processing stock: {ts_code}") # 添加调试日志
financial_data = tushare_client.get_financial_data(ts_code)
financial_ratios = FinancialIndicators.calculate_all_ratios(financial_data)
daily_data = tushare_client.get_stock_daily(ts_code)
if not daily_data.empty:
daily_with_indicators = TechnicalIndicators.calculate_all_indicators(daily_data)
technical_indicators = daily_with_indicators.iloc[-1].to_dict()
else:
technical_indicators = {}
stock_data = {
'ts_code': ts_code,
'name': stock['name'],
'industry': stock.get('industry', ''),
'financial_ratios': financial_ratios,
'technical_indicators': technical_indicators
}
if request.strategy == "comprehensive":
# 综合评估:计算三种策略的平均分
value_score = value_strategy.calculate_score(stock_data)
growth_score = growth_strategy.calculate_score(stock_data)
technical_score = technical_strategy.calculate_score(stock_data)
# 综合评分(可以调整权重)
comprehensive_score = (value_score * 0.4 + growth_score * 0.3 + technical_score * 0.3)
if comprehensive_score >= request.min_score:
stock_data['score'] = round(comprehensive_score, 2)
stock_data['value_score'] = round(value_score, 2)
stock_data['growth_score'] = round(growth_score, 2)
stock_data['technical_score'] = round(technical_score, 2)
# 综合推荐逻辑
high_scores = sum([score >= 70 for score in [value_score, growth_score, technical_score]])
if comprehensive_score >= 75 or high_scores >= 2:
stock_data['recommendation'] = 'BUY'
elif comprehensive_score >= 65:
stock_data['recommendation'] = 'HOLD'
else:
stock_data['recommendation'] = 'WATCH'
logger.debug(f"Adding stock {ts_code} with comprehensive score {comprehensive_score}")
results.append(stock_data)
else:
# 单一策略评估(保持原有逻辑)
score = strategy.calculate_score(stock_data)
if score >= request.min_score:
stock_data['score'] = round(score, 2)
stock_data['recommendation'] = 'BUY' if score >= 65 else 'HOLD'
logger.debug(f"Adding stock {ts_code} with {request.strategy} score {score}")
results.append(stock_data)
processed_count += 1
if len(results) >= request.limit:
logger.info(f"[{request_id}] Reached limit of {request.limit} results, stopping processing")
break
except Exception as e:
logger.error(f"Error processing stock {stock.get('ts_code', 'unknown')}: {e}")
continue
# 对结果进行排序
results = sorted(results, key=lambda x: x['score'], reverse=True)
logger.info(f"[{request_id}] Screening completed. Found {len(results)} qualifying stocks from {processed_count} processed")
response = {
'strategy': request.strategy,
'min_score': request.min_score,
'results': clean_data_for_json(results),
'count': len(results),
'processed_count': processed_count,
'request_id': request_id # 添加请求ID到响应中
}
cache.set(cache_key, response, 1800)
logger.info(f"[{request_id}] Returning {len(results)} results")
return response
except Exception as e:
logger.error(f"[{request_id}] Error screening stocks: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"main:app",
host=Config.API_HOST,
port=Config.API_PORT,
reload=True
)