181 lines
5.6 KiB
Python
181 lines
5.6 KiB
Python
"""
|
||
Brave搜索技能
|
||
提供网页搜索、新闻搜索等能力
|
||
"""
|
||
import aiohttp
|
||
from typing import Dict, Any, List, Optional
|
||
from app.skills.base import BaseSkill, SkillParameter
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class BraveSearchSkill(BaseSkill):
|
||
"""Brave搜索技能"""
|
||
|
||
def __init__(self, api_key: str = "BSAcaROCUmCAI0XsQWzxooWT74LFFX_"):
|
||
super().__init__()
|
||
self.name = "brave_search"
|
||
self.description = "使用Brave搜索引擎搜索网页、新闻、公司公告等实时信息"
|
||
self.api_key = api_key
|
||
self.base_url = "https://api.search.brave.com/res/v1"
|
||
|
||
self.parameters = [
|
||
SkillParameter(
|
||
name="query",
|
||
type="string",
|
||
description="搜索关键词",
|
||
required=True
|
||
),
|
||
SkillParameter(
|
||
name="search_type",
|
||
type="string",
|
||
description="搜索类型:web(网页)、news(新闻)",
|
||
required=False,
|
||
default="web"
|
||
),
|
||
SkillParameter(
|
||
name="count",
|
||
type="integer",
|
||
description="返回结果数量(1-20)",
|
||
required=False,
|
||
default=5
|
||
),
|
||
SkillParameter(
|
||
name="freshness",
|
||
type="string",
|
||
description="时效性:pd(过去一天)、pw(过去一周)、pm(过去一月)、py(过去一年)",
|
||
required=False,
|
||
default=None
|
||
)
|
||
]
|
||
|
||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||
"""
|
||
执行Brave搜索
|
||
|
||
Args:
|
||
query: 搜索关键词
|
||
search_type: 搜索类型(web/news)
|
||
count: 结果数量
|
||
freshness: 时效性过滤
|
||
|
||
Returns:
|
||
搜索结果
|
||
"""
|
||
query = kwargs.get("query")
|
||
search_type = kwargs.get("search_type", "web")
|
||
count = kwargs.get("count", 5)
|
||
freshness = kwargs.get("freshness")
|
||
|
||
logger.info(f"Brave搜索: {query}, 类型: {search_type}")
|
||
|
||
try:
|
||
if search_type == "news":
|
||
results = await self._search_news(query, count, freshness)
|
||
else:
|
||
results = await self._search_web(query, count, freshness)
|
||
|
||
return {
|
||
"query": query,
|
||
"search_type": search_type,
|
||
"results": results,
|
||
"count": len(results)
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Brave搜索失败: {e}")
|
||
return {
|
||
"error": f"搜索失败: {str(e)}"
|
||
}
|
||
|
||
async def _search_web(
|
||
self,
|
||
query: str,
|
||
count: int = 5,
|
||
freshness: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""网页搜索"""
|
||
url = f"{self.base_url}/web/search"
|
||
|
||
params = {
|
||
"q": query,
|
||
"count": min(count, 20),
|
||
"text_decorations": False,
|
||
"search_lang": "zh-hans"
|
||
}
|
||
|
||
if freshness:
|
||
params["freshness"] = freshness
|
||
|
||
headers = {
|
||
"Accept": "application/json",
|
||
"X-Subscription-Token": self.api_key
|
||
}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, params=params, headers=headers) as response:
|
||
if response.status != 200:
|
||
error_text = await response.text()
|
||
raise Exception(f"API请求失败: {response.status}, {error_text}")
|
||
|
||
data = await response.json()
|
||
|
||
# 解析结果
|
||
results = []
|
||
web_results = data.get("web", {}).get("results", [])
|
||
|
||
for item in web_results[:count]:
|
||
results.append({
|
||
"title": item.get("title", ""),
|
||
"url": item.get("url", ""),
|
||
"description": item.get("description", ""),
|
||
"published": item.get("age", "")
|
||
})
|
||
|
||
return results
|
||
|
||
async def _search_news(
|
||
self,
|
||
query: str,
|
||
count: int = 5,
|
||
freshness: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""新闻搜索"""
|
||
url = f"{self.base_url}/news/search"
|
||
|
||
params = {
|
||
"q": query,
|
||
"count": min(count, 20),
|
||
"search_lang": "zh-hans"
|
||
}
|
||
|
||
if freshness:
|
||
params["freshness"] = freshness
|
||
|
||
headers = {
|
||
"Accept": "application/json",
|
||
"X-Subscription-Token": self.api_key
|
||
}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url, params=params, headers=headers) as response:
|
||
if response.status != 200:
|
||
error_text = await response.text()
|
||
raise Exception(f"API请求失败: {response.status}, {error_text}")
|
||
|
||
data = await response.json()
|
||
|
||
# 解析结果
|
||
results = []
|
||
news_results = data.get("results", [])
|
||
|
||
for item in news_results[:count]:
|
||
results.append({
|
||
"title": item.get("title", ""),
|
||
"url": item.get("url", ""),
|
||
"description": item.get("description", ""),
|
||
"published": item.get("age", ""),
|
||
"source": item.get("meta_url", {}).get("hostname", "")
|
||
})
|
||
|
||
return results
|