""" 新闻数据库服务 """ from datetime import datetime, timedelta from typing import List, Dict, Any, Optional from sqlalchemy import create_engine, and_, or_ from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.exc import IntegrityError from app.models.news import NewsArticle from app.models.database import Base from app.config import get_settings from app.utils.logger import logger class NewsDatabaseService: """新闻数据库服务""" def __init__(self): self.settings = get_settings() self.engine = None self.SessionLocal = None self._init_db() def _init_db(self): """初始化数据库连接""" try: # 使用 settings.database_url 或构建路径 if hasattr(self.settings, 'database_url'): database_url = self.settings.database_url elif hasattr(self.settings, 'database_path'): database_url = f"sqlite:///{self.settings.database_path}" else: # 默认路径 database_url = "sqlite:///./backend/stock_agent.db" self.engine = create_engine( database_url, connect_args={"check_same_thread": False}, echo=False ) self.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=self.engine ) # 创建表(如果不存在) from app.models.news import NewsArticle NewsArticle.metadata.create_all(self.engine, checkfirst=True) logger.info("新闻数据库服务初始化完成") except Exception as e: logger.error(f"新闻数据库初始化失败: {e}") import traceback logger.error(traceback.format_exc()) # 重新抛出异常,避免 SessionLocal 为 None raise def get_session(self) -> Session: """获取数据库会话""" return self.SessionLocal() def save_article(self, article_data: Dict[str, Any]) -> Optional[NewsArticle]: """ 保存单篇文章 Args: article_data: 文章数据字典 Returns: 保存的文章对象或 None """ session = self.get_session() try: article = NewsArticle(**article_data) session.add(article) session.commit() session.refresh(article) logger.debug(f"文章保存成功: {article.title[:50]}...") return article except IntegrityError as e: session.rollback() logger.debug(f"文章已存在(URL 重复): {article_data.get('url', '')}") return None except Exception as e: session.rollback() logger.error(f"保存文章失败: {e}") return None finally: session.close() def check_duplicate_by_hash(self, content_hash: str, hours: int = 24) -> bool: """ 检查内容哈希是否重复 Args: content_hash: 内容哈希 hours: 检查最近多少小时 Returns: True 如果重复 """ session = self.get_session() try: since = datetime.utcnow() - timedelta(hours=hours) count = session.query(NewsArticle).filter( and_( NewsArticle.content_hash == content_hash, NewsArticle.created_at >= since ) ).count() return count > 0 finally: session.close() def mark_as_analyzed( self, article_id: int, analysis: Dict[str, Any], priority: float ) -> bool: """ 标记文章已分析 Args: article_id: 文章 ID analysis: LLM 分析结果 priority: 优先级分数 Returns: 是否成功 """ session = self.get_session() try: article = session.query(NewsArticle).filter( NewsArticle.id == article_id ).first() if not article: logger.warning(f"文章不存在: {article_id}") return False article.llm_analyzed = True article.market_impact = analysis.get('market_impact') article.impact_type = analysis.get('impact_type') article.sentiment = analysis.get('sentiment') article.summary = analysis.get('summary') article.key_points = analysis.get('key_points') article.trading_advice = analysis.get('trading_advice') article.relevant_symbols = analysis.get('relevant_symbols') article.quality_score = analysis.get('confidence', 70) / 100 article.priority = priority session.commit() logger.debug(f"文章分析结果已保存: {article.title[:50]}...") return True except Exception as e: session.rollback() logger.error(f"保存分析结果失败: {e}") return False finally: session.close() def mark_as_notified(self, article_id: int, channel: str = 'feishu') -> bool: """ 标记文章已发送通知 Args: article_id: 文章 ID channel: 通知渠道 Returns: 是否成功 """ session = self.get_session() try: article = session.query(NewsArticle).filter( NewsArticle.id == article_id ).first() if not article: return False article.notified = True article.notification_sent_at = datetime.utcnow() article.notification_channel = channel session.commit() return True except Exception as e: session.rollback() logger.error(f"标记通知状态失败: {e}") return False finally: session.close() def get_high_priority_articles( self, limit: int = 20, min_priority: float = 40.0, hours: int = 24 ) -> List[NewsArticle]: """ 获取高优先级文章 Args: limit: 返回数量限制 min_priority: 最低优先级分数 hours: 查询最近多少小时 Returns: 文章列表 """ session = self.get_session() try: since = datetime.utcnow() - timedelta(hours=hours) articles = session.query(NewsArticle).filter( and_( NewsArticle.llm_analyzed == True, NewsArticle.priority >= min_priority, NewsArticle.created_at >= since, NewsArticle.notified == False ) ).order_by(NewsArticle.priority.desc()).limit(limit).all() return articles finally: session.close() def get_latest_articles( self, category: str = None, limit: int = 50, hours: int = 24 ) -> List[Dict[str, Any]]: """ 获取最新文章 Args: category: 分类过滤 limit: 返回数量限制 hours: 查询最近多少小时 Returns: 文章字典列表 """ session = self.get_session() try: since = datetime.utcnow() - timedelta(hours=hours) query = session.query(NewsArticle).filter( NewsArticle.created_at >= since ) if category: query = query.filter(NewsArticle.category == category) articles = query.order_by( NewsArticle.created_at.desc() ).limit(limit).all() return [article.to_dict() for article in articles] finally: session.close() def get_stats(self, hours: int = 24) -> Dict[str, Any]: """ 获取统计数据 Args: hours: 统计最近多少小时 Returns: 统计数据 """ session = self.get_session() try: since = datetime.utcnow() - timedelta(hours=hours) total = session.query(NewsArticle).filter( NewsArticle.created_at >= since ).count() analyzed = session.query(NewsArticle).filter( and_( NewsArticle.created_at >= since, NewsArticle.llm_analyzed == True ) ).count() high_impact = session.query(NewsArticle).filter( and_( NewsArticle.created_at >= since, NewsArticle.market_impact == 'high' ) ).count() notified = session.query(NewsArticle).filter( and_( NewsArticle.created_at >= since, NewsArticle.notified == True ) ).count() return { 'total_articles': total, 'analyzed': analyzed, 'high_impact': high_impact, 'notified': notified, 'hours': hours } finally: session.close() def get_unanalyzed_articles(self, limit: int = 50, hours: int = 24) -> List[NewsArticle]: """ 获取未分析的文章 Args: limit: 返回数量限制 hours: 查询最近多少小时 Returns: 未分析的文章列表 """ session = self.get_session() try: since = datetime.utcnow() - timedelta(hours=hours) articles = session.query(NewsArticle).filter( and_( NewsArticle.llm_analyzed == False, NewsArticle.created_at >= since ) ).order_by(NewsArticle.created_at.desc()).limit(limit).all() return articles finally: session.close() def clean_old_articles(self, days: int = 7) -> int: """ 清理旧文章(设置为不活跃) Args: days: 保留多少天的文章 Returns: 清理的数量 """ session = self.get_session() try: before = datetime.utcnow() - timedelta(days=days) count = session.query(NewsArticle).filter( NewsArticle.created_at < before ).update({ 'is_active': False }) session.commit() if count > 0: logger.info(f"清理了 {count} 条旧文章") return count except Exception as e: session.rollback() logger.error(f"清理旧文章失败: {e}") return 0 finally: session.close() # 全局实例 _news_db_service = None def get_news_db_service() -> NewsDatabaseService: """获取新闻数据库服务单例""" global _news_db_service if _news_db_service is None: _news_db_service = NewsDatabaseService() return _news_db_service