From 88060474d18ad2b5f36ebba1fa18e71a7f835e44 Mon Sep 17 00:00:00 2001 From: aaron <> Date: Wed, 4 Feb 2026 11:18:19 +0800 Subject: [PATCH] update --- backend/app/api/auth.py | 158 +++++++++++ backend/app/api/chat.py | 14 +- backend/app/config.py | 16 ++ backend/app/main.py | 3 +- backend/app/middleware/auth_middleware.py | 101 +++++++ backend/app/models/auth.py | 46 +++ backend/app/models/database.py | 45 ++- backend/app/services/auth_service.py | 182 ++++++++++++ backend/app/services/jwt_service.py | 93 ++++++ backend/app/services/sms_service.py | 100 +++++++ backend/requirements.txt | 3 + frontend/css/style.css | 21 ++ frontend/index.html | 8 + frontend/js/app.js | 36 ++- frontend/login.html | 330 ++++++++++++++++++++++ 15 files changed, 1146 insertions(+), 10 deletions(-) create mode 100644 backend/app/api/auth.py create mode 100644 backend/app/middleware/auth_middleware.py create mode 100644 backend/app/models/auth.py create mode 100644 backend/app/services/auth_service.py create mode 100644 backend/app/services/jwt_service.py create mode 100644 backend/app/services/sms_service.py create mode 100644 frontend/login.html diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py new file mode 100644 index 0000000..2319bab --- /dev/null +++ b/backend/app/api/auth.py @@ -0,0 +1,158 @@ +""" +认证API +""" +from fastapi import APIRouter, HTTPException, Depends, Request +from app.models.auth import ( + SendCodeRequest, SendCodeResponse, + LoginRequest, LoginResponse, + RefreshTokenResponse, UserInfo +) +from app.models.database import User +from app.services.auth_service import auth_service +from app.services.jwt_service import jwt_service +from app.services.db_service import db_service +from app.middleware.auth_middleware import get_current_user, get_client_ip +from app.utils.logger import logger + +router = APIRouter(prefix="/api/auth", tags=["认证"]) + + +@router.post("/send-code", response_model=SendCodeResponse) +async def send_verification_code( + request_data: SendCodeRequest, + request: Request +): + """ + 发送验证码 + + - 同一手机号60秒内只能发送一次 + - 同一IP每小时最多发送10次 + - 验证码5分钟有效期 + """ + try: + client_ip = get_client_ip(request) + db = db_service.get_session() + + try: + result = await auth_service.send_verification_code( + db=db, + phone=request_data.phone, + ip_address=client_ip + ) + + return SendCodeResponse(**result) + + finally: + db.close() + + except Exception as e: + logger.error(f"发送验证码失败: {e}") + raise HTTPException(status_code=500, detail="发送验证码失败") + + +@router.post("/login", response_model=LoginResponse) +async def login_with_code( + request_data: LoginRequest, + request: Request +): + """ + 验证码登录 + + - 验证验证码是否正确且未过期 + - 用户不存在则自动注册 + - 生成JWT token(7天有效期) + - 更新最后登录时间 + """ + try: + client_ip = get_client_ip(request) + db = db_service.get_session() + + try: + result = await auth_service.login_with_code( + db=db, + phone=request_data.phone, + code=request_data.code, + ip_address=client_ip + ) + + if not result["success"]: + return LoginResponse( + success=False, + message=result["message"] + ) + + return LoginResponse( + success=True, + token=result["token"], + user=UserInfo(**result["user"]) + ) + + finally: + db.close() + + except Exception as e: + logger.error(f"登录失败: {e}") + raise HTTPException(status_code=500, detail="登录失败") + + +@router.post("/refresh", response_model=RefreshTokenResponse) +async def refresh_token( + current_user: User = Depends(get_current_user) +): + """ + 刷新Token + + 需要提供有效的JWT token + """ + try: + # 生成新的token + new_token = jwt_service.create_access_token( + current_user.id, + current_user.phone + ) + + return RefreshTokenResponse( + success=True, + token=new_token + ) + + except Exception as e: + logger.error(f"刷新token失败: {e}") + raise HTTPException(status_code=500, detail="刷新token失败") + + +@router.get("/me", response_model=UserInfo) +async def get_current_user_info( + current_user: User = Depends(get_current_user) +): + """ + 获取当前用户信息 + + 需要提供有效的JWT token + """ + # 手机号脱敏 + masked_phone = f"{current_user.phone[:3]}****{current_user.phone[-4:]}" + + return UserInfo( + id=current_user.id, + phone=masked_phone, + created_at=current_user.created_at, + last_login_at=current_user.last_login_at + ) + + +@router.post("/logout") +async def logout( + current_user: User = Depends(get_current_user) +): + """ + 登出 + + 主要在前端清除token,后端记录日志 + """ + logger.info(f"用户登出: {current_user.phone}") + + return { + "success": True, + "message": "已登出" + } diff --git a/backend/app/api/chat.py b/backend/app/api/chat.py index 5a6da66..851064e 100644 --- a/backend/app/api/chat.py +++ b/backend/app/api/chat.py @@ -1,14 +1,16 @@ """ 对话API路由 """ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import StreamingResponse from typing import Optional import uuid import json import asyncio from app.models.chat import ChatRequest, ChatResponse +from app.models.database import User from app.agent.smart_agent import smart_agent # 使用智能Agent +from app.middleware.auth_middleware import get_current_user from app.utils.logger import logger router = APIRouter() @@ -71,12 +73,16 @@ async def get_history(session_id: str, limit: int = 50): @router.post("/message/stream") -async def send_message_stream(request: ChatRequest): +async def send_message_stream( + request: ChatRequest, + current_user: User = Depends(get_current_user) +): """ 流式发送消息给Agent Args: request: 聊天请求 + current_user: 当前登录用户 Returns: Server-Sent Events 流式响应 @@ -94,11 +100,11 @@ async def send_message_stream(request: ChatRequest): # 添加小延迟确保数据被发送 await asyncio.sleep(0.01) - # 处理消息并流式返回 + # 处理消息并流式返回(使用真实用户ID) async for chunk in smart_agent.process_message_stream( message=request.message, session_id=session_id, - user_id=request.user_id + user_id=str(current_user.id) ): yield f"data: {json.dumps({'type': 'content', 'content': chunk})}\n\n" # 添加小延迟,让浏览器有机会接收数据 diff --git a/backend/app/config.py b/backend/app/config.py index 4dd5595..0f7b532 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -67,6 +67,22 @@ class Settings(BaseSettings): secret_key: str = "change-this-secret-key-in-production" rate_limit: str = "100/minute" + # JWT配置 + jwt_algorithm: str = "HS256" + jwt_expire_days: int = 7 + + # 腾讯云短信配置 + tencent_sms_app_id: str = "1400961527" + tencent_sms_secret_id: str = "AKIDxnbGj281iHtKallqqzvlV5YxBCrPltnS" # 腾讯云SecretId + tencent_sms_secret_key: str = "ta6PXTMBsX7dzA7IN6uYUFn8F9uTovoU" # 腾讯云SecretKey + tencent_sms_sign_id: str = "629073" + tencent_sms_template_id: str = "2353142" + + # 验证码配置 + code_expire_minutes: int = 5 + code_resend_seconds: int = 60 + code_max_per_hour: int = 10 + # CORS配置 cors_origins: str = "http://localhost:8000,http://127.0.0.1:8000" diff --git a/backend/app/main.py b/backend/app/main.py index cfd9b8f..f355b14 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -7,7 +7,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from app.config import get_settings from app.utils.logger import logger -from app.api import chat, stock, skills, llm +from app.api import chat, stock, skills, llm, auth import os # 创建FastAPI应用 @@ -28,6 +28,7 @@ app.add_middleware( ) # 注册路由 +app.include_router(auth.router, tags=["认证"]) app.include_router(chat.router, prefix="/api/chat", tags=["对话"]) app.include_router(stock.router, prefix="/api/stock", tags=["股票数据"]) app.include_router(skills.router, prefix="/api/skills", tags=["技能管理"]) diff --git a/backend/app/middleware/auth_middleware.py b/backend/app/middleware/auth_middleware.py new file mode 100644 index 0000000..3d1df6e --- /dev/null +++ b/backend/app/middleware/auth_middleware.py @@ -0,0 +1,101 @@ +""" +JWT认证中间件 +""" +from fastapi import Request, HTTPException, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from sqlalchemy.orm import Session +from app.services.jwt_service import jwt_service +from app.models.database import User +from app.services.db_service import db_service +from app.utils.logger import logger + +security = HTTPBearer() + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security) +) -> User: + """ + 获取当前登录用户 + + 从JWT token中解析用户信息并返回User对象 + + Args: + credentials: HTTP Bearer认证凭据 + + Returns: + User对象 + + Raises: + HTTPException: 认证失败时抛出401异常 + """ + try: + token = credentials.credentials + + # 验证token + payload = jwt_service.verify_token(token) + user_id = int(payload.get("sub")) + + # 从数据库查询用户 + db = db_service.get_session() + try: + user = db.query(User).filter( + User.id == user_id, + User.is_active == True + ).first() + + if not user: + logger.warning(f"用户不存在或已禁用: user_id={user_id}") + raise HTTPException( + status_code=401, + detail="用户不存在或已禁用", + headers={"WWW-Authenticate": "Bearer"} + ) + + return user + + finally: + db.close() + + except ValueError as e: + logger.warning(f"Token验证失败: {e}") + raise HTTPException( + status_code=401, + detail=str(e), + headers={"WWW-Authenticate": "Bearer"} + ) + except Exception as e: + logger.error(f"认证异常: {e}") + raise HTTPException( + status_code=401, + detail="认证失败", + headers={"WWW-Authenticate": "Bearer"} + ) + + +def get_client_ip(request: Request) -> str: + """ + 获取客户端IP地址 + + Args: + request: FastAPI请求对象 + + Returns: + 客户端IP地址 + """ + # 优先从X-Forwarded-For获取(代理/负载均衡场景) + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + # X-Forwarded-For可能包含多个IP,取第一个 + return forwarded.split(",")[0].strip() + + # 从X-Real-IP获取 + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + # 直接从client获取 + if request.client: + return request.client.host + + return "unknown" diff --git a/backend/app/models/auth.py b/backend/app/models/auth.py new file mode 100644 index 0000000..12beb13 --- /dev/null +++ b/backend/app/models/auth.py @@ -0,0 +1,46 @@ +""" +认证相关的Pydantic模型 +""" +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + + +class SendCodeRequest(BaseModel): + """发送验证码请求""" + phone: str = Field(..., min_length=11, max_length=11, description="手机号") + + +class SendCodeResponse(BaseModel): + """发送验证码响应""" + success: bool + message: str + expires_in: Optional[int] = None + + +class LoginRequest(BaseModel): + """登录请求""" + phone: str = Field(..., min_length=11, max_length=11, description="手机号") + code: str = Field(..., min_length=6, max_length=6, description="验证码") + + +class UserInfo(BaseModel): + """用户信息""" + id: int + phone: str + created_at: datetime + last_login_at: Optional[datetime] = None + + +class LoginResponse(BaseModel): + """登录响应""" + success: bool + token: Optional[str] = None + user: Optional[UserInfo] = None + message: Optional[str] = None + + +class RefreshTokenResponse(BaseModel): + """刷新Token响应""" + success: bool + token: str diff --git a/backend/app/models/database.py b/backend/app/models/database.py index a201259..1b58677 100644 --- a/backend/app/models/database.py +++ b/backend/app/models/database.py @@ -2,23 +2,57 @@ 数据库模型定义 """ from datetime import datetime -from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON +from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON, Boolean from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship Base = declarative_base() +class User(Base): + """用户表""" + __tablename__ = "users" + + id = Column(Integer, primary_key=True, index=True) + phone = Column(String(11), unique=True, nullable=False, index=True) + created_at = Column(DateTime, default=datetime.utcnow) + last_login_at = Column(DateTime, nullable=True) + is_active = Column(Boolean, default=True) + + # 关联 + conversations = relationship("Conversation", back_populates="user") + verification_codes = relationship("VerificationCode", back_populates="user") + + +class VerificationCode(Base): + """验证码表""" + __tablename__ = "verification_codes" + + id = Column(Integer, primary_key=True, index=True) + phone = Column(String(11), nullable=False, index=True) + code = Column(String(6), nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + expires_at = Column(DateTime, nullable=False) + is_used = Column(Boolean, default=False) + used_at = Column(DateTime, nullable=True) + ip_address = Column(String(45), nullable=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + + # 关联 + user = relationship("User", back_populates="verification_codes") + + class Conversation(Base): """对话记录表""" __tablename__ = "conversations" id = Column(Integer, primary_key=True, index=True) session_id = Column(String(64), nullable=False, index=True) - user_id = Column(String(64), nullable=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime, default=datetime.utcnow) - # 关联消息 + # 关联 + user = relationship("User", back_populates="conversations") messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") @@ -42,6 +76,9 @@ class UserPreference(Base): __tablename__ = "user_preferences" id = Column(Integer, primary_key=True, index=True) - user_id = Column(String(64), unique=True, nullable=False, index=True) + user_id = Column(Integer, ForeignKey("users.id"), unique=True, nullable=False, index=True) preferences = Column(JSON, nullable=True) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + # 关联 + user = relationship("User") diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py new file mode 100644 index 0000000..873ef3e --- /dev/null +++ b/backend/app/services/auth_service.py @@ -0,0 +1,182 @@ +""" +认证服务 +""" +from datetime import datetime, timedelta +from typing import Dict, Optional +from sqlalchemy.orm import Session +from app.models.database import User, VerificationCode +from app.services.sms_service import sms_service +from app.services.jwt_service import jwt_service +from app.utils.logger import logger + + +class AuthService: + """认证服务类""" + + def __init__(self): + """初始化认证服务""" + self.code_expire_minutes = 5 + self.code_resend_seconds = 60 + self.code_max_per_hour = 10 + + async def send_verification_code( + self, + db: Session, + phone: str, + ip_address: str + ) -> Dict: + """ + 发送验证码 + + Args: + db: 数据库会话 + phone: 手机号 + ip_address: IP地址 + + Returns: + {"success": bool, "message": str, "expires_in": int} + """ + # 1. 检查发送频率限制(60秒) + last_code = db.query(VerificationCode).filter( + VerificationCode.phone == phone, + VerificationCode.created_at > datetime.utcnow() - timedelta(seconds=self.code_resend_seconds) + ).first() + + if last_code: + remaining = self.code_resend_seconds - int((datetime.utcnow() - last_code.created_at).total_seconds()) + return { + "success": False, + "message": f"请{remaining}秒后再试" + } + + # 2. 检查IP限制(每小时10次) + ip_count = db.query(VerificationCode).filter( + VerificationCode.ip_address == ip_address, + VerificationCode.created_at > datetime.utcnow() - timedelta(hours=1) + ).count() + + if ip_count >= self.code_max_per_hour: + return { + "success": False, + "message": "发送次数过多,请稍后再试" + } + + # 3. 生成验证码 + code = sms_service.generate_code() + + # 4. 发送短信 + success = await sms_service.send_code(phone, code) + + if not success: + return { + "success": False, + "message": "发送失败,请稍后重试" + } + + # 5. 保存验证码记录 + verification = VerificationCode( + phone=phone, + code=code, + expires_at=datetime.utcnow() + timedelta(minutes=self.code_expire_minutes), + ip_address=ip_address + ) + db.add(verification) + db.commit() + + logger.info(f"验证码已发送: {phone}") + + return { + "success": True, + "message": "验证码已发送", + "expires_in": self.code_expire_minutes * 60 + } + + async def login_with_code( + self, + db: Session, + phone: str, + code: str, + ip_address: str + ) -> Dict: + """ + 验证码登录 + + Args: + db: 数据库会话 + phone: 手机号 + code: 验证码 + ip_address: IP地址 + + Returns: + {"success": bool, "token": str, "user": dict, "message": str} + """ + # 1. 查找验证码 + verification = db.query(VerificationCode).filter( + VerificationCode.phone == phone, + VerificationCode.code == code, + VerificationCode.is_used == False, + VerificationCode.expires_at > datetime.utcnow() + ).order_by(VerificationCode.created_at.desc()).first() + + if not verification: + return { + "success": False, + "message": "验证码错误或已过期" + } + + # 2. 标记验证码已使用 + verification.is_used = True + verification.used_at = datetime.utcnow() + + # 3. 查找或创建用户 + user = db.query(User).filter(User.phone == phone).first() + + if not user: + # 自动注册 + user = User(phone=phone) + db.add(user) + db.flush() + logger.info(f"新用户注册: {phone}") + + # 4. 更新最后登录时间 + user.last_login_at = datetime.utcnow() + + # 5. 关联验证码到用户 + verification.user_id = user.id + + db.commit() + db.refresh(user) + + # 6. 生成JWT token + token = jwt_service.create_access_token(user.id, user.phone) + + logger.info(f"用户登录成功: {phone}") + + return { + "success": True, + "token": token, + "user": { + "id": user.id, + "phone": self._mask_phone(user.phone), + "created_at": user.created_at.isoformat(), + "last_login_at": user.last_login_at.isoformat() if user.last_login_at else None + } + } + + def _mask_phone(self, phone: str) -> str: + """ + 手机号脱敏 + + Args: + phone: 手机号 + + Returns: + 脱敏后的手机号 + """ + if len(phone) == 11: + return f"{phone[:3]}****{phone[-4:]}" + return phone + + +# 创建全局实例 +auth_service = AuthService() diff --git a/backend/app/services/jwt_service.py b/backend/app/services/jwt_service.py new file mode 100644 index 0000000..53694d7 --- /dev/null +++ b/backend/app/services/jwt_service.py @@ -0,0 +1,93 @@ +""" +JWT令牌服务 +""" +from datetime import datetime, timedelta +from typing import Optional, Dict +from jose import JWTError, jwt +from app.config import get_settings +from app.utils.logger import logger + + +class JWTService: + """JWT服务类""" + + def __init__(self): + """初始化JWT服务""" + settings = get_settings() + self.secret_key = settings.secret_key + self.algorithm = settings.jwt_algorithm + self.expire_days = settings.jwt_expire_days + + def create_access_token(self, user_id: int, phone: str) -> str: + """ + 创建访问令牌 + + Args: + user_id: 用户ID + phone: 手机号 + + Returns: + JWT token + """ + expire = datetime.utcnow() + timedelta(days=self.expire_days) + to_encode = { + "sub": str(user_id), + "phone": phone, + "exp": expire, + "iat": datetime.utcnow() + } + + try: + encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) + logger.info(f"创建JWT token成功: user_id={user_id}") + return encoded_jwt + except Exception as e: + logger.error(f"创建JWT token失败: {e}") + raise + + def verify_token(self, token: str) -> Dict: + """ + 验证令牌 + + Args: + token: JWT token + + Returns: + 解码后的payload + + Raises: + ValueError: token无效或过期 + """ + try: + payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) + user_id = payload.get("sub") + if user_id is None: + raise ValueError("Token中缺少用户ID") + return payload + except JWTError as e: + logger.warning(f"Token验证失败: {e}") + raise ValueError(f"Token验证失败: {str(e)}") + except Exception as e: + logger.error(f"Token解析异常: {e}") + raise ValueError(f"Token解析异常: {str(e)}") + + def decode_token_without_verification(self, token: str) -> Optional[Dict]: + """ + 不验证签名解码token(用于调试) + + Args: + token: JWT token + + Returns: + 解码后的payload或None + """ + try: + payload = jwt.decode(token, options={"verify_signature": False}) + return payload + except Exception as e: + logger.error(f"Token解码失败: {e}") + return None + + +# 创建全局实例 +jwt_service = JWTService() diff --git a/backend/app/services/sms_service.py b/backend/app/services/sms_service.py new file mode 100644 index 0000000..de6cebf --- /dev/null +++ b/backend/app/services/sms_service.py @@ -0,0 +1,100 @@ +""" +腾讯云短信服务 +""" +import random +import string +from typing import Optional +from tencentcloud.common import credential +from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from tencentcloud.sms.v20210111 import sms_client, models +from app.config import get_settings +from app.utils.logger import logger + + +class SMSService: + """腾讯云短信服务类""" + + def __init__(self): + """初始化短信服务""" + settings = get_settings() + self.app_id = settings.tencent_sms_app_id + self.secret_id = settings.tencent_sms_secret_id + self.secret_key = settings.tencent_sms_secret_key + self.sign_id = settings.tencent_sms_sign_id + self.template_id = settings.tencent_sms_template_id + + # 初始化客户端 + try: + # 使用密钥认证 + cred = credential.Credential(self.secret_id, self.secret_key) + self.client = sms_client.SmsClient(cred, "ap-guangzhou") + logger.info("腾讯云短信服务初始化成功") + except Exception as e: + logger.error(f"腾讯云短信服务初始化失败: {e}") + self.client = None + + def generate_code(self, length: int = 6) -> str: + """ + 生成验证码 + + Args: + length: 验证码长度,默认6位 + + Returns: + 验证码字符串 + """ + return ''.join(random.choices(string.digits, k=length)) + + async def send_code(self, phone: str, code: str) -> bool: + """ + 发送验证码短信 + + Args: + phone: 手机号 + code: 验证码 + + Returns: + 是否发送成功 + """ + if not self.client: + logger.error("短信客户端未初始化") + return False + + if not self.template_id: + logger.warning("短信模板ID未配置,跳过发送(开发模式)") + logger.info(f"【开发模式】验证码: {code} (手机号: {phone})") + return True + + try: + req = models.SendSmsRequest() + req.SmsSdkAppId = self.app_id + req.SignName = "成都爱嘉辰科技" + req.TemplateId = self.template_id + req.TemplateParamSet = [code] # 只传递验证码参数 + req.PhoneNumberSet = [f"+86{phone}"] + + resp = self.client.SendSms(req) + + # 检查发送结果 + if resp.SendStatusSet and len(resp.SendStatusSet) > 0: + status = resp.SendStatusSet[0] + if status.Code == "Ok": + logger.info(f"短信发送成功: {phone}") + return True + else: + logger.error(f"短信发送失败: {status.Code} - {status.Message}") + return False + else: + logger.error("短信发送失败: 无响应状态") + return False + + except TencentCloudSDKException as e: + logger.error(f"腾讯云SDK异常: {e}") + return False + except Exception as e: + logger.error(f"发送短信异常: {e}") + return False + + +# 创建全局实例 +sms_service = SMSService() diff --git a/backend/requirements.txt b/backend/requirements.txt index db0f982..bba367f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,3 +16,6 @@ numpy>=1.26.0 python-multipart==0.0.6 aiohttp==3.9.1 yfinance>=0.2.36 +PyJWT==2.8.0 +tencentcloud-sdk-python==3.0.1100 +python-jose[cryptography]==3.3.0 diff --git a/frontend/css/style.css b/frontend/css/style.css index 0fef21c..5df3595 100644 --- a/frontend/css/style.css +++ b/frontend/css/style.css @@ -98,6 +98,27 @@ html, body { flex-shrink: 0; } +.logout-btn { + display: flex; + align-items: center; + justify-content: center; + width: 36px; + height: 36px; + margin-left: 12px; + background: transparent; + border: 1px solid var(--border-bright); + border-radius: 2px; + color: var(--text-secondary); + cursor: pointer; + transition: all 0.2s; +} + +.logout-btn:hover { + background: var(--accent-dim); + border-color: var(--accent); + color: var(--accent); +} + .model-select { background: transparent; border: none; diff --git a/frontend/index.html b/frontend/index.html index b061f56..b3b2f92 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -41,6 +41,14 @@ + + diff --git a/frontend/js/app.js b/frontend/js/app.js index 663a843..9940111 100644 --- a/frontend/js/app.js +++ b/frontend/js/app.js @@ -18,11 +18,35 @@ createApp({ }; }, mounted() { + // 检查登录状态 + if (!this.checkAuth()) { + window.location.href = '/static/login.html'; + return; + } + this.sessionId = this.generateSessionId(); this.autoResizeTextarea(); this.loadModels(); }, methods: { + checkAuth() { + const token = localStorage.getItem('token'); + if (!token) return false; + + // 验证token是否过期(简单检查) + try { + const payload = JSON.parse(atob(token.split('.')[1])); + return payload.exp * 1000 > Date.now(); + } catch { + return false; + } + }, + + logout() { + localStorage.removeItem('token'); + window.location.href = '/static/login.html'; + }, + async sendMessage() { if (!this.userInput.trim() || this.loading) return; @@ -55,13 +79,16 @@ createApp({ const messageIndex = this.messages.length - 1; try { + const token = localStorage.getItem('token'); + // 使用流式API const response = await fetch('/api/chat/message/stream', { method: 'POST', headers: { 'Content-Type': 'application/json', 'Accept': 'text/event-stream', - 'Cache-Control': 'no-cache' + 'Cache-Control': 'no-cache', + 'Authorization': `Bearer ${token}` }, body: JSON.stringify({ message: message, @@ -69,6 +96,13 @@ createApp({ }) }); + if (response.status === 401) { + // Token过期或无效,跳转登录页 + localStorage.removeItem('token'); + window.location.href = '/static/login.html'; + return; + } + if (!response.ok) { throw new Error('请求失败'); } diff --git a/frontend/login.html b/frontend/login.html new file mode 100644 index 0000000..a35fd08 --- /dev/null +++ b/frontend/login.html @@ -0,0 +1,330 @@ + + + + + + 登录 - Tradus|AI 金融智能体 + + + + +
+
+ + + + + +
+
+ + + + +