stock-ai-agent/backend/app/news_agent/news_db_service.py
2026-02-25 19:59:20 +08:00

407 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
新闻数据库服务
"""
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from sqlalchemy import create_engine, and_, or_
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.exc import IntegrityError
from app.models.news import NewsArticle
from app.models.database import Base
from app.config import get_settings
from app.utils.logger import logger
class NewsDatabaseService:
"""新闻数据库服务"""
def __init__(self):
self.settings = get_settings()
self.engine = None
self.SessionLocal = None
self._init_db()
def _init_db(self):
"""初始化数据库连接"""
try:
# 使用 settings.database_url 或构建路径
if hasattr(self.settings, 'database_url'):
database_url = self.settings.database_url
elif hasattr(self.settings, 'database_path'):
database_url = f"sqlite:///{self.settings.database_path}"
else:
# 默认路径
database_url = "sqlite:///./backend/stock_agent.db"
self.engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
echo=False
)
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
# 创建表(如果不存在)
from app.models.news import NewsArticle
NewsArticle.metadata.create_all(self.engine, checkfirst=True)
logger.info("新闻数据库服务初始化完成")
except Exception as e:
logger.error(f"新闻数据库初始化失败: {e}")
import traceback
logger.error(traceback.format_exc())
# 重新抛出异常,避免 SessionLocal 为 None
raise
def get_session(self) -> Session:
"""获取数据库会话"""
return self.SessionLocal()
def save_article(self, article_data: Dict[str, Any]) -> Optional[NewsArticle]:
"""
保存单篇文章
Args:
article_data: 文章数据字典
Returns:
保存的文章对象或 None
"""
session = self.get_session()
try:
article = NewsArticle(**article_data)
session.add(article)
session.commit()
session.refresh(article)
logger.debug(f"文章保存成功: {article.title[:50]}...")
return article
except IntegrityError as e:
session.rollback()
logger.debug(f"文章已存在URL 重复): {article_data.get('url', '')}")
return None
except Exception as e:
session.rollback()
logger.error(f"保存文章失败: {e}")
return None
finally:
session.close()
def check_duplicate_by_hash(self, content_hash: str, hours: int = 24) -> bool:
"""
检查内容哈希是否重复
Args:
content_hash: 内容哈希
hours: 检查最近多少小时
Returns:
True 如果重复
"""
session = self.get_session()
try:
since = datetime.utcnow() - timedelta(hours=hours)
count = session.query(NewsArticle).filter(
and_(
NewsArticle.content_hash == content_hash,
NewsArticle.created_at >= since
)
).count()
return count > 0
finally:
session.close()
def mark_as_analyzed(
self,
article_id: int,
analysis: Dict[str, Any],
priority: float
) -> bool:
"""
标记文章已分析
Args:
article_id: 文章 ID
analysis: LLM 分析结果
priority: 优先级分数
Returns:
是否成功
"""
session = self.get_session()
try:
article = session.query(NewsArticle).filter(
NewsArticle.id == article_id
).first()
if not article:
logger.warning(f"文章不存在: {article_id}")
return False
article.llm_analyzed = True
article.market_impact = analysis.get('market_impact')
article.impact_type = analysis.get('impact_type')
article.sentiment = analysis.get('sentiment')
article.summary = analysis.get('summary')
article.key_points = analysis.get('key_points')
article.trading_advice = analysis.get('trading_advice')
article.relevant_symbols = analysis.get('relevant_symbols')
article.quality_score = analysis.get('confidence', 70) / 100
article.priority = priority
session.commit()
logger.debug(f"文章分析结果已保存: {article.title[:50]}...")
return True
except Exception as e:
session.rollback()
logger.error(f"保存分析结果失败: {e}")
return False
finally:
session.close()
def mark_as_notified(self, article_id: int, channel: str = 'feishu') -> bool:
"""
标记文章已发送通知
Args:
article_id: 文章 ID
channel: 通知渠道
Returns:
是否成功
"""
session = self.get_session()
try:
article = session.query(NewsArticle).filter(
NewsArticle.id == article_id
).first()
if not article:
return False
article.notified = True
article.notification_sent_at = datetime.utcnow()
article.notification_channel = channel
session.commit()
return True
except Exception as e:
session.rollback()
logger.error(f"标记通知状态失败: {e}")
return False
finally:
session.close()
def get_high_priority_articles(
self,
limit: int = 20,
min_priority: float = 40.0,
hours: int = 24
) -> List[NewsArticle]:
"""
获取高优先级文章
Args:
limit: 返回数量限制
min_priority: 最低优先级分数
hours: 查询最近多少小时
Returns:
文章列表
"""
session = self.get_session()
try:
since = datetime.utcnow() - timedelta(hours=hours)
articles = session.query(NewsArticle).filter(
and_(
NewsArticle.llm_analyzed == True,
NewsArticle.priority >= min_priority,
NewsArticle.created_at >= since,
NewsArticle.notified == False
)
).order_by(NewsArticle.priority.desc()).limit(limit).all()
return articles
finally:
session.close()
def get_latest_articles(
self,
category: str = None,
limit: int = 50,
hours: int = 24
) -> List[Dict[str, Any]]:
"""
获取最新文章
Args:
category: 分类过滤
limit: 返回数量限制
hours: 查询最近多少小时
Returns:
文章字典列表
"""
session = self.get_session()
try:
since = datetime.utcnow() - timedelta(hours=hours)
query = session.query(NewsArticle).filter(
NewsArticle.created_at >= since
)
if category:
query = query.filter(NewsArticle.category == category)
articles = query.order_by(
NewsArticle.created_at.desc()
).limit(limit).all()
return [article.to_dict() for article in articles]
finally:
session.close()
def get_stats(self, hours: int = 24) -> Dict[str, Any]:
"""
获取统计数据
Args:
hours: 统计最近多少小时
Returns:
统计数据
"""
session = self.get_session()
try:
since = datetime.utcnow() - timedelta(hours=hours)
total = session.query(NewsArticle).filter(
NewsArticle.created_at >= since
).count()
analyzed = session.query(NewsArticle).filter(
and_(
NewsArticle.created_at >= since,
NewsArticle.llm_analyzed == True
)
).count()
high_impact = session.query(NewsArticle).filter(
and_(
NewsArticle.created_at >= since,
NewsArticle.market_impact == 'high'
)
).count()
notified = session.query(NewsArticle).filter(
and_(
NewsArticle.created_at >= since,
NewsArticle.notified == True
)
).count()
return {
'total_articles': total,
'analyzed': analyzed,
'high_impact': high_impact,
'notified': notified,
'hours': hours
}
finally:
session.close()
def get_unanalyzed_articles(self, limit: int = 50, hours: int = 24) -> List[NewsArticle]:
"""
获取未分析的文章
Args:
limit: 返回数量限制
hours: 查询最近多少小时
Returns:
未分析的文章列表
"""
session = self.get_session()
try:
since = datetime.utcnow() - timedelta(hours=hours)
articles = session.query(NewsArticle).filter(
and_(
NewsArticle.llm_analyzed == False,
NewsArticle.created_at >= since
)
).order_by(NewsArticle.created_at.desc()).limit(limit).all()
return articles
finally:
session.close()
def clean_old_articles(self, days: int = 7) -> int:
"""
清理旧文章(设置为不活跃)
Args:
days: 保留多少天的文章
Returns:
清理的数量
"""
session = self.get_session()
try:
before = datetime.utcnow() - timedelta(days=days)
count = session.query(NewsArticle).filter(
NewsArticle.created_at < before
).update({
'is_active': False
})
session.commit()
if count > 0:
logger.info(f"清理了 {count} 条旧文章")
return count
except Exception as e:
session.rollback()
logger.error(f"清理旧文章失败: {e}")
return 0
finally:
session.close()
# 全局实例
_news_db_service = None
def get_news_db_service() -> NewsDatabaseService:
"""获取新闻数据库服务单例"""
global _news_db_service
if _news_db_service is None:
_news_db_service = NewsDatabaseService()
return _news_db_service