This commit is contained in:
aaron 2026-02-04 14:56:03 +08:00
parent 88060474d1
commit 637d3e93b4
8 changed files with 61 additions and 29 deletions

View File

@ -46,7 +46,8 @@ class ContextManager:
session_id: str, session_id: str,
role: str, role: str,
content: str, content: str,
metadata: Optional[dict] = None metadata: Optional[dict] = None,
user_id: Optional[int] = None
): ):
""" """
添加消息到上下文 添加消息到上下文
@ -56,8 +57,9 @@ class ContextManager:
role: 角色user/assistant role: 角色user/assistant
content: 消息内容 content: 消息内容
metadata: 元数据 metadata: 元数据
user_id: 用户ID创建新对话时需要
""" """
db_service.add_message(session_id, role, content, metadata) db_service.add_message(session_id, role, content, metadata, user_id)
logger.info(f"添加消息到上下文: {session_id}, {role}") logger.info(f"添加消息到上下文: {session_id}, {role}")
def clear_context(self, session_id: str): def clear_context(self, session_id: str):

View File

@ -1641,8 +1641,11 @@ RSI{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
""" """
logger.info(f"[智能模式-流式] 处理消息: {message[:50]}...") logger.info(f"[智能模式-流式] 处理消息: {message[:50]}...")
# 转换 user_id 为整数(如果是字符串)
user_id_int = int(user_id) if user_id else None
# 1. 保存用户消息 # 1. 保存用户消息
self.context_manager.add_message(session_id, "user", message) self.context_manager.add_message(session_id, "user", message, user_id=user_id_int)
# 2. 提取上下文信息 # 2. 提取上下文信息
context_info = self.context_manager.extract_context_info(session_id) context_info = self.context_manager.extract_context_info(session_id)
@ -1685,7 +1688,7 @@ RSI{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
yield char yield char
# 6. 保存助手响应 # 6. 保存助手响应
self.context_manager.add_message(session_id, "assistant", full_response) self.context_manager.add_message(session_id, "assistant", full_response, user_id=user_id_int)
async def _handle_other_question( async def _handle_other_question(
self, self,

View File

@ -83,6 +83,9 @@ class Settings(BaseSettings):
code_resend_seconds: int = 60 code_resend_seconds: int = 60
code_max_per_hour: int = 10 code_max_per_hour: int = 10
# 白名单手机号(无需验证码即可登录)
whitelist_phones: str = "18583366860,18583926860"
# CORS配置 # CORS配置
cors_origins: str = "http://localhost:8000,http://127.0.0.1:8000" cors_origins: str = "http://localhost:8000,http://127.0.0.1:8000"

View File

@ -41,11 +41,17 @@ if os.path.exists(frontend_path):
@app.get("/") @app.get("/")
async def root(): async def root():
"""根路径,返回前端页面""" """根路径,重定向到登录页"""
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/static/login.html")
@app.get("/app")
async def app_page():
"""主应用页面"""
index_path = os.path.join(frontend_path, "index.html") index_path = os.path.join(frontend_path, "index.html")
if os.path.exists(index_path): if os.path.exists(index_path):
return FileResponse(index_path) return FileResponse(index_path)
return {"message": "A股AI分析Agent系统API"} return {"message": "页面不存在"}
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():

View File

@ -31,6 +31,7 @@ async def get_current_user(
""" """
try: try:
token = credentials.credentials token = credentials.credentials
logger.info(f"收到认证请求token前10位: {token[:10] if token else 'None'}")
# 验证token # 验证token
payload = jwt_service.verify_token(token) payload = jwt_service.verify_token(token)
@ -52,6 +53,7 @@ async def get_current_user(
headers={"WWW-Authenticate": "Bearer"} headers={"WWW-Authenticate": "Bearer"}
) )
logger.info(f"认证成功: user_id={user.id}, phone={user.phone}")
return user return user
finally: finally:

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from app.models.database import User, VerificationCode from app.models.database import User, VerificationCode
from app.services.sms_service import sms_service from app.services.sms_service import sms_service
from app.services.jwt_service import jwt_service from app.services.jwt_service import jwt_service
from app.config import get_settings
from app.utils.logger import logger from app.utils.logger import logger
@ -15,9 +16,13 @@ class AuthService:
def __init__(self): def __init__(self):
"""初始化认证服务""" """初始化认证服务"""
self.code_expire_minutes = 5 settings = get_settings()
self.code_resend_seconds = 60 self.code_expire_minutes = settings.code_expire_minutes
self.code_max_per_hour = 10 self.code_resend_seconds = settings.code_resend_seconds
self.code_max_per_hour = settings.code_max_per_hour
# 白名单手机号列表
self.whitelist_phones = [p.strip() for p in settings.whitelist_phones.split(",") if p.strip()]
logger.info(f"白名单手机号: {self.whitelist_phones}")
async def send_verification_code( async def send_verification_code(
self, self,
@ -110,7 +115,13 @@ class AuthService:
Returns: Returns:
{"success": bool, "token": str, "user": dict, "message": str} {"success": bool, "token": str, "user": dict, "message": str}
""" """
# 1. 查找验证码 # 检查是否为白名单手机号
is_whitelist = phone in self.whitelist_phones
if is_whitelist:
logger.info(f"白名单手机号登录: {phone}")
else:
# 1. 查找验证码(非白名单需要验证)
verification = db.query(VerificationCode).filter( verification = db.query(VerificationCode).filter(
VerificationCode.phone == phone, VerificationCode.phone == phone,
VerificationCode.code == code, VerificationCode.code == code,
@ -141,7 +152,8 @@ class AuthService:
# 4. 更新最后登录时间 # 4. 更新最后登录时间
user.last_login_at = datetime.utcnow() user.last_login_at = datetime.utcnow()
# 5. 关联验证码到用户 # 5. 关联验证码到用户(如果不是白名单)
if not is_whitelist:
verification.user_id = user.id verification.user_id = user.id
db.commit() db.commit()

View File

@ -81,7 +81,8 @@ class DatabaseService:
session_id: str, session_id: str,
role: str, role: str,
content: str, content: str,
metadata: Optional[dict] = None metadata: Optional[dict] = None,
user_id: Optional[int] = None
) -> Message: ) -> Message:
""" """
添加消息 添加消息
@ -91,6 +92,7 @@ class DatabaseService:
role: 角色user/assistant role: 角色user/assistant
content: 消息内容 content: 消息内容
metadata: 元数据 metadata: 元数据
user_id: 用户ID创建新对话时需要
Returns: Returns:
消息对象 消息对象
@ -103,7 +105,9 @@ class DatabaseService:
).first() ).first()
if not conversation: if not conversation:
conversation = Conversation(session_id=session_id) if not user_id:
raise ValueError("创建新对话时必须提供 user_id")
conversation = Conversation(session_id=session_id, user_id=user_id)
db.add(conversation) db.add(conversation)
db.commit() db.commit()
db.refresh(conversation) db.refresh(conversation)

View File

@ -312,7 +312,7 @@
// 保存token // 保存token
localStorage.setItem('token', data.token); localStorage.setItem('token', data.token);
// 跳转到主页 // 跳转到主页
window.location.href = '/'; window.location.href = '/app';
} else { } else {
this.errorMessage = data.message || '登录失败'; this.errorMessage = data.message || '登录失败';
} }