stock-ai-agent/backend/app/news_agent/news_agent.py
2026-02-26 19:48:57 +08:00

314 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
新闻智能体 - 主控制器
实时抓取、分析、通知重要新闻
"""
import asyncio
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from app.utils.logger import logger
from app.config import get_settings
from app.news_agent.sources import get_enabled_sources
from app.news_agent.fetcher import NewsFetcher, NewsItem
from app.news_agent.filter import NewsDeduplicator, NewsFilter
from app.news_agent.analyzer import NewsAnalyzer, NewsAnalyzerSimple
from app.news_agent.news_db_service import get_news_db_service
from app.news_agent.notifier import get_news_notifier
class NewsAgent:
"""新闻智能体 - 主控制器"""
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""初始化新闻智能体"""
if NewsAgent._initialized:
return
NewsAgent._initialized = True
self.settings = get_settings()
# 核心组件
self.fetcher = NewsFetcher()
self.deduplicator = NewsDeduplicator()
self.filter = NewsFilter()
self.analyzer = NewsAnalyzer() # LLM 分析器
self.simple_analyzer = NewsAnalyzerSimple() # 规则分析器(备用)
self.db_service = get_news_db_service()
self.notifier = get_news_notifier()
# 配置
self.fetch_interval = 300 # 抓取间隔(秒)= 5分钟
self.min_priority = 40.0 # 最低通知优先级
self.use_llm = True # 使用 LLM 批量分析
# 统计数据
self.stats = {
'total_fetched': 0,
'total_saved': 0,
'total_analyzed': 0,
'total_notified': 0,
'last_fetch_time': None,
'last_notify_time': None
}
# 运行状态
self.running = False
self._task = None
logger.info("新闻智能体初始化完成")
async def start(self):
"""启动新闻智能体"""
if self.running:
logger.warning("新闻智能体已在运行")
return
self.running = True
# 发送启动通知
sources = get_enabled_sources()
crypto_count = sum(1 for s in sources if s['category'] == 'crypto')
stock_count = sum(1 for s in sources if s['category'] == 'stock')
await self.notifier.notify_startup({
'crypto_sources': crypto_count,
'stock_sources': stock_count,
'fetch_interval': self.fetch_interval
})
# 启动后台任务
self._task = asyncio.create_task(self._run_loop())
logger.info("新闻智能体已启动")
async def stop(self):
"""停止新闻智能体"""
if not self.running:
return
self.running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
await self.fetcher.close()
logger.info("新闻智能体已停止")
async def _run_loop(self):
"""主循环"""
while self.running:
try:
await self._fetch_and_process_news()
except Exception as e:
logger.error(f"新闻处理循环出错: {e}")
await self.notifier.notify_error(str(e))
# 等待下一次抓取
await asyncio.sleep(self.fetch_interval)
async def _fetch_and_process_news(self):
"""抓取并处理新闻"""
logger.info("=" * 60)
logger.info("开始新闻处理周期")
# 1. 抓取新闻
items = await self.fetcher.fetch_all_news()
self.stats['total_fetched'] += len(items)
self.stats['last_fetch_time'] = datetime.utcnow().isoformat()
if not items:
logger.info("没有获取到新新闻")
return
logger.info(f"获取到 {len(items)} 条新闻")
# 2. 去重
items = self.deduplicator.deduplicate_list(items)
logger.info(f"去重后剩余 {len(items)}")
# 3. 过滤
filtered_items = self.filter.filter_news(items)
logger.info(f"过滤后剩余 {len(filtered_items)}")
if not filtered_items:
logger.info("没有符合条件的新闻")
return
# 4. 保存到数据库
saved_articles = []
for item in filtered_items:
# 检查数据库中是否已存在
if self.db_service.check_duplicate_by_hash(item.content_hash):
continue
# 保存
article_data = {
'title': item.title,
'content': item.content,
'url': item.url,
'source': item.source,
'author': item.author,
'category': item.category,
'tags': item.tags,
'published_at': item.published_at,
'crawled_at': item.crawled_at,
'content_hash': item.content_hash,
'quality_score': getattr(item, 'quality_score', 0.5),
}
article = self.db_service.save_article(article_data)
if article:
saved_articles.append((article, item))
self.stats['total_saved'] += len(saved_articles)
logger.info(f"保存了 {len(saved_articles)} 条新文章")
if not saved_articles:
return
# 5. LLM 分析(仅批量分析)
analyzed_count = 0
high_priority_articles = []
if self.use_llm:
# 只使用批量分析 (异步)
items_to_analyze = [item for _, item in saved_articles]
results = await self.analyzer.analyze_batch(items_to_analyze)
for (article, _), result in zip(saved_articles, results):
if result:
priority = self.analyzer.calculate_priority(
result,
getattr(article, 'quality_score', 0.5)
)
self.db_service.mark_as_analyzed(article.id, result, priority)
analyzed_count += 1
# 只发送重大影响high的新闻
if result.get('market_impact') == 'high':
article_dict = article.to_dict()
article_dict.update({
'llm_analyzed': True,
'market_impact': result.get('market_impact'),
'impact_type': result.get('impact_type'),
'sentiment': result.get('sentiment'),
'summary': result.get('summary'),
'key_points': result.get('key_points'),
'trading_advice': result.get('trading_advice'),
'relevant_symbols': result.get('relevant_symbols'),
'priority': priority,
})
high_priority_articles.append(article_dict)
else:
# 使用规则分析
for article, item in saved_articles:
result = self.simple_analyzer.analyze_single(item)
priority = result.get('confidence', 50)
self.db_service.mark_as_analyzed(article.id, result, priority)
analyzed_count += 1
# 只发送重大影响high的新闻
if result.get('market_impact') == 'high':
article_dict = article.to_dict()
article_dict.update({
'llm_analyzed': True,
'market_impact': result.get('market_impact'),
'impact_type': result.get('impact_type'),
'sentiment': result.get('sentiment'),
'summary': result.get('summary'),
'key_points': result.get('key_points'),
'trading_advice': result.get('trading_advice'),
'relevant_symbols': result.get('relevant_symbols'),
'priority': priority,
})
high_priority_articles.append(article_dict)
self.stats['total_analyzed'] += analyzed_count
logger.info(f"分析了 {analyzed_count} 条文章")
# 6. 发送通知(仅批量发送)
if high_priority_articles:
# 按优先级排序
high_priority_articles.sort(
key=lambda x: x.get('priority', 0),
reverse=True
)
# 批量发送最多10条
await self.notifier.notify_news_batch(high_priority_articles[:10])
for article in high_priority_articles[:10]:
self.db_service.mark_as_notified(article['id'])
self.stats['total_notified'] += 1
self.stats['last_notify_time'] = datetime.utcnow().isoformat()
logger.info("=" * 60)
def get_stats(self) -> Dict[str, Any]:
"""获取统计数据"""
stats = self.stats.copy()
stats['running'] = self.running
stats['fetch_interval'] = self.fetch_interval
stats['use_llm'] = self.use_llm
# 从数据库获取更多统计
db_stats = self.db_service.get_stats(hours=24)
stats['db_stats'] = db_stats
return stats
async def manual_fetch(self, category: str = None) -> Dict[str, Any]:
"""
手动触发新闻抓取
Args:
category: 分类过滤
Returns:
处理结果
"""
logger.info(f"手动触发新闻抓取: category={category}")
items = await self.fetcher.fetch_all_news(category)
result = {
'fetched': len(items),
'timestamp': datetime.utcnow().isoformat()
}
if items:
# 这里可以触发处理流程
# 为简化,只返回抓取结果
result['items'] = [item.to_dict() for item in items[:5]]
return result
# 全局实例
_news_agent = None
def get_news_agent() -> NewsAgent:
"""获取新闻智能体单例"""
global _news_agent
if _news_agent is None:
_news_agent = NewsAgent()
return _news_agent