增加千问
This commit is contained in:
parent
5ca85d9edb
commit
ecfac5d150
66
app/api/endpoints/ai.py
Normal file
66
app/api/endpoints/ai.py
Normal file
@ -0,0 +1,66 @@
|
||||
from fastapi import APIRouter, Depends, UploadFile, File
|
||||
from app.core.response import success_response, error_response, ResponseModel
|
||||
from app.core.ai_client import ai_client
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import UserDB
|
||||
import logging
|
||||
from app.core.qcloud import qcloud_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/extract_pickup_code", response_model=ResponseModel)
|
||||
async def extract_pickup_code(
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
"""从图片中提取取件码"""
|
||||
try:
|
||||
# 检查文件类型
|
||||
if not file.content_type.startswith('image/'):
|
||||
return error_response(code=400, message="只能上传图片文件")
|
||||
|
||||
url = await qcloud_manager.upload_file(file)
|
||||
if not url:
|
||||
return error_response(code=500, message="上传图片失败")
|
||||
|
||||
# 调用 AI 客户端提取取件码
|
||||
result = await ai_client.extract_pickup_code(url)
|
||||
|
||||
if "error" in result:
|
||||
return error_response(code=500, message=result.get("message", "提取取件码失败"))
|
||||
|
||||
# 检查是否提取到取件码
|
||||
if not result.get("stations") or not any(station.get("pickup_codes") for station in result.get("stations", [])):
|
||||
return error_response(code=400, message="提取取件码信息失败")
|
||||
|
||||
# 格式化输出
|
||||
formatted_text = format_pickup_codes(result)
|
||||
|
||||
# 返回原始数据和格式化文本
|
||||
return success_response(data={
|
||||
"raw": result,
|
||||
"formatted_text": formatted_text
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"提取取件码失败: {str(e)}")
|
||||
return error_response(code=500, message=f"提取取件码失败: {str(e)}")
|
||||
|
||||
def format_pickup_codes(result):
|
||||
"""将取件码结果格式化为指定格式"""
|
||||
formatted_lines = []
|
||||
|
||||
for station in result.get("stations", []):
|
||||
station_name = station.get("name", "未知驿站")
|
||||
pickup_codes = station.get("pickup_codes", [])
|
||||
|
||||
if pickup_codes:
|
||||
# 格式化取件码,用 | 分隔
|
||||
codes_text = " | ".join(pickup_codes)
|
||||
|
||||
# 添加驿站和取件码信息
|
||||
formatted_lines.append(f"驿站:{station_name}")
|
||||
formatted_lines.append(f"取件码:{codes_text}")
|
||||
formatted_lines.append("") # 添加空行分隔不同驿站
|
||||
|
||||
# 合并所有行
|
||||
return "\n".join(formatted_lines).strip()
|
||||
@ -1,32 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, UploadFile, File
|
||||
from app.core.response import success_response, error_response, ResponseModel
|
||||
from app.core.ocr_service import ocr_service
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import UserDB
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/pickup_code", response_model=ResponseModel)
|
||||
async def recognize_pickup_code(
|
||||
file: UploadFile = File(...),
|
||||
current_user: UserDB = Depends(get_current_user)
|
||||
):
|
||||
"""识别收件码图片"""
|
||||
try:
|
||||
# 检查文件类型
|
||||
if not file.content_type.startswith('image/'):
|
||||
return error_response(code=400, message="只能上传图片文件")
|
||||
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 调用OCR服务识别图片
|
||||
result = await ocr_service.recognize_pickup_code(content)
|
||||
|
||||
if not result.get("stations") or not any(station["pickup_codes"] for station in result["stations"]):
|
||||
return error_response(code=400, message="未能识别到取件码")
|
||||
|
||||
return success_response(data=result)
|
||||
|
||||
except Exception as e:
|
||||
return error_response(code=500, message=f"识别失败: {str(e)}")
|
||||
@ -90,13 +90,11 @@ async def get_order_additional_fees(
|
||||
return error_response(code=403, message="您无权查看该订单的加价请求")
|
||||
|
||||
# 获取加价请求列表
|
||||
fee_requests = db.query(OrderAdditionalFeeDB).filter(
|
||||
request = db.query(OrderAdditionalFeeDB).filter(
|
||||
OrderAdditionalFeeDB.orderid == orderid
|
||||
).order_by(OrderAdditionalFeeDB.create_time.desc()).all()
|
||||
).order_by(OrderAdditionalFeeDB.create_time.desc()).first()
|
||||
|
||||
return success_response(data=[
|
||||
OrderAdditionalFeeInfo.model_validate(req) for req in fee_requests
|
||||
])
|
||||
return success_response(data=OrderAdditionalFeeInfo.model_validate(request))
|
||||
|
||||
@router.put("/{request_id}/accept", response_model=ResponseModel)
|
||||
async def accept_additional_fee(
|
||||
|
||||
73
app/core/ai_client.py
Normal file
73
app/core/ai_client.py
Normal file
@ -0,0 +1,73 @@
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from app.core.config import settings
|
||||
from app.core.qwen_client import qwen_client
|
||||
|
||||
class AIClient:
|
||||
"""AI 客户端,统一包装千问和 DeepSeek"""
|
||||
|
||||
def __init__(self):
|
||||
self.timeout = 15 # 请求超时时间(秒)
|
||||
|
||||
async def extract_pickup_code(self, image_url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
从图片中提取取件码
|
||||
|
||||
Args:
|
||||
image_content: 图片二进制内容
|
||||
|
||||
Returns:
|
||||
Dict: 提取结果,包含取件码信息
|
||||
"""
|
||||
try:
|
||||
primary_result = await self._extract_with_qwen(image_url)
|
||||
|
||||
# 检查结果是否有效
|
||||
if self._is_valid_result(primary_result):
|
||||
return primary_result
|
||||
|
||||
return {"error": "处理失败", "message": "提取取件码失败"}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"提取取件码异常: {str(e)}")
|
||||
return {"error": "处理失败", "message": str(e)}
|
||||
|
||||
async def _extract_with_qwen(self, image_url: str) -> Dict[str, Any]:
|
||||
"""使用千问提取取件码"""
|
||||
try:
|
||||
# 添加超时控制
|
||||
return await asyncio.wait_for(
|
||||
qwen_client.extract_pickup_code(image_url),
|
||||
timeout=self.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logging.error("千问 API 请求超时")
|
||||
return {"error": "API请求超时", "details": "千问 API 请求超时"}
|
||||
except Exception as e:
|
||||
logging.exception(f"千问提取异常: {str(e)}")
|
||||
return {"error": "处理失败", "message": str(e)}
|
||||
|
||||
|
||||
def _is_valid_result(self, result: Dict[str, Any]) -> bool:
|
||||
"""检查结果是否有效"""
|
||||
# 检查是否有错误
|
||||
if "error" in result:
|
||||
return False
|
||||
|
||||
# 检查是否有站点信息
|
||||
stations = result.get("stations", [])
|
||||
if not stations:
|
||||
return False
|
||||
|
||||
# 检查是否有取件码
|
||||
for station in stations:
|
||||
if station.get("pickup_codes") and len(station.get("pickup_codes", [])) > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# 创建全局实例
|
||||
ai_client = AIClient()
|
||||
@ -101,6 +101,15 @@ class Settings(BaseSettings):
|
||||
|
||||
# 反馈需求企业微信
|
||||
FEEDBACK_NEED_WECOM_BOT_WEBHOOK_URL: str = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=ccd6e8d4-4c8a-45b4-9b6b-dd4cae563176"
|
||||
|
||||
# DeepSeek 相关配置
|
||||
DEEPSEEK_API_KEY: str = "sk-9f6b56f08796435d988cf202e37f6ee3"
|
||||
DEEPSEEK_API_URL: str = "https://api.deepseek.com/v1/chat/completions"
|
||||
|
||||
# 千问 API 配置
|
||||
QWEN_API_KEY: str = "sk-caa199589f1c451aaac471fad2986e28"
|
||||
QWEN_API_URL: str = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = ".env"
|
||||
|
||||
@ -1,154 +0,0 @@
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.ocr.v20181119 import ocr_client, models
|
||||
from app.core.config import settings
|
||||
import json
|
||||
import base64
|
||||
|
||||
class OCRService:
|
||||
def __init__(self):
|
||||
cred = credential.Credential(settings.TENCENT_SECRET_ID, settings.TENCENT_SECRET_KEY)
|
||||
httpProfile = HttpProfile()
|
||||
httpProfile.endpoint = "ocr.tencentcloudapi.com"
|
||||
|
||||
clientProfile = ClientProfile()
|
||||
clientProfile.httpProfile = httpProfile
|
||||
self.client = ocr_client.OcrClient(cred, settings.TENCENT_REGION, clientProfile)
|
||||
|
||||
async def recognize_pickup_code(self, image_content: bytes) -> dict:
|
||||
"""识别收件码图片"""
|
||||
try:
|
||||
# 将图片内容转为base64
|
||||
img_base64 = base64.b64encode(image_content).decode()
|
||||
|
||||
req = models.GeneralAccurateOCRRequest()
|
||||
req.ImageBase64 = img_base64
|
||||
|
||||
resp = self.client.GeneralAccurateOCR(req)
|
||||
result = json.loads(resp.to_json_string())
|
||||
|
||||
print(result)
|
||||
|
||||
# 解析文本内容
|
||||
text_list = []
|
||||
for item in result.get("TextDetections", []):
|
||||
text_list.append(item["DetectedText"])
|
||||
|
||||
# 提取关键信息
|
||||
pickup_info = self._extract_pickup_info(text_list)
|
||||
return pickup_info
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"识别失败: {str(e)}")
|
||||
|
||||
def _is_valid_pickup_code(self, text: str) -> bool:
|
||||
"""验证是否是有效的取件码格式"""
|
||||
import re
|
||||
# 匹配格式:xx-x-xxx 或 xx-xx-xxx 等类似格式
|
||||
patterns = [
|
||||
r'\b\d{1,2}-\d{1,2}-\d{2,3}\b', # 15-4-223
|
||||
r'\b\d{4,8}\b', # 普通4-8位数字
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, text):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_pickup_info(self, text_list: list) -> dict:
|
||||
"""提取收件码信息"""
|
||||
# 存储所有驿站信息
|
||||
stations = []
|
||||
current_station = None
|
||||
current_codes = []
|
||||
|
||||
pickup_info = {
|
||||
"stations": [], # 驿站列表
|
||||
"app_type": None # APP类型(菜鸟/京东等)
|
||||
}
|
||||
|
||||
# 识别APP类型
|
||||
app_keywords = {
|
||||
"菜鸟": "CAINIAO",
|
||||
"京东": "JD",
|
||||
"顺丰": "SF"
|
||||
}
|
||||
|
||||
for text in text_list:
|
||||
# 查找APP类型
|
||||
for keyword, app_type in app_keywords.items():
|
||||
if keyword in text:
|
||||
pickup_info["app_type"] = app_type
|
||||
break
|
||||
|
||||
# 查找驿站名称
|
||||
is_station = False
|
||||
if "驿站" in text:
|
||||
is_station = True
|
||||
elif "站点" in text:
|
||||
is_station = True
|
||||
elif "仓" in text:
|
||||
is_station = True
|
||||
elif "站" in text:
|
||||
is_station = True
|
||||
elif "分拨" in text:
|
||||
is_station = True
|
||||
elif "分拣" in text:
|
||||
is_station = True
|
||||
elif "分拨" in text:
|
||||
is_station = True
|
||||
|
||||
if is_station:
|
||||
# 如果之前有未保存的驿站信息,先保存
|
||||
if current_station and current_codes:
|
||||
stations.append({
|
||||
"station_name": current_station,
|
||||
"pickup_codes": current_codes
|
||||
})
|
||||
# 开始新的驿站信息收集
|
||||
current_station = text
|
||||
current_codes = []
|
||||
|
||||
# 查找取件码
|
||||
if self._is_valid_pickup_code(text):
|
||||
# 清理文本中的多余字符
|
||||
cleaned_text = ''.join(c for c in text if c.isdigit() or c == '-')
|
||||
# 提取所有匹配的取件码
|
||||
import re
|
||||
for pattern in [r'\d{1,2}-\d{1,2}-\d{2,3}', r'\d{4,8}']:
|
||||
matches = re.finditer(pattern, cleaned_text)
|
||||
for match in matches:
|
||||
code = match.group()
|
||||
# 如果已找到驿站,将取件码添加到当前驿站
|
||||
if current_station and code not in current_codes:
|
||||
current_codes.append(code)
|
||||
# 如果还没找到驿站,暂存取件码
|
||||
elif code not in current_codes:
|
||||
current_codes.append(code)
|
||||
|
||||
# 保存最后一个驿站的信息
|
||||
if current_station and current_codes:
|
||||
stations.append({
|
||||
"station_name": current_station,
|
||||
"pickup_codes": current_codes
|
||||
})
|
||||
# 如果有未分配到驿站的取件码,创建一个默认驿站
|
||||
elif current_codes:
|
||||
stations.append({
|
||||
"station_name": None,
|
||||
"pickup_codes": current_codes
|
||||
})
|
||||
|
||||
# 如果找到了取件码但没找到APP类型,根据取件码格式推测
|
||||
if stations and not pickup_info["app_type"]:
|
||||
# 如果任一取件码包含连字符,判定为菜鸟
|
||||
for station in stations:
|
||||
if any('-' in code for code in station["pickup_codes"]):
|
||||
pickup_info["app_type"] = "CAINIAO"
|
||||
break
|
||||
|
||||
pickup_info["stations"] = stations
|
||||
return pickup_info
|
||||
|
||||
ocr_service = OCRService()
|
||||
150
app/core/qwen_client.py
Normal file
150
app/core/qwen_client.py
Normal file
@ -0,0 +1,150 @@
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, Any, Optional, List
|
||||
import re
|
||||
from app.core.config import settings
|
||||
|
||||
# 导入 DashScope SDK
|
||||
try:
|
||||
from dashscope import MultiModalConversation
|
||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||
except ImportError:
|
||||
logging.error("请安装 DashScope SDK: pip install dashscope")
|
||||
raise
|
||||
|
||||
class QwenClient:
|
||||
"""千问 API 客户端 (使用 DashScope SDK)"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = settings.QWEN_API_KEY
|
||||
self.model = "qwen-vl-max" # 使用千问视觉语言大模型
|
||||
|
||||
async def extract_pickup_code(self, image_url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
从图片中提取取件码
|
||||
|
||||
Args:
|
||||
image_content: 图片二进制内容
|
||||
|
||||
Returns:
|
||||
Dict: 提取结果,包含取件码信息
|
||||
"""
|
||||
try:
|
||||
|
||||
# 构建消息
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专门识别快递取件码的助手。请准确提取图片中的所有取件码信息。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请识别图中驿站的所有取件码,以[{\"station\":\"驿站名字\",\"pickup_codes\":[\"3232\",\"2323\"]}]的格式返回。只返回JSON格式数据,不要其他解释。"
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_url
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 使用 SDK 调用 API
|
||||
response = MultiModalConversation.call(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
result_format='message',
|
||||
temperature=0.1,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
logging.error(f"千问 API 请求失败: {response.code} - {response.message}")
|
||||
return {"error": "API请求失败", "details": f"{response.code}: {response.message}"}
|
||||
|
||||
# 记录响应
|
||||
logging.info(f"千问 API 响应状态: {response.status_code}")
|
||||
logging.info(f"千问 API 响应内容: {response}")
|
||||
|
||||
# 提取回复内容
|
||||
try:
|
||||
# 直接使用响应对象
|
||||
# 提取消息内容 - 使用字典访问方式
|
||||
output = response.get('output', {})
|
||||
choices = output.get('choices', [{}])
|
||||
message = choices[0].get('message', {}) if choices else {}
|
||||
|
||||
logging.info(f"消息: {message}")
|
||||
print(f"消息: {message}")
|
||||
|
||||
# 获取文本内容
|
||||
content = message.get('content', [])
|
||||
if isinstance(content, list) and len(content) > 0:
|
||||
# 提取文本内容
|
||||
text_content = ""
|
||||
for item in content:
|
||||
if isinstance(item, dict) and 'text' in item:
|
||||
text_content = item['text']
|
||||
break
|
||||
|
||||
logging.info(f"提取的文本内容: {text_content}")
|
||||
|
||||
# 清理文本,移除 Markdown 代码块
|
||||
text_content = text_content.strip()
|
||||
|
||||
# 移除 ```json 和 ``` 标记
|
||||
if text_content.startswith("```json"):
|
||||
text_content = text_content[7:]
|
||||
elif text_content.startswith("```"):
|
||||
text_content = text_content[3:]
|
||||
|
||||
if text_content.endswith("```"):
|
||||
text_content = text_content[:-3]
|
||||
|
||||
text_content = text_content.strip()
|
||||
logging.info(f"清理后的文本内容: {text_content}")
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
pickup_data = json.loads(text_content)
|
||||
|
||||
# 确保是列表格式
|
||||
if isinstance(pickup_data, list):
|
||||
# 转换为统一格式
|
||||
return {"stations": [{"name": item.get("station", ""), "pickup_codes": item.get("pickup_codes", [])} for item in pickup_data]}
|
||||
else:
|
||||
logging.warning(f"解析结果不是列表格式: {pickup_data}")
|
||||
return {"stations": []}
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"JSON解析错误: {str(e)}, 原始字符串: {text_content}")
|
||||
|
||||
# 尝试使用正则表达式提取JSON
|
||||
json_match = re.search(r'(\[{.*}\])', text_content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
json_str = json_match.group(1)
|
||||
pickup_data = json.loads(json_str)
|
||||
return {"stations": [{"name": item.get("station", ""), "pickup_codes": item.get("pickup_codes", [])} for item in pickup_data]}
|
||||
except Exception as je:
|
||||
logging.error(f"正则提取的JSON解析错误: {str(je)}, 提取的字符串: {json_match.group(1)}")
|
||||
|
||||
return {"stations": []}
|
||||
else:
|
||||
logging.error(f"无法提取内容列表或内容列表为空: {content}")
|
||||
return {"stations": []}
|
||||
except Exception as e:
|
||||
logging.exception(f"解析千问 API 响应失败: {str(e)}")
|
||||
return {"stations": []}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"调用千问 API 异常: {str(e)}")
|
||||
return {"error": "处理失败", "message": str(e)}
|
||||
|
||||
# 创建全局实例
|
||||
qwen_client = QwenClient()
|
||||
16
app/main.py
16
app/main.py
@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, ocr, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee
|
||||
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai
|
||||
from app.models.database import Base, engine
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
@ -38,15 +38,9 @@ app.add_middleware(
|
||||
# 添加请求日志中间件
|
||||
app.add_middleware(RequestLoggerMiddleware)
|
||||
|
||||
@app.get("/api/info")
|
||||
async def get_info():
|
||||
"""获取当前环境信息"""
|
||||
return {
|
||||
"project": settings.PROJECT_NAME,
|
||||
"debug": settings.DEBUG
|
||||
}
|
||||
|
||||
# 添加用户路由
|
||||
app.include_router(ai.router, prefix="/api/ai", tags=["AI服务"])
|
||||
|
||||
app.include_router(dashboard.router, prefix="/api/dashboard", tags=["仪表盘"])
|
||||
app.include_router(wechat.router,prefix="/api/wechat",tags=["微信"])
|
||||
app.include_router(mp.router, prefix="/api/mp", tags=["微信公众号"])
|
||||
@ -77,7 +71,6 @@ app.include_router(message.router, prefix="/api/message", tags=["消息中心"])
|
||||
app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
|
||||
app.include_router(config.router, prefix="/api/config", tags=["系统配置"])
|
||||
app.include_router(log.router, prefix="/api/logs", tags=["系统日志"])
|
||||
app.include_router(ocr.router, prefix="/api/ai/ocr", tags=["图像识别"])
|
||||
app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
|
||||
|
||||
|
||||
@ -85,9 +78,6 @@ app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
|
||||
async def root():
|
||||
return {"message": "欢迎使用 Beefast 蜂快到家 API"}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
|
||||
@ -16,4 +16,5 @@ aiohttp==3.9.1
|
||||
cryptography==42.0.2
|
||||
qrcode>=7.3.1
|
||||
pillow>=9.0.0
|
||||
pytz==2024.1
|
||||
pytz==2024.1
|
||||
dashscope>=1.13.0
|
||||
Loading…
Reference in New Issue
Block a user