增加千问

This commit is contained in:
aaron 2025-03-07 14:11:29 +08:00
parent 5ca85d9edb
commit ecfac5d150
9 changed files with 306 additions and 205 deletions

66
app/api/endpoints/ai.py Normal file
View File

@ -0,0 +1,66 @@
from fastapi import APIRouter, Depends, UploadFile, File
from app.core.response import success_response, error_response, ResponseModel
from app.core.ai_client import ai_client
from app.api.deps import get_current_user
from app.models.user import UserDB
import logging
from app.core.qcloud import qcloud_manager
router = APIRouter()
@router.post("/extract_pickup_code", response_model=ResponseModel)
async def extract_pickup_code(
file: UploadFile = File(...)
):
"""从图片中提取取件码"""
try:
# 检查文件类型
if not file.content_type.startswith('image/'):
return error_response(code=400, message="只能上传图片文件")
url = await qcloud_manager.upload_file(file)
if not url:
return error_response(code=500, message="上传图片失败")
# 调用 AI 客户端提取取件码
result = await ai_client.extract_pickup_code(url)
if "error" in result:
return error_response(code=500, message=result.get("message", "提取取件码失败"))
# 检查是否提取到取件码
if not result.get("stations") or not any(station.get("pickup_codes") for station in result.get("stations", [])):
return error_response(code=400, message="提取取件码信息失败")
# 格式化输出
formatted_text = format_pickup_codes(result)
# 返回原始数据和格式化文本
return success_response(data={
"raw": result,
"formatted_text": formatted_text
})
except Exception as e:
logging.exception(f"提取取件码失败: {str(e)}")
return error_response(code=500, message=f"提取取件码失败: {str(e)}")
def format_pickup_codes(result):
"""将取件码结果格式化为指定格式"""
formatted_lines = []
for station in result.get("stations", []):
station_name = station.get("name", "未知驿站")
pickup_codes = station.get("pickup_codes", [])
if pickup_codes:
# 格式化取件码,用 分隔
codes_text = " ".join(pickup_codes)
# 添加驿站和取件码信息
formatted_lines.append(f"驿站:{station_name}")
formatted_lines.append(f"取件码:{codes_text}")
formatted_lines.append("") # 添加空行分隔不同驿站
# 合并所有行
return "\n".join(formatted_lines).strip()

View File

@ -1,32 +0,0 @@
from fastapi import APIRouter, Depends, UploadFile, File
from app.core.response import success_response, error_response, ResponseModel
from app.core.ocr_service import ocr_service
from app.api.deps import get_current_user
from app.models.user import UserDB
router = APIRouter()
@router.post("/pickup_code", response_model=ResponseModel)
async def recognize_pickup_code(
file: UploadFile = File(...),
current_user: UserDB = Depends(get_current_user)
):
"""识别收件码图片"""
try:
# 检查文件类型
if not file.content_type.startswith('image/'):
return error_response(code=400, message="只能上传图片文件")
# 读取文件内容
content = await file.read()
# 调用OCR服务识别图片
result = await ocr_service.recognize_pickup_code(content)
if not result.get("stations") or not any(station["pickup_codes"] for station in result["stations"]):
return error_response(code=400, message="未能识别到取件码")
return success_response(data=result)
except Exception as e:
return error_response(code=500, message=f"识别失败: {str(e)}")

View File

@ -90,13 +90,11 @@ async def get_order_additional_fees(
return error_response(code=403, message="您无权查看该订单的加价请求")
# 获取加价请求列表
fee_requests = db.query(OrderAdditionalFeeDB).filter(
request = db.query(OrderAdditionalFeeDB).filter(
OrderAdditionalFeeDB.orderid == orderid
).order_by(OrderAdditionalFeeDB.create_time.desc()).all()
).order_by(OrderAdditionalFeeDB.create_time.desc()).first()
return success_response(data=[
OrderAdditionalFeeInfo.model_validate(req) for req in fee_requests
])
return success_response(data=OrderAdditionalFeeInfo.model_validate(request))
@router.put("/{request_id}/accept", response_model=ResponseModel)
async def accept_additional_fee(

73
app/core/ai_client.py Normal file
View File

@ -0,0 +1,73 @@
import logging
import json
import base64
from typing import Dict, Any, Optional, List
import asyncio
from app.core.config import settings
from app.core.qwen_client import qwen_client
class AIClient:
"""AI 客户端,统一包装千问和 DeepSeek"""
def __init__(self):
self.timeout = 15 # 请求超时时间(秒)
async def extract_pickup_code(self, image_url: str) -> Dict[str, Any]:
"""
从图片中提取取件码
Args:
image_content: 图片二进制内容
Returns:
Dict: 提取结果包含取件码信息
"""
try:
primary_result = await self._extract_with_qwen(image_url)
# 检查结果是否有效
if self._is_valid_result(primary_result):
return primary_result
return {"error": "处理失败", "message": "提取取件码失败"}
except Exception as e:
logging.exception(f"提取取件码异常: {str(e)}")
return {"error": "处理失败", "message": str(e)}
async def _extract_with_qwen(self, image_url: str) -> Dict[str, Any]:
"""使用千问提取取件码"""
try:
# 添加超时控制
return await asyncio.wait_for(
qwen_client.extract_pickup_code(image_url),
timeout=self.timeout
)
except asyncio.TimeoutError:
logging.error("千问 API 请求超时")
return {"error": "API请求超时", "details": "千问 API 请求超时"}
except Exception as e:
logging.exception(f"千问提取异常: {str(e)}")
return {"error": "处理失败", "message": str(e)}
def _is_valid_result(self, result: Dict[str, Any]) -> bool:
"""检查结果是否有效"""
# 检查是否有错误
if "error" in result:
return False
# 检查是否有站点信息
stations = result.get("stations", [])
if not stations:
return False
# 检查是否有取件码
for station in stations:
if station.get("pickup_codes") and len(station.get("pickup_codes", [])) > 0:
return True
return False
# 创建全局实例
ai_client = AIClient()

View File

@ -101,6 +101,15 @@ class Settings(BaseSettings):
# 反馈需求企业微信
FEEDBACK_NEED_WECOM_BOT_WEBHOOK_URL: str = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=ccd6e8d4-4c8a-45b4-9b6b-dd4cae563176"
# DeepSeek 相关配置
DEEPSEEK_API_KEY: str = "sk-9f6b56f08796435d988cf202e37f6ee3"
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
# 千问 API 配置
QWEN_API_KEY: str = "sk-caa199589f1c451aaac471fad2986e28"
QWEN_API_URL: str = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
class Config:
case_sensitive = True
env_file = ".env"

View File

@ -1,154 +0,0 @@
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.ocr.v20181119 import ocr_client, models
from app.core.config import settings
import json
import base64
class OCRService:
def __init__(self):
cred = credential.Credential(settings.TENCENT_SECRET_ID, settings.TENCENT_SECRET_KEY)
httpProfile = HttpProfile()
httpProfile.endpoint = "ocr.tencentcloudapi.com"
clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile
self.client = ocr_client.OcrClient(cred, settings.TENCENT_REGION, clientProfile)
async def recognize_pickup_code(self, image_content: bytes) -> dict:
"""识别收件码图片"""
try:
# 将图片内容转为base64
img_base64 = base64.b64encode(image_content).decode()
req = models.GeneralAccurateOCRRequest()
req.ImageBase64 = img_base64
resp = self.client.GeneralAccurateOCR(req)
result = json.loads(resp.to_json_string())
print(result)
# 解析文本内容
text_list = []
for item in result.get("TextDetections", []):
text_list.append(item["DetectedText"])
# 提取关键信息
pickup_info = self._extract_pickup_info(text_list)
return pickup_info
except Exception as e:
raise Exception(f"识别失败: {str(e)}")
def _is_valid_pickup_code(self, text: str) -> bool:
"""验证是否是有效的取件码格式"""
import re
# 匹配格式xx-x-xxx 或 xx-xx-xxx 等类似格式
patterns = [
r'\b\d{1,2}-\d{1,2}-\d{2,3}\b', # 15-4-223
r'\b\d{4,8}\b', # 普通4-8位数字
]
for pattern in patterns:
if re.search(pattern, text):
return True
return False
def _extract_pickup_info(self, text_list: list) -> dict:
"""提取收件码信息"""
# 存储所有驿站信息
stations = []
current_station = None
current_codes = []
pickup_info = {
"stations": [], # 驿站列表
"app_type": None # APP类型(菜鸟/京东等)
}
# 识别APP类型
app_keywords = {
"菜鸟": "CAINIAO",
"京东": "JD",
"顺丰": "SF"
}
for text in text_list:
# 查找APP类型
for keyword, app_type in app_keywords.items():
if keyword in text:
pickup_info["app_type"] = app_type
break
# 查找驿站名称
is_station = False
if "驿站" in text:
is_station = True
elif "站点" in text:
is_station = True
elif "" in text:
is_station = True
elif "" in text:
is_station = True
elif "分拨" in text:
is_station = True
elif "分拣" in text:
is_station = True
elif "分拨" in text:
is_station = True
if is_station:
# 如果之前有未保存的驿站信息,先保存
if current_station and current_codes:
stations.append({
"station_name": current_station,
"pickup_codes": current_codes
})
# 开始新的驿站信息收集
current_station = text
current_codes = []
# 查找取件码
if self._is_valid_pickup_code(text):
# 清理文本中的多余字符
cleaned_text = ''.join(c for c in text if c.isdigit() or c == '-')
# 提取所有匹配的取件码
import re
for pattern in [r'\d{1,2}-\d{1,2}-\d{2,3}', r'\d{4,8}']:
matches = re.finditer(pattern, cleaned_text)
for match in matches:
code = match.group()
# 如果已找到驿站,将取件码添加到当前驿站
if current_station and code not in current_codes:
current_codes.append(code)
# 如果还没找到驿站,暂存取件码
elif code not in current_codes:
current_codes.append(code)
# 保存最后一个驿站的信息
if current_station and current_codes:
stations.append({
"station_name": current_station,
"pickup_codes": current_codes
})
# 如果有未分配到驿站的取件码,创建一个默认驿站
elif current_codes:
stations.append({
"station_name": None,
"pickup_codes": current_codes
})
# 如果找到了取件码但没找到APP类型根据取件码格式推测
if stations and not pickup_info["app_type"]:
# 如果任一取件码包含连字符,判定为菜鸟
for station in stations:
if any('-' in code for code in station["pickup_codes"]):
pickup_info["app_type"] = "CAINIAO"
break
pickup_info["stations"] = stations
return pickup_info
ocr_service = OCRService()

150
app/core/qwen_client.py Normal file
View File

@ -0,0 +1,150 @@
import logging
import json
import base64
from typing import Dict, Any, Optional, List
import re
from app.core.config import settings
# 导入 DashScope SDK
try:
from dashscope import MultiModalConversation
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
except ImportError:
logging.error("请安装 DashScope SDK: pip install dashscope")
raise
class QwenClient:
"""千问 API 客户端 (使用 DashScope SDK)"""
def __init__(self):
self.api_key = settings.QWEN_API_KEY
self.model = "qwen-vl-max" # 使用千问视觉语言大模型
async def extract_pickup_code(self, image_url: str) -> Dict[str, Any]:
"""
从图片中提取取件码
Args:
image_content: 图片二进制内容
Returns:
Dict: 提取结果包含取件码信息
"""
try:
# 构建消息
messages = [
{
"role": "system",
"content": "你是一个专门识别快递取件码的助手。请准确提取图片中的所有取件码信息。"
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "请识别图中驿站的所有取件码,以[{\"station\":\"驿站名字\",\"pickup_codes\":[\"3232\",\"2323\"]}]的格式返回。只返回JSON格式数据不要其他解释。"
},
{
"type": "image",
"image": image_url
}
]
}
]
# 使用 SDK 调用 API
response = MultiModalConversation.call(
model=self.model,
messages=messages,
api_key=self.api_key,
result_format='message',
temperature=0.1,
max_tokens=1000
)
# 检查响应状态
if response.status_code != 200:
logging.error(f"千问 API 请求失败: {response.code} - {response.message}")
return {"error": "API请求失败", "details": f"{response.code}: {response.message}"}
# 记录响应
logging.info(f"千问 API 响应状态: {response.status_code}")
logging.info(f"千问 API 响应内容: {response}")
# 提取回复内容
try:
# 直接使用响应对象
# 提取消息内容 - 使用字典访问方式
output = response.get('output', {})
choices = output.get('choices', [{}])
message = choices[0].get('message', {}) if choices else {}
logging.info(f"消息: {message}")
print(f"消息: {message}")
# 获取文本内容
content = message.get('content', [])
if isinstance(content, list) and len(content) > 0:
# 提取文本内容
text_content = ""
for item in content:
if isinstance(item, dict) and 'text' in item:
text_content = item['text']
break
logging.info(f"提取的文本内容: {text_content}")
# 清理文本,移除 Markdown 代码块
text_content = text_content.strip()
# 移除 ```json 和 ``` 标记
if text_content.startswith("```json"):
text_content = text_content[7:]
elif text_content.startswith("```"):
text_content = text_content[3:]
if text_content.endswith("```"):
text_content = text_content[:-3]
text_content = text_content.strip()
logging.info(f"清理后的文本内容: {text_content}")
# 尝试解析 JSON
try:
pickup_data = json.loads(text_content)
# 确保是列表格式
if isinstance(pickup_data, list):
# 转换为统一格式
return {"stations": [{"name": item.get("station", ""), "pickup_codes": item.get("pickup_codes", [])} for item in pickup_data]}
else:
logging.warning(f"解析结果不是列表格式: {pickup_data}")
return {"stations": []}
except json.JSONDecodeError as e:
logging.error(f"JSON解析错误: {str(e)}, 原始字符串: {text_content}")
# 尝试使用正则表达式提取JSON
json_match = re.search(r'(\[{.*}\])', text_content, re.DOTALL)
if json_match:
try:
json_str = json_match.group(1)
pickup_data = json.loads(json_str)
return {"stations": [{"name": item.get("station", ""), "pickup_codes": item.get("pickup_codes", [])} for item in pickup_data]}
except Exception as je:
logging.error(f"正则提取的JSON解析错误: {str(je)}, 提取的字符串: {json_match.group(1)}")
return {"stations": []}
else:
logging.error(f"无法提取内容列表或内容列表为空: {content}")
return {"stations": []}
except Exception as e:
logging.exception(f"解析千问 API 响应失败: {str(e)}")
return {"stations": []}
except Exception as e:
logging.exception(f"调用千问 API 异常: {str(e)}")
return {"error": "处理失败", "message": str(e)}
# 创建全局实例
qwen_client = QwenClient()

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, ocr, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai
from app.models.database import Base, engine
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@ -38,15 +38,9 @@ app.add_middleware(
# 添加请求日志中间件
app.add_middleware(RequestLoggerMiddleware)
@app.get("/api/info")
async def get_info():
"""获取当前环境信息"""
return {
"project": settings.PROJECT_NAME,
"debug": settings.DEBUG
}
# 添加用户路由
app.include_router(ai.router, prefix="/api/ai", tags=["AI服务"])
app.include_router(dashboard.router, prefix="/api/dashboard", tags=["仪表盘"])
app.include_router(wechat.router,prefix="/api/wechat",tags=["微信"])
app.include_router(mp.router, prefix="/api/mp", tags=["微信公众号"])
@ -77,7 +71,6 @@ app.include_router(message.router, prefix="/api/message", tags=["消息中心"])
app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
app.include_router(config.router, prefix="/api/config", tags=["系统配置"])
app.include_router(log.router, prefix="/api/logs", tags=["系统日志"])
app.include_router(ocr.router, prefix="/api/ai/ocr", tags=["图像识别"])
app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
@ -85,9 +78,6 @@ app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
async def root():
return {"message": "欢迎使用 Beefast 蜂快到家 API"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.exception_handler(Exception)

View File

@ -16,4 +16,5 @@ aiohttp==3.9.1
cryptography==42.0.2
qrcode>=7.3.1
pillow>=9.0.0
pytz==2024.1
pytz==2024.1
dashscope>=1.13.0