218 lines
7.5 KiB
Python
218 lines
7.5 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
FastAPI应用程序入口
|
||
为CryptoAI系统提供web API接口层
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
import uvicorn
|
||
import json
|
||
from fastapi import FastAPI, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
import time
|
||
from typing import Dict, Any
|
||
from cryptoai.utils.db_manager import init_db
|
||
|
||
from cryptoai.routes.user import router as user_router
|
||
from cryptoai.routes.adata import router as adata_router
|
||
from cryptoai.routes.crypto import router as crypto_router
|
||
from cryptoai.routes.platform import router as platform_router
|
||
from cryptoai.routes.analysis import router as analysis_router
|
||
from cryptoai.routes.alltick import router as alltick_router
|
||
from cryptoai.routes.payment import router as payment_router
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler("api_server.log"),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
logger = logging.getLogger("fastapi")
|
||
|
||
init_db()
|
||
|
||
# 配置选项
|
||
LOG_REQUEST_BODY = os.environ.get("LOG_REQUEST_BODY", "true").lower() == "true"
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(
|
||
title="CryptoAI API",
|
||
description="加密货币AI分析系统API接口",
|
||
version="0.1.0"
|
||
)
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 可以设置为特定域名,如["http://localhost:3000"]
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 添加API路由
|
||
app.include_router(platform_router, prefix="/platform", tags=["平台信息"])
|
||
app.include_router(user_router, prefix="/user", tags=["用户管理"])
|
||
app.include_router(adata_router, prefix="/adata", tags=["A股数据"])
|
||
app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"])
|
||
app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"])
|
||
app.include_router(alltick_router, prefix="/alltick", tags=["AllTick数据"])
|
||
app.include_router(payment_router, prefix="/payment", tags=["支付"])
|
||
|
||
# 请求体日志中间件
|
||
@app.middleware("http")
|
||
async def log_request_body(request: Request, call_next):
|
||
"""
|
||
记录请求体的中间件
|
||
可通过环境变量 LOG_REQUEST_BODY=false 来禁用
|
||
"""
|
||
# 如果禁用了请求体日志记录,则直接处理请求
|
||
if not LOG_REQUEST_BODY:
|
||
response = await call_next(request)
|
||
return response
|
||
|
||
# 获取请求的基本信息
|
||
method = request.method
|
||
url = str(request.url)
|
||
client_ip = request.client.host if request.client else "unknown"
|
||
|
||
# 初始化请求体内容
|
||
body_content = ""
|
||
|
||
# 只记录非 GET 请求的请求体
|
||
if method in ["POST", "PUT", "PATCH", "DELETE"]:
|
||
try:
|
||
# 读取请求体
|
||
body = await request.body()
|
||
|
||
if body:
|
||
# 尝试解析为 JSON
|
||
try:
|
||
body_json = json.loads(body.decode('utf-8'))
|
||
# 敏感信息过滤 - 隐藏密码等敏感字段
|
||
sensitive_fields = ['password', 'token', 'secret', 'key', 'auth']
|
||
filtered_body = body_json.copy()
|
||
|
||
def filter_sensitive_data(data, fields_to_hide):
|
||
if isinstance(data, dict):
|
||
for key in data:
|
||
if any(sensitive in key.lower() for sensitive in fields_to_hide):
|
||
data[key] = "***HIDDEN***"
|
||
elif isinstance(data[key], dict):
|
||
filter_sensitive_data(data[key], fields_to_hide)
|
||
elif isinstance(data[key], list):
|
||
for item in data[key]:
|
||
if isinstance(item, dict):
|
||
filter_sensitive_data(item, fields_to_hide)
|
||
|
||
filter_sensitive_data(filtered_body, sensitive_fields)
|
||
body_content = json.dumps(filtered_body, ensure_ascii=False, indent=2)
|
||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||
# 如果不是 JSON,则显示原始内容(限制长度)
|
||
body_content = body.decode('utf-8', errors='ignore')[:1000]
|
||
if len(body) > 1000:
|
||
body_content += "... (truncated)"
|
||
else:
|
||
body_content = "(empty)"
|
||
|
||
except Exception as e:
|
||
body_content = f"(error reading body: {str(e)})"
|
||
|
||
# 记录请求信息
|
||
logger.info(f"""
|
||
╭─── REQUEST LOG ───────────────────────────────────────────
|
||
│ Method: {method}
|
||
│ URL: {url}
|
||
│ Client IP: {client_ip}
|
||
│ Content-Type: {request.headers.get('content-type', 'N/A')}
|
||
│ User-Agent: {request.headers.get('user-agent', 'N/A')[:100]}
|
||
│ Body:
|
||
{body_content}
|
||
╰─────────────────────────────────────────────────────────
|
||
""")
|
||
else:
|
||
# GET 请求只记录基本信息
|
||
query_params = str(request.query_params) if request.query_params else "(none)"
|
||
logger.info(f"""
|
||
╭─── REQUEST LOG ───────────────────────────────────────────
|
||
│ Method: {method}
|
||
│ URL: {url}
|
||
│ Client IP: {client_ip}
|
||
│ Query Params: {query_params}
|
||
│ User-Agent: {request.headers.get('user-agent', 'N/A')[:100]}
|
||
╰─────────────────────────────────────────────────────────
|
||
""")
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
return response
|
||
|
||
# 请求计时中间件
|
||
@app.middleware("http")
|
||
async def add_process_time_header(request: Request, call_next):
|
||
start_time = time.time()
|
||
response = await call_next(request)
|
||
process_time = time.time() - start_time
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
return response
|
||
|
||
# 根路由
|
||
@app.get("/", tags=["信息"])
|
||
async def root() -> Dict[str, Any]:
|
||
"""
|
||
API根路径,提供API基本信息
|
||
"""
|
||
return {
|
||
"name": "CryptoAI API",
|
||
"version": "0.1.0",
|
||
"description": "加密货币AI分析系统API接口",
|
||
"documentation": "/docs",
|
||
"status": "running"
|
||
}
|
||
|
||
# 健康检查
|
||
@app.get("/health", tags=["信息"])
|
||
async def health_check() -> Dict[str, Any]:
|
||
"""
|
||
API健康检查接口
|
||
"""
|
||
return {
|
||
"status": "healthy",
|
||
"timestamp": time.time()
|
||
}
|
||
|
||
# 异常处理
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request: Request, exc: Exception):
|
||
logger.error(f"全局异常: {str(exc)}", exc_info=True)
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"detail": f"服务器内部错误: {str(exc)}"}
|
||
)
|
||
|
||
def start():
|
||
"""
|
||
启动FastAPI服务器
|
||
"""
|
||
# 获取环境变量或使用默认值
|
||
host = os.environ.get("API_HOST", "127.0.0.1")
|
||
port = int(os.environ.get("API_PORT", 8000))
|
||
|
||
# 启动服务器
|
||
uvicorn.run(
|
||
"cryptoai.routes.fastapi_app:app",
|
||
host=host,
|
||
port=port,
|
||
reload=True # 生产环境设为False
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
init_db()
|
||
start() |