407 lines
11 KiB
Python
407 lines
11 KiB
Python
"""
|
||
新闻数据库服务
|
||
"""
|
||
from datetime import datetime, timedelta
|
||
from typing import List, Dict, Any, Optional
|
||
from sqlalchemy import create_engine, and_, or_
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
from sqlalchemy.exc import IntegrityError
|
||
|
||
from app.models.news import NewsArticle
|
||
from app.models.database import Base
|
||
from app.config import get_settings
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class NewsDatabaseService:
|
||
"""新闻数据库服务"""
|
||
|
||
def __init__(self):
|
||
self.settings = get_settings()
|
||
self.engine = None
|
||
self.SessionLocal = None
|
||
self._init_db()
|
||
|
||
def _init_db(self):
|
||
"""初始化数据库连接"""
|
||
try:
|
||
# 使用 settings.database_url 或构建路径
|
||
if hasattr(self.settings, 'database_url'):
|
||
database_url = self.settings.database_url
|
||
elif hasattr(self.settings, 'database_path'):
|
||
database_url = f"sqlite:///{self.settings.database_path}"
|
||
else:
|
||
# 默认路径
|
||
database_url = "sqlite:///./backend/stock_agent.db"
|
||
|
||
self.engine = create_engine(
|
||
database_url,
|
||
connect_args={"check_same_thread": False},
|
||
echo=False
|
||
)
|
||
|
||
self.SessionLocal = sessionmaker(
|
||
autocommit=False,
|
||
autoflush=False,
|
||
bind=self.engine
|
||
)
|
||
|
||
# 创建表(如果不存在)
|
||
from app.models.news import NewsArticle
|
||
NewsArticle.metadata.create_all(self.engine, checkfirst=True)
|
||
|
||
logger.info("新闻数据库服务初始化完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"新闻数据库初始化失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
# 重新抛出异常,避免 SessionLocal 为 None
|
||
raise
|
||
|
||
def get_session(self) -> Session:
|
||
"""获取数据库会话"""
|
||
return self.SessionLocal()
|
||
|
||
def save_article(self, article_data: Dict[str, Any]) -> Optional[NewsArticle]:
|
||
"""
|
||
保存单篇文章
|
||
|
||
Args:
|
||
article_data: 文章数据字典
|
||
|
||
Returns:
|
||
保存的文章对象或 None
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
article = NewsArticle(**article_data)
|
||
session.add(article)
|
||
session.commit()
|
||
session.refresh(article)
|
||
|
||
logger.debug(f"文章保存成功: {article.title[:50]}...")
|
||
return article
|
||
|
||
except IntegrityError as e:
|
||
session.rollback()
|
||
logger.debug(f"文章已存在(URL 重复): {article_data.get('url', '')}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"保存文章失败: {e}")
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def check_duplicate_by_hash(self, content_hash: str, hours: int = 24) -> bool:
|
||
"""
|
||
检查内容哈希是否重复
|
||
|
||
Args:
|
||
content_hash: 内容哈希
|
||
hours: 检查最近多少小时
|
||
|
||
Returns:
|
||
True 如果重复
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
since = datetime.utcnow() - timedelta(hours=hours)
|
||
|
||
count = session.query(NewsArticle).filter(
|
||
and_(
|
||
NewsArticle.content_hash == content_hash,
|
||
NewsArticle.created_at >= since
|
||
)
|
||
).count()
|
||
|
||
return count > 0
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def mark_as_analyzed(
|
||
self,
|
||
article_id: int,
|
||
analysis: Dict[str, Any],
|
||
priority: float
|
||
) -> bool:
|
||
"""
|
||
标记文章已分析
|
||
|
||
Args:
|
||
article_id: 文章 ID
|
||
analysis: LLM 分析结果
|
||
priority: 优先级分数
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
article = session.query(NewsArticle).filter(
|
||
NewsArticle.id == article_id
|
||
).first()
|
||
|
||
if not article:
|
||
logger.warning(f"文章不存在: {article_id}")
|
||
return False
|
||
|
||
article.llm_analyzed = True
|
||
article.market_impact = analysis.get('market_impact')
|
||
article.impact_type = analysis.get('impact_type')
|
||
article.sentiment = analysis.get('sentiment')
|
||
article.summary = analysis.get('summary')
|
||
article.key_points = analysis.get('key_points')
|
||
article.trading_advice = analysis.get('trading_advice')
|
||
article.relevant_symbols = analysis.get('relevant_symbols')
|
||
article.quality_score = analysis.get('confidence', 70) / 100
|
||
article.priority = priority
|
||
|
||
session.commit()
|
||
|
||
logger.debug(f"文章分析结果已保存: {article.title[:50]}...")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"保存分析结果失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def mark_as_notified(self, article_id: int, channel: str = 'feishu') -> bool:
|
||
"""
|
||
标记文章已发送通知
|
||
|
||
Args:
|
||
article_id: 文章 ID
|
||
channel: 通知渠道
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
article = session.query(NewsArticle).filter(
|
||
NewsArticle.id == article_id
|
||
).first()
|
||
|
||
if not article:
|
||
return False
|
||
|
||
article.notified = True
|
||
article.notification_sent_at = datetime.utcnow()
|
||
article.notification_channel = channel
|
||
|
||
session.commit()
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"标记通知状态失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def get_high_priority_articles(
|
||
self,
|
||
limit: int = 20,
|
||
min_priority: float = 40.0,
|
||
hours: int = 24
|
||
) -> List[NewsArticle]:
|
||
"""
|
||
获取高优先级文章
|
||
|
||
Args:
|
||
limit: 返回数量限制
|
||
min_priority: 最低优先级分数
|
||
hours: 查询最近多少小时
|
||
|
||
Returns:
|
||
文章列表
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
since = datetime.utcnow() - timedelta(hours=hours)
|
||
|
||
articles = session.query(NewsArticle).filter(
|
||
and_(
|
||
NewsArticle.llm_analyzed == True,
|
||
NewsArticle.priority >= min_priority,
|
||
NewsArticle.created_at >= since,
|
||
NewsArticle.notified == False
|
||
)
|
||
).order_by(NewsArticle.priority.desc()).limit(limit).all()
|
||
|
||
return articles
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def get_latest_articles(
|
||
self,
|
||
category: str = None,
|
||
limit: int = 50,
|
||
hours: int = 24
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取最新文章
|
||
|
||
Args:
|
||
category: 分类过滤
|
||
limit: 返回数量限制
|
||
hours: 查询最近多少小时
|
||
|
||
Returns:
|
||
文章字典列表
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
since = datetime.utcnow() - timedelta(hours=hours)
|
||
|
||
query = session.query(NewsArticle).filter(
|
||
NewsArticle.created_at >= since
|
||
)
|
||
|
||
if category:
|
||
query = query.filter(NewsArticle.category == category)
|
||
|
||
articles = query.order_by(
|
||
NewsArticle.created_at.desc()
|
||
).limit(limit).all()
|
||
|
||
return [article.to_dict() for article in articles]
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def get_stats(self, hours: int = 24) -> Dict[str, Any]:
|
||
"""
|
||
获取统计数据
|
||
|
||
Args:
|
||
hours: 统计最近多少小时
|
||
|
||
Returns:
|
||
统计数据
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
since = datetime.utcnow() - timedelta(hours=hours)
|
||
|
||
total = session.query(NewsArticle).filter(
|
||
NewsArticle.created_at >= since
|
||
).count()
|
||
|
||
analyzed = session.query(NewsArticle).filter(
|
||
and_(
|
||
NewsArticle.created_at >= since,
|
||
NewsArticle.llm_analyzed == True
|
||
)
|
||
).count()
|
||
|
||
high_impact = session.query(NewsArticle).filter(
|
||
and_(
|
||
NewsArticle.created_at >= since,
|
||
NewsArticle.market_impact == 'high'
|
||
)
|
||
).count()
|
||
|
||
notified = session.query(NewsArticle).filter(
|
||
and_(
|
||
NewsArticle.created_at >= since,
|
||
NewsArticle.notified == True
|
||
)
|
||
).count()
|
||
|
||
return {
|
||
'total_articles': total,
|
||
'analyzed': analyzed,
|
||
'high_impact': high_impact,
|
||
'notified': notified,
|
||
'hours': hours
|
||
}
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def get_unanalyzed_articles(self, limit: int = 50, hours: int = 24) -> List[NewsArticle]:
|
||
"""
|
||
获取未分析的文章
|
||
|
||
Args:
|
||
limit: 返回数量限制
|
||
hours: 查询最近多少小时
|
||
|
||
Returns:
|
||
未分析的文章列表
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
since = datetime.utcnow() - timedelta(hours=hours)
|
||
|
||
articles = session.query(NewsArticle).filter(
|
||
and_(
|
||
NewsArticle.llm_analyzed == False,
|
||
NewsArticle.created_at >= since
|
||
)
|
||
).order_by(NewsArticle.created_at.desc()).limit(limit).all()
|
||
|
||
return articles
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
def clean_old_articles(self, days: int = 7) -> int:
|
||
"""
|
||
清理旧文章(设置为不活跃)
|
||
|
||
Args:
|
||
days: 保留多少天的文章
|
||
|
||
Returns:
|
||
清理的数量
|
||
"""
|
||
session = self.get_session()
|
||
try:
|
||
before = datetime.utcnow() - timedelta(days=days)
|
||
|
||
count = session.query(NewsArticle).filter(
|
||
NewsArticle.created_at < before
|
||
).update({
|
||
'is_active': False
|
||
})
|
||
|
||
session.commit()
|
||
|
||
if count > 0:
|
||
logger.info(f"清理了 {count} 条旧文章")
|
||
|
||
return count
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"清理旧文章失败: {e}")
|
||
return 0
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# 全局实例
|
||
_news_db_service = None
|
||
|
||
|
||
def get_news_db_service() -> NewsDatabaseService:
|
||
"""获取新闻数据库服务单例"""
|
||
global _news_db_service
|
||
if _news_db_service is None:
|
||
_news_db_service = NewsDatabaseService()
|
||
return _news_db_service
|