update
This commit is contained in:
parent
36c662f3d9
commit
d6c92117e9
Binary file not shown.
@ -243,21 +243,21 @@ class CryptoAgent:
|
||||
"current_price": float(processed_data['close'].iloc[-1]),
|
||||
# "price_change_24h": float(processed_data['close'].iloc[-1] - processed_data['close'].iloc[-24]),
|
||||
# "price_change_percentage_24h": float((processed_data['close'].iloc[-1] - processed_data['close'].iloc[-24]) / processed_data['close'].iloc[-24] * 100),
|
||||
# "historical_prices": processed_data['close'].tail(100).tolist(),
|
||||
# "volumes": processed_data['volume'].tail(100).tolist(),
|
||||
# "technical_indicators": {
|
||||
# "rsi": float(processed_data['RSI'].iloc[-1]),
|
||||
# "macd": float(processed_data['MACD'].iloc[-1]),
|
||||
# "macd_signal": float(processed_data['MACD_Signal'].iloc[-1]),
|
||||
# "bollinger_upper": float(processed_data['Bollinger_Upper'].iloc[-1]),
|
||||
# "bollinger_lower": float(processed_data['Bollinger_Lower'].iloc[-1]),
|
||||
# "ma5": float(processed_data['MA5'].iloc[-1]),
|
||||
# "ma10": float(processed_data['MA10'].iloc[-1]),
|
||||
# "ma20": float(processed_data['MA20'].iloc[-1]),
|
||||
# "ma50": float(processed_data['MA50'].iloc[-1]),
|
||||
# "atr": float(processed_data['ATR'].iloc[-1])
|
||||
# },
|
||||
"klines": processed_data[['open', 'high', 'low', 'close', 'volume', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_Upper', 'Bollinger_Lower', 'MA5', 'MA10', 'MA20', 'MA50', 'ATR']].tail(100).to_dict('records')
|
||||
"historical_prices": processed_data['close'].tail(100).tolist(),
|
||||
"volumes": processed_data['volume'].tail(100).tolist(),
|
||||
"technical_indicators": {
|
||||
"rsi": float(processed_data['RSI'].iloc[-1]),
|
||||
"macd": float(processed_data['MACD'].iloc[-1]),
|
||||
"macd_signal": float(processed_data['MACD_Signal'].iloc[-1]),
|
||||
"bollinger_upper": float(processed_data['Bollinger_Upper'].iloc[-1]),
|
||||
"bollinger_lower": float(processed_data['Bollinger_Lower'].iloc[-1]),
|
||||
"ma5": float(processed_data['MA5'].iloc[-1]),
|
||||
"ma10": float(processed_data['MA10'].iloc[-1]),
|
||||
"ma20": float(processed_data['MA20'].iloc[-1]),
|
||||
"ma50": float(processed_data['MA50'].iloc[-1]),
|
||||
"atr": float(processed_data['ATR'].iloc[-1])
|
||||
},
|
||||
"klines": processed_data[['open', 'high', 'low', 'close', 'volume']].tail(100).to_dict('records')
|
||||
}
|
||||
|
||||
# 将市场数据格式化为适合大模型的格式
|
||||
@ -364,7 +364,7 @@ class CryptoAgent:
|
||||
# 保存交易建议到数据库
|
||||
try:
|
||||
saved = self.db_manager.save_agent_feed(
|
||||
agent_name="加密货币AI助理",
|
||||
agent_name="Crypto Agent",
|
||||
content=message
|
||||
)
|
||||
except Exception as e:
|
||||
@ -383,7 +383,6 @@ class CryptoAgent:
|
||||
请对以下加密货币市场分析的JSON结果进行归纳总结:
|
||||
|
||||
需要输出的内容包括:
|
||||
标题:AI Agent 加密货币分析报告
|
||||
1. 对交易对给出操作建议:
|
||||
1.1 操作方向(做多、做空、观望)
|
||||
1.2 操作价位
|
||||
@ -394,7 +393,7 @@ class CryptoAgent:
|
||||
以下是每个交易对的分析结果:
|
||||
{results}
|
||||
|
||||
请以优美的Markdown格式输出,不要使用表格,通过 emoji 标签来增加可读性。
|
||||
请以优美的Markdown格式输出,不宜用过大的标题,通过 emoji 标签来增加可读性。
|
||||
"""
|
||||
|
||||
system_prompt = """
|
||||
|
||||
@ -81,4 +81,12 @@ database:
|
||||
port: 27469
|
||||
user: "root"
|
||||
password: "Aa#223388"
|
||||
db_name: "cryptoai"
|
||||
db_name: "cryptoai"
|
||||
|
||||
# 腾讯云SES邮件服务配置
|
||||
ses:
|
||||
secret_id: "AKIDxnbGj281iHtKallqqzvlV5YxBCrPltnS" # 腾讯云API密钥ID
|
||||
secret_key: "ta6PXTMBsX7dzA7IN6uYUFn8F9uTovoU" # 腾讯云API密钥
|
||||
region: "ap-guangzhou" # 地域,如ap-guangzhou, ap-hongkong等
|
||||
from_email: "system@mail.ibtc.work" # 发件人邮箱,需要在腾讯云SES控制台中验证
|
||||
template_id: 31670 # 可选,邮件模板ID,如果不使用模板,可以删除此行
|
||||
@ -17,6 +17,7 @@ from typing import Dict, Any
|
||||
|
||||
from cryptoai.routes.agent import router as agent_router
|
||||
from cryptoai.routes.feed import router as feed_router
|
||||
from cryptoai.routes.user import router as user_router
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
@ -48,6 +49,7 @@ app.add_middleware(
|
||||
# 添加API路由
|
||||
app.include_router(agent_router, prefix="/agent")
|
||||
app.include_router(feed_router, prefix="/feed", tags=["AI Agent信息流"])
|
||||
app.include_router(user_router, prefix="/user", tags=["用户管理"])
|
||||
|
||||
# 请求计时中间件
|
||||
@app.middleware("http")
|
||||
|
||||
@ -6,11 +6,12 @@ AI Agent信息流API路由模块,提供信息流的增删改查功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, status, Query
|
||||
from fastapi import APIRouter, HTTPException, status, Query, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from cryptoai.routes.user import get_current_user
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
|
||||
# 配置日志
|
||||
@ -79,7 +80,8 @@ async def create_feed(feed: AgentFeedCreate) -> Dict[str, Any]:
|
||||
async def get_feeds(
|
||||
agent_name: Optional[str] = Query(None, description="AI Agent名称,可选"),
|
||||
limit: int = Query(20, description="返回的最大记录数,默认20条"),
|
||||
skip: int = Query(0, description="跳过的记录数,默认0条")
|
||||
skip: int = Query(0, description="跳过的记录数,默认0条"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> List[AgentFeedResponse]:
|
||||
"""
|
||||
获取AI Agent信息流列表
|
||||
|
||||
354
cryptoai/routes/user.py
Normal file
354
cryptoai/routes/user.py
Normal file
@ -0,0 +1,354 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
用户API路由模块,提供用户注册、登录和信息获取功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
import secrets
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Query
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import jwt
|
||||
from jwt.exceptions import PyJWTError
|
||||
from fastapi import Request
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from cryptoai.utils.email_service import get_email_service
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger("user_router")
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7天过期
|
||||
|
||||
# 请求模型
|
||||
class UserRegister(BaseModel):
|
||||
"""用户注册请求模型"""
|
||||
mail: EmailStr
|
||||
nickname: str
|
||||
password: str
|
||||
verification_code: str
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
"""用户登录请求模型"""
|
||||
mail: EmailStr
|
||||
password: str
|
||||
|
||||
class SendVerificationCodeRequest(BaseModel):
|
||||
"""发送验证码请求模型"""
|
||||
mail: EmailStr
|
||||
|
||||
# 响应模型
|
||||
class UserResponse(BaseModel):
|
||||
"""用户信息响应模型"""
|
||||
id: int
|
||||
mail: str
|
||||
nickname: str
|
||||
level: int
|
||||
create_time: datetime
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""令牌响应模型"""
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_in: int
|
||||
user_info: UserResponse
|
||||
|
||||
# 工具函数
|
||||
def hash_password(password: str) -> str:
|
||||
"""对密码进行哈希处理"""
|
||||
return hashlib.sha256(password.encode()).hexdigest()
|
||||
|
||||
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -> str:
|
||||
"""创建访问令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now() + expires_delta
|
||||
else:
|
||||
expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
async def get_current_user(request: Request) -> Dict[str, Any]:
|
||||
"""获取当前用户"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的身份验证凭据",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
token = request.headers.get("Authorization")
|
||||
if not token:
|
||||
raise credentials_exception
|
||||
token = token.split(" ")[1]
|
||||
print(f"token:{token}")
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
mail = payload.get("sub")
|
||||
print(f"mail:{mail}")
|
||||
if mail is None:
|
||||
raise credentials_exception
|
||||
except PyJWTError as e:
|
||||
print(f"PyJWTError: {e}")
|
||||
raise credentials_exception
|
||||
|
||||
db_manager = get_db_manager()
|
||||
user = db_manager.get_user_by_mail(mail)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
@router.post("/send-verification-code", response_model=Dict[str, Any])
|
||||
async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
发送邮箱验证码
|
||||
|
||||
Args:
|
||||
request: 发送验证码请求
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 检查邮箱是否已被注册
|
||||
user = db_manager.get_user_by_mail(request.mail)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="该邮箱已被注册"
|
||||
)
|
||||
|
||||
# 获取邮件服务
|
||||
email_service = get_email_service()
|
||||
|
||||
# 发送验证码邮件
|
||||
result = email_service.send_verification_email(request.mail)
|
||||
|
||||
if not result['success']:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=result['message']
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "验证码已发送到您的邮箱"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"发送验证码失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"发送验证码失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||||
async def register_user(user: UserRegister) -> Dict[str, Any]:
|
||||
"""
|
||||
注册新用户
|
||||
|
||||
Args:
|
||||
user: 用户注册信息
|
||||
|
||||
Returns:
|
||||
注册成功的状态信息
|
||||
"""
|
||||
try:
|
||||
# 获取邮件服务
|
||||
email_service = get_email_service()
|
||||
|
||||
# 验证验证码
|
||||
if not email_service.verify_code(user.mail, user.verification_code):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="验证码错误或已过期"
|
||||
)
|
||||
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 对密码进行哈希处理
|
||||
hashed_password = hash_password(user.password)
|
||||
|
||||
# 注册用户
|
||||
success = db_manager.register_user(
|
||||
mail=user.mail,
|
||||
nickname=user.nickname,
|
||||
password=hashed_password,
|
||||
level=0 # 默认为普通用户
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="注册失败,邮箱可能已被使用"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "用户注册成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"注册用户失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"注册用户失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(loginData: UserLogin) -> TokenResponse:
|
||||
"""
|
||||
用户登录
|
||||
|
||||
Args:
|
||||
form_data: 表单数据,包含用户名(邮箱)和密码
|
||||
|
||||
Returns:
|
||||
访问令牌和用户信息
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取用户信息
|
||||
user = db_manager.get_user_by_mail(loginData.mail)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="邮箱或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
hashed_password = hash_password(loginData.password)
|
||||
|
||||
# 查询用户的密码哈希
|
||||
session = db_manager.Session()
|
||||
try:
|
||||
from cryptoai.utils.db_manager import User
|
||||
db_user = session.query(User).filter(User.mail == loginData.mail).first()
|
||||
if not db_user or db_user.password != hashed_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="邮箱或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# 创建访问令牌,不过期
|
||||
access_token_expires = None
|
||||
access_token = create_access_token(
|
||||
data={"sub": user["mail"]}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
user_info=UserResponse(
|
||||
id=user["id"],
|
||||
mail=user["mail"],
|
||||
nickname=user["nickname"],
|
||||
level=user["level"],
|
||||
create_time=user["create_time"]
|
||||
)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"用户登录失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"用户登录失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse:
|
||||
"""
|
||||
获取当前登录用户信息
|
||||
|
||||
Args:
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
用户信息
|
||||
"""
|
||||
return UserResponse(
|
||||
id=current_user["id"],
|
||||
mail=current_user["mail"],
|
||||
nickname=current_user["nickname"],
|
||||
level=current_user["level"],
|
||||
create_time=current_user["create_time"]
|
||||
)
|
||||
|
||||
@router.put("/level/{user_id}", response_model=Dict[str, Any])
|
||||
async def update_user_level(
|
||||
user_id: int,
|
||||
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新用户级别(需要管理员权限)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
level: 新的用户级别
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
更新成功的状态信息
|
||||
"""
|
||||
# 简单的权限检查(实际应用中应该有更完善的权限管理)
|
||||
if current_user["level"] < 2: # 假设SVIP用户有管理权限
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="没有足够的权限执行此操作"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 更新用户级别
|
||||
success = db_manager.update_user_level(user_id, level)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"用户ID {user_id} 不存在"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"成功更新用户级别为 {level}"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户级别失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新用户级别失败: {str(e)}"
|
||||
)
|
||||
Binary file not shown.
@ -106,5 +106,8 @@ class ConfigLoader:
|
||||
|
||||
def get_database_config(self) -> Dict[str, Any]:
|
||||
"""获取数据库配置"""
|
||||
|
||||
return self.get_config('database')
|
||||
return self.get_config('database')
|
||||
|
||||
def get_ses_config(self) -> Dict[str, Any]:
|
||||
"""获取腾讯云SES配置"""
|
||||
return self.get_config('ses')
|
||||
@ -64,6 +64,26 @@ class AgentFeed(Base):
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
# 定义用户数据模型
|
||||
class User(Base):
|
||||
"""用户数据表模型"""
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
mail = Column(String(100), nullable=False, unique=True, comment='邮箱')
|
||||
nickname = Column(String(50), nullable=False, comment='昵称')
|
||||
password = Column(String(100), nullable=False, comment='密码')
|
||||
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_mail', 'mail'),
|
||||
Index('idx_level', 'level'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class DBManager:
|
||||
"""数据库管理工具,用于连接MySQL数据库并保存智能体分析结果"""
|
||||
|
||||
@ -234,6 +254,207 @@ class DBManager:
|
||||
pass
|
||||
return False
|
||||
|
||||
def register_user(self, mail: str, nickname: str, password: str, level: int = 0) -> bool:
|
||||
"""
|
||||
注册新用户
|
||||
|
||||
Args:
|
||||
mail: 邮箱
|
||||
nickname: 昵称
|
||||
password: 密码
|
||||
level: 用户级别,默认为0(普通用户)
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 检查邮箱是否已存在
|
||||
existing_user = session.query(User).filter(User.mail == mail).first()
|
||||
if existing_user:
|
||||
logger.warning(f"邮箱 {mail} 已被注册")
|
||||
return False
|
||||
|
||||
# 创建新用户
|
||||
new_user = User(
|
||||
mail=mail,
|
||||
nickname=nickname,
|
||||
password=password, # 实际应用中应该对密码进行哈希处理
|
||||
level=level,
|
||||
create_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
session.add(new_user)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"成功注册用户: {mail}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"注册用户失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库会话失败: {e}")
|
||||
# 如果是连接错误,尝试重新初始化
|
||||
try:
|
||||
self._init_db()
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过邮箱获取用户信息
|
||||
|
||||
Args:
|
||||
mail: 邮箱
|
||||
|
||||
Returns:
|
||||
用户信息,如果用户不存在则返回None
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询用户
|
||||
user = session.query(User).filter(User.mail == mail).first()
|
||||
|
||||
if user:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': user.id,
|
||||
'mail': user.mail,
|
||||
'nickname': user.nickname,
|
||||
'level': user.level,
|
||||
'create_time': user.create_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户信息失败: {e}")
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过ID获取用户信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户信息,如果用户不存在则返回None
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询用户
|
||||
user = session.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if user:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': user.id,
|
||||
'mail': user.mail,
|
||||
'nickname': user.nickname,
|
||||
'level': user.level,
|
||||
'create_time': user.create_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户信息失败: {e}")
|
||||
return None
|
||||
|
||||
def update_user_level(self, user_id: int, level: int) -> bool:
|
||||
"""
|
||||
更新用户级别
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
level: 新的用户级别
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询用户
|
||||
user = session.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"用户ID {user_id} 不存在")
|
||||
return False
|
||||
|
||||
# 更新级别
|
||||
user.level = level
|
||||
session.commit()
|
||||
|
||||
logger.info(f"成功更新用户 {user.mail} 的级别为 {level}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"更新用户级别失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库会话失败: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取AI Agent信息流
|
||||
|
||||
252
cryptoai/utils/email_service.py
Normal file
252
cryptoai/utils/email_service.py
Normal file
@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
邮件服务工具类,使用腾讯云SES服务发送邮件
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
import requests
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
||||
from tencentcloud.ses.v20201002 import ses_client, models
|
||||
|
||||
from cryptoai.utils.config_loader import ConfigLoader
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger('email_service')
|
||||
|
||||
# 验证码缓存,格式: {email: {'code': '123456', 'expire_time': timestamp}}
|
||||
verification_codes = {}
|
||||
|
||||
class EmailService:
|
||||
"""邮件服务类,用于发送验证码邮件"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化邮件服务"""
|
||||
# 加载腾讯云SES配置
|
||||
config_loader = ConfigLoader()
|
||||
self.ses_config = config_loader.get_ses_config()
|
||||
|
||||
# 初始化SES客户端
|
||||
self.client = None
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self) -> None:
|
||||
"""初始化腾讯云SES客户端"""
|
||||
try:
|
||||
# 创建认证对象
|
||||
cred = credential.Credential(
|
||||
self.ses_config.get('secret_id'),
|
||||
self.ses_config.get('secret_key')
|
||||
)
|
||||
|
||||
# 创建HTTP配置
|
||||
http_profile = HttpProfile()
|
||||
http_profile.endpoint = "ses.tencentcloudapi.com"
|
||||
|
||||
# 创建客户端配置
|
||||
client_profile = ClientProfile()
|
||||
client_profile.httpProfile = http_profile
|
||||
|
||||
# 创建SES客户端
|
||||
self.client = ses_client.SesClient(cred, self.ses_config.get('region'), client_profile)
|
||||
|
||||
logger.info("成功初始化腾讯云SES客户端")
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
logger.error(f"初始化腾讯云SES客户端失败: {e}")
|
||||
self.client = None
|
||||
|
||||
def generate_verification_code(self, length: int = 6) -> str:
|
||||
"""
|
||||
生成数字验证码
|
||||
|
||||
Args:
|
||||
length: 验证码长度,默认6位
|
||||
|
||||
Returns:
|
||||
生成的验证码
|
||||
"""
|
||||
return ''.join(random.choices(string.digits, k=length))
|
||||
|
||||
def save_verification_code(self, email: str, code: str, expire_minutes: int = 10) -> None:
|
||||
"""
|
||||
保存验证码到缓存
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
code: 验证码
|
||||
expire_minutes: 过期时间(分钟),默认10分钟
|
||||
"""
|
||||
expire_time = time.time() + expire_minutes * 60
|
||||
verification_codes[email] = {
|
||||
'code': code,
|
||||
'expire_time': expire_time
|
||||
}
|
||||
|
||||
def verify_code(self, email: str, code: str) -> bool:
|
||||
"""
|
||||
验证邮箱验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
code: 验证码
|
||||
|
||||
Returns:
|
||||
验证是否成功
|
||||
"""
|
||||
# 检查验证码是否存在
|
||||
if email not in verification_codes:
|
||||
return False
|
||||
|
||||
# 获取验证码信息
|
||||
code_info = verification_codes[email]
|
||||
|
||||
# 检查验证码是否过期
|
||||
if time.time() > code_info['expire_time']:
|
||||
# 删除过期验证码
|
||||
del verification_codes[email]
|
||||
return False
|
||||
|
||||
# 验证验证码
|
||||
if code_info['code'] == code:
|
||||
# 验证成功后删除验证码
|
||||
del verification_codes[email]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def send_verification_email(self, email: str) -> Dict[str, Any]:
|
||||
"""
|
||||
发送验证码邮件
|
||||
|
||||
Args:
|
||||
email: 收件人邮箱
|
||||
|
||||
Returns:
|
||||
发送结果
|
||||
"""
|
||||
if not self.client:
|
||||
try:
|
||||
self._init_client()
|
||||
except Exception as e:
|
||||
logger.error(f"重新初始化SES客户端失败: {e}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': '邮件服务初始化失败'
|
||||
}
|
||||
|
||||
try:
|
||||
# 生成验证码
|
||||
code = self.generate_verification_code()
|
||||
|
||||
# 保存验证码
|
||||
self.save_verification_code(email, code)
|
||||
|
||||
# 构建邮件内容
|
||||
subject = "CryptoAI - 您的验证码"
|
||||
html_content = f"""
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; line-height: 1.6; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }}
|
||||
.header {{ background-color: #4A90E2; color: white; padding: 10px; text-align: center; }}
|
||||
.content {{ padding: 20px; }}
|
||||
.code {{ font-size: 24px; font-weight: bold; text-align: center;
|
||||
color: #4A90E2; padding: 10px; margin: 20px 0; }}
|
||||
.footer {{ font-size: 12px; color: #999; text-align: center; margin-top: 30px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h2>CryptoAI 验证码</h2>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>您好,</p>
|
||||
<p>您正在注册 CryptoAI 账号,请使用以下验证码完成注册:</p>
|
||||
<div class="code">{code}</div>
|
||||
<p>验证码有效期为10分钟,请勿将验证码泄露给他人。</p>
|
||||
<p>如果您没有进行此操作,请忽略此邮件。</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>此邮件由系统自动发送,请勿回复。</p>
|
||||
<p>© {time.strftime('%Y')} CryptoAI. 保留所有权利。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# 创建发送邮件请求
|
||||
req = models.SendEmailRequest()
|
||||
req.FromEmailAddress = self.ses_config.get('from_email')
|
||||
req.Destination = [email]
|
||||
req.Subject = subject
|
||||
req.Template = {
|
||||
"TemplateID": self.ses_config.get('template_id', 1), # 使用默认模板ID或配置的模板ID
|
||||
"TemplateData": json.dumps({
|
||||
"code": code
|
||||
})
|
||||
}
|
||||
|
||||
# 如果没有配置模板ID,则使用HTML内容
|
||||
if not self.ses_config.get('template_id'):
|
||||
req.Template = None
|
||||
req.Simple = {
|
||||
"Html": html_content
|
||||
}
|
||||
|
||||
# 发送邮件
|
||||
resp = self.client.SendEmail(req)
|
||||
|
||||
logger.info(f"成功发送验证码邮件到 {email}")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '验证码已发送',
|
||||
'request_id': resp.RequestId
|
||||
}
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
logger.error(f"发送验证码邮件失败: {e}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'发送验证码失败: {e}'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"发送验证码邮件出现未知错误: {e}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'发送验证码失败: {str(e)}'
|
||||
}
|
||||
|
||||
# 单例模式
|
||||
_email_service_instance = None
|
||||
|
||||
def get_email_service() -> EmailService:
|
||||
"""
|
||||
获取邮件服务实例(单例模式)
|
||||
|
||||
Returns:
|
||||
邮件服务实例
|
||||
"""
|
||||
global _email_service_instance
|
||||
|
||||
# 如果已经初始化过,直接返回
|
||||
if _email_service_instance is not None:
|
||||
return _email_service_instance
|
||||
|
||||
# 创建实例
|
||||
_email_service_instance = EmailService()
|
||||
|
||||
return _email_service_instance
|
||||
@ -45,6 +45,36 @@ def update_table_charset():
|
||||
MODIFY content TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
# 检查users表是否存在
|
||||
result = session.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = 'users';
|
||||
"""))
|
||||
|
||||
table_exists = result.scalar() > 0
|
||||
|
||||
# 如果users表存在,更新其字符集
|
||||
if table_exists:
|
||||
session.execute(text("""
|
||||
ALTER TABLE users
|
||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
# 特别更新nickname和mail列的字符集
|
||||
session.execute(text("""
|
||||
ALTER TABLE users
|
||||
MODIFY nickname VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
session.execute(text("""
|
||||
ALTER TABLE users
|
||||
MODIFY mail VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
logger.info("成功更新users表字符集为utf8mb4")
|
||||
|
||||
session.commit()
|
||||
logger.info("成功更新数据库表字符集为utf8mb4")
|
||||
return True
|
||||
|
||||
@ -4,7 +4,7 @@ services:
|
||||
cryptoai-task:
|
||||
build: .
|
||||
container_name: cryptoai-task
|
||||
image: cryptoai:0.0.15
|
||||
image: cryptoai:0.0.16
|
||||
restart: always
|
||||
volumes:
|
||||
- ./cryptoai/data:/app/cryptoai/data
|
||||
@ -29,7 +29,7 @@ services:
|
||||
cryptoai-api:
|
||||
build: .
|
||||
container_name: cryptoai-api
|
||||
image: cryptoai-api:0.0.3
|
||||
image: cryptoai-api:0.0.4
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
@ -10,6 +10,10 @@ pyyaml==6.0.1
|
||||
fastapi==0.110.0
|
||||
uvicorn==0.27.1
|
||||
python-dotenv==1.0.0
|
||||
pyjwt==2.8.0
|
||||
python-multipart==0.0.9
|
||||
email-validator==2.1.0
|
||||
tencentcloud-sdk-python==3.0.1030
|
||||
# # 日志相关
|
||||
# logging==0.4.9.6
|
||||
# # 数据处理相关
|
||||
|
||||
Loading…
Reference in New Issue
Block a user