314 lines
10 KiB
Python
314 lines
10 KiB
Python
"""
|
||
新闻智能体 - 主控制器
|
||
实时抓取、分析、通知重要新闻
|
||
"""
|
||
import asyncio
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime, timedelta
|
||
|
||
from app.utils.logger import logger
|
||
from app.config import get_settings
|
||
from app.news_agent.sources import get_enabled_sources
|
||
from app.news_agent.fetcher import NewsFetcher, NewsItem
|
||
from app.news_agent.filter import NewsDeduplicator, NewsFilter
|
||
from app.news_agent.analyzer import NewsAnalyzer, NewsAnalyzerSimple
|
||
from app.news_agent.news_db_service import get_news_db_service
|
||
from app.news_agent.notifier import get_news_notifier
|
||
|
||
|
||
class NewsAgent:
|
||
"""新闻智能体 - 主控制器"""
|
||
|
||
_instance = None
|
||
_initialized = False
|
||
|
||
def __new__(cls, *args, **kwargs):
|
||
"""单例模式"""
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
"""初始化新闻智能体"""
|
||
if NewsAgent._initialized:
|
||
return
|
||
|
||
NewsAgent._initialized = True
|
||
self.settings = get_settings()
|
||
|
||
# 核心组件
|
||
self.fetcher = NewsFetcher()
|
||
self.deduplicator = NewsDeduplicator()
|
||
self.filter = NewsFilter()
|
||
self.analyzer = NewsAnalyzer() # LLM 分析器
|
||
self.simple_analyzer = NewsAnalyzerSimple() # 规则分析器(备用)
|
||
self.db_service = get_news_db_service()
|
||
self.notifier = get_news_notifier()
|
||
|
||
# 配置
|
||
self.fetch_interval = 300 # 抓取间隔(秒)= 5分钟
|
||
self.min_priority = 40.0 # 最低通知优先级
|
||
self.use_llm = True # 使用 LLM 批量分析
|
||
|
||
# 统计数据
|
||
self.stats = {
|
||
'total_fetched': 0,
|
||
'total_saved': 0,
|
||
'total_analyzed': 0,
|
||
'total_notified': 0,
|
||
'last_fetch_time': None,
|
||
'last_notify_time': None
|
||
}
|
||
|
||
# 运行状态
|
||
self.running = False
|
||
self._task = None
|
||
|
||
logger.info("新闻智能体初始化完成")
|
||
|
||
async def start(self):
|
||
"""启动新闻智能体"""
|
||
if self.running:
|
||
logger.warning("新闻智能体已在运行")
|
||
return
|
||
|
||
self.running = True
|
||
|
||
# 发送启动通知
|
||
sources = get_enabled_sources()
|
||
crypto_count = sum(1 for s in sources if s['category'] == 'crypto')
|
||
stock_count = sum(1 for s in sources if s['category'] == 'stock')
|
||
|
||
await self.notifier.notify_startup({
|
||
'crypto_sources': crypto_count,
|
||
'stock_sources': stock_count,
|
||
'fetch_interval': self.fetch_interval
|
||
})
|
||
|
||
# 启动后台任务
|
||
self._task = asyncio.create_task(self._run_loop())
|
||
|
||
logger.info("新闻智能体已启动")
|
||
|
||
async def stop(self):
|
||
"""停止新闻智能体"""
|
||
if not self.running:
|
||
return
|
||
|
||
self.running = False
|
||
|
||
if self._task:
|
||
self._task.cancel()
|
||
try:
|
||
await self._task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
await self.fetcher.close()
|
||
|
||
logger.info("新闻智能体已停止")
|
||
|
||
async def _run_loop(self):
|
||
"""主循环"""
|
||
while self.running:
|
||
try:
|
||
await self._fetch_and_process_news()
|
||
|
||
except Exception as e:
|
||
logger.error(f"新闻处理循环出错: {e}")
|
||
await self.notifier.notify_error(str(e))
|
||
|
||
# 等待下一次抓取
|
||
await asyncio.sleep(self.fetch_interval)
|
||
|
||
async def _fetch_and_process_news(self):
|
||
"""抓取并处理新闻"""
|
||
logger.info("=" * 60)
|
||
logger.info("开始新闻处理周期")
|
||
|
||
# 1. 抓取新闻
|
||
items = await self.fetcher.fetch_all_news()
|
||
self.stats['total_fetched'] += len(items)
|
||
self.stats['last_fetch_time'] = datetime.utcnow().isoformat()
|
||
|
||
if not items:
|
||
logger.info("没有获取到新新闻")
|
||
return
|
||
|
||
logger.info(f"获取到 {len(items)} 条新闻")
|
||
|
||
# 2. 去重
|
||
items = self.deduplicator.deduplicate_list(items)
|
||
logger.info(f"去重后剩余 {len(items)} 条")
|
||
|
||
# 3. 过滤
|
||
filtered_items = self.filter.filter_news(items)
|
||
logger.info(f"过滤后剩余 {len(filtered_items)} 条")
|
||
|
||
if not filtered_items:
|
||
logger.info("没有符合条件的新闻")
|
||
return
|
||
|
||
# 4. 保存到数据库
|
||
saved_articles = []
|
||
for item in filtered_items:
|
||
# 检查数据库中是否已存在
|
||
if self.db_service.check_duplicate_by_hash(item.content_hash):
|
||
continue
|
||
|
||
# 保存
|
||
article_data = {
|
||
'title': item.title,
|
||
'content': item.content,
|
||
'url': item.url,
|
||
'source': item.source,
|
||
'author': item.author,
|
||
'category': item.category,
|
||
'tags': item.tags,
|
||
'published_at': item.published_at,
|
||
'crawled_at': item.crawled_at,
|
||
'content_hash': item.content_hash,
|
||
'quality_score': getattr(item, 'quality_score', 0.5),
|
||
}
|
||
|
||
article = self.db_service.save_article(article_data)
|
||
if article:
|
||
saved_articles.append((article, item))
|
||
|
||
self.stats['total_saved'] += len(saved_articles)
|
||
logger.info(f"保存了 {len(saved_articles)} 条新文章")
|
||
|
||
if not saved_articles:
|
||
return
|
||
|
||
# 5. LLM 分析(仅批量分析)
|
||
analyzed_count = 0
|
||
high_priority_articles = []
|
||
|
||
if self.use_llm:
|
||
# 只使用批量分析 (异步)
|
||
items_to_analyze = [item for _, item in saved_articles]
|
||
results = await self.analyzer.analyze_batch(items_to_analyze)
|
||
|
||
for (article, _), result in zip(saved_articles, results):
|
||
if result:
|
||
priority = self.analyzer.calculate_priority(
|
||
result,
|
||
getattr(article, 'quality_score', 0.5)
|
||
)
|
||
self.db_service.mark_as_analyzed(article.id, result, priority)
|
||
|
||
analyzed_count += 1
|
||
# 只发送重大影响(high)的新闻
|
||
if result.get('market_impact') == 'high':
|
||
article_dict = article.to_dict()
|
||
article_dict.update({
|
||
'llm_analyzed': True,
|
||
'market_impact': result.get('market_impact'),
|
||
'impact_type': result.get('impact_type'),
|
||
'sentiment': result.get('sentiment'),
|
||
'summary': result.get('summary'),
|
||
'key_points': result.get('key_points'),
|
||
'trading_advice': result.get('trading_advice'),
|
||
'relevant_symbols': result.get('relevant_symbols'),
|
||
'priority': priority,
|
||
})
|
||
high_priority_articles.append(article_dict)
|
||
|
||
else:
|
||
# 使用规则分析
|
||
for article, item in saved_articles:
|
||
result = self.simple_analyzer.analyze_single(item)
|
||
priority = result.get('confidence', 50)
|
||
|
||
self.db_service.mark_as_analyzed(article.id, result, priority)
|
||
analyzed_count += 1
|
||
|
||
# 只发送重大影响(high)的新闻
|
||
if result.get('market_impact') == 'high':
|
||
article_dict = article.to_dict()
|
||
article_dict.update({
|
||
'llm_analyzed': True,
|
||
'market_impact': result.get('market_impact'),
|
||
'impact_type': result.get('impact_type'),
|
||
'sentiment': result.get('sentiment'),
|
||
'summary': result.get('summary'),
|
||
'key_points': result.get('key_points'),
|
||
'trading_advice': result.get('trading_advice'),
|
||
'relevant_symbols': result.get('relevant_symbols'),
|
||
'priority': priority,
|
||
})
|
||
high_priority_articles.append(article_dict)
|
||
|
||
self.stats['total_analyzed'] += analyzed_count
|
||
logger.info(f"分析了 {analyzed_count} 条文章")
|
||
|
||
# 6. 发送通知(仅批量发送)
|
||
if high_priority_articles:
|
||
# 按优先级排序
|
||
high_priority_articles.sort(
|
||
key=lambda x: x.get('priority', 0),
|
||
reverse=True
|
||
)
|
||
|
||
# 批量发送最多10条
|
||
await self.notifier.notify_news_batch(high_priority_articles[:10])
|
||
for article in high_priority_articles[:10]:
|
||
self.db_service.mark_as_notified(article['id'])
|
||
self.stats['total_notified'] += 1
|
||
|
||
self.stats['last_notify_time'] = datetime.utcnow().isoformat()
|
||
|
||
logger.info("=" * 60)
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取统计数据"""
|
||
stats = self.stats.copy()
|
||
stats['running'] = self.running
|
||
stats['fetch_interval'] = self.fetch_interval
|
||
stats['use_llm'] = self.use_llm
|
||
|
||
# 从数据库获取更多统计
|
||
db_stats = self.db_service.get_stats(hours=24)
|
||
stats['db_stats'] = db_stats
|
||
|
||
return stats
|
||
|
||
async def manual_fetch(self, category: str = None) -> Dict[str, Any]:
|
||
"""
|
||
手动触发新闻抓取
|
||
|
||
Args:
|
||
category: 分类过滤
|
||
|
||
Returns:
|
||
处理结果
|
||
"""
|
||
logger.info(f"手动触发新闻抓取: category={category}")
|
||
|
||
items = await self.fetcher.fetch_all_news(category)
|
||
|
||
result = {
|
||
'fetched': len(items),
|
||
'timestamp': datetime.utcnow().isoformat()
|
||
}
|
||
|
||
if items:
|
||
# 这里可以触发处理流程
|
||
# 为简化,只返回抓取结果
|
||
result['items'] = [item.to_dict() for item in items[:5]]
|
||
|
||
return result
|
||
|
||
|
||
# 全局实例
|
||
_news_agent = None
|
||
|
||
|
||
def get_news_agent() -> NewsAgent:
|
||
"""获取新闻智能体单例"""
|
||
global _news_agent
|
||
if _news_agent is None:
|
||
_news_agent = NewsAgent()
|
||
return _news_agent
|